Explorar o código

Feat/model as tool (#2744)

Yeuoly hai 1 ano
pai
achega
40c646cf7a

+ 26 - 0
api/controllers/console/workspace/tool_providers.py

@@ -82,6 +82,30 @@ class ToolBuiltinProviderIconApi(Resource):
         icon_bytes, minetype = ToolManageService.get_builtin_tool_provider_icon(provider)
         return send_file(io.BytesIO(icon_bytes), mimetype=minetype)
 
+class ToolModelProviderIconApi(Resource):
+    @setup_required
+    def get(self, provider):
+        icon_bytes, mimetype = ToolManageService.get_model_tool_provider_icon(provider)
+        return send_file(io.BytesIO(icon_bytes), mimetype=mimetype)
+    
+class ToolModelProviderListToolsApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self):
+        user_id = current_user.id
+        tenant_id = current_user.current_tenant_id
+
+        parser = reqparse.RequestParser()
+        parser.add_argument('provider', type=str, required=True, nullable=False, location='args')
+
+        args = parser.parse_args()
+
+        return ToolManageService.list_model_tool_provider_tools(
+            user_id,
+            tenant_id,
+            args['provider'],
+        )
 
 class ToolApiProviderAddApi(Resource):
     @setup_required
@@ -283,6 +307,8 @@ api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provide
 api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin/<provider>/update')
 api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin/<provider>/credentials_schema')
 api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin/<provider>/icon')
+api.add_resource(ToolModelProviderIconApi, '/workspaces/current/tool-provider/model/<provider>/icon')
+api.add_resource(ToolModelProviderListToolsApi, '/workspaces/current/tool-provider/model/tools')
 api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add')
 api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote')
 api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools')

+ 5 - 2
api/core/model_runtime/entities/model_entities.py

@@ -17,7 +17,7 @@ class ModelType(Enum):
     SPEECH2TEXT = "speech2text"
     MODERATION = "moderation"
     TTS = "tts"
-    # TEXT2IMG = "text2img"
+    TEXT2IMG = "text2img"
 
     @classmethod
     def value_of(cls, origin_model_type: str) -> "ModelType":
@@ -36,6 +36,8 @@ class ModelType(Enum):
             return cls.SPEECH2TEXT
         elif origin_model_type == 'tts' or origin_model_type == cls.TTS.value:
             return cls.TTS
+        elif origin_model_type == 'text2img' or origin_model_type == cls.TEXT2IMG.value:
+            return cls.TEXT2IMG
         elif origin_model_type == cls.MODERATION.value:
             return cls.MODERATION
         else:
@@ -59,10 +61,11 @@ class ModelType(Enum):
             return 'tts'
         elif self == self.MODERATION:
             return 'moderation'
+        elif self == self.TEXT2IMG:
+            return 'text2img'
         else:
             raise ValueError(f'invalid model type {self}')
 
-
 class FetchFrom(Enum):
     """
     Enum class for fetch from.

+ 48 - 0
api/core/model_runtime/model_providers/__base/text2img_model.py

@@ -0,0 +1,48 @@
+from abc import abstractmethod
+from typing import IO, Optional
+
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.model_providers.__base.ai_model import AIModel
+
+
+class Text2ImageModel(AIModel):
+    """
+    Model class for text2img model.
+    """
+    model_type: ModelType = ModelType.TEXT2IMG
+
+    def invoke(self, model: str, credentials: dict, prompt: str, 
+               model_parameters: dict, user: Optional[str] = None) \
+            -> list[IO[bytes]]:
+        """
+        Invoke Text2Image model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param prompt: prompt for image generation
+        :param model_parameters: model parameters
+        :param user: unique user id
+
+        :return: image bytes
+        """
+        try:
+            return self._invoke(model, credentials, prompt, model_parameters, user)
+        except Exception as e:
+            raise self._transform_invoke_error(e)
+
+    @abstractmethod
+    def _invoke(self, model: str, credentials: dict, prompt: str, 
+                model_parameters: dict, user: Optional[str] = None) \
+            -> list[IO[bytes]]:
+        """
+        Invoke Text2Image model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param prompt: prompt for image generation
+        :param model_parameters: model parameters
+        :param user: unique user id
+
+        :return: image bytes
+        """
+        raise NotImplementedError

+ 5 - 1
api/core/tools/entities/common_entities.py

@@ -8,15 +8,19 @@ class I18nObject(BaseModel):
     Model class for i18n object.
     """
     zh_Hans: Optional[str] = None
+    pt_BR: Optional[str] = None
     en_US: str
 
     def __init__(self, **data):
         super().__init__(**data)
         if not self.zh_Hans:
             self.zh_Hans = self.en_US
+        if not self.pt_BR:
+            self.pt_BR = self.en_US
 
     def to_dict(self) -> dict:
         return {
             'zh_Hans': self.zh_Hans,
             'en_US': self.en_US,
-        }
+            'pt_BR': self.pt_BR
+        }

+ 21 - 1
api/core/tools/entities/tool_entities.py

@@ -304,4 +304,24 @@ class ToolRuntimeVariablePool(BaseModel):
             value=value,
         )
 
-        self.pool.append(variable)
+        self.pool.append(variable)
+
+class ModelToolPropertyKey(Enum):
+    IMAGE_PARAMETER_NAME = "image_parameter_name"
+
+class ModelToolConfiguration(BaseModel):
+    """
+    Model tool configuration
+    """
+    type: str = Field(..., description="The type of the model tool")
+    model: str = Field(..., description="The model")
+    label: I18nObject = Field(..., description="The label of the model tool")
+    properties: dict[ModelToolPropertyKey, Any] = Field(..., description="The properties of the model tool")
+
+class ModelToolProviderConfiguration(BaseModel):
+    """
+    Model tool provider configuration
+    """
+    provider: str = Field(..., description="The provider of the model tool")
+    models: list[ModelToolConfiguration] = Field(..., description="The models of the model tool")
+    label: I18nObject = Field(..., description="The label of the model tool")

+ 1 - 0
api/core/tools/entities/user_entities.py

@@ -13,6 +13,7 @@ class UserToolProvider(BaseModel):
         BUILTIN = "builtin"
         APP = "app"
         API = "api"
+        MODEL = "model"
 
     id: str
     author: str

+ 20 - 0
api/core/tools/model_tools/anthropic.yaml

@@ -0,0 +1,20 @@
+provider: anthropic
+label:
+  en_US: Anthropic Model Tools
+  zh_Hans: Anthropic 模型能力
+  pt_BR: Anthropic Model Tools
+models:
+  - type: llm
+    model: claude-3-sonnet-20240229
+    label:
+      zh_Hans: Claude3 Sonnet 视觉
+      en_US: Claude3 Sonnet Vision
+    properties:
+      image_parameter_name: image_id
+  - type: llm
+    model: claude-3-opus-20240229
+    label:
+      zh_Hans: Claude3 Opus 视觉
+      en_US: Claude3 Opus Vision
+    properties:
+      image_parameter_name: image_id

+ 13 - 0
api/core/tools/model_tools/google.yaml

@@ -0,0 +1,13 @@
+provider: google
+label:
+  en_US: Google Model Tools
+  zh_Hans: Google 模型能力
+  pt_BR: Google Model Tools
+models:
+  - type: llm
+    model: gemini-pro-vision
+    label:
+      zh_Hans: Gemini Pro 视觉
+      en_US: Gemini Pro Vision
+    properties:
+      image_parameter_name: image_id

+ 13 - 0
api/core/tools/model_tools/openai.yaml

@@ -0,0 +1,13 @@
+provider: openai
+label:
+  en_US: OpenAI Model Tools
+  zh_Hans: OpenAI 模型能力
+  pt_BR: OpenAI Model Tools
+models:
+  - type: llm
+    model: gpt-4-vision-preview
+    label:
+      zh_Hans: GPT-4 视觉
+      en_US: GPT-4 Vision
+    properties:
+      image_parameter_name: image_id

+ 13 - 0
api/core/tools/model_tools/zhipuai.yaml

@@ -0,0 +1,13 @@
+provider: zhipuai
+label:
+  en_US: ZhipuAI Model Tools
+  zh_Hans: ZhipuAI 模型能力
+  pt_BR: ZhipuAI Model Tools
+models:
+  - type: llm
+    model: glm-4v
+    label:
+      zh_Hans: GLM-4 视觉
+      en_US: GLM-4 Vision
+    properties:
+      image_parameter_name: image_id

+ 7 - 3
api/core/tools/provider/_position.yaml

@@ -1,14 +1,18 @@
 - google
 - bing
 - duckduckgo
-- yahoo
+- dalle
+- azuredalle
 - wikipedia
+- model.openai
+- model.google
+- model.anthropic
+- yahoo
 - arxiv
 - pubmed
-- dalle
-- azuredalle
 - stablediffusion
 - webscraper
+- model.zhipuai
 - aippt
 - youtube
 - wolframalpha

+ 9 - 9
api/core/tools/provider/builtin/_positions.py

@@ -4,24 +4,24 @@ from yaml import FullLoader, load
 
 from core.tools.entities.user_entities import UserToolProvider
 
-position = {}
 
 class BuiltinToolProviderSort:
-    @staticmethod
-    def sort(providers: list[UserToolProvider]) -> list[UserToolProvider]:
-        global position
-        if not position:
+    _position = {}
+
+    @classmethod
+    def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:
+        if not cls._position:
             tmp_position = {}
             file_path = os.path.join(os.path.dirname(__file__), '..', '_position.yaml')
             with open(file_path) as f:
                 for pos, val in enumerate(load(f, Loader=FullLoader)):
                     tmp_position[val] = pos
-            position = tmp_position
+            cls._position = tmp_position
 
         def sort_compare(provider: UserToolProvider) -> int:
-            # if provider.type == UserToolProvider.ProviderType.MODEL:
-            #     return position.get(f'model_provider.{provider.name}', 10000)
-            return position.get(provider.name, 10000)
+            if provider.type == UserToolProvider.ProviderType.MODEL:
+                return cls._position.get(f'model.{provider.name}', 10000)
+            return cls._position.get(provider.name, 10000)
         
         sorted_providers = sorted(providers, key=sort_compare)
 

+ 237 - 0
api/core/tools/provider/model_tool_provider.py

@@ -0,0 +1,237 @@
+from typing import Any
+
+from core.entities.model_entities import ModelStatus
+from core.errors.error import ProviderTokenNotInitError
+from core.model_manager import ModelInstance
+from core.model_runtime.entities.model_entities import ModelFeature, ModelType
+from core.provider_manager import ProviderConfiguration, ProviderManager, ProviderModelBundle
+from core.tools.entities.common_entities import I18nObject
+from core.tools.entities.tool_entities import (
+    ModelToolPropertyKey,
+    ToolDescription,
+    ToolIdentity,
+    ToolParameter,
+    ToolProviderCredentials,
+    ToolProviderIdentity,
+    ToolProviderType,
+)
+from core.tools.errors import ToolNotFoundError
+from core.tools.provider.tool_provider import ToolProviderController
+from core.tools.tool.model_tool import ModelTool
+from core.tools.tool.tool import Tool
+from core.tools.utils.configuration import ModelToolConfigurationManager
+
+
+class ModelToolProviderController(ToolProviderController):
+    configuration: ProviderConfiguration = None
+    is_active: bool = False
+
+    def __init__(self, configuration: ProviderConfiguration = None, **kwargs):
+        """
+            init the provider
+
+            :param data: the data of the provider
+        """
+        super().__init__(**kwargs)
+        self.configuration = configuration
+
+    @staticmethod
+    def from_db(configuration: ProviderConfiguration = None) -> 'ModelToolProviderController':
+        """
+            init the provider from db
+
+            :param configuration: the configuration of the provider
+        """
+        # check if all models are active
+        if configuration is None:
+            return None
+        is_active = True
+        models = configuration.get_provider_models()
+        for model in models:
+            if model.status != ModelStatus.ACTIVE:
+                is_active = False
+                break
+
+        # get the provider configuration
+        model_tool_configuration = ModelToolConfigurationManager.get_configuration(configuration.provider.provider)
+        if model_tool_configuration is None:
+            raise RuntimeError(f'no configuration found for provider {configuration.provider.provider}')
+
+        # override the configuration
+        if model_tool_configuration.label:
+            if model_tool_configuration.label.en_US:
+                configuration.provider.label.en_US = model_tool_configuration.label.en_US
+            if model_tool_configuration.label.zh_Hans:
+                configuration.provider.label.zh_Hans = model_tool_configuration.label.zh_Hans
+
+        return ModelToolProviderController(
+            is_active=is_active,
+            identity=ToolProviderIdentity(
+                author='Dify',
+                name=configuration.provider.provider,
+                description=I18nObject(
+                    zh_Hans=f'{configuration.provider.label.zh_Hans} 模型能力提供商', 
+                    en_US=f'{configuration.provider.label.en_US} model capability provider'
+                ),
+                label=I18nObject(
+                    zh_Hans=configuration.provider.label.zh_Hans,
+                    en_US=configuration.provider.label.en_US
+                ),
+                icon=configuration.provider.icon_small.en_US,
+            ),
+            configuration=configuration,
+            credentials_schema={},
+        )
+    
+    @staticmethod
+    def is_configuration_valid(configuration: ProviderConfiguration) -> bool:
+        """
+            check if the configuration has a model can be used as a tool
+        """
+        models = configuration.get_provider_models()
+        for model in models:
+            if model.model_type == ModelType.LLM and ModelFeature.VISION in (model.features or []):
+                return True
+        return False
+
+    def _get_model_tools(self, tenant_id: str = None) -> list[ModelTool]:
+        """
+            returns a list of tools that the provider can provide
+
+            :return: list of tools
+        """
+        tenant_id = tenant_id or 'ffffffff-ffff-ffff-ffff-ffffffffffff'
+        provider_manager = ProviderManager()
+        if self.configuration is None:
+            configurations = provider_manager.get_configurations(tenant_id=tenant_id).values()
+            self.configuration = next(filter(lambda x: x.provider == self.identity.name, configurations), None)
+        # get all tools
+        tools: list[ModelTool] = []
+        # get all models
+        if not self.configuration:
+            return tools
+        configuration = self.configuration
+
+        provider_configuration = ModelToolConfigurationManager.get_configuration(configuration.provider.provider)
+        if provider_configuration is None:
+            raise RuntimeError(f'no configuration found for provider {configuration.provider.provider}')
+
+        for model in configuration.get_provider_models():
+            model_configuration = ModelToolConfigurationManager.get_model_configuration(self.configuration.provider.provider, model.model)
+            if model_configuration is None:
+                continue
+
+            if model.model_type == ModelType.LLM and ModelFeature.VISION in (model.features or []):
+                provider_instance = configuration.get_provider_instance()
+                model_type_instance = provider_instance.get_model_instance(model.model_type)
+                provider_model_bundle = ProviderModelBundle(
+                    configuration=configuration,
+                    provider_instance=provider_instance,
+                    model_type_instance=model_type_instance
+                )
+
+                try:
+                    model_instance = ModelInstance(provider_model_bundle, model.model)
+                except ProviderTokenNotInitError:
+                    model_instance = None
+                
+                tools.append(ModelTool(
+                    identity=ToolIdentity(
+                        author='Dify',
+                        name=model.model,
+                        label=model_configuration.label,
+                    ),
+                    parameters=[
+                        ToolParameter(
+                            name=ModelToolPropertyKey.IMAGE_PARAMETER_NAME.value,
+                            label=I18nObject(zh_Hans='图片ID', en_US='Image ID'),
+                            human_description=I18nObject(zh_Hans='图片ID', en_US='Image ID'),
+                            type=ToolParameter.ToolParameterType.STRING,
+                            form=ToolParameter.ToolParameterForm.LLM,
+                            required=True,
+                            default=Tool.VARIABLE_KEY.IMAGE.value
+                        )
+                    ],
+                    description=ToolDescription(
+                        human=I18nObject(zh_Hans='图生文工具', en_US='Convert image to text'),
+                        llm='Vision tool used to extract text and other visual information from images, can be used for OCR, image captioning, etc.',
+                    ),
+                    is_team_authorization=model.status == ModelStatus.ACTIVE,
+                    tool_type=ModelTool.ModelToolType.VISION,
+                    model_instance=model_instance,
+                    model=model.model,
+                ))
+
+        self.tools = tools
+        return tools
+    
+    def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]:
+        """
+            returns the credentials schema of the provider
+
+            :return: the credentials schema
+        """
+        return {}
+
+    def get_tools(self, user_id: str, tenant_id: str) -> list[ModelTool]:
+        """
+            returns a list of tools that the provider can provide
+
+            :return: list of tools
+        """
+        return self._get_model_tools(tenant_id=tenant_id)
+    
+    def get_tool(self, tool_name: str) -> ModelTool:
+        """
+            get tool by name
+
+            :param tool_name: the name of the tool
+            :return: the tool
+        """
+        if self.tools is None:
+            self.get_tools(user_id='', tenant_id=self.configuration.tenant_id)
+
+        for tool in self.tools:
+            if tool.identity.name == tool_name:
+                return tool
+
+        raise ValueError(f'tool {tool_name} not found')
+
+    def get_parameters(self, tool_name: str) -> list[ToolParameter]:
+        """
+            returns the parameters of the tool
+
+            :param tool_name: the name of the tool, defined in `get_tools`
+            :return: list of parameters
+        """
+        tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
+        if tool is None:
+            raise ToolNotFoundError(f'tool {tool_name} not found')
+        return tool.parameters
+
+    @property
+    def app_type(self) -> ToolProviderType:
+        """
+            returns the type of the provider
+
+            :return: type of the provider
+        """
+        return ToolProviderType.MODEL
+    
+    def validate_credentials(self, credentials: dict[str, Any]) -> None:
+        """
+            validate the credentials of the provider
+
+            :param tool_name: the name of the tool, defined in `get_tools`
+            :param credentials: the credentials of the tool
+        """
+        pass
+
+    def _validate_credentials(self, credentials: dict[str, Any]) -> None:
+        """
+            validate the credentials of the provider
+
+            :param tool_name: the name of the tool, defined in `get_tools`
+            :param credentials: the credentials of the tool
+        """
+        pass

+ 156 - 0
api/core/tools/tool/model_tool.py

@@ -0,0 +1,156 @@
+from base64 import b64encode
+from enum import Enum
+from typing import Any, cast
+
+from core.model_manager import ModelInstance
+from core.model_runtime.entities.llm_entities import LLMResult
+from core.model_runtime.entities.message_entities import (
+    PromptMessageContent,
+    PromptMessageContentType,
+    SystemPromptMessage,
+    UserPromptMessage,
+)
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.tools.entities.tool_entities import ModelToolPropertyKey, ToolInvokeMessage
+from core.tools.tool.tool import Tool
+
+VISION_PROMPT = """## Image Recognition Task
+### Task Description
+I require a powerful vision language model for an image recognition task. The model should be capable of extracting various details from the images, including but not limited to text content, layout distribution, color distribution, main subjects, and emotional expressions.
+### Specific Requirements
+1. **Text Content Extraction:** Ensure that the model accurately recognizes and extracts text content from the images, regardless of text size, font, or color.
+2. **Layout Distribution Analysis:** The model should analyze the layout structure of the images, capturing the relationships between various elements and providing detailed information about the image layout.
+3. **Color Distribution Analysis:** Extract information about color distribution in the images, including primary colors, color combinations, and other relevant details.
+4. **Main Subject Recognition:** The model should accurately identify the main subjects in the images and provide detailed descriptions of these subjects.
+5. **Emotional Expression Analysis:** Analyze and describe the emotions or expressions conveyed in the images based on facial expressions, postures, and other relevant features.
+### Additional Considerations
+- Ensure that the extracted information is as comprehensive and accurate as possible.
+- For each task, provide confidence scores or relevance scores for the model outputs to assess the reliability of the results.
+- If necessary, pose specific questions for different tasks to guide the model in better understanding the images and providing relevant information."""
+
+class ModelTool(Tool):
+    class ModelToolType(Enum):
+        """
+            the type of the model tool
+        """
+        VISION = 'vision'
+
+    model_configuration: dict[str, Any] = None
+    tool_type: ModelToolType
+    
+    def __init__(self, model_instance: ModelInstance = None, model: str = None, 
+                 tool_type: ModelToolType = ModelToolType.VISION, 
+                 properties: dict[ModelToolPropertyKey, Any] = None,
+                 **kwargs):
+        """
+            init the tool
+        """
+        kwargs['model_configuration'] = {
+            'model_instance': model_instance,
+            'model': model,
+            'properties': properties
+        }
+        kwargs['tool_type'] = tool_type
+        super().__init__(**kwargs)
+
+    """
+    Model tool
+    """
+    def fork_tool_runtime(self, meta: dict[str, Any]) -> 'Tool':
+        """
+            fork a new tool with meta data
+
+            :param meta: the meta data of a tool call processing, tenant_id is required
+            :return: the new tool
+        """
+        return self.__class__(
+            identity=self.identity.copy() if self.identity else None,
+            parameters=self.parameters.copy() if self.parameters else None,
+            description=self.description.copy() if self.description else None,
+            model_instance=self.model_configuration['model_instance'],
+            model=self.model_configuration['model'],
+            tool_type=self.tool_type,
+            runtime=Tool.Runtime(**meta)
+        )
+
+    def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False) -> None:
+        """
+            validate the credentials for Model tool
+        """
+        pass
+
+    def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
+        """
+        """
+        model_instance = self.model_configuration['model_instance']
+        if not model_instance:
+            return self.create_text_message('the tool is not configured correctly')
+        
+        if self.tool_type == ModelTool.ModelToolType.VISION:
+            return self._invoke_llm_vision(user_id, tool_parameters)
+        else:
+            return self.create_text_message('the tool is not configured correctly')
+        
+    def _invoke_llm_vision(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
+        # get image
+        image_parameter_name = self.model_configuration['properties'].get(ModelToolPropertyKey.IMAGE_PARAMETER_NAME, 'image_id')
+        image_id = tool_parameters.pop(image_parameter_name, '')
+        if not image_id:
+            image = self.get_default_image_variable()
+            if not image:
+                return self.create_text_message('Please upload an image or input image_id')
+        else:
+            image = self.get_variable(image_id)
+            if not image:
+                image = self.get_default_image_variable()
+                if not image:
+                    return self.create_text_message('Please upload an image or input image_id')
+        
+        if not image:
+            return self.create_text_message('Please upload an image or input image_id')
+        
+        # get image
+        image = self.get_variable_file(image.name)
+        if not image:
+            return self.create_text_message('Failed to get image')
+        
+        # organize prompt messages
+        prompt_messages = [
+            SystemPromptMessage(
+                content=VISION_PROMPT
+            ),
+            UserPromptMessage(
+                content=[
+                    PromptMessageContent(
+                        type=PromptMessageContentType.TEXT,
+                        data='Recognize the image and extract the information from the image.'
+                    ),
+                    PromptMessageContent(
+                        type=PromptMessageContentType.IMAGE,
+                        data=f'data:image/png;base64,{b64encode(image).decode("utf-8")}'
+                    )
+                ]
+            )
+        ]
+
+        llm_instance = cast(LargeLanguageModel, self.model_configuration['model_instance'])
+        result: LLMResult = llm_instance.invoke(
+            model=self.model_configuration['model'],
+            credentials=self.runtime.credentials,
+            prompt_messages=prompt_messages,
+            model_parameters=tool_parameters,
+            tools=[],
+            stop=[],
+            stream=False,
+            user=user_id,
+        )
+
+        if not result:
+            return self.create_text_message('Failed to extract information from the image')
+        
+        # get result
+        content = result.message.content
+        if not content:
+            return self.create_text_message('Failed to extract information from the image')
+        
+        return self.create_text_message(content)

+ 82 - 2
api/core/tools/tool_manager.py

@@ -7,6 +7,7 @@ from typing import Any, Union
 
 from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
 from core.model_runtime.entities.message_entities import PromptMessage
+from core.provider_manager import ProviderManager
 from core.tools.entities.common_entities import I18nObject
 from core.tools.entities.constant import DEFAULT_PROVIDERS
 from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeMessage, ToolProviderCredentials
@@ -16,10 +17,11 @@ from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
 from core.tools.provider.app_tool_provider import AppBasedToolProviderEntity
 from core.tools.provider.builtin._positions import BuiltinToolProviderSort
 from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
+from core.tools.provider.model_tool_provider import ModelToolProviderController
 from core.tools.provider.tool_provider import ToolProviderController
 from core.tools.tool.api_tool import ApiTool
 from core.tools.tool.builtin_tool import BuiltinTool
-from core.tools.utils.configuration import ToolConfiguration
+from core.tools.utils.configuration import ModelToolConfigurationManager, ToolConfiguration
 from core.tools.utils.encoder import serialize_base_model_dict
 from extensions.ext_database import db
 from models.tools import ApiToolProvider, BuiltinToolProvider
@@ -135,7 +137,7 @@ class ToolManager:
             raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
         
     @staticmethod
-    def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tenant_id, 
+    def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tenant_id: str, 
                          agent_callback: DifyAgentCallbackHandler = None) \
         -> Union[BuiltinTool, ApiTool]:
         """
@@ -194,6 +196,19 @@ class ToolManager:
                 'tenant_id': tenant_id,
                 'credentials': decrypted_credentials,
             })
+        elif provider_type == 'model':
+            if tenant_id is None:
+                raise ValueError('tenant id is required for model provider')
+            # get model provider
+            model_provider = ToolManager.get_model_provider(tenant_id, provider_name)
+
+            # get tool
+            model_tool = model_provider.get_tool(tool_name)
+
+            return model_tool.fork_tool_runtime(meta={
+                'tenant_id': tenant_id,
+                'credentials': model_tool.model_configuration['model_instance'].credentials
+            })
         elif provider_type == 'app':
             raise NotImplementedError('app provider not implemented')
         else:
@@ -266,6 +281,49 @@ class ToolManager:
 
         return builtin_providers
     
+    @staticmethod
+    def list_model_providers(tenant_id: str = None) -> list[ModelToolProviderController]:
+        """
+            list all the model providers
+
+            :return: the list of the model providers
+        """
+        tenant_id = tenant_id or 'ffffffff-ffff-ffff-ffff-ffffffffffff'
+        # get configurations
+        model_configurations = ModelToolConfigurationManager.get_all_configuration()
+        # get all providers
+        provider_manager = ProviderManager()
+        configurations = provider_manager.get_configurations(tenant_id).values()
+        # get model providers
+        model_providers: list[ModelToolProviderController] = []
+        for configuration in configurations:
+            # all the model tool should be configurated
+            if configuration.provider.provider not in model_configurations:
+                continue
+            if not ModelToolProviderController.is_configuration_valid(configuration):
+                continue
+            model_providers.append(ModelToolProviderController.from_db(configuration))
+
+        return model_providers
+    
+    @staticmethod
+    def get_model_provider(tenant_id: str, provider_name: str) -> ModelToolProviderController:
+        """
+            get the model provider
+
+            :param provider_name: the name of the provider
+
+            :return: the provider
+        """
+        # get configurations
+        provider_manager = ProviderManager()
+        configurations = provider_manager.get_configurations(tenant_id)
+        configuration = configurations.get(provider_name)
+        if configuration is None:
+            raise ToolProviderNotFoundError(f'model provider {provider_name} not found')
+        
+        return ModelToolProviderController.from_db(configuration)
+
     @staticmethod
     def get_tool_label(tool_name: str) -> Union[I18nObject, None]:
         """
@@ -345,6 +403,28 @@ class ToolManager:
 
             result_providers[provider_name].team_credentials = masked_credentials
 
+        # get model tool providers
+        model_providers = ToolManager.list_model_providers(tenant_id=tenant_id)
+        # append model providers
+        for provider in model_providers:
+            result_providers[f'model_provider.{provider.identity.name}'] = UserToolProvider(
+                id=provider.identity.name,
+                author=provider.identity.author,
+                name=provider.identity.name,
+                description=I18nObject(
+                    en_US=provider.identity.description.en_US,
+                    zh_Hans=provider.identity.description.zh_Hans,
+                ),
+                icon=provider.identity.icon,
+                label=I18nObject(
+                    en_US=provider.identity.label.en_US,
+                    zh_Hans=provider.identity.label.zh_Hans,
+                ),
+                type=UserToolProvider.ProviderType.MODEL,
+                team_credentials={},
+                is_team_authorization=provider.is_active,
+            )
+
         # get db api providers
         db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \
             filter(ApiToolProvider.tenant_id == tenant_id).all()

+ 70 - 2
api/core/tools/utils/configuration.py

@@ -1,10 +1,16 @@
-from typing import Any
+import os
+from typing import Any, Union
 
 from pydantic import BaseModel
+from yaml import FullLoader, load
 
 from core.helper import encrypter
 from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
-from core.tools.entities.tool_entities import ToolProviderCredentials
+from core.tools.entities.tool_entities import (
+    ModelToolConfiguration,
+    ModelToolProviderConfiguration,
+    ToolProviderCredentials,
+)
 from core.tools.provider.tool_provider import ToolProviderController
 
 
@@ -94,3 +100,65 @@ class ToolConfiguration(BaseModel):
             cache_type=ToolProviderCredentialsCacheType.PROVIDER
         )
         cache.delete()
+
+class ModelToolConfigurationManager:
+    """
+    Model as tool configuration
+    """
+    _configurations: dict[str, ModelToolProviderConfiguration] = {}
+    _model_configurations: dict[str, ModelToolConfiguration] = {}
+    _inited = False
+
+    @classmethod
+    def _init_configuration(cls):
+        """
+        init configuration
+        """
+        
+        absolute_path = os.path.abspath(os.path.dirname(__file__))
+        model_tools_path = os.path.join(absolute_path, '..', 'model_tools')
+
+        # get all .yaml file
+        files = [f for f in os.listdir(model_tools_path) if f.endswith('.yaml')]
+
+        for file in files:
+            provider = file.split('.')[0]
+            with open(os.path.join(model_tools_path, file), encoding='utf-8') as f:
+                configurations = ModelToolProviderConfiguration(**load(f, Loader=FullLoader))
+                models = configurations.models or []
+                for model in models:
+                    model_key = f'{provider}.{model.model}'
+                    cls._model_configurations[model_key] = model
+
+                cls._configurations[provider] = configurations
+        cls._inited = True
+
+    @classmethod
+    def get_configuration(cls, provider: str) -> Union[ModelToolProviderConfiguration, None]:
+        """
+        get configuration by provider
+        """
+        if not cls._inited:
+            cls._init_configuration()
+        return cls._configurations.get(provider, None)
+    
+    @classmethod
+    def get_all_configuration(cls) -> dict[str, ModelToolProviderConfiguration]:
+        """
+        get all configurations
+        """
+        if not cls._inited:
+            cls._init_configuration()
+        return cls._configurations
+    
+    @classmethod
+    def get_model_configuration(cls, provider: str, model: str) -> Union[ModelToolConfiguration, None]:
+        """
+        get model configuration
+        """
+        key = f'{provider}.{model}'
+
+        if not cls._inited:
+            cls._init_configuration()
+
+        return cls._model_configurations.get(key, None)

+ 45 - 2
api/services/tools_manage_service.py

@@ -22,6 +22,7 @@ from core.tools.utils.encoder import serialize_base_model_array, serialize_base_
 from core.tools.utils.parser import ApiBasedToolSchemaParser
 from extensions.ext_database import db
 from models.tools import ApiToolProvider, BuiltinToolProvider
+from services.model_provider_service import ModelProviderService
 
 
 class ToolManageService:
@@ -50,11 +51,13 @@ class ToolManageService:
             :param provider: the provider dict
         """
         url_prefix = (current_app.config.get("CONSOLE_API_URL")
-                      + "/console/api/workspaces/current/tool-provider/builtin/")
+                      + "/console/api/workspaces/current/tool-provider/")
         
         if 'icon' in provider:
             if provider['type'] == UserToolProvider.ProviderType.BUILTIN.value:
-                provider['icon'] = url_prefix + provider['name'] + '/icon'
+                provider['icon'] = url_prefix + 'builtin/' + provider['name'] + '/icon'
+            elif provider['type'] == UserToolProvider.ProviderType.MODEL.value:
+                provider['icon'] = url_prefix + 'model/' + provider['name'] + '/icon'
             elif provider['type'] == UserToolProvider.ProviderType.API.value:
                 try:
                     provider['icon'] = json.loads(provider['icon'])
@@ -505,6 +508,46 @@ class ToolManageService:
 
         return icon_bytes, mime_type
     
+    @staticmethod
+    def get_model_tool_provider_icon(
+        provider: str
+    ):
+        """
+            get tool provider icon and it's mimetype
+        """
+        
+        service = ModelProviderService()
+        icon_bytes, mime_type = service.get_model_provider_icon(provider=provider, icon_type='icon_small', lang='en_US')
+
+        if icon_bytes is None:
+            raise ValueError(f'provider {provider} does not exists')
+
+        return icon_bytes, mime_type
+    
+    @staticmethod
+    def list_model_tool_provider_tools(
+        user_id: str, tenant_id: str, provider: str
+    ):
+        """
+            list model tool provider tools
+        """
+        provider_controller = ToolManager.get_model_provider(tenant_id=tenant_id, provider_name=provider)
+        tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id)
+
+        result = [
+            UserTool(
+                author=tool.identity.author,
+                name=tool.identity.name,
+                label=tool.identity.label,
+                description=tool.description.human,
+                parameters=tool.parameters or []
+            ) for tool in tools
+        ]
+
+        return json.loads(
+            serialize_base_model_array(result)
+        )
+    
     @staticmethod
     def delete_api_tool_provider(
         user_id: str, tenant_id: str, provider_name: str

+ 1 - 1
web/app/components/app/configuration/config/agent/agent-tools/index.tsx

@@ -34,7 +34,7 @@ const AgentTools: FC = () => {
   const [selectedProviderId, setSelectedProviderId] = useState<string | undefined>(undefined)
   const [isShowSettingTool, setIsShowSettingTool] = useState(false)
   const tools = (modelConfig?.agentConfig?.tools as AgentTool[] || []).map((item) => {
-    const collection = collectionList.find(collection => collection.id === item.provider_id)
+    const collection = collectionList.find(collection => collection.id === item.provider_id && collection.type === item.provider_type)
     const icon = collection?.icon
     return {
       ...item,

+ 8 - 2
web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx

@@ -8,7 +8,7 @@ import Drawer from '@/app/components/base/drawer-plus'
 import Form from '@/app/components/header/account-setting/model-provider-page/model-modal/Form'
 import { addDefaultValue, toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema'
 import type { Collection, Tool } from '@/app/components/tools/types'
-import { fetchBuiltInToolList, fetchCustomToolList } from '@/service/tools'
+import { fetchBuiltInToolList, fetchCustomToolList, fetchModelToolList } from '@/service/tools'
 import I18n from '@/context/i18n'
 import Button from '@/app/components/base/button'
 import Loading from '@/app/components/base/loading'
@@ -19,6 +19,7 @@ import AppIcon from '@/app/components/base/app-icon'
 type Props = {
   collection: Collection
   isBuiltIn?: boolean
+  isModel?: boolean
   toolName: string
   setting?: Record<string, any>
   readonly?: boolean
@@ -29,6 +30,7 @@ type Props = {
 const SettingBuiltInTool: FC<Props> = ({
   collection,
   isBuiltIn = true,
+  isModel = true,
   toolName,
   setting = {},
   readonly,
@@ -56,7 +58,11 @@ const SettingBuiltInTool: FC<Props> = ({
     (async () => {
       setIsLoading(true)
       try {
-        const list = isBuiltIn ? await fetchBuiltInToolList(collection.name) : await fetchCustomToolList(collection.name)
+        const list = isBuiltIn
+          ? await fetchBuiltInToolList(collection.name)
+          : isModel
+            ? await fetchModelToolList(collection.name)
+            : await fetchCustomToolList(collection.name)
         setTools(list)
         const currTool = list.find(tool => tool.name === toolName)
         if (currTool) {

+ 12 - 6
web/app/components/tools/index.tsx

@@ -18,7 +18,7 @@ import NoSearchRes from './info/no-search-res'
 import NoCustomToolPlaceholder from './no-custom-tool-placeholder'
 import { useTabSearchParams } from '@/hooks/use-tab-searchparams'
 import TabSlider from '@/app/components/base/tab-slider'
-import { createCustomCollection, fetchCollectionList as doFetchCollectionList, fetchBuiltInToolList, fetchCustomToolList } from '@/service/tools'
+import { createCustomCollection, fetchCollectionList as doFetchCollectionList, fetchBuiltInToolList, fetchCustomToolList, fetchModelToolList } from '@/service/tools'
 import type { AgentTool } from '@/types/app'
 
 type Props = {
@@ -89,9 +89,11 @@ const Tools: FC<Props> = ({
   const showCollectionList = (() => {
     let typeFilteredList: Collection[] = []
     if (collectionType === CollectionType.all)
-      typeFilteredList = collectionList
-    else
-      typeFilteredList = collectionList.filter(item => item.type === collectionType)
+      typeFilteredList = collectionList.filter(item => item.type !== CollectionType.model)
+    else if (collectionType === CollectionType.builtIn)
+      typeFilteredList = collectionList.filter(item => item.type === CollectionType.builtIn)
+    else if (collectionType === CollectionType.custom)
+      typeFilteredList = collectionList.filter(item => item.type === CollectionType.custom)
     if (query)
       return typeFilteredList.filter(item => item.name.includes(query))
 
@@ -122,6 +124,10 @@ const Tools: FC<Props> = ({
           const list = await fetchBuiltInToolList(currCollection.name)
           setCurrentTools(list)
         }
+        else if (currCollection.type === CollectionType.model) {
+          const list = await fetchModelToolList(currCollection.name)
+          setCurrentTools(list)
+        }
         else {
           const list = await fetchCustomToolList(currCollection.name)
           setCurrentTools(list)
@@ -130,7 +136,7 @@ const Tools: FC<Props> = ({
       catch (e) { }
       setIsDetailLoading(false)
     })()
-  }, [currCollection?.name])
+  }, [currCollection?.name, currCollection?.type])
 
   const [isShowEditCollectionToolModal, setIsShowEditCollectionToolModal] = useState(false)
   const handleCreateToolCollection = () => {
@@ -197,7 +203,7 @@ const Tools: FC<Props> = ({
               (showCollectionList.length > 0 || !query)
                 ? <ToolNavList
                   className='mt-2 grow height-0 overflow-y-auto'
-                  currentName={currCollection?.name || ''}
+                  currentIndex={currCollectionIndex || 0}
                   list={showCollectionList}
                   onChosen={setCurrCollectionIndex}
                 />

+ 6 - 4
web/app/components/tools/tool-list/header.tsx

@@ -29,9 +29,8 @@ const Header: FC<Props> = ({
   const { t } = useTranslation()
   const isInToolsPage = loc === LOC.tools
   const isInDebugPage = !isInToolsPage
-  const needAuth = collection?.allow_delete
 
-  // const isBuiltIn = collection.type === CollectionType.builtIn
+  const needAuth = collection?.allow_delete || collection?.type === CollectionType.model
   const isAuthed = collection.is_team_authorization
   return (
     <div className={cn(isInToolsPage ? 'py-4 px-6' : 'py-[11px] pl-4 pr-3', 'flex justify-between items-start border-b border-gray-200')}>
@@ -50,10 +49,13 @@ const Header: FC<Props> = ({
           )}
         </div>
       </div>
-      {collection.type === CollectionType.builtIn && needAuth && (
+      {(collection.type === CollectionType.builtIn || collection.type === CollectionType.model) && needAuth && (
         <div
           className={cn('cursor-pointer', 'ml-1 shrink-0 flex items-center h-8 border border-gray-200 rounded-lg px-3 space-x-2 shadow-xs')}
-          onClick={() => onShowAuth()}
+          onClick={() => {
+            if (collection.type === CollectionType.builtIn || collection.type === CollectionType.model)
+              onShowAuth()
+          }}
         >
           <div className={cn(isAuthed ? 'border-[#12B76A] bg-[#32D583]' : 'border-gray-400 bg-gray-300', 'rounded h-2 w-2 border')}></div>
           <div className='leading-5 text-sm font-medium text-gray-700'>{t(`tools.auth.${isAuthed ? 'authorized' : 'unauthorized'}`)}</div>

+ 30 - 4
web/app/components/tools/tool-list/index.tsx

@@ -8,6 +8,7 @@ import type { Collection, CustomCollectionBackend, Tool } from '../types'
 import Loading from '../../base/loading'
 import { ArrowNarrowRight } from '../../base/icons/src/vender/line/arrows'
 import Toast from '../../base/toast'
+import { ConfigurateMethodEnum } from '../../header/account-setting/model-provider-page/declarations'
 import Header from './header'
 import Item from './item'
 import AppIcon from '@/app/components/base/app-icon'
@@ -16,6 +17,8 @@ import { fetchCustomCollection, removeBuiltInToolCredential, removeCustomCollect
 import EditCustomToolModal from '@/app/components/tools/edit-custom-collection-modal'
 import type { AgentTool } from '@/types/app'
 import { MAX_TOOLS_NUM } from '@/config'
+import { useModalContext } from '@/context/modal-context'
+import { useProviderContext } from '@/context/provider-context'
 
 type Props = {
   collection: Collection | null
@@ -42,9 +45,32 @@ const ToolList: FC<Props> = ({
   const { t } = useTranslation()
   const isInToolsPage = loc === LOC.tools
   const isBuiltIn = collection?.type === CollectionType.builtIn
+  const isModel = collection?.type === CollectionType.model
   const needAuth = collection?.allow_delete
 
+  const { setShowModelModal } = useModalContext()
   const [showSettingAuth, setShowSettingAuth] = useState(false)
+  const { modelProviders: providers } = useProviderContext()
+  const showSettingAuthModal = () => {
+    if (isModel) {
+      const provider = providers.find(item => item.provider === collection?.id)
+      if (provider) {
+        setShowModelModal({
+          payload: {
+            currentProvider: provider,
+            currentConfigurateMethod: ConfigurateMethodEnum.predefinedModel,
+            currentCustomConfigrationModelFixedFields: undefined,
+          },
+          onSaveCallback: () => {
+            onRefreshData()
+          },
+        })
+      }
+    }
+    else {
+      setShowSettingAuth(true)
+    }
+  }
 
   const [customCollection, setCustomCollection] = useState<CustomCollectionBackend | null>(null)
   useEffect(() => {
@@ -116,7 +142,7 @@ const ToolList: FC<Props> = ({
         icon={icon}
         collection={collection}
         loc={loc}
-        onShowAuth={() => setShowSettingAuth(true)}
+        onShowAuth={() => showSettingAuthModal()}
         onShowEditCustomCollection={() => setIsShowEditCustomCollectionModal(true)}
       />
       <div className={cn(isInToolsPage ? 'px-6 pt-4' : 'px-4 pt-3')}>
@@ -124,12 +150,12 @@ const ToolList: FC<Props> = ({
           <div className=''>{t('tools.includeToolNum', {
             num: list.length,
           })}</div>
-          {needAuth && isBuiltIn && !collection.is_team_authorization && (
+          {needAuth && (isBuiltIn || isModel) && !collection.is_team_authorization && (
             <>
               <div>·</div>
               <div
                 className='flex items-center text-[#155EEF] cursor-pointer'
-                onClick={() => setShowSettingAuth(true)}
+                onClick={() => showSettingAuthModal()}
               >
                 <div>{t('tools.auth.setup')}</div>
                 <ArrowNarrowRight className='ml-0.5 w-3 h-3' />
@@ -149,7 +175,7 @@ const ToolList: FC<Props> = ({
               collection={collection}
               isInToolsPage={isInToolsPage}
               isToolNumMax={(addedTools?.length || 0) >= MAX_TOOLS_NUM}
-              added={!!addedTools?.find(v => v.provider_id === collection.id && v.tool_name === item.name)}
+              added={!!addedTools?.find(v => v.provider_id === collection.id && v.provider_type === collection.type && v.tool_name === item.name)}
               onAdd={!isInToolsPage ? tool => onAddTool?.(collection as Collection, tool) : undefined}
             />
           ))}

+ 2 - 0
web/app/components/tools/tool-list/item.tsx

@@ -35,6 +35,7 @@ const Item: FC<Props> = ({
   const language = getLanguage(locale)
 
   const isBuiltIn = collection.type === CollectionType.builtIn
+  const isModel = collection.type === CollectionType.model
   const canShowDetail = isInToolsPage
   const [showDetail, setShowDetail] = useState(false)
   const addBtn = <Button className='shrink-0 flex items-center h-7 !px-3 !text-xs !font-medium !text-gray-700' disabled={added || !collection.is_team_authorization} onClick={() => onAdd?.(payload)}>{t(`common.operation.${added ? 'added' : 'add'}`)}</Button>
@@ -73,6 +74,7 @@ const Item: FC<Props> = ({
             setShowDetail(false)
           }}
           isBuiltIn={isBuiltIn}
+          isModel={isModel}
         />
       )}
     </>

+ 3 - 3
web/app/components/tools/tool-nav-list/index.tsx

@@ -6,21 +6,21 @@ import Item from './item'
 import type { Collection } from '@/app/components/tools/types'
 type Props = {
   className?: string
-  currentName: string
+  currentIndex: number
   list: Collection[]
   onChosen: (index: number) => void
 }
 
 const ToolNavList: FC<Props> = ({
   className,
-  currentName,
+  currentIndex,
   list,
   onChosen,
 }) => {
   return (
     <div className={cn(className)}>
       {list.map((item, index) => (
-        <Item isCurrent={item.name === currentName} key={item.name} payload={item} onClick={() => onChosen(index)}></Item>
+        <Item isCurrent={index === currentIndex} key={index} payload={item} onClick={() => onChosen(index)}></Item>
       ))}
     </div>
   )

+ 1 - 0
web/app/components/tools/types.ts

@@ -26,6 +26,7 @@ export enum CollectionType {
   all = 'all',
   builtIn = 'builtin',
   custom = 'api',
+  model = 'model',
 }
 
 export type Emoji = {

+ 5 - 0
web/service/tools.ts

@@ -12,6 +12,11 @@ export const fetchBuiltInToolList = (collectionName: string) => {
 export const fetchCustomToolList = (collectionName: string) => {
   return get<Tool[]>(`/workspaces/current/tool-provider/api/tools?provider=${collectionName}`)
 }
+
+export const fetchModelToolList = (collectionName: string) => {
+  return get<Tool[]>(`/workspaces/current/tool-provider/model/tools?provider=${collectionName}`)
+}
+
 export const fetchBuiltInToolCredentialSchema = (collectionName: string) => {
   return get<ToolCredential[]>(`/workspaces/current/tool-provider/builtin/${collectionName}/credentials_schema`)
 }