Преглед на файлове

Feat/blocking function call (#2247)

Yeuoly преди 1 година
родител
ревизия
6d5b386394
променени са 33 файла, в които са добавени 430 реда и са изтрити 95 реда
  1. 11 3
      api/core/app_runner/assistant_app_runner.py
  2. 14 1
      api/core/features/assistant_base_runner.py
  3. 4 3
      api/core/features/assistant_cot_runner.py
  4. 105 23
      api/core/features/assistant_fc_runner.py
  5. 1 0
      api/core/model_runtime/entities/model_entities.py
  6. 5 0
      api/core/model_runtime/model_providers/azure_openai/_constant.py
  7. 24 4
      api/core/model_runtime/model_providers/azure_openai/llm/llm.py
  8. 5 1
      api/core/model_runtime/model_providers/chatglm/llm/llm.py
  9. 2 0
      api/core/model_runtime/model_providers/minimax/llm/abab5.5-chat.yaml
  10. 2 0
      api/core/model_runtime/model_providers/minimax/llm/abab6-chat.yaml
  11. 1 2
      api/core/model_runtime/model_providers/minimax/llm/chat_completion.py
  12. 37 9
      api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py
  13. 42 1
      api/core/model_runtime/model_providers/minimax/llm/llm.py
  14. 10 0
      api/core/model_runtime/model_providers/minimax/llm/types.py
  15. 1 0
      api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0613.yaml
  16. 1 0
      api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-1106.yaml
  17. 1 0
      api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k-0613.yaml
  18. 1 0
      api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k.yaml
  19. 1 0
      api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml
  20. 1 0
      api/core/model_runtime/model_providers/openai/llm/gpt-4-0125-preview.yaml
  21. 1 0
      api/core/model_runtime/model_providers/openai/llm/gpt-4-1106-preview.yaml
  22. 1 0
      api/core/model_runtime/model_providers/openai/llm/gpt-4-32k.yaml
  23. 1 0
      api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo-preview.yaml
  24. 1 0
      api/core/model_runtime/model_providers/openai/llm/gpt-4.yaml
  25. 1 1
      api/core/model_runtime/model_providers/openai/llm/llm.py
  26. 27 4
      api/core/model_runtime/model_providers/xinference/llm/llm.py
  27. 19 4
      api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py
  28. 19 6
      api/core/model_runtime/model_providers/xinference/xinference_helper.py
  29. 4 0
      api/core/model_runtime/model_providers/zhipuai/llm/glm_3_turbo.yaml
  30. 4 0
      api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml
  31. 21 0
      api/core/model_runtime/model_providers/zhipuai/llm/llm.py
  32. 1 1
      api/requirements.txt
  33. 61 32
      api/tests/integration_tests/model_runtime/__mock/xinference.py

+ 11 - 3
api/core/app_runner/assistant_app_runner.py

@@ -11,6 +11,7 @@ from core.application_queue_manager import ApplicationQueueManager, PublishFrom
 from core.memory.token_buffer_memory import TokenBufferMemory
 from core.model_manager import ModelInstance
 from core.model_runtime.entities.llm_entities import LLMUsage
+from core.model_runtime.entities.model_entities import ModelFeature
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.moderation.base import ModerationException
 from core.tools.entities.tool_entities import ToolRuntimeVariablePool
@@ -194,6 +195,13 @@ class AssistantApplicationRunner(AppRunner):
             memory=memory,
         )
 
+        # change function call strategy based on LLM model
+        llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
+        model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
+
+        if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features):
+            agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
+
         # start agent runner
         if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
             assistant_cot_runner = AssistantCotApplicationRunner(
@@ -209,9 +217,9 @@ class AssistantApplicationRunner(AppRunner):
                 prompt_messages=prompt_message,
                 variables_pool=tool_variables,
                 db_variables=tool_conversation_variables,
+                model_instance=model_instance
             )
             invoke_result = assistant_cot_runner.run(
-                model_instance=model_instance,
                 conversation=conversation,
                 message=message,
                 query=query,
@@ -229,10 +237,10 @@ class AssistantApplicationRunner(AppRunner):
                 memory=memory,
                 prompt_messages=prompt_message,
                 variables_pool=tool_variables,
-                db_variables=tool_conversation_variables
+                db_variables=tool_conversation_variables,
+                model_instance=model_instance
             )
             invoke_result = assistant_fc_runner.run(
-                model_instance=model_instance,
                 conversation=conversation,
                 message=message,
                 query=query,

+ 14 - 1
api/core/features/assistant_base_runner.py

@@ -1,7 +1,7 @@
 import logging
 import json
 
-from typing import Optional, List, Tuple, Union
+from typing import Optional, List, Tuple, Union, cast
 from datetime import datetime
 from mimetypes import guess_extension
 
@@ -27,7 +27,10 @@ from core.entities.application_entities import ModelConfigEntity, \
     AgentEntity, AppOrchestrationConfigEntity, ApplicationGenerateEntity, InvokeFrom
 from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
 from core.model_runtime.entities.llm_entities import LLMUsage
+from core.model_runtime.entities.model_entities import ModelFeature
 from core.model_runtime.utils.encoders import jsonable_encoder
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.model_manager import ModelInstance
 from core.file.message_file_parser import FileTransferMethod
 
 logger = logging.getLogger(__name__)
@@ -45,6 +48,7 @@ class BaseAssistantApplicationRunner(AppRunner):
                  prompt_messages: Optional[List[PromptMessage]] = None,
                  variables_pool: Optional[ToolRuntimeVariablePool] = None,
                  db_variables: Optional[ToolConversationVariables] = None,
+                 model_instance: ModelInstance = None
                  ) -> None:
         """
         Agent runner
@@ -71,6 +75,7 @@ class BaseAssistantApplicationRunner(AppRunner):
         self.history_prompt_messages = prompt_messages
         self.variables_pool = variables_pool
         self.db_variables_pool = db_variables
+        self.model_instance = model_instance
 
         # init callback
         self.agent_callback = DifyAgentCallbackHandler()
@@ -95,6 +100,14 @@ class BaseAssistantApplicationRunner(AppRunner):
             MessageAgentThought.message_id == self.message.id,
         ).count()
 
+        # check if model supports stream tool call
+        llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
+        model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
+        if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []):
+            self.stream_tool_call = True
+        else:
+            self.stream_tool_call = False
+
     def _repacket_app_orchestration_config(self, app_orchestration_config: AppOrchestrationConfigEntity) -> AppOrchestrationConfigEntity:
         """
         Repacket app orchestration config

+ 4 - 3
api/core/features/assistant_cot_runner.py

@@ -20,8 +20,7 @@ from core.features.assistant_base_runner import BaseAssistantApplicationRunner
 from models.model import Conversation, Message
 
 class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
-    def run(self, model_instance: ModelInstance,
-        conversation: Conversation,
+    def run(self, conversation: Conversation,
         message: Message,
         query: str,
     ) -> Union[Generator, LLMResult]:
@@ -82,6 +81,8 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
                 llm_usage.prompt_price += usage.prompt_price
                 llm_usage.completion_price += usage.completion_price
 
+        model_instance = self.model_instance
+
         while function_call_state and iteration_step <= max_iteration_steps:
             # continue to run until there is not any tool call
             function_call_state = False
@@ -390,7 +391,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
                     # remove Action: xxx from agent thought
                     agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
 
-                    if action_name and action_input:
+                    if action_name and action_input is not None:
                         return AgentScratchpadUnit(
                             agent_response=content,
                             thought=agent_thought,

+ 105 - 23
api/core/features/assistant_fc_runner.py

@@ -5,7 +5,7 @@ from typing import Union, Generator, Dict, Any, Tuple, List
 
 from core.model_runtime.entities.message_entities import PromptMessage, UserPromptMessage,\
       SystemPromptMessage, AssistantPromptMessage, ToolPromptMessage, PromptMessageTool
-from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult, LLMUsage
+from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult, LLMUsage, LLMResultChunkDelta
 from core.model_manager import ModelInstance
 from core.application_queue_manager import PublishFrom
 
@@ -20,8 +20,7 @@ from models.model import Conversation, Message, MessageAgentThought
 logger = logging.getLogger(__name__)
 
 class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
-    def run(self, model_instance: ModelInstance,
-                conversation: Conversation,
+    def run(self, conversation: Conversation,
                 message: Message,
                 query: str,
     ) -> Generator[LLMResultChunk, None, None]:
@@ -81,6 +80,8 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
                 llm_usage.prompt_price += usage.prompt_price
                 llm_usage.completion_price += usage.completion_price
 
+        model_instance = self.model_instance
+
         while function_call_state and iteration_step <= max_iteration_steps:
             function_call_state = False
 
@@ -101,12 +102,12 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
             # recale llm max tokens
             self.recale_llm_max_tokens(self.model_config, prompt_messages)
             # invoke model
-            chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
+            chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
                 prompt_messages=prompt_messages,
                 model_parameters=app_orchestration_config.model_config.parameters,
                 tools=prompt_messages_tools,
                 stop=app_orchestration_config.model_config.stop,
-                stream=True,
+                stream=self.stream_tool_call,
                 user=self.user_id,
                 callbacks=[],
             )
@@ -122,11 +123,41 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
 
             current_llm_usage = None
 
-            for chunk in chunks:
+            if self.stream_tool_call:
+                for chunk in chunks:
+                    # check if there is any tool call
+                    if self.check_tool_calls(chunk):
+                        function_call_state = True
+                        tool_calls.extend(self.extract_tool_calls(chunk))
+                        tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
+                        try:
+                            tool_call_inputs = json.dumps({
+                                tool_call[1]: tool_call[2] for tool_call in tool_calls
+                            }, ensure_ascii=False)
+                        except json.JSONDecodeError as e:
+                            # ensure ascii to avoid encoding error
+                            tool_call_inputs = json.dumps({
+                                tool_call[1]: tool_call[2] for tool_call in tool_calls
+                            })
+
+                    if chunk.delta.message and chunk.delta.message.content:
+                        if isinstance(chunk.delta.message.content, list):
+                            for content in chunk.delta.message.content:
+                                response += content.data
+                        else:
+                            response += chunk.delta.message.content
+
+                    if chunk.delta.usage:
+                        increase_usage(llm_usage, chunk.delta.usage)
+                        current_llm_usage = chunk.delta.usage
+
+                    yield chunk
+            else:
+                result: LLMResult = chunks
                 # check if there is any tool call
-                if self.check_tool_calls(chunk):
+                if self.check_blocking_tool_calls(result):
                     function_call_state = True
-                    tool_calls.extend(self.extract_tool_calls(chunk))
+                    tool_calls.extend(self.extract_blocking_tool_calls(result))
                     tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
                     try:
                         tool_call_inputs = json.dumps({
@@ -138,18 +169,44 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
                             tool_call[1]: tool_call[2] for tool_call in tool_calls
                         })
 
-                if chunk.delta.message and chunk.delta.message.content:
-                    if isinstance(chunk.delta.message.content, list):
-                        for content in chunk.delta.message.content:
+                if result.usage:
+                    increase_usage(llm_usage, result.usage)
+                    current_llm_usage = result.usage
+
+                if result.message and result.message.content:
+                    if isinstance(result.message.content, list):
+                        for content in result.message.content:
                             response += content.data
                     else:
-                        response += chunk.delta.message.content
-
-                if chunk.delta.usage:
-                    increase_usage(llm_usage, chunk.delta.usage)
-                    current_llm_usage = chunk.delta.usage
+                        response += result.message.content
+
+                if not result.message.content:
+                    result.message.content = ''
+
+                yield LLMResultChunk(
+                    model=model_instance.model,
+                    prompt_messages=result.prompt_messages,
+                    system_fingerprint=result.system_fingerprint,
+                    delta=LLMResultChunkDelta(
+                        index=0,
+                        message=result.message,
+                        usage=result.usage,
+                    )
+                )
 
-                yield chunk
+            if tool_calls:
+                prompt_messages.append(AssistantPromptMessage(
+                    content='',
+                    name='',
+                    tool_calls=[AssistantPromptMessage.ToolCall(
+                        id=tool_call[0],
+                        type='function',
+                        function=AssistantPromptMessage.ToolCall.ToolCallFunction(
+                            name=tool_call[1],
+                            arguments=json.dumps(tool_call[2], ensure_ascii=False)
+                        )
+                    ) for tool_call in tool_calls]
+                ))
 
             # save thought
             self.save_agent_thought(
@@ -167,6 +224,12 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
             
             final_answer += response + '\n'
 
+            # update prompt messages
+            if response.strip():
+                prompt_messages.append(AssistantPromptMessage(
+                    content=response,
+                ))
+            
             # call tools
             tool_responses = []
             for tool_call_id, tool_call_name, tool_call_args in tool_calls:
@@ -256,12 +319,6 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
                 )
                 self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
 
-            # update prompt messages
-            if response.strip():
-                prompt_messages.append(AssistantPromptMessage(
-                    content=response,
-                ))
-
             # update prompt tool
             for prompt_tool in prompt_messages_tools:
                 self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
@@ -287,6 +344,14 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
         if llm_result_chunk.delta.message.tool_calls:
             return True
         return False
+    
+    def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
+        """
+        Check if there is any blocking tool call in llm result
+        """
+        if llm_result.message.tool_calls:
+            return True
+        return False
 
     def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
         """
@@ -304,6 +369,23 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
             ))
 
         return tool_calls
+    
+    def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
+        """
+        Extract blocking tool calls from llm result
+
+        Returns:
+            List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
+        """
+        tool_calls = []
+        for prompt_message in llm_result.message.tool_calls:
+            tool_calls.append((
+                prompt_message.id,
+                prompt_message.function.name,
+                json.loads(prompt_message.function.arguments),
+            ))
+
+        return tool_calls
 
     def organize_prompt_messages(self, prompt_template: str,
                                  query: str = None, 

+ 1 - 0
api/core/model_runtime/entities/model_entities.py

@@ -78,6 +78,7 @@ class ModelFeature(Enum):
     MULTI_TOOL_CALL = "multi-tool-call"
     AGENT_THOUGHT = "agent-thought"
     VISION = "vision"
+    STREAM_TOOL_CALL = "stream-tool-call"
 
 
 class DefaultParameterName(Enum):

+ 5 - 0
api/core/model_runtime/model_providers/azure_openai/_constant.py

@@ -36,6 +36,7 @@ LLM_BASE_MODELS = [
             features=[
                 ModelFeature.AGENT_THOUGHT,
                 ModelFeature.MULTI_TOOL_CALL,
+                ModelFeature.STREAM_TOOL_CALL,
             ],
             fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
             model_properties={
@@ -80,6 +81,7 @@ LLM_BASE_MODELS = [
             features=[
                 ModelFeature.AGENT_THOUGHT,
                 ModelFeature.MULTI_TOOL_CALL,
+                ModelFeature.STREAM_TOOL_CALL,
             ],
             fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
             model_properties={
@@ -124,6 +126,7 @@ LLM_BASE_MODELS = [
             features=[
                 ModelFeature.AGENT_THOUGHT,
                 ModelFeature.MULTI_TOOL_CALL,
+                ModelFeature.STREAM_TOOL_CALL,
             ],
             fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
             model_properties={
@@ -198,6 +201,7 @@ LLM_BASE_MODELS = [
             features=[
                 ModelFeature.AGENT_THOUGHT,
                 ModelFeature.MULTI_TOOL_CALL,
+                ModelFeature.STREAM_TOOL_CALL,
             ],
             fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
             model_properties={
@@ -272,6 +276,7 @@ LLM_BASE_MODELS = [
             features=[
                 ModelFeature.AGENT_THOUGHT,
                 ModelFeature.MULTI_TOOL_CALL,
+                ModelFeature.STREAM_TOOL_CALL,
             ],
             fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
             model_properties={

+ 24 - 4
api/core/model_runtime/model_providers/azure_openai/llm/llm.py

@@ -324,6 +324,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
                                               tools: Optional[list[PromptMessageTool]] = None) -> Generator:
         index = 0
         full_assistant_content = ''
+        delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None
         real_model = model
         system_fingerprint = None
         completion = ''
@@ -333,12 +334,32 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
 
             delta = chunk.choices[0]
 
-            if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''):
+            if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == '') and \
+                delta.delta.function_call is None:
                 continue
-
+            
             # assistant_message_tool_calls = delta.delta.tool_calls
             assistant_message_function_call = delta.delta.function_call
 
+            # extract tool calls from response
+            if delta_assistant_message_function_call_storage is not None:
+                # handle process of stream function call
+                if assistant_message_function_call:
+                    # message has not ended ever
+                    delta_assistant_message_function_call_storage.arguments += assistant_message_function_call.arguments
+                    continue
+                else:
+                    # message has ended
+                    assistant_message_function_call = delta_assistant_message_function_call_storage
+                    delta_assistant_message_function_call_storage = None
+            else:
+                if assistant_message_function_call:
+                    # start of stream function call
+                    delta_assistant_message_function_call_storage = assistant_message_function_call
+                    if delta_assistant_message_function_call_storage.arguments is None:
+                        delta_assistant_message_function_call_storage.arguments = ''
+                    continue
+
             # extract tool calls from response
             # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
             function_call = self._extract_response_function_call(assistant_message_function_call)
@@ -489,7 +510,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
         else:
             raise ValueError(f"Got unknown type {message}")
 
-        if message.name is not None:
+        if message.name:
             message_dict["name"] = message.name
 
         return message_dict
@@ -586,7 +607,6 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
         num_tokens = 0
         for tool in tools:
             num_tokens += len(encoding.encode('type'))
-            num_tokens += len(encoding.encode(tool.get("type")))
             num_tokens += len(encoding.encode('function'))
 
             # calculate num tokens for function object

+ 5 - 1
api/core/model_runtime/model_providers/chatglm/llm/llm.py

@@ -5,7 +5,7 @@ from typing import Generator, List, Optional, cast
 
 from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
 from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageFunction,
-                                                          PromptMessageTool, SystemPromptMessage, UserPromptMessage)
+                                                          PromptMessageTool, SystemPromptMessage, UserPromptMessage, ToolPromptMessage)
 from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
                                               InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
@@ -194,6 +194,10 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
         elif isinstance(message, SystemPromptMessage):
             message = cast(SystemPromptMessage, message)
             message_dict = {"role": "system", "content": message.content}
+        elif isinstance(message, ToolPromptMessage):
+            # check if last message is user message
+            message = cast(ToolPromptMessage, message)
+            message_dict = {"role": "function", "content": message.content}
         else:
             raise ValueError(f"Unknown message type {type(message)}")
         

+ 2 - 0
api/core/model_runtime/model_providers/minimax/llm/abab5.5-chat.yaml

@@ -4,6 +4,8 @@ label:
 model_type: llm
 features:
   - agent-thought
+  - tool-call
+  - stream-tool-call
 model_properties:
   mode: chat
   context_size: 16384

+ 2 - 0
api/core/model_runtime/model_providers/minimax/llm/abab6-chat.yaml

@@ -4,6 +4,8 @@ label:
 model_type: llm
 features:
   - agent-thought
+  - tool-call
+  - stream-tool-call
 model_properties:
   mode: chat
   context_size: 32768

+ 1 - 2
api/core/model_runtime/model_providers/minimax/llm/chat_completion.py

@@ -16,7 +16,7 @@ class MinimaxChatCompletion(object):
     """
     def generate(self, model: str, api_key: str, group_id: str, 
                  prompt_messages: List[MinimaxMessage], model_parameters: dict,
-                 tools: Dict[str, Any], stop: List[str] | None, stream: bool, user: str) \
+                 tools: List[Dict[str, Any]], stop: List[str] | None, stream: bool, user: str) \
         -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
         """
             generate chat completion
@@ -162,7 +162,6 @@ class MinimaxChatCompletion(object):
                 continue
 
             for choice in choices:
-                print(choice)
                 message = choice['delta']
                 yield MinimaxMessage(
                     content=message,

+ 37 - 9
api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py

@@ -17,7 +17,7 @@ class MinimaxChatCompletionPro(object):
     """
     def generate(self, model: str, api_key: str, group_id: str, 
                  prompt_messages: List[MinimaxMessage], model_parameters: dict,
-                 tools: Dict[str, Any], stop: List[str] | None, stream: bool, user: str) \
+                 tools: List[Dict[str, Any]], stop: List[str] | None, stream: bool, user: str) \
         -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
         """
             generate chat completion
@@ -82,6 +82,10 @@ class MinimaxChatCompletionPro(object):
             **extra_kwargs
         }
 
+        if tools:
+            body['functions'] = tools
+            body['function_call'] = { 'type': 'auto' }
+
         try:
             response = post(
                 url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300))
@@ -135,6 +139,7 @@ class MinimaxChatCompletionPro(object):
         """
             handle stream chat generate response
         """
+        function_call_storage = None
         for line in response.iter_lines():
             if not line:
                 continue
@@ -148,7 +153,7 @@ class MinimaxChatCompletionPro(object):
                 msg = data['base_resp']['status_msg']
                 self._handle_error(code, msg)
 
-            if data['reply']:
+            if data['reply'] or 'usage' in data and data['usage']:
                 total_tokens = data['usage']['total_tokens']
                 message =  MinimaxMessage(
                     role=MinimaxMessage.Role.ASSISTANT.value,
@@ -160,6 +165,12 @@ class MinimaxChatCompletionPro(object):
                     'total_tokens': total_tokens
                 }
                 message.stop_reason = data['choices'][0]['finish_reason']
+
+                if function_call_storage:
+                    function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
+                    function_call_message.function_call = function_call_storage
+                    yield function_call_message
+
                 yield message
                 return
 
@@ -168,11 +179,28 @@ class MinimaxChatCompletionPro(object):
                 continue
 
             for choice in choices:
-                message = choice['messages'][0]['text']
-                if not message:
-                    continue
+                message = choice['messages'][0]
+
+                if 'function_call' in message:
+                    if not function_call_storage:
+                        function_call_storage = message['function_call']
+                        if 'arguments' not in function_call_storage or not function_call_storage['arguments']:
+                            function_call_storage['arguments'] = ''
+                            continue
+                    else:
+                        function_call_storage['arguments'] += message['function_call']['arguments']
+                        continue
+                else:
+                    if function_call_storage:
+                        message['function_call'] = function_call_storage
+                        function_call_storage = None
                 
-                yield MinimaxMessage(
-                    content=message,
-                    role=MinimaxMessage.Role.ASSISTANT.value
-                )
+                minimax_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
+
+                if 'function_call' in message:
+                    minimax_message.function_call = message['function_call']
+
+                if 'text' in message:
+                    minimax_message.content = message['text']
+
+                yield minimax_message

+ 42 - 1
api/core/model_runtime/model_providers/minimax/llm/llm.py

@@ -2,7 +2,7 @@ from typing import Generator, List
 
 from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
 from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool,
-                                                          SystemPromptMessage, UserPromptMessage)
+                                                          SystemPromptMessage, UserPromptMessage, ToolPromptMessage)
 from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
                                               InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
@@ -84,6 +84,13 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
         """
         client: MinimaxChatCompletionPro = self.model_apis[model]()
 
+        if tools:
+            tools = [{
+                "name": tool.name,
+                "description": tool.description,
+                "parameters": tool.parameters
+            } for tool in tools]
+
         response = client.generate(
             model=model,
             api_key=credentials['minimax_api_key'],
@@ -109,7 +116,19 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
         elif isinstance(prompt_message, UserPromptMessage):
             return MinimaxMessage(role=MinimaxMessage.Role.USER.value, content=prompt_message.content)
         elif isinstance(prompt_message, AssistantPromptMessage):
+            if prompt_message.tool_calls:
+                message = MinimaxMessage(
+                    role=MinimaxMessage.Role.ASSISTANT.value,
+                    content=''
+                )
+                message.function_call={
+                    'name': prompt_message.tool_calls[0].function.name,
+                    'arguments': prompt_message.tool_calls[0].function.arguments
+                }
+                return message
             return MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content=prompt_message.content)
+        elif isinstance(prompt_message, ToolPromptMessage):
+            return MinimaxMessage(role=MinimaxMessage.Role.FUNCTION.value, content=prompt_message.content)
         else:
             raise NotImplementedError(f'Prompt message type {type(prompt_message)} is not supported')
 
@@ -151,6 +170,28 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
                         finish_reason=message.stop_reason if message.stop_reason else None,
                     ),
                 )
+            elif message.function_call:
+                if 'name' not in message.function_call or 'arguments' not in message.function_call:
+                    continue
+
+                yield LLMResultChunk(
+                    model=model,
+                    prompt_messages=prompt_messages,
+                    delta=LLMResultChunkDelta(
+                        index=0,
+                        message=AssistantPromptMessage(
+                            content='',
+                            tool_calls=[AssistantPromptMessage.ToolCall(
+                                id='',
+                                type='function',
+                                function=AssistantPromptMessage.ToolCall.ToolCallFunction(
+                                    name=message.function_call['name'],
+                                    arguments=message.function_call['arguments']
+                                )
+                            )]
+                        ),
+                    ),
+                )
             else:
                 yield LLMResultChunk(
                     model=model,

+ 10 - 0
api/core/model_runtime/model_providers/minimax/llm/types.py

@@ -7,13 +7,23 @@ class MinimaxMessage:
         USER = 'USER'
         ASSISTANT = 'BOT'
         SYSTEM = 'SYSTEM'
+        FUNCTION = 'FUNCTION'
 
     role: str = Role.USER.value
     content: str
     usage: Dict[str, int] = None
     stop_reason: str = ''
+    function_call: Dict[str, Any] = None
 
     def to_dict(self) -> Dict[str, Any]:
+        if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value:
+            return {
+                'sender_type': 'BOT',
+                'sender_name': '专家',
+                'text': '',
+                'function_call': self.function_call
+            }
+        
         return {
             'sender_type': self.role,
             'sender_name': '我' if self.role == 'USER' else '专家',

+ 1 - 0
api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0613.yaml

@@ -6,6 +6,7 @@ model_type: llm
 features:
   - multi-tool-call
   - agent-thought
+  - stream-tool-call
 model_properties:
   mode: chat
   context_size: 4096

+ 1 - 0
api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-1106.yaml

@@ -6,6 +6,7 @@ model_type: llm
 features:
   - multi-tool-call
   - agent-thought
+  - stream-tool-call
 model_properties:
   mode: chat
   context_size: 16385

+ 1 - 0
api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k-0613.yaml

@@ -6,6 +6,7 @@ model_type: llm
 features:
   - multi-tool-call
   - agent-thought
+  - stream-tool-call
 model_properties:
   mode: chat
   context_size: 16385

+ 1 - 0
api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k.yaml

@@ -6,6 +6,7 @@ model_type: llm
 features:
   - multi-tool-call
   - agent-thought
+  - stream-tool-call
 model_properties:
   mode: chat
   context_size: 16385

+ 1 - 0
api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml

@@ -6,6 +6,7 @@ model_type: llm
 features:
   - multi-tool-call
   - agent-thought
+  - stream-tool-call
 model_properties:
   mode: chat
   context_size: 4096

+ 1 - 0
api/core/model_runtime/model_providers/openai/llm/gpt-4-0125-preview.yaml

@@ -6,6 +6,7 @@ model_type: llm
 features:
   - multi-tool-call
   - agent-thought
+  - stream-tool-call
 model_properties:
   mode: chat
   context_size: 128000

+ 1 - 0
api/core/model_runtime/model_providers/openai/llm/gpt-4-1106-preview.yaml

@@ -6,6 +6,7 @@ model_type: llm
 features:
   - multi-tool-call
   - agent-thought
+  - stream-tool-call
 model_properties:
   mode: chat
   context_size: 128000

+ 1 - 0
api/core/model_runtime/model_providers/openai/llm/gpt-4-32k.yaml

@@ -6,6 +6,7 @@ model_type: llm
 features:
   - multi-tool-call
   - agent-thought
+  - stream-tool-call
 model_properties:
   mode: chat
   context_size: 32768

+ 1 - 0
api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo-preview.yaml

@@ -6,6 +6,7 @@ model_type: llm
 features:
   - multi-tool-call
   - agent-thought
+  - stream-tool-call
 model_properties:
   mode: chat
   context_size: 128000

+ 1 - 0
api/core/model_runtime/model_providers/openai/llm/gpt-4.yaml

@@ -6,6 +6,7 @@ model_type: llm
 features:
   - multi-tool-call
   - agent-thought
+  - stream-tool-call
 model_properties:
   mode: chat
   context_size: 8192

+ 1 - 1
api/core/model_runtime/model_providers/openai/llm/llm.py

@@ -671,7 +671,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
         else:
             raise ValueError(f"Got unknown type {message}")
 
-        if message.name is not None:
+        if message.name:
             message_dict["name"] = message.name
 
         return message_dict

+ 27 - 4
api/core/model_runtime/model_providers/xinference/llm/llm.py

@@ -3,14 +3,14 @@ from typing import Generator, Iterator, List, Optional, Union, cast
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
 from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool,
-                                                          SystemPromptMessage, UserPromptMessage)
+                                                          SystemPromptMessage, UserPromptMessage, ToolPromptMessage)
 from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelPropertyKey, ModelType,
-                                                        ParameterRule, ParameterType)
+                                                        ParameterRule, ParameterType, ModelFeature)
 from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
                                               InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
-from core.model_runtime.model_providers.xinference.llm.xinference_helper import (XinferenceHelper,
+from core.model_runtime.model_providers.xinference.xinference_helper import (XinferenceHelper,
                                                                                  XinferenceModelExtraParameter)
 from core.model_runtime.utils import helper
 from openai import (APIConnectionError, APITimeoutError, AuthenticationError, ConflictError, InternalServerError,
@@ -33,6 +33,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
 
             see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke`
         """
+        if 'temperature' in model_parameters:
+            if model_parameters['temperature'] < 0.01:
+                model_parameters['temperature'] = 0.01
+            elif model_parameters['temperature'] > 1.0:
+                model_parameters['temperature'] = 0.99
+
         return self._generate(
             model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters,
             tools=tools, stop=stop, stream=stream, user=user,
@@ -65,6 +71,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
                     credentials['completion_type'] = 'completion'
                 else:
                     raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported')
+                
+            if extra_param.support_function_call:
+                credentials['support_function_call'] = True
 
         except RuntimeError as e:
             raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}')
@@ -220,6 +229,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
         elif isinstance(message, SystemPromptMessage):
             message = cast(SystemPromptMessage, message)
             message_dict = {"role": "system", "content": message.content}
+        elif isinstance(message, ToolPromptMessage):
+            message = cast(ToolPromptMessage, message)
+            message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content}
         else:
             raise ValueError(f"Unknown message type {type(message)}")
         
@@ -237,7 +249,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
                 label=I18nObject(
                     zh_Hans='温度',
                     en_US='Temperature'
-                )
+                ),
             ),
             ParameterRule(
                 name='top_p',
@@ -282,6 +294,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
                 completion_type = LLMMode.COMPLETION.value
             else:
                 raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
+            
+        support_function_call = credentials.get('support_function_call', False)
 
         entity = AIModelEntity(
             model=model,
@@ -290,6 +304,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
             ),
             fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
             model_type=ModelType.LLM,
+            features=[
+                ModelFeature.TOOL_CALL
+            ] if support_function_call else [],
             model_properties={ 
                 ModelPropertyKey.MODE: completion_type,
             },
@@ -310,6 +327,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
             
             extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter`
         """
+        if 'server_url' not in credentials:
+            raise CredentialsValidateFailedError('server_url is required in credentials')
+        
+        if credentials['server_url'].endswith('/'):
+            credentials['server_url'] = credentials['server_url'][:-1]
+
         client = OpenAI(
             base_url=f'{credentials["server_url"]}/v1',
             api_key='abc',

+ 19 - 4
api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py

@@ -2,7 +2,7 @@ import time
 from typing import Optional
 
 from core.model_runtime.entities.common_entities import I18nObject
-from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType
+from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType, ModelPropertyKey
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
                                               InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
@@ -10,6 +10,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
 from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
 from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle, RESTfulModelHandle
 
+from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper
 
 class XinferenceTextEmbeddingModel(TextEmbeddingModel):
     """
@@ -35,7 +36,10 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
         """
         server_url = credentials['server_url']
         model_uid = credentials['model_uid']
-        
+
+        if server_url.endswith('/'):
+            server_url = server_url[:-1]
+
         client = Client(base_url=server_url)
         
         try:
@@ -102,8 +106,15 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
         :return:
         """
         try:
+            server_url = credentials['server_url']
+            model_uid = credentials['model_uid']
+            extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid)
+
+            if extra_args.max_tokens:
+                credentials['max_tokens'] = extra_args.max_tokens
+
             self._invoke(model=model, credentials=credentials, texts=['ping'])
-        except InvokeAuthorizationError:
+        except (InvokeAuthorizationError, RuntimeError):
             raise CredentialsValidateFailedError('Invalid api key')
 
     @property
@@ -160,6 +171,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
         """
             used to define customizable model schema
         """
+        
         entity = AIModelEntity(
             model=model,
             label=I18nObject(
@@ -167,7 +179,10 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
             ),
             fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
             model_type=ModelType.TEXT_EMBEDDING,
-            model_properties={},
+            model_properties={
+                ModelPropertyKey.MAX_CHUNKS: 1,
+                ModelPropertyKey.CONTEXT_SIZE: 'max_tokens' in credentials and credentials['max_tokens'] or 512,
+            },
             parameter_rules=[]
         )
 

+ 19 - 6
api/core/model_runtime/model_providers/xinference/llm/xinference_helper.py → api/core/model_runtime/model_providers/xinference/xinference_helper.py

@@ -1,6 +1,7 @@
 from threading import Lock
 from time import time
 from typing import List
+from os import path
 
 from requests import get
 from requests.adapters import HTTPAdapter
@@ -12,11 +13,16 @@ class XinferenceModelExtraParameter(object):
     model_format: str
     model_handle_type: str
     model_ability: List[str]
+    max_tokens: int = 512
+    support_function_call: bool = False
 
-    def __init__(self, model_format: str, model_handle_type: str, model_ability: List[str]) -> None:
+    def __init__(self, model_format: str, model_handle_type: str, model_ability: List[str], 
+                 support_function_call: bool, max_tokens: int) -> None:
         self.model_format = model_format
         self.model_handle_type = model_handle_type
         self.model_ability = model_ability
+        self.support_function_call = support_function_call
+        self.max_tokens = max_tokens
 
 cache = {}
 cache_lock = Lock()
@@ -49,7 +55,7 @@ class XinferenceHelper:
             get xinference model extra parameter like model_format and model_handle_type
         """
 
-        url = f'{server_url}/v1/models/{model_uid}'
+        url = path.join(server_url, 'v1/models', model_uid)
 
         # this methid is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
         session = Session()
@@ -66,10 +72,12 @@ class XinferenceHelper:
         
         response_json = response.json()
 
-        model_format = response_json['model_format']
-        model_ability = response_json['model_ability']
+        model_format = response_json.get('model_format', 'ggmlv3')
+        model_ability = response_json.get('model_ability', [])
 
-        if model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']:
+        if response_json.get('model_type') == 'embedding':
+            model_handle_type = 'embedding'
+        elif model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']:
             model_handle_type = 'chatglm'
         elif 'generate' in model_ability:
             model_handle_type = 'generate'
@@ -78,8 +86,13 @@ class XinferenceHelper:
         else:
             raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported')
         
+        support_function_call = 'tools' in model_ability
+        max_tokens = response_json.get('max_tokens', 512)
+        
         return XinferenceModelExtraParameter(
             model_format=model_format,
             model_handle_type=model_handle_type,
-            model_ability=model_ability
+            model_ability=model_ability,
+            support_function_call=support_function_call,
+            max_tokens=max_tokens
         )

+ 4 - 0
api/core/model_runtime/model_providers/zhipuai/llm/glm_3_turbo.yaml

@@ -2,6 +2,10 @@ model: glm-3-turbo
 label:
   en_US: glm-3-turbo
 model_type: llm
+features:
+  - multi-tool-call
+  - agent-thought
+  - stream-tool-call
 model_properties:
   mode: chat
 parameter_rules:

+ 4 - 0
api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml

@@ -2,6 +2,10 @@ model: glm-4
 label:
   en_US: glm-4
 model_type: llm
+features:
+  - multi-tool-call
+  - agent-thought
+  - stream-tool-call
 model_properties:
   mode: chat
 parameter_rules:

+ 21 - 0
api/core/model_runtime/model_providers/zhipuai/llm/llm.py

@@ -194,6 +194,27 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
                             'content': prompt_message.content,
                             'tool_call_id': prompt_message.tool_call_id
                         })
+                    elif isinstance(prompt_message, AssistantPromptMessage):
+                        if prompt_message.tool_calls:
+                            params['messages'].append({
+                                'role': 'assistant',
+                                'content': prompt_message.content,
+                                'tool_calls': [
+                                    {
+                                        'id': tool_call.id,
+                                        'type': tool_call.type,
+                                        'function': {
+                                            'name': tool_call.function.name,
+                                            'arguments': tool_call.function.arguments
+                                        }
+                                    } for tool_call in prompt_message.tool_calls
+                                ]
+                            })
+                        else:
+                            params['messages'].append({
+                                'role': 'assistant',
+                                'content': prompt_message.content
+                            })
                     else:
                         params['messages'].append({
                             'role': prompt_message.role.value,

+ 1 - 1
api/requirements.txt

@@ -47,7 +47,7 @@ dashscope[tokenizer]~=1.14.0
 huggingface_hub~=0.16.4
 transformers~=4.31.0
 pandas==1.5.3
-xinference-client~=0.6.4
+xinference-client~=0.8.1
 safetensors==0.3.2
 zhipuai==1.0.7
 werkzeug~=3.0.1

+ 61 - 32
api/tests/integration_tests/model_runtime/__mock/xinference.py

@@ -19,58 +19,86 @@ class MockXinferenceClass(object):
             raise RuntimeError('404 Not Found')
         
         if 'generate' == model_uid:
-            return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url)
+            return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={})
         if 'chat' == model_uid:
-            return RESTfulChatModelHandle(model_uid, base_url=self.base_url)
+            return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={})
         if 'embedding' == model_uid:
-            return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url)
+            return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={})
         if 'rerank' == model_uid:
-            return RESTfulRerankModelHandle(model_uid, base_url=self.base_url)
+            return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={})
         raise RuntimeError('404 Not Found')
         
     def get(self: Session, url: str, **kwargs):
-        if '/v1/models/' in url:
-            response = Response()
-            
+        response = Response()
+        if 'v1/models/' in url:
             # get model uid
             model_uid = url.split('/')[-1]
             if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \
                 model_uid not in ['generate', 'chat', 'embedding', 'rerank']:
                 response.status_code = 404
-                raise ConnectionError('404 Not Found')
+                return response
 
             # check if url is valid
             if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url):
                 response.status_code = 404
-                raise ConnectionError('404 Not Found')
-
+                return response
+            
+            if model_uid in ['generate', 'chat']:
+                response.status_code = 200
+                response._content = b'''{
+        "model_type": "LLM",
+        "address": "127.0.0.1:43877",
+        "accelerators": [
+            "0",
+            "1"
+        ],
+        "model_name": "chatglm3-6b",
+        "model_lang": [
+            "en"
+        ],
+        "model_ability": [
+            "generate",
+            "chat"
+        ],
+        "model_description": "latest chatglm3",
+        "model_format": "pytorch",
+        "model_size_in_billions": 7,
+        "quantization": "none",
+        "model_hub": "huggingface",
+        "revision": null,
+        "context_length": 2048,
+        "replica": 1
+    }'''
+                return response
+            
+            elif model_uid == 'embedding':
+                response.status_code = 200
+                response._content = b'''{
+        "model_type": "embedding",
+        "address": "127.0.0.1:43877",
+        "accelerators": [
+            "0",
+            "1"
+        ],
+        "model_name": "bge",
+        "model_lang": [
+            "en"
+        ],
+        "revision": null,
+        "max_tokens": 512
+}'''
+                return response
+            
+        elif 'v1/cluster/auth' in url:
             response.status_code = 200
             response._content = b'''{
-    "model_type": "LLM",
-    "address": "127.0.0.1:43877",
-    "accelerators": [
-        "0",
-        "1"
-    ],
-    "model_name": "chatglm3-6b",
-    "model_lang": [
-        "en"
-    ],
-    "model_ability": [
-        "generate",
-        "chat"
-    ],
-    "model_description": "latest chatglm3",
-    "model_format": "pytorch",
-    "model_size_in_billions": 7,
-    "quantization": "none",
-    "model_hub": "huggingface",
-    "revision": null,
-    "context_length": 2048,
-    "replica": 1
+    "auth": true
 }'''
             return response
         
+    def _check_cluster_authenticated(self):
+        self._cluster_authed = True
+        
     def rerank(self: RESTfulRerankModelHandle, documents: List[str], query: str, top_n: int) -> dict:
         # check if self._model_uid is a valid uuid
         if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
@@ -133,6 +161,7 @@ MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
 def setup_xinference_mock(request, monkeypatch: MonkeyPatch):
     if MOCK:
         monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model)
+        monkeypatch.setattr(Client, '_check_cluster_authenticated', MockXinferenceClass._check_cluster_authenticated)
         monkeypatch.setattr(Session, 'get', MockXinferenceClass.get)
         monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding)
         monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank)