Parcourir la source

feat: Allow to contains files in the system prompt even model not support. (#11111)

-LAN- il y a 4 mois
Parent
commit
5b7b328193

+ 1 - 1
api/core/model_runtime/model_providers/anthropic/llm/llm.py

@@ -453,7 +453,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
 
         return credentials_kwargs
 
-    def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]:
+    def _convert_prompt_messages(self, prompt_messages: Sequence[PromptMessage]) -> tuple[str, list[dict]]:
         """
         Convert prompt messages to dict list and system
         """

+ 3 - 0
api/core/model_runtime/model_providers/openai/llm/llm.py

@@ -943,6 +943,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
                 }
         elif isinstance(message, SystemPromptMessage):
             message = cast(SystemPromptMessage, message)
+            if isinstance(message.content, list):
+                text_contents = filter(lambda c: isinstance(c, TextPromptMessageContent), message.content)
+                message.content = "".join(c.data for c in text_contents)
             message_dict = {"role": "system", "content": message.content}
         elif isinstance(message, ToolPromptMessage):
             message = cast(ToolPromptMessage, message)

+ 15 - 8
api/core/workflow/nodes/llm/node.py

@@ -20,6 +20,7 @@ from core.model_runtime.entities import (
 from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
 from core.model_runtime.entities.message_entities import (
     AssistantPromptMessage,
+    PromptMessageContent,
     PromptMessageRole,
     SystemPromptMessage,
     UserPromptMessage,
@@ -828,14 +829,14 @@ class LLMNode(BaseNode[LLMNodeData]):
         }
 
 
-def _combine_text_message_with_role(*, text: str, role: PromptMessageRole):
+def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
     match role:
         case PromptMessageRole.USER:
-            return UserPromptMessage(content=[TextPromptMessageContent(data=text)])
+            return UserPromptMessage(content=contents)
         case PromptMessageRole.ASSISTANT:
-            return AssistantPromptMessage(content=[TextPromptMessageContent(data=text)])
+            return AssistantPromptMessage(content=contents)
         case PromptMessageRole.SYSTEM:
-            return SystemPromptMessage(content=[TextPromptMessageContent(data=text)])
+            return SystemPromptMessage(content=contents)
     raise NotImplementedError(f"Role {role} is not supported")
 
 
@@ -877,7 +878,9 @@ def _handle_list_messages(
                 jinjia2_variables=jinja2_variables,
                 variable_pool=variable_pool,
             )
-            prompt_message = _combine_text_message_with_role(text=result_text, role=message.role)
+            prompt_message = _combine_message_content_with_role(
+                contents=[TextPromptMessageContent(data=result_text)], role=message.role
+            )
             prompt_messages.append(prompt_message)
         else:
             # Get segment group from basic message
@@ -908,12 +911,14 @@ def _handle_list_messages(
             # Create message with text from all segments
             plain_text = segment_group.text
             if plain_text:
-                prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role)
+                prompt_message = _combine_message_content_with_role(
+                    contents=[TextPromptMessageContent(data=plain_text)], role=message.role
+                )
                 prompt_messages.append(prompt_message)
 
             if file_contents:
                 # Create message with image contents
-                prompt_message = UserPromptMessage(content=file_contents)
+                prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role)
                 prompt_messages.append(prompt_message)
 
     return prompt_messages
@@ -1018,6 +1023,8 @@ def _handle_completion_template(
         else:
             template_text = template.text
         result_text = variable_pool.convert_template(template_text).text
-    prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER)
+    prompt_message = _combine_message_content_with_role(
+        contents=[TextPromptMessageContent(data=result_text)], role=PromptMessageRole.USER
+    )
     prompt_messages.append(prompt_message)
     return prompt_messages