Kaynağa Gözat

feat: support vision models from xinference (#4094)

Co-authored-by: Yeuoly <admin@srmxy.cn>
Minamiyama 11 ay önce
ebeveyn
işleme
f361c7004d

+ 63 - 30
api/core/model_runtime/model_providers/xinference/llm/llm.py

@@ -28,7 +28,10 @@ from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
 from core.model_runtime.entities.message_entities import (
     AssistantPromptMessage,
+    ImagePromptMessageContent,
     PromptMessage,
+    PromptMessageContent,
+    PromptMessageContentType,
     PromptMessageTool,
     SystemPromptMessage,
     ToolPromptMessage,
@@ -61,8 +64,8 @@ from core.model_runtime.utils import helper
 
 
 class XinferenceAILargeLanguageModel(LargeLanguageModel):
-    def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], 
-                model_parameters: dict, tools: list[PromptMessageTool] | None = None, 
+    def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+                model_parameters: dict, tools: list[PromptMessageTool] | None = None,
                 stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
         -> LLMResult | Generator:
         """
@@ -99,7 +102,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
         try:
             if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
                 raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
-            
+
             extra_param = XinferenceHelper.get_xinference_extra_parameter(
                 server_url=credentials['server_url'],
                 model_uid=credentials['model_uid']
@@ -111,10 +114,13 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
                     credentials['completion_type'] = 'completion'
                 else:
                     raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported, check if you have the right model type')
-                
+
             if extra_param.support_function_call:
                 credentials['support_function_call'] = True
 
+            if extra_param.support_vision:
+                credentials['support_vision'] = True
+
             if extra_param.context_length:
                 credentials['context_length'] = extra_param.context_length
 
@@ -135,7 +141,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
         """
         return self._num_tokens_from_messages(prompt_messages, tools)
 
-    def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool], 
+    def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool],
                                   is_completion_model: bool = False) -> int:
         def tokens(text: str):
             return self._get_num_tokens_by_gpt2(text)
@@ -155,7 +161,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
                     text = ''
                     for item in value:
                         if isinstance(item, dict) and item['type'] == 'text':
-                            text += item.text
+                            text += item['text']
 
                     value = text
 
@@ -191,7 +197,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
             num_tokens += self._num_tokens_for_tools(tools)
 
         return num_tokens
-    
+
     def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int:
         """
         Calculate num tokens for tool calling
@@ -234,7 +240,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
                     num_tokens += tokens(required_field)
 
         return num_tokens
-    
+
     def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str:
         """
             convert prompt message to text
@@ -260,7 +266,26 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
             if isinstance(message.content, str):
                 message_dict = {"role": "user", "content": message.content}
             else:
-                raise ValueError("User message content must be str")
+                sub_messages = []
+                for message_content in message.content:
+                    if message_content.type == PromptMessageContentType.TEXT:
+                        message_content = cast(PromptMessageContent, message_content)
+                        sub_message_dict = {
+                            "type": "text",
+                            "text": message_content.data
+                        }
+                        sub_messages.append(sub_message_dict)
+                    elif message_content.type == PromptMessageContentType.IMAGE:
+                        message_content = cast(ImagePromptMessageContent, message_content)
+                        sub_message_dict = {
+                            "type": "image_url",
+                            "image_url": {
+                                "url": message_content.data,
+                                "detail": message_content.detail.value
+                            }
+                        }
+                        sub_messages.append(sub_message_dict)
+                message_dict = {"role": "user", "content": sub_messages}
         elif isinstance(message, AssistantPromptMessage):
             message = cast(AssistantPromptMessage, message)
             message_dict = {"role": "assistant", "content": message.content}
@@ -277,7 +302,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
             message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content}
         else:
             raise ValueError(f"Unknown message type {type(message)}")
-        
+
         return message_dict
 
     def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
@@ -338,8 +363,18 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
                 completion_type = LLMMode.COMPLETION.value
             else:
                 raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
-            
+
+
+        features = []
+
         support_function_call = credentials.get('support_function_call', False)
+        if support_function_call:
+            features.append(ModelFeature.TOOL_CALL)
+
+        support_vision = credentials.get('support_vision', False)
+        if support_vision:
+            features.append(ModelFeature.VISION)
+
         context_length = credentials.get('context_length', 2048)
 
         entity = AIModelEntity(
@@ -349,10 +384,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
             ),
             fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
             model_type=ModelType.LLM,
-            features=[
-                ModelFeature.TOOL_CALL
-            ] if support_function_call else [],
-            model_properties={ 
+            features=features,
+            model_properties={
                 ModelPropertyKey.MODE: completion_type,
                 ModelPropertyKey.CONTEXT_SIZE: context_length
             },
@@ -360,22 +393,22 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
         )
 
         return entity
-    
-    def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], 
+
+    def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
                  model_parameters: dict, extra_model_kwargs: XinferenceModelExtraParameter,
-                 tools: list[PromptMessageTool] | None = None, 
+                 tools: list[PromptMessageTool] | None = None,
                  stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
             -> LLMResult | Generator:
         """
             generate text from LLM
 
             see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._generate`
-            
+
             extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter`
         """
         if 'server_url' not in credentials:
             raise CredentialsValidateFailedError('server_url is required in credentials')
-        
+
         if credentials['server_url'].endswith('/'):
             credentials['server_url'] = credentials['server_url'][:-1]
 
@@ -408,11 +441,11 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
                     'function': helper.dump_model(tool)
                 } for tool in tools
             ]
-
+        vision = credentials.get('support_vision', False)
         if isinstance(xinference_model, RESTfulChatModelHandle | RESTfulChatglmCppChatModelHandle):
             resp = client.chat.completions.create(
                 model=credentials['model_uid'],
-                messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages], 
+                messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages],
                 stream=stream,
                 user=user,
                 **generate_config,
@@ -497,7 +530,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
         """
         if len(resp.choices) == 0:
             raise InvokeServerUnavailableError("Empty response")
-        
+
         assistant_message = resp.choices[0].message
 
         # convert tool call to assistant message tool call
@@ -527,7 +560,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
         )
 
         return response
-
+    
     def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
                                         tools: list[PromptMessageTool],
                                         resp: Iterator[ChatCompletionChunk]) -> Generator:
@@ -544,7 +577,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
 
             if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''):
                 continue
-            
+
             # check if there is a tool call in the response
             function_call = None
             tool_calls = []
@@ -573,9 +606,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
                 prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
                 completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[])
 
-                usage = self._calc_response_usage(model=model, credentials=credentials, 
+                usage = self._calc_response_usage(model=model, credentials=credentials,
                                                   prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
-                
+
                 yield LLMResultChunk(
                     model=model,
                     prompt_messages=prompt_messages,
@@ -608,7 +641,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
         """
         if len(resp.choices) == 0:
             raise InvokeServerUnavailableError("Empty response")
-        
+
         assistant_message = resp.choices[0].text
 
         # transform assistant message to prompt message
@@ -670,9 +703,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
                 completion_tokens = self._num_tokens_from_messages(
                     messages=[temp_assistant_prompt_message], tools=[], is_completion_model=True
                 )
-                usage = self._calc_response_usage(model=model, credentials=credentials, 
+                usage = self._calc_response_usage(model=model, credentials=credentials,
                                                   prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
-                
+
                 yield LLMResultChunk(
                     model=model,
                     prompt_messages=prompt_messages,

+ 9 - 5
api/core/model_runtime/model_providers/xinference/xinference_helper.py

@@ -14,13 +14,15 @@ class XinferenceModelExtraParameter:
     max_tokens: int = 512
     context_length: int = 2048
     support_function_call: bool = False
+    support_vision: bool = False
 
-    def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str], 
-                 support_function_call: bool, max_tokens: int, context_length: int) -> None:
+    def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str],
+                 support_function_call: bool, support_vision: bool, max_tokens: int, context_length: int) -> None:
         self.model_format = model_format
         self.model_handle_type = model_handle_type
         self.model_ability = model_ability
         self.support_function_call = support_function_call
+        self.support_vision = support_vision
         self.max_tokens = max_tokens
         self.context_length = context_length
 
@@ -71,7 +73,7 @@ class XinferenceHelper:
             raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}')
         if response.status_code != 200:
             raise RuntimeError(f'get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}')
-        
+
         response_json = response.json()
 
         model_format = response_json.get('model_format', 'ggmlv3')
@@ -87,17 +89,19 @@ class XinferenceHelper:
             model_handle_type = 'chat'
         else:
             raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported')
-        
+
         support_function_call = 'tools' in model_ability
+        support_vision = 'vision' in model_ability
         max_tokens = response_json.get('max_tokens', 512)
 
         context_length = response_json.get('context_length', 2048)
-        
+
         return XinferenceModelExtraParameter(
             model_format=model_format,
             model_handle_type=model_handle_type,
             model_ability=model_ability,
             support_function_call=support_function_call,
+            support_vision=support_vision,
             max_tokens=max_tokens,
             context_length=context_length
         )