Bläddra i källkod

enhance: use override_settings for concurrent stable diffusion (#2818)

Qun 1 år sedan
förälder
incheckning
1e5455e266

+ 9 - 5
api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py

@@ -131,7 +131,8 @@ class StableDiffusionTool(BuiltinTool):
                                     negative_prompt=negative_prompt,
                                     width=width,
                                     height=height,
-                                    steps=steps)
+                                    steps=steps,
+                                    model=model)
             
         return self.text2img(base_url=base_url,
                              lora=lora,
@@ -139,7 +140,8 @@ class StableDiffusionTool(BuiltinTool):
                              negative_prompt=negative_prompt,
                              width=width,
                              height=height,
-                             steps=steps)
+                             steps=steps,
+                             model=model)
 
     def validate_models(self) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
         """
@@ -197,7 +199,7 @@ class StableDiffusionTool(BuiltinTool):
 
     def img2img(self, base_url: str, lora: str, image_binary: bytes, 
                 prompt: str, negative_prompt: str,
-                width: int, height: int, steps: int) \
+                width: int, height: int, steps: int, model: str) \
         -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
         """
             generate image
@@ -213,7 +215,8 @@ class StableDiffusionTool(BuiltinTool):
             "sampler_name": "Euler a",
             "restore_faces": False,
             "steps": steps,
-            "script_args": ["outpainting mk2"]
+            "script_args": ["outpainting mk2"],
+            "override_settings": {"sd_model_checkpoint": model}
         }
 
         if lora:
@@ -236,7 +239,7 @@ class StableDiffusionTool(BuiltinTool):
         except Exception as e:
             return self.create_text_message('Failed to generate image')
 
-    def text2img(self, base_url: str, lora: str, prompt: str, negative_prompt: str, width: int, height: int, steps: int) \
+    def text2img(self, base_url: str, lora: str, prompt: str, negative_prompt: str, width: int, height: int, steps: int, model: str) \
         -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
         """
             generate image
@@ -253,6 +256,7 @@ class StableDiffusionTool(BuiltinTool):
         draw_options['height'] = height
         draw_options['steps'] = steps
         draw_options['negative_prompt'] = negative_prompt
+        draw_options['override_settings']['sd_model_checkpoint'] = model
         
         try:
             url = str(URL(base_url) / 'sdapi' / 'v1' / 'txt2img')