|
@@ -0,0 +1,69 @@
|
|
|
+import random
|
|
|
+from typing import Any, Union
|
|
|
+
|
|
|
+from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
|
|
|
+from core.tools.entities.tool_entities import ToolInvokeMessage
|
|
|
+from core.tools.tool.builtin_tool import BuiltinTool
|
|
|
+
|
|
|
+
|
|
|
+class CogView3Tool(BuiltinTool):
|
|
|
+ """ CogView3 Tool """
|
|
|
+
|
|
|
+ def _invoke(self,
|
|
|
+ user_id: str,
|
|
|
+ tool_parameters: dict[str, Any]
|
|
|
+ ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
|
|
+ """
|
|
|
+ Invoke CogView3 tool
|
|
|
+ """
|
|
|
+ client = ZhipuAI(
|
|
|
+ base_url=self.runtime.credentials['zhipuai_base_url'],
|
|
|
+ api_key=self.runtime.credentials['zhipuai_api_key'],
|
|
|
+ )
|
|
|
+ size_mapping = {
|
|
|
+ 'square': '1024x1024',
|
|
|
+ 'vertical': '1024x1792',
|
|
|
+ 'horizontal': '1792x1024',
|
|
|
+ }
|
|
|
+ # prompt
|
|
|
+ prompt = tool_parameters.get('prompt', '')
|
|
|
+ if not prompt:
|
|
|
+ return self.create_text_message('Please input prompt')
|
|
|
+ # get size
|
|
|
+ print(tool_parameters.get('prompt', 'square'))
|
|
|
+ size = size_mapping[tool_parameters.get('size', 'square')]
|
|
|
+ # get n
|
|
|
+ n = tool_parameters.get('n', 1)
|
|
|
+ # get quality
|
|
|
+ quality = tool_parameters.get('quality', 'standard')
|
|
|
+ if quality not in ['standard', 'hd']:
|
|
|
+ return self.create_text_message('Invalid quality')
|
|
|
+ # get style
|
|
|
+ style = tool_parameters.get('style', 'vivid')
|
|
|
+ if style not in ['natural', 'vivid']:
|
|
|
+ return self.create_text_message('Invalid style')
|
|
|
+ # set extra body
|
|
|
+ seed_id = tool_parameters.get('seed_id', self._generate_random_id(8))
|
|
|
+ extra_body = {'seed': seed_id}
|
|
|
+ response = client.images.generations(
|
|
|
+ prompt=prompt,
|
|
|
+ model="cogview-3",
|
|
|
+ size=size,
|
|
|
+ n=n,
|
|
|
+ extra_body=extra_body,
|
|
|
+ style=style,
|
|
|
+ quality=quality,
|
|
|
+ response_format='b64_json'
|
|
|
+ )
|
|
|
+ result = []
|
|
|
+ for image in response.data:
|
|
|
+ result.append(self.create_image_message(image=image.url))
|
|
|
+ result.append(self.create_text_message(
|
|
|
+ f'\nGenerate image source to Seed ID: {seed_id}'))
|
|
|
+ return result
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _generate_random_id(length=8):
|
|
|
+ characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
|
|
|
+ random_id = ''.join(random.choices(characters, k=length))
|
|
|
+ return random_id
|