Sfoglia il codice sorgente

Optimization stable diffusion verify (#2322)

Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Charlie.Wei 1 anno fa
parent
commit
5929e84036

+ 2 - 10
api/core/tools/provider/builtin/stablediffusion/stablediffusion.py

@@ -5,6 +5,7 @@ from core.tools.provider.builtin.stablediffusion.tools.stable_diffusion import S
 
 from typing import Any, Dict
 
+
 class StableDiffusionProvider(BuiltinToolProviderController):
     def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
         try:
@@ -12,15 +13,6 @@ class StableDiffusionProvider(BuiltinToolProviderController):
                 meta={
                     "credentials": credentials,
                 }
-            ).invoke(
-                user_id='',
-                tool_parameters={
-                    "prompt": "cat",
-                    "lora": "",
-                    "steps": 1,
-                    "width": 512,
-                    "height": 512,
-                },
-            )
+            ).validate_models()
         except Exception as e:
             raise ToolProviderCredentialValidationError(str(e))

+ 29 - 5
api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py

@@ -4,7 +4,7 @@ from core.tools.entities.common_entities import I18nObject
 from core.tools.errors import ToolProviderCredentialValidationError
 
 from typing import Any, Dict, List, Union
-from httpx import post
+from httpx import post, get
 from os.path import join
 from base64 import b64decode, b64encode
 from PIL import Image
@@ -59,6 +59,7 @@ DRAW_TEXT_OPTIONS = {
     "alwayson_scripts": {}
 }
 
+
 class StableDiffusionTool(BuiltinTool):
     def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
         -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
@@ -136,7 +137,31 @@ class StableDiffusionTool(BuiltinTool):
                              width=width,
                              height=height,
                              steps=steps)
-        
+
+    def validate_models(self) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
+        """
+            validate models
+        """
+        try:
+            base_url = self.runtime.credentials.get('base_url', None)
+            if not base_url:
+                raise ToolProviderCredentialValidationError('Please input base_url')
+            model = self.runtime.credentials.get('model', None)
+            if not model:
+                raise ToolProviderCredentialValidationError('Please input model')
+
+            response = get(url=f'{base_url}/sdapi/v1/sd-models', timeout=120)
+            if response.status_code != 200:
+                raise ToolProviderCredentialValidationError('Failed to get models')
+            else:
+                models = [d['model_name'] for d in response.json()]
+                if len([d for d in models if d == model]) > 0:
+                    return self.create_text_message(json.dumps(models))
+                else:
+                    raise ToolProviderCredentialValidationError(f'model {model} does not exist')
+        except Exception as e:
+            raise ToolProviderCredentialValidationError(f'Failed to get models, {e}')
+
     def img2img(self, base_url: str, lora: str, image_binary: bytes, 
                 prompt: str, negative_prompt: str,
                 width: int, height: int, steps: int) \
@@ -211,10 +236,9 @@ class StableDiffusionTool(BuiltinTool):
         except Exception as e:
             return self.create_text_message('Failed to generate image')
 
-
     def get_runtime_parameters(self) -> List[ToolParameter]:
         parameters = [
-            ToolParameter(name='prompt', 
+            ToolParameter(name='prompt',
                          label=I18nObject(en_US='Prompt', zh_Hans='Prompt'),
                          human_description=I18nObject(
                              en_US='Image prompt, you can check the official documentation of Stable Diffusion',
@@ -227,7 +251,7 @@ class StableDiffusionTool(BuiltinTool):
         ]
         if len(self.list_default_image_variables()) != 0:
             parameters.append(
-                ToolParameter(name='image_id', 
+                ToolParameter(name='image_id',
                              label=I18nObject(en_US='image_id', zh_Hans='image_id'),
                              human_description=I18nObject(
                                 en_US='Image id of the image you want to generate based on, if you want to generate image based on the default image, you can leave this field empty.',