Browse Source

Feat/add zhipu CogView 3 tool (#6210)

Waffle 9 tháng trước cách đây
mục cha
commit
07add06c59

+ 7 - 2
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py

@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Optional
 import httpx
 
 from ..core._base_api import BaseAPI
-from ..core._base_type import NOT_GIVEN, Headers, NotGiven
+from ..core._base_type import NOT_GIVEN, Body, Headers, NotGiven
 from ..core._http_client import make_user_request_input
 from ..types.image import ImagesResponded
 
@@ -28,7 +28,9 @@ class Images(BaseAPI):
             size: Optional[str] | NotGiven = NOT_GIVEN,
             style: Optional[str] | NotGiven = NOT_GIVEN,
             user: str | NotGiven = NOT_GIVEN,
+            request_id: Optional[str] | NotGiven = NOT_GIVEN,
             extra_headers: Headers | None = None,
+            extra_body: Body | None = None,
             disable_strict_validation: Optional[bool] | None = None,
             timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
     ) -> ImagesResponded:
@@ -46,9 +48,12 @@ class Images(BaseAPI):
                 "size": size,
                 "style": style,
                 "user": user,
+                "request_id": request_id,
             },
             options=make_user_request_input(
-                extra_headers=extra_headers, timeout=timeout
+                extra_headers=extra_headers,
+                extra_body=extra_body,
+                timeout=timeout
             ),
             cast_type=_cast_type,
             enable_stream=False,

+ 4 - 1
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py

@@ -11,7 +11,7 @@ from tenacity import retry
 from tenacity.stop import stop_after_attempt
 
 from . import _errors
-from ._base_type import NOT_GIVEN, Body, Data, Headers, NotGiven, Query, RequestFiles, ResponseT
+from ._base_type import NOT_GIVEN, AnyMapping, Body, Data, Headers, NotGiven, Query, RequestFiles, ResponseT
 from ._errors import APIResponseValidationError, APIStatusError, APITimeoutError
 from ._files import make_httpx_files
 from ._request_opt import ClientRequestParam, UserRequestInput
@@ -358,6 +358,7 @@ def make_user_request_input(
         max_retries: int | None = None,
         timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
         extra_headers: Headers = None,
+        extra_body: Body | None = None,
         query: Query | None = None,
 ) -> UserRequestInput:
     options: UserRequestInput = {}
@@ -370,5 +371,7 @@ def make_user_request_input(
         options['timeout'] = timeout
     if query is not None:
         options["params"] = query
+    if extra_body is not None:
+        options["extra_json"] = cast(AnyMapping, extra_body)
 
     return options

+ 0 - 0
api/core/tools/provider/builtin/cogview/__init__.py


BIN
api/core/tools/provider/builtin/cogview/_assets/icon.png


+ 27 - 0
api/core/tools/provider/builtin/cogview/cogview.py

@@ -0,0 +1,27 @@
+""" Provide the input parameters type for the cogview provider class """
+from typing import Any
+
+from core.tools.errors import ToolProviderCredentialValidationError
+from core.tools.provider.builtin.cogview.tools.cogview3 import CogView3Tool
+from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
+
+
+class COGVIEWProvider(BuiltinToolProviderController):
+    """ cogview provider """
+    def _validate_credentials(self, credentials: dict[str, Any]) -> None:
+        try:
+            CogView3Tool().fork_tool_runtime(
+                runtime={
+                    "credentials": credentials,
+                }
+            ).invoke(
+                user_id='',
+                tool_parameters={
+                    "prompt": "一个城市在水晶瓶中欢快生活的场景,水彩画风格,展现出微观与珠宝般的美丽。",
+                    "size": "square",
+                    "n": 1
+                },
+            )
+        except Exception as e:
+            raise ToolProviderCredentialValidationError(str(e)) from e
+        

+ 61 - 0
api/core/tools/provider/builtin/cogview/cogview.yaml

@@ -0,0 +1,61 @@
+identity:
+  author: Waffle
+  name: cogview
+  label:
+    en_US: CogView
+    zh_Hans: CogView 绘画
+    pt_BR: CogView
+  description:
+    en_US: CogView art
+    zh_Hans: CogView 绘画
+    pt_BR: CogView art
+  icon: icon.png
+  tags:
+    - image
+    - productivity
+credentials_for_provider:
+  zhipuai_api_key:
+    type: secret-input
+    required: true
+    label:
+      en_US: ZhipuAI API key
+      zh_Hans: ZhipuAI API key
+      pt_BR: ZhipuAI API key
+    help:
+      en_US: Please input your ZhipuAI API key
+      zh_Hans: 请输入你的 ZhipuAI API key
+      pt_BR: Please input your ZhipuAI API key
+    placeholder:
+      en_US: Please input your ZhipuAI API key
+      zh_Hans: 请输入你的 ZhipuAI API key
+      pt_BR: Please input your ZhipuAI API key
+  zhipuai_organizaion_id:
+    type: text-input
+    required: false
+    label:
+      en_US: ZhipuAI organization ID
+      zh_Hans: ZhipuAI organization ID
+      pt_BR: ZhipuAI organization ID
+    help:
+      en_US: Please input your ZhipuAI organization ID
+      zh_Hans: 请输入你的 ZhipuAI organization ID
+      pt_BR: Please input your ZhipuAI organization ID
+    placeholder:
+      en_US: Please input your ZhipuAI organization ID
+      zh_Hans: 请输入你的 ZhipuAI organization ID
+      pt_BR: Please input your ZhipuAI organization ID
+  zhipuai_base_url:
+    type: text-input
+    required: false
+    label:
+      en_US: ZhipuAI base URL
+      zh_Hans: ZhipuAI base URL
+      pt_BR: ZhipuAI base URL
+    help:
+      en_US: Please input your ZhipuAI base URL
+      zh_Hans: 请输入你的 ZhipuAI base URL
+      pt_BR: Please input your ZhipuAI base URL
+    placeholder:
+      en_US: Please input your ZhipuAI base URL
+      zh_Hans: 请输入你的 ZhipuAI base URL
+      pt_BR: Please input your ZhipuAI base URL

+ 69 - 0
api/core/tools/provider/builtin/cogview/tools/cogview3.py

@@ -0,0 +1,69 @@
+import random
+from typing import Any, Union
+
+from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
+from core.tools.entities.tool_entities import ToolInvokeMessage
+from core.tools.tool.builtin_tool import BuiltinTool
+
+
+class CogView3Tool(BuiltinTool):
+    """ CogView3 Tool """
+
+    def _invoke(self,
+                user_id: str,
+                tool_parameters: dict[str, Any]
+                ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
+        """
+        Invoke CogView3 tool
+        """
+        client = ZhipuAI(
+            base_url=self.runtime.credentials['zhipuai_base_url'],
+            api_key=self.runtime.credentials['zhipuai_api_key'],
+        )
+        size_mapping = {
+            'square': '1024x1024',
+            'vertical': '1024x1792',
+            'horizontal': '1792x1024',
+        }
+        # prompt
+        prompt = tool_parameters.get('prompt', '')
+        if not prompt:
+            return self.create_text_message('Please input prompt')
+        # get size
+        print(tool_parameters.get('prompt', 'square'))
+        size = size_mapping[tool_parameters.get('size', 'square')]
+        # get n
+        n = tool_parameters.get('n', 1)
+        # get quality
+        quality = tool_parameters.get('quality', 'standard')
+        if quality not in ['standard', 'hd']:
+            return self.create_text_message('Invalid quality')
+        # get style
+        style = tool_parameters.get('style', 'vivid')
+        if style not in ['natural', 'vivid']:
+            return self.create_text_message('Invalid style')
+        # set extra body
+        seed_id = tool_parameters.get('seed_id', self._generate_random_id(8))
+        extra_body = {'seed': seed_id}
+        response = client.images.generations(
+            prompt=prompt,
+            model="cogview-3",
+            size=size,
+            n=n,
+            extra_body=extra_body,
+            style=style,
+            quality=quality,
+            response_format='b64_json'
+        )
+        result = []
+        for image in response.data:
+            result.append(self.create_image_message(image=image.url))
+        result.append(self.create_text_message(
+            f'\nGenerate image source to Seed ID: {seed_id}'))
+        return result
+
+    @staticmethod
+    def _generate_random_id(length=8):
+        characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
+        random_id = ''.join(random.choices(characters, k=length))
+        return random_id

+ 123 - 0
api/core/tools/provider/builtin/cogview/tools/cogview3.yaml

@@ -0,0 +1,123 @@
+identity:
+  name: cogview3
+  author: Waffle
+  label:
+    en_US: CogView 3
+    zh_Hans: CogView 3 绘画
+    pt_BR: CogView 3
+  description:
+    en_US: CogView 3 is a powerful drawing tool that can draw the image you want based on your prompt
+    zh_Hans: CogView 3 是一个强大的绘画工具,它可以根据您的提示词绘制出您想要的图像
+    pt_BR: CogView 3 is a powerful drawing tool that can draw the image you want based on your prompt
+description:
+  human:
+    en_US: CogView 3 is a text to image tool
+    zh_Hans: CogView 3 是一个文本到图像的工具
+    pt_BR: CogView 3 is a text to image tool
+  llm: CogView 3 is a tool used to generate images from text
+parameters:
+  - name: prompt
+    type: string
+    required: true
+    label:
+      en_US: Prompt
+      zh_Hans: 提示词
+      pt_BR: Prompt
+    human_description:
+      en_US: Image prompt, you can check the official documentation of CogView 3
+      zh_Hans: 图像提示词,您可以查看CogView 3 的官方文档
+      pt_BR: Image prompt, you can check the official documentation of CogView 3
+    llm_description: Image prompt of CogView 3, you should describe the image you want to generate as a list of words as possible as detailed
+    form: llm
+  - name: size
+    type: select
+    required: true
+    human_description:
+      en_US: selecting the image size
+      zh_Hans: 选择图像大小
+      pt_BR: selecting the image size
+    label:
+      en_US: Image size
+      zh_Hans: 图像大小
+      pt_BR: Image size
+    form: form
+    options:
+      - value: square
+        label:
+          en_US: Squre(1024x1024)
+          zh_Hans: 方(1024x1024)
+          pt_BR: Squre(1024x1024)
+      - value: vertical
+        label:
+          en_US: Vertical(1024x1792)
+          zh_Hans: 竖屏(1024x1792)
+          pt_BR: Vertical(1024x1792)
+      - value: horizontal
+        label:
+          en_US: Horizontal(1792x1024)
+          zh_Hans: 横屏(1792x1024)
+          pt_BR: Horizontal(1792x1024)
+    default: square
+  - name: n
+    type: number
+    required: true
+    human_description:
+      en_US: selecting the number of images
+      zh_Hans: 选择图像数量
+      pt_BR: selecting the number of images
+    label:
+      en_US: Number of images
+      zh_Hans: 图像数量
+      pt_BR: Number of images
+    form: form
+    min: 1
+    max: 1
+    default: 1
+  - name: quality
+    type: select
+    required: true
+    human_description:
+      en_US: selecting the image quality
+      zh_Hans: 选择图像质量
+      pt_BR: selecting the image quality
+    label:
+      en_US: Image quality
+      zh_Hans: 图像质量
+      pt_BR: Image quality
+    form: form
+    options:
+      - value: standard
+        label:
+          en_US: Standard
+          zh_Hans: 标准
+          pt_BR: Standard
+      - value: hd
+        label:
+          en_US: HD
+          zh_Hans: 高清
+          pt_BR: HD
+    default: standard
+  - name: style
+    type: select
+    required: true
+    human_description:
+      en_US: selecting the image style
+      zh_Hans: 选择图像风格
+      pt_BR: selecting the image style
+    label:
+      en_US: Image style
+      zh_Hans: 图像风格
+      pt_BR: Image style
+    form: form
+    options:
+      - value: vivid
+        label:
+          en_US: Vivid
+          zh_Hans: 生动
+          pt_BR: Vivid
+      - value: natural
+        label:
+          en_US: Natural
+          zh_Hans: 自然
+          pt_BR: Natural
+    default: vivid