|
@@ -1,3 +1,4 @@
|
|
|
+import random
|
|
|
from base64 import b64decode
|
|
|
from typing import Any, Union
|
|
|
|
|
@@ -9,10 +10,10 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
|
|
|
|
|
|
|
|
class DallE3Tool(BuiltinTool):
|
|
|
- def _invoke(self,
|
|
|
- user_id: str,
|
|
|
- tool_parameters: dict[str, Any],
|
|
|
- ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
|
|
+ def _invoke(self,
|
|
|
+ user_id: str,
|
|
|
+ tool_parameters: dict[str, Any],
|
|
|
+ ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
|
|
"""
|
|
|
invoke tools
|
|
|
"""
|
|
@@ -53,6 +54,9 @@ class DallE3Tool(BuiltinTool):
|
|
|
style = tool_parameters.get('style', 'vivid')
|
|
|
if style not in ['natural', 'vivid']:
|
|
|
return self.create_text_message('Invalid style')
|
|
|
+ # set extra body
|
|
|
+ seed_id = tool_parameters.get('seed_id', self._generate_random_id(8))
|
|
|
+ extra_body = {'seed': seed_id}
|
|
|
|
|
|
# call openapi dalle3
|
|
|
response = client.images.generate(
|
|
@@ -60,6 +64,7 @@ class DallE3Tool(BuiltinTool):
|
|
|
model='dall-e-3',
|
|
|
size=size,
|
|
|
n=n,
|
|
|
+ extra_body=extra_body,
|
|
|
style=style,
|
|
|
quality=quality,
|
|
|
response_format='b64_json'
|
|
@@ -68,8 +73,15 @@ class DallE3Tool(BuiltinTool):
|
|
|
result = []
|
|
|
|
|
|
for image in response.data:
|
|
|
- result.append(self.create_blob_message(blob=b64decode(image.b64_json),
|
|
|
- meta={ 'mime_type': 'image/png' },
|
|
|
- save_as=self.VARIABLE_KEY.IMAGE.value))
|
|
|
+ result.append(self.create_blob_message(blob=b64decode(image.b64_json),
|
|
|
+ meta={'mime_type': 'image/png'},
|
|
|
+ save_as=self.VARIABLE_KEY.IMAGE.value))
|
|
|
+ result.append(self.create_text_message(f'\nGenerate image source to Seed ID: {seed_id}'))
|
|
|
|
|
|
return result
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _generate_random_id(length=8):
|
|
|
+ characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
|
|
|
+ random_id = ''.join(random.choices(characters, k=length))
|
|
|
+ return random_id
|