Browse Source

Feat/zhipuai function calling (#2199)

Co-authored-by: Joel <iamjoel007@gmail.com>
Yeuoly 1 năm trước cách đây
mục cha
commit
b921c55677
46 tập tin đã thay đổi với 2115 bổ sung138 xóa
  1. 0 61
      api/core/model_runtime/model_providers/zhipuai/_client.py
  2. 152 59
      api/core/model_runtime/model_providers/zhipuai/llm/llm.py
  3. 11 12
      api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py
  4. 17 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py
  5. 2 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py
  6. 71 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py
  7. 5 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/__init__.py
  8. 0 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/__init__.py
  9. 87 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py
  10. 16 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/chat.py
  11. 71 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py
  12. 49 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py
  13. 78 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py
  14. 0 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/__init__.py
  15. 15 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py
  16. 115 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py
  17. 55 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py
  18. 0 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/__init__.py
  19. 17 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_api.py
  20. 115 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py
  21. 90 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py
  22. 46 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py
  23. 377 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py
  24. 30 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_jwt_token.py
  25. 54 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py
  26. 121 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py
  27. 149 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py
  28. 18 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils.py
  29. 0 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/__init__.py
  30. 0 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/__init__.py
  31. 23 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py
  32. 45 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py
  33. 55 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion_chunk.py
  34. 8 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completions_create_param.py
  35. 20 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/embeddings.py
  36. 24 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py
  37. 5 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py
  38. 52 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py
  39. 36 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py
  40. 15 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/job_create_params.py
  41. 18 0
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/image.py
  42. 47 1
      api/tests/integration_tests/model_runtime/zhipuai/test_llm.py
  43. 1 1
      api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py
  44. 1 1
      web/app/components/app/chat/answer/index.tsx
  45. 2 3
      web/app/components/app/configuration/index.tsx
  46. 2 0
      web/config/index.ts

+ 0 - 61
api/core/model_runtime/model_providers/zhipuai/_client.py

@@ -1,61 +0,0 @@
-"""Wrapper around ZhipuAI APIs."""
-from __future__ import annotations
-
-import logging
-import posixpath
-
-from pydantic import BaseModel, Extra
-from zhipuai.model_api.api import InvokeType
-from zhipuai.utils import jwt_token
-from zhipuai.utils.http_client import post, stream
-from zhipuai.utils.sse_client import SSEClient
-
-logger = logging.getLogger(__name__)
-
-
-class ZhipuModelAPI(BaseModel):
-    base_url: str = "https://open.bigmodel.cn/api/paas/v3/model-api"
-    api_key: str
-    api_timeout_seconds = 60
-
-    class Config:
-        """Configuration for this pydantic object."""
-
-        extra = Extra.forbid
-
-    def invoke(self, **kwargs):
-        url = self._build_api_url(kwargs, InvokeType.SYNC)
-        response = post(url, self._generate_token(), kwargs, self.api_timeout_seconds)
-        if not response['success']:
-            raise ValueError(
-                f"Error Code: {response['code']}, Message: {response['msg']} "
-            )
-        return response
-
-    def sse_invoke(self, **kwargs):
-        url = self._build_api_url(kwargs, InvokeType.SSE)
-        data = stream(url, self._generate_token(), kwargs, self.api_timeout_seconds)
-        return SSEClient(data)
-
-    def _build_api_url(self, kwargs, *path):
-        if kwargs:
-            if "model" not in kwargs:
-                raise Exception("model param missed")
-            model = kwargs.pop("model")
-        else:
-            model = "-"
-
-        return posixpath.join(self.base_url, model, *path)
-
-    def _generate_token(self):
-        if not self.api_key:
-            raise Exception(
-                "api_key not provided, you could provide it."
-            )
-
-        try:
-            return jwt_token.generate_token(self.api_key)
-        except Exception:
-            raise ValueError(
-                f"Your api_key is invalid, please check it."
-            )

+ 152 - 59
api/core/model_runtime/model_providers/zhipuai/llm/llm.py

@@ -3,13 +3,15 @@ from typing import Any, Dict, Generator, List, Optional, Union
 
 from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
 from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageRole,
-                                                          PromptMessageTool, SystemPromptMessage, UserPromptMessage,
+                                                          PromptMessageTool, SystemPromptMessage, UserPromptMessage, ToolPromptMessage,
                                                           TextPromptMessageContent, ImagePromptMessageContent, PromptMessageContentType)
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.utils import helper
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
-from core.model_runtime.model_providers.zhipuai._client import ZhipuModelAPI
 from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
-
+from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
+from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ChatCompletionChunk
+from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion import Completion
 
 class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
 
@@ -35,7 +37,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
         credentials_kwargs = self._to_credential_kwargs(credentials)
 
         # invoke model
-        return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, stop, stream, user)
+        return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, tools, stop, stream, user)
 
     def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
                        tools: Optional[list[PromptMessageTool]] = None) -> int:
@@ -48,7 +50,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
         :param tools: tools for tool calling
         :return:
         """
-        prompt = self._convert_messages_to_prompt(prompt_messages)
+        prompt = self._convert_messages_to_prompt(prompt_messages, tools)
 
         return self._get_num_tokens_by_gpt2(prompt)
 
@@ -72,6 +74,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
                 model_parameters={
                     "temperature": 0.5,
                 },
+                tools=[],
                 stream=False
             )
         except Exception as ex:
@@ -79,6 +82,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
 
     def _generate(self, model: str, credentials_kwargs: dict,
                   prompt_messages: list[PromptMessage], model_parameters: dict,
+                  tools: Optional[list[PromptMessageTool]] = None,
                   stop: Optional[List[str]] = None, stream: bool = True,
                   user: Optional[str] = None) -> Union[LLMResult, Generator]:
         """
@@ -97,7 +101,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
         if stop:
             extra_model_kwargs['stop_sequences'] = stop
 
-        client = ZhipuModelAPI(
+        client = ZhipuAI(
             api_key=credentials_kwargs['api_key']
         )
 
@@ -128,11 +132,17 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
                     # not support image message
                     continue
 
-                if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.USER:
+                if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.USER and \
+                    copy_prompt_message.role == PromptMessageRole.USER:
                     new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content
                 else:
                     if copy_prompt_message.role == PromptMessageRole.USER:
                         new_prompt_messages.append(copy_prompt_message)
+                    elif copy_prompt_message.role == PromptMessageRole.TOOL:
+                        new_prompt_messages.append(copy_prompt_message)
+                    elif copy_prompt_message.role == PromptMessageRole.SYSTEM:
+                        new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content)
+                        new_prompt_messages.append(new_prompt_message)
                     else:
                         new_prompt_message = UserPromptMessage(content=copy_prompt_message.content)
                         new_prompt_messages.append(new_prompt_message)
@@ -145,7 +155,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
         if model == 'glm-4v':
             params = {
                 'model': model,
-                'prompt': [{
+                'messages': [{
                     'role': prompt_message.role.value,
                     'content': 
                         [
@@ -171,23 +181,63 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
         else:
             params = {
                 'model': model,
-                'prompt': [{
-                    'role': prompt_message.role.value,
-                    'content': prompt_message.content,
-                } for prompt_message in new_prompt_messages],
+                'messages': [],
                 **model_parameters
             }
+            # glm model
+            if not model.startswith('chatglm'):
+
+                for prompt_message in new_prompt_messages:
+                    if prompt_message.role == PromptMessageRole.TOOL:
+                        params['messages'].append({
+                            'role': 'tool',
+                            'content': prompt_message.content,
+                            'tool_call_id': prompt_message.tool_call_id
+                        })
+                    else:
+                        params['messages'].append({
+                            'role': prompt_message.role.value,
+                            'content': prompt_message.content
+                        })
+            else:
+                # chatglm model
+                for prompt_message in new_prompt_messages:
+                    # merge system message to user message
+                    if prompt_message.role == PromptMessageRole.SYSTEM or \
+                        prompt_message.role == PromptMessageRole.TOOL or \
+                        prompt_message.role == PromptMessageRole.USER:
+                        if len(params['messages']) > 0 and params['messages'][-1]['role'] == 'user':
+                            params['messages'][-1]['content'] += "\n\n" + prompt_message.content
+                        else:
+                            params['messages'].append({
+                                'role': 'user',
+                                'content': prompt_message.content
+                            })
+                    else:
+                        params['messages'].append({
+                            'role': prompt_message.role.value,
+                            'content': prompt_message.content
+                        })
+
+        if tools and len(tools) > 0:
+            params['tools'] = [
+                {
+                    'type': 'function',
+                    'function': helper.dump_model(tool)
+                } for tool in tools
+            ]
 
         if stream:
-            response = client.sse_invoke(incremental=True, **params).events()
-            return self._handle_generate_stream_response(model, credentials_kwargs, response, prompt_messages)
+            response = client.chat.completions.create(stream=stream, **params)
+            return self._handle_generate_stream_response(model, credentials_kwargs, tools, response, prompt_messages)
 
-        response = client.invoke(**params)
-        return self._handle_generate_response(model, credentials_kwargs, response, prompt_messages)
+        response = client.chat.completions.create(**params)
+        return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages)
         
     def _handle_generate_response(self, model: str, 
                                   credentials: dict,
-                                  response: Dict[str, Any],
+                                  tools: Optional[list[PromptMessageTool]],
+                                  response: Completion,
                                   prompt_messages: list[PromptMessage]) -> LLMResult:
         """
         Handle llm response
@@ -197,26 +247,39 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
         :param prompt_messages: prompt messages
         :return: llm response
         """
-        data = response["data"]
         text = ''
-        for res in data["choices"]:
-            text += res['content']
+        assistant_tool_calls: List[AssistantPromptMessage.ToolCall] = []
+        for choice in response.choices:
+            if choice.message.tool_calls:
+                for tool_call in choice.message.tool_calls:
+                    if tool_call.type == 'function':
+                        assistant_tool_calls.append(
+                            AssistantPromptMessage.ToolCall(
+                                id=tool_call.id,
+                                type=tool_call.type,
+                                function=AssistantPromptMessage.ToolCall.ToolCallFunction(
+                                    name=tool_call.function.name,
+                                    arguments=tool_call.function.arguments,
+                                )
+                            )
+                        )
+
+            text += choice.message.content or ''
           
-        token_usage = data.get("usage")
-        if token_usage is not None:
-            if 'prompt_tokens' not in token_usage:
-                token_usage['prompt_tokens'] = 0
-            if 'completion_tokens' not in token_usage:
-                token_usage['completion_tokens'] = token_usage['total_tokens']
+        prompt_usage = response.usage.prompt_tokens
+        completion_usage = response.usage.completion_tokens
 
         # transform usage
-        usage = self._calc_response_usage(model, credentials, token_usage['prompt_tokens'], token_usage['completion_tokens'])
+        usage = self._calc_response_usage(model, credentials, prompt_usage, completion_usage)
 
         # transform response
         result = LLMResult(
             model=model,
             prompt_messages=prompt_messages,
-            message=AssistantPromptMessage(content=text),
+            message=AssistantPromptMessage(
+                content=text,
+                tool_calls=assistant_tool_calls
+            ),
             usage=usage,
         )
 
@@ -224,7 +287,8 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
 
     def _handle_generate_stream_response(self, model: str, 
                                          credentials: dict,
-                                         responses: list[Generator],
+                                         tools: Optional[list[PromptMessageTool]],
+                                         responses: Generator[ChatCompletionChunk, None, None],
                                          prompt_messages: list[PromptMessage]) -> Generator:
         """
         Handle llm stream response
@@ -234,39 +298,64 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
         :param prompt_messages: prompt messages
         :return: llm response chunk generator result
         """
-        for index, event in enumerate(responses):
-            if event.event == "add":
+        full_assistant_content = ''
+        for chunk in responses:
+            if len(chunk.choices) == 0:
+                continue
+
+            delta = chunk.choices[0]
+
+            if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''):
+                continue
+            
+            assistant_tool_calls: List[AssistantPromptMessage.ToolCall] = []
+            for tool_call in delta.delta.tool_calls or []:
+                if tool_call.type == 'function':
+                    assistant_tool_calls.append(
+                        AssistantPromptMessage.ToolCall(
+                            id=tool_call.id,
+                            type=tool_call.type,
+                            function=AssistantPromptMessage.ToolCall.ToolCallFunction(
+                                name=tool_call.function.name,
+                                arguments=tool_call.function.arguments,
+                            )
+                        )
+                    )
+
+            # transform assistant message to prompt message
+            assistant_prompt_message = AssistantPromptMessage(
+                content=delta.delta.content if delta.delta.content else '',
+                tool_calls=assistant_tool_calls
+            )
+
+            full_assistant_content += delta.delta.content if delta.delta.content else ''
+
+            if delta.finish_reason is not None and chunk.usage is not None:
+                completion_tokens = chunk.usage.completion_tokens
+                prompt_tokens = chunk.usage.prompt_tokens
+
+                # transform usage
+                usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
                 yield LLMResultChunk(
+                    model=chunk.model,
                     prompt_messages=prompt_messages,
-                    model=model,
+                    system_fingerprint='',
                     delta=LLMResultChunkDelta(
-                        index=index,
-                        message=AssistantPromptMessage(content=event.data)
+                        index=delta.index,
+                        message=assistant_prompt_message,
+                        finish_reason=delta.finish_reason,
+                        usage=usage
                     )
                 )
-            elif event.event == "error" or event.event == "interrupted":
-                raise ValueError(
-                    f"{event.data}"
-                )
-            elif event.event == "finish":
-                meta = json.loads(event.meta)
-                token_usage = meta['usage']
-                if token_usage is not None:
-                    if 'prompt_tokens' not in token_usage:
-                        token_usage['prompt_tokens'] = 0
-                    if 'completion_tokens' not in token_usage:
-                        token_usage['completion_tokens'] = token_usage['total_tokens']
-
-                usage = self._calc_response_usage(model, credentials, token_usage['prompt_tokens'], token_usage['completion_tokens'])
-
+            else:
                 yield LLMResultChunk(
-                    model=model,
+                    model=chunk.model,
                     prompt_messages=prompt_messages,
+                    system_fingerprint='',
                     delta=LLMResultChunkDelta(
-                        index=index,
-                        message=AssistantPromptMessage(content=event.data),
-                        finish_reason='finish',
-                        usage=usage
+                        index=delta.index,
+                        message=assistant_prompt_message,
                     )
                 )
 
@@ -291,11 +380,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
             raise ValueError(f"Got unknown type {message}")
 
         return message_text
-    
-    def _convert_messages_to_prompt(self, messages: List[PromptMessage]) -> str:
-        """
-        Format a list of messages into a full prompt for the Anthropic model
 
+
+    def _convert_messages_to_prompt(self, messages: List[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> str:
+        """
         :param messages: List of PromptMessage to combine.
         :return: Combined string with necessary human_prompt and ai_prompt tags.
         """
@@ -306,5 +394,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
             for message in messages
         )
 
+        if tools and len(tools) > 0:
+            text += "\n\nTools:"
+            for tool in tools:
+                text += f"\n{tool.json()}"
+
         # trim off the trailing ' ' that might come from the "Assistant: "
-        return text.rstrip()
+        return text.rstrip()

+ 11 - 12
api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py

@@ -5,7 +5,7 @@ from core.model_runtime.entities.model_entities import PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
 from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
-from core.model_runtime.model_providers.zhipuai._client import ZhipuModelAPI
+from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
 from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
 from langchain.schema.language_model import _get_token_ids_default_method
 
@@ -28,7 +28,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
         :return: embeddings result
         """
         credentials_kwargs = self._to_credential_kwargs(credentials)
-        client = ZhipuModelAPI(
+        client = ZhipuAI(
             api_key=credentials_kwargs['api_key']
         )
 
@@ -69,7 +69,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
         try:
             # transform credentials to kwargs for model instance
             credentials_kwargs = self._to_credential_kwargs(credentials)
-            client = ZhipuModelAPI(
+            client = ZhipuAI(
                 api_key=credentials_kwargs['api_key']
             )
 
@@ -82,7 +82,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
         except Exception as ex:
             raise CredentialsValidateFailedError(str(ex))
 
-    def embed_documents(self, model: str, client: ZhipuModelAPI, texts: List[str]) -> Tuple[List[List[float]], int]:
+    def embed_documents(self, model: str, client: ZhipuAI, texts: List[str]) -> Tuple[List[List[float]], int]:
         """Call out to ZhipuAI's embedding endpoint.
 
         Args:
@@ -91,17 +91,16 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
         Returns:
             List of embeddings, one for each text.
         """
-        
-
         embeddings = []
-        for text in texts:
-            response = client.invoke(model=model, prompt=text)
-            data = response["data"]
-            embeddings.append(data.get('embedding'))
+        embedding_used_tokens = 0
 
-        embedding_used_tokens = data.get('usage')
+        for text in texts:
+            response = client.embeddings.create(model=model, input=text)
+            data = response.data[0]
+            embeddings.append(data.embedding)
+            embedding_used_tokens += response.usage.total_tokens
 
-        return [list(map(float, e)) for e in embeddings], embedding_used_tokens['total_tokens'] if embedding_used_tokens else 0
+        return [list(map(float, e)) for e in embeddings], embedding_used_tokens
     
     def embed_query(self, text: str) -> List[float]:
         """Call out to ZhipuAI's embedding endpoint.

+ 17 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py

@@ -0,0 +1,17 @@
+
+from ._client import ZhipuAI
+
+from .core._errors import (
+    ZhipuAIError,
+    APIStatusError,
+    APIRequestFailedError,
+    APIAuthenticationError,
+    APIReachLimitError,
+    APIInternalError,
+    APIServerFlowExceedError,
+    APIResponseError,
+    APIResponseValidationError,
+    APITimeoutError,
+)
+
+from .__version__ import __version__

+ 2 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py

@@ -0,0 +1,2 @@
+
+__version__ = 'v2.0.1'

+ 71 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py

@@ -0,0 +1,71 @@
+from __future__ import annotations
+
+from typing import Union, Mapping
+
+from typing_extensions import override
+
+from .core import _jwt_token
+from .core._errors import ZhipuAIError
+from .core._http_client import HttpClient, ZHIPUAI_DEFAULT_MAX_RETRIES
+from .core._base_type import NotGiven, NOT_GIVEN
+from . import api_resource
+import os
+import httpx
+from httpx import Timeout
+
+
+class ZhipuAI(HttpClient):
+    chat: api_resource.chat
+    api_key: str
+
+    def __init__(
+            self,
+            *,
+            api_key: str | None = None,
+            base_url: str | httpx.URL | None = None,
+            timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
+            max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES,
+            http_client: httpx.Client | None = None,
+            custom_headers: Mapping[str, str] | None = None
+    ) -> None:
+        # if api_key is None:
+        #     api_key = os.environ.get("ZHIPUAI_API_KEY")
+        if api_key is None:
+            raise ZhipuAIError("未提供api_key,请通过参数或环境变量提供")
+        self.api_key = api_key
+
+        if base_url is None:
+            base_url = os.environ.get("ZHIPUAI_BASE_URL")
+        if base_url is None:
+            base_url = f"https://open.bigmodel.cn/api/paas/v4"
+        from .__version__ import __version__
+        super().__init__(
+            version=__version__,
+            base_url=base_url,
+            timeout=timeout,
+            custom_httpx_client=http_client,
+            custom_headers=custom_headers,
+        )
+        self.chat = api_resource.chat.Chat(self)
+        self.images = api_resource.images.Images(self)
+        self.embeddings = api_resource.embeddings.Embeddings(self)
+        self.files = api_resource.files.Files(self)
+        self.fine_tuning = api_resource.fine_tuning.FineTuning(self)
+
+    @property
+    @override
+    def _auth_headers(self) -> dict[str, str]:
+        api_key = self.api_key
+        return {"Authorization": f"{_jwt_token.generate_token(api_key)}"}
+
+    def __del__(self) -> None:
+        if (not hasattr(self, "_has_custom_http_client")
+                or not hasattr(self, "close")
+                or not hasattr(self, "_client")):
+            # if the '__init__' method raised an error, self would not have client attr
+            return
+
+        if self._has_custom_http_client:
+            return
+
+        self.close()

+ 5 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/__init__.py

@@ -0,0 +1,5 @@
+from .chat import chat
+from .images import Images
+from .embeddings import Embeddings
+from .files import Files
+from .fine_tuning import fine_tuning

+ 0 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/__init__.py


+ 87 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py

@@ -0,0 +1,87 @@
+from __future__ import annotations
+
+from typing import Union, List, Optional, TYPE_CHECKING
+
+import httpx
+from typing_extensions import Literal
+
+from ...core._base_api import BaseAPI
+from ...core._base_type import NotGiven, NOT_GIVEN, Headers
+from ...core._http_client import make_user_request_input
+from ...types.chat.async_chat_completion import AsyncTaskStatus, AsyncCompletion
+
+if TYPE_CHECKING:
+    from ..._client import ZhipuAI
+
+
+class AsyncCompletions(BaseAPI):
+    def __init__(self, client: "ZhipuAI") -> None:
+        super().__init__(client)
+
+
+    def create(
+            self,
+            *,
+            model: str,
+            request_id: Optional[str] | NotGiven = NOT_GIVEN,
+            do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
+            temperature: Optional[float] | NotGiven = NOT_GIVEN,
+            top_p: Optional[float] | NotGiven = NOT_GIVEN,
+            max_tokens: int | NotGiven = NOT_GIVEN,
+            seed: int | NotGiven = NOT_GIVEN,
+            messages: Union[str, List[str], List[int], List[List[int]], None],
+            stop: Optional[Union[str, List[str], None]] | NotGiven = NOT_GIVEN,
+            sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN,
+            tools: Optional[object] | NotGiven = NOT_GIVEN,
+            tool_choice: str | NotGiven = NOT_GIVEN,
+            extra_headers: Headers | None = None,
+            disable_strict_validation: Optional[bool] | None = None,
+            timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+    ) -> AsyncTaskStatus:
+        _cast_type = AsyncTaskStatus
+
+        if disable_strict_validation:
+            _cast_type = object
+        return self._post(
+            "/async/chat/completions",
+            body={
+                "model": model,
+                "request_id": request_id,
+                "temperature": temperature,
+                "top_p": top_p,
+                "do_sample": do_sample,
+                "max_tokens": max_tokens,
+                "seed": seed,
+                "messages": messages,
+                "stop": stop,
+                "sensitive_word_check": sensitive_word_check,
+                "tools": tools,
+                "tool_choice": tool_choice,
+            },
+            options=make_user_request_input(
+                extra_headers=extra_headers, timeout=timeout
+            ),
+            cast_type=_cast_type,
+            enable_stream=False,
+        )
+
+    def retrieve_completion_result(
+        self,
+        id: str,
+        extra_headers: Headers | None = None,
+        disable_strict_validation: Optional[bool] | None = None,
+        timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+    ) -> Union[AsyncCompletion, AsyncTaskStatus]:
+        _cast_type = Union[AsyncCompletion,AsyncTaskStatus]
+        if disable_strict_validation:
+            _cast_type = object
+        return self._get(
+            path=f"/async-result/{id}",
+            cast_type=_cast_type,
+            options=make_user_request_input(
+                extra_headers=extra_headers,
+                timeout=timeout
+            )
+        )
+
+

+ 16 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/chat.py

@@ -0,0 +1,16 @@
+from typing import TYPE_CHECKING
+from .completions import Completions
+from .async_completions import AsyncCompletions
+from ...core._base_api import BaseAPI
+
+if TYPE_CHECKING:
+    from ..._client import ZhipuAI
+
+
+class Chat(BaseAPI):
+    completions: Completions
+
+    def __init__(self, client: "ZhipuAI") -> None:
+        super().__init__(client)
+        self.completions = Completions(client)
+        self.asyncCompletions = AsyncCompletions(client)

+ 71 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py

@@ -0,0 +1,71 @@
+from __future__ import annotations
+
+from typing import Union, List, Optional, TYPE_CHECKING
+
+import httpx
+from typing_extensions import Literal
+
+from ...core._base_api import BaseAPI
+from ...core._base_type import NotGiven, NOT_GIVEN, Headers
+from ...core._http_client import make_user_request_input
+from ...core._sse_client import StreamResponse
+from ...types.chat.chat_completion import Completion
+from ...types.chat.chat_completion_chunk import ChatCompletionChunk
+
+if TYPE_CHECKING:
+    from ..._client import ZhipuAI
+
+
+class Completions(BaseAPI):
+    def __init__(self, client: "ZhipuAI") -> None:
+        super().__init__(client)
+
+    def create(
+            self,
+            *,
+            model: str,
+            request_id: Optional[str] | NotGiven = NOT_GIVEN,
+            do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
+            stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
+            temperature: Optional[float] | NotGiven = NOT_GIVEN,
+            top_p: Optional[float] | NotGiven = NOT_GIVEN,
+            max_tokens: int | NotGiven = NOT_GIVEN,
+            seed: int | NotGiven = NOT_GIVEN,
+            messages: Union[str, List[str], List[int], object, None],
+            stop: Optional[Union[str, List[str], None]] | NotGiven = NOT_GIVEN,
+            sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN,
+            tools: Optional[object] | NotGiven = NOT_GIVEN,
+            tool_choice: str | NotGiven = NOT_GIVEN,
+            extra_headers: Headers | None = None,
+            disable_strict_validation: Optional[bool] | None = None,
+            timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+    ) -> Completion | StreamResponse[ChatCompletionChunk]:
+        _cast_type = Completion
+        _stream_cls = StreamResponse[ChatCompletionChunk]
+        if disable_strict_validation:
+            _cast_type = object
+            _stream_cls = StreamResponse[object]
+        return self._post(
+            "/chat/completions",
+            body={
+                "model": model,
+                "request_id": request_id,
+                "temperature": temperature,
+                "top_p": top_p,
+                "do_sample": do_sample,
+                "max_tokens": max_tokens,
+                "seed": seed,
+                "messages": messages,
+                "stop": stop,
+                "sensitive_word_check": sensitive_word_check,
+                "stream": stream,
+                "tools": tools,
+                "tool_choice": tool_choice,
+            },
+            options=make_user_request_input(
+                extra_headers=extra_headers,
+            ),
+            cast_type=_cast_type,
+            enable_stream=stream or False,
+            stream_cls=_stream_cls,
+        )

+ 49 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py

@@ -0,0 +1,49 @@
+from __future__ import annotations
+
+from typing import Union, List, Optional, TYPE_CHECKING
+
+import httpx
+
+from ..core._base_api import BaseAPI
+from ..core._base_type import NotGiven, NOT_GIVEN, Headers
+from ..core._http_client import make_user_request_input
+from ..types.embeddings import EmbeddingsResponded
+
+if TYPE_CHECKING:
+    from .._client import ZhipuAI
+
+
+class Embeddings(BaseAPI):
+    def __init__(self, client: "ZhipuAI") -> None:
+        super().__init__(client)
+
+    def create(
+            self,
+            *,
+            input: Union[str, List[str], List[int], List[List[int]]],
+            model: Union[str],
+            encoding_format: str | NotGiven = NOT_GIVEN,
+            user: str | NotGiven = NOT_GIVEN,
+            sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN,
+            extra_headers: Headers | None = None,
+            disable_strict_validation: Optional[bool] | None = None,
+            timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+    ) -> EmbeddingsResponded:
+        _cast_type = EmbeddingsResponded
+        if disable_strict_validation:
+            _cast_type = object
+        return self._post(
+            "/embeddings",
+            body={
+                "input": input,
+                "model": model,
+                "encoding_format": encoding_format,
+                "user": user,
+                "sensitive_word_check": sensitive_word_check,
+            },
+            options=make_user_request_input(
+                extra_headers=extra_headers, timeout=timeout
+            ),
+            cast_type=_cast_type,
+            enable_stream=False,
+        )

+ 78 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py

@@ -0,0 +1,78 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import httpx
+
+from ..core._base_api import BaseAPI
+from ..core._base_type import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes
+from ..core._files import is_file_content
+from ..core._http_client import (
+    make_user_request_input,
+)
+from ..types.file_object import FileObject, ListOfFileObject
+
+if TYPE_CHECKING:
+    from .._client import ZhipuAI
+
+__all__ = ["Files"]
+
+
+class Files(BaseAPI):
+
+    def __init__(self, client: "ZhipuAI") -> None:
+        super().__init__(client)
+
+    def create(
+            self,
+            *,
+            file: FileTypes,
+            purpose: str,
+            extra_headers: Headers | None = None,
+            timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+    ) -> FileObject:
+        if not is_file_content(file):
+            prefix = f"Expected file input `{file!r}`"
+            raise RuntimeError(
+                f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(file)} instead."
+            ) from None
+        files = [("file", file)]
+
+        extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
+
+        return self._post(
+            "/files",
+            body={
+                "purpose": purpose,
+            },
+            files=files,
+            options=make_user_request_input(
+                extra_headers=extra_headers, timeout=timeout
+            ),
+            cast_type=FileObject,
+        )
+
+    def list(
+            self,
+            *,
+            purpose: str | NotGiven = NOT_GIVEN,
+            limit: int  | NotGiven = NOT_GIVEN,
+            after: str | NotGiven = NOT_GIVEN,
+            order: str | NotGiven = NOT_GIVEN,
+            extra_headers: Headers | None = None,
+            timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+    ) -> ListOfFileObject:
+        return self._get(
+            "/files",
+            cast_type=ListOfFileObject,
+            options=make_user_request_input(
+                extra_headers=extra_headers,
+                timeout=timeout,
+                query={
+                    "purpose": purpose,
+                    "limit": limit,
+                    "after": after,
+                    "order": order,
+                },
+            ),
+        )

+ 0 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/__init__.py


+ 15 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py

@@ -0,0 +1,15 @@
+from typing import TYPE_CHECKING
+from .jobs import Jobs
+from ...core._base_api import BaseAPI
+
+if TYPE_CHECKING:
+    from ..._client import ZhipuAI
+
+
+class FineTuning(BaseAPI):
+    jobs: Jobs
+
+    def __init__(self, client: "ZhipuAI") -> None:
+        super().__init__(client)
+        self.jobs = Jobs(client)
+

+ 115 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py

@@ -0,0 +1,115 @@
+from __future__ import annotations
+
+from typing import Optional, TYPE_CHECKING
+
+import httpx
+
+from ...core._base_api import BaseAPI
+from ...core._base_type import NOT_GIVEN, Headers, NotGiven
+from ...core._http_client import (
+    make_user_request_input,
+)
+from ...types.fine_tuning import (
+    FineTuningJob,
+    job_create_params,
+    ListOfFineTuningJob,
+    FineTuningJobEvent,
+)
+
+if TYPE_CHECKING:
+    from ..._client import ZhipuAI
+
+__all__ = ["Jobs"]
+
+
+class Jobs(BaseAPI):
+
+    def __init__(self, client: "ZhipuAI") -> None:
+        super().__init__(client)
+
+    def create(
+            self,
+            *,
+            model: str,
+            training_file: str,
+            hyperparameters: job_create_params.Hyperparameters | NotGiven = NOT_GIVEN,
+            suffix: Optional[str] | NotGiven = NOT_GIVEN,
+            request_id: Optional[str] | NotGiven = NOT_GIVEN,
+            validation_file: Optional[str] | NotGiven = NOT_GIVEN,
+            extra_headers: Headers | None = None,
+            timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+    ) -> FineTuningJob:
+        return self._post(
+            "/fine_tuning/jobs",
+            body={
+                "model": model,
+                "training_file": training_file,
+                "hyperparameters": hyperparameters,
+                "suffix": suffix,
+                "validation_file": validation_file,
+                "request_id": request_id,
+            },
+            options=make_user_request_input(
+                extra_headers=extra_headers, timeout=timeout
+            ),
+            cast_type=FineTuningJob,
+        )
+
+    def retrieve(
+            self,
+            fine_tuning_job_id: str,
+            *,
+            extra_headers: Headers | None = None,
+            timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+    ) -> FineTuningJob:
+        return self._get(
+            f"/fine_tuning/jobs/{fine_tuning_job_id}",
+            options=make_user_request_input(
+                extra_headers=extra_headers, timeout=timeout
+            ),
+            cast_type=FineTuningJob,
+        )
+
+    def list(
+            self,
+            *,
+            after: str | NotGiven = NOT_GIVEN,
+            limit: int | NotGiven = NOT_GIVEN,
+            extra_headers: Headers | None = None,
+            timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+    ) -> ListOfFineTuningJob:
+        return self._get(
+            "/fine_tuning/jobs",
+            cast_type=ListOfFineTuningJob,
+            options=make_user_request_input(
+                extra_headers=extra_headers,
+                timeout=timeout,
+                query={
+                    "after": after,
+                    "limit": limit,
+                },
+            ),
+        )
+
+    def list_events(
+        self,
+        fine_tuning_job_id: str,
+        *,
+        after: str | NotGiven = NOT_GIVEN,
+        limit: int | NotGiven = NOT_GIVEN,
+        extra_headers: Headers | None = None,
+        timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+    ) -> FineTuningJobEvent:
+
+        return self._get(
+            f"/fine_tuning/jobs/{fine_tuning_job_id}/events",
+            cast_type=FineTuningJobEvent,
+            options=make_user_request_input(
+                extra_headers=extra_headers,
+                timeout=timeout,
+                query={
+                    "after": after,
+                    "limit": limit,
+                },
+            ),
+        )

+ 55 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py

@@ -0,0 +1,55 @@
+from __future__ import annotations
+
+from typing import Union, List, Optional, TYPE_CHECKING
+
+import httpx
+
+from ..core._base_api import BaseAPI
+from ..core._base_type import NotGiven, NOT_GIVEN, Headers
+from ..core._http_client import make_user_request_input
+from ..types.image import ImagesResponded
+
+if TYPE_CHECKING:
+    from .._client import ZhipuAI
+
+
+class Images(BaseAPI):
+    def __init__(self, client: "ZhipuAI") -> None:
+        super().__init__(client)
+
+    def generations(
+            self,
+            *,
+            prompt: str,
+            model: str | NotGiven = NOT_GIVEN,
+            n: Optional[int] | NotGiven = NOT_GIVEN,
+            quality: Optional[str] | NotGiven = NOT_GIVEN,
+            response_format: Optional[str] | NotGiven = NOT_GIVEN,
+            size: Optional[str] | NotGiven = NOT_GIVEN,
+            style: Optional[str] | NotGiven = NOT_GIVEN,
+            user: str | NotGiven = NOT_GIVEN,
+            extra_headers: Headers | None = None,
+            disable_strict_validation: Optional[bool] | None = None,
+            timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+    ) -> ImagesResponded:
+        _cast_type = ImagesResponded
+        if disable_strict_validation:
+            _cast_type = object
+        return self._post(
+            "/images/generations",
+            body={
+                "prompt": prompt,
+                "model": model,
+                "n": n,
+                "quality": quality,
+                "response_format": response_format,
+                "size": size,
+                "style": style,
+                "user": user,
+            },
+            options=make_user_request_input(
+                extra_headers=extra_headers, timeout=timeout
+            ),
+            cast_type=_cast_type,
+            enable_stream=False,
+        )

+ 0 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/__init__.py


+ 17 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_api.py

@@ -0,0 +1,17 @@
+from __future__ import annotations
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+    from .._client import ZhipuAI
+
+
+class BaseAPI:
+    _client: ZhipuAI
+
+    def __init__(self, client: ZhipuAI) -> None:
+        self._client = client
+        self._delete = client.delete
+        self._get = client.get
+        self._post = client.post
+        self._put = client.put
+        self._patch = client.patch

+ 115 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py

@@ -0,0 +1,115 @@
+from __future__ import annotations
+
+from os import PathLike
+from typing import (
+    TYPE_CHECKING,
+    Type,
+    Union,
+    Mapping,
+    TypeVar, IO, Tuple, Sequence, Any, List,
+)
+
+import pydantic
+from typing_extensions import (
+    Literal,
+    override,
+)
+
+
+Query = Mapping[str, object]
+Body = object
+AnyMapping = Mapping[str, object]
+PrimitiveData = Union[str, int, float, bool, None]
+Data = Union[PrimitiveData, List[Any], Tuple[Any], "Mapping[str, Any]"]
+ModelT = TypeVar("ModelT", bound=pydantic.BaseModel)
+_T = TypeVar("_T")
+
+if TYPE_CHECKING:
+    NoneType: Type[None]
+else:
+    NoneType = type(None)
+
+
+# Sentinel class used until PEP 0661 is accepted
+class NotGiven(pydantic.BaseModel):
+    """
+    A sentinel singleton class used to distinguish omitted keyword arguments
+    from those passed in with the value None (which may have different behavior).
+
+    For example:
+
+    ```py
+    def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: ...
+
+    get(timeout=1) # 1s timeout
+    get(timeout=None) # No timeout
+    get() # Default timeout behavior, which may not be statically known at the method definition.
+    ```
+    """
+
+    def __bool__(self) -> Literal[False]:
+        return False
+
+    @override
+    def __repr__(self) -> str:
+        return "NOT_GIVEN"
+
+
+NotGivenOr = Union[_T, NotGiven]
+NOT_GIVEN = NotGiven()
+
+
+class Omit(pydantic.BaseModel):
+    """In certain situations you need to be able to represent a case where a default value has
+    to be explicitly removed and `None` is not an appropriate substitute, for example:
+
+    ```py
+    # as the default `Content-Type` header is `application/json` that will be sent
+    client.post('/upload/files', files={'file': b'my raw file content'})
+
+    # you can't explicitly override the header as it has to be dynamically generated
+    # to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983'
+    client.post(..., headers={'Content-Type': 'multipart/form-data'})
+
+    # instead you can remove the default `application/json` header by passing Omit
+    client.post(..., headers={'Content-Type': Omit()})
+    ```
+    """
+
+    def __bool__(self) -> Literal[False]:
+        return False
+
+
+Headers = Mapping[str, Union[str, Omit]]
+
+ResponseT = TypeVar(
+    "ResponseT",
+    bound="Union[str, None, BaseModel, List[Any], Dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]",
+)
+
+# for user input files
+if TYPE_CHECKING:
+    FileContent = Union[IO[bytes], bytes, PathLike[str]]
+else:
+    FileContent = Union[IO[bytes], bytes, PathLike]
+
+FileTypes = Union[
+    FileContent,  # file content
+    Tuple[str, FileContent],  # (filename, file)
+    Tuple[str, FileContent, str],  # (filename, file , content_type)
+    Tuple[str, FileContent, str, Mapping[str, str]],  # (filename, file , content_type, headers)
+]
+
+RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]]
+
+# for httpx client supported files
+
+HttpxFileContent = Union[bytes, IO[bytes]]
+HttpxFileTypes = Union[
+    FileContent,  # file content
+    Tuple[str, HttpxFileContent],  # (filename, file)
+    Tuple[str, HttpxFileContent, str],  # (filename, file , content_type)
+    Tuple[str, HttpxFileContent, str, Mapping[str, str]],  # (filename, file , content_type, headers)
+]
+
+HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]]]

+ 90 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py

@@ -0,0 +1,90 @@
+from __future__ import annotations
+
+import httpx
+
+__all__ = [
+    "ZhipuAIError",
+    "APIStatusError",
+    "APIRequestFailedError",
+    "APIAuthenticationError",
+    "APIReachLimitError",
+    "APIInternalError",
+    "APIServerFlowExceedError",
+    "APIResponseError",
+    "APIResponseValidationError",
+    "APITimeoutError",
+]
+
+
+class ZhipuAIError(Exception):
+    def __init__(self, message: str, ) -> None:
+        super().__init__(message)
+
+
+class APIStatusError(Exception):
+    response: httpx.Response
+    status_code: int
+
+    def __init__(self, message: str, *, response: httpx.Response) -> None:
+        super().__init__(message)
+        self.response = response
+        self.status_code = response.status_code
+
+
+class APIRequestFailedError(APIStatusError):
+    ...
+
+
+class APIAuthenticationError(APIStatusError):
+    ...
+
+
+class APIReachLimitError(APIStatusError):
+    ...
+
+
+class APIInternalError(APIStatusError):
+    ...
+
+
+class APIServerFlowExceedError(APIStatusError):
+    ...
+
+
+class APIResponseError(Exception):
+    message: str
+    request: httpx.Request
+    json_data: object
+
+    def __init__(self, message: str, request: httpx.Request, json_data: object):
+        self.message = message
+        self.request = request
+        self.json_data = json_data
+        super().__init__(message)
+
+
+class APIResponseValidationError(APIResponseError):
+    status_code: int
+    response: httpx.Response
+
+    def __init__(
+            self,
+            response: httpx.Response,
+            json_data: object | None, *,
+            message: str | None = None
+    ) -> None:
+        super().__init__(
+            message=message or "Data returned by API invalid for expected schema.",
+            request=response.request,
+            json_data=json_data
+        )
+        self.response = response
+        self.status_code = response.status_code
+
+
+class APITimeoutError(Exception):
+    request: httpx.Request
+
+    def __init__(self, request: httpx.Request):
+        self.request = request
+        super().__init__("Request Timeout")

+ 46 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py

@@ -0,0 +1,46 @@
+from __future__ import annotations
+
+import io
+import os
+from pathlib import Path
+from typing import Mapping, Sequence
+
+from ._base_type import (
+    FileTypes,
+    HttpxFileTypes,
+    HttpxRequestFiles,
+    RequestFiles,
+)
+
+
+def is_file_content(obj: object) -> bool:
+    return isinstance(obj, (bytes, tuple, io.IOBase, os.PathLike))
+
+
+def _transform_file(file: FileTypes) -> HttpxFileTypes:
+    if is_file_content(file):
+        if isinstance(file, os.PathLike):
+            path = Path(file)
+            return path.name, path.read_bytes()
+        else:
+            return file
+    if isinstance(file, tuple):
+        if isinstance(file[1], os.PathLike):
+            return (file[0], Path(file[1]).read_bytes(), *file[2:])
+        else:
+            return (file[0], file[1], *file[2:])
+    else:
+        raise TypeError(f"Unexpected input file with type {type(file)},Expected FileContent type or tuple type")
+
+
+def make_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
+    if files is None:
+        return None
+
+    if isinstance(files, Mapping):
+        files = {key: _transform_file(file) for key, file in files.items()}
+    elif isinstance(files, Sequence):
+        files = [(key, _transform_file(file)) for key, file in files]
+    else:
+        raise TypeError(f"Unexpected input file with type {type(files)}, excepted Mapping or Sequence")
+    return files

+ 377 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py

@@ -0,0 +1,377 @@
+# -*- coding:utf-8 -*-
+from __future__ import annotations
+
+import inspect
+from typing import (
+    Any,
+    Type,
+    Union,
+    cast,
+    Mapping,
+)
+
+import httpx
+import pydantic
+from httpx import URL, Timeout
+
+from . import _errors
+from ._base_type import NotGiven, ResponseT, Body, Headers, NOT_GIVEN, RequestFiles, Query, Data
+from ._errors import APIResponseValidationError, APIStatusError, APITimeoutError
+from ._files import make_httpx_files
+from ._request_opt import ClientRequestParam, UserRequestInput
+from ._response import HttpResponse
+from ._sse_client import StreamResponse
+from ._utils import flatten
+
+headers = {
+    "Accept": "application/json",
+    "Content-Type": "application/json; charset=UTF-8",
+}
+
+
+def _merge_map(map1: Mapping, map2: Mapping) -> Mapping:
+    merged = {**map1, **map2}
+    return {key: val for key, val in merged.items() if val is not None}
+
+
+from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT
+
+ZHIPUAI_DEFAULT_TIMEOUT = httpx.Timeout(timeout=300.0, connect=8.0)
+ZHIPUAI_DEFAULT_MAX_RETRIES = 3
+ZHIPUAI_DEFAULT_LIMITS = httpx.Limits(max_connections=50, max_keepalive_connections=10)
+
+
+class HttpClient:
+    _client: httpx.Client
+    _version: str
+    _base_url: URL
+
+    timeout: Union[float, Timeout, None]
+    _limits: httpx.Limits
+    _has_custom_http_client: bool
+    _default_stream_cls: type[StreamResponse[Any]] | None = None
+
+    def __init__(
+            self,
+            *,
+            version: str,
+            base_url: URL,
+            timeout: Union[float, Timeout, None],
+            custom_httpx_client: httpx.Client | None = None,
+            custom_headers: Mapping[str, str] | None = None,
+    ) -> None:
+        if timeout is None or isinstance(timeout, NotGiven):
+            if custom_httpx_client and custom_httpx_client.timeout != HTTPX_DEFAULT_TIMEOUT:
+                timeout = custom_httpx_client.timeout
+            else:
+                timeout = ZHIPUAI_DEFAULT_TIMEOUT
+        self.timeout = cast(Timeout, timeout)
+        self._has_custom_http_client = bool(custom_httpx_client)
+        self._client = custom_httpx_client or httpx.Client(
+            base_url=base_url,
+            timeout=self.timeout,
+            limits=ZHIPUAI_DEFAULT_LIMITS,
+        )
+        self._version = version
+        url = URL(url=base_url)
+        if not url.raw_path.endswith(b"/"):
+            url = url.copy_with(raw_path=url.raw_path + b"/")
+        self._base_url = url
+        self._custom_headers = custom_headers or {}
+
+    def _prepare_url(self, url: str) -> URL:
+
+        sub_url = URL(url)
+        if sub_url.is_relative_url:
+            request_raw_url = self._base_url.raw_path + sub_url.raw_path.lstrip(b"/")
+            return self._base_url.copy_with(raw_path=request_raw_url)
+
+        return sub_url
+
+    @property
+    def _default_headers(self):
+        return \
+            {
+                "Accept": "application/json",
+                "Content-Type": "application/json; charset=UTF-8",
+                "ZhipuAI-SDK-Ver": self._version,
+                "source_type": "zhipu-sdk-python",
+                "x-request-sdk": "zhipu-sdk-python",
+                **self._auth_headers,
+                **self._custom_headers,
+            }
+
+    @property
+    def _auth_headers(self):
+        return {}
+
+    def _prepare_headers(self, request_param: ClientRequestParam) -> httpx.Headers:
+        custom_headers = request_param.headers or {}
+        headers_dict = _merge_map(self._default_headers, custom_headers)
+
+        httpx_headers = httpx.Headers(headers_dict)
+
+        return httpx_headers
+
+    def _prepare_request(
+            self,
+            request_param: ClientRequestParam
+    ) -> httpx.Request:
+        kwargs: dict[str, Any] = {}
+        json_data = request_param.json_data
+        headers = self._prepare_headers(request_param)
+        url = self._prepare_url(request_param.url)
+        json_data = request_param.json_data
+        if headers.get("Content-Type") == "multipart/form-data":
+            headers.pop("Content-Type")
+
+            if json_data:
+                kwargs["data"] = self._make_multipartform(json_data)
+
+        return self._client.build_request(
+            headers=headers,
+            timeout=self.timeout if isinstance(request_param.timeout, NotGiven) else request_param.timeout,
+            method=request_param.method,
+            url=url,
+            json=json_data,
+            files=request_param.files,
+            params=request_param.params,
+            **kwargs,
+        )
+
+    def _object_to_formfata(self, key: str, value: Data | Mapping[object, object]) -> list[tuple[str, str]]:
+        items = []
+
+        if isinstance(value, Mapping):
+            for k, v in value.items():
+                items.extend(self._object_to_formfata(f"{key}[{k}]", v))
+            return items
+        if isinstance(value, (list, tuple)):
+            for v in value:
+                items.extend(self._object_to_formfata(key + "[]", v))
+            return items
+
+        def _primitive_value_to_str(val) -> str:
+            # copied from httpx
+            if val is True:
+                return "true"
+            elif val is False:
+                return "false"
+            elif val is None:
+                return ""
+            return str(val)
+
+        str_data = _primitive_value_to_str(value)
+
+        if not str_data:
+            return []
+        return [(key, str_data)]
+
+    def _make_multipartform(self, data: Mapping[object, object]) -> dict[str, object]:
+
+        items = flatten([self._object_to_formfata(k, v) for k, v in data.items()])
+
+        serialized: dict[str, object] = {}
+        for key, value in items:
+            if key in serialized:
+                raise ValueError(f"存在重复的键: {key};")
+            serialized[key] = value
+        return serialized
+
+    def _parse_response(
+            self,
+            *,
+            cast_type: Type[ResponseT],
+            response: httpx.Response,
+            enable_stream: bool,
+            request_param: ClientRequestParam,
+            stream_cls: type[StreamResponse[Any]] | None = None,
+    ) -> HttpResponse:
+
+        http_response = HttpResponse(
+            raw_response=response,
+            cast_type=cast_type,
+            client=self,
+            enable_stream=enable_stream,
+            stream_cls=stream_cls
+        )
+        return http_response.parse()
+
+    def _process_response_data(
+            self,
+            *,
+            data: object,
+            cast_type: type[ResponseT],
+            response: httpx.Response,
+    ) -> ResponseT:
+        if data is None:
+            return cast(ResponseT, None)
+
+        try:
+            if inspect.isclass(cast_type) and issubclass(cast_type, pydantic.BaseModel):
+                return cast(ResponseT, cast_type.validate(data))
+
+            return cast(ResponseT, pydantic.TypeAdapter(cast_type).validate_python(data))
+        except pydantic.ValidationError as err:
+            raise APIResponseValidationError(response=response, json_data=data) from err
+
+    def is_closed(self) -> bool:
+        return self._client.is_closed
+
+    def close(self):
+        self._client.close()
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.close()
+
+    def request(
+            self,
+            *,
+            cast_type: Type[ResponseT],
+            params: ClientRequestParam,
+            enable_stream: bool = False,
+            stream_cls: type[StreamResponse[Any]] | None = None,
+    ) -> ResponseT | StreamResponse:
+        request = self._prepare_request(params)
+
+        try:
+            response = self._client.send(
+                request,
+                stream=enable_stream,
+            )
+            response.raise_for_status()
+        except httpx.TimeoutException as err:
+            raise APITimeoutError(request=request) from err
+        except httpx.HTTPStatusError as err:
+            err.response.read()
+            # raise err
+            raise self._make_status_error(err.response) from None
+
+        except Exception as err:
+            raise err
+
+        return self._parse_response(
+            cast_type=cast_type,
+            request_param=params,
+            response=response,
+            enable_stream=enable_stream,
+            stream_cls=stream_cls,
+        )
+
+    def get(
+            self,
+            path: str,
+            *,
+            cast_type: Type[ResponseT],
+            options: UserRequestInput = {},
+            enable_stream: bool = False,
+    ) -> ResponseT | StreamResponse:
+        opts = ClientRequestParam.construct(method="get", url=path, **options)
+        return self.request(
+            cast_type=cast_type, params=opts,
+            enable_stream=enable_stream
+        )
+
+    def post(
+            self,
+            path: str,
+            *,
+            body: Body | None = None,
+            cast_type: Type[ResponseT],
+            options: UserRequestInput = {},
+            files: RequestFiles | None = None,
+            enable_stream: bool = False,
+            stream_cls: type[StreamResponse[Any]] | None = None,
+    ) -> ResponseT | StreamResponse:
+        opts = ClientRequestParam.construct(method="post", json_data=body, files=make_httpx_files(files), url=path,
+                                            **options)
+
+        return self.request(
+            cast_type=cast_type, params=opts,
+            enable_stream=enable_stream,
+            stream_cls=stream_cls
+        )
+
+    def patch(
+            self,
+            path: str,
+            *,
+            body: Body | None = None,
+            cast_type: Type[ResponseT],
+            options: UserRequestInput = {},
+    ) -> ResponseT:
+        opts = ClientRequestParam.construct(method="patch", url=path, json_data=body, **options)
+
+        return self.request(
+            cast_type=cast_type, params=opts,
+        )
+
+    def put(
+            self,
+            path: str,
+            *,
+            body: Body | None = None,
+            cast_type: Type[ResponseT],
+            options: UserRequestInput = {},
+            files: RequestFiles | None = None,
+    ) -> ResponseT | StreamResponse:
+        opts = ClientRequestParam.construct(method="put", url=path, json_data=body, files=make_httpx_files(files),
+                                            **options)
+
+        return self.request(
+            cast_type=cast_type, params=opts,
+        )
+
+    def delete(
+            self,
+            path: str,
+            *,
+            body: Body | None = None,
+            cast_type: Type[ResponseT],
+            options: UserRequestInput = {},
+    ) -> ResponseT | StreamResponse:
+        opts = ClientRequestParam.construct(method="delete", url=path, json_data=body, **options)
+
+        return self.request(
+            cast_type=cast_type, params=opts,
+        )
+
+    def _make_status_error(self, response) -> APIStatusError:
+        response_text = response.text.strip()
+        status_code = response.status_code
+        error_msg = f"Error code: {status_code}, with error text {response_text}"
+
+        if status_code == 400:
+            return _errors.APIRequestFailedError(message=error_msg, response=response)
+        elif status_code == 401:
+            return _errors.APIAuthenticationError(message=error_msg, response=response)
+        elif status_code == 429:
+            return _errors.APIReachLimitError(message=error_msg, response=response)
+        elif status_code == 500:
+            return _errors.APIInternalError(message=error_msg, response=response)
+        elif status_code == 503:
+            return _errors.APIServerFlowExceedError(message=error_msg, response=response)
+        return APIStatusError(message=error_msg, response=response)
+
+
+def make_user_request_input(
+        max_retries: int | None = None,
+        timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
+        extra_headers: Headers = None,
+        query: Query | None = None,
+) -> UserRequestInput:
+    options: UserRequestInput = {}
+
+    if extra_headers is not None:
+        options["headers"] = extra_headers
+    if max_retries is not None:
+        options["max_retries"] = max_retries
+    if not isinstance(timeout, NotGiven):
+        options['timeout'] = timeout
+    if query is not None:
+        options["params"] = query
+
+    return options

+ 30 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_jwt_token.py

@@ -0,0 +1,30 @@
+# -*- coding:utf-8 -*-
+import time
+
+import cachetools.func
+import jwt
+
+API_TOKEN_TTL_SECONDS = 3 * 60
+
+CACHE_TTL_SECONDS = API_TOKEN_TTL_SECONDS - 30
+
+
+@cachetools.func.ttl_cache(maxsize=10, ttl=CACHE_TTL_SECONDS)
+def generate_token(apikey: str):
+    try:
+        api_key, secret = apikey.split(".")
+    except Exception as e:
+        raise Exception("invalid api_key", e)
+
+    payload = {
+        "api_key": api_key,
+        "exp": int(round(time.time() * 1000)) + API_TOKEN_TTL_SECONDS * 1000,
+        "timestamp": int(round(time.time() * 1000)),
+    }
+    ret = jwt.encode(
+        payload,
+        secret,
+        algorithm="HS256",
+        headers={"alg": "HS256", "sign_type": "SIGN"},
+    )
+    return ret

+ 54 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py

@@ -0,0 +1,54 @@
+from __future__ import annotations
+
+from typing import Union, Any, cast
+
+import pydantic.generics
+from httpx import Timeout
+from pydantic import ConfigDict
+from typing_extensions import (
+    Unpack, ClassVar, TypedDict
+)
+
+from ._base_type import Body, NotGiven, Headers, HttpxRequestFiles, Query
+from ._utils import remove_notgiven_indict
+
+
+class UserRequestInput(TypedDict, total=False):
+    max_retries: int
+    timeout: float | Timeout | None
+    headers: Headers
+    params: Query | None
+
+
+class ClientRequestParam():
+    method: str
+    url: str
+    max_retries: Union[int, NotGiven] = NotGiven()
+    timeout: Union[float, NotGiven] = NotGiven()
+    headers: Union[Headers, NotGiven] = NotGiven()
+    json_data: Union[Body, None] = None
+    files: Union[HttpxRequestFiles, None] = None
+    params: Query = {}
+    model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
+
+    def get_max_retries(self, max_retries) -> int:
+        if isinstance(self.max_retries, NotGiven):
+            return max_retries
+        return self.max_retries
+
+    @classmethod
+    def construct(  # type: ignore
+            cls,
+            _fields_set: set[str] | None = None,
+            **values: Unpack[UserRequestInput],
+    ) -> ClientRequestParam :
+        kwargs: dict[str, Any] = {
+            key: remove_notgiven_indict(value) for key, value in values.items()
+        }
+        client = cls()
+        client.__dict__.update(kwargs)
+
+        return client
+
+    model_construct = construct
+

+ 121 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py

@@ -0,0 +1,121 @@
+from __future__ import annotations
+
+import datetime
+from typing import TypeVar, Generic, cast, Any, TYPE_CHECKING
+
+import httpx
+import pydantic
+from typing_extensions import ParamSpec, get_origin, get_args
+
+from ._base_type import NoneType
+from ._sse_client import StreamResponse
+
+if TYPE_CHECKING:
+    from ._http_client import HttpClient
+
+P = ParamSpec("P")
+R = TypeVar("R")
+
+
+class HttpResponse(Generic[R]):
+    _cast_type: type[R]
+    _client: "HttpClient"
+    _parsed: R | None
+    _enable_stream: bool
+    _stream_cls: type[StreamResponse[Any]]
+    http_response: httpx.Response
+
+    def __init__(
+            self,
+            *,
+            raw_response: httpx.Response,
+            cast_type: type[R],
+            client: "HttpClient",
+            enable_stream: bool = False,
+            stream_cls: type[StreamResponse[Any]] | None = None,
+    ) -> None:
+        self._cast_type = cast_type
+        self._client = client
+        self._parsed = None
+        self._stream_cls = stream_cls
+        self._enable_stream = enable_stream
+        self.http_response = raw_response
+
+    def parse(self) -> R:
+        self._parsed = self._parse()
+        return self._parsed
+
+    def _parse(self) -> R:
+        if self._enable_stream:
+            self._parsed = cast(
+                R,
+                self._stream_cls(
+                    cast_type=cast(type, get_args(self._stream_cls)[0]),
+                    response=self.http_response,
+                    client=self._client
+                )
+            )
+            return self._parsed
+        cast_type = self._cast_type
+        if cast_type is NoneType:
+            return cast(R, None)
+        http_response = self.http_response
+        if cast_type == str:
+            return cast(R, http_response.text)
+
+        content_type, *_ = http_response.headers.get("content-type", "application/json").split(";")
+        origin = get_origin(cast_type) or cast_type
+        if content_type != "application/json":
+            if issubclass(origin, pydantic.BaseModel):
+                data = http_response.json()
+                return self._client._process_response_data(
+                    data=data,
+                    cast_type=cast_type,  # type: ignore
+                    response=http_response,
+                )
+
+            return http_response.text
+
+        data = http_response.json()
+
+        return self._client._process_response_data(
+            data=data,
+            cast_type=cast_type,  # type: ignore
+            response=http_response,
+        )
+
+    @property
+    def headers(self) -> httpx.Headers:
+        return self.http_response.headers
+
+    @property
+    def http_request(self) -> httpx.Request:
+        return self.http_response.request
+
+    @property
+    def status_code(self) -> int:
+        return self.http_response.status_code
+
+    @property
+    def url(self) -> httpx.URL:
+        return self.http_response.url
+
+    @property
+    def method(self) -> str:
+        return self.http_request.method
+
+    @property
+    def content(self) -> bytes:
+        return self.http_response.content
+
+    @property
+    def text(self) -> str:
+        return self.http_response.text
+
+    @property
+    def http_version(self) -> str:
+        return self.http_response.http_version
+
+    @property
+    def elapsed(self) -> datetime.timedelta:
+        return self.http_response.elapsed

+ 149 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py

@@ -0,0 +1,149 @@
+# -*- coding:utf-8 -*-
+from __future__ import annotations
+
+import json
+from typing import Generic, Iterator, TYPE_CHECKING, Mapping
+
+import httpx
+
+from ._base_type import ResponseT
+from ._errors import APIResponseError
+
+_FIELD_SEPARATOR = ":"
+
+if TYPE_CHECKING:
+    from ._http_client import HttpClient
+
+
+class StreamResponse(Generic[ResponseT]):
+
+    response: httpx.Response
+    _cast_type: type[ResponseT]
+
+    def __init__(
+            self,
+            *,
+            cast_type: type[ResponseT],
+            response: httpx.Response,
+            client: HttpClient,
+    ) -> None:
+        self.response = response
+        self._cast_type = cast_type
+        self._data_process_func = client._process_response_data
+        self._stream_chunks = self.__stream__()
+
+    def __next__(self) -> ResponseT:
+        return self._stream_chunks.__next__()
+
+    def __iter__(self) -> Iterator[ResponseT]:
+        for item in self._stream_chunks:
+            yield item
+
+    def __stream__(self) -> Iterator[ResponseT]:
+
+        sse_line_parser = SSELineParser()
+        iterator = sse_line_parser.iter_lines(self.response.iter_lines())
+
+        for sse in iterator:
+            if sse.data.startswith("[DONE]"):
+                break
+
+            if sse.event is None:
+                data = sse.json_data()
+                if isinstance(data, Mapping) and data.get("error"):
+                    raise APIResponseError(
+                        message="An error occurred during streaming",
+                        request=self.response.request,
+                        json_data=data["error"],
+                    )
+
+                yield self._data_process_func(data=data, cast_type=self._cast_type, response=self.response)
+        for sse in iterator:
+            pass
+
+
+class Event(object):
+    def __init__(
+            self,
+            event: str | None = None,
+            data: str | None = None,
+            id: str | None = None,
+            retry: int | None = None
+    ):
+        self._event = event
+        self._data = data
+        self._id = id
+        self._retry = retry
+
+    def __repr__(self):
+        data_len = len(self._data) if self._data else 0
+        return f"Event(event={self._event}, data={self._data} ,data_length={data_len}, id={self._id}, retry={self._retry}"
+
+    @property
+    def event(self): return self._event
+
+    @property
+    def data(self): return self._data
+
+    def json_data(self): return json.loads(self._data)
+
+    @property
+    def id(self): return self._id
+
+    @property
+    def retry(self): return self._retry
+
+
+class SSELineParser:
+    _data: list[str]
+    _event: str | None
+    _retry: int | None
+    _id: str | None
+
+    def __init__(self):
+        self._event = None
+        self._data = []
+        self._id = None
+        self._retry = None
+
+    def iter_lines(self, lines: Iterator[str]) -> Iterator[Event]:
+        for line in lines:
+            line = line.rstrip('\n')
+            if not line:
+                if self._event is None and \
+                        not self._data and \
+                        self._id is None and \
+                        self._retry is None:
+                    continue
+                sse_event = Event(
+                    event=self._event,
+                    data='\n'.join(self._data),
+                    id=self._id,
+                    retry=self._retry
+                )
+                self._event = None
+                self._data = []
+                self._id = None
+                self._retry = None
+
+                yield sse_event
+            self.decode_line(line)
+
+    def decode_line(self, line: str):
+        if line.startswith(":") or not line:
+            return
+
+        field, _p, value = line.partition(":")
+
+        if value.startswith(' '):
+            value = value[1:]
+        if field == "data":
+            self._data.append(value)
+        elif field == "event":
+            self._event = value
+        elif field == "retry":
+            try:
+                self._retry = int(value)
+            except (TypeError, ValueError):
+                pass
+        return

+ 18 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils.py

@@ -0,0 +1,18 @@
+from __future__ import annotations
+
+from typing import Mapping, Iterable, TypeVar
+
+from ._base_type import NotGiven
+
+
+def remove_notgiven_indict(obj):
+    if obj is None or (not isinstance(obj, Mapping)):
+        return obj
+    return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)}
+
+
+_T = TypeVar("_T")
+
+
+def flatten(t: Iterable[Iterable[_T]]) -> list[_T]:
+    return [item for sublist in t for item in sublist]

+ 0 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/__init__.py


+ 0 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/__init__.py


+ 23 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py

@@ -0,0 +1,23 @@
+from typing import List, Optional
+
+from pydantic import BaseModel
+
+from .chat_completion import CompletionChoice, CompletionUsage
+
+__all__ = ["AsyncTaskStatus"]
+
+
+class AsyncTaskStatus(BaseModel):
+    id: Optional[str] = None
+    request_id: Optional[str] = None
+    model: Optional[str] = None
+    task_status: Optional[str] = None
+
+
+class AsyncCompletion(BaseModel):
+    id: Optional[str] = None
+    request_id: Optional[str] = None
+    model: Optional[str] = None
+    task_status: str
+    choices: List[CompletionChoice]
+    usage: CompletionUsage

+ 45 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py

@@ -0,0 +1,45 @@
+from typing import List, Optional
+
+from pydantic import BaseModel
+
+__all__ = ["Completion", "CompletionUsage"]
+
+
+class Function(BaseModel):
+    arguments: str
+    name: str
+
+
+class CompletionMessageToolCall(BaseModel):
+    id: str
+    function: Function
+    type: str
+
+
+class CompletionMessage(BaseModel):
+    content: Optional[str] = None
+    role: str
+    tool_calls: Optional[List[CompletionMessageToolCall]] = None
+
+
+class CompletionUsage(BaseModel):
+    prompt_tokens: int
+    completion_tokens: int
+    total_tokens: int
+
+
+class CompletionChoice(BaseModel):
+    index: int
+    finish_reason: str
+    message: CompletionMessage
+
+
+class Completion(BaseModel):
+    model: Optional[str] = None
+    created: Optional[int] = None
+    choices: List[CompletionChoice]
+    request_id: Optional[str] = None
+    id: Optional[str] = None
+    usage: CompletionUsage
+
+

+ 55 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion_chunk.py

@@ -0,0 +1,55 @@
+from typing import List, Optional
+
+from pydantic import BaseModel
+
+__all__ = [
+    "ChatCompletionChunk",
+    "Choice",
+    "ChoiceDelta",
+    "ChoiceDeltaFunctionCall",
+    "ChoiceDeltaToolCall",
+    "ChoiceDeltaToolCallFunction",
+]
+
+
+class ChoiceDeltaFunctionCall(BaseModel):
+    arguments: Optional[str] = None
+    name: Optional[str] = None
+
+
+class ChoiceDeltaToolCallFunction(BaseModel):
+    arguments: Optional[str] = None
+    name: Optional[str] = None
+
+
+class ChoiceDeltaToolCall(BaseModel):
+    index: int
+    id: Optional[str] = None
+    function: Optional[ChoiceDeltaToolCallFunction] = None
+    type: Optional[str] = None
+
+
+class ChoiceDelta(BaseModel):
+    content: Optional[str] = None
+    role: Optional[str] = None
+    tool_calls: Optional[List[ChoiceDeltaToolCall]] = None
+
+
+class Choice(BaseModel):
+    delta: ChoiceDelta
+    finish_reason: Optional[str] = None
+    index: int
+
+
+class CompletionUsage(BaseModel):
+    prompt_tokens: int
+    completion_tokens: int
+    total_tokens: int
+
+
+class ChatCompletionChunk(BaseModel):
+    id: Optional[str] = None
+    choices: List[Choice]
+    created: Optional[int] = None
+    model: Optional[str] = None
+    usage: Optional[CompletionUsage] = None

+ 8 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completions_create_param.py

@@ -0,0 +1,8 @@
+from typing import Optional
+
+from typing_extensions import TypedDict
+
+
+class Reference(TypedDict, total=False):
+    enable: Optional[bool]
+    search_query: Optional[str]

+ 20 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/embeddings.py

@@ -0,0 +1,20 @@
+from __future__ import annotations
+
+from typing import Optional, List
+
+from pydantic import BaseModel
+from .chat.chat_completion import CompletionUsage
+__all__ = ["Embedding", "EmbeddingsResponded"]
+
+
+class Embedding(BaseModel):
+    object: str
+    index: Optional[int] = None
+    embedding: List[float]
+
+
+class EmbeddingsResponded(BaseModel):
+    object: str
+    data: List[Embedding]
+    model: str
+    usage: CompletionUsage

+ 24 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py

@@ -0,0 +1,24 @@
+from typing import Optional, List
+
+from pydantic import BaseModel
+
+__all__ = ["FileObject"]
+
+
+class FileObject(BaseModel):
+
+    id: Optional[str] = None
+    bytes: Optional[int] = None
+    created_at: Optional[int] = None
+    filename: Optional[str] = None
+    object: Optional[str] = None
+    purpose: Optional[str] = None
+    status: Optional[str] = None
+    status_details: Optional[str] = None
+
+
+class ListOfFileObject(BaseModel):
+
+    object: Optional[str] = None
+    data: List[FileObject]
+    has_more: Optional[bool] = None

+ 5 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py

@@ -0,0 +1,5 @@
+from __future__ import annotations
+
+from .fine_tuning_job import FineTuningJob as FineTuningJob
+from .fine_tuning_job import ListOfFineTuningJob as ListOfFineTuningJob
+from .fine_tuning_job_event import FineTuningJobEvent as FineTuningJobEvent

+ 52 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py

@@ -0,0 +1,52 @@
+from typing import List, Union, Optional
+from typing_extensions import Literal
+
+from pydantic import BaseModel
+
+__all__ = ["FineTuningJob", "Error", "Hyperparameters", "ListOfFineTuningJob" ]
+
+
+class Error(BaseModel):
+    code: str
+    message: str
+    param: Optional[str] = None
+
+
+class Hyperparameters(BaseModel):
+    n_epochs: Union[str, int, None] = None
+
+
+class FineTuningJob(BaseModel):
+    id: Optional[str] = None
+
+    request_id: Optional[str] = None
+
+    created_at: Optional[int] = None
+
+    error: Optional[Error] = None
+
+    fine_tuned_model: Optional[str] = None
+
+    finished_at: Optional[int] = None
+
+    hyperparameters: Optional[Hyperparameters] = None
+
+    model: Optional[str] = None
+
+    object: Optional[str] = None
+
+    result_files: List[str]
+
+    status: str
+
+    trained_tokens: Optional[int] = None
+
+    training_file: str
+
+    validation_file: Optional[str] = None
+
+
+class ListOfFineTuningJob(BaseModel):
+    object: Optional[str] = None
+    data: List[FineTuningJob]
+    has_more: Optional[bool] = None

+ 36 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py

@@ -0,0 +1,36 @@
+from typing import List, Union, Optional
+from typing_extensions import Literal
+
+from pydantic import BaseModel
+
+__all__ = ["FineTuningJobEvent", "Metric", "JobEvent"]
+
+
+class Metric(BaseModel):
+    epoch: Optional[Union[str, int, float]] = None
+    current_steps: Optional[int] = None
+    total_steps: Optional[int] = None
+    elapsed_time: Optional[str] = None
+    remaining_time: Optional[str] = None
+    trained_tokens: Optional[int] = None
+    loss: Optional[Union[str, int, float]] = None
+    eval_loss: Optional[Union[str, int, float]] = None
+    acc: Optional[Union[str, int, float]] = None
+    eval_acc: Optional[Union[str, int, float]] = None
+    learning_rate: Optional[Union[str, int, float]] = None
+
+
+class JobEvent(BaseModel):
+    object: Optional[str] = None
+    id: Optional[str] = None
+    type: Optional[str] = None
+    created_at: Optional[int] = None
+    level: Optional[str] = None
+    message: Optional[str] = None
+    data: Optional[Metric] = None
+
+
+class FineTuningJobEvent(BaseModel):
+    object: Optional[str] = None
+    data: List[JobEvent]
+    has_more: Optional[bool] = None

+ 15 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/job_create_params.py

@@ -0,0 +1,15 @@
+from __future__ import annotations
+
+from typing import Union
+
+from typing_extensions import Literal, TypedDict
+
+__all__ = ["Hyperparameters"]
+
+
+class Hyperparameters(TypedDict, total=False):
+    batch_size: Union[Literal["auto"], int]
+
+    learning_rate_multiplier: Union[Literal["auto"], float]
+
+    n_epochs: Union[Literal["auto"], int]

+ 18 - 0
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/image.py

@@ -0,0 +1,18 @@
+from __future__ import annotations
+
+from typing import Optional, List
+
+from pydantic import BaseModel
+
+__all__ = ["GeneratedImage", "ImagesResponded"]
+
+
+class GeneratedImage(BaseModel):
+    b64_json: Optional[str] = None
+    url: Optional[str] = None
+    revised_prompt: Optional[str] = None
+
+
+class ImagesResponded(BaseModel):
+    created: int
+    data: List[GeneratedImage]

+ 47 - 1
api/tests/integration_tests/model_runtime/zhipuai/test_llm.py

@@ -3,7 +3,8 @@ from typing import Generator
 
 import pytest
 from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
-from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
+from core.model_runtime.entities.message_entities import (AssistantPromptMessage, SystemPromptMessage, 
+                                                          UserPromptMessage, PromptMessageTool)
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
 from core.model_runtime.model_providers.zhipuai.llm.llm import ZhipuAILargeLanguageModel
 
@@ -102,3 +103,48 @@ def test_get_num_tokens():
     )
 
     assert num_tokens == 14
+
+def test_get_tools_num_tokens():
+    model = ZhipuAILargeLanguageModel()
+
+    num_tokens = model.get_num_tokens(
+        model='tools',
+        credentials={
+            'api_key': os.environ.get('ZHIPUAI_API_KEY')
+        },
+        tools=[
+            PromptMessageTool(
+                name='get_current_weather',
+                description='Get the current weather in a given location',
+                parameters={
+                    "type": "object",
+                    "properties": {
+                        "location": {
+                        "type": "string",
+                            "description": "The city and state e.g. San Francisco, CA"
+                        },
+                        "unit": {
+                            "type": "string",
+                            "enum": [
+                                "c",
+                                "f"
+                            ]
+                        }
+                    },
+                    "required": [
+                        "location"
+                    ]
+                }
+            )
+        ],
+        prompt_messages=[
+            SystemPromptMessage(
+                content='You are a helpful AI assistant.',
+            ),
+            UserPromptMessage(
+                content='Hello World!'
+            )
+        ]
+    )
+
+    assert num_tokens == 108

+ 1 - 1
api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py

@@ -42,7 +42,7 @@ def test_invoke_model():
 
     assert isinstance(result, TextEmbeddingResult)
     assert len(result.embeddings) == 2
-    assert result.usage.total_tokens == 2
+    assert result.usage.total_tokens > 0
 
 
 def test_get_num_tokens():

+ 1 - 1
web/app/components/app/chat/answer/index.tsx

@@ -229,7 +229,7 @@ const Answer: FC<IAnswerProps> = ({
             <Thought
               thought={item}
               allToolIcons={allToolIcons || {}}
-              isFinished={!!item.observation}
+              isFinished={!!item.observation || !isResponsing}
             />
           )}
 

+ 2 - 3
web/app/components/app/configuration/index.tsx

@@ -43,7 +43,7 @@ import { fetchDatasets } from '@/service/datasets'
 import { useProviderContext } from '@/context/provider-context'
 import { AgentStrategy, AppType, ModelModeType, RETRIEVE_TYPE, Resolution, TransferMethod } from '@/types/app'
 import { PromptMode } from '@/models/debug'
-import { ANNOTATION_DEFAULT, DEFAULT_AGENT_SETTING, DEFAULT_CHAT_PROMPT_CONFIG, DEFAULT_COMPLETION_PROMPT_CONFIG } from '@/config'
+import { ANNOTATION_DEFAULT, DEFAULT_AGENT_SETTING, DEFAULT_CHAT_PROMPT_CONFIG, DEFAULT_COMPLETION_PROMPT_CONFIG, supportFunctionCallModels } from '@/config'
 import SelectDataSet from '@/app/components/app/configuration/dataset-config/select-dataset'
 import I18n from '@/context/i18n'
 import { useModalContext } from '@/context/modal-context'
@@ -163,8 +163,7 @@ const Configuration: FC = () => {
     doSetModelConfig(newModelConfig)
   }
   const isOpenAI = modelConfig.provider === 'openai'
-  const isFunctionCall = isOpenAI && modelConfig.mode === ModelModeType.chat
-
+  const isFunctionCall = (isOpenAI && modelConfig.mode === ModelModeType.chat) || supportFunctionCallModels.includes(modelConfig.model_id)
   const [collectionList, setCollectionList] = useState<Collection[]>([])
   useEffect(() => {
 

+ 2 - 0
web/config/index.ts

@@ -160,6 +160,8 @@ export const DEFAULT_AGENT_SETTING = {
   tools: [],
 }
 
+export const supportFunctionCallModels = ['glm-3-turbo', 'glm-4']
+
 export const DEFAULT_AGENT_PROMPT = {
   chat: `Respond to the human as helpfully and accurately as possible.