|
@@ -5,7 +5,10 @@ import requests
|
|
|
from core.tools.entities.tool_entities import ToolInvokeMessage
|
|
|
from core.tools.tool.builtin_tool import BuiltinTool
|
|
|
|
|
|
-FLUX_URL = "https://api.siliconflow.cn/v1/black-forest-labs/FLUX.1-schnell/text-to-image"
|
|
|
+FLUX_URL = {
|
|
|
+ "schnell": "https://api.siliconflow.cn/v1/black-forest-labs/FLUX.1-schnell/text-to-image",
|
|
|
+ "dev": "https://api.siliconflow.cn/v1/image/generations",
|
|
|
+}
|
|
|
|
|
|
|
|
|
class FluxTool(BuiltinTool):
|
|
@@ -24,8 +27,12 @@ class FluxTool(BuiltinTool):
|
|
|
"seed": tool_parameters.get("seed"),
|
|
|
"num_inference_steps": tool_parameters.get("num_inference_steps", 20),
|
|
|
}
|
|
|
+ model = tool_parameters.get("model", "schnell")
|
|
|
+ url = FLUX_URL.get(model)
|
|
|
+ if model == "dev":
|
|
|
+ payload["model"] = "black-forest-labs/FLUX.1-dev"
|
|
|
|
|
|
- response = requests.post(FLUX_URL, json=payload, headers=headers)
|
|
|
+ response = requests.post(url, json=payload, headers=headers)
|
|
|
if response.status_code != 200:
|
|
|
return self.create_text_message(f"Got Error Response:{response.text}")
|
|
|
|