Pārlūkot izejas kodu

chore(api/core): Improve FileVar's type hint and imports. (#7290)

-LAN- 8 mēneši atpakaļ
vecāks
revīzija
8f16165f92

+ 8 - 6
api/core/app/apps/base_app_runner.py

@@ -1,6 +1,6 @@
 import time
 from collections.abc import Generator
-from typing import Optional, Union
+from typing import TYPE_CHECKING, Optional, Union
 
 from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
 from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@@ -14,7 +14,6 @@ from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChu
 from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
 from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
 from core.external_data_tool.external_data_fetch import ExternalDataFetch
-from core.file.file_obj import FileVar
 from core.memory.token_buffer_memory import TokenBufferMemory
 from core.model_manager import ModelInstance
 from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
@@ -27,13 +26,16 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
 from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
 from models.model import App, AppMode, Message, MessageAnnotation
 
+if TYPE_CHECKING:
+    from core.file.file_obj import FileVar
+
 
 class AppRunner:
     def get_pre_calculate_rest_tokens(self, app_record: App,
                                       model_config: ModelConfigWithCredentialsEntity,
                                       prompt_template_entity: PromptTemplateEntity,
                                       inputs: dict[str, str],
-                                      files: list[FileVar],
+                                      files: list["FileVar"],
                                       query: Optional[str] = None) -> int:
         """
         Get pre calculate rest tokens
@@ -126,7 +128,7 @@ class AppRunner:
                                  model_config: ModelConfigWithCredentialsEntity,
                                  prompt_template_entity: PromptTemplateEntity,
                                  inputs: dict[str, str],
-                                 files: list[FileVar],
+                                 files: list["FileVar"],
                                  query: Optional[str] = None,
                                  context: Optional[str] = None,
                                  memory: Optional[TokenBufferMemory] = None) \
@@ -366,7 +368,7 @@ class AppRunner:
             message_id=message_id,
             trace_manager=app_generate_entity.trace_manager
         )
-    
+
     def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity,
                                  queue_manager: AppQueueManager,
                                  prompt_messages: list[PromptMessage]) -> bool:
@@ -418,7 +420,7 @@ class AppRunner:
             inputs=inputs,
             query=query
         )
-    
+
     def query_app_annotations_to_reply(self, app_record: App,
                                        message: Message,
                                        query: str,

+ 1 - 1
api/core/app/entities/app_invoke_entities.py

@@ -166,4 +166,4 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
         node_id: str
         inputs: dict
 
-    single_iteration_run: Optional[SingleIterationRunEntity] = None
+    single_iteration_run: Optional[SingleIterationRunEntity] = None

+ 2 - 2
api/core/file/message_file_parser.py

@@ -99,7 +99,7 @@ class MessageFileParser:
         # return all file objs
         return new_files
 
-    def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig) -> list[FileVar]:
+    def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig):
         """
         transform message files
 
@@ -144,7 +144,7 @@ class MessageFileParser:
 
         return type_file_objs
 
-    def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig) -> FileVar:
+    def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig):
         """
         transform file to file obj
 

+ 8 - 6
api/core/prompt/simple_prompt_transform.py

@@ -1,11 +1,10 @@
 import enum
 import json
 import os
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
 
 from core.app.app_config.entities import PromptTemplateEntity
 from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
-from core.file.file_obj import FileVar
 from core.memory.token_buffer_memory import TokenBufferMemory
 from core.model_runtime.entities.message_entities import (
     PromptMessage,
@@ -18,6 +17,9 @@ from core.prompt.prompt_transform import PromptTransform
 from core.prompt.utils.prompt_template_parser import PromptTemplateParser
 from models.model import AppMode
 
+if TYPE_CHECKING:
+    from core.file.file_obj import FileVar
+
 
 class ModelMode(enum.Enum):
     COMPLETION = 'completion'
@@ -50,7 +52,7 @@ class SimplePromptTransform(PromptTransform):
                    prompt_template_entity: PromptTemplateEntity,
                    inputs: dict,
                    query: str,
-                   files: list[FileVar],
+                   files: list["FileVar"],
                    context: Optional[str],
                    memory: Optional[TokenBufferMemory],
                    model_config: ModelConfigWithCredentialsEntity) -> \
@@ -163,7 +165,7 @@ class SimplePromptTransform(PromptTransform):
                                         inputs: dict,
                                         query: str,
                                         context: Optional[str],
-                                        files: list[FileVar],
+                                        files: list["FileVar"],
                                         memory: Optional[TokenBufferMemory],
                                         model_config: ModelConfigWithCredentialsEntity) \
             -> tuple[list[PromptMessage], Optional[list[str]]]:
@@ -206,7 +208,7 @@ class SimplePromptTransform(PromptTransform):
                                               inputs: dict,
                                               query: str,
                                               context: Optional[str],
-                                              files: list[FileVar],
+                                              files: list["FileVar"],
                                               memory: Optional[TokenBufferMemory],
                                               model_config: ModelConfigWithCredentialsEntity) \
             -> tuple[list[PromptMessage], Optional[list[str]]]:
@@ -255,7 +257,7 @@ class SimplePromptTransform(PromptTransform):
 
         return [self.get_last_user_message(prompt, files)], stops
 
-    def get_last_user_message(self, prompt: str, files: list[FileVar]) -> UserPromptMessage:
+    def get_last_user_message(self, prompt: str, files: list["FileVar"]) -> UserPromptMessage:
         if files:
             prompt_message_contents = [TextPromptMessageContent(data=prompt)]
             for file in files:

+ 33 - 31
api/core/tools/tool/tool.py

@@ -2,13 +2,12 @@ from abc import ABC, abstractmethod
 from collections.abc import Mapping
 from copy import deepcopy
 from enum import Enum
-from typing import Any, Optional, Union
+from typing import TYPE_CHECKING, Any, Optional, Union
 
 from pydantic import BaseModel, ConfigDict, field_validator
 from pydantic_core.core_schema import ValidationInfo
 
 from core.app.entities.app_invoke_entities import InvokeFrom
-from core.file.file_obj import FileVar
 from core.tools.entities.tool_entities import (
     ToolDescription,
     ToolIdentity,
@@ -23,6 +22,9 @@ from core.tools.entities.tool_entities import (
 from core.tools.tool_file_manager import ToolFileManager
 from core.tools.utils.tool_parameter_converter import ToolParameterConverter
 
+if TYPE_CHECKING:
+    from core.file.file_obj import FileVar
+
 
 class Tool(BaseModel, ABC):
     identity: Optional[ToolIdentity] = None
@@ -76,7 +78,7 @@ class Tool(BaseModel, ABC):
             description=self.description.model_copy() if self.description else None,
             runtime=Tool.Runtime(**runtime),
         )
-    
+
     @abstractmethod
     def tool_provider_type(self) -> ToolProviderType:
         """
@@ -84,7 +86,7 @@ class Tool(BaseModel, ABC):
 
             :return: the tool provider type
         """
-    
+
     def load_variables(self, variables: ToolRuntimeVariablePool):
         """
             load variables from database
@@ -99,7 +101,7 @@ class Tool(BaseModel, ABC):
         """
         if not self.variables:
             return
-        
+
         self.variables.set_file(self.identity.name, variable_name, image_key)
 
     def set_text_variable(self, variable_name: str, text: str) -> None:
@@ -108,9 +110,9 @@ class Tool(BaseModel, ABC):
         """
         if not self.variables:
             return
-        
+
         self.variables.set_text(self.identity.name, variable_name, text)
-        
+
     def get_variable(self, name: Union[str, Enum]) -> Optional[ToolRuntimeVariable]:
         """
             get a variable
@@ -120,14 +122,14 @@ class Tool(BaseModel, ABC):
         """
         if not self.variables:
             return None
-        
+
         if isinstance(name, Enum):
             name = name.value
-        
+
         for variable in self.variables.pool:
             if variable.name == name:
                 return variable
-            
+
         return None
 
     def get_default_image_variable(self) -> Optional[ToolRuntimeVariable]:
@@ -138,9 +140,9 @@ class Tool(BaseModel, ABC):
         """
         if not self.variables:
             return None
-        
+
         return self.get_variable(self.VARIABLE_KEY.IMAGE)
-    
+
     def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]:
         """
             get a variable file
@@ -151,7 +153,7 @@ class Tool(BaseModel, ABC):
         variable = self.get_variable(name)
         if not variable:
             return None
-        
+
         if not isinstance(variable, ToolRuntimeImageVariable):
             return None
 
@@ -160,9 +162,9 @@ class Tool(BaseModel, ABC):
         file_binary = ToolFileManager.get_file_binary_by_message_file_id(message_file_id)
         if not file_binary:
             return None
-        
+
         return file_binary[0]
-    
+
     def list_variables(self) -> list[ToolRuntimeVariable]:
         """
             list all variables
@@ -171,9 +173,9 @@ class Tool(BaseModel, ABC):
         """
         if not self.variables:
             return []
-        
+
         return self.variables.pool
-    
+
     def list_default_image_variables(self) -> list[ToolRuntimeVariable]:
         """
             list all image variables
@@ -182,9 +184,9 @@ class Tool(BaseModel, ABC):
         """
         if not self.variables:
             return []
-        
+
         result = []
-        
+
         for variable in self.variables.pool:
             if variable.name.startswith(self.VARIABLE_KEY.IMAGE.value):
                 result.append(variable)
@@ -225,7 +227,7 @@ class Tool(BaseModel, ABC):
     @abstractmethod
     def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
         pass
-    
+
     def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
         """
             validate the credentials
@@ -244,7 +246,7 @@ class Tool(BaseModel, ABC):
             :return: the runtime parameters
         """
         return self.parameters or []
-    
+
     def get_all_runtime_parameters(self) -> list[ToolParameter]:
         """
             get all runtime parameters
@@ -278,7 +280,7 @@ class Tool(BaseModel, ABC):
                 parameters.append(parameter)
 
         return parameters
-    
+
     def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage:
         """
             create an image message
@@ -286,18 +288,18 @@ class Tool(BaseModel, ABC):
             :param image: the url of the image
             :return: the image message
         """
-        return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, 
-                                 message=image, 
+        return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE,
+                                 message=image,
                                  save_as=save_as)
-    
-    def create_file_var_message(self, file_var: FileVar) -> ToolInvokeMessage:
+
+    def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage:
         return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE_VAR,
                                  message='',
                                  meta={
                                      'file_var': file_var
                                  },
                                  save_as='')
-    
+
     def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage:
         """
             create a link message
@@ -305,10 +307,10 @@ class Tool(BaseModel, ABC):
             :param link: the url of the link
             :return: the link message
         """
-        return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK, 
-                                 message=link, 
+        return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK,
+                                 message=link,
                                  save_as=save_as)
-    
+
     def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage:
         """
             create a text message
@@ -321,7 +323,7 @@ class Tool(BaseModel, ABC):
             message=text,
             save_as=save_as
         )
-    
+
     def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage:
         """
             create a blob message

+ 8 - 8
api/core/tools/utils/message_transformer.py

@@ -1,7 +1,7 @@
 import logging
 from mimetypes import guess_extension
 
-from core.file.file_obj import FileTransferMethod, FileType, FileVar
+from core.file.file_obj import FileTransferMethod, FileType
 from core.tools.entities.tool_entities import ToolInvokeMessage
 from core.tools.tool_file_manager import ToolFileManager
 
@@ -27,12 +27,12 @@ class ToolFileMessageTransformer:
                 # try to download image
                 try:
                     file = ToolFileManager.create_file_by_url(
-                        user_id=user_id, 
+                        user_id=user_id,
                         tenant_id=tenant_id,
                         conversation_id=conversation_id,
                         file_url=message.message
                     )
-                    
+
                     url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}'
 
                     result.append(ToolInvokeMessage(
@@ -55,14 +55,14 @@ class ToolFileMessageTransformer:
                 # if message is str, encode it to bytes
                 if isinstance(message.message, str):
                     message.message = message.message.encode('utf-8')
-                
+
                 file = ToolFileManager.create_file_by_raw(
                     user_id=user_id, tenant_id=tenant_id,
                     conversation_id=conversation_id,
                     file_binary=message.message,
                     mimetype=mimetype
                 )
-                                                            
+
                 url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype))
 
                 # check if file is image
@@ -81,7 +81,7 @@ class ToolFileMessageTransformer:
                         meta=message.meta.copy() if message.meta is not None else {},
                     ))
             elif message.type == ToolInvokeMessage.MessageType.FILE_VAR:
-                file_var: FileVar = message.meta.get('file_var')
+                file_var = message.meta.get('file_var')
                 if file_var:
                     if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
                         url = cls.get_tool_file_url(file_var.related_id, file_var.extension)
@@ -103,7 +103,7 @@ class ToolFileMessageTransformer:
                 result.append(message)
 
         return result
-    
+
     @classmethod
     def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str:
-        return f'/files/tools/{tool_file_id}{extension or ".bin"}'
+        return f'/files/tools/{tool_file_id}{extension or ".bin"}'

+ 8 - 5
api/core/workflow/nodes/llm/llm_node.py

@@ -1,14 +1,13 @@
 import json
 from collections.abc import Generator
 from copy import deepcopy
-from typing import Optional, cast
+from typing import TYPE_CHECKING, Optional, cast
 
 from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
 from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
 from core.entities.model_entities import ModelStatus
 from core.entities.provider_entities import QuotaUnit
 from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
-from core.file.file_obj import FileVar
 from core.memory.token_buffer_memory import TokenBufferMemory
 from core.model_manager import ModelInstance, ModelManager
 from core.model_runtime.entities.llm_entities import LLMUsage
@@ -39,6 +38,10 @@ from models.model import Conversation
 from models.provider import Provider, ProviderType
 from models.workflow import WorkflowNodeExecutionStatus
 
+if TYPE_CHECKING:
+    from core.file.file_obj import FileVar
+
+
 
 class LLMNode(BaseNode):
     _node_data_cls = LLMNodeData
@@ -71,7 +74,7 @@ class LLMNode(BaseNode):
             node_inputs = {}
 
             # fetch files
-            files: list[FileVar] = self._fetch_files(node_data, variable_pool)
+            files = self._fetch_files(node_data, variable_pool)
 
             if files:
                 node_inputs['#files#'] = [file.to_dict() for file in files]
@@ -322,7 +325,7 @@ class LLMNode(BaseNode):
 
         return inputs
 
-    def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileVar]:
+    def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list["FileVar"]:
         """
         Fetch files
         :param node_data: node data
@@ -521,7 +524,7 @@ class LLMNode(BaseNode):
                                query: Optional[str],
                                query_prompt_template: Optional[str],
                                inputs: dict[str, str],
-                               files: list[FileVar],
+                               files: list["FileVar"],
                                context: Optional[str],
                                memory: Optional[TokenBufferMemory],
                                model_config: ModelConfigWithCredentialsEntity) \