Browse Source

feat: add zhipuai (#1188)

takatost 1 year ago
parent
commit
827c97f0d3
36 changed files with 1089 additions and 113 deletions
  1. 8 2
      api/controllers/console/workspace/model_providers.py
  2. 12 1
      api/core/callback_handler/llm_callback_handler.py
  3. 7 6
      api/core/chain/sensitive_word_avoidance_chain.py
  4. 51 30
      api/core/completion.py
  5. 19 17
      api/core/helper/moderation.py
  6. 3 0
      api/core/model_providers/model_provider_factory.py
  7. 22 0
      api/core/model_providers/models/embedding/zhipuai_embedding.py
  8. 1 0
      api/core/model_providers/models/entity/model_params.py
  9. 61 0
      api/core/model_providers/models/llm/zhipuai_model.py
  10. 10 6
      api/core/model_providers/models/moderation/openai_moderation.py
  11. 3 3
      api/core/model_providers/providers/anthropic_provider.py
  12. 5 5
      api/core/model_providers/providers/azure_openai_provider.py
  13. 3 3
      api/core/model_providers/providers/chatglm_provider.py
  14. 18 0
      api/core/model_providers/providers/hosted.py
  15. 3 3
      api/core/model_providers/providers/huggingface_hub_provider.py
  16. 3 3
      api/core/model_providers/providers/localai_provider.py
  17. 3 3
      api/core/model_providers/providers/minimax_provider.py
  18. 5 5
      api/core/model_providers/providers/openai_provider.py
  19. 5 5
      api/core/model_providers/providers/openllm_provider.py
  20. 2 0
      api/core/model_providers/providers/replicate_provider.py
  21. 2 2
      api/core/model_providers/providers/spark_provider.py
  22. 2 2
      api/core/model_providers/providers/tongyi_provider.py
  23. 2 2
      api/core/model_providers/providers/wenxin_provider.py
  24. 11 11
      api/core/model_providers/providers/xinference_provider.py
  25. 176 0
      api/core/model_providers/providers/zhipuai_provider.py
  26. 1 0
      api/core/model_providers/rules/_providers.json
  27. 44 0
      api/core/model_providers/rules/zhipuai.json
  28. 64 0
      api/core/third_party/langchain/embeddings/zhipuai_embedding.py
  29. 315 0
      api/core/third_party/langchain/llms/zhipuai_llm.py
  30. 2 1
      api/requirements.txt
  31. 5 2
      api/services/provider_service.py
  32. 3 0
      api/tests/integration_tests/.env.example
  33. 50 0
      api/tests/integration_tests/models/embedding/test_zhipuai_embedding.py
  34. 79 0
      api/tests/integration_tests/models/llm/test_zhipuai_model.py
  35. 1 1
      api/tests/unit_tests/model_providers/test_spark_provider.py
  36. 88 0
      api/tests/unit_tests/model_providers/test_zhipuai_provider.py

+ 8 - 2
api/controllers/console/workspace/model_providers.py

@@ -246,7 +246,8 @@ class ModelProviderModelParameterRuleApi(Resource):
                 'enabled': v.enabled,
                 'min': v.min,
                 'max': v.max,
-                'default': v.default
+                'default': v.default,
+                'precision': v.precision
             }
             for k, v in vars(parameter_rules).items()
         }
@@ -290,10 +291,15 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
     @login_required
     @account_initialization_required
     def get(self, provider_name: str):
+        parser = reqparse.RequestParser()
+        parser.add_argument('token', type=str, required=False, nullable=True, location='args')
+        args = parser.parse_args()
+
         provider_service = ProviderService()
         result = provider_service.free_quota_qualification_verify(
             tenant_id=current_user.current_tenant_id,
-            provider_name=provider_name
+            provider_name=provider_name,
+            token=args['token']
         )
 
         return result

+ 12 - 1
api/core/callback_handler/llm_callback_handler.py

@@ -63,7 +63,18 @@ class LLMCallbackHandler(BaseCallbackHandler):
             self.conversation_message_task.append_message_text(response.generations[0][0].text)
             self.llm_message.completion = response.generations[0][0].text
 
-        self.llm_message.completion_tokens = self.model_instance.get_num_tokens([PromptMessage(content=self.llm_message.completion)])
+        if response.llm_output and 'token_usage' in response.llm_output:
+            if 'prompt_tokens' in response.llm_output['token_usage']:
+                self.llm_message.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
+
+            if 'completion_tokens' in response.llm_output['token_usage']:
+                self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
+            else:
+                self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
+                    [PromptMessage(content=self.llm_message.completion)])
+        else:
+            self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
+                [PromptMessage(content=self.llm_message.completion)])
 
         self.conversation_message_task.save_message(self.llm_message)
 

+ 7 - 6
api/core/chain/sensitive_word_avoidance_chain.py

@@ -2,13 +2,8 @@ import enum
 import logging
 from typing import List, Dict, Optional, Any
 
-import openai
-from flask import current_app
 from langchain.callbacks.manager import CallbackManagerForChainRun
 from langchain.chains.base import Chain
-from openai import InvalidRequestError
-from openai.error import APIConnectionError, APIError, ServiceUnavailableError, Timeout, RateLimitError, \
-    AuthenticationError, OpenAIError
 from pydantic import BaseModel
 
 from core.model_providers.error import LLMBadRequestError
@@ -86,6 +81,12 @@ class SensitiveWordAvoidanceChain(Chain):
             result = self._check_moderation(text)
 
         if not result:
-            raise LLMBadRequestError(self.sensitive_word_avoidance_rule.canned_response)
+            raise SensitiveWordAvoidanceError(self.sensitive_word_avoidance_rule.canned_response)
 
         return {self.output_key: text}
+
+
+class SensitiveWordAvoidanceError(Exception):
+    def __init__(self, message):
+        super().__init__(message)
+        self.message = message

+ 51 - 30
api/core/completion.py

@@ -7,6 +7,7 @@ from requests.exceptions import ChunkedEncodingError
 from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
 from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
 from core.callback_handler.llm_callback_handler import LLMCallbackHandler
+from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceError
 from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
 from core.model_providers.error import LLMBadRequestError
 from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
@@ -76,28 +77,53 @@ class Completion:
             app_model_config=app_model_config
         )
 
-        # parse sensitive_word_avoidance_chain
-        chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
-        sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain(final_model_instance, [chain_callback])
-        if sensitive_word_avoidance_chain:
-            query = sensitive_word_avoidance_chain.run(query)
-
-        # get agent executor
-        agent_executor = orchestrator_rule_parser.to_agent_executor(
-            conversation_message_task=conversation_message_task,
-            memory=memory,
-            rest_tokens=rest_tokens_for_context_and_memory,
-            chain_callback=chain_callback
-        )
-
-        # run agent executor
-        agent_execute_result = None
-        if agent_executor:
-            should_use_agent = agent_executor.should_use_agent(query)
-            if should_use_agent:
-                agent_execute_result = agent_executor.run(query)
-        # run the final llm
         try:
+            # parse sensitive_word_avoidance_chain
+            chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
+            sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain(
+                final_model_instance, [chain_callback])
+            if sensitive_word_avoidance_chain:
+                try:
+                    query = sensitive_word_avoidance_chain.run(query)
+                except SensitiveWordAvoidanceError as ex:
+                    cls.run_final_llm(
+                        model_instance=final_model_instance,
+                        mode=app.mode,
+                        app_model_config=app_model_config,
+                        query=query,
+                        inputs=inputs,
+                        agent_execute_result=None,
+                        conversation_message_task=conversation_message_task,
+                        memory=memory,
+                        fake_response=ex.message
+                    )
+                    return
+
+            # get agent executor
+            agent_executor = orchestrator_rule_parser.to_agent_executor(
+                conversation_message_task=conversation_message_task,
+                memory=memory,
+                rest_tokens=rest_tokens_for_context_and_memory,
+                chain_callback=chain_callback,
+                retriever_from=retriever_from
+            )
+
+            # run agent executor
+            agent_execute_result = None
+            if agent_executor:
+                should_use_agent = agent_executor.should_use_agent(query)
+                if should_use_agent:
+                    agent_execute_result = agent_executor.run(query)
+
+            # When no extra pre prompt is specified,
+            # the output of the agent can be used directly as the main output content without calling LLM again
+            fake_response = None
+            if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
+                    and agent_execute_result.strategy not in [PlanningStrategy.ROUTER,
+                                                              PlanningStrategy.REACT_ROUTER]:
+                fake_response = agent_execute_result.output
+
+            # run the final llm
             cls.run_final_llm(
                 model_instance=final_model_instance,
                 mode=app.mode,
@@ -106,7 +132,8 @@ class Completion:
                 inputs=inputs,
                 agent_execute_result=agent_execute_result,
                 conversation_message_task=conversation_message_task,
-                memory=memory
+                memory=memory,
+                fake_response=fake_response
             )
         except ConversationTaskStoppedException:
             return
@@ -121,14 +148,8 @@ class Completion:
                       inputs: dict,
                       agent_execute_result: Optional[AgentExecuteResult],
                       conversation_message_task: ConversationMessageTask,
-                      memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]):
-        # When no extra pre prompt is specified,
-        # the output of the agent can be used directly as the main output content without calling LLM again
-        fake_response = None
-        if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
-                and agent_execute_result.strategy not in [PlanningStrategy.ROUTER, PlanningStrategy.REACT_ROUTER]:
-            fake_response = agent_execute_result.output
-
+                      memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
+                      fake_response: Optional[str]):
         # get llm prompt
         prompt_messages, stop_words = model_instance.get_prompt(
             mode=mode,

+ 19 - 17
api/core/helper/moderation.py

@@ -1,32 +1,34 @@
 import logging
 
 import openai
-from flask import current_app
 
 from core.model_providers.error import LLMBadRequestError
 from core.model_providers.providers.base import BaseModelProvider
+from core.model_providers.providers.hosted import hosted_config, hosted_model_providers
 from models.provider import ProviderType
 
 
 def check_moderation(model_provider: BaseModelProvider, text: str) -> bool:
-    if current_app.config['HOSTED_MODERATION_ENABLED'] and current_app.config['HOSTED_MODERATION_PROVIDERS']:
-        moderation_providers = current_app.config['HOSTED_MODERATION_PROVIDERS'].split(',')
-
+    if hosted_config.moderation.enabled is True and hosted_model_providers.openai:
         if model_provider.provider.provider_type == ProviderType.SYSTEM.value \
-                and model_provider.provider_name in moderation_providers:
+                and model_provider.provider_name in hosted_config.moderation.providers:
             # 2000 text per chunk
             length = 2000
-            chunks = [text[i:i + length] for i in range(0, len(text), length)]
-
-            try:
-                moderation_result = openai.Moderation.create(input=chunks,
-                                                             api_key=current_app.config['HOSTED_OPENAI_API_KEY'])
-            except Exception as ex:
-                logging.exception(ex)
-                raise LLMBadRequestError('Rate limit exceeded, please try again later.')
-
-            for result in moderation_result.results:
-                if result['flagged'] is True:
-                    return False
+            text_chunks = [text[i:i + length] for i in range(0, len(text), length)]
+
+            max_text_chunks = 32
+            chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)]
+
+            for text_chunk in chunks:
+                try:
+                    moderation_result = openai.Moderation.create(input=text_chunk,
+                                                                 api_key=hosted_model_providers.openai.api_key)
+                except Exception as ex:
+                    logging.exception(ex)
+                    raise LLMBadRequestError('Rate limit exceeded, please try again later.')
+
+                for result in moderation_result.results:
+                    if result['flagged'] is True:
+                        return False
 
     return True

+ 3 - 0
api/core/model_providers/model_provider_factory.py

@@ -45,6 +45,9 @@ class ModelProviderFactory:
         elif provider_name == 'wenxin':
             from core.model_providers.providers.wenxin_provider import WenxinProvider
             return WenxinProvider
+        elif provider_name == 'zhipuai':
+            from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider
+            return ZhipuAIProvider
         elif provider_name == 'chatglm':
             from core.model_providers.providers.chatglm_provider import ChatGLMProvider
             return ChatGLMProvider

+ 22 - 0
api/core/model_providers/models/embedding/zhipuai_embedding.py

@@ -0,0 +1,22 @@
+from core.model_providers.error import LLMBadRequestError
+from core.model_providers.providers.base import BaseModelProvider
+from core.model_providers.models.embedding.base import BaseEmbedding
+from core.third_party.langchain.embeddings.zhipuai_embedding import ZhipuAIEmbeddings
+
+
+class ZhipuAIEmbedding(BaseEmbedding):
+    def __init__(self, model_provider: BaseModelProvider, name: str):
+        credentials = model_provider.get_model_credentials(
+            model_name=name,
+            model_type=self.type
+        )
+
+        client = ZhipuAIEmbeddings(
+            model=name,
+            **credentials,
+        )
+
+        super().__init__(model_provider, client, name)
+
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        return LLMBadRequestError(f"ZhipuAI embedding: {str(ex)}")

+ 1 - 0
api/core/model_providers/models/entity/model_params.py

@@ -49,6 +49,7 @@ class KwargRule(Generic[T], BaseModel):
     max: Optional[T] = None
     default: Optional[T] = None
     alias: Optional[str] = None
+    precision: Optional[int] = None
 
 
 class ModelKwargsRules(BaseModel):

+ 61 - 0
api/core/model_providers/models/llm/zhipuai_model.py

@@ -0,0 +1,61 @@
+from typing import List, Optional, Any
+
+from langchain.callbacks.manager import Callbacks
+from langchain.schema import LLMResult
+
+from core.model_providers.error import LLMBadRequestError
+from core.model_providers.models.llm.base import BaseLLM
+from core.model_providers.models.entity.message import PromptMessage
+from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
+from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM
+
+
+class ZhipuAIModel(BaseLLM):
+    model_mode: ModelMode = ModelMode.CHAT
+
+    def _init_client(self) -> Any:
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
+        return ZhipuAIChatLLM(
+            streaming=self.streaming,
+            callbacks=self.callbacks,
+            **self.credentials,
+            **provider_model_kwargs
+        )
+
+    def _run(self, messages: List[PromptMessage],
+             stop: Optional[List[str]] = None,
+             callbacks: Callbacks = None,
+             **kwargs) -> LLMResult:
+        """
+        run predict by prompt messages and stop words.
+
+        :param messages:
+        :param stop:
+        :param callbacks:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        return self._client.generate([prompts], stop, callbacks)
+
+    def get_num_tokens(self, messages: List[PromptMessage]) -> int:
+        """
+        get num tokens of prompt messages.
+
+        :param messages:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        return max(self._client.get_num_tokens_from_messages(prompts), 0)
+
+    def _set_model_kwargs(self, model_kwargs: ModelKwargs):
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
+        for k, v in provider_model_kwargs.items():
+            if hasattr(self.client, k):
+                setattr(self.client, k, v)
+
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        return LLMBadRequestError(f"ZhipuAI: {str(ex)}")
+
+    @property
+    def support_streaming(self):
+        return True

+ 10 - 6
api/core/model_providers/models/moderation/openai_moderation.py

@@ -23,14 +23,18 @@ class OpenAIModeration(BaseModeration):
 
         # 2000 text per chunk
         length = 2000
-        chunks = [text[i:i + length] for i in range(0, len(text), length)]
+        text_chunks = [text[i:i + length] for i in range(0, len(text), length)]
 
-        moderation_result = self._client.create(input=chunks,
-                                                api_key=credentials['openai_api_key'])
+        max_text_chunks = 32
+        chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)]
 
-        for result in moderation_result.results:
-            if result['flagged'] is True:
-                return False
+        for text_chunk in chunks:
+            moderation_result = self._client.create(input=text_chunk,
+                                                    api_key=credentials['openai_api_key'])
+
+            for result in moderation_result.results:
+                if result['flagged'] is True:
+                    return False
 
         return True
 

+ 3 - 3
api/core/model_providers/providers/anthropic_provider.py

@@ -69,11 +69,11 @@ class AnthropicProvider(BaseModelProvider):
         :return:
         """
         return ModelKwargsRules(
-            temperature=KwargRule[float](min=0, max=1, default=1),
-            top_p=KwargRule[float](min=0, max=1, default=0.7),
+            temperature=KwargRule[float](min=0, max=1, default=1, precision=2),
+            top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
             presence_penalty=KwargRule[float](enabled=False),
             frequency_penalty=KwargRule[float](enabled=False),
-            max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=100000, default=256),
+            max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=100000, default=256, precision=0),
         )
 
     @classmethod

+ 5 - 5
api/core/model_providers/providers/azure_openai_provider.py

@@ -164,14 +164,14 @@ class AzureOpenAIProvider(BaseModelProvider):
         model_credentials = self.get_model_credentials(model_name, model_type)
 
         return ModelKwargsRules(
-            temperature=KwargRule[float](min=0, max=2, default=1),
-            top_p=KwargRule[float](min=0, max=1, default=1),
-            presence_penalty=KwargRule[float](min=-2, max=2, default=0),
-            frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
+            temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
+            top_p=KwargRule[float](min=0, max=1, default=1, precision=2),
+            presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
+            frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
             max_tokens=KwargRule[int](min=10, max=base_model_max_tokens.get(
                 model_credentials['base_model_name'],
                 4097
-            ), default=16),
+            ), default=16, precision=0),
         )
 
     @classmethod

+ 3 - 3
api/core/model_providers/providers/chatglm_provider.py

@@ -64,11 +64,11 @@ class ChatGLMProvider(BaseModelProvider):
         }
 
         return ModelKwargsRules(
-            temperature=KwargRule[float](min=0, max=2, default=1),
-            top_p=KwargRule[float](min=0, max=1, default=0.7),
+            temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
+            top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
             presence_penalty=KwargRule[float](enabled=False),
             frequency_penalty=KwargRule[float](enabled=False),
-            max_tokens=KwargRule[int](alias='max_token', min=10, max=model_max_tokens.get(model_name), default=2048),
+            max_tokens=KwargRule[int](alias='max_token', min=10, max=model_max_tokens.get(model_name), default=2048, precision=0),
         )
 
     @classmethod

+ 18 - 0
api/core/model_providers/providers/hosted.py

@@ -45,6 +45,18 @@ class HostedModelProviders(BaseModel):
 hosted_model_providers = HostedModelProviders()
 
 
+class HostedModerationConfig(BaseModel):
+    enabled: bool = False
+    providers: list[str] = []
+
+
+class HostedConfig(BaseModel):
+    moderation = HostedModerationConfig()
+
+
+hosted_config = HostedConfig()
+
+
 def init_app(app: Flask):
     if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
         langchain.verbose = True
@@ -78,3 +90,9 @@ def init_app(app: Flask):
             paid_min_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY"),
             paid_max_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY"),
         )
+
+    if app.config.get("HOSTED_MODERATION_ENABLED") and app.config.get("HOSTED_MODERATION_PROVIDERS"):
+        hosted_config.moderation = HostedModerationConfig(
+            enabled=app.config.get("HOSTED_MODERATION_ENABLED"),
+            providers=app.config.get("HOSTED_MODERATION_PROVIDERS").split(',')
+        )

+ 3 - 3
api/core/model_providers/providers/huggingface_hub_provider.py

@@ -47,11 +47,11 @@ class HuggingfaceHubProvider(BaseModelProvider):
         :return:
         """
         return ModelKwargsRules(
-            temperature=KwargRule[float](min=0, max=2, default=1),
-            top_p=KwargRule[float](min=0.01, max=0.99, default=0.7),
+            temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
+            top_p=KwargRule[float](min=0.01, max=0.99, default=0.7, precision=2),
             presence_penalty=KwargRule[float](enabled=False),
             frequency_penalty=KwargRule[float](enabled=False),
-            max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=200),
+            max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=200, precision=0),
         )
 
     @classmethod

+ 3 - 3
api/core/model_providers/providers/localai_provider.py

@@ -52,9 +52,9 @@ class LocalAIProvider(BaseModelProvider):
         :return:
         """
         return ModelKwargsRules(
-            temperature=KwargRule[float](min=0, max=2, default=0.7),
-            top_p=KwargRule[float](min=0, max=1, default=1),
-            max_tokens=KwargRule[int](min=10, max=4097, default=16),
+            temperature=KwargRule[float](min=0, max=2, default=0.7, precision=2),
+            top_p=KwargRule[float](min=0, max=1, default=1, precision=2),
+            max_tokens=KwargRule[int](min=10, max=4097, default=16, precision=0),
         )
 
     @classmethod

+ 3 - 3
api/core/model_providers/providers/minimax_provider.py

@@ -74,11 +74,11 @@ class MinimaxProvider(BaseModelProvider):
         }
 
         return ModelKwargsRules(
-            temperature=KwargRule[float](min=0.01, max=1, default=0.9),
-            top_p=KwargRule[float](min=0, max=1, default=0.95),
+            temperature=KwargRule[float](min=0.01, max=1, default=0.9, precision=2),
+            top_p=KwargRule[float](min=0, max=1, default=0.95, precision=2),
             presence_penalty=KwargRule[float](enabled=False),
             frequency_penalty=KwargRule[float](enabled=False),
-            max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 6144), default=1024),
+            max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 6144), default=1024, precision=0),
         )
 
     @classmethod

+ 5 - 5
api/core/model_providers/providers/openai_provider.py

@@ -133,11 +133,11 @@ class OpenAIProvider(BaseModelProvider):
         }
 
         return ModelKwargsRules(
-            temperature=KwargRule[float](min=0, max=2, default=1),
-            top_p=KwargRule[float](min=0, max=1, default=1),
-            presence_penalty=KwargRule[float](min=-2, max=2, default=0),
-            frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
-            max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 4097), default=16),
+            temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
+            top_p=KwargRule[float](min=0, max=1, default=1, precision=2),
+            presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
+            frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
+            max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 4097), default=16, precision=0),
         )
 
     @classmethod

+ 5 - 5
api/core/model_providers/providers/openllm_provider.py

@@ -45,11 +45,11 @@ class OpenLLMProvider(BaseModelProvider):
         :return:
         """
         return ModelKwargsRules(
-            temperature=KwargRule[float](min=0.01, max=2, default=1),
-            top_p=KwargRule[float](min=0, max=1, default=0.7),
-            presence_penalty=KwargRule[float](min=-2, max=2, default=0),
-            frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
-            max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128),
+            temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2),
+            top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
+            presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
+            frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
+            max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128, precision=0),
         )
 
     @classmethod

+ 2 - 0
api/core/model_providers/providers/replicate_provider.py

@@ -72,6 +72,7 @@ class ReplicateProvider(BaseModelProvider):
                         min=float(value.get('minimum')) if value.get('minimum') is not None else None,
                         max=float(value.get('maximum')) if value.get('maximum') is not None else None,
                         default=float(value.get('default')) if value.get('default') is not None else None,
+                        precision = 2
                     )
                     if key == 'temperature':
                         model_kwargs_rules.temperature = kwarg_rule
@@ -84,6 +85,7 @@ class ReplicateProvider(BaseModelProvider):
                         min=int(value.get('minimum')) if value.get('minimum') is not None else 1,
                         max=int(value.get('maximum')) if value.get('maximum') is not None else 8000,
                         default=int(value.get('default')) if value.get('default') is not None else 500,
+                        precision = 0
                     )
 
         return model_kwargs_rules

+ 2 - 2
api/core/model_providers/providers/spark_provider.py

@@ -62,11 +62,11 @@ class SparkProvider(BaseModelProvider):
         :return:
         """
         return ModelKwargsRules(
-            temperature=KwargRule[float](min=0, max=1, default=0.5),
+            temperature=KwargRule[float](min=0, max=1, default=0.5, precision=2),
             top_p=KwargRule[float](enabled=False),
             presence_penalty=KwargRule[float](enabled=False),
             frequency_penalty=KwargRule[float](enabled=False),
-            max_tokens=KwargRule[int](min=10, max=4096, default=2048),
+            max_tokens=KwargRule[int](min=10, max=4096, default=2048, precision=0),
         )
 
     @classmethod

+ 2 - 2
api/core/model_providers/providers/tongyi_provider.py

@@ -64,10 +64,10 @@ class TongyiProvider(BaseModelProvider):
 
         return ModelKwargsRules(
             temperature=KwargRule[float](enabled=False),
-            top_p=KwargRule[float](min=0, max=1, default=0.8),
+            top_p=KwargRule[float](min=0, max=1, default=0.8, precision=2),
             presence_penalty=KwargRule[float](enabled=False),
             frequency_penalty=KwargRule[float](enabled=False),
-            max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name), default=1024),
+            max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name), default=1024, precision=0),
         )
 
     @classmethod

+ 2 - 2
api/core/model_providers/providers/wenxin_provider.py

@@ -63,8 +63,8 @@ class WenxinProvider(BaseModelProvider):
         """
         if model_name in ['ernie-bot', 'ernie-bot-turbo']:
             return ModelKwargsRules(
-                temperature=KwargRule[float](min=0.01, max=1, default=0.95),
-                top_p=KwargRule[float](min=0.01, max=1, default=0.8),
+                temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2),
+                top_p=KwargRule[float](min=0.01, max=1, default=0.8, precision=2),
                 presence_penalty=KwargRule[float](enabled=False),
                 frequency_penalty=KwargRule[float](enabled=False),
                 max_tokens=KwargRule[int](enabled=False),

+ 11 - 11
api/core/model_providers/providers/xinference_provider.py

@@ -53,27 +53,27 @@ class XinferenceProvider(BaseModelProvider):
         credentials = self.get_model_credentials(model_name, model_type)
         if credentials['model_format'] == "ggmlv3" and credentials["model_handle_type"] == "chatglm":
             return ModelKwargsRules(
-                temperature=KwargRule[float](min=0.01, max=2, default=1),
-                top_p=KwargRule[float](min=0, max=1, default=0.7),
+                temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2),
+                top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
                 presence_penalty=KwargRule[float](enabled=False),
                 frequency_penalty=KwargRule[float](enabled=False),
-                max_tokens=KwargRule[int](min=10, max=4000, default=256),
+                max_tokens=KwargRule[int](min=10, max=4000, default=256, precision=0),
             )
         elif credentials['model_format'] == "ggmlv3":
             return ModelKwargsRules(
-                temperature=KwargRule[float](min=0.01, max=2, default=1),
-                top_p=KwargRule[float](min=0, max=1, default=0.7),
-                presence_penalty=KwargRule[float](min=-2, max=2, default=0),
-                frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
-                max_tokens=KwargRule[int](min=10, max=4000, default=256),
+                temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2),
+                top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
+                presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
+                frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
+                max_tokens=KwargRule[int](min=10, max=4000, default=256, precision=0),
             )
         else:
             return ModelKwargsRules(
-                temperature=KwargRule[float](min=0.01, max=2, default=1),
-                top_p=KwargRule[float](min=0, max=1, default=0.7),
+                temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2),
+                top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
                 presence_penalty=KwargRule[float](enabled=False),
                 frequency_penalty=KwargRule[float](enabled=False),
-                max_tokens=KwargRule[int](min=10, max=4000, default=256),
+                max_tokens=KwargRule[int](min=10, max=4000, default=256, precision=0),
             )
 
 

+ 176 - 0
api/core/model_providers/providers/zhipuai_provider.py

@@ -0,0 +1,176 @@
+import json
+from json import JSONDecodeError
+from typing import Type
+
+from langchain.schema import HumanMessage
+
+from core.helper import encrypter
+from core.model_providers.models.base import BaseProviderModel
+from core.model_providers.models.embedding.zhipuai_embedding import ZhipuAIEmbedding
+from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
+from core.model_providers.models.llm.zhipuai_model import ZhipuAIModel
+from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
+from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM
+from models.provider import ProviderType, ProviderQuotaType
+
+
+class ZhipuAIProvider(BaseModelProvider):
+
+    @property
+    def provider_name(self):
+        """
+        Returns the name of a provider.
+        """
+        return 'zhipuai'
+
+    def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
+        if model_type == ModelType.TEXT_GENERATION:
+            return [
+                {
+                    'id': 'chatglm_pro',
+                    'name': 'chatglm_pro',
+                },
+                {
+                    'id': 'chatglm_std',
+                    'name': 'chatglm_std',
+                },
+                {
+                    'id': 'chatglm_lite',
+                    'name': 'chatglm_lite',
+                },
+                {
+                    'id': 'chatglm_lite_32k',
+                    'name': 'chatglm_lite_32k',
+                }
+            ]
+        elif model_type == ModelType.EMBEDDINGS:
+            return [
+                {
+                    'id': 'text_embedding',
+                    'name': 'text_embedding',
+                }
+            ]
+        else:
+            return []
+
+    def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
+        """
+        Returns the model class.
+
+        :param model_type:
+        :return:
+        """
+        if model_type == ModelType.TEXT_GENERATION:
+            model_class = ZhipuAIModel
+        elif model_type == ModelType.EMBEDDINGS:
+            model_class = ZhipuAIEmbedding
+        else:
+            raise NotImplementedError
+
+        return model_class
+
+    def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
+        """
+        get model parameter rules.
+
+        :param model_name:
+        :param model_type:
+        :return:
+        """
+        return ModelKwargsRules(
+            temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2),
+            top_p=KwargRule[float](min=0.1, max=0.9, default=0.8, precision=1),
+            presence_penalty=KwargRule[float](enabled=False),
+            frequency_penalty=KwargRule[float](enabled=False),
+            max_tokens=KwargRule[int](enabled=False),
+        )
+
+    @classmethod
+    def is_provider_credentials_valid_or_raise(cls, credentials: dict):
+        """
+        Validates the given credentials.
+        """
+        if 'api_key' not in credentials:
+            raise CredentialsValidateFailedError('ZhipuAI api_key must be provided.')
+
+        try:
+            credential_kwargs = {
+                'api_key': credentials['api_key']
+            }
+
+            llm = ZhipuAIChatLLM(
+                temperature=0.01,
+                **credential_kwargs
+            )
+
+            llm([HumanMessage(content='ping')])
+        except Exception as ex:
+            raise CredentialsValidateFailedError(str(ex))
+
+    @classmethod
+    def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
+        credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key'])
+        return credentials
+
+    def get_provider_credentials(self, obfuscated: bool = False) -> dict:
+        if self.provider.provider_type == ProviderType.CUSTOM.value \
+                or (self.provider.provider_type == ProviderType.SYSTEM.value
+                    and self.provider.quota_type == ProviderQuotaType.FREE.value):
+            try:
+                credentials = json.loads(self.provider.encrypted_config)
+            except JSONDecodeError:
+                credentials = {
+                    'api_key': None,
+                }
+
+            if credentials['api_key']:
+                credentials['api_key'] = encrypter.decrypt_token(
+                    self.provider.tenant_id,
+                    credentials['api_key']
+                )
+
+                if obfuscated:
+                    credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key'])
+
+            return credentials
+        else:
+            return {}
+
+    def should_deduct_quota(self):
+        return True
+
+    @classmethod
+    def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
+        """
+        check model credentials valid.
+
+        :param model_name:
+        :param model_type:
+        :param credentials:
+        """
+        return
+
+    @classmethod
+    def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
+                                  credentials: dict) -> dict:
+        """
+        encrypt model credentials for save.
+
+        :param tenant_id:
+        :param model_name:
+        :param model_type:
+        :param credentials:
+        :return:
+        """
+        return {}
+
+    def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
+        """
+        get credentials for llm use.
+
+        :param model_name:
+        :param model_type:
+        :param obfuscated:
+        :return:
+        """
+        return self.get_provider_credentials(obfuscated)

+ 1 - 0
api/core/model_providers/rules/_providers.json

@@ -6,6 +6,7 @@
   "tongyi",
   "spark",
   "wenxin",
+  "zhipuai",
   "chatglm",
   "replicate",
   "huggingface_hub",

+ 44 - 0
api/core/model_providers/rules/zhipuai.json

@@ -0,0 +1,44 @@
+{
+    "support_provider_types": [
+        "system",
+        "custom"
+    ],
+    "system_config": {
+        "supported_quota_types": [
+            "free"
+        ],
+        "quota_unit": "tokens"
+    },
+    "model_flexibility": "fixed",
+    "price_config": {
+        "chatglm_pro": {
+            "prompt": "0.01",
+            "completion": "0.01",
+            "unit": "0.001",
+            "currency": "RMB"
+        },
+        "chatglm_std": {
+            "prompt": "0.005",
+            "completion": "0.005",
+            "unit": "0.001",
+            "currency": "RMB"
+        },
+        "chatglm_lite": {
+            "prompt": "0.002",
+            "completion": "0.002",
+            "unit": "0.001",
+            "currency": "RMB"
+        },
+        "chatglm_lite_32k": {
+            "prompt": "0.0004",
+            "completion": "0.0004",
+            "unit": "0.001",
+            "currency": "RMB"
+        },
+        "text_embedding": {
+            "completion": "0",
+            "unit": "0.001",
+            "currency": "RMB"
+        }
+    }
+}

+ 64 - 0
api/core/third_party/langchain/embeddings/zhipuai_embedding.py

@@ -0,0 +1,64 @@
+"""Wrapper around ZhipuAI embedding models."""
+from typing import Any, Dict, List, Optional
+
+from pydantic import BaseModel, Extra, root_validator
+
+from langchain.embeddings.base import Embeddings
+from langchain.utils import get_from_dict_or_env
+
+from core.third_party.langchain.llms.zhipuai_llm import ZhipuModelAPI
+
+
+class ZhipuAIEmbeddings(BaseModel, Embeddings):
+    """Wrapper around ZhipuAI embedding models.
+    1024 dimensions.
+    """
+
+    client: Any  #: :meta private:
+    model: str
+    """Model name to use."""
+
+    base_url: str = "https://open.bigmodel.cn/api/paas/v3/model-api"
+    api_key: Optional[str] = None
+
+    class Config:
+        """Configuration for this pydantic object."""
+
+        extra = Extra.forbid
+
+    @root_validator()
+    def validate_environment(cls, values: Dict) -> Dict:
+        """Validate that api key and python package exists in environment."""
+        values["api_key"] = get_from_dict_or_env(
+            values, "api_key", "ZHIPUAI_API_KEY"
+        )
+        values['client'] = ZhipuModelAPI(api_key=values['api_key'], base_url=values['base_url'])
+        return values
+
+    def embed_documents(self, texts: List[str]) -> List[List[float]]:
+        """Call out to ZhipuAI's embedding endpoint.
+
+        Args:
+            texts: The list of texts to embed.
+
+        Returns:
+            List of embeddings, one for each text.
+        """
+        embeddings = []
+        for text in texts:
+            response = self.client.invoke(model=self.model, prompt=text)
+            data = response["data"]
+            embeddings.append(data.get('embedding'))
+
+        return [list(map(float, e)) for e in embeddings]
+
+    def embed_query(self, text: str) -> List[float]:
+        """Call out to ZhipuAI's embedding endpoint.
+
+        Args:
+            text: The text to embed.
+
+        Returns:
+            Embeddings for the text.
+        """
+        return self.embed_documents([text])[0]

+ 315 - 0
api/core/third_party/langchain/llms/zhipuai_llm.py

@@ -0,0 +1,315 @@
+"""Wrapper around ZhipuAI APIs."""
+from __future__ import annotations
+
+import json
+import logging
+import posixpath
+from typing import (
+    Any,
+    Dict,
+    List,
+    Optional, Iterator, Sequence,
+)
+
+import zhipuai
+from langchain.chat_models.base import BaseChatModel
+from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage
+from langchain.schema.messages import AIMessageChunk
+from langchain.schema.output import ChatResult, ChatGenerationChunk, ChatGeneration
+from pydantic import Extra, root_validator, BaseModel
+
+from langchain.callbacks.manager import (
+    CallbackManagerForLLMRun,
+)
+from langchain.utils import get_from_dict_or_env
+from zhipuai.model_api.api import InvokeType
+from zhipuai.utils import jwt_token
+from zhipuai.utils.http_client import post, stream
+from zhipuai.utils.sse_client import SSEClient
+
+logger = logging.getLogger(__name__)
+
+
+class ZhipuModelAPI(BaseModel):
+    base_url: str
+    api_key: str
+    api_timeout_seconds = 60
+
+    class Config:
+        """Configuration for this pydantic object."""
+
+        extra = Extra.forbid
+
+    def invoke(self, **kwargs):
+        url = self._build_api_url(kwargs, InvokeType.SYNC)
+        response = post(url, self._generate_token(), kwargs, self.api_timeout_seconds)
+        if not response['success']:
+            raise ValueError(
+                f"Error Code: {response['code']}, Message: {response['msg']} "
+            )
+        return response
+
+    def sse_invoke(self, **kwargs):
+        url = self._build_api_url(kwargs, InvokeType.SSE)
+        data = stream(url, self._generate_token(), kwargs, self.api_timeout_seconds)
+        return SSEClient(data)
+
+    def _build_api_url(self, kwargs, *path):
+        if kwargs:
+            if "model" not in kwargs:
+                raise Exception("model param missed")
+            model = kwargs.pop("model")
+        else:
+            model = "-"
+
+        return posixpath.join(self.base_url, model, *path)
+
+    def _generate_token(self):
+        if not self.api_key:
+            raise Exception(
+                "api_key not provided, you could provide it."
+            )
+
+        try:
+            return jwt_token.generate_token(self.api_key)
+        except Exception:
+            raise ValueError(
+                f"Your api_key is invalid, please check it."
+            )
+
+
+class ZhipuAIChatLLM(BaseChatModel):
+    """Wrapper around ZhipuAI large language models.
+    To use, you should pass the api_key as a named parameter to the constructor.
+    Example:
+     .. code-block:: python
+         from core.third_party.langchain.llms.zhipuai import ZhipuAI
+         model = ZhipuAI(model="<model_name>", api_key="my-api-key")
+    """
+
+    @property
+    def lc_secrets(self) -> Dict[str, str]:
+        return {"api_key": "API_KEY"}
+
+    @property
+    def lc_serializable(self) -> bool:
+        return True
+
+    client: Any = None  #: :meta private:
+    model: str = "chatglm_lite"
+    """Model name to use."""
+    temperature: float = 0.95
+    """A non-negative float that tunes the degree of randomness in generation."""
+    top_p: float = 0.7
+    """Total probability mass of tokens to consider at each step."""
+    streaming: bool = False
+    """Whether to stream the response or return it all at once."""
+    api_key: Optional[str] = None
+
+    base_url: str = "https://open.bigmodel.cn/api/paas/v3/model-api"
+
+    class Config:
+        """Configuration for this pydantic object."""
+
+        extra = Extra.forbid
+
+    @root_validator()
+    def validate_environment(cls, values: Dict) -> Dict:
+        """Validate that api key and python package exists in environment."""
+        values["api_key"] = get_from_dict_or_env(
+            values, "api_key", "ZHIPUAI_API_KEY"
+        )
+
+        if 'test' in values['base_url']:
+            values['model'] = 'chatglm_130b_test'
+
+        values['client'] = ZhipuModelAPI(api_key=values['api_key'], base_url=values['base_url'])
+        return values
+
+    @property
+    def _default_params(self) -> Dict[str, Any]:
+        """Get the default parameters for calling OpenAI API."""
+        return {
+            "model": self.model,
+            "temperature": self.temperature,
+            "top_p": self.top_p
+        }
+
+    @property
+    def _identifying_params(self) -> Dict[str, Any]:
+        """Get the identifying parameters."""
+        return self._default_params
+
+    @property
+    def _llm_type(self) -> str:
+        """Return type of llm."""
+        return "zhipuai"
+
+    def _convert_message_to_dict(self, message: BaseMessage) -> dict:
+        if isinstance(message, ChatMessage):
+            message_dict = {"role": message.role, "content": message.content}
+        elif isinstance(message, HumanMessage):
+            message_dict = {"role": "user", "content": message.content}
+        elif isinstance(message, AIMessage):
+            message_dict = {"role": "assistant", "content": message.content}
+        elif isinstance(message, SystemMessage):
+            message_dict = {"role": "user", "content": message.content}
+        else:
+            raise ValueError(f"Got unknown type {message}")
+        return message_dict
+
+    def _convert_dict_to_message(self, _dict: Dict[str, Any]) -> BaseMessage:
+        role = _dict["role"]
+        if role == "user":
+            return HumanMessage(content=_dict["content"])
+        elif role == "assistant":
+            return AIMessage(content=_dict["content"])
+        elif role == "system":
+            return SystemMessage(content=_dict["content"])
+        else:
+            return ChatMessage(content=_dict["content"], role=role)
+
+    def _create_message_dicts(
+        self, messages: List[BaseMessage]
+    ) -> List[Dict[str, Any]]:
+        dict_messages = []
+        for m in messages:
+            message = self._convert_message_to_dict(m)
+            if dict_messages:
+                previous_message = dict_messages[-1]
+                if previous_message['role'] == message['role']:
+                    dict_messages[-1]['content'] += f"\n{message['content']}"
+                else:
+                    dict_messages.append(message)
+            else:
+                dict_messages.append(message)
+
+        return dict_messages
+
+    def _generate(
+            self,
+            messages: List[BaseMessage],
+            stop: Optional[List[str]] = None,
+            run_manager: Optional[CallbackManagerForLLMRun] = None,
+            **kwargs: Any,
+    ) -> ChatResult:
+        if self.streaming:
+            generation: Optional[ChatGenerationChunk] = None
+            llm_output: Optional[Dict] = None
+            for chunk in self._stream(
+                    messages=messages, stop=stop, run_manager=run_manager, **kwargs
+            ):
+                if chunk.generation_info is not None \
+                        and 'token_usage' in chunk.generation_info:
+                    llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model}
+                    continue
+
+                if generation is None:
+                    generation = chunk
+                else:
+                    generation += chunk
+            assert generation is not None
+            return ChatResult(generations=[generation], llm_output=llm_output)
+        else:
+            message_dicts = self._create_message_dicts(messages)
+            request = self._default_params
+            request["prompt"] = message_dicts
+            request.update(kwargs)
+            response = self.client.invoke(**request)
+            return self._create_chat_result(response)
+
+    def _stream(
+            self,
+            messages: List[BaseMessage],
+            stop: Optional[List[str]] = None,
+            run_manager: Optional[CallbackManagerForLLMRun] = None,
+            **kwargs: Any,
+    ) -> Iterator[ChatGenerationChunk]:
+        message_dicts = self._create_message_dicts(messages)
+        request = self._default_params
+        request["prompt"] = message_dicts
+        request.update(kwargs)
+
+        for event in self.client.sse_invoke(incremental=True, **request).events():
+            if event.event == "add":
+                yield ChatGenerationChunk(message=AIMessageChunk(content=event.data))
+                if run_manager:
+                    run_manager.on_llm_new_token(event.data)
+            elif event.event == "error" or event.event == "interrupted":
+                raise ValueError(
+                    f"{event.data}"
+                )
+            elif event.event == "finish":
+                meta = json.loads(event.meta)
+                token_usage = meta['usage']
+                if token_usage is not None:
+                    if 'prompt_tokens' not in token_usage:
+                        token_usage['prompt_tokens'] = 0
+                    if 'completion_tokens' not in token_usage:
+                        token_usage['completion_tokens'] = token_usage['total_tokens']
+
+                yield ChatGenerationChunk(
+                    message=AIMessageChunk(content=event.data),
+                    generation_info=dict({'token_usage': token_usage})
+                )
+
+    def _create_chat_result(self, response: Dict[str, Any]) -> ChatResult:
+        data = response["data"]
+        generations = []
+        for res in data["choices"]:
+            message = self._convert_dict_to_message(res)
+            gen = ChatGeneration(
+                message=message
+            )
+            generations.append(gen)
+        token_usage = data.get("usage")
+        if token_usage is not None:
+            if 'prompt_tokens' not in token_usage:
+                token_usage['prompt_tokens'] = 0
+            if 'completion_tokens' not in token_usage:
+                token_usage['completion_tokens'] = token_usage['total_tokens']
+
+        llm_output = {"token_usage": token_usage, "model_name": self.model}
+        return ChatResult(generations=generations, llm_output=llm_output)
+
+    # def get_token_ids(self, text: str) -> List[int]:
+    #     """Return the ordered ids of the tokens in a text.
+    #
+    #     Args:
+    #         text: The string input to tokenize.
+    #
+    #     Returns:
+    #         A list of ids corresponding to the tokens in the text, in order they occur
+    #             in the text.
+    #     """
+    #     from core.third_party.transformers.Token import ChatGLMTokenizer
+    #
+    #     tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm2-6b")
+    #     return tokenizer.encode(text)
+
+    def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
+        """Get the number of tokens in the messages.
+
+        Useful for checking if an input will fit in a model's context window.
+
+        Args:
+            messages: The message inputs to tokenize.
+
+        Returns:
+            The sum of the number of tokens across the messages.
+        """
+        return sum([self.get_num_tokens(m.content) for m in messages])
+
+    def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
+        overall_token_usage: dict = {}
+        for output in llm_outputs:
+            if output is None:
+                # Happens in streaming
+                continue
+            token_usage = output["token_usage"]
+            for k, v in token_usage.items():
+                if k in overall_token_usage:
+                    overall_token_usage[k] += v
+                else:
+                    overall_token_usage[k] = v
+        return {"token_usage": overall_token_usage, "model_name": self.model}

+ 2 - 1
api/requirements.txt

@@ -50,4 +50,5 @@ transformers~=4.31.0
 stripe~=5.5.0
 pandas==1.5.3
 xinference==0.4.2
-safetensors==0.3.2
+safetensors==0.3.2
+zhipuai==1.0.7

+ 5 - 2
api/services/provider_service.py

@@ -548,7 +548,7 @@ class ProviderService:
                 'result': 'success'
             }
 
-    def free_quota_qualification_verify(self, tenant_id: str, provider_name: str):
+    def free_quota_qualification_verify(self, tenant_id: str, provider_name: str, token: Optional[str]):
         api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
         api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
         api_url = api_base_url + '/api/v1/providers/qualification-verify'
@@ -557,8 +557,11 @@ class ProviderService:
             'Content-Type': 'application/json',
             'Authorization': f"Bearer {api_key}"
         }
+        json_data = {'workspace_id': tenant_id, 'provider_name': provider_name}
+        if token:
+            json_data['token'] = token
         response = requests.post(api_url, headers=headers,
-                                 json={'workspace_id': tenant_id, 'provider_name': provider_name})
+                                 json=json_data)
         if not response.ok:
             logging.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
             raise ValueError(f"Error: {response.status_code} ")

+ 3 - 0
api/tests/integration_tests/.env.example

@@ -31,6 +31,9 @@ TONGYI_DASHSCOPE_API_KEY=
 WENXIN_API_KEY=
 WENXIN_SECRET_KEY=
 
+# ZhipuAI Credentials
+ZHIPUAI_API_KEY=
+
 # ChatGLM Credentials
 CHATGLM_API_BASE=
 

+ 50 - 0
api/tests/integration_tests/models/embedding/test_zhipuai_embedding.py

@@ -0,0 +1,50 @@
+import json
+import os
+from unittest.mock import patch
+
+from core.model_providers.models.embedding.zhipuai_embedding import ZhipuAIEmbedding
+from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider
+from models.provider import Provider, ProviderType
+
+
+def get_mock_provider(valid_api_key):
+    return Provider(
+        id='provider_id',
+        tenant_id='tenant_id',
+        provider_name='zhipuai',
+        provider_type=ProviderType.CUSTOM.value,
+        encrypted_config=json.dumps({
+            'api_key': valid_api_key
+        }),
+        is_valid=True,
+    )
+
+
+def get_mock_embedding_model():
+    model_name = 'text_embedding'
+    valid_api_key = os.environ['ZHIPUAI_API_KEY']
+    provider = ZhipuAIProvider(provider=get_mock_provider(valid_api_key))
+    return ZhipuAIEmbedding(
+        model_provider=provider,
+        name=model_name
+    )
+
+
+def decrypt_side_effect(tenant_id, encrypted_api_key):
+    return encrypted_api_key
+
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_embedding(mock_decrypt):
+    embedding_model = get_mock_embedding_model()
+    rst = embedding_model.client.embed_query('test')
+    assert isinstance(rst, list)
+    assert len(rst) == 1024
+
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_doc_embedding(mock_decrypt):
+    embedding_model = get_mock_embedding_model()
+    rst = embedding_model.client.embed_documents(['test', 'test2'])
+    assert isinstance(rst, list)
+    assert len(rst[0]) == 1024

+ 79 - 0
api/tests/integration_tests/models/llm/test_zhipuai_model.py

@@ -0,0 +1,79 @@
+import json
+import os
+from unittest.mock import patch
+
+
+from core.model_providers.models.entity.message import PromptMessage, MessageType
+from core.model_providers.models.entity.model_params import ModelKwargs
+from core.model_providers.models.llm.zhipuai_model import ZhipuAIModel
+from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider
+from models.provider import Provider, ProviderType
+
+
+def get_mock_provider(valid_api_key):
+    return Provider(
+        id='provider_id',
+        tenant_id='tenant_id',
+        provider_name='zhipuai',
+        provider_type=ProviderType.CUSTOM.value,
+        encrypted_config=json.dumps({
+            'api_key': valid_api_key
+        }),
+        is_valid=True,
+    )
+
+
+def get_mock_model(model_name: str, streaming: bool = False):
+    model_kwargs = ModelKwargs(
+        temperature=0.01,
+    )
+    valid_api_key = os.environ['ZHIPUAI_API_KEY']
+    model_provider = ZhipuAIProvider(provider=get_mock_provider(valid_api_key))
+    return ZhipuAIModel(
+        model_provider=model_provider,
+        name=model_name,
+        model_kwargs=model_kwargs,
+        streaming=streaming
+    )
+
+
+def decrypt_side_effect(tenant_id, encrypted_api_key):
+    return encrypted_api_key
+
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_chat_get_num_tokens(mock_decrypt):
+    model = get_mock_model('chatglm_lite')
+    rst = model.get_num_tokens([
+        PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
+        PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
+    ])
+    assert rst > 0
+
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_chat_run(mock_decrypt, mocker):
+    mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
+
+    model = get_mock_model('chatglm_lite')
+    messages = [
+        PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
+    ]
+    rst = model.run(
+        messages,
+    )
+    assert len(rst.content) > 0
+
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_chat_stream_run(mock_decrypt, mocker):
+    mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
+
+    model = get_mock_model('chatglm_lite', streaming=True)
+    messages = [
+        PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
+    ]
+    rst = model.run(
+        messages
+    )
+    assert len(rst.content) > 0

+ 1 - 1
api/tests/unit_tests/model_providers/test_spark_provider.py

@@ -39,7 +39,7 @@ def test_is_provider_credentials_valid_or_raise_invalid():
         MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
 
     credential = VALIDATE_CREDENTIAL.copy()
-    credential['api_key'] = 'invalid_key'
+    del credential['api_key']
 
     # raise CredentialsValidateFailedError if api_key is invalid
     with pytest.raises(CredentialsValidateFailedError):

+ 88 - 0
api/tests/unit_tests/model_providers/test_zhipuai_provider.py

@@ -0,0 +1,88 @@
+import pytest
+from unittest.mock import patch
+import json
+
+from langchain.schema import ChatResult, ChatGeneration, AIMessage
+
+from core.model_providers.providers.base import CredentialsValidateFailedError
+from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider
+from models.provider import ProviderType, Provider
+
+
+PROVIDER_NAME = 'zhipuai'
+MODEL_PROVIDER_CLASS = ZhipuAIProvider
+VALIDATE_CREDENTIAL = {
+    'api_key': 'valid_key',
+}
+
+
+def encrypt_side_effect(tenant_id, encrypt_key):
+    return f'encrypted_{encrypt_key}'
+
+
+def decrypt_side_effect(tenant_id, encrypted_key):
+    return encrypted_key.replace('encrypted_', '')
+
+
+def test_is_provider_credentials_valid_or_raise_valid(mocker):
+    mocker.patch('core.third_party.langchain.llms.zhipuai_llm.ZhipuAIChatLLM._generate',
+                 return_value=ChatResult(generations=[ChatGeneration(message=AIMessage(content='abc'))]))
+
+    MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
+
+
+def test_is_provider_credentials_valid_or_raise_invalid():
+    # raise CredentialsValidateFailedError if api_key is not in credentials
+    with pytest.raises(CredentialsValidateFailedError):
+        MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
+
+    credential = VALIDATE_CREDENTIAL.copy()
+    credential['api_key'] = 'invalid_key'
+
+    # raise CredentialsValidateFailedError if api_key is invalid
+    with pytest.raises(CredentialsValidateFailedError):
+        MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential)
+
+
+@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
+def test_encrypt_credentials(mock_encrypt):
+    result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy())
+    assert result['api_key'] == f'encrypted_{VALIDATE_CREDENTIAL["api_key"]}'
+
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_get_credentials_custom(mock_decrypt):
+    encrypted_credential = VALIDATE_CREDENTIAL.copy()
+    encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key']
+
+    provider = Provider(
+        id='provider_id',
+        tenant_id='tenant_id',
+        provider_name=PROVIDER_NAME,
+        provider_type=ProviderType.CUSTOM.value,
+        encrypted_config=json.dumps(encrypted_credential),
+        is_valid=True,
+    )
+    model_provider = MODEL_PROVIDER_CLASS(provider=provider)
+    result = model_provider.get_provider_credentials()
+    assert result['api_key'] == 'valid_key'
+
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_get_credentials_obfuscated(mock_decrypt):
+    encrypted_credential = VALIDATE_CREDENTIAL.copy()
+    encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key']
+
+    provider = Provider(
+        id='provider_id',
+        tenant_id='tenant_id',
+        provider_name=PROVIDER_NAME,
+        provider_type=ProviderType.CUSTOM.value,
+        encrypted_config=json.dumps(encrypted_credential),
+        is_valid=True,
+    )
+    model_provider = MODEL_PROVIDER_CLASS(provider=provider)
+    result = model_provider.get_provider_credentials(obfuscated=True)
+    middle_token = result['api_key'][6:-2]
+    assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['api_key']) - 8, 0)
+    assert all(char == '*' for char in middle_token)