Prechádzať zdrojové kódy

dalle3 add style consistency parameter (#5067)

Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Charlie.Wei 10 mesiacov pred
rodič
commit
b7c72f7a97

+ 19 - 8
api/core/tools/provider/builtin/azuredalle/tools/dalle3.py

@@ -8,10 +8,10 @@ from core.tools.tool.builtin_tool import BuiltinTool
 
 
 class DallE3Tool(BuiltinTool):
-    def _invoke(self, 
-                user_id: str, 
-               tool_parameters: dict[str, Any], 
-        ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
+    def _invoke(self,
+                user_id: str,
+                tool_parameters: dict[str, Any],
+                ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
         """
             invoke tools
         """
@@ -43,14 +43,18 @@ class DallE3Tool(BuiltinTool):
         style = tool_parameters.get('style', 'vivid')
         if style not in ['natural', 'vivid']:
             return self.create_text_message('Invalid style')
+        # set extra body
+        seed_id = tool_parameters.get('seed_id', self._generate_random_id(8))
+        extra_body = {'seed': seed_id}
 
         # call openapi dalle3
-        model=self.runtime.credentials['azure_openai_api_model_name']
+        model = self.runtime.credentials['azure_openai_api_model_name']
         response = client.images.generate(
             prompt=prompt,
             model=model,
             size=size,
             n=n,
+            extra_body=extra_body,
             style=style,
             quality=quality,
             response_format='b64_json'
@@ -59,8 +63,15 @@ class DallE3Tool(BuiltinTool):
         result = []
 
         for image in response.data:
-            result.append(self.create_blob_message(blob=b64decode(image.b64_json), 
-                                                   meta={ 'mime_type': 'image/png' },
-                                                    save_as=self.VARIABLE_KEY.IMAGE.value))
+            result.append(self.create_blob_message(blob=b64decode(image.b64_json),
+                                                   meta={'mime_type': 'image/png'},
+                                                   save_as=self.VARIABLE_KEY.IMAGE.value))
+        result.append(self.create_text_message(f'\nGenerate image source to Seed ID: {seed_id}'))
 
         return result
+
+    @staticmethod
+    def _generate_random_id(length=8):
+        characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
+        random_id = ''.join(random.choices(characters, k=length))
+        return random_id

+ 13 - 0
api/core/tools/provider/builtin/azuredalle/tools/dalle3.yaml

@@ -29,6 +29,19 @@ parameters:
       pt_BR: Imagem prompt, você pode verificar a documentação oficial do DallE 3
     llm_description: Image prompt of DallE 3, you should describe the image you want to generate as a list of words as possible as detailed
     form: llm
+  - name: seed_id
+    type: string
+    required: false
+    label:
+      en_US: Seed ID
+      zh_Hans: 种子ID
+      pt_BR: ID da semente
+    human_description:
+      en_US: Image generation seed ID to ensure consistency of series generated images
+      zh_Hans: 图像生成种子ID,确保系列生成图像的一致性
+      pt_BR: ID de semente de geração de imagem para garantir a consistência das imagens geradas em série
+    llm_description: If the user requests image consistency, extract the seed ID from the user's question or context.The seed id consists of an 8-bit string containing uppercase and lowercase letters and numbers
+    form: llm
   - name: size
     type: select
     required: true

+ 19 - 7
api/core/tools/provider/builtin/dalle/tools/dalle3.py

@@ -1,3 +1,4 @@
+import random
 from base64 import b64decode
 from typing import Any, Union
 
@@ -9,10 +10,10 @@ from core.tools.tool.builtin_tool import BuiltinTool
 
 
 class DallE3Tool(BuiltinTool):
-    def _invoke(self, 
-                user_id: str, 
-               tool_parameters: dict[str, Any], 
-        ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
+    def _invoke(self,
+                user_id: str,
+                tool_parameters: dict[str, Any],
+                ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
         """
             invoke tools
         """
@@ -53,6 +54,9 @@ class DallE3Tool(BuiltinTool):
         style = tool_parameters.get('style', 'vivid')
         if style not in ['natural', 'vivid']:
             return self.create_text_message('Invalid style')
+        # set extra body
+        seed_id = tool_parameters.get('seed_id', self._generate_random_id(8))
+        extra_body = {'seed': seed_id}
 
         # call openapi dalle3
         response = client.images.generate(
@@ -60,6 +64,7 @@ class DallE3Tool(BuiltinTool):
             model='dall-e-3',
             size=size,
             n=n,
+            extra_body=extra_body,
             style=style,
             quality=quality,
             response_format='b64_json'
@@ -68,8 +73,15 @@ class DallE3Tool(BuiltinTool):
         result = []
 
         for image in response.data:
-            result.append(self.create_blob_message(blob=b64decode(image.b64_json), 
-                                                   meta={ 'mime_type': 'image/png' },
-                                                    save_as=self.VARIABLE_KEY.IMAGE.value))
+            result.append(self.create_blob_message(blob=b64decode(image.b64_json),
+                                                   meta={'mime_type': 'image/png'},
+                                                   save_as=self.VARIABLE_KEY.IMAGE.value))
+        result.append(self.create_text_message(f'\nGenerate image source to Seed ID: {seed_id}'))
 
         return result
+
+    @staticmethod
+    def _generate_random_id(length=8):
+        characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
+        random_id = ''.join(random.choices(characters, k=length))
+        return random_id

+ 13 - 0
api/core/tools/provider/builtin/dalle/tools/dalle3.yaml

@@ -29,6 +29,19 @@ parameters:
       pt_BR: Image prompt, you can check the official documentation of DallE 3
     llm_description: Image prompt of DallE 3, you should describe the image you want to generate as a list of words as possible as detailed
     form: llm
+  - name: seed_id
+    type: string
+    required: false
+    label:
+      en_US: Seed ID
+      zh_Hans: 种子ID
+      pt_BR: ID da semente
+    human_description:
+      en_US: Image generation seed ID to ensure consistency of series generated images
+      zh_Hans: 图像生成种子ID,确保系列生成图像的一致性
+      pt_BR: ID de semente de geração de imagem para garantir a consistência das imagens geradas em série
+    llm_description: If the user requests image consistency, extract the seed ID from the user's question or context.The seed id consists of an 8-bit string containing uppercase and lowercase letters and numbers
+    form: llm
   - name: size
     type: select
     required: true