Quellcode durchsuchen

Feat/support tool credentials bool schema (#2875)

Yeuoly vor 1 Jahr
Ursprung
Commit
95b74c211d

+ 2 - 1
api/core/tools/entities/tool_entities.py

@@ -171,6 +171,7 @@ class ToolProviderCredentials(BaseModel):
         SECRET_INPUT = "secret-input"
         TEXT_INPUT = "text-input"
         SELECT = "select"
+        BOOLEAN = "boolean"
 
         @classmethod
         def value_of(cls, value: str) -> "ToolProviderCredentials.CredentialsType":
@@ -192,7 +193,7 @@ class ToolProviderCredentials(BaseModel):
     name: str = Field(..., description="The name of the credentials")
     type: CredentialsType = Field(..., description="The type of the credentials")
     required: bool = False
-    default: Optional[str] = None
+    default: Optional[Union[int, str]] = None
     options: Optional[list[ToolCredentialsOption]] = None
     label: Optional[I18nObject] = None
     help: Optional[I18nObject] = None

+ 2 - 3
api/core/tools/provider/builtin/bing/bing.py

@@ -12,12 +12,11 @@ class BingProvider(BuiltinToolProviderController):
                 meta={
                     "credentials": credentials,
                 }
-            ).invoke(
-                user_id='',
+            ).validate_credentials(
+                credentials=credentials,
                 tool_parameters={
                     "query": "test",
                     "result_type": "link",
-                    "enable_webpages": True,
                 },
             )
         except Exception as e:

+ 60 - 0
api/core/tools/provider/builtin/bing/bing.yaml

@@ -43,3 +43,63 @@ credentials_for_provider:
       zh_Hans: 例如 "https://api.bing.microsoft.com/v7.0/search"
       pt_BR: An endpoint is like "https://api.bing.microsoft.com/v7.0/search"
     default: https://api.bing.microsoft.com/v7.0/search
+  allow_entities:
+    type: boolean
+    required: false
+    label:
+      en_US: Allow Entities Search
+      zh_Hans: 支持实体搜索
+      pt_BR: Allow Entities Search
+    help:
+      en_US: Does your subscription plan allow entity search
+      zh_Hans: 您的订阅计划是否支持实体搜索
+      pt_BR: Does your subscription plan allow entity search
+    default: true
+  allow_web_pages:
+    type: boolean
+    required: false
+    label:
+      en_US: Allow Web Pages Search
+      zh_Hans: 支持网页搜索
+      pt_BR: Allow Web Pages Search
+    help:
+      en_US: Does your subscription plan allow web pages search
+      zh_Hans: 您的订阅计划是否支持网页搜索
+      pt_BR: Does your subscription plan allow web pages search
+    default: true
+  allow_computation:
+    type: boolean
+    required: false
+    label:
+      en_US: Allow Computation Search
+      zh_Hans: 支持计算搜索
+      pt_BR: Allow Computation Search
+    help:
+      en_US: Does your subscription plan allow computation search
+      zh_Hans: 您的订阅计划是否支持计算搜索
+      pt_BR: Does your subscription plan allow computation search
+    default: false
+  allow_news:
+    type: boolean
+    required: false
+    label:
+      en_US: Allow News Search
+      zh_Hans: 支持新闻搜索
+      pt_BR: Allow News Search
+    help:
+      en_US: Does your subscription plan allow news search
+      zh_Hans: 您的订阅计划是否支持新闻搜索
+      pt_BR: Does your subscription plan allow news search
+    default: false
+  allow_related_searches:
+    type: boolean
+    required: false
+    label:
+      en_US: Allow Related Searches
+      zh_Hans: 支持相关搜索
+      pt_BR: Allow Related Searches
+    help:
+      en_US: Does your subscription plan allow related searches
+      zh_Hans: 您的订阅计划是否支持相关搜索
+      pt_BR: Does your subscription plan allow related searches
+    default: false

+ 110 - 38
api/core/tools/provider/builtin/bing/tools/bing_web_search.py

@@ -10,53 +10,23 @@ from core.tools.tool.builtin_tool import BuiltinTool
 class BingSearchTool(BuiltinTool):
     url = 'https://api.bing.microsoft.com/v7.0/search'
 
-    def _invoke(self, 
-                user_id: str,
-               tool_parameters: dict[str, Any], 
-        ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
+    def _invoke_bing(self, 
+                     user_id: str,
+                     subscription_key: str, query: str, limit: int, 
+                     result_type: str, market: str, lang: str, 
+                     filters: list[str]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
         """
-            invoke tools
+            invoke bing search
         """
-
-        key = self.runtime.credentials.get('subscription_key', None)
-        if not key:
-            raise Exception('subscription_key is required')
-        
-        server_url = self.runtime.credentials.get('server_url', None)
-        if not server_url:
-            server_url = self.url
-        
-        query = tool_parameters.get('query', None)
-        if not query:
-            raise Exception('query is required')
-        
-        limit = min(tool_parameters.get('limit', 5), 10)
-        result_type = tool_parameters.get('result_type', 'text') or 'text'
-        
-        market = tool_parameters.get('market', 'US')
-        lang = tool_parameters.get('language', 'en')
-        filter = []
-
-        if tool_parameters.get('enable_computation', False):
-            filter.append('Computation')
-        if tool_parameters.get('enable_entities', False):
-            filter.append('Entities')
-        if tool_parameters.get('enable_news', False):
-            filter.append('News')
-        if tool_parameters.get('enable_related_search', False):
-            filter.append('RelatedSearches')
-        if tool_parameters.get('enable_webpages', False):
-            filter.append('WebPages')
-
         market_code = f'{lang}-{market}'
         accept_language = f'{lang},{market_code};q=0.9'
         headers = {
-            'Ocp-Apim-Subscription-Key': key,
+            'Ocp-Apim-Subscription-Key': subscription_key,
             'Accept-Language': accept_language
         }
 
         query = quote(query)
-        server_url = f'{server_url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filter)}'
+        server_url = f'{self.url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filters)}'
         response = get(server_url, headers=headers)
 
         if response.status_code != 200:
@@ -124,3 +94,105 @@ class BingSearchTool(BuiltinTool):
                     text += f'{related["displayText"]} - {related["webSearchUrl"]}\n'
 
             return self.create_text_message(text=self.summary(user_id=user_id, content=text))
+        
+
+    def validate_credentials(self, credentials: dict[str, Any], tool_parameters: dict[str, Any]) -> None:
+        key = credentials.get('subscription_key', None)
+        if not key:
+            raise Exception('subscription_key is required')
+        
+        server_url = credentials.get('server_url', None)
+        if not server_url:
+            server_url = self.url
+
+        query = tool_parameters.get('query', None)
+        if not query:
+            raise Exception('query is required')
+        
+        limit = min(tool_parameters.get('limit', 5), 10)
+        result_type = tool_parameters.get('result_type', 'text') or 'text'
+
+        market = tool_parameters.get('market', 'US')
+        lang = tool_parameters.get('language', 'en')
+        filter = []
+
+        if credentials.get('allow_entities', False):
+            filter.append('Entities')
+
+        if credentials.get('allow_computation', False):
+            filter.append('Computation')
+
+        if credentials.get('allow_news', False):
+            filter.append('News')
+
+        if credentials.get('allow_related_searches', False):
+            filter.append('RelatedSearches')
+
+        if credentials.get('allow_web_pages', False):
+            filter.append('WebPages')
+
+        if not filter:
+            raise Exception('At least one filter is required')
+        
+        self._invoke_bing(
+            user_id='test',
+            subscription_key=key,
+            query=query,
+            limit=limit,
+            result_type=result_type,
+            market=market,
+            lang=lang,
+            filters=filter
+        )
+        
+    def _invoke(self, 
+                user_id: str,
+               tool_parameters: dict[str, Any], 
+        ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
+        """
+            invoke tools
+        """
+
+        key = self.runtime.credentials.get('subscription_key', None)
+        if not key:
+            raise Exception('subscription_key is required')
+        
+        server_url = self.runtime.credentials.get('server_url', None)
+        if not server_url:
+            server_url = self.url
+        
+        query = tool_parameters.get('query', None)
+        if not query:
+            raise Exception('query is required')
+        
+        limit = min(tool_parameters.get('limit', 5), 10)
+        result_type = tool_parameters.get('result_type', 'text') or 'text'
+        
+        market = tool_parameters.get('market', 'US')
+        lang = tool_parameters.get('language', 'en')
+        filter = []
+
+        if tool_parameters.get('enable_computation', False):
+            filter.append('Computation')
+        if tool_parameters.get('enable_entities', False):
+            filter.append('Entities')
+        if tool_parameters.get('enable_news', False):
+            filter.append('News')
+        if tool_parameters.get('enable_related_search', False):
+            filter.append('RelatedSearches')
+        if tool_parameters.get('enable_webpages', False):
+            filter.append('WebPages')
+
+        if not filter:
+            raise Exception('At least one filter is required')
+        
+        return self._invoke_bing(
+            user_id=user_id,
+            subscription_key=key,
+            query=query,
+            limit=limit,
+            result_type=result_type,
+            market=market,
+            lang=lang,
+            filters=filter
+        )

+ 21 - 2
api/core/tools/provider/builtin_tool_provider.py

@@ -246,8 +246,27 @@ class BuiltinToolProviderController(ToolProviderController):
                 
                 if credentials[credential_name] not in [x.value for x in options]:
                     raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be one of {options}')
-            
-            if credentials[credential_name]:
+            elif credential_schema.type == ToolProviderCredentials.CredentialsType.BOOLEAN:
+                if isinstance(credentials[credential_name], bool):
+                    pass
+                elif isinstance(credentials[credential_name], str):
+                    if credentials[credential_name].lower() == 'true':
+                        credentials[credential_name] = True
+                    elif credentials[credential_name].lower() == 'false':
+                        credentials[credential_name] = False
+                    else:
+                        raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be boolean')
+                elif isinstance(credentials[credential_name], int):
+                    if credentials[credential_name] == 1:
+                        credentials[credential_name] = True
+                    elif credentials[credential_name] == 0:
+                        credentials[credential_name] = False
+                    else:
+                        raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be boolean')
+                else:
+                    raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be boolean')
+
+            if credentials[credential_name] or credentials[credential_name] == False:
                 credentials_need_to_validate.pop(credential_name)
 
         for credential_name in credentials_need_to_validate:

+ 3 - 3
api/services/tools_manage_service.py

@@ -138,9 +138,9 @@ class ToolManageService:
             :return: the list of tool providers
         """
         provider = ToolManager.get_builtin_provider(provider_name)
-        return [
-            v.to_dict() for _, v in (provider.credentials_schema or {}).items()
-        ]
+        return json.loads(serialize_base_model_array([
+            v for _, v in (provider.credentials_schema or {}).items()
+        ]))
 
     @staticmethod
     def parser_api_schema(schema: str) -> list[ApiBasedToolBundle]:

+ 6 - 3
web/app/components/tools/setting/build-in/config-credentials.tsx

@@ -3,7 +3,7 @@ import type { FC } from 'react'
 import React, { useEffect, useState } from 'react'
 import { useTranslation } from 'react-i18next'
 import cn from 'classnames'
-import { toolCredentialToFormSchemas } from '../../utils/to-form-schema'
+import { addDefaultValue, toolCredentialToFormSchemas } from '../../utils/to-form-schema'
 import type { Collection } from '../../types'
 import Drawer from '@/app/components/base/drawer-plus'
 import Button from '@/app/components/base/button'
@@ -28,12 +28,15 @@ const ConfigCredential: FC<Props> = ({
   const { t } = useTranslation()
   const [credentialSchema, setCredentialSchema] = useState<any>(null)
   const { team_credentials: credentialValue, name: collectionName } = collection
+  const [tempCredential, setTempCredential] = React.useState<any>(credentialValue)
   useEffect(() => {
     fetchBuiltInToolCredentialSchema(collectionName).then((res) => {
-      setCredentialSchema(toolCredentialToFormSchemas(res))
+      const toolCredentialSchemas = toolCredentialToFormSchemas(res)
+      const defaultCredentials = addDefaultValue(credentialValue, toolCredentialSchemas)
+      setCredentialSchema(toolCredentialSchemas)
+      setTempCredential(defaultCredentials)
     })
   }, [])
-  const [tempCredential, setTempCredential] = React.useState<any>(credentialValue)
 
   return (
     <Drawer