Przeglądaj źródła

Feat: AIPPT & DynamicToolParamter (#2725)

Yeuoly 1 rok temu
rodzic
commit
27e678480e

+ 1 - 0
api/core/tools/provider/_position.yaml

@@ -9,6 +9,7 @@
 - azuredalle
 - azuredalle
 - stablediffusion
 - stablediffusion
 - webscraper
 - webscraper
+- aippt
 - youtube
 - youtube
 - wolframalpha
 - wolframalpha
 - maths
 - maths

BIN
api/core/tools/provider/builtin/aippt/_assets/icon.png


+ 11 - 0
api/core/tools/provider/builtin/aippt/aippt.py

@@ -0,0 +1,11 @@
+from core.tools.errors import ToolProviderCredentialValidationError
+from core.tools.provider.builtin.aippt.tools.aippt import AIPPTGenerateTool
+from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
+
+
+class AIPPTProvider(BuiltinToolProviderController):
+    def _validate_credentials(self, credentials: dict) -> None:
+        try:
+            AIPPTGenerateTool._get_api_token(credentials, user_id='__dify_system__')
+        except Exception as e:
+            raise ToolProviderCredentialValidationError(str(e))

+ 42 - 0
api/core/tools/provider/builtin/aippt/aippt.yaml

@@ -0,0 +1,42 @@
+identity:
+  author: Dify
+  name: aippt
+  label:
+    en_US: AIPPT
+    zh_Hans: AIPPT
+  description:
+    en_US: AI-generated PPT with one click, input your content topic, and let AI serve you one-stop
+    zh_Hans: AI一键生成PPT,输入你的内容主题,让AI为你一站式服务到底
+  icon: icon.png
+credentials_for_provider:
+  aippt_access_key:
+    type: secret-input
+    required: true
+    label:
+      en_US: AIPPT API key
+      zh_Hans: AIPPT API key
+      pt_BR: AIPPT API key
+    help:
+      en_US: Please input your AIPPT API key
+      zh_Hans: 请输入你的 AIPPT API key
+      pt_BR: Please input your AIPPT API key
+    placeholder:
+      en_US: Please input your AIPPT API key
+      zh_Hans: 请输入你的 AIPPT API key
+      pt_BR: Please input your AIPPT API key
+    url: https://www.aippt.cn
+  aippt_secret_key:
+    type: secret-input
+    required: true
+    label:
+      en_US: AIPPT Secret key
+      zh_Hans: AIPPT Secret key
+      pt_BR: AIPPT Secret key
+    help:
+      en_US: Please input your AIPPT Secret key
+      zh_Hans: 请输入你的 AIPPT Secret key
+      pt_BR: Please input your AIPPT Secret key
+    placeholder:
+      en_US: Please input your AIPPT Secret key
+      zh_Hans: 请输入你的 AIPPT Secret key
+      pt_BR: Please input your AIPPT Secret key

+ 509 - 0
api/core/tools/provider/builtin/aippt/tools/aippt.py

@@ -0,0 +1,509 @@
+from base64 import b64encode
+from hashlib import sha1
+from hmac import new as hmac_new
+from json import loads as json_loads
+from threading import Lock
+from time import sleep, time
+from typing import Any
+
+from httpx import get, post
+from requests import get as requests_get
+from yarl import URL
+
+from core.tools.entities.common_entities import I18nObject
+from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
+from core.tools.tool.builtin_tool import BuiltinTool
+
+
+class AIPPTGenerateTool(BuiltinTool):
+    """
+    A tool for generating a ppt
+    """
+
+    _api_base_url = URL('https://co.aippt.cn/api')
+    _api_token_cache = {}
+    _api_token_cache_lock = Lock()
+
+    _task = {}
+    _task_type_map = {
+        'auto': 1,
+        'markdown': 7,
+    }
+
+    def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
+        """
+        Invokes the AIPPT generate tool with the given user ID and tool parameters.
+
+        Args:
+            user_id (str): The ID of the user invoking the tool.
+            tool_parameters (dict[str, Any]): The parameters for the tool
+
+        Returns:
+            ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, which can be a single message or a list of messages.
+        """
+        title = tool_parameters.get('title', '')
+        if not title:
+            return self.create_text_message('Please provide a title for the ppt')
+        
+        model = tool_parameters.get('model', 'aippt')
+        if not model:
+            return self.create_text_message('Please provide a model for the ppt')
+        
+        outline = tool_parameters.get('outline', '')
+
+        # create task
+        task_id = self._create_task(
+            type=self._task_type_map['auto' if not outline else 'markdown'],
+            title=title,
+            content=outline,
+            user_id=user_id
+        )
+
+        # get suit
+        color = tool_parameters.get('color')
+        style = tool_parameters.get('style')
+
+        if color == '__default__':
+            color_id = ''
+        else:
+            color_id = int(color.split('-')[1])
+
+        if style == '__default__':
+            style_id = ''
+        else:
+            style_id = int(style.split('-')[1])
+
+        suit_id = self._get_suit(style_id=style_id, colour_id=color_id)
+
+        # generate outline
+        if not outline:
+            self._generate_outline(
+                task_id=task_id,
+                model=model,
+                user_id=user_id
+            )
+
+            # generate content
+            self._generate_content(
+                task_id=task_id,
+                model=model,
+                user_id=user_id
+            )
+
+        # generate ppt
+        _, ppt_url = self._generate_ppt(
+            task_id=task_id,
+            suit_id=suit_id,
+            user_id=user_id
+        )
+
+        return self.create_text_message('''the ppt has been created successfully,'''
+                                 f'''the ppt url is {ppt_url}'''
+                                 '''please give the ppt url to user and direct user to download it.''')
+
+    def _create_task(self, type: int, title: str, content: str, user_id: str) -> str:
+        """
+        Create a task
+
+        :param type: the task type
+        :param title: the task title
+        :param content: the task content
+
+        :return: the task ID
+        """
+        headers = {
+            'x-channel': '',
+            'x-api-key': self.runtime.credentials['aippt_access_key'],
+            'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
+        }
+        response = post(
+            str(self._api_base_url / 'ai' / 'chat' / 'v2' / 'task'),
+            headers=headers,
+            files={
+                'type': ('', str(type)),
+                'title': ('', title),
+                'content': ('', content)
+            }
+        )
+
+        if response.status_code != 200:
+            raise Exception(f'Failed to connect to aippt: {response.text}')
+        
+        response = response.json()
+        if response.get('code') != 0:
+            raise Exception(f'Failed to create task: {response.get("msg")}')
+
+        return response.get('data', {}).get('id')
+    
+    def _generate_outline(self, task_id: str, model: str, user_id: str) -> str:
+        api_url = self._api_base_url / 'ai' / 'chat' / 'outline' if model == 'aippt' else \
+            self._api_base_url / 'ai' / 'chat' / 'wx' / 'outline'
+        api_url %= {'task_id': task_id}
+
+        headers = {
+            'x-channel': '',
+            'x-api-key': self.runtime.credentials['aippt_access_key'],
+            'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
+        }
+
+        response = requests_get(
+            url=api_url,
+            headers=headers,
+            stream=True,
+            timeout=(10, 60)
+        )
+
+        if response.status_code != 200:
+            raise Exception(f'Failed to connect to aippt: {response.text}')
+        
+        outline = ''
+        for chunk in response.iter_lines(delimiter=b'\n\n'):
+            if not chunk:
+                continue
+            
+            event = ''
+            lines = chunk.decode('utf-8').split('\n')
+            for line in lines:
+                if line.startswith('event:'):
+                    event = line[6:]
+                elif line.startswith('data:'):
+                    data = line[5:]
+                    if event == 'message':
+                        try:
+                            data = json_loads(data)
+                            outline += data.get('content', '')
+                        except Exception as e:
+                            pass
+                    elif event == 'close':
+                        break
+                    elif event == 'error' or event == 'filter':
+                        raise Exception(f'Failed to generate outline: {data}')
+                    
+        return outline
+    
+    def _generate_content(self, task_id: str, model: str, user_id: str) -> str:
+        api_url = self._api_base_url / 'ai' / 'chat' / 'content' if model == 'aippt' else \
+            self._api_base_url / 'ai' / 'chat' / 'wx' / 'content'
+        api_url %= {'task_id': task_id}
+
+        headers = {
+            'x-channel': '',
+            'x-api-key': self.runtime.credentials['aippt_access_key'],
+            'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
+        }
+
+        response = requests_get(
+            url=api_url,
+            headers=headers,
+            stream=True,
+            timeout=(10, 60)
+        )
+
+        if response.status_code != 200:
+            raise Exception(f'Failed to connect to aippt: {response.text}')
+        
+        if model == 'aippt':
+            content = ''
+            for chunk in response.iter_lines(delimiter=b'\n\n'):
+                if not chunk:
+                    continue
+                
+                event = ''
+                lines = chunk.decode('utf-8').split('\n')
+                for line in lines:
+                    if line.startswith('event:'):
+                        event = line[6:]
+                    elif line.startswith('data:'):
+                        data = line[5:]
+                        if event == 'message':
+                            try:
+                                data = json_loads(data)
+                                content += data.get('content', '')
+                            except Exception as e:
+                                pass
+                        elif event == 'close':
+                            break
+                        elif event == 'error' or event == 'filter':
+                            raise Exception(f'Failed to generate content: {data}')
+                        
+            return content
+        elif model == 'wenxin':
+            response = response.json()
+            if response.get('code') != 0:
+                raise Exception(f'Failed to generate content: {response.get("msg")}')
+            
+            return response.get('data', '')
+        
+        return ''
+
+    def _generate_ppt(self, task_id: str, suit_id: int, user_id) -> tuple[str, str]:
+        """
+        Generate a ppt
+
+        :param task_id: the task ID
+        :param suit_id: the suit ID
+        :return: the cover url of the ppt and the ppt url
+        """
+        headers = {
+            'x-channel': '',
+            'x-api-key': self.runtime.credentials['aippt_access_key'],
+            'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
+        }
+
+        response = post(
+            str(self._api_base_url / 'design' / 'v2' / 'save'),
+            headers=headers,
+            data={
+                'task_id': task_id,
+                'template_id': suit_id
+            }
+        )
+
+        if response.status_code != 200:
+            raise Exception(f'Failed to connect to aippt: {response.text}')
+        
+        response = response.json()
+        if response.get('code') != 0:
+            raise Exception(f'Failed to generate ppt: {response.get("msg")}')
+        
+        id = response.get('data', {}).get('id')
+        cover_url = response.get('data', {}).get('cover_url')
+
+        response = post(
+            str(self._api_base_url / 'download' / 'export' / 'file'),
+            headers=headers,
+            data={
+                'id': id,
+                'format': 'ppt',
+                'files_to_zip': False,
+                'edit': True
+            }
+        )
+
+        if response.status_code != 200:
+            raise Exception(f'Failed to connect to aippt: {response.text}')
+        
+        response = response.json()
+        if response.get('code') != 0:
+            raise Exception(f'Failed to generate ppt: {response.get("msg")}')
+        
+        export_code = response.get('data')
+        if not export_code:
+            raise Exception('Failed to generate ppt, the export code is empty')
+        
+        current_iteration = 0
+        while current_iteration < 50:
+            # get ppt url
+            response = post(
+                str(self._api_base_url / 'download' / 'export' / 'file' / 'result'),
+                headers=headers,
+                data={
+                    'task_key': export_code
+                }
+            )
+
+            if response.status_code != 200:
+                raise Exception(f'Failed to connect to aippt: {response.text}')
+            
+            response = response.json()
+            if response.get('code') != 0:
+                raise Exception(f'Failed to generate ppt: {response.get("msg")}')
+            
+            if response.get('msg') == '导出中':
+                current_iteration += 1
+                sleep(2)
+                continue
+            
+            ppt_url = response.get('data', [])
+            if len(ppt_url) == 0:
+                raise Exception('Failed to generate ppt, the ppt url is empty')
+            
+            return cover_url, ppt_url[0]
+        
+        raise Exception('Failed to generate ppt, the export is timeout')
+        
+    @classmethod
+    def _get_api_token(cls, credentials: dict[str, str], user_id: str) -> str:
+        """
+        Get API token
+
+        :param credentials: the credentials
+        :return: the API token
+        """
+        access_key = credentials['aippt_access_key']
+        secret_key = credentials['aippt_secret_key']
+
+        cache_key = f'{access_key}#@#{user_id}'
+
+        with cls._api_token_cache_lock:
+            # clear expired tokens
+            now = time()
+            for key in list(cls._api_token_cache.keys()):
+                if cls._api_token_cache[key]['expire'] < now:
+                    del cls._api_token_cache[key]
+
+            if cache_key in cls._api_token_cache:
+                return cls._api_token_cache[cache_key]['token']
+            
+        # get token
+        headers = {
+            'x-api-key': access_key,
+            'x-timestamp': str(int(now)),
+            'x-signature': cls._calculate_sign(access_key, secret_key, int(now))
+        }
+
+        param = {
+            'uid': user_id,
+            'channel': ''
+        }
+
+        response = get(
+            str(cls._api_base_url / 'grant' / 'token'),
+            params=param,
+            headers=headers
+        )
+
+        if response.status_code != 200:
+            raise Exception(f'Failed to connect to aippt: {response.text}')
+        response = response.json()
+        if response.get('code') != 0:
+            raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
+        
+        token = response.get('data', {}).get('token')
+        expire = response.get('data', {}).get('time_expire')
+
+        with cls._api_token_cache_lock:
+            cls._api_token_cache[cache_key] = {
+                'token': token,
+                'expire': now + expire
+            }
+
+        return token
+
+    @classmethod
+    def _calculate_sign(cls, access_key: str, secret_key: str, timestamp: int) -> str:
+        return b64encode(
+            hmac_new(
+                key=secret_key.encode('utf-8'), 
+                msg=f'GET@/api/grant/token/@{timestamp}'.encode(),
+                digestmod=sha1
+            ).digest()
+        ).decode('utf-8')
+
+    def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]:
+        """
+        Get styles
+
+        :param credentials: the credentials
+        :return: Tuple[list[dict[id, color]], list[dict[id, style]]
+        """
+        headers = {
+            'x-channel': '',
+            'x-api-key': self.runtime.credentials['aippt_access_key'],
+            'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id)
+        }
+        response = get(
+            str(self._api_base_url / 'template_component' / 'suit' / 'select'),
+            headers=headers
+        )
+
+        if response.status_code != 200:
+            raise Exception(f'Failed to connect to aippt: {response.text}')
+        
+        response = response.json()
+
+        if response.get('code') != 0:
+            raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
+        
+        colors = [{
+            'id': f'id-{item.get("id")}',
+            'name': item.get('name'),
+            'en_name': item.get('en_name', item.get('name')),
+        } for item in response.get('data', {}).get('colour') or []]
+        styles = [{
+            'id': f'id-{item.get("id")}',
+            'name': item.get('title'),
+        } for item in response.get('data', {}).get('suit_style') or []]
+
+        return colors, styles
+    
+    def _get_suit(self, style_id: int, colour_id: int) -> int:
+        """
+        Get suit
+        """
+        headers = {
+            'x-channel': '',
+            'x-api-key': self.runtime.credentials['aippt_access_key'],
+            'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id='__dify_system__')
+        }
+        response = get(
+            str(self._api_base_url / 'template_component' / 'suit' / 'search'),
+            headers=headers,
+            params={
+                'style_id': style_id,
+                'colour_id': colour_id,
+                'page': 1,
+                'page_size': 1
+            }
+        )
+
+        if response.status_code != 200:
+            raise Exception(f'Failed to connect to aippt: {response.text}')
+        
+        response = response.json()
+
+        if response.get('code') != 0:
+            raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
+        
+        if len(response.get('data', {}).get('list') or []) > 0:
+            return response.get('data', {}).get('list')[0].get('id')
+        
+        raise Exception('Failed to get suit, the suit does not exist, please check the style and color')
+    
+    def get_runtime_parameters(self) -> list[ToolParameter]:
+        """
+        Get runtime parameters
+
+        Override this method to add runtime parameters to the tool.
+        """
+        try:
+            colors, styles = self.get_styles(user_id='__dify_system__')
+        except Exception as e:
+            colors, styles = [
+                {'id': -1, 'name': '__default__'}
+            ], [
+                {'id': -1, 'name': '__default__'}
+            ]
+
+        return [
+            ToolParameter(
+                name='color',
+                label=I18nObject(zh_Hans='颜色', en_US='Color'),
+                human_description=I18nObject(zh_Hans='颜色', en_US='Color'),
+                type=ToolParameter.ToolParameterType.SELECT,
+                form=ToolParameter.ToolParameterForm.FORM,
+                required=False,
+                default=colors[0]['id'],
+                options=[
+                    ToolParameterOption(
+                        value=color['id'],
+                        label=I18nObject(zh_Hans=color['name'], en_US=color['en_name'])
+                    ) for color in colors
+                ]
+            ),
+            ToolParameter(
+                name='style',
+                label=I18nObject(zh_Hans='风格', en_US='Style'),
+                human_description=I18nObject(zh_Hans='风格', en_US='Style'),
+                type=ToolParameter.ToolParameterType.SELECT,
+                form=ToolParameter.ToolParameterForm.FORM,
+                required=False,
+                default=styles[0]['id'],
+                options=[
+                    ToolParameterOption(
+                        value=style['id'],
+                        label=I18nObject(zh_Hans=style['name'], en_US=style['name'])
+                    ) for style in styles
+                ]
+            ),
+        ]

+ 54 - 0
api/core/tools/provider/builtin/aippt/tools/aippt.yaml

@@ -0,0 +1,54 @@
+identity:
+  name: aippt
+  author: Dify
+  label:
+    en_US: AIPPT
+    zh_Hans: AIPPT
+description:
+  human:
+    en_US: AI-generated PPT with one click, input your content topic, and let AI serve you one-stop
+    zh_Hans: AI一键生成PPT,输入你的内容主题,让AI为你一站式服务到底
+  llm: A tool used to generate PPT with AI, input your content topic, and let AI generate PPT for you.
+parameters:
+  - name: title
+    type: string
+    required: true
+    label:
+      en_US: Title
+      zh_Hans: 标题
+    human_description:
+      en_US: The title of the PPT.
+      zh_Hans: PPT的标题。
+    llm_description: The title of the PPT, which will be used to generate the PPT outline.
+    form: llm
+  - name: outline
+    type: string
+    required: false
+    label:
+      en_US: Outline
+      zh_Hans: 大纲
+    human_description:
+      en_US: The outline of the PPT
+      zh_Hans: PPT的大纲
+    llm_description: The outline of the PPT, which will be used to generate the PPT content. provide it if you have.
+    form: llm
+  - name: llm
+    type: select
+    required: true
+    label:
+      en_US: LLM model
+      zh_Hans: 生成大纲的LLM
+    options:
+      - value: aippt
+        label:
+          en_US: AIPPT default model
+          zh_Hans: AIPPT默认模型
+      - value: wenxin
+        label:
+          en_US: Wenxin ErnieBot
+          zh_Hans: 文心一言
+    default: aippt
+    human_description:
+      en_US: The LLM model used for generating PPT outline.
+      zh_Hans: 用于生成PPT大纲的LLM模型。
+    form: form

+ 60 - 6
api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py

@@ -2,11 +2,11 @@ import io
 import json
 import json
 from base64 import b64decode, b64encode
 from base64 import b64decode, b64encode
 from copy import deepcopy
 from copy import deepcopy
-from os.path import join
 from typing import Any, Union
 from typing import Any, Union
 
 
 from httpx import get, post
 from httpx import get, post
 from PIL import Image
 from PIL import Image
+from yarl import URL
 
 
 from core.tools.entities.common_entities import I18nObject
 from core.tools.entities.common_entities import I18nObject
 from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
 from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
@@ -79,7 +79,7 @@ class StableDiffusionTool(BuiltinTool):
         
         
         # set model
         # set model
         try:
         try:
-            url = join(base_url, 'sdapi/v1/options')
+            url = str(URL(base_url) / 'sdapi' / 'v1' / 'options')
             response = post(url, data=json.dumps({
             response = post(url, data=json.dumps({
                 'sd_model_checkpoint': model
                 'sd_model_checkpoint': model
             }))
             }))
@@ -153,8 +153,21 @@ class StableDiffusionTool(BuiltinTool):
             if not model:
             if not model:
                 raise ToolProviderCredentialValidationError('Please input model')
                 raise ToolProviderCredentialValidationError('Please input model')
 
 
-            response = get(url=f'{base_url}/sdapi/v1/sd-models', timeout=120)
-            if response.status_code != 200:
+            api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'sd-models')
+            response = get(url=api_url, timeout=10)
+            if response.status_code == 404:
+                # try draw a picture
+                self._invoke(
+                    user_id='test',
+                    tool_parameters={
+                        'prompt': 'a cat',
+                        'width': 1024,
+                        'height': 1024,
+                        'steps': 1,
+                        'lora': '',
+                    }
+                )
+            elif response.status_code != 200:
                 raise ToolProviderCredentialValidationError('Failed to get models')
                 raise ToolProviderCredentialValidationError('Failed to get models')
             else:
             else:
                 models = [d['model_name'] for d in response.json()]
                 models = [d['model_name'] for d in response.json()]
@@ -165,6 +178,23 @@ class StableDiffusionTool(BuiltinTool):
         except Exception as e:
         except Exception as e:
             raise ToolProviderCredentialValidationError(f'Failed to get models, {e}')
             raise ToolProviderCredentialValidationError(f'Failed to get models, {e}')
 
 
+    def get_sd_models(self) -> list[str]:
+        """
+            get sd models
+        """
+        try:
+            base_url = self.runtime.credentials.get('base_url', None)
+            if not base_url:
+                return []
+            api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'sd-models')
+            response = get(url=api_url, timeout=10)
+            if response.status_code != 200:
+                return []
+            else:
+                return [d['model_name'] for d in response.json()]
+        except Exception as e:
+            return []
+
     def img2img(self, base_url: str, lora: str, image_binary: bytes, 
     def img2img(self, base_url: str, lora: str, image_binary: bytes, 
                 prompt: str, negative_prompt: str,
                 prompt: str, negative_prompt: str,
                 width: int, height: int, steps: int) \
                 width: int, height: int, steps: int) \
@@ -192,7 +222,7 @@ class StableDiffusionTool(BuiltinTool):
             draw_options['prompt'] = prompt
             draw_options['prompt'] = prompt
 
 
         try:
         try:
-            url = join(base_url, 'sdapi/v1/img2img')
+            url = str(URL(base_url) / 'sdapi' / 'v1' / 'img2img')
             response = post(url, data=json.dumps(draw_options), timeout=120)
             response = post(url, data=json.dumps(draw_options), timeout=120)
             if response.status_code != 200:
             if response.status_code != 200:
                 return self.create_text_message('Failed to generate image')
                 return self.create_text_message('Failed to generate image')
@@ -225,7 +255,7 @@ class StableDiffusionTool(BuiltinTool):
         draw_options['negative_prompt'] = negative_prompt
         draw_options['negative_prompt'] = negative_prompt
         
         
         try:
         try:
-            url = join(base_url, 'sdapi/v1/txt2img')
+            url = str(URL(base_url) / 'sdapi' / 'v1' / 'txt2img')
             response = post(url, data=json.dumps(draw_options), timeout=120)
             response = post(url, data=json.dumps(draw_options), timeout=120)
             if response.status_code != 200:
             if response.status_code != 200:
                 return self.create_text_message('Failed to generate image')
                 return self.create_text_message('Failed to generate image')
@@ -269,5 +299,29 @@ class StableDiffusionTool(BuiltinTool):
                                  label=I18nObject(en_US=i.name, zh_Hans=i.name)
                                  label=I18nObject(en_US=i.name, zh_Hans=i.name)
                              ) for i in self.list_default_image_variables()])
                              ) for i in self.list_default_image_variables()])
             )
             )
+        
+        if self.runtime.credentials:
+            try:
+                models = self.get_sd_models()
+                if len(models) != 0:
+                    parameters.append(
+                        ToolParameter(name='model',
+                                     label=I18nObject(en_US='Model', zh_Hans='Model'),
+                                     human_description=I18nObject(
+                                        en_US='Model of Stable Diffusion, you can check the official documentation of Stable Diffusion',
+                                        zh_Hans='Stable Diffusion 的模型,您可以查看 Stable Diffusion 的官方文档',
+                                     ),
+                                     type=ToolParameter.ToolParameterType.SELECT,
+                                     form=ToolParameter.ToolParameterForm.FORM,
+                                     llm_description='Model of Stable Diffusion, you can check the official documentation of Stable Diffusion',
+                                     required=True,
+                                     default=models[0],
+                                     options=[ToolParameterOption(
+                                         value=i,
+                                         label=I18nObject(en_US=i, zh_Hans=i)
+                                     ) for i in models])
+                    )
+            except:
+                pass
 
 
         return parameters
         return parameters

+ 43 - 5
api/services/tools_manage_service.py

@@ -9,6 +9,7 @@ from core.tools.entities.tool_entities import (
     ApiProviderAuthType,
     ApiProviderAuthType,
     ApiProviderSchemaType,
     ApiProviderSchemaType,
     ToolCredentialsOption,
     ToolCredentialsOption,
+    ToolParameter,
     ToolProviderCredentials,
     ToolProviderCredentials,
 )
 )
 from core.tools.entities.user_entities import UserTool, UserToolProvider
 from core.tools.entities.user_entities import UserTool, UserToolProvider
@@ -73,15 +74,52 @@ class ToolManageService:
         provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider)
         provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider)
         tools = provider_controller.get_tools()
         tools = provider_controller.get_tools()
 
 
-        result = [
-            UserTool(
+        tool_provider_configurations = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller)
+        # check if user has added the provider
+        builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
+            BuiltinToolProvider.tenant_id == tenant_id,
+            BuiltinToolProvider.provider == provider,
+        ).first()
+
+        credentials = {}
+        if builtin_provider is not None:
+            # get credentials
+            credentials = builtin_provider.credentials
+            credentials = tool_provider_configurations.decrypt_tool_credentials(credentials)
+
+        result = []
+        for tool in tools:
+            # fork tool runtime
+            tool = tool.fork_tool_runtime(meta={
+                'credentials': credentials,
+                'tenant_id': tenant_id,
+            })
+
+            # get tool parameters
+            parameters = tool.parameters or []
+            # get tool runtime parameters
+            runtime_parameters = tool.get_runtime_parameters()
+            # override parameters
+            current_parameters = parameters.copy()
+            for runtime_parameter in runtime_parameters:
+                found = False
+                for index, parameter in enumerate(current_parameters):
+                    if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
+                        current_parameters[index] = runtime_parameter
+                        found = True
+                        break
+
+                if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
+                    current_parameters.append(runtime_parameter)
+
+            user_tool = UserTool(
                 author=tool.identity.author,
                 author=tool.identity.author,
                 name=tool.identity.name,
                 name=tool.identity.name,
                 label=tool.identity.label,
                 label=tool.identity.label,
                 description=tool.description.human,
                 description=tool.description.human,
-                parameters=tool.parameters or []
-            ) for tool in tools
-        ]
+                parameters=current_parameters
+            )
+            result.append(user_tool)
 
 
         return json.loads(
         return json.loads(
             serialize_base_model_array(result)
             serialize_base_model_array(result)