Sfoglia il codice sorgente

fix: create_blob_message of tool will always create image type file (#10701)

非法操作 5 mesi fa
parent
commit
4b2abf8ac2
2 ha cambiato i file con 16 aggiunte e 10 eliminazioni
  1. 0 9
      api/core/workflow/nodes/tool/tool_node.py
  2. 16 1
      api/factories/file_factory.py

+ 0 - 9
api/core/workflow/nodes/tool/tool_node.py

@@ -1,5 +1,4 @@
 from collections.abc import Mapping, Sequence
-from os import path
 from typing import Any
 
 from sqlalchemy import select
@@ -180,7 +179,6 @@ class ToolNode(BaseNode[ToolNodeData]):
         for response in tool_response:
             if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
                 url = str(response.message) if response.message else None
-                ext = path.splitext(url)[1] if url else ".bin"
                 tool_file_id = str(url).split("/")[-1].split(".")[0]
                 transfer_method = response.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
 
@@ -202,7 +200,6 @@ class ToolNode(BaseNode[ToolNodeData]):
                 )
                 result.append(file)
             elif response.type == ToolInvokeMessage.MessageType.BLOB:
-                # get tool file id
                 tool_file_id = str(response.message).split("/")[-1].split(".")[0]
                 with Session(db.engine) as session:
                     stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
@@ -211,7 +208,6 @@ class ToolNode(BaseNode[ToolNodeData]):
                         raise ValueError(f"tool file {tool_file_id} not exists")
                 mapping = {
                     "tool_file_id": tool_file_id,
-                    "type": FileType.IMAGE,
                     "transfer_method": FileTransferMethod.TOOL_FILE,
                 }
                 file = file_factory.build_from_mapping(
@@ -228,13 +224,8 @@ class ToolNode(BaseNode[ToolNodeData]):
                     tool_file = session.scalar(stmt)
                     if tool_file is None:
                         raise ToolFileError(f"Tool file {tool_file_id} does not exist")
-                if "." in url:
-                    extension = "." + url.split("/")[-1].split(".")[1]
-                else:
-                    extension = ".bin"
                 mapping = {
                     "tool_file_id": tool_file_id,
-                    "type": FileType.IMAGE,
                     "transfer_method": transfer_method,
                     "url": url,
                 }

+ 16 - 1
api/factories/file_factory.py

@@ -180,6 +180,20 @@ def _get_remote_file_info(url: str):
     return mime_type, filename, file_size
 
 
+def _get_file_type_by_mimetype(mime_type: str) -> FileType:
+    if "image" in mime_type:
+        file_type = FileType.IMAGE
+    elif "video" in mime_type:
+        file_type = FileType.VIDEO
+    elif "audio" in mime_type:
+        file_type = FileType.AUDIO
+    elif "text" in mime_type or "pdf" in mime_type:
+        file_type = FileType.DOCUMENT
+    else:
+        file_type = FileType.CUSTOM
+    return file_type
+
+
 def _build_from_tool_file(
     *,
     mapping: Mapping[str, Any],
@@ -199,12 +213,13 @@ def _build_from_tool_file(
         raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found")
 
     extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
+    file_type = mapping.get("type", _get_file_type_by_mimetype(tool_file.mimetype))
 
     return File(
         id=mapping.get("id"),
         tenant_id=tenant_id,
         filename=tool_file.name,
-        type=FileType.value_of(mapping.get("type")),
+        type=file_type,
         transfer_method=transfer_method,
         remote_url=tool_file.original_url,
         related_id=tool_file.id,