|
@@ -1,5 +1,3 @@
|
|
|
-import base64
|
|
|
-import io
|
|
|
import json
|
|
|
import random
|
|
|
import uuid
|
|
@@ -8,7 +6,7 @@ import httpx
|
|
|
from websocket import WebSocket
|
|
|
from yarl import URL
|
|
|
|
|
|
-from core.file.file_manager import _get_encoded_string
|
|
|
+from core.file.file_manager import download
|
|
|
from core.file.models import File
|
|
|
|
|
|
|
|
@@ -29,8 +27,7 @@ class ComfyUiClient:
|
|
|
return response.content
|
|
|
|
|
|
def upload_image(self, image_file: File) -> dict:
|
|
|
- image_content = base64.b64decode(_get_encoded_string(image_file))
|
|
|
- file = io.BytesIO(image_content)
|
|
|
+ file = download(image_file)
|
|
|
files = {"image": (image_file.filename, file, image_file.mime_type), "overwrite": "true"}
|
|
|
res = httpx.post(str(self.base_url / "upload/image"), files=files)
|
|
|
return res.json()
|
|
@@ -47,12 +44,7 @@ class ComfyUiClient:
|
|
|
ws.connect(ws_address)
|
|
|
return ws, client_id
|
|
|
|
|
|
- def set_prompt(
|
|
|
- self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = "", image_name: str = ""
|
|
|
- ) -> dict:
|
|
|
- """
|
|
|
- find the first KSampler, then can find the prompt node through it.
|
|
|
- """
|
|
|
+ def set_prompt_by_ksampler(self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = "") -> dict:
|
|
|
prompt = origin_prompt.copy()
|
|
|
id_to_class_type = {id: details["class_type"] for id, details in prompt.items()}
|
|
|
k_sampler = [key for key, value in id_to_class_type.items() if value == "KSampler"][0]
|
|
@@ -64,9 +56,20 @@ class ComfyUiClient:
|
|
|
negative_input_id = prompt.get(k_sampler)["inputs"]["negative"][0]
|
|
|
prompt.get(negative_input_id)["inputs"]["text"] = negative_prompt
|
|
|
|
|
|
- if image_name != "":
|
|
|
- image_loader = [key for key, value in id_to_class_type.items() if value == "LoadImage"][0]
|
|
|
- prompt.get(image_loader)["inputs"]["image"] = image_name
|
|
|
+ return prompt
|
|
|
+
|
|
|
+ def set_prompt_images_by_ids(self, origin_prompt: dict, image_names: list[str], image_ids: list[str]) -> dict:
|
|
|
+ prompt = origin_prompt.copy()
|
|
|
+ for index, image_node_id in enumerate(image_ids):
|
|
|
+ prompt[image_node_id]["inputs"]["image"] = image_names[index]
|
|
|
+ return prompt
|
|
|
+
|
|
|
+ def set_prompt_images_by_default(self, origin_prompt: dict, image_names: list[str]) -> dict:
|
|
|
+ prompt = origin_prompt.copy()
|
|
|
+ id_to_class_type = {id: details["class_type"] for id, details in prompt.items()}
|
|
|
+ load_image_nodes = [key for key, value in id_to_class_type.items() if value == "LoadImage"]
|
|
|
+ for load_image, image_name in zip(load_image_nodes, image_names):
|
|
|
+ prompt.get(load_image)["inputs"]["image"] = image_name
|
|
|
return prompt
|
|
|
|
|
|
def track_progress(self, prompt: dict, ws: WebSocket, prompt_id: str):
|