Parcourir la source

Feat/add azure dalle tool (#2276)

Co-authored-by: lux@njuelectronics.com <lux@njuelectronics.com>
Co-authored-by: crazywoola <427733928@qq.com>
呆萌闷油瓶 il y a 1 an
Parent
commit
c97b7f6748

+ 0 - 0
api/core/tools/provider/builtin/azuredalle/__init__.py


BIN
api/core/tools/provider/builtin/azuredalle/_assets/icon.png


+ 23 - 0
api/core/tools/provider/builtin/azuredalle/azuredalle.py

@@ -0,0 +1,23 @@
+from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
+from core.tools.provider.builtin.azuredalle.tools.dalle3 import DallE3Tool
+from core.tools.errors import ToolProviderCredentialValidationError
+
+from typing import Any, Dict
+
+class AzureDALLEProvider(BuiltinToolProviderController):
+    def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
+        try:
+            DallE3Tool().fork_tool_runtime(
+                meta={
+                    "credentials": credentials,
+                }
+            ).invoke(
+                user_id='',
+                tool_paramters={
+                    "prompt": "cute girl, blue eyes, white hair, anime style",
+                    "size": "square",
+                    "n": 1
+                },
+            )
+        except Exception as e:
+            raise ToolProviderCredentialValidationError(str(e))

+ 73 - 0
api/core/tools/provider/builtin/azuredalle/azuredalle.yaml

@@ -0,0 +1,73 @@
+identity:
+  author: Leslie
+  name: azuredalle
+  label:
+    en_US: AZURE DALL-E
+    zh_Hans: AZURE DALL-E 绘画
+    pt_BR: AZURE DALL-E
+  description:
+    en_US: AZURE DALL-E art
+    zh_Hans: AZURE DALL-E 绘画
+    pt_BR: AZURE DALL-E art
+  icon: icon.png
+credentials_for_provider:
+  azure_openai_api_key:
+    type: secret-input
+    required: true
+    label:
+      en_US: API key
+      zh_Hans: 密钥
+      pt_BR: API key
+    help:
+      en_US: Please input your Azure OpenAI API key
+      zh_Hans: 请输入你的 Azure OpenAI API key
+      pt_BR: Please input your Azure OpenAI API key
+    placeholder:
+      en_US: Please input your Azure OpenAI API key
+      zh_Hans: 请输入你的 Azure OpenAI API key
+      pt_BR: Please input your Azure OpenAI API key
+  azure_openai_api_model_name:
+    type: text-input
+    required: true
+    label:
+      en_US: Deployment Name
+      zh_Hans: 部署名称
+      pt_BR: Deployment Name
+    help:
+      en_US: Please input the name of your Azure Openai DALL-E API deployment
+      zh_Hans: 请输入你的 Azure Openai DALL-E API 部署名称
+      pt_BR: Please input the name of your Azure Openai DALL-E API deployment
+    placeholder:
+      en_US: Please input the name of your Azure Openai DALL-E API deployment
+      zh_Hans: 请输入你的 Azure Openai DALL-E API 部署名称
+      pt_BR: Please input the name of your Azure Openai DALL-E API deployment
+  azure_openai_base_url:
+    type: text-input
+    required: true
+    label:
+      en_US: API Endpoint URL
+      zh_Hans: API 域名
+      pt_BR: API Endpoint URL
+    help:
+      en_US: Please input your Azure OpenAI Endpoint URL,eg:https://xxx.openai.azure.com/
+      zh_Hans: 请输入你的 Azure OpenAI API域名,例如:https://xxx.openai.azure.com/
+      pt_BR: Please input your Azure OpenAI Endpoint URL,eg:https://xxx.openai.azure.com/
+    placeholder:
+      en_US: Please input your Azure OpenAI Endpoint URL,eg:https://xxx.openai.azure.com/
+      zh_Hans: 请输入你的 Azure OpenAI API域名,例如:https://xxx.openai.azure.com/
+      pt_BR: Please input your Azure OpenAI Endpoint URL,eg:https://xxx.openai.azure.com/
+  azure_openai_api_version:
+    type: text-input
+    required: true
+    label:
+      en_US: API Version
+      zh_Hans: API 版本
+      pt_BR: API Version
+    help:
+      en_US: Please input your Azure OpenAI API Version,eg:2023-12-01-preview
+      zh_Hans: 请输入你的 Azure OpenAI API 版本,例如:2023-12-01-preview
+      pt_BR: Please input your Azure OpenAI API Version,eg:2023-12-01-preview
+    placeholder:
+      en_US: Please input your Azure OpenAI API Version,eg:2023-12-01-preview
+      zh_Hans: 请输入你的 Azure OpenAI API 版本,例如:2023-12-01-preview
+      pt_BR: Please input your Azure OpenAI API Version,eg:2023-12-01-preview

+ 66 - 0
api/core/tools/provider/builtin/azuredalle/tools/dalle3.py

@@ -0,0 +1,66 @@
+from typing import Any, Dict, List, Union
+from core.tools.entities.tool_entities import ToolInvokeMessage
+from core.tools.tool.builtin_tool import BuiltinTool
+
+from base64 import b64decode
+from os.path import join
+
+from openai import AzureOpenAI
+
+class DallE3Tool(BuiltinTool):
+    def _invoke(self, 
+                user_id: str, 
+               tool_paramters: Dict[str, Any], 
+        ) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
+        """
+            invoke tools
+        """
+        client = AzureOpenAI(
+            api_version=self.runtime.credentials['azure_openai_api_version'],
+            azure_endpoint=self.runtime.credentials['azure_openai_base_url'],
+            api_key=self.runtime.credentials['azure_openai_api_key'],
+        )
+
+        SIZE_MAPPING = {
+            'square': '1024x1024',
+            'vertical': '1024x1792',
+            'horizontal': '1792x1024',
+        }
+
+        # prompt
+        prompt = tool_paramters.get('prompt', '')
+        if not prompt:
+            return self.create_text_message('Please input prompt')
+        # get size
+        size = SIZE_MAPPING[tool_paramters.get('size', 'square')]
+        # get n
+        n = tool_paramters.get('n', 1)
+        # get quality
+        quality = tool_paramters.get('quality', 'standard')
+        if quality not in ['standard', 'hd']:
+            return self.create_text_message('Invalid quality')
+        # get style
+        style = tool_paramters.get('style', 'vivid')
+        if style not in ['natural', 'vivid']:
+            return self.create_text_message('Invalid style')
+
+        # call openapi dalle3
+        model=self.runtime.credentials['azure_openai_api_model_name']
+        response = client.images.generate(
+            prompt=prompt,
+            model=model,
+            size=size,
+            n=n,
+            style=style,
+            quality=quality,
+            response_format='b64_json'
+        )
+
+        result = []
+
+        for image in response.data:
+            result.append(self.create_blob_message(blob=b64decode(image.b64_json), 
+                                                   meta={ 'mime_type': 'image/png' },
+                                                    save_as=self.VARIABLE_KEY.IMAGE.value))
+
+        return result

+ 123 - 0
api/core/tools/provider/builtin/azuredalle/tools/dalle3.yaml

@@ -0,0 +1,123 @@
+identity:
+  name: dalle3
+  author: Leslie
+  label:
+    en_US: DALL-E 3
+    zh_Hans: DALL-E 3 绘画
+    pt_BR: DALL-E 3
+  description:
+    en_US: DALL-E 3 is a powerful drawing tool that can draw the image you want based on your prompt, compared to DallE 2, DallE 3 has stronger drawing ability, but it will consume more resources
+    zh_Hans: DALL-E 3 是一个强大的绘画工具,它可以根据您的提示词绘制出您想要的图像,相比于DallE 2, DallE 3拥有更强的绘画能力,但会消耗更多的资源
+    pt_BR: DALL-E 3 is a powerful drawing tool that can draw the image you want based on your prompt, compared to DallE 2, DallE 3 has stronger drawing ability, but it will consume more resources
+description:
+  human:
+    en_US: DALL-E is a text to image tool
+    zh_Hans: DALL-E 是一个文本到图像的工具
+    pt_BR: DALL-E is a text to image tool
+  llm: DALL-E is a tool used to generate images from text
+parameters:
+  - name: prompt
+    type: string
+    required: true
+    label:
+      en_US: Prompt
+      zh_Hans: 提示词
+      pt_BR: Prompt
+    human_description:
+      en_US: Image prompt, you can check the official documentation of DallE 3
+      zh_Hans: 图像提示词,您可以查看DallE 3 的官方文档
+      pt_BR: Image prompt, you can check the official documentation of DallE 3
+    llm_description: Image prompt of DallE 3, you should describe the image you want to generate as a list of words as possible as detailed
+    form: llm
+  - name: size
+    type: select
+    required: true
+    human_description:
+      en_US: selecting the image size
+      zh_Hans: 选择图像大小
+      pt_BR: selecting the image size
+    label:
+      en_US: Image size
+      zh_Hans: 图像大小
+      pt_BR: Image size
+    form: form
+    options:
+      - value: square
+        label:
+          en_US: Squre(1024x1024)
+          zh_Hans: 方(1024x1024)
+          pt_BR: Squre(1024x1024)
+      - value: vertical
+        label:
+          en_US: Vertical(1024x1792)
+          zh_Hans: 竖屏(1024x1792)
+          pt_BR: Vertical(1024x1792)
+      - value: horizontal
+        label:
+          en_US: Horizontal(1792x1024)
+          zh_Hans: 横屏(1792x1024)
+          pt_BR: Horizontal(1792x1024)
+    default: square
+  - name: n
+    type: number
+    required: true
+    human_description:
+      en_US: selecting the number of images
+      zh_Hans: 选择图像数量
+      pt_BR: selecting the number of images
+    label:
+      en_US: Number of images
+      zh_Hans: 图像数量
+      pt_BR: Number of images
+    form: form
+    min: 1
+    max: 1
+    default: 1
+  - name: quality
+    type: select
+    required: true
+    human_description:
+      en_US: selecting the image quality
+      zh_Hans: 选择图像质量
+      pt_BR: selecting the image quality
+    label:
+      en_US: Image quality
+      zh_Hans: 图像质量
+      pt_BR: Image quality
+    form: form
+    options:
+      - value: standard
+        label:
+          en_US: Standard
+          zh_Hans: 标准
+          pt_BR: Standard
+      - value: hd
+        label:
+          en_US: HD
+          zh_Hans: 高清
+          pt_BR: HD
+    default: standard
+  - name: style
+    type: select
+    required: true
+    human_description:
+      en_US: selecting the image style
+      zh_Hans: 选择图像风格
+      pt_BR: selecting the image style
+    label:
+      en_US: Image style
+      zh_Hans: 图像风格
+      pt_BR: Image style
+    form: form
+    options:
+      - value: vivid
+        label:
+          en_US: Vivid
+          zh_Hans: 生动
+          pt_BR: Vivid
+      - value: natural
+        label:
+          en_US: Natural
+          zh_Hans: 自然
+          pt_BR: Natural
+    default: vivid