Bladeren bron

feat: Add Vanna.AI as a builtin tool (#4878)

Co-authored-by: Yeuoly <admin@srmxy.cn>
Henry Lu 10 maanden geleden
bovenliggende
commit
2d9f55b632

BIN
api/core/tools/provider/builtin/vanna/_assets/icon.png


+ 119 - 0
api/core/tools/provider/builtin/vanna/tools/vanna.py

@@ -0,0 +1,119 @@
+from typing import Any, Union
+
+from vanna.remote import VannaDefault
+
+from core.tools.entities.tool_entities import ToolInvokeMessage
+from core.tools.errors import ToolProviderCredentialValidationError
+from core.tools.tool.builtin_tool import BuiltinTool
+
+
+class VannaTool(BuiltinTool):
+    def _invoke(
+        self, user_id: str, tool_parameters: dict[str, Any]
+    ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
+        """
+        invoke tools
+        """
+        api_key = self.runtime.credentials.get("api_key", None)
+        if not api_key:
+            raise ToolProviderCredentialValidationError("Please input api key")
+
+        model = tool_parameters.get("model", "")
+        if not model:
+            return self.create_text_message("Please input RAG model")
+
+        prompt = tool_parameters.get("prompt", "")
+        if not prompt:
+            return self.create_text_message("Please input prompt")
+
+        url = tool_parameters.get("url", "")
+        if not url:
+            return self.create_text_message("Please input URL/Host/DSN")
+
+        db_name = tool_parameters.get("db_name", "")
+        username = tool_parameters.get("username", "")
+        password = tool_parameters.get("password", "")
+        port = tool_parameters.get("port", 0)
+
+        vn = VannaDefault(model=model, api_key=api_key)
+
+        db_type = tool_parameters.get("db_type", "")
+        if db_type in ["Postgres", "MySQL", "Hive", "ClickHouse"]:
+            if not db_name:
+                return self.create_text_message("Please input database name")
+            if not username:
+                return self.create_text_message("Please input username")
+            if port < 1:
+                return self.create_text_message("Please input port")
+
+        schema_sql = "SELECT * FROM INFORMATION_SCHEMA.COLUMNS"
+        match db_type:
+            case "SQLite":
+                schema_sql = "SELECT type, sql FROM sqlite_master WHERE sql is not null"
+                vn.connect_to_sqlite(url)
+            case "Postgres":
+                vn.connect_to_postgres(host=url, dbname=db_name, user=username, password=password, port=port)
+            case "DuckDB":
+                vn.connect_to_duckdb(url=url)
+            case "SQLServer":
+                vn.connect_to_mssql(url)
+            case "MySQL":
+                vn.connect_to_mysql(host=url, dbname=db_name, user=username, password=password, port=port)
+            case "Oracle":
+                vn.connect_to_oracle(user=username, password=password, dsn=url)
+            case "Hive":
+                vn.connect_to_hive(host=url, dbname=db_name, user=username, password=password, port=port)
+            case "ClickHouse":
+                vn.connect_to_clickhouse(host=url, dbname=db_name, user=username, password=password, port=port)
+
+        enable_training = tool_parameters.get("enable_training", False)
+        reset_training_data = tool_parameters.get("reset_training_data", False)
+        if enable_training:
+            if reset_training_data:
+                existing_training_data = vn.get_training_data()
+                if len(existing_training_data) > 0:
+                    for _, training_data in existing_training_data.iterrows():
+                        vn.remove_training_data(training_data["id"])
+
+            ddl = tool_parameters.get("ddl", "")
+            question = tool_parameters.get("question", "")
+            sql = tool_parameters.get("sql", "")
+            memos = tool_parameters.get("memos", "")
+            training_metadata = tool_parameters.get("training_metadata", False)
+
+            if training_metadata:
+                if db_type == "SQLite":
+                    df_ddl = vn.run_sql(schema_sql)
+                    for ddl in df_ddl["sql"].to_list():
+                        vn.train(ddl=ddl)
+                else:
+                    df_information_schema = vn.run_sql(schema_sql)
+                    plan = vn.get_training_plan_generic(df_information_schema)
+                    vn.train(plan=plan)
+
+            if ddl:
+                vn.train(ddl=ddl)
+
+            if sql:
+                if question:
+                    vn.train(question=question, sql=sql)
+                else:
+                    vn.train(sql=sql)
+            if memos:
+                vn.train(documentation=memos)
+
+        generate_chart = tool_parameters.get("generate_chart", True)
+        res = vn.ask(prompt, False, True, generate_chart)
+
+        result = []
+
+        if res is not None:
+            result.append(self.create_text_message(res[0]))
+            if len(res) > 1 and res[1] is not None:
+                result.append(self.create_text_message(res[1].to_markdown()))
+            if len(res) > 2 and res[2] is not None:
+                result.append(
+                    self.create_blob_message(blob=res[2].to_image(format="svg"), meta={"mime_type": "image/svg+xml"})
+                )
+
+        return result

+ 213 - 0
api/core/tools/provider/builtin/vanna/tools/vanna.yaml

@@ -0,0 +1,213 @@
+identity:
+  name: vanna
+  author: QCTC
+  label:
+    en_US: Vanna.AI
+    zh_Hans: Vanna.AI
+description:
+  human:
+    en_US: The fastest way to get actionable insights from your database just by asking questions.
+    zh_Hans: 一个基于大模型和RAG的Text2SQL工具。
+  llm: A tool for converting text to SQL.
+parameters:
+  - name: prompt
+    type: string
+    required: true
+    label:
+      en_US: Prompt
+      zh_Hans: 提示词
+      pt_BR: Prompt
+    human_description:
+      en_US: used for generating SQL
+      zh_Hans: 用于生成SQL
+    llm_description: key words for generating SQL
+    form: llm
+  - name: model
+    type: string
+    required: true
+    label:
+      en_US: RAG Model
+      zh_Hans: RAG模型
+    human_description:
+      en_US: RAG Model for your database DDL
+      zh_Hans: 存储数据库训练数据的RAG模型
+    llm_description: RAG Model for generating SQL
+    form: form
+  - name: db_type
+    type: select
+    required: true
+    options:
+      - value: SQLite
+        label:
+          en_US: SQLite
+          zh_Hans: SQLite
+      - value: Postgres
+        label:
+          en_US: Postgres
+          zh_Hans: Postgres
+      - value: DuckDB
+        label:
+          en_US: DuckDB
+          zh_Hans: DuckDB
+      - value: SQLServer
+        label:
+          en_US: Microsoft SQL Server
+          zh_Hans: 微软 SQL Server
+      - value: MySQL
+        label:
+          en_US: MySQL
+          zh_Hans: MySQL
+      - value: Oracle
+        label:
+          en_US: Oracle
+          zh_Hans: Oracle
+      - value: Hive
+        label:
+          en_US: Hive
+          zh_Hans: Hive
+      - value: ClickHouse
+        label:
+          en_US: ClickHouse
+          zh_Hans: ClickHouse
+    default: SQLite
+    label:
+      en_US: DB Type
+      zh_Hans: 数据库类型
+    human_description:
+      en_US: Database type.
+      zh_Hans: 选择要链接的数据库类型。
+    form: form
+  - name: url
+    type: string
+    required: true
+    label:
+      en_US: URL/Host/DSN
+      zh_Hans: URL/Host/DSN
+    human_description:
+      en_US: Please input depending on DB type, visit https://vanna.ai/docs/ for more specification
+      zh_Hans: 请根据数据库类型,填入对应值,详情参考https://vanna.ai/docs/
+    form: form
+  - name: db_name
+    type: string
+    required: false
+    label:
+      en_US: DB name
+      zh_Hans: 数据库名
+    human_description:
+      en_US: Database name
+      zh_Hans: 数据库名
+    form: form
+  - name: username
+    type: string
+    required: false
+    label:
+      en_US: Username
+      zh_Hans: 用户名
+    human_description:
+      en_US: Username
+      zh_Hans: 用户名
+    form: form
+  - name: password
+    type: secret-input
+    required: false
+    label:
+      en_US: Password
+      zh_Hans: 密码
+    human_description:
+      en_US: Password
+      zh_Hans: 密码
+    form: form
+  - name: port
+    type: number
+    required: false
+    label:
+      en_US: Port
+      zh_Hans: 端口
+    human_description:
+      en_US: Port
+      zh_Hans: 端口
+    form: form
+  - name: ddl
+    type: string
+    required: false
+    label:
+      en_US: Training DDL
+      zh_Hans: 训练DDL
+    human_description:
+      en_US: DDL statements for training data
+      zh_Hans: 用于训练RAG Model的建表语句
+    form: form
+  - name: question
+    type: string
+    required: false
+    label:
+      en_US: Training Question
+      zh_Hans: 训练问题
+    human_description:
+      en_US: Question-SQL Pairs
+      zh_Hans: Question-SQL中的问题
+    form: form
+  - name: sql
+    type: string
+    required: false
+    label:
+      en_US: Training SQL
+      zh_Hans: 训练SQL
+    human_description:
+      en_US: SQL queries to your training data
+      zh_Hans: 用于训练RAG Model的SQL语句
+    form: form
+  - name: memos
+    type: string
+    required: false
+    label:
+      en_US: Training Memos
+      zh_Hans: 训练说明
+    human_description:
+      en_US: Sometimes you may want to add documentation about your business terminology or definitions
+      zh_Hans: 添加更多关于数据库的业务说明
+    form: form
+  - name: enable_training
+    type: boolean
+    required: false
+    default: false
+    label:
+      en_US: Training Data
+      zh_Hans: 训练数据
+    human_description:
+      en_US: You only need to train once. Do not train again unless you want to add more training data
+      zh_Hans: 训练数据无更新时,训练一次即可
+    form: form
+  - name: reset_training_data
+    type: boolean
+    required: false
+    default: false
+    label:
+      en_US: Reset Training Data
+      zh_Hans: 重置训练数据
+    human_description:
+      en_US: Remove all training data in the current RAG Model
+      zh_Hans: 删除当前RAG Model中的所有训练数据
+    form: form
+  - name: training_metadata
+    type: boolean
+    required: false
+    default: false
+    label:
+      en_US: Training Metadata
+      zh_Hans: 训练元数据
+    human_description:
+      en_US: If enabled, it will attempt to train on the metadata of that database
+      zh_Hans: 是否自动从数据库获取元数据来训练
+    form: form
+  - name: generate_chart
+    type: boolean
+    required: false
+    default: True
+    label:
+      en_US: Generate Charts
+      zh_Hans: 生成图表
+    human_description:
+      en_US: Generate Charts
+      zh_Hans: 是否生成图表
+    form: form

+ 25 - 0
api/core/tools/provider/builtin/vanna/vanna.py

@@ -0,0 +1,25 @@
+from typing import Any
+
+from core.tools.errors import ToolProviderCredentialValidationError
+from core.tools.provider.builtin.vanna.tools.vanna import VannaTool
+from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
+
+
+class VannaProvider(BuiltinToolProviderController):
+    def _validate_credentials(self, credentials: dict[str, Any]) -> None:
+        try:
+            VannaTool().fork_tool_runtime(
+                runtime={
+                    "credentials": credentials,
+                }
+            ).invoke(
+                user_id='',
+                tool_parameters={
+                    "model": "chinook",
+                    "db_type": "SQLite",
+                    "url": "https://vanna.ai/Chinook.sqlite",
+                    "query": "What are the top 10 customers by sales?"
+                },
+            )
+        except Exception as e:
+            raise ToolProviderCredentialValidationError(str(e))

+ 25 - 0
api/core/tools/provider/builtin/vanna/vanna.yaml

@@ -0,0 +1,25 @@
+identity:
+  author: QCTC
+  name: vanna
+  label:
+    en_US: Vanna.AI
+    zh_Hans: Vanna.AI
+  description:
+    en_US: The fastest way to get actionable insights from your database just by asking questions.
+    zh_Hans: 一个基于大模型和RAG的Text2SQL工具。
+  icon: icon.png
+credentials_for_provider:
+  api_key:
+    type: secret-input
+    required: true
+    label:
+      en_US: API key
+      zh_Hans: API key
+    placeholder:
+      en_US: Please input your API key
+      zh_Hans: 请输入你的 API key
+      pt_BR: Please input your API key
+    help:
+      en_US: Get your API key from Vanna.AI
+      zh_Hans: 从 Vanna.AI 获取你的 API key
+    url: https://vanna.ai/account/profile

+ 1 - 0
api/requirements.txt

@@ -82,3 +82,4 @@ firecrawl-py==0.0.5
 oss2==2.18.5
 pgvector==0.2.5
 google-cloud-aiplatform==1.49.0
+vanna[postgres,mysql,clickhouse,duckdb]==0.5.5