Explorar o código

Feat/Agent-Image-Processing (#3293)

Co-authored-by: Joel <iamjoel007@gmail.com>
Yeuoly hai 1 ano
pai
achega
14bb0b02ac

+ 48 - 4
api/core/agent/base_agent_runner.py

@@ -5,6 +5,7 @@ from datetime import datetime
 from typing import Optional, Union, cast
 
 from core.agent.entities import AgentEntity, AgentToolEntity
+from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
 from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
 from core.app.apps.base_app_queue_manager import AppQueueManager
 from core.app.apps.base_app_runner import AppRunner
@@ -14,6 +15,7 @@ from core.app.entities.app_invoke_entities import (
 )
 from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
 from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
+from core.file.message_file_parser import MessageFileParser
 from core.memory.token_buffer_memory import TokenBufferMemory
 from core.model_manager import ModelInstance
 from core.model_runtime.entities.llm_entities import LLMUsage
@@ -22,6 +24,7 @@ from core.model_runtime.entities.message_entities import (
     PromptMessage,
     PromptMessageTool,
     SystemPromptMessage,
+    TextPromptMessageContent,
     ToolPromptMessage,
     UserPromptMessage,
 )
@@ -37,7 +40,7 @@ from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
 from core.tools.tool.tool import Tool
 from core.tools.tool_manager import ToolManager
 from extensions.ext_database import db
-from models.model import Message, MessageAgentThought
+from models.model import Conversation, Message, MessageAgentThought
 from models.tools import ToolConversationVariables
 
 logger = logging.getLogger(__name__)
@@ -45,6 +48,7 @@ logger = logging.getLogger(__name__)
 class BaseAgentRunner(AppRunner):
     def __init__(self, tenant_id: str,
                  application_generate_entity: AgentChatAppGenerateEntity,
+                 conversation: Conversation,
                  app_config: AgentChatAppConfig,
                  model_config: ModelConfigWithCredentialsEntity,
                  config: AgentEntity,
@@ -72,6 +76,7 @@ class BaseAgentRunner(AppRunner):
         """
         self.tenant_id = tenant_id
         self.application_generate_entity = application_generate_entity
+        self.conversation = conversation
         self.app_config = app_config
         self.model_config = model_config
         self.config = config
@@ -118,6 +123,12 @@ class BaseAgentRunner(AppRunner):
         else:
             self.stream_tool_call = False
 
+        # check if model supports vision
+        if model_schema and ModelFeature.VISION in (model_schema.features or []):
+            self.files = application_generate_entity.files
+        else:
+            self.files = []
+
     def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \
             -> AgentChatAppGenerateEntity:
         """
@@ -412,15 +423,19 @@ class BaseAgentRunner(AppRunner):
         """
         result = []
         # check if there is a system message in the beginning of the conversation
-        if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage):
-            result.append(prompt_messages[0])
+        for prompt_message in prompt_messages:
+            if isinstance(prompt_message, SystemPromptMessage):
+                result.append(prompt_message)
 
         messages: list[Message] = db.session.query(Message).filter(
             Message.conversation_id == self.message.conversation_id,
         ).order_by(Message.created_at.asc()).all()
 
         for message in messages:
-            result.append(UserPromptMessage(content=message.query))
+            if message.id == self.message.id:
+                continue
+            
+            result.append(self.organize_agent_user_prompt(message))
             agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
             if agent_thoughts:
                 for agent_thought in agent_thoughts:
@@ -471,3 +486,32 @@ class BaseAgentRunner(AppRunner):
         db.session.close()
 
         return result
+
+    def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
+        message_file_parser = MessageFileParser(
+            tenant_id=self.tenant_id,
+            app_id=self.app_config.app_id,
+        )
+
+        files = message.message_files
+        if files:
+            file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
+
+            if file_extra_config:
+                file_objs = message_file_parser.transform_message_files(
+                    files,
+                    file_extra_config
+                )
+            else:
+                file_objs = []
+
+            if not file_objs:
+                return UserPromptMessage(content=message.query)
+            else:
+                prompt_message_contents = [TextPromptMessageContent(data=message.query)]
+                for file_obj in file_objs:
+                    prompt_message_contents.append(file_obj.prompt_message_content)
+
+                return UserPromptMessage(content=prompt_message_contents)
+        else:
+            return UserPromptMessage(content=message.query)

+ 2 - 3
api/core/agent/cot_agent_runner.py

@@ -19,15 +19,14 @@ from core.model_runtime.entities.message_entities import (
 from core.model_runtime.utils.encoders import jsonable_encoder
 from core.tools.entities.tool_entities import ToolInvokeMeta
 from core.tools.tool_engine import ToolEngine
-from models.model import Conversation, Message
+from models.model import Message
 
 
 class CotAgentRunner(BaseAgentRunner):
     _is_first_iteration = True
     _ignore_observation_providers = ['wenxin']
 
-    def run(self, conversation: Conversation,
-        message: Message,
+    def run(self, message: Message,
         query: str,
         inputs: dict[str, str],
     ) -> Union[Generator, LLMResult]:

+ 67 - 30
api/core/agent/fc_agent_runner.py

@@ -1,6 +1,7 @@
 import json
 import logging
 from collections.abc import Generator
+from copy import deepcopy
 from typing import Any, Union
 
 from core.agent.base_agent_runner import BaseAgentRunner
@@ -10,20 +11,21 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
 from core.model_runtime.entities.message_entities import (
     AssistantPromptMessage,
     PromptMessage,
+    PromptMessageContentType,
     PromptMessageTool,
     SystemPromptMessage,
+    TextPromptMessageContent,
     ToolPromptMessage,
     UserPromptMessage,
 )
 from core.tools.entities.tool_entities import ToolInvokeMeta
 from core.tools.tool_engine import ToolEngine
-from models.model import Conversation, Message, MessageAgentThought
+from models.model import Message
 
 logger = logging.getLogger(__name__)
 
 class FunctionCallAgentRunner(BaseAgentRunner):
-    def run(self, conversation: Conversation,
-                message: Message,
+    def run(self, message: Message,
                 query: str,
     ) -> Generator[LLMResultChunk, None, None]:
         """
@@ -35,11 +37,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
 
         prompt_template = app_config.prompt_template.simple_prompt_template or ''
         prompt_messages = self.history_prompt_messages
-        prompt_messages = self.organize_prompt_messages(
-            prompt_template=prompt_template,
-            query=query,
-            prompt_messages=prompt_messages
-        )
+        prompt_messages = self._init_system_message(prompt_template, prompt_messages)
+        prompt_messages = self._organize_user_query(query, prompt_messages)
 
         # convert tools into ModelRuntime Tool format
         prompt_messages_tools: list[PromptMessageTool] = []
@@ -68,7 +67,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
 
         # continue to run until there is not any tool call
         function_call_state = True
-        agent_thoughts: list[MessageAgentThought] = []
         llm_usage = {
             'usage': None
         }
@@ -287,9 +285,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
                     }
                 
                 tool_responses.append(tool_response)
-                prompt_messages = self.organize_prompt_messages(
-                    prompt_template=prompt_template,
-                    query=None,
+                prompt_messages = self._organize_assistant_message(
                     tool_call_id=tool_call_id,
                     tool_call_name=tool_call_name,
                     tool_response=tool_response['tool_response'],
@@ -324,6 +320,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
 
             iteration_step += 1
 
+            prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
+
         self.update_db_variables(self.variables_pool, self.db_variables_pool)
         # publish end event
         self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
@@ -386,29 +384,68 @@ class FunctionCallAgentRunner(BaseAgentRunner):
 
         return tool_calls
 
-    def organize_prompt_messages(self, prompt_template: str,
-                                 query: str = None, 
-                                 tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None,
-                                 prompt_messages: list[PromptMessage] = None
-                                 ) -> list[PromptMessage]:
+    def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
         """
-        Organize prompt messages
+        Initialize system message
         """
-        
-        if not prompt_messages:
-            prompt_messages = [
+        if not prompt_messages and prompt_template:
+            return [
                 SystemPromptMessage(content=prompt_template),
-                UserPromptMessage(content=query),
             ]
+        
+        if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
+            prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
+
+        return prompt_messages
+
+    def _organize_user_query(self, query,  prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
+        """
+        Organize user query
+        """
+        if self.files:
+            prompt_message_contents = [TextPromptMessageContent(data=query)]
+            for file_obj in self.files:
+                prompt_message_contents.append(file_obj.prompt_message_content)
+
+            prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
         else:
-            if tool_response:
-                prompt_messages = prompt_messages.copy()
-                prompt_messages.append(
-                    ToolPromptMessage(
-                        content=tool_response,
-                        tool_call_id=tool_call_id,
-                        name=tool_call_name,
-                    )
+            prompt_messages.append(UserPromptMessage(content=query))
+
+        return prompt_messages
+    
+    def _organize_assistant_message(self, tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None, 
+                                    prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
+        """
+        Organize assistant message
+        """
+        prompt_messages = deepcopy(prompt_messages)
+
+        if tool_response is not None:
+            prompt_messages.append(
+                ToolPromptMessage(
+                    content=tool_response,
+                    tool_call_id=tool_call_id,
+                    name=tool_call_name,
                 )
+            )
+
+        return prompt_messages
+    
+    def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
+        """
+        As for now, gpt supports both fc and vision at the first iteration.
+        We need to remove the image messages from the prompt messages at the first iteration.
+        """
+        prompt_messages = deepcopy(prompt_messages)
+
+        for prompt_message in prompt_messages:
+            if isinstance(prompt_message, UserPromptMessage):
+                if isinstance(prompt_message.content, list):
+                    prompt_message.content = '\n'.join([
+                        content.data if content.type == PromptMessageContentType.TEXT else 
+                        '[image]' if content.type == PromptMessageContentType.IMAGE else
+                        '[file]' 
+                        for content in prompt_message.content 
+                    ])
 
         return prompt_messages

+ 2 - 2
api/core/app/apps/agent_chat/app_runner.py

@@ -210,6 +210,7 @@ class AgentChatAppRunner(AppRunner):
             assistant_cot_runner = CotAgentRunner(
                 tenant_id=app_config.tenant_id,
                 application_generate_entity=application_generate_entity,
+                conversation=conversation,
                 app_config=app_config,
                 model_config=application_generate_entity.model_config,
                 config=agent_entity,
@@ -223,7 +224,6 @@ class AgentChatAppRunner(AppRunner):
                 model_instance=model_instance
             )
             invoke_result = assistant_cot_runner.run(
-                conversation=conversation,
                 message=message,
                 query=query,
                 inputs=inputs,
@@ -232,6 +232,7 @@ class AgentChatAppRunner(AppRunner):
             assistant_fc_runner = FunctionCallAgentRunner(
                 tenant_id=app_config.tenant_id,
                 application_generate_entity=application_generate_entity,
+                conversation=conversation,
                 app_config=app_config,
                 model_config=application_generate_entity.model_config,
                 config=agent_entity,
@@ -245,7 +246,6 @@ class AgentChatAppRunner(AppRunner):
                 model_instance=model_instance
             )
             invoke_result = assistant_fc_runner.run(
-                conversation=conversation,
                 message=message,
                 query=query,
             )

+ 28 - 0
api/core/model_runtime/model_providers/openai/llm/llm.py

@@ -547,6 +547,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
         if user:
             extra_model_kwargs['user'] = user
 
+        # clear illegal prompt messages
+        prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
+
         # chat model
         response = client.chat.completions.create(
             messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
@@ -757,6 +760,31 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
 
         return tool_call
 
+    def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
+        """
+        Clear illegal prompt messages for OpenAI API
+
+        :param model: model name
+        :param prompt_messages: prompt messages
+        :return: cleaned prompt messages
+        """
+        checklist = ['gpt-4-turbo', 'gpt-4-turbo-2024-04-09']
+
+        if model in checklist:
+            # count how many user messages are there
+            user_message_count = len([m for m in prompt_messages if isinstance(m, UserPromptMessage)])
+            if user_message_count > 1:
+                for prompt_message in prompt_messages:
+                    if isinstance(prompt_message, UserPromptMessage):
+                        if isinstance(prompt_message.content, list):
+                            prompt_message.content = '\n'.join([
+                                item.data if item.type == PromptMessageContentType.TEXT else
+                                '[IMAGE]' if item.type == PromptMessageContentType.IMAGE else ''
+                                for item in prompt_message.content
+                            ])
+
+        return prompt_messages
+
     def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
         """
         Convert PromptMessage to dict for OpenAI API

+ 1 - 1
web/app/components/base/chat/chat/hooks.ts

@@ -229,7 +229,7 @@ export const useChat = (
 
     // answer
     const responseItem: ChatItem = {
-      id: `${Date.now()}`,
+      id: placeholderAnswerId,
       content: '',
       agent_thoughts: [],
       message_files: [],