Quellcode durchsuchen

fix: image text when retrieve chat histories (#3220)

takatost vor 1 Jahr
Ursprung
Commit
4ad3f2cdc2

+ 13 - 2
api/core/memory/token_buffer_memory.py

@@ -3,6 +3,7 @@ from core.file.message_file_parser import MessageFileParser
 from core.model_manager import ModelInstance
 from core.model_runtime.entities.message_entities import (
     AssistantPromptMessage,
+    ImagePromptMessageContent,
     PromptMessage,
     PromptMessageRole,
     TextPromptMessageContent,
@@ -124,7 +125,17 @@ class TokenBufferMemory:
             else:
                 continue
 
-            message = f"{role}: {m.content}"
-            string_messages.append(message)
+            if isinstance(m.content, list):
+                inner_msg = ""
+                for content in m.content:
+                    if isinstance(content, TextPromptMessageContent):
+                        inner_msg += f"{content.data}\n"
+                    elif isinstance(content, ImagePromptMessageContent):
+                        inner_msg += "[image]\n"
+
+                string_messages.append(f"{role}: {inner_msg.strip()}")
+            else:
+                message = f"{role}: {m.content}"
+                string_messages.append(message)
 
         return "\n".join(string_messages)

+ 36 - 4
api/core/workflow/workflow_engine_manager.py

@@ -1,9 +1,10 @@
 import logging
 import time
-from typing import Optional
+from typing import Optional, cast
 
+from core.app.app_config.entities import FileExtraConfig
 from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
-from core.file.file_obj import FileVar
+from core.file.file_obj import FileTransferMethod, FileType, FileVar
 from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
 from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
 from core.workflow.entities.variable_pool import VariablePool, VariableValue
@@ -16,6 +17,7 @@ from core.workflow.nodes.end.end_node import EndNode
 from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
 from core.workflow.nodes.if_else.if_else_node import IfElseNode
 from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
+from core.workflow.nodes.llm.entities import LLMNodeData
 from core.workflow.nodes.llm.llm_node import LLMNode
 from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
 from core.workflow.nodes.start.start_node import StartNode
@@ -219,7 +221,8 @@ class WorkflowEngineManager:
             raise ValueError('node id not found in workflow graph')
 
         # Get node class
-        node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type')))
+        node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
+        node_cls = node_classes.get(node_type)
 
         # init workflow run state
         node_instance = node_cls(
@@ -252,11 +255,40 @@ class WorkflowEngineManager:
                 variable_node_id = variable_selector[0]
                 variable_key_list = variable_selector[1:]
 
+                # get value
+                value = user_inputs.get(variable_key)
+
+                # temp fix for image type
+                if node_type == NodeType.LLM:
+                    new_value = []
+                    if isinstance(value, list):
+                        node_data = node_instance.node_data
+                        node_data = cast(LLMNodeData, node_data)
+
+                        detail = node_data.vision.configs.detail if node_data.vision.configs else None
+
+                        for item in value:
+                            if isinstance(item, dict) and 'type' in item and item['type'] == 'image':
+                                transfer_method = FileTransferMethod.value_of(item.get('transfer_method'))
+                                file = FileVar(
+                                    tenant_id=workflow.tenant_id,
+                                    type=FileType.IMAGE,
+                                    transfer_method=transfer_method,
+                                    url=item.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
+                                    related_id=item.get(
+                                        'upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
+                                    extra_config=FileExtraConfig(image_config={'detail': detail} if detail else None),
+                                )
+                                new_value.append(file)
+
+                    if new_value:
+                        value = new_value
+
                 # append variable and value to variable pool
                 variable_pool.append_variable(
                     node_id=variable_node_id,
                     variable_key_list=variable_key_list,
-                    value=user_inputs.get(variable_key)
+                    value=value
                 )
             # run node
             node_run_result = node_instance.run(

+ 1 - 1
api/models/model.py

@@ -815,7 +815,7 @@ class Message(db.Model):
     @property
     def workflow_run(self):
         if self.workflow_run_id:
-            from api.models.workflow import WorkflowRun
+            from .workflow import WorkflowRun
             return db.session.query(WorkflowRun).filter(WorkflowRun.id == self.workflow_run_id).first()
 
         return None

+ 4 - 0
api/models/workflow.py

@@ -299,6 +299,10 @@ class WorkflowRun(db.Model):
             Message.workflow_run_id == self.id
         ).first()
 
+    @property
+    def workflow(self):
+        return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first()
+
 
 class WorkflowNodeExecutionTriggeredFrom(Enum):
     """