|
@@ -4,19 +4,15 @@ from typing import Any, Union
|
|
|
|
|
|
from novita_client import (
|
|
|
NovitaClient,
|
|
|
- Txt2ImgV3Embedding,
|
|
|
- Txt2ImgV3HiresFix,
|
|
|
- Txt2ImgV3LoRA,
|
|
|
- Txt2ImgV3Refiner,
|
|
|
- V3TaskImage,
|
|
|
)
|
|
|
|
|
|
from core.tools.entities.tool_entities import ToolInvokeMessage
|
|
|
from core.tools.errors import ToolProviderCredentialValidationError
|
|
|
+from core.tools.provider.builtin.novitaai._novita_tool_base import NovitaAiToolBase
|
|
|
from core.tools.tool.builtin_tool import BuiltinTool
|
|
|
|
|
|
|
|
|
-class NovitaAiTxt2ImgTool(BuiltinTool):
|
|
|
+class NovitaAiTxt2ImgTool(BuiltinTool, NovitaAiToolBase):
|
|
|
def _invoke(self,
|
|
|
user_id: str,
|
|
|
tool_parameters: dict[str, Any],
|
|
@@ -73,65 +69,19 @@ class NovitaAiTxt2ImgTool(BuiltinTool):
|
|
|
|
|
|
# process loras
|
|
|
if 'loras' in res_parameters:
|
|
|
- loras_ori_list = res_parameters.get('loras').strip().split(';')
|
|
|
- locals_list = []
|
|
|
- for lora_str in loras_ori_list:
|
|
|
- lora_info = lora_str.strip().split(',')
|
|
|
- lora = Txt2ImgV3LoRA(
|
|
|
- model_name=lora_info[0].strip(),
|
|
|
- strength=float(lora_info[1]),
|
|
|
- )
|
|
|
- locals_list.append(lora)
|
|
|
-
|
|
|
- res_parameters['loras'] = locals_list
|
|
|
+ res_parameters['loras'] = self._extract_loras(res_parameters.get('loras'))
|
|
|
|
|
|
# process embeddings
|
|
|
if 'embeddings' in res_parameters:
|
|
|
- embeddings_ori_list = res_parameters.get('embeddings').strip().split(';')
|
|
|
- locals_list = []
|
|
|
- for embedding_str in embeddings_ori_list:
|
|
|
- embedding = Txt2ImgV3Embedding(
|
|
|
- model_name=embedding_str.strip()
|
|
|
- )
|
|
|
- locals_list.append(embedding)
|
|
|
-
|
|
|
- res_parameters['embeddings'] = locals_list
|
|
|
+ res_parameters['embeddings'] = self._extract_embeddings(res_parameters.get('embeddings'))
|
|
|
|
|
|
# process hires_fix
|
|
|
if 'hires_fix' in res_parameters:
|
|
|
- hires_fix_ori = res_parameters.get('hires_fix')
|
|
|
- hires_fix_info = hires_fix_ori.strip().split(',')
|
|
|
- if 'upscaler' in hires_fix_info:
|
|
|
- hires_fix = Txt2ImgV3HiresFix(
|
|
|
- target_width=int(hires_fix_info[0]),
|
|
|
- target_height=int(hires_fix_info[1]),
|
|
|
- strength=float(hires_fix_info[2]),
|
|
|
- upscaler=hires_fix_info[3].strip()
|
|
|
- )
|
|
|
- else:
|
|
|
- hires_fix = Txt2ImgV3HiresFix(
|
|
|
- target_width=int(hires_fix_info[0]),
|
|
|
- target_height=int(hires_fix_info[1]),
|
|
|
- strength=float(hires_fix_info[2])
|
|
|
- )
|
|
|
-
|
|
|
- res_parameters['hires_fix'] = hires_fix
|
|
|
-
|
|
|
- if 'refiner_switch_at' in res_parameters:
|
|
|
- refiner = Txt2ImgV3Refiner(
|
|
|
- switch_at=float(res_parameters.get('refiner_switch_at'))
|
|
|
- )
|
|
|
- del res_parameters['refiner_switch_at']
|
|
|
- res_parameters['refiner'] = refiner
|
|
|
+ res_parameters['hires_fix'] = self._extract_hires_fix(res_parameters.get('hires_fix'))
|
|
|
|
|
|
- return res_parameters
|
|
|
+ # process refiner
|
|
|
+ if 'refiner_switch_at' in res_parameters:
|
|
|
+ res_parameters['refiner'] = self._extract_refiner(res_parameters.get('refiner_switch_at'))
|
|
|
+ del res_parameters['refiner_switch_at']
|
|
|
|
|
|
- def _is_hit_nsfw_detection(self, image: V3TaskImage, confidence_threshold: float) -> bool:
|
|
|
- """
|
|
|
- is hit nsfw
|
|
|
- """
|
|
|
- if image.nsfw_detection_result is None:
|
|
|
- return False
|
|
|
- if image.nsfw_detection_result.valid and image.nsfw_detection_result.confidence >= confidence_threshold:
|
|
|
- return True
|
|
|
- return False
|
|
|
+ return res_parameters
|