|
@@ -4,7 +4,7 @@ from core.tools.entities.common_entities import I18nObject
|
|
|
from core.tools.errors import ToolProviderCredentialValidationError
|
|
|
|
|
|
from typing import Any, Dict, List, Union
|
|
|
-from httpx import post
|
|
|
+from httpx import post, get
|
|
|
from os.path import join
|
|
|
from base64 import b64decode, b64encode
|
|
|
from PIL import Image
|
|
@@ -59,6 +59,7 @@ DRAW_TEXT_OPTIONS = {
|
|
|
"alwayson_scripts": {}
|
|
|
}
|
|
|
|
|
|
+
|
|
|
class StableDiffusionTool(BuiltinTool):
|
|
|
def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
|
|
|
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
|
@@ -136,7 +137,31 @@ class StableDiffusionTool(BuiltinTool):
|
|
|
width=width,
|
|
|
height=height,
|
|
|
steps=steps)
|
|
|
-
|
|
|
+
|
|
|
+ def validate_models(self) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
|
|
+ """
|
|
|
+ validate models
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ base_url = self.runtime.credentials.get('base_url', None)
|
|
|
+ if not base_url:
|
|
|
+ raise ToolProviderCredentialValidationError('Please input base_url')
|
|
|
+ model = self.runtime.credentials.get('model', None)
|
|
|
+ if not model:
|
|
|
+ raise ToolProviderCredentialValidationError('Please input model')
|
|
|
+
|
|
|
+ response = get(url=f'{base_url}/sdapi/v1/sd-models', timeout=120)
|
|
|
+ if response.status_code != 200:
|
|
|
+ raise ToolProviderCredentialValidationError('Failed to get models')
|
|
|
+ else:
|
|
|
+ models = [d['model_name'] for d in response.json()]
|
|
|
+ if len([d for d in models if d == model]) > 0:
|
|
|
+ return self.create_text_message(json.dumps(models))
|
|
|
+ else:
|
|
|
+ raise ToolProviderCredentialValidationError(f'model {model} does not exist')
|
|
|
+ except Exception as e:
|
|
|
+ raise ToolProviderCredentialValidationError(f'Failed to get models, {e}')
|
|
|
+
|
|
|
def img2img(self, base_url: str, lora: str, image_binary: bytes,
|
|
|
prompt: str, negative_prompt: str,
|
|
|
width: int, height: int, steps: int) \
|
|
@@ -211,10 +236,9 @@ class StableDiffusionTool(BuiltinTool):
|
|
|
except Exception as e:
|
|
|
return self.create_text_message('Failed to generate image')
|
|
|
|
|
|
-
|
|
|
def get_runtime_parameters(self) -> List[ToolParameter]:
|
|
|
parameters = [
|
|
|
- ToolParameter(name='prompt',
|
|
|
+ ToolParameter(name='prompt',
|
|
|
label=I18nObject(en_US='Prompt', zh_Hans='Prompt'),
|
|
|
human_description=I18nObject(
|
|
|
en_US='Image prompt, you can check the official documentation of Stable Diffusion',
|
|
@@ -227,7 +251,7 @@ class StableDiffusionTool(BuiltinTool):
|
|
|
]
|
|
|
if len(self.list_default_image_variables()) != 0:
|
|
|
parameters.append(
|
|
|
- ToolParameter(name='image_id',
|
|
|
+ ToolParameter(name='image_id',
|
|
|
label=I18nObject(en_US='image_id', zh_Hans='image_id'),
|
|
|
human_description=I18nObject(
|
|
|
en_US='Image id of the image you want to generate based on, if you want to generate image based on the default image, you can leave this field empty.',
|