Browse Source

fix: respect resolution settings for vision for basic chatbot, text generator, and parameter extractor node (#16041)

kurokobo 1 month ago
parent
commit
86d3fff666

+ 8 - 1
api/core/app/apps/base_app_runner.py

@@ -17,7 +17,11 @@ from core.external_data_tool.external_data_fetch import ExternalDataFetch
 from core.memory.token_buffer_memory import TokenBufferMemory
 from core.model_manager import ModelInstance
 from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
-from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage
+from core.model_runtime.entities.message_entities import (
+    AssistantPromptMessage,
+    ImagePromptMessageContent,
+    PromptMessage,
+)
 from core.model_runtime.entities.model_entities import ModelPropertyKey
 from core.model_runtime.errors.invoke import InvokeBadRequestError
 from core.moderation.input_moderation import InputModeration
@@ -141,6 +145,7 @@ class AppRunner:
         query: Optional[str] = None,
         context: Optional[str] = None,
         memory: Optional[TokenBufferMemory] = None,
+        image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
     ) -> tuple[list[PromptMessage], Optional[list[str]]]:
         """
         Organize prompt messages
@@ -167,6 +172,7 @@ class AppRunner:
                 context=context,
                 memory=memory,
                 model_config=model_config,
+                image_detail_config=image_detail_config,
             )
         else:
             memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
@@ -201,6 +207,7 @@ class AppRunner:
                 memory_config=memory_config,
                 memory=memory,
                 model_config=model_config,
+                image_detail_config=image_detail_config,
             )
             stop = model_config.stop
 

+ 13 - 0
api/core/app/apps/chat/app_runner.py

@@ -11,6 +11,7 @@ from core.app.entities.queue_entities import QueueAnnotationReplyEvent
 from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
 from core.memory.token_buffer_memory import TokenBufferMemory
 from core.model_manager import ModelInstance
+from core.model_runtime.entities.message_entities import ImagePromptMessageContent
 from core.moderation.base import ModerationError
 from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
 from extensions.ext_database import db
@@ -50,6 +51,16 @@ class ChatAppRunner(AppRunner):
         query = application_generate_entity.query
         files = application_generate_entity.files
 
+        image_detail_config = (
+            application_generate_entity.file_upload_config.image_config.detail
+            if (
+                application_generate_entity.file_upload_config
+                and application_generate_entity.file_upload_config.image_config
+            )
+            else None
+        )
+        image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
+
         # Pre-calculate the number of tokens of the prompt messages,
         # and return the rest number of tokens by model context token size limit and max token size limit.
         # If the rest number of tokens is not enough, raise exception.
@@ -85,6 +96,7 @@ class ChatAppRunner(AppRunner):
             files=files,
             query=query,
             memory=memory,
+            image_detail_config=image_detail_config,
         )
 
         # moderation
@@ -182,6 +194,7 @@ class ChatAppRunner(AppRunner):
             query=query,
             context=context,
             memory=memory,
+            image_detail_config=image_detail_config,
         )
 
         # check hosting moderation

+ 13 - 0
api/core/app/apps/completion/app_runner.py

@@ -9,6 +9,7 @@ from core.app.entities.app_invoke_entities import (
 )
 from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
 from core.model_manager import ModelInstance
+from core.model_runtime.entities.message_entities import ImagePromptMessageContent
 from core.moderation.base import ModerationError
 from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
 from extensions.ext_database import db
@@ -43,6 +44,16 @@ class CompletionAppRunner(AppRunner):
         query = application_generate_entity.query
         files = application_generate_entity.files
 
+        image_detail_config = (
+            application_generate_entity.file_upload_config.image_config.detail
+            if (
+                application_generate_entity.file_upload_config
+                and application_generate_entity.file_upload_config.image_config
+            )
+            else None
+        )
+        image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
+
         # Pre-calculate the number of tokens of the prompt messages,
         # and return the rest number of tokens by model context token size limit and max token size limit.
         # If the rest number of tokens is not enough, raise exception.
@@ -66,6 +77,7 @@ class CompletionAppRunner(AppRunner):
             inputs=inputs,
             files=files,
             query=query,
+            image_detail_config=image_detail_config,
         )
 
         # moderation
@@ -140,6 +152,7 @@ class CompletionAppRunner(AppRunner):
             files=files,
             query=query,
             context=context,
+            image_detail_config=image_detail_config,
         )
 
         # check hosting moderation

+ 20 - 5
api/core/prompt/advanced_prompt_transform.py

@@ -46,6 +46,7 @@ class AdvancedPromptTransform(PromptTransform):
         memory_config: Optional[MemoryConfig],
         memory: Optional[TokenBufferMemory],
         model_config: ModelConfigWithCredentialsEntity,
+        image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
     ) -> list[PromptMessage]:
         prompt_messages = []
 
@@ -59,6 +60,7 @@ class AdvancedPromptTransform(PromptTransform):
                 memory_config=memory_config,
                 memory=memory,
                 model_config=model_config,
+                image_detail_config=image_detail_config,
             )
         elif isinstance(prompt_template, list) and all(isinstance(item, ChatModelMessage) for item in prompt_template):
             prompt_messages = self._get_chat_model_prompt_messages(
@@ -70,6 +72,7 @@ class AdvancedPromptTransform(PromptTransform):
                 memory_config=memory_config,
                 memory=memory,
                 model_config=model_config,
+                image_detail_config=image_detail_config,
             )
 
         return prompt_messages
@@ -84,6 +87,7 @@ class AdvancedPromptTransform(PromptTransform):
         memory_config: Optional[MemoryConfig],
         memory: Optional[TokenBufferMemory],
         model_config: ModelConfigWithCredentialsEntity,
+        image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
     ) -> list[PromptMessage]:
         """
         Get completion model prompt messages.
@@ -124,7 +128,9 @@ class AdvancedPromptTransform(PromptTransform):
             prompt_message_contents: list[PromptMessageContent] = []
             prompt_message_contents.append(TextPromptMessageContent(data=prompt))
             for file in files:
-                prompt_message_contents.append(file_manager.to_prompt_message_content(file))
+                prompt_message_contents.append(
+                    file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
+                )
 
             prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
         else:
@@ -142,6 +148,7 @@ class AdvancedPromptTransform(PromptTransform):
         memory_config: Optional[MemoryConfig],
         memory: Optional[TokenBufferMemory],
         model_config: ModelConfigWithCredentialsEntity,
+        image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
     ) -> list[PromptMessage]:
         """
         Get chat model prompt messages.
@@ -197,7 +204,9 @@ class AdvancedPromptTransform(PromptTransform):
                 prompt_message_contents: list[PromptMessageContent] = []
                 prompt_message_contents.append(TextPromptMessageContent(data=query))
                 for file in files:
-                    prompt_message_contents.append(file_manager.to_prompt_message_content(file))
+                    prompt_message_contents.append(
+                        file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
+                    )
                 prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
             else:
                 prompt_messages.append(UserPromptMessage(content=query))
@@ -209,19 +218,25 @@ class AdvancedPromptTransform(PromptTransform):
                     # get last user message content and add files
                     prompt_message_contents = [TextPromptMessageContent(data=cast(str, last_message.content))]
                     for file in files:
-                        prompt_message_contents.append(file_manager.to_prompt_message_content(file))
+                        prompt_message_contents.append(
+                            file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
+                        )
 
                     last_message.content = prompt_message_contents
                 else:
                     prompt_message_contents = [TextPromptMessageContent(data="")]  # not for query
                     for file in files:
-                        prompt_message_contents.append(file_manager.to_prompt_message_content(file))
+                        prompt_message_contents.append(
+                            file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
+                        )
 
                     prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
             else:
                 prompt_message_contents = [TextPromptMessageContent(data=query)]
                 for file in files:
-                    prompt_message_contents.append(file_manager.to_prompt_message_content(file))
+                    prompt_message_contents.append(
+                        file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
+                    )
 
                 prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
         elif query:

+ 18 - 5
api/core/prompt/simple_prompt_transform.py

@@ -9,6 +9,7 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
 from core.file import file_manager
 from core.memory.token_buffer_memory import TokenBufferMemory
 from core.model_runtime.entities.message_entities import (
+    ImagePromptMessageContent,
     PromptMessage,
     PromptMessageContent,
     SystemPromptMessage,
@@ -60,6 +61,7 @@ class SimplePromptTransform(PromptTransform):
         context: Optional[str],
         memory: Optional[TokenBufferMemory],
         model_config: ModelConfigWithCredentialsEntity,
+        image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
     ) -> tuple[list[PromptMessage], Optional[list[str]]]:
         inputs = {key: str(value) for key, value in inputs.items()}
 
@@ -74,6 +76,7 @@ class SimplePromptTransform(PromptTransform):
                 context=context,
                 memory=memory,
                 model_config=model_config,
+                image_detail_config=image_detail_config,
             )
         else:
             prompt_messages, stops = self._get_completion_model_prompt_messages(
@@ -85,6 +88,7 @@ class SimplePromptTransform(PromptTransform):
                 context=context,
                 memory=memory,
                 model_config=model_config,
+                image_detail_config=image_detail_config,
             )
 
         return prompt_messages, stops
@@ -175,6 +179,7 @@ class SimplePromptTransform(PromptTransform):
         files: Sequence["File"],
         memory: Optional[TokenBufferMemory],
         model_config: ModelConfigWithCredentialsEntity,
+        image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
     ) -> tuple[list[PromptMessage], Optional[list[str]]]:
         prompt_messages: list[PromptMessage] = []
 
@@ -204,9 +209,9 @@ class SimplePromptTransform(PromptTransform):
             )
 
         if query:
-            prompt_messages.append(self.get_last_user_message(query, files))
+            prompt_messages.append(self.get_last_user_message(query, files, image_detail_config))
         else:
-            prompt_messages.append(self.get_last_user_message(prompt, files))
+            prompt_messages.append(self.get_last_user_message(prompt, files, image_detail_config))
 
         return prompt_messages, None
 
@@ -220,6 +225,7 @@ class SimplePromptTransform(PromptTransform):
         files: Sequence["File"],
         memory: Optional[TokenBufferMemory],
         model_config: ModelConfigWithCredentialsEntity,
+        image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
     ) -> tuple[list[PromptMessage], Optional[list[str]]]:
         # get prompt
         prompt, prompt_rules = self.get_prompt_str_and_rules(
@@ -262,14 +268,21 @@ class SimplePromptTransform(PromptTransform):
         if stops is not None and len(stops) == 0:
             stops = None
 
-        return [self.get_last_user_message(prompt, files)], stops
+        return [self.get_last_user_message(prompt, files, image_detail_config)], stops
 
-    def get_last_user_message(self, prompt: str, files: Sequence["File"]) -> UserPromptMessage:
+    def get_last_user_message(
+        self,
+        prompt: str,
+        files: Sequence["File"],
+        image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
+    ) -> UserPromptMessage:
         if files:
             prompt_message_contents: list[PromptMessageContent] = []
             prompt_message_contents.append(TextPromptMessageContent(data=prompt))
             for file in files:
-                prompt_message_contents.append(file_manager.to_prompt_message_content(file))
+                prompt_message_contents.append(
+                    file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
+                )
 
             prompt_message = UserPromptMessage(content=prompt_message_contents)
         else:

+ 12 - 0
api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py

@@ -7,6 +7,7 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
 from core.file import File
 from core.memory.token_buffer_memory import TokenBufferMemory
 from core.model_manager import ModelInstance
+from core.model_runtime.entities import ImagePromptMessageContent
 from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
 from core.model_runtime.entities.message_entities import (
     AssistantPromptMessage,
@@ -129,6 +130,7 @@ class ParameterExtractorNode(LLMNode):
                 model_config=model_config,
                 memory=memory,
                 files=files,
+                vision_detail=node_data.vision.configs.detail,
             )
         else:
             # use prompt engineering
@@ -139,6 +141,7 @@ class ParameterExtractorNode(LLMNode):
                 model_config=model_config,
                 memory=memory,
                 files=files,
+                vision_detail=node_data.vision.configs.detail,
             )
 
             prompt_message_tools = []
@@ -267,6 +270,7 @@ class ParameterExtractorNode(LLMNode):
         model_config: ModelConfigWithCredentialsEntity,
         memory: Optional[TokenBufferMemory],
         files: Sequence[File],
+        vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None,
     ) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
         """
         Generate function call prompt.
@@ -289,6 +293,7 @@ class ParameterExtractorNode(LLMNode):
             memory_config=node_data.memory,
             memory=None,
             model_config=model_config,
+            image_detail_config=vision_detail,
         )
 
         # find last user message
@@ -347,6 +352,7 @@ class ParameterExtractorNode(LLMNode):
         model_config: ModelConfigWithCredentialsEntity,
         memory: Optional[TokenBufferMemory],
         files: Sequence[File],
+        vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None,
     ) -> list[PromptMessage]:
         """
         Generate prompt engineering prompt.
@@ -361,6 +367,7 @@ class ParameterExtractorNode(LLMNode):
                 model_config=model_config,
                 memory=memory,
                 files=files,
+                vision_detail=vision_detail,
             )
         elif model_mode == ModelMode.CHAT:
             return self._generate_prompt_engineering_chat_prompt(
@@ -370,6 +377,7 @@ class ParameterExtractorNode(LLMNode):
                 model_config=model_config,
                 memory=memory,
                 files=files,
+                vision_detail=vision_detail,
             )
         else:
             raise InvalidModelModeError(f"Invalid model mode: {model_mode}")
@@ -382,6 +390,7 @@ class ParameterExtractorNode(LLMNode):
         model_config: ModelConfigWithCredentialsEntity,
         memory: Optional[TokenBufferMemory],
         files: Sequence[File],
+        vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None,
     ) -> list[PromptMessage]:
         """
         Generate completion prompt.
@@ -402,6 +411,7 @@ class ParameterExtractorNode(LLMNode):
             memory_config=node_data.memory,
             memory=memory,
             model_config=model_config,
+            image_detail_config=vision_detail,
         )
 
         return prompt_messages
@@ -414,6 +424,7 @@ class ParameterExtractorNode(LLMNode):
         model_config: ModelConfigWithCredentialsEntity,
         memory: Optional[TokenBufferMemory],
         files: Sequence[File],
+        vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None,
     ) -> list[PromptMessage]:
         """
         Generate chat prompt.
@@ -441,6 +452,7 @@ class ParameterExtractorNode(LLMNode):
             memory_config=node_data.memory,
             memory=None,
             model_config=model_config,
+            image_detail_config=vision_detail,
         )
 
         # find last user message