|
@@ -48,7 +48,6 @@ class ComfyUiClient:
|
|
|
prompt = origin_prompt.copy()
|
|
|
id_to_class_type = {id: details["class_type"] for id, details in prompt.items()}
|
|
|
k_sampler = [key for key, value in id_to_class_type.items() if value == "KSampler"][0]
|
|
|
- prompt.get(k_sampler)["inputs"]["seed"] = random.randint(10**14, 10**15 - 1)
|
|
|
positive_input_id = prompt.get(k_sampler)["inputs"]["positive"][0]
|
|
|
prompt.get(positive_input_id)["inputs"]["text"] = positive_prompt
|
|
|
|
|
@@ -72,6 +71,18 @@ class ComfyUiClient:
|
|
|
prompt.get(load_image)["inputs"]["image"] = image_name
|
|
|
return prompt
|
|
|
|
|
|
+ def set_prompt_seed_by_id(self, origin_prompt: dict, seed_id: str) -> dict:
|
|
|
+ prompt = origin_prompt.copy()
|
|
|
+ if seed_id not in prompt:
|
|
|
+ raise Exception("Not a valid seed node")
|
|
|
+ if "seed" in prompt[seed_id]["inputs"]:
|
|
|
+ prompt[seed_id]["inputs"]["seed"] = random.randint(10**14, 10**15 - 1)
|
|
|
+ elif "noise_seed" in prompt[seed_id]["inputs"]:
|
|
|
+ prompt[seed_id]["inputs"]["noise_seed"] = random.randint(10**14, 10**15 - 1)
|
|
|
+ else:
|
|
|
+ raise Exception("Not a valid seed node")
|
|
|
+ return prompt
|
|
|
+
|
|
|
def track_progress(self, prompt: dict, ws: WebSocket, prompt_id: str):
|
|
|
node_ids = list(prompt.keys())
|
|
|
finished_nodes = []
|