Przeglądaj źródła

fix the issue of the refine_switches at param being invalid in the Novita.AI tool (#7485)

Xiao Ley 8 miesięcy temu
rodzic
commit
0c99a3d0c5

+ 73 - 0
api/core/tools/provider/builtin/novitaai/_novita_tool_base.py

@@ -0,0 +1,73 @@
+from novita_client import (
+    Txt2ImgV3Embedding,
+    Txt2ImgV3HiresFix,
+    Txt2ImgV3LoRA,
+    Txt2ImgV3Refiner,
+    V3TaskImage,
+)
+
+
+class NovitaAiToolBase:
+    def _extract_loras(self, loras_str: str):
+        if not loras_str:
+            return []
+
+        loras_ori_list = lora_str.strip().split(';')
+        result_list = []
+        for lora_str in loras_ori_list:
+            lora_info = lora_str.strip().split(',')
+            lora = Txt2ImgV3LoRA(
+                model_name=lora_info[0].strip(),
+                strength=float(lora_info[1]),
+            )
+            result_list.append(lora)
+
+        return result_list
+
+    def _extract_embeddings(self, embeddings_str: str):
+        if not embeddings_str:
+            return []
+
+        embeddings_ori_list = embeddings_str.strip().split(';')
+        result_list = []
+        for embedding_str in embeddings_ori_list:
+            embedding = Txt2ImgV3Embedding(
+                model_name=embedding_str.strip()
+            )
+            result_list.append(embedding)
+
+        return result_list
+
+    def _extract_hires_fix(self, hires_fix_str: str):
+        hires_fix_info = hires_fix_str.strip().split(',')
+        if 'upscaler' in hires_fix_info:
+            hires_fix = Txt2ImgV3HiresFix(
+                target_width=int(hires_fix_info[0]),
+                target_height=int(hires_fix_info[1]),
+                strength=float(hires_fix_info[2]),
+                upscaler=hires_fix_info[3].strip()
+            )
+        else:
+            hires_fix = Txt2ImgV3HiresFix(
+                target_width=int(hires_fix_info[0]),
+                target_height=int(hires_fix_info[1]),
+                strength=float(hires_fix_info[2])
+            )
+
+        return hires_fix
+
+    def _extract_refiner(self, switch_at: str):
+        refiner = Txt2ImgV3Refiner(
+            switch_at=float(switch_at)
+        )
+        return refiner
+
+    def _is_hit_nsfw_detection(self, image: V3TaskImage, confidence_threshold: float) -> bool:
+        """
+            is hit nsfw
+        """
+        if image.nsfw_detection_result is None:
+            return False
+        if image.nsfw_detection_result.valid and image.nsfw_detection_result.confidence >= confidence_threshold:
+            return True
+        return False

+ 10 - 60
api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py

@@ -4,19 +4,15 @@ from typing import Any, Union
 
 from novita_client import (
     NovitaClient,
-    Txt2ImgV3Embedding,
-    Txt2ImgV3HiresFix,
-    Txt2ImgV3LoRA,
-    Txt2ImgV3Refiner,
-    V3TaskImage,
 )
 
 from core.tools.entities.tool_entities import ToolInvokeMessage
 from core.tools.errors import ToolProviderCredentialValidationError
+from core.tools.provider.builtin.novitaai._novita_tool_base import NovitaAiToolBase
 from core.tools.tool.builtin_tool import BuiltinTool
 
 
-class NovitaAiTxt2ImgTool(BuiltinTool):
+class NovitaAiTxt2ImgTool(BuiltinTool, NovitaAiToolBase):
     def _invoke(self,
                 user_id: str,
                 tool_parameters: dict[str, Any],
@@ -73,65 +69,19 @@ class NovitaAiTxt2ImgTool(BuiltinTool):
 
         # process loras
         if 'loras' in res_parameters:
-            loras_ori_list = res_parameters.get('loras').strip().split(';')
-            locals_list = []
-            for lora_str in loras_ori_list:
-                lora_info = lora_str.strip().split(',')
-                lora = Txt2ImgV3LoRA(
-                    model_name=lora_info[0].strip(),
-                    strength=float(lora_info[1]),
-                )
-                locals_list.append(lora)
-
-            res_parameters['loras'] = locals_list
+            res_parameters['loras'] = self._extract_loras(res_parameters.get('loras'))
 
         # process embeddings
         if 'embeddings' in res_parameters:
-            embeddings_ori_list = res_parameters.get('embeddings').strip().split(';')
-            locals_list = []
-            for embedding_str in embeddings_ori_list:
-                embedding = Txt2ImgV3Embedding(
-                    model_name=embedding_str.strip()
-                )
-                locals_list.append(embedding)
-
-            res_parameters['embeddings'] = locals_list
+            res_parameters['embeddings'] = self._extract_embeddings(res_parameters.get('embeddings'))
 
         # process hires_fix
         if 'hires_fix' in res_parameters:
-            hires_fix_ori = res_parameters.get('hires_fix')
-            hires_fix_info = hires_fix_ori.strip().split(',')
-            if 'upscaler' in hires_fix_info:
-                hires_fix = Txt2ImgV3HiresFix(
-                    target_width=int(hires_fix_info[0]),
-                    target_height=int(hires_fix_info[1]),
-                    strength=float(hires_fix_info[2]),
-                    upscaler=hires_fix_info[3].strip()
-                )
-            else:
-                hires_fix = Txt2ImgV3HiresFix(
-                    target_width=int(hires_fix_info[0]),
-                    target_height=int(hires_fix_info[1]),
-                    strength=float(hires_fix_info[2])
-                )
-
-            res_parameters['hires_fix'] = hires_fix
-
-            if 'refiner_switch_at' in res_parameters:
-                refiner = Txt2ImgV3Refiner(
-                    switch_at=float(res_parameters.get('refiner_switch_at'))
-                )
-                del res_parameters['refiner_switch_at']
-                res_parameters['refiner'] = refiner
+            res_parameters['hires_fix'] = self._extract_hires_fix(res_parameters.get('hires_fix'))
 
-        return res_parameters
+        # process refiner
+        if 'refiner_switch_at' in res_parameters:
+            res_parameters['refiner'] = self._extract_refiner(res_parameters.get('refiner_switch_at'))
+            del res_parameters['refiner_switch_at']
 
-    def _is_hit_nsfw_detection(self, image: V3TaskImage, confidence_threshold: float) -> bool:
-        """
-            is hit nsfw
-        """
-        if image.nsfw_detection_result is None:
-            return False
-        if image.nsfw_detection_result.valid and image.nsfw_detection_result.confidence >= confidence_threshold:
-            return True
-        return False
+        return res_parameters

Plik diff jest za duży
+ 417 - 405
api/poetry.lock


+ 1 - 1
api/pyproject.toml

@@ -153,7 +153,7 @@ langfuse = "^2.36.1"
 langsmith = "^0.1.77"
 mailchimp-transactional = "~1.0.50"
 markdown = "~3.5.1"
-novita-client = "^0.5.6"
+novita-client = "^0.5.7"
 numpy = "~1.26.4"
 openai = "~1.29.0"
 openpyxl = "~3.1.5"

Niektóre pliki nie zostały wyświetlone z powodu dużej ilości zmienionych plików