Ver código fonte

Add custom tools (#2259)

Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Charlie.Wei 1 ano atrás
pai
commit
76cc19f525

+ 4 - 1
api/core/tools/provider/builtin/_positions.py

@@ -13,8 +13,11 @@ position = {
     'stablediffusion': 9,
     'vectorizer': 10,
     'youtube': 11,
+    'github': 12,
+    'gaode': 13
 }
 
+
 class BuiltinToolProviderSort:
     @staticmethod
     def sort(providers: List[UserToolProvider]) -> List[UserToolProvider]:
@@ -23,4 +26,4 @@ class BuiltinToolProviderSort:
         
         sorted_providers = sorted(providers, key=sort_compare)
 
-        return sorted_providers
+        return sorted_providers

BIN
api/core/tools/provider/builtin/gaode/_assets/icon.png


+ 24 - 0
api/core/tools/provider/builtin/gaode/gaode.py

@@ -0,0 +1,24 @@
+import requests
+import urllib.parse
+from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
+from core.tools.errors import ToolProviderCredentialValidationError
+
+
+class GaodeProvider(BuiltinToolProviderController):
+    def _validate_credentials(self, credentials: dict) -> None:
+        try:
+            if 'api_key' not in credentials or not credentials.get('api_key'):
+                raise ToolProviderCredentialValidationError("Gaode API key is required.")
+
+            try:
+                response = requests.get(url="https://restapi.amap.com/v3/geocode/geo?address={address}&key={apikey}"
+                                            "".format(address=urllib.parse.quote('广东省广州市天河区广州塔'),
+                                                      apikey=credentials.get('api_key')))
+                if response.status_code == 200 and (response.json()).get('info') == 'OK':
+                    pass
+                else:
+                    raise ToolProviderCredentialValidationError((response.json()).get('info'))
+            except Exception as e:
+                raise ToolProviderCredentialValidationError("Gaode API Key is invalid. {}".format(e))
+        except Exception as e:
+            raise ToolProviderCredentialValidationError(str(e))

+ 29 - 0
api/core/tools/provider/builtin/gaode/gaode.yaml

@@ -0,0 +1,29 @@
+identity:
+  author: CharlirWei
+  name: gaode
+  label:
+    en_US: GaoDe
+    zh_Hans: 高德
+    pt_BR: GaoDe
+  description:
+    en_US: Autonavi Open Platform service toolkit.
+    zh_Hans: 高德开放平台服务工具包。
+    pt_BR: Kit de ferramentas de serviço Autonavi Open Platform.
+  icon: icon.png
+credentials_for_provider:
+  api_key:
+    type: secret-input
+    required: true
+    label:
+      en_US: API Key
+      zh_Hans: API Key
+      pt_BR: Fogo a chave
+    placeholder:
+      en_US: Please enter your GaoDe API Key
+      zh_Hans: 请输入你的高德开放平台 API Key
+      pt_BR: Insira sua chave de API GaoDe
+    help:
+      en_US: Get your API Key from GaoDe
+      zh_Hans: 从高德获取您的 API Key
+      pt_BR: Obtenha sua chave de API do GaoDe
+    url: https://console.amap.com/dev/key/app

+ 55 - 0
api/core/tools/provider/builtin/gaode/tools/gaode_weather.py

@@ -0,0 +1,55 @@
+import json
+import requests
+from core.tools.tool.builtin_tool import BuiltinTool
+from core.tools.entities.tool_entities import ToolInvokeMessage
+from typing import Any, Dict, List, Union
+
+
+class GaodeRepositoriesTool(BuiltinTool):
+    def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
+        """
+            invoke tools
+        """
+        city = tool_paramters.get('city', '')
+        if not city:
+            return self.create_text_message('Please tell me your city')
+
+        if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'):
+            return self.create_text_message("Gaode API key is required.")
+
+        try:
+            s = requests.session()
+            api_domain = 'https://restapi.amap.com/v3'
+            city_response = s.request(method='GET', headers={"Content-Type": "application/json; charset=utf-8"},
+                                      url="{url}/config/district?keywords={keywords}"
+                                          "&subdistrict=0&extensions=base&key={apikey}"
+                                          "".format(url=api_domain, keywords=city,
+                                                    apikey=self.runtime.credentials.get('api_key')))
+            City_data = city_response.json()
+            if city_response.status_code == 200 and City_data.get('info') == 'OK':
+                if len(City_data.get('districts')) > 0:
+                    CityCode = City_data['districts'][0]['adcode']
+                    weatherInfo_response = s.request(method='GET',
+                                                     url="{url}/weather/weatherInfo?city={citycode}&extensions=all&key={apikey}&output=json"
+                                                         "".format(url=api_domain, citycode=CityCode,
+                                                                   apikey=self.runtime.credentials.get('api_key')))
+                    weatherInfo_data = weatherInfo_response.json()
+                    if weatherInfo_response.status_code == 200 and weatherInfo_data.get('info') == 'OK':
+                        contents = list()
+                        if len(weatherInfo_data.get('forecasts')) > 0:
+                            for item in weatherInfo_data['forecasts'][0]['casts']:
+                                content = dict()
+                                content['date'] = item.get('date')
+                                content['week'] = item.get('week')
+                                content['dayweather'] = item.get('dayweather')
+                                content['daytemp_float'] = item.get('daytemp_float')
+                                content['daywind'] = item.get('daywind')
+                                content['nightweather'] = item.get('nightweather')
+                                content['nighttemp_float'] = item.get('nighttemp_float')
+                                contents.append(content)
+                            s.close()
+                            return self.create_text_message(self.summary(user_id=user_id, content=json.dumps(contents, ensure_ascii=False)))
+            s.close()
+            return self.create_text_message(f'No weather information for {city} was found.')
+        except Exception as e:
+            return self.create_text_message("Github API Key and Api Version is invalid. {}".format(e))

+ 28 - 0
api/core/tools/provider/builtin/gaode/tools/gaode_weather.yaml

@@ -0,0 +1,28 @@
+identity:
+  name: gaode_weather
+  author: CharlieWei
+  label:
+    en_US: Weather Forecast
+    zh_Hans: 天气预报
+    pt_BR: Previsão do tempo
+  icon: icon.svg
+description:
+  human:
+    en_US: Weather forecast inquiry
+    zh_Hans: 天气预报查询。
+    pt_BR: Inquérito sobre previsão meteorológica.
+  llm: A tool when you want to ask about the weather or weather-related question.
+parameters:
+  - name: city
+    type: string
+    required: true
+    label:
+      en_US: city
+      zh_Hans: 城市
+      pt_BR: cidade
+    human_description:
+      en_US: Target city for weather forecast query.
+      zh_Hans: 天气预报查询的目标城市。
+      pt_BR: Cidade de destino para consulta de previsão do tempo.
+    llm_description: If you don't know you can extract the city name from the question or you can reply:Please tell me your city. You have to extract the Chinese city name from the question.
+    form: llm

BIN
api/core/tools/provider/builtin/github/_assets/icon.png


+ 31 - 0
api/core/tools/provider/builtin/github/github.py

@@ -0,0 +1,31 @@
+import requests
+from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
+from core.tools.errors import ToolProviderCredentialValidationError
+
+
+class GihubProvider(BuiltinToolProviderController):
+    def _validate_credentials(self, credentials: dict) -> None:
+        try:
+            if 'access_tokens' not in credentials or not credentials.get('access_tokens'):
+                raise ToolProviderCredentialValidationError("Github API Access Tokens is required.")
+            if 'api_version' not in credentials or not credentials.get('api_version'):
+                api_version = '2022-11-28'
+            else:
+                api_version = credentials.get('api_version')
+
+            try:
+                headers = {
+                    "Content-Type": "application/vnd.github+json",
+                    "Authorization": f"Bearer {credentials.get('access_tokens')}",
+                    "X-GitHub-Api-Version": api_version
+                }
+
+                response = requests.get(
+                    url="https://api.github.com/search/users?q={account}".format(account='charli117'),
+                    headers=headers)
+                if response.status_code != 200:
+                    raise ToolProviderCredentialValidationError((response.json()).get('message'))
+            except Exception as e:
+                raise ToolProviderCredentialValidationError("Github API Key and Api Version is invalid. {}".format(e))
+        except Exception as e:
+            raise ToolProviderCredentialValidationError(str(e))

+ 46 - 0
api/core/tools/provider/builtin/github/github.yaml

@@ -0,0 +1,46 @@
+identity:
+  author: CharlirWei
+  name: github
+  label:
+    en_US: Github
+    zh_Hans: Github
+    pt_BR: Github
+  description:
+    en_US: GitHub is an online software source code hosting service.
+    zh_Hans: GitHub是一个在线软件源代码托管服务平台。
+    pt_BR: GitHub é uma plataforma online para serviços de hospedagem de código fonte de software.
+  icon: icon.png
+credentials_for_provider:
+  access_tokens:
+    type: secret-input
+    required: true
+    label:
+      en_US: Access Tokens
+      zh_Hans: Access Tokens
+      pt_BR: Tokens de acesso
+    placeholder:
+      en_US: Please input your Github Access Tokens
+      zh_Hans: 请输入你的 Github Access Tokens
+      pt_BR: Insira seus Tokens de Acesso do Github
+    help:
+      en_US: Get your Access Tokens from Github
+      zh_Hans: 从 Github 获取您的 Access Tokens
+      pt_BR: Obtenha sua chave da API do Google no Google
+    url: https://github.com/settings/tokens?type=beta
+  api_version:
+    type: text-input
+    required: false
+    default: '2022-11-28'
+    label:
+      en_US: API Version
+      zh_Hans: API Version
+      pt_BR: Versão da API
+    placeholder:
+      en_US: Please input your Github API Version
+      zh_Hans: 请输入你的 Github API Version
+      pt_BR: Insira sua versão da API do Github
+    help:
+      en_US: Get your API Version from Github
+      zh_Hans: 从 Github 获取您的 API Version
+      pt_BR: Obtenha sua versão da API do Github
+    url: https://docs.github.com/en/rest/about-the-rest-api/api-versions?apiVersion=2022-11-28

+ 61 - 0
api/core/tools/provider/builtin/github/tools/repositories.py

@@ -0,0 +1,61 @@
+import json
+import requests
+from datetime import datetime
+from urllib.parse import quote
+from core.tools.tool.builtin_tool import BuiltinTool
+from core.tools.entities.tool_entities import ToolInvokeMessage
+
+from typing import Any, Dict, List, Union
+
+
+class GihubRepositoriesTool(BuiltinTool):
+    def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
+        """
+            invoke tools
+        """
+        top_n = tool_paramters.get('top_n', 5)
+        query = tool_paramters.get('query', '')
+        if not query:
+            return self.create_text_message('Please input symbol')
+
+        if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'):
+            return self.create_text_message("Github API Access Tokens is required.")
+        if 'api_version' not in self.runtime.credentials or not self.runtime.credentials.get('api_version'):
+            api_version = '2022-11-28'
+        else:
+            api_version = self.runtime.credentials.get('api_version')
+
+        try:
+            headers = {
+                "Content-Type": "application/vnd.github+json",
+                "Authorization": f"Bearer {self.runtime.credentials.get('access_tokens')}",
+                "X-GitHub-Api-Version": api_version
+            }
+            s = requests.session()
+            api_domain = 'https://api.github.com'
+            response = s.request(method='GET', headers=headers,
+                                 url=f"{api_domain}/search/repositories?"
+                                     f"q={quote(query)}&sort=stars&per_page={top_n}&order=desc")
+            response_data = response.json()
+            if response.status_code == 200 and isinstance(response_data.get('items'), list):
+                contents = list()
+                if len(response_data.get('items')) > 0:
+                    for item in response_data.get('items'):
+                        content = dict()
+                        updated_at_object = datetime.strptime(item['updated_at'], "%Y-%m-%dT%H:%M:%SZ")
+                        content['owner'] = item['owner']['login']
+                        content['name'] = item['name']
+                        content['description'] = item['description'][:100] + '...' if len(item['description']) > 100 else item['description']
+                        content['url'] = item['html_url']
+                        content['star'] = item['watchers']
+                        content['forks'] = item['forks']
+                        content['updated'] = updated_at_object.strftime("%Y-%m-%d")
+                        contents.append(content)
+                    s.close()
+                    return self.create_text_message(self.summary(user_id=user_id, content=json.dumps(contents, ensure_ascii=False)))
+                else:
+                    return self.create_text_message(f'No items related to {query} were found.')
+            else:
+                return self.create_text_message((response.json()).get('message'))
+        except Exception as e:
+            return self.create_text_message("Github API Key and Api Version is invalid. {}".format(e))

+ 42 - 0
api/core/tools/provider/builtin/github/tools/repositories.yaml

@@ -0,0 +1,42 @@
+identity:
+  name: repositories
+  author: CharlieWei
+  label:
+    en_US: Search Repositories
+    zh_Hans: 仓库搜索
+    pt_BR: Pesquisar Repositórios
+  icon: icon.svg
+description:
+  human:
+    en_US: Search the Github repository to retrieve the open source projects you need
+    zh_Hans: 搜索Github仓库,检索你需要的开源项目。
+    pt_BR: Pesquise o repositório do Github para recuperar os projetos de código aberto necessários.
+  llm: A tool when you wants to search for popular warehouses or open source projects for any keyword. format query condition like "keywords+language:js", language can be other dev languages.
+parameters:
+  - name: query
+    type: string
+    required: true
+    label:
+      en_US: query
+      zh_Hans: 关键字
+      pt_BR: consulta
+    human_description:
+      en_US: You want to find the project development language, keywords, For example. Find 10 Python developed PDF document parsing projects.
+      zh_Hans: 你想要找的项目开发语言、关键字,如:找10个Python开发的PDF文档解析项目。
+      pt_BR: Você deseja encontrar a linguagem de desenvolvimento do projeto, palavras-chave, Por exemplo. Encontre 10 projetos de análise de documentos PDF desenvolvidos em Python.
+    llm_description: The query of you want to search, format query condition like "keywords+language:js", language can be other dev languages, por exemplo. Procuro um projeto de análise de documentos PDF desenvolvido em Python.
+    form: llm
+  - name: top_n
+    type: number
+    default: 5
+    required: true
+    label:
+      en_US: Top N
+      zh_Hans: Top N
+      pt_BR: Topo N
+    human_description:
+      en_US: Number of records returned by sorting based on stars. 5 is returned by default.
+      zh_Hans: 基于stars排序返回的记录数, 默认返回5条。
+      pt_BR: Número de registros retornados por classificação com base em estrelas. 5 é retornado por padrão.
+    llm_description: Extract the first N records from the returned result.
+    form: llm