浏览代码

Update stable_diffusion.py (#7536)

Jie.F 8 月之前
父节点
当前提交
70d6ab0bf5
共有 1 个文件被更改,包括 39 次插入2 次删除
  1. 39 2
      api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py

+ 39 - 2
api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py

@@ -27,7 +27,7 @@ DRAW_TEXT_OPTIONS = {
     "seed_resize_from_w": -1,
 
     # Samplers
-    # "sampler_name": "DPM++ 2M",
+    "sampler_name": "DPM++ 2M",
     # "scheduler": "",
     # "sampler_index": "Automatic",
 
@@ -178,6 +178,23 @@ class StableDiffusionTool(BuiltinTool):
                 return [d['model_name'] for d in response.json()]
         except Exception as e:
             return []
+        
+    def get_sample_methods(self) -> list[str]:
+        """
+            get sample method
+        """
+        try:
+            base_url = self.runtime.credentials.get('base_url', None)
+            if not base_url:
+                return []
+            api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'samplers')
+            response = get(url=api_url, timeout=(2, 10))
+            if response.status_code != 200:
+                return []
+            else:
+                return [d['name'] for d in response.json()]
+        except Exception as e:
+            return []
 
     def img2img(self, base_url: str, tool_parameters: dict[str, Any]) \
         -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
@@ -339,7 +356,27 @@ class StableDiffusionTool(BuiltinTool):
                                          label=I18nObject(en_US=i, zh_Hans=i)
                                      ) for i in models])
                     )
+                    
             except:
                 pass
-
+            
+            sample_methods = self.get_sample_methods()
+            if len(sample_methods) != 0:
+                parameters.append(
+                        ToolParameter(name='sampler_name',
+                                     label=I18nObject(en_US='Sampling method', zh_Hans='Sampling method'),
+                                     human_description=I18nObject(
+                                        en_US='Sampling method of Stable Diffusion, you can check the official documentation of Stable Diffusion',
+                                        zh_Hans='Stable Diffusion 的Sampling method,您可以查看 Stable Diffusion 的官方文档',
+                                     ),
+                                     type=ToolParameter.ToolParameterType.SELECT,
+                                     form=ToolParameter.ToolParameterForm.FORM,
+                                     llm_description='Sampling method of Stable Diffusion, you can check the official documentation of Stable Diffusion',
+                                     required=True,
+                                     default=sample_methods[0],
+                                     options=[ToolParameterOption(
+                                         value=i,
+                                         label=I18nObject(en_US=i, zh_Hans=i)
+                                     ) for i in sample_methods])
+                    )
         return parameters