|
@@ -0,0 +1,119 @@
|
|
|
+from typing import Any, Union
|
|
|
+
|
|
|
+from vanna.remote import VannaDefault
|
|
|
+
|
|
|
+from core.tools.entities.tool_entities import ToolInvokeMessage
|
|
|
+from core.tools.errors import ToolProviderCredentialValidationError
|
|
|
+from core.tools.tool.builtin_tool import BuiltinTool
|
|
|
+
|
|
|
+
|
|
|
+class VannaTool(BuiltinTool):
|
|
|
+ def _invoke(
|
|
|
+ self, user_id: str, tool_parameters: dict[str, Any]
|
|
|
+ ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
|
|
+ """
|
|
|
+ invoke tools
|
|
|
+ """
|
|
|
+ api_key = self.runtime.credentials.get("api_key", None)
|
|
|
+ if not api_key:
|
|
|
+ raise ToolProviderCredentialValidationError("Please input api key")
|
|
|
+
|
|
|
+ model = tool_parameters.get("model", "")
|
|
|
+ if not model:
|
|
|
+ return self.create_text_message("Please input RAG model")
|
|
|
+
|
|
|
+ prompt = tool_parameters.get("prompt", "")
|
|
|
+ if not prompt:
|
|
|
+ return self.create_text_message("Please input prompt")
|
|
|
+
|
|
|
+ url = tool_parameters.get("url", "")
|
|
|
+ if not url:
|
|
|
+ return self.create_text_message("Please input URL/Host/DSN")
|
|
|
+
|
|
|
+ db_name = tool_parameters.get("db_name", "")
|
|
|
+ username = tool_parameters.get("username", "")
|
|
|
+ password = tool_parameters.get("password", "")
|
|
|
+ port = tool_parameters.get("port", 0)
|
|
|
+
|
|
|
+ vn = VannaDefault(model=model, api_key=api_key)
|
|
|
+
|
|
|
+ db_type = tool_parameters.get("db_type", "")
|
|
|
+ if db_type in ["Postgres", "MySQL", "Hive", "ClickHouse"]:
|
|
|
+ if not db_name:
|
|
|
+ return self.create_text_message("Please input database name")
|
|
|
+ if not username:
|
|
|
+ return self.create_text_message("Please input username")
|
|
|
+ if port < 1:
|
|
|
+ return self.create_text_message("Please input port")
|
|
|
+
|
|
|
+ schema_sql = "SELECT * FROM INFORMATION_SCHEMA.COLUMNS"
|
|
|
+ match db_type:
|
|
|
+ case "SQLite":
|
|
|
+ schema_sql = "SELECT type, sql FROM sqlite_master WHERE sql is not null"
|
|
|
+ vn.connect_to_sqlite(url)
|
|
|
+ case "Postgres":
|
|
|
+ vn.connect_to_postgres(host=url, dbname=db_name, user=username, password=password, port=port)
|
|
|
+ case "DuckDB":
|
|
|
+ vn.connect_to_duckdb(url=url)
|
|
|
+ case "SQLServer":
|
|
|
+ vn.connect_to_mssql(url)
|
|
|
+ case "MySQL":
|
|
|
+ vn.connect_to_mysql(host=url, dbname=db_name, user=username, password=password, port=port)
|
|
|
+ case "Oracle":
|
|
|
+ vn.connect_to_oracle(user=username, password=password, dsn=url)
|
|
|
+ case "Hive":
|
|
|
+ vn.connect_to_hive(host=url, dbname=db_name, user=username, password=password, port=port)
|
|
|
+ case "ClickHouse":
|
|
|
+ vn.connect_to_clickhouse(host=url, dbname=db_name, user=username, password=password, port=port)
|
|
|
+
|
|
|
+ enable_training = tool_parameters.get("enable_training", False)
|
|
|
+ reset_training_data = tool_parameters.get("reset_training_data", False)
|
|
|
+ if enable_training:
|
|
|
+ if reset_training_data:
|
|
|
+ existing_training_data = vn.get_training_data()
|
|
|
+ if len(existing_training_data) > 0:
|
|
|
+ for _, training_data in existing_training_data.iterrows():
|
|
|
+ vn.remove_training_data(training_data["id"])
|
|
|
+
|
|
|
+ ddl = tool_parameters.get("ddl", "")
|
|
|
+ question = tool_parameters.get("question", "")
|
|
|
+ sql = tool_parameters.get("sql", "")
|
|
|
+ memos = tool_parameters.get("memos", "")
|
|
|
+ training_metadata = tool_parameters.get("training_metadata", False)
|
|
|
+
|
|
|
+ if training_metadata:
|
|
|
+ if db_type == "SQLite":
|
|
|
+ df_ddl = vn.run_sql(schema_sql)
|
|
|
+ for ddl in df_ddl["sql"].to_list():
|
|
|
+ vn.train(ddl=ddl)
|
|
|
+ else:
|
|
|
+ df_information_schema = vn.run_sql(schema_sql)
|
|
|
+ plan = vn.get_training_plan_generic(df_information_schema)
|
|
|
+ vn.train(plan=plan)
|
|
|
+
|
|
|
+ if ddl:
|
|
|
+ vn.train(ddl=ddl)
|
|
|
+
|
|
|
+ if sql:
|
|
|
+ if question:
|
|
|
+ vn.train(question=question, sql=sql)
|
|
|
+ else:
|
|
|
+ vn.train(sql=sql)
|
|
|
+ if memos:
|
|
|
+ vn.train(documentation=memos)
|
|
|
+
|
|
|
+ generate_chart = tool_parameters.get("generate_chart", True)
|
|
|
+ res = vn.ask(prompt, False, True, generate_chart)
|
|
|
+
|
|
|
+ result = []
|
|
|
+
|
|
|
+ if res is not None:
|
|
|
+ result.append(self.create_text_message(res[0]))
|
|
|
+ if len(res) > 1 and res[1] is not None:
|
|
|
+ result.append(self.create_text_message(res[1].to_markdown()))
|
|
|
+ if len(res) > 2 and res[2] is not None:
|
|
|
+ result.append(
|
|
|
+ self.create_blob_message(blob=res[2].to_image(format="svg"), meta={"mime_type": "image/svg+xml"})
|
|
|
+ )
|
|
|
+
|
|
|
+ return result
|