Procházet zdrojové kódy

feat: filter empty content messages in llm node (#3547)

takatost před 1 rokem
rodič
revize
b890c11c14

+ 29 - 0
api/core/model_runtime/entities/message_entities.py

@@ -88,6 +88,14 @@ class PromptMessage(ABC, BaseModel):
     content: Optional[str | list[PromptMessageContent]] = None
     name: Optional[str] = None
 
+    def is_empty(self) -> bool:
+        """
+        Check if prompt message is empty.
+
+        :return: True if prompt message is empty, False otherwise
+        """
+        return not self.content
+
 
 class UserPromptMessage(PromptMessage):
     """
@@ -118,6 +126,16 @@ class AssistantPromptMessage(PromptMessage):
     role: PromptMessageRole = PromptMessageRole.ASSISTANT
     tool_calls: list[ToolCall] = []
 
+    def is_empty(self) -> bool:
+        """
+        Check if prompt message is empty.
+
+        :return: True if prompt message is empty, False otherwise
+        """
+        if not super().is_empty() and not self.tool_calls:
+            return False
+
+        return True
 
 class SystemPromptMessage(PromptMessage):
     """
@@ -132,3 +150,14 @@ class ToolPromptMessage(PromptMessage):
     """
     role: PromptMessageRole = PromptMessageRole.TOOL
     tool_call_id: str
+
+    def is_empty(self) -> bool:
+        """
+        Check if prompt message is empty.
+
+        :return: True if prompt message is empty, False otherwise
+        """
+        if not super().is_empty() and not self.tool_call_id:
+            return False
+
+        return True

+ 11 - 1
api/core/workflow/nodes/llm/llm_node.py

@@ -438,7 +438,11 @@ class LLMNode(BaseNode):
         stop = model_config.stop
 
         vision_enabled = node_data.vision.enabled
+        filtered_prompt_messages = []
         for prompt_message in prompt_messages:
+            if prompt_message.is_empty():
+                continue
+
             if not isinstance(prompt_message.content, str):
                 prompt_message_content = []
                 for content_item in prompt_message.content:
@@ -453,7 +457,13 @@ class LLMNode(BaseNode):
                       and prompt_message_content[0].type == PromptMessageContentType.TEXT):
                     prompt_message.content = prompt_message_content[0].data
 
-        return prompt_messages, stop
+            filtered_prompt_messages.append(prompt_message)
+
+        if not filtered_prompt_messages:
+            raise ValueError("No prompt found in the LLM configuration. "
+                             "Please ensure a prompt is properly configured before proceeding.")
+
+        return filtered_prompt_messages, stop
 
     @classmethod
     def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None: