Преглед изворни кода

chore(api/tests): apply ruff reformat #7590 (#7591)

Co-authored-by: -LAN- <laipz8200@outlook.com>
Bowen Liang пре 8 месеци
родитељ
комит
b035c02f78
100 измењених фајлова са 2288 додато и 3556 уклоњено
  1. 0 1
      api/pyproject.toml
  2. 28 40
      api/tests/integration_tests/model_runtime/__mock/anthropic.py
  3. 21 37
      api/tests/integration_tests/model_runtime/__mock/google.py
  4. 4 3
      api/tests/integration_tests/model_runtime/__mock/huggingface.py
  5. 11 16
      api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py
  6. 21 21
      api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py
  7. 13 8
      api/tests/integration_tests/model_runtime/__mock/openai.py
  8. 102 76
      api/tests/integration_tests/model_runtime/__mock/openai_chat.py
  9. 43 34
      api/tests/integration_tests/model_runtime/__mock/openai_completion.py
  10. 13 19
      api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py
  11. 68 34
      api/tests/integration_tests/model_runtime/__mock/openai_moderation.py
  12. 6 5
      api/tests/integration_tests/model_runtime/__mock/openai_remote.py
  13. 9 10
      api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py
  14. 65 70
      api/tests/integration_tests/model_runtime/__mock/xinference.py
  15. 26 49
      api/tests/integration_tests/model_runtime/anthropic/test_llm.py
  16. 3 9
      api/tests/integration_tests/model_runtime/anthropic/test_provider.py
  17. 73 89
      api/tests/integration_tests/model_runtime/azure_openai/test_llm.py
  18. 20 29
      api/tests/integration_tests/model_runtime/azure_openai/test_text_embedding.py
  19. 57 75
      api/tests/integration_tests/model_runtime/baichuan/test_llm.py
  20. 2 10
      api/tests/integration_tests/model_runtime/baichuan/test_provider.py
  21. 15 27
      api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py
  22. 25 41
      api/tests/integration_tests/model_runtime/bedrock/test_llm.py
  23. 2 4
      api/tests/integration_tests/model_runtime/bedrock/test_provider.py
  24. 77 138
      api/tests/integration_tests/model_runtime/chatglm/test_llm.py
  25. 3 11
      api/tests/integration_tests/model_runtime/chatglm/test_provider.py
  26. 54 129
      api/tests/integration_tests/model_runtime/cohere/test_llm.py
  27. 2 8
      api/tests/integration_tests/model_runtime/cohere/test_provider.py
  28. 7 19
      api/tests/integration_tests/model_runtime/cohere/test_rerank.py
  29. 9 29
      api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py
  30. 38 71
      api/tests/integration_tests/model_runtime/google/test_llm.py
  31. 3 9
      api/tests/integration_tests/model_runtime/google/test_provider.py
  32. 117 143
      api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py
  33. 40 49
      api/tests/integration_tests/model_runtime/huggingface_hub/test_text_embedding.py
  34. 18 20
      api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py
  35. 20 18
      api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py
  36. 24 45
      api/tests/integration_tests/model_runtime/hunyuan/test_llm.py
  37. 3 8
      api/tests/integration_tests/model_runtime/hunyuan/test_provider.py
  38. 21 29
      api/tests/integration_tests/model_runtime/hunyuan/test_text_embedding.py
  39. 2 10
      api/tests/integration_tests/model_runtime/jina/test_provider.py
  40. 9 23
      api/tests/integration_tests/model_runtime/jina/test_text_embedding.py
  41. 3 3
      api/tests/integration_tests/model_runtime/localai/test_embedding.py
  42. 55 99
      api/tests/integration_tests/model_runtime/localai/test_llm.py
  43. 25 31
      api/tests/integration_tests/model_runtime/localai/test_rerank.py
  44. 9 21
      api/tests/integration_tests/model_runtime/localai/test_speech2text.py
  45. 17 24
      api/tests/integration_tests/model_runtime/minimax/test_embedding.py
  46. 46 61
      api/tests/integration_tests/model_runtime/minimax/test_llm.py
  47. 4 4
      api/tests/integration_tests/model_runtime/minimax/test_provider.py
  48. 23 47
      api/tests/integration_tests/model_runtime/novita/test_llm.py
  49. 2 4
      api/tests/integration_tests/model_runtime/novita/test_provider.py
  50. 56 88
      api/tests/integration_tests/model_runtime/ollama/test_llm.py
  51. 21 27
      api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py
  52. 43 78
      api/tests/integration_tests/model_runtime/openai/test_llm.py
  53. 11 22
      api/tests/integration_tests/model_runtime/openai/test_moderation.py
  54. 3 9
      api/tests/integration_tests/model_runtime/openai/test_provider.py
  55. 12 23
      api/tests/integration_tests/model_runtime/openai/test_speech2text.py
  56. 12 33
      api/tests/integration_tests/model_runtime/openai/test_text_embedding.py
  57. 60 89
      api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py
  58. 4 13
      api/tests/integration_tests/model_runtime/openai_api_compatible/test_speech2text.py
  59. 20 32
      api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py
  60. 14 19
      api/tests/integration_tests/model_runtime/openllm/test_embedding.py
  61. 30 39
      api/tests/integration_tests/model_runtime/openllm/test_llm.py
  62. 26 45
      api/tests/integration_tests/model_runtime/openrouter/test_llm.py
  63. 34 40
      api/tests/integration_tests/model_runtime/replicate/test_llm.py
  64. 40 55
      api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py
  65. 2 6
      api/tests/integration_tests/model_runtime/sagemaker/test_provider.py
  66. 6 6
      api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py
  67. 5 27
      api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py
  68. 19 52
      api/tests/integration_tests/model_runtime/siliconflow/test_llm.py
  69. 2 8
      api/tests/integration_tests/model_runtime/siliconflow/test_provider.py
  70. 5 7
      api/tests/integration_tests/model_runtime/siliconflow/test_rerank.py
  71. 4 12
      api/tests/integration_tests/model_runtime/siliconflow/test_speech2text.py
  72. 1 3
      api/tests/integration_tests/model_runtime/siliconflow/test_text_embedding.py
  73. 28 49
      api/tests/integration_tests/model_runtime/spark/test_llm.py
  74. 4 6
      api/tests/integration_tests/model_runtime/spark/test_provider.py
  75. 36 87
      api/tests/integration_tests/model_runtime/stepfun/test_llm.py
  76. 8 21
      api/tests/integration_tests/model_runtime/test_model_provider_factory.py
  77. 29 45
      api/tests/integration_tests/model_runtime/togetherai/test_llm.py
  78. 18 49
      api/tests/integration_tests/model_runtime/tongyi/test_llm.py
  79. 2 6
      api/tests/integration_tests/model_runtime/tongyi/test_provider.py
  80. 7 11
      api/tests/integration_tests/model_runtime/tongyi/test_response_format.py
  81. 59 118
      api/tests/integration_tests/model_runtime/upstage/test_llm.py
  82. 3 9
      api/tests/integration_tests/model_runtime/upstage/test_provider.py
  83. 12 25
      api/tests/integration_tests/model_runtime/upstage/test_text_embedding.py
  84. 32 38
      api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py
  85. 50 63
      api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py
  86. 16 28
      api/tests/integration_tests/model_runtime/wenxin/test_embedding.py
  87. 69 126
      api/tests/integration_tests/model_runtime/wenxin/test_llm.py
  88. 2 10
      api/tests/integration_tests/model_runtime/wenxin/test_provider.py
  89. 21 25
      api/tests/integration_tests/model_runtime/xinference/test_embeddings.py
  90. 85 116
      api/tests/integration_tests/model_runtime/xinference/test_llm.py
  91. 15 17
      api/tests/integration_tests/model_runtime/xinference/test_rerank.py
  92. 19 52
      api/tests/integration_tests/model_runtime/zhinao/test_llm.py
  93. 2 8
      api/tests/integration_tests/model_runtime/zhinao/test_provider.py
  94. 31 77
      api/tests/integration_tests/model_runtime/zhipuai/test_llm.py
  95. 2 8
      api/tests/integration_tests/model_runtime/zhipuai/test_provider.py
  96. 8 30
      api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py
  97. 6 9
      api/tests/integration_tests/tools/__mock/http.py
  98. 4 2
      api/tests/integration_tests/tools/__mock_server/openapi_todo.py
  99. 27 24
      api/tests/integration_tests/tools/api_tool/test_api_tool.py
  100. 5 4
      api/tests/integration_tests/tools/test_all_provider.py

+ 0 - 1
api/pyproject.toml

@@ -76,7 +76,6 @@ exclude = [
     "migrations/**/*",
     "services/**/*.py",
     "tasks/**/*.py",
-    "tests/**/*.py",
 ]
 
 [tool.pytest_env]

+ 28 - 40
api/tests/integration_tests/model_runtime/__mock/anthropic.py

@@ -22,23 +22,20 @@ from anthropic.types import (
 )
 from anthropic.types.message_delta_event import Delta
 
-MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true'
+MOCK = os.getenv("MOCK_SWITCH", "false") == "true"
 
 
 class MockAnthropicClass:
     @staticmethod
     def mocked_anthropic_chat_create_sync(model: str) -> Message:
         return Message(
-            id='msg-123',
-            type='message',
-            role='assistant',
-            content=[ContentBlock(text='hello, I\'m a chatbot from anthropic', type='text')],
+            id="msg-123",
+            type="message",
+            role="assistant",
+            content=[ContentBlock(text="hello, I'm a chatbot from anthropic", type="text")],
             model=model,
-            stop_reason='stop_sequence',
-            usage=Usage(
-                input_tokens=1,
-                output_tokens=1
-            )
+            stop_reason="stop_sequence",
+            usage=Usage(input_tokens=1, output_tokens=1),
         )
 
     @staticmethod
@@ -46,52 +43,43 @@ class MockAnthropicClass:
         full_response_text = "hello, I'm a chatbot from anthropic"
 
         yield MessageStartEvent(
-            type='message_start',
+            type="message_start",
             message=Message(
-                id='msg-123',
+                id="msg-123",
                 content=[],
-                role='assistant',
+                role="assistant",
                 model=model,
                 stop_reason=None,
-                type='message',
-                usage=Usage(
-                    input_tokens=1,
-                    output_tokens=1
-                )
-            )
+                type="message",
+                usage=Usage(input_tokens=1, output_tokens=1),
+            ),
         )
 
         index = 0
         for i in range(0, len(full_response_text)):
             yield ContentBlockDeltaEvent(
-                type='content_block_delta',
-                delta=TextDelta(text=full_response_text[i], type='text_delta'),
-                index=index
+                type="content_block_delta", delta=TextDelta(text=full_response_text[i], type="text_delta"), index=index
             )
 
             index += 1
 
         yield MessageDeltaEvent(
-            type='message_delta',
-            delta=Delta(
-                stop_reason='stop_sequence'
-            ),
-            usage=MessageDeltaUsage(
-                output_tokens=1
-            )
+            type="message_delta", delta=Delta(stop_reason="stop_sequence"), usage=MessageDeltaUsage(output_tokens=1)
         )
 
-        yield MessageStopEvent(type='message_stop')
-
-    def mocked_anthropic(self: Messages, *,
-                         max_tokens: int,
-                         messages: Iterable[MessageParam],
-                         model: str,
-                         stream: Literal[True],
-                         **kwargs: Any
-                         ) -> Union[Message, Stream[MessageStreamEvent]]:
+        yield MessageStopEvent(type="message_stop")
+
+    def mocked_anthropic(
+        self: Messages,
+        *,
+        max_tokens: int,
+        messages: Iterable[MessageParam],
+        model: str,
+        stream: Literal[True],
+        **kwargs: Any,
+    ) -> Union[Message, Stream[MessageStreamEvent]]:
         if len(self._client.api_key) < 18:
-            raise anthropic.AuthenticationError('Invalid API key')
+            raise anthropic.AuthenticationError("Invalid API key")
 
         if stream:
             return MockAnthropicClass.mocked_anthropic_chat_create_stream(model=model)
@@ -102,7 +90,7 @@ class MockAnthropicClass:
 @pytest.fixture
 def setup_anthropic_mock(request, monkeypatch: MonkeyPatch):
     if MOCK:
-        monkeypatch.setattr(Messages, 'create', MockAnthropicClass.mocked_anthropic)
+        monkeypatch.setattr(Messages, "create", MockAnthropicClass.mocked_anthropic)
 
     yield
 

+ 21 - 37
api/tests/integration_tests/model_runtime/__mock/google.py

@@ -12,63 +12,46 @@ from google.generativeai.client import _ClientManager, configure
 from google.generativeai.types import GenerateContentResponse
 from google.generativeai.types.generation_types import BaseGenerateContentResponse
 
-current_api_key = ''
+current_api_key = ""
+
 
 class MockGoogleResponseClass:
     _done = False
 
     def __iter__(self):
-        full_response_text = 'it\'s google!'
+        full_response_text = "it's google!"
 
         for i in range(0, len(full_response_text) + 1, 1):
             if i == len(full_response_text):
                 self._done = True
                 yield GenerateContentResponse(
-                    done=True,
-                    iterator=None,
-                    result=glm.GenerateContentResponse({
-
-                    }),
-                    chunks=[]
+                    done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]
                 )
             else:
                 yield GenerateContentResponse(
-                    done=False,
-                    iterator=None,
-                    result=glm.GenerateContentResponse({
-
-                    }),
-                    chunks=[]
+                    done=False, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]
                 )
 
+
 class MockGoogleResponseCandidateClass:
-    finish_reason = 'stop'
+    finish_reason = "stop"
 
     @property
     def content(self) -> gag_content.Content:
-        return gag_content.Content(
-            parts=[
-                gag_content.Part(text='it\'s google!')
-            ]
-        )
+        return gag_content.Content(parts=[gag_content.Part(text="it's google!")])
+
 
 class MockGoogleClass:
     @staticmethod
     def generate_content_sync() -> GenerateContentResponse:
-        return GenerateContentResponse(
-            done=True,
-            iterator=None,
-            result=glm.GenerateContentResponse({
-
-            }),
-            chunks=[]
-        )
+        return GenerateContentResponse(done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[])
 
     @staticmethod
     def generate_content_stream() -> Generator[GenerateContentResponse, None, None]:
         return MockGoogleResponseClass()
 
-    def generate_content(self: GenerativeModel,
+    def generate_content(
+        self: GenerativeModel,
         contents: content_types.ContentsType,
         *,
         generation_config: generation_config_types.GenerationConfigType | None = None,
@@ -79,21 +62,21 @@ class MockGoogleClass:
         global current_api_key
 
         if len(current_api_key) < 16:
-            raise Exception('Invalid API key')
+            raise Exception("Invalid API key")
 
         if stream:
             return MockGoogleClass.generate_content_stream()
-        
+
         return MockGoogleClass.generate_content_sync()
-    
+
     @property
     def generative_response_text(self) -> str:
-        return 'it\'s google!'
-    
+        return "it's google!"
+
     @property
     def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]:
         return [MockGoogleResponseCandidateClass()]
-    
+
     def make_client(self: _ClientManager, name: str):
         global current_api_key
 
@@ -121,7 +104,8 @@ class MockGoogleClass:
 
         if not self.default_metadata:
             return client
-    
+
+
 @pytest.fixture
 def setup_google_mock(request, monkeypatch: MonkeyPatch):
     monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text)
@@ -131,4 +115,4 @@ def setup_google_mock(request, monkeypatch: MonkeyPatch):
 
     yield
 
-    monkeypatch.undo()
+    monkeypatch.undo()

+ 4 - 3
api/tests/integration_tests/model_runtime/__mock/huggingface.py

@@ -6,14 +6,15 @@ from huggingface_hub import InferenceClient
 
 from tests.integration_tests.model_runtime.__mock.huggingface_chat import MockHuggingfaceChatClass
 
-MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
+MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
+
 
 @pytest.fixture
 def setup_huggingface_mock(request, monkeypatch: MonkeyPatch):
     if MOCK:
         monkeypatch.setattr(InferenceClient, "text_generation", MockHuggingfaceChatClass.text_generation)
-    
+
     yield
 
     if MOCK:
-        monkeypatch.undo()
+        monkeypatch.undo()

+ 11 - 16
api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py

@@ -22,10 +22,8 @@ class MockHuggingfaceChatClass:
             details=Details(
                 finish_reason="length",
                 generated_tokens=6,
-                tokens=[
-                    Token(id=0, text="You", logprob=0.0, special=False) for i in range(0, 6)
-                ]
-            )
+                tokens=[Token(id=0, text="You", logprob=0.0, special=False) for i in range(0, 6)],
+            ),
         )
 
         return response
@@ -36,26 +34,23 @@ class MockHuggingfaceChatClass:
 
         for i in range(0, len(full_text)):
             response = TextGenerationStreamResponse(
-                token = Token(id=i, text=full_text[i], logprob=0.0, special=False),
+                token=Token(id=i, text=full_text[i], logprob=0.0, special=False),
             )
             response.generated_text = full_text[i]
-            response.details = StreamDetails(finish_reason='stop_sequence', generated_tokens=1)
+            response.details = StreamDetails(finish_reason="stop_sequence", generated_tokens=1)
 
             yield response
 
-    def text_generation(self: InferenceClient, prompt: str, *,
-        stream: Literal[False] = ...,
-        model: Optional[str] = None,
-        **kwargs: Any
+    def text_generation(
+        self: InferenceClient, prompt: str, *, stream: Literal[False] = ..., model: Optional[str] = None, **kwargs: Any
     ) -> Union[TextGenerationResponse, Generator[TextGenerationStreamResponse, None, None]]:
         # check if key is valid
-        if not re.match(r'Bearer\shf\-[a-zA-Z0-9]{16,}', self.headers['authorization']):
-            raise BadRequestError('Invalid API key')
-        
+        if not re.match(r"Bearer\shf\-[a-zA-Z0-9]{16,}", self.headers["authorization"]):
+            raise BadRequestError("Invalid API key")
+
         if model is None:
-            raise BadRequestError('Invalid model')
-        
+            raise BadRequestError("Invalid model")
+
         if stream:
             return MockHuggingfaceChatClass.generate_create_stream(model)
         return MockHuggingfaceChatClass.generate_create_sync(model)
-

+ 21 - 21
api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py

@@ -5,10 +5,10 @@ class MockTEIClass:
     @staticmethod
     def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter:
         # During mock, we don't have a real server to query, so we just return a dummy value
-        if 'rerank' in model_name:
-            model_type = 'reranker'
+        if "rerank" in model_name:
+            model_type = "reranker"
         else:
-            model_type = 'embedding'
+            model_type = "embedding"
 
         return TeiModelExtraParameter(model_type=model_type, max_input_length=512, max_client_batch_size=1)
 
@@ -17,16 +17,16 @@ class MockTEIClass:
         # Use space as token separator, and split the text into tokens
         tokenized_texts = []
         for text in texts:
-            tokens = text.split(' ')
+            tokens = text.split(" ")
             current_index = 0
             tokenized_text = []
             for idx, token in enumerate(tokens):
                 s_token = {
-                    'id': idx,
-                    'text': token,
-                    'special': False,
-                    'start': current_index,
-                    'stop': current_index + len(token),
+                    "id": idx,
+                    "text": token,
+                    "special": False,
+                    "start": current_index,
+                    "stop": current_index + len(token),
                 }
                 current_index += len(token) + 1
                 tokenized_text.append(s_token)
@@ -55,18 +55,18 @@ class MockTEIClass:
             embedding = [0.1] * 768
             embeddings.append(
                 {
-                    'object': 'embedding',
-                    'embedding': embedding,
-                    'index': idx,
+                    "object": "embedding",
+                    "embedding": embedding,
+                    "index": idx,
                 }
             )
         return {
-            'object': 'list',
-            'data': embeddings,
-            'model': 'MODEL_NAME',
-            'usage': {
-                'prompt_tokens': sum(len(text.split(' ')) for text in texts),
-                'total_tokens': sum(len(text.split(' ')) for text in texts),
+            "object": "list",
+            "data": embeddings,
+            "model": "MODEL_NAME",
+            "usage": {
+                "prompt_tokens": sum(len(text.split(" ")) for text in texts),
+                "total_tokens": sum(len(text.split(" ")) for text in texts),
             },
         }
 
@@ -83,9 +83,9 @@ class MockTEIClass:
         for idx, text in enumerate(texts):
             reranked_docs.append(
                 {
-                    'index': idx,
-                    'text': text,
-                    'score': 0.9,
+                    "index": idx,
+                    "text": text,
+                    "score": 0.9,
                 }
             )
             # For mock, only return the first document

+ 13 - 8
api/tests/integration_tests/model_runtime/__mock/openai.py

@@ -21,13 +21,17 @@ from tests.integration_tests.model_runtime.__mock.openai_remote import MockModel
 from tests.integration_tests.model_runtime.__mock.openai_speech2text import MockSpeech2TextClass
 
 
-def mock_openai(monkeypatch: MonkeyPatch, methods: list[Literal["completion", "chat", "remote", "moderation", "speech2text", "text_embedding"]]) -> Callable[[], None]:
+def mock_openai(
+    monkeypatch: MonkeyPatch,
+    methods: list[Literal["completion", "chat", "remote", "moderation", "speech2text", "text_embedding"]],
+) -> Callable[[], None]:
     """
-        mock openai module
+    mock openai module
 
-        :param monkeypatch: pytest monkeypatch fixture
-        :return: unpatch function
+    :param monkeypatch: pytest monkeypatch fixture
+    :return: unpatch function
     """
+
     def unpatch() -> None:
         monkeypatch.undo()
 
@@ -52,15 +56,16 @@ def mock_openai(monkeypatch: MonkeyPatch, methods: list[Literal["completion", "c
     return unpatch
 
 
-MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
+MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
+
 
 @pytest.fixture
 def setup_openai_mock(request, monkeypatch):
-    methods = request.param if hasattr(request, 'param') else []
+    methods = request.param if hasattr(request, "param") else []
     if MOCK:
         unpatch = mock_openai(monkeypatch, methods=methods)
-    
+
     yield
 
     if MOCK:
-        unpatch()
+        unpatch()

+ 102 - 76
api/tests/integration_tests/model_runtime/__mock/openai_chat.py

@@ -43,62 +43,64 @@ class MockChatClass:
         if not functions or len(functions) == 0:
             return None
         function: completion_create_params.Function = functions[0]
-        function_name = function['name']
-        function_description = function['description']
-        function_parameters = function['parameters']
-        function_parameters_type = function_parameters['type']
-        if function_parameters_type != 'object':
+        function_name = function["name"]
+        function_description = function["description"]
+        function_parameters = function["parameters"]
+        function_parameters_type = function_parameters["type"]
+        if function_parameters_type != "object":
             return None
-        function_parameters_properties = function_parameters['properties']
-        function_parameters_required = function_parameters['required']
+        function_parameters_properties = function_parameters["properties"]
+        function_parameters_required = function_parameters["required"]
         parameters = {}
         for parameter_name, parameter in function_parameters_properties.items():
             if parameter_name not in function_parameters_required:
                 continue
-            parameter_type = parameter['type']
-            if parameter_type == 'string':
-                if 'enum' in parameter:
-                    if len(parameter['enum']) == 0:
+            parameter_type = parameter["type"]
+            if parameter_type == "string":
+                if "enum" in parameter:
+                    if len(parameter["enum"]) == 0:
                         continue
-                    parameters[parameter_name] = parameter['enum'][0]
+                    parameters[parameter_name] = parameter["enum"][0]
                 else:
-                    parameters[parameter_name] = 'kawaii'
-            elif parameter_type == 'integer':
+                    parameters[parameter_name] = "kawaii"
+            elif parameter_type == "integer":
                 parameters[parameter_name] = 114514
-            elif parameter_type == 'number':
+            elif parameter_type == "number":
                 parameters[parameter_name] = 1919810.0
-            elif parameter_type == 'boolean':
+            elif parameter_type == "boolean":
                 parameters[parameter_name] = True
 
         return FunctionCall(name=function_name, arguments=dumps(parameters))
-        
+
     @staticmethod
-    def generate_tool_calls(tools = NOT_GIVEN) -> Optional[list[ChatCompletionMessageToolCall]]:
+    def generate_tool_calls(tools=NOT_GIVEN) -> Optional[list[ChatCompletionMessageToolCall]]:
         list_tool_calls = []
         if not tools or len(tools) == 0:
             return None
         tool = tools[0]
 
-        if 'type' in tools and tools['type'] != 'function':
+        if "type" in tools and tools["type"] != "function":
             return None
 
-        function = tool['function']
+        function = tool["function"]
 
         function_call = MockChatClass.generate_function_call(functions=[function])
         if function_call is None:
             return None
-        
-        list_tool_calls.append(ChatCompletionMessageToolCall(
-            id='sakurajima-mai',
-            function=Function(
-                name=function_call.name,
-                arguments=function_call.arguments,
-            ),
-            type='function'
-        ))
+
+        list_tool_calls.append(
+            ChatCompletionMessageToolCall(
+                id="sakurajima-mai",
+                function=Function(
+                    name=function_call.name,
+                    arguments=function_call.arguments,
+                ),
+                type="function",
+            )
+        )
 
         return list_tool_calls
-    
+
     @staticmethod
     def mocked_openai_chat_create_sync(
         model: str,
@@ -111,30 +113,27 @@ class MockChatClass:
             tool_calls = MockChatClass.generate_tool_calls(tools=tools)
 
         return _ChatCompletion(
-            id='cmpl-3QJQa5jXJ5Z5X',
+            id="cmpl-3QJQa5jXJ5Z5X",
             choices=[
                 _ChatCompletionChoice(
-                    finish_reason='content_filter',
+                    finish_reason="content_filter",
                     index=0,
                     message=ChatCompletionMessage(
-                        content='elaina',
-                        role='assistant',
-                        function_call=function_call,
-                        tool_calls=tool_calls
-                    )
+                        content="elaina", role="assistant", function_call=function_call, tool_calls=tool_calls
+                    ),
                 )
             ],
             created=int(time()),
             model=model,
-            object='chat.completion',
-            system_fingerprint='',
+            object="chat.completion",
+            system_fingerprint="",
             usage=CompletionUsage(
                 prompt_tokens=2,
                 completion_tokens=1,
                 total_tokens=3,
-            )
+            ),
         )
-    
+
     @staticmethod
     def mocked_openai_chat_create_stream(
         model: str,
@@ -150,36 +149,40 @@ class MockChatClass:
         for i in range(0, len(full_text) + 1):
             if i == len(full_text):
                 yield ChatCompletionChunk(
-                    id='cmpl-3QJQa5jXJ5Z5X',
+                    id="cmpl-3QJQa5jXJ5Z5X",
                     choices=[
                         Choice(
                             delta=ChoiceDelta(
-                                content='',
+                                content="",
                                 function_call=ChoiceDeltaFunctionCall(
                                     name=function_call.name,
                                     arguments=function_call.arguments,
-                                ) if function_call else None,
-                                role='assistant',
+                                )
+                                if function_call
+                                else None,
+                                role="assistant",
                                 tool_calls=[
                                     ChoiceDeltaToolCall(
                                         index=0,
-                                        id='misaka-mikoto',
+                                        id="misaka-mikoto",
                                         function=ChoiceDeltaToolCallFunction(
                                             name=tool_calls[0].function.name,
                                             arguments=tool_calls[0].function.arguments,
                                         ),
-                                        type='function'
+                                        type="function",
                                     )
-                                ] if tool_calls and len(tool_calls) > 0 else None
+                                ]
+                                if tool_calls and len(tool_calls) > 0
+                                else None,
                             ),
-                            finish_reason='function_call',
+                            finish_reason="function_call",
                             index=0,
                         )
                     ],
                     created=int(time()),
                     model=model,
-                    object='chat.completion.chunk',
-                    system_fingerprint='',
+                    object="chat.completion.chunk",
+                    system_fingerprint="",
                     usage=CompletionUsage(
                         prompt_tokens=2,
                         completion_tokens=17,
@@ -188,30 +191,45 @@ class MockChatClass:
                 )
             else:
                 yield ChatCompletionChunk(
-                    id='cmpl-3QJQa5jXJ5Z5X',
+                    id="cmpl-3QJQa5jXJ5Z5X",
                     choices=[
                         Choice(
                             delta=ChoiceDelta(
                                 content=full_text[i],
-                                role='assistant',
+                                role="assistant",
                             ),
-                            finish_reason='content_filter',
+                            finish_reason="content_filter",
                             index=0,
                         )
                     ],
                     created=int(time()),
                     model=model,
-                    object='chat.completion.chunk',
-                    system_fingerprint='',
+                    object="chat.completion.chunk",
+                    system_fingerprint="",
                 )
 
-    def chat_create(self: Completions, *,
+    def chat_create(
+        self: Completions,
+        *,
         messages: list[ChatCompletionMessageParam],
-        model: Union[str,Literal[
-            "gpt-4-1106-preview", "gpt-4-vision-preview", "gpt-4", "gpt-4-0314", "gpt-4-0613",
-            "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
-            "gpt-3.5-turbo-1106", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0301",
-            "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613"],
+        model: Union[
+            str,
+            Literal[
+                "gpt-4-1106-preview",
+                "gpt-4-vision-preview",
+                "gpt-4",
+                "gpt-4-0314",
+                "gpt-4-0613",
+                "gpt-4-32k",
+                "gpt-4-32k-0314",
+                "gpt-4-32k-0613",
+                "gpt-3.5-turbo-1106",
+                "gpt-3.5-turbo",
+                "gpt-3.5-turbo-16k",
+                "gpt-3.5-turbo-0301",
+                "gpt-3.5-turbo-0613",
+                "gpt-3.5-turbo-16k-0613",
+            ],
         ],
         functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN,
         response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
@@ -220,24 +238,32 @@ class MockChatClass:
         **kwargs: Any,
     ):
         openai_models = [
-            "gpt-4-1106-preview", "gpt-4-vision-preview", "gpt-4", "gpt-4-0314", "gpt-4-0613",
-            "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
-            "gpt-3.5-turbo-1106", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0301",
-            "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613",
-        ]
-        azure_openai_models = [
-            "gpt35", "gpt-4v", "gpt-35-turbo"
+            "gpt-4-1106-preview",
+            "gpt-4-vision-preview",
+            "gpt-4",
+            "gpt-4-0314",
+            "gpt-4-0613",
+            "gpt-4-32k",
+            "gpt-4-32k-0314",
+            "gpt-4-32k-0613",
+            "gpt-3.5-turbo-1106",
+            "gpt-3.5-turbo",
+            "gpt-3.5-turbo-16k",
+            "gpt-3.5-turbo-0301",
+            "gpt-3.5-turbo-0613",
+            "gpt-3.5-turbo-16k-0613",
         ]
-        if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()):
-            raise InvokeAuthorizationError('Invalid base url')
+        azure_openai_models = ["gpt35", "gpt-4v", "gpt-35-turbo"]
+        if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()):
+            raise InvokeAuthorizationError("Invalid base url")
         if model in openai_models + azure_openai_models:
-            if not re.match(r'sk-[a-zA-Z0-9]{24,}$', self._client.api_key) and type(self._client) == OpenAI:
+            if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI:
                 # sometime, provider use OpenAI compatible API will not have api key or have different api key format
                 # so we only check if model is in openai_models
-                raise InvokeAuthorizationError('Invalid api key')
+                raise InvokeAuthorizationError("Invalid api key")
             if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI:
-                raise InvokeAuthorizationError('Invalid api key')
+                raise InvokeAuthorizationError("Invalid api key")
         if stream:
             return MockChatClass.mocked_openai_chat_create_stream(model=model, functions=functions, tools=tools)
-        
-        return MockChatClass.mocked_openai_chat_create_sync(model=model, functions=functions, tools=tools)
+
+        return MockChatClass.mocked_openai_chat_create_sync(model=model, functions=functions, tools=tools)

+ 43 - 34
api/tests/integration_tests/model_runtime/__mock/openai_completion.py

@@ -17,9 +17,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
 
 class MockCompletionsClass:
     @staticmethod
-    def mocked_openai_completion_create_sync(
-        model: str
-    ) -> CompletionMessage:
+    def mocked_openai_completion_create_sync(model: str) -> CompletionMessage:
         return CompletionMessage(
             id="cmpl-3QJQa5jXJ5Z5X",
             object="text_completion",
@@ -38,13 +36,11 @@ class MockCompletionsClass:
                 prompt_tokens=2,
                 completion_tokens=1,
                 total_tokens=3,
-            )
+            ),
         )
-    
+
     @staticmethod
-    def mocked_openai_completion_create_stream(
-        model: str
-    ) -> Generator[CompletionMessage, None, None]:
+    def mocked_openai_completion_create_stream(model: str) -> Generator[CompletionMessage, None, None]:
         full_text = "Hello, world!\n\n```python\nprint('Hello, world!')\n```"
         for i in range(0, len(full_text) + 1):
             if i == len(full_text):
@@ -76,46 +72,59 @@ class MockCompletionsClass:
                     model=model,
                     system_fingerprint="",
                     choices=[
-                        CompletionChoice(
-                            text=full_text[i],
-                            index=0,
-                            logprobs=None,
-                            finish_reason="content_filter"
-                        )
+                        CompletionChoice(text=full_text[i], index=0, logprobs=None, finish_reason="content_filter")
                     ],
                 )
 
-    def completion_create(self: Completions, *, model: Union[
-            str, Literal["babbage-002", "davinci-002", "gpt-3.5-turbo-instruct",
-                "text-davinci-003", "text-davinci-002", "text-davinci-001",
-                "code-davinci-002", "text-curie-001", "text-babbage-001",
-                "text-ada-001"],
+    def completion_create(
+        self: Completions,
+        *,
+        model: Union[
+            str,
+            Literal[
+                "babbage-002",
+                "davinci-002",
+                "gpt-3.5-turbo-instruct",
+                "text-davinci-003",
+                "text-davinci-002",
+                "text-davinci-001",
+                "code-davinci-002",
+                "text-curie-001",
+                "text-babbage-001",
+                "text-ada-001",
+            ],
         ],
         prompt: Union[str, list[str], list[int], list[list[int]], None],
         stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
-        **kwargs: Any
+        **kwargs: Any,
     ):
         openai_models = [
-            "babbage-002", "davinci-002", "gpt-3.5-turbo-instruct", "text-davinci-003", "text-davinci-002", "text-davinci-001",
-            "code-davinci-002", "text-curie-001", "text-babbage-001", "text-ada-001",
-        ]
-        azure_openai_models = [
-            "gpt-35-turbo-instruct"
+            "babbage-002",
+            "davinci-002",
+            "gpt-3.5-turbo-instruct",
+            "text-davinci-003",
+            "text-davinci-002",
+            "text-davinci-001",
+            "code-davinci-002",
+            "text-curie-001",
+            "text-babbage-001",
+            "text-ada-001",
         ]
+        azure_openai_models = ["gpt-35-turbo-instruct"]
 
-        if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()):
-            raise InvokeAuthorizationError('Invalid base url')
+        if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()):
+            raise InvokeAuthorizationError("Invalid base url")
         if model in openai_models + azure_openai_models:
-            if not re.match(r'sk-[a-zA-Z0-9]{24,}$', self._client.api_key) and type(self._client) == OpenAI:
+            if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI:
                 # sometime, provider use OpenAI compatible API will not have api key or have different api key format
                 # so we only check if model is in openai_models
-                raise InvokeAuthorizationError('Invalid api key')
+                raise InvokeAuthorizationError("Invalid api key")
             if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI:
-                raise InvokeAuthorizationError('Invalid api key')
-            
+                raise InvokeAuthorizationError("Invalid api key")
+
         if not prompt:
-            raise BadRequestError('Invalid prompt')
+            raise BadRequestError("Invalid prompt")
         if stream:
             return MockCompletionsClass.mocked_openai_completion_create_stream(model=model)
-        
-        return MockCompletionsClass.mocked_openai_completion_create_sync(model=model)
+
+        return MockCompletionsClass.mocked_openai_completion_create_sync(model=model)

Разлика између датотеке није приказан због своје велике величине
+ 13 - 19
api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py


+ 68 - 34
api/tests/integration_tests/model_runtime/__mock/openai_moderation.py

@@ -10,58 +10,92 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
 
 
 class MockModerationClass:
-    def moderation_create(self: Moderations,*,
+    def moderation_create(
+        self: Moderations,
+        *,
         input: Union[str, list[str]],
         model: Union[str, Literal["text-moderation-latest", "text-moderation-stable"]] | NotGiven = NOT_GIVEN,
-        **kwargs: Any
+        **kwargs: Any,
     ) -> ModerationCreateResponse:
         if isinstance(input, str):
             input = [input]
 
-        if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()):
-            raise InvokeAuthorizationError('Invalid base url')
-        
+        if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()):
+            raise InvokeAuthorizationError("Invalid base url")
+
         if len(self._client.api_key) < 18:
-            raise InvokeAuthorizationError('Invalid API key')
+            raise InvokeAuthorizationError("Invalid API key")
 
         for text in input:
             result = []
-            if 'kill' in text:
+            if "kill" in text:
                 moderation_categories = {
-                    'harassment': False, 'harassment/threatening': False, 'hate': False, 'hate/threatening': False,
-                    'self-harm': False, 'self-harm/instructions': False, 'self-harm/intent': False, 'sexual': False,
-                    'sexual/minors': False, 'violence': False, 'violence/graphic': False
+                    "harassment": False,
+                    "harassment/threatening": False,
+                    "hate": False,
+                    "hate/threatening": False,
+                    "self-harm": False,
+                    "self-harm/instructions": False,
+                    "self-harm/intent": False,
+                    "sexual": False,
+                    "sexual/minors": False,
+                    "violence": False,
+                    "violence/graphic": False,
                 }
                 moderation_categories_scores = {
-                    'harassment': 1.0, 'harassment/threatening': 1.0, 'hate': 1.0, 'hate/threatening': 1.0,
-                    'self-harm': 1.0, 'self-harm/instructions': 1.0, 'self-harm/intent': 1.0, 'sexual': 1.0,
-                    'sexual/minors': 1.0, 'violence': 1.0, 'violence/graphic': 1.0
+                    "harassment": 1.0,
+                    "harassment/threatening": 1.0,
+                    "hate": 1.0,
+                    "hate/threatening": 1.0,
+                    "self-harm": 1.0,
+                    "self-harm/instructions": 1.0,
+                    "self-harm/intent": 1.0,
+                    "sexual": 1.0,
+                    "sexual/minors": 1.0,
+                    "violence": 1.0,
+                    "violence/graphic": 1.0,
                 }
 
-                result.append(Moderation(
-                    flagged=True,
-                    categories=Categories(**moderation_categories),
-                    category_scores=CategoryScores(**moderation_categories_scores)
-                ))
+                result.append(
+                    Moderation(
+                        flagged=True,
+                        categories=Categories(**moderation_categories),
+                        category_scores=CategoryScores(**moderation_categories_scores),
+                    )
+                )
             else:
                 moderation_categories = {
-                    'harassment': False, 'harassment/threatening': False, 'hate': False, 'hate/threatening': False,
-                    'self-harm': False, 'self-harm/instructions': False, 'self-harm/intent': False, 'sexual': False,
-                    'sexual/minors': False, 'violence': False, 'violence/graphic': False
+                    "harassment": False,
+                    "harassment/threatening": False,
+                    "hate": False,
+                    "hate/threatening": False,
+                    "self-harm": False,
+                    "self-harm/instructions": False,
+                    "self-harm/intent": False,
+                    "sexual": False,
+                    "sexual/minors": False,
+                    "violence": False,
+                    "violence/graphic": False,
                 }
                 moderation_categories_scores = {
-                    'harassment': 0.0, 'harassment/threatening': 0.0, 'hate': 0.0, 'hate/threatening': 0.0,
-                    'self-harm': 0.0, 'self-harm/instructions': 0.0, 'self-harm/intent': 0.0, 'sexual': 0.0,
-                    'sexual/minors': 0.0, 'violence': 0.0, 'violence/graphic': 0.0
+                    "harassment": 0.0,
+                    "harassment/threatening": 0.0,
+                    "hate": 0.0,
+                    "hate/threatening": 0.0,
+                    "self-harm": 0.0,
+                    "self-harm/instructions": 0.0,
+                    "self-harm/intent": 0.0,
+                    "sexual": 0.0,
+                    "sexual/minors": 0.0,
+                    "violence": 0.0,
+                    "violence/graphic": 0.0,
                 }
-                result.append(Moderation(
-                    flagged=False,
-                    categories=Categories(**moderation_categories),
-                    category_scores=CategoryScores(**moderation_categories_scores)
-                ))
+                result.append(
+                    Moderation(
+                        flagged=False,
+                        categories=Categories(**moderation_categories),
+                        category_scores=CategoryScores(**moderation_categories_scores),
+                    )
+                )
 
-        return ModerationCreateResponse(
-            id='shiroii kuloko',
-            model=model,
-            results=result
-        )
+        return ModerationCreateResponse(id="shiroii kuloko", model=model, results=result)

+ 6 - 5
api/tests/integration_tests/model_runtime/__mock/openai_remote.py

@@ -6,17 +6,18 @@ from openai.types.model import Model
 
 class MockModelClass:
     """
-        mock class for openai.models.Models
+    mock class for openai.models.Models
     """
+
     def list(
         self,
         **kwargs,
     ) -> list[Model]:
         return [
             Model(
-                id='ft:gpt-3.5-turbo-0613:personal::8GYJLPDQ',
+                id="ft:gpt-3.5-turbo-0613:personal::8GYJLPDQ",
                 created=int(time()),
-                object='model',
-                owned_by='organization:org-123',
+                object="model",
+                owned_by="organization:org-123",
             )
-        ]
+        ]

+ 9 - 10
api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py

@@ -9,7 +9,8 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
 
 
 class MockSpeech2TextClass:
-    def speech2text_create(self: Transcriptions,
+    def speech2text_create(
+        self: Transcriptions,
         *,
         file: FileTypes,
         model: Union[str, Literal["whisper-1"]],
@@ -17,14 +18,12 @@ class MockSpeech2TextClass:
         prompt: str | NotGiven = NOT_GIVEN,
         response_format: Literal["json", "text", "srt", "verbose_json", "vtt"] | NotGiven = NOT_GIVEN,
         temperature: float | NotGiven = NOT_GIVEN,
-        **kwargs: Any
+        **kwargs: Any,
     ) -> Transcription:
-        if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()):
-            raise InvokeAuthorizationError('Invalid base url')
-        
+        if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()):
+            raise InvokeAuthorizationError("Invalid base url")
+
         if len(self._client.api_key) < 18:
-            raise InvokeAuthorizationError('Invalid API key')
-        
-        return Transcription(
-            text='1, 2, 3, 4, 5, 6, 7, 8, 9, 10'
-        )
+            raise InvokeAuthorizationError("Invalid API key")
+
+        return Transcription(text="1, 2, 3, 4, 5, 6, 7, 8, 9, 10")

+ 65 - 70
api/tests/integration_tests/model_runtime/__mock/xinference.py

@@ -19,40 +19,43 @@ from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage
 
 
 class MockXinferenceClass:
-    def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]:
-        if not re.match(r'https?:\/\/[^\s\/$.?#].[^\s]*$', self.base_url):
-            raise RuntimeError('404 Not Found')
-        
-        if 'generate' == model_uid:
+    def get_chat_model(
+        self: Client, model_uid: str
+    ) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]:
+        if not re.match(r"https?:\/\/[^\s\/$.?#].[^\s]*$", self.base_url):
+            raise RuntimeError("404 Not Found")
+
+        if "generate" == model_uid:
             return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={})
-        if 'chat' == model_uid:
+        if "chat" == model_uid:
             return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={})
-        if 'embedding' == model_uid:
+        if "embedding" == model_uid:
             return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={})
-        if 'rerank' == model_uid:
+        if "rerank" == model_uid:
             return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={})
-        raise RuntimeError('404 Not Found')
-        
+        raise RuntimeError("404 Not Found")
+
     def get(self: Session, url: str, **kwargs):
         response = Response()
-        if 'v1/models/' in url:
+        if "v1/models/" in url:
             # get model uid
-            model_uid = url.split('/')[-1] or ''
-            if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \
-                model_uid not in ['generate', 'chat', 'embedding', 'rerank']:
+            model_uid = url.split("/")[-1] or ""
+            if not re.match(
+                r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", model_uid
+            ) and model_uid not in ["generate", "chat", "embedding", "rerank"]:
                 response.status_code = 404
-                response._content = b'{}'
+                response._content = b"{}"
                 return response
 
             # check if url is valid
-            if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url):
+            if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", url):
                 response.status_code = 404
-                response._content = b'{}'
+                response._content = b"{}"
                 return response
-            
-            if model_uid in ['generate', 'chat']:
+
+            if model_uid in ["generate", "chat"]:
                 response.status_code = 200
-                response._content = b'''{
+                response._content = b"""{
                     "model_type": "LLM",
                     "address": "127.0.0.1:43877",
                     "accelerators": [
@@ -75,12 +78,12 @@ class MockXinferenceClass:
                     "revision": null,
                     "context_length": 2048,
                     "replica": 1
-                }'''
+                }"""
                 return response
-            
-            elif model_uid == 'embedding':
+
+            elif model_uid == "embedding":
                 response.status_code = 200
-                response._content = b'''{
+                response._content = b"""{
                     "model_type": "embedding",
                     "address": "127.0.0.1:43877",
                     "accelerators": [
@@ -93,51 +96,48 @@ class MockXinferenceClass:
                     ],
                     "revision": null,
                     "max_tokens": 512
-                }'''
+                }"""
                 return response
-            
-        elif 'v1/cluster/auth' in url:
+
+        elif "v1/cluster/auth" in url:
             response.status_code = 200
-            response._content = b'''{
+            response._content = b"""{
                 "auth": true
-            }'''
+            }"""
             return response
-        
+
     def _check_cluster_authenticated(self):
         self._cluster_authed = True
-        
-    def rerank(self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool) -> dict:
+
+    def rerank(
+        self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool
+    ) -> dict:
         # check if self._model_uid is a valid uuid
-        if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
-            self._model_uid != 'rerank':
-            raise RuntimeError('404 Not Found')
-        
-        if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._base_url):
-            raise RuntimeError('404 Not Found')
+        if (
+            not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid)
+            and self._model_uid != "rerank"
+        ):
+            raise RuntimeError("404 Not Found")
+
+        if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._base_url):
+            raise RuntimeError("404 Not Found")
 
         if top_n is None:
             top_n = 1
 
         return {
-            'results': [
-                {
-                    'index': i,
-                    'document': doc,
-                    'relevance_score': 0.9
-                }
-                for i, doc in enumerate(documents[:top_n])
+            "results": [
+                {"index": i, "document": doc, "relevance_score": 0.9} for i, doc in enumerate(documents[:top_n])
             ]
         }
-        
-    def create_embedding(
-        self: RESTfulGenerateModelHandle,
-        input: Union[str, list[str]],
-        **kwargs
-    ) -> dict:
+
+    def create_embedding(self: RESTfulGenerateModelHandle, input: Union[str, list[str]], **kwargs) -> dict:
         # check if self._model_uid is a valid uuid
-        if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
-            self._model_uid != 'embedding':
-            raise RuntimeError('404 Not Found')
+        if (
+            not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid)
+            and self._model_uid != "embedding"
+        ):
+            raise RuntimeError("404 Not Found")
 
         if isinstance(input, str):
             input = [input]
@@ -147,32 +147,27 @@ class MockXinferenceClass:
             object="list",
             model=self._model_uid,
             data=[
-                EmbeddingData(
-                    index=i,
-                    object="embedding",
-                    embedding=[1919.810 for _ in range(768)]
-                )
+                EmbeddingData(index=i, object="embedding", embedding=[1919.810 for _ in range(768)])
                 for i in range(ipt_len)
             ],
-            usage=EmbeddingUsage(
-                prompt_tokens=ipt_len,
-                total_tokens=ipt_len
-            )
+            usage=EmbeddingUsage(prompt_tokens=ipt_len, total_tokens=ipt_len),
         )
 
         return embedding
 
-MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
+
+MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
+
 
 @pytest.fixture
 def setup_xinference_mock(request, monkeypatch: MonkeyPatch):
     if MOCK:
-        monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model)
-        monkeypatch.setattr(Client, '_check_cluster_authenticated', MockXinferenceClass._check_cluster_authenticated)
-        monkeypatch.setattr(Session, 'get', MockXinferenceClass.get)
-        monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding)
-        monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank)
+        monkeypatch.setattr(Client, "get_model", MockXinferenceClass.get_chat_model)
+        monkeypatch.setattr(Client, "_check_cluster_authenticated", MockXinferenceClass._check_cluster_authenticated)
+        monkeypatch.setattr(Session, "get", MockXinferenceClass.get)
+        monkeypatch.setattr(RESTfulEmbeddingModelHandle, "create_embedding", MockXinferenceClass.create_embedding)
+        monkeypatch.setattr(RESTfulRerankModelHandle, "rerank", MockXinferenceClass.rerank)
     yield
 
     if MOCK:
-        monkeypatch.undo()
+        monkeypatch.undo()

+ 26 - 49
api/tests/integration_tests/model_runtime/anthropic/test_llm.py

@@ -10,79 +10,60 @@ from core.model_runtime.model_providers.anthropic.llm.llm import AnthropicLargeL
 from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock
 
 
-@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True)
+@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True)
 def test_validate_credentials(setup_anthropic_mock):
     model = AnthropicLargeLanguageModel()
 
     with pytest.raises(CredentialsValidateFailedError):
-        model.validate_credentials(
-            model='claude-instant-1.2',
-            credentials={
-                'anthropic_api_key': 'invalid_key'
-            }
-        )
+        model.validate_credentials(model="claude-instant-1.2", credentials={"anthropic_api_key": "invalid_key"})
 
     model.validate_credentials(
-        model='claude-instant-1.2',
-        credentials={
-            'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
-        }
+        model="claude-instant-1.2", credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")}
     )
 
-@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True)
+
+@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True)
 def test_invoke_model(setup_anthropic_mock):
     model = AnthropicLargeLanguageModel()
 
     response = model.invoke(
-        model='claude-instant-1.2',
+        model="claude-instant-1.2",
         credentials={
-            'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY'),
-            'anthropic_api_url': os.environ.get('ANTHROPIC_API_URL')
+            "anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY"),
+            "anthropic_api_url": os.environ.get("ANTHROPIC_API_URL"),
         },
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
+            UserPromptMessage(content="Hello World!"),
         ],
-        model_parameters={
-            'temperature': 0.0,
-            'top_p': 1.0,
-            'max_tokens': 10
-        },
-        stop=['How'],
+        model_parameters={"temperature": 0.0, "top_p": 1.0, "max_tokens": 10},
+        stop=["How"],
         stream=False,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, LLMResult)
     assert len(response.message.content) > 0
 
-@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True)
+
+@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True)
 def test_invoke_stream_model(setup_anthropic_mock):
     model = AnthropicLargeLanguageModel()
 
     response = model.invoke(
-        model='claude-instant-1.2',
-        credentials={
-            'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
-        },
+        model="claude-instant-1.2",
+        credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")},
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
+            UserPromptMessage(content="Hello World!"),
         ],
-        model_parameters={
-            'temperature': 0.0,
-            'max_tokens': 100
-        },
+        model_parameters={"temperature": 0.0, "max_tokens": 100},
         stream=True,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, Generator)
@@ -98,18 +79,14 @@ def test_get_num_tokens():
     model = AnthropicLargeLanguageModel()
 
     num_tokens = model.get_num_tokens(
-        model='claude-instant-1.2',
-        credentials={
-            'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
-        },
+        model="claude-instant-1.2",
+        credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")},
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ]
+            UserPromptMessage(content="Hello World!"),
+        ],
     )
 
     assert num_tokens == 18

+ 3 - 9
api/tests/integration_tests/model_runtime/anthropic/test_provider.py

@@ -7,17 +7,11 @@ from core.model_runtime.model_providers.anthropic.anthropic import AnthropicProv
 from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock
 
 
-@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True)
+@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True)
 def test_validate_provider_credentials(setup_anthropic_mock):
     provider = AnthropicProvider()
 
     with pytest.raises(CredentialsValidateFailedError):
-        provider.validate_provider_credentials(
-            credentials={}
-        )
+        provider.validate_provider_credentials(credentials={})
 
-    provider.validate_provider_credentials(
-        credentials={
-            'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
-        }
-    )
+    provider.validate_provider_credentials(credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")})

Разлика између датотеке није приказан због своје велике величине
+ 73 - 89
api/tests/integration_tests/model_runtime/azure_openai/test_llm.py


+ 20 - 29
api/tests/integration_tests/model_runtime/azure_openai/test_text_embedding.py

@@ -8,45 +8,43 @@ from core.model_runtime.model_providers.azure_openai.text_embedding.text_embeddi
 from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
 
 
-@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
+@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
 def test_validate_credentials(setup_openai_mock):
     model = AzureOpenAITextEmbeddingModel()
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
-            model='embedding',
+            model="embedding",
             credentials={
-                'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'),
-                'openai_api_key': 'invalid_key',
-                'base_model_name': 'text-embedding-ada-002'
-            }
+                "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"),
+                "openai_api_key": "invalid_key",
+                "base_model_name": "text-embedding-ada-002",
+            },
         )
 
     model.validate_credentials(
-        model='embedding',
+        model="embedding",
         credentials={
-            'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'),
-            'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'),
-            'base_model_name': 'text-embedding-ada-002'
-        }
+            "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"),
+            "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"),
+            "base_model_name": "text-embedding-ada-002",
+        },
     )
 
-@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
+
+@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
 def test_invoke_model(setup_openai_mock):
     model = AzureOpenAITextEmbeddingModel()
 
     result = model.invoke(
-        model='embedding',
+        model="embedding",
         credentials={
-            'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'),
-            'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'),
-            'base_model_name': 'text-embedding-ada-002'
+            "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"),
+            "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"),
+            "base_model_name": "text-embedding-ada-002",
         },
-        texts=[
-            "hello",
-            "world"
-        ],
-        user="abc-123"
+        texts=["hello", "world"],
+        user="abc-123",
     )
 
     assert isinstance(result, TextEmbeddingResult)
@@ -58,14 +56,7 @@ def test_get_num_tokens():
     model = AzureOpenAITextEmbeddingModel()
 
     num_tokens = model.get_num_tokens(
-        model='embedding',
-        credentials={
-            'base_model_name': 'text-embedding-ada-002'
-        },
-        texts=[
-            "hello",
-            "world"
-        ]
+        model="embedding", credentials={"base_model_name": "text-embedding-ada-002"}, texts=["hello", "world"]
     )
 
     assert num_tokens == 2

+ 57 - 75
api/tests/integration_tests/model_runtime/baichuan/test_llm.py

@@ -17,111 +17,99 @@ def test_predefined_models():
     assert len(model_schemas) >= 1
     assert isinstance(model_schemas[0], AIModelEntity)
 
+
 def test_validate_credentials_for_chat_model():
     sleep(3)
     model = BaichuanLarguageModel()
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
-            model='baichuan2-turbo',
-            credentials={
-                'api_key': 'invalid_key',
-                'secret_key': 'invalid_key'
-            }
+            model="baichuan2-turbo", credentials={"api_key": "invalid_key", "secret_key": "invalid_key"}
         )
 
     model.validate_credentials(
-        model='baichuan2-turbo',
+        model="baichuan2-turbo",
         credentials={
-            'api_key': os.environ.get('BAICHUAN_API_KEY'),
-            'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
-        }
+            "api_key": os.environ.get("BAICHUAN_API_KEY"),
+            "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
+        },
     )
 
+
 def test_invoke_model():
     sleep(3)
     model = BaichuanLarguageModel()
 
     response = model.invoke(
-        model='baichuan2-turbo',
+        model="baichuan2-turbo",
         credentials={
-            'api_key': os.environ.get('BAICHUAN_API_KEY'),
-            'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
+            "api_key": os.environ.get("BAICHUAN_API_KEY"),
+            "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
         },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ],
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
         model_parameters={
-            'temperature': 0.7,
-            'top_p': 1.0,
-            'top_k': 1,
+            "temperature": 0.7,
+            "top_p": 1.0,
+            "top_k": 1,
         },
-        stop=['you'],
+        stop=["you"],
         user="abc-123",
-        stream=False
+        stream=False,
     )
 
     assert isinstance(response, LLMResult)
     assert len(response.message.content) > 0
     assert response.usage.total_tokens > 0
 
+
 def test_invoke_model_with_system_message():
     sleep(3)
     model = BaichuanLarguageModel()
 
     response = model.invoke(
-        model='baichuan2-turbo',
+        model="baichuan2-turbo",
         credentials={
-            'api_key': os.environ.get('BAICHUAN_API_KEY'),
-            'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
+            "api_key": os.environ.get("BAICHUAN_API_KEY"),
+            "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
         },
         prompt_messages=[
-            SystemPromptMessage(
-                content='请记住你是Kasumi。'
-            ),
-            UserPromptMessage(
-                content='现在告诉我你是谁?'
-            )
+            SystemPromptMessage(content="请记住你是Kasumi。"),
+            UserPromptMessage(content="现在告诉我你是谁?"),
         ],
         model_parameters={
-            'temperature': 0.7,
-            'top_p': 1.0,
-            'top_k': 1,
+            "temperature": 0.7,
+            "top_p": 1.0,
+            "top_k": 1,
         },
-        stop=['you'],
+        stop=["you"],
         user="abc-123",
-        stream=False
+        stream=False,
     )
 
     assert isinstance(response, LLMResult)
     assert len(response.message.content) > 0
     assert response.usage.total_tokens > 0
 
+
 def test_invoke_stream_model():
     sleep(3)
     model = BaichuanLarguageModel()
 
     response = model.invoke(
-        model='baichuan2-turbo',
+        model="baichuan2-turbo",
         credentials={
-            'api_key': os.environ.get('BAICHUAN_API_KEY'),
-            'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
+            "api_key": os.environ.get("BAICHUAN_API_KEY"),
+            "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
         },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ],
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
         model_parameters={
-            'temperature': 0.7,
-            'top_p': 1.0,
-            'top_k': 1,
+            "temperature": 0.7,
+            "top_p": 1.0,
+            "top_k": 1,
         },
-        stop=['you'],
+        stop=["you"],
         stream=True,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, Generator)
@@ -131,34 +119,31 @@ def test_invoke_stream_model():
         assert isinstance(chunk.delta.message, AssistantPromptMessage)
         assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
 
+
 def test_invoke_with_search():
     sleep(3)
     model = BaichuanLarguageModel()
 
     response = model.invoke(
-        model='baichuan2-turbo',
+        model="baichuan2-turbo",
         credentials={
-            'api_key': os.environ.get('BAICHUAN_API_KEY'),
-            'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
+            "api_key": os.environ.get("BAICHUAN_API_KEY"),
+            "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
         },
-        prompt_messages=[
-            UserPromptMessage(
-                content='北京今天的天气怎么样'
-            )
-        ],
+        prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
         model_parameters={
-            'temperature': 0.7,
-            'top_p': 1.0,
-            'top_k': 1,
-            'with_search_enhance': True,
+            "temperature": 0.7,
+            "top_p": 1.0,
+            "top_k": 1,
+            "with_search_enhance": True,
         },
-        stop=['you'],
+        stop=["you"],
         stream=True,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, Generator)
-    total_message = ''
+    total_message = ""
     for chunk in response:
         assert isinstance(chunk, LLMResultChunk)
         assert isinstance(chunk.delta, LLMResultChunkDelta)
@@ -166,25 +151,22 @@ def test_invoke_with_search():
         assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True
         total_message += chunk.delta.message.content
 
-    assert '不' not in total_message
+    assert "不" not in total_message
+
 
 def test_get_num_tokens():
     sleep(3)
     model = BaichuanLarguageModel()
 
     response = model.get_num_tokens(
-        model='baichuan2-turbo',
+        model="baichuan2-turbo",
         credentials={
-            'api_key': os.environ.get('BAICHUAN_API_KEY'),
-            'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
+            "api_key": os.environ.get("BAICHUAN_API_KEY"),
+            "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
         },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ],
-        tools=[]
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
+        tools=[],
     )
 
     assert isinstance(response, int)
-    assert response == 9
+    assert response == 9

+ 2 - 10
api/tests/integration_tests/model_runtime/baichuan/test_provider.py

@@ -10,14 +10,6 @@ def test_validate_provider_credentials():
     provider = BaichuanProvider()
 
     with pytest.raises(CredentialsValidateFailedError):
-        provider.validate_provider_credentials(
-            credentials={
-                'api_key': 'hahahaha'
-            }
-        )
+        provider.validate_provider_credentials(credentials={"api_key": "hahahaha"})
 
-    provider.validate_provider_credentials(
-        credentials={
-            'api_key': os.environ.get('BAICHUAN_API_KEY')
-        }
-    )
+    provider.validate_provider_credentials(credentials={"api_key": os.environ.get("BAICHUAN_API_KEY")})

+ 15 - 27
api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py

@@ -11,18 +11,10 @@ def test_validate_credentials():
     model = BaichuanTextEmbeddingModel()
 
     with pytest.raises(CredentialsValidateFailedError):
-        model.validate_credentials(
-            model='baichuan-text-embedding',
-            credentials={
-                'api_key': 'invalid_key'
-            }
-        )
+        model.validate_credentials(model="baichuan-text-embedding", credentials={"api_key": "invalid_key"})
 
     model.validate_credentials(
-        model='baichuan-text-embedding',
-        credentials={
-            'api_key': os.environ.get('BAICHUAN_API_KEY')
-        }
+        model="baichuan-text-embedding", credentials={"api_key": os.environ.get("BAICHUAN_API_KEY")}
     )
 
 
@@ -30,44 +22,40 @@ def test_invoke_model():
     model = BaichuanTextEmbeddingModel()
 
     result = model.invoke(
-        model='baichuan-text-embedding',
+        model="baichuan-text-embedding",
         credentials={
-            'api_key': os.environ.get('BAICHUAN_API_KEY'),
+            "api_key": os.environ.get("BAICHUAN_API_KEY"),
         },
-        texts=[
-            "hello",
-            "world"
-        ],
-        user="abc-123"
+        texts=["hello", "world"],
+        user="abc-123",
     )
 
     assert isinstance(result, TextEmbeddingResult)
     assert len(result.embeddings) == 2
     assert result.usage.total_tokens == 6
 
+
 def test_get_num_tokens():
     model = BaichuanTextEmbeddingModel()
 
     num_tokens = model.get_num_tokens(
-        model='baichuan-text-embedding',
+        model="baichuan-text-embedding",
         credentials={
-            'api_key': os.environ.get('BAICHUAN_API_KEY'),
+            "api_key": os.environ.get("BAICHUAN_API_KEY"),
         },
-        texts=[
-            "hello",
-            "world"
-        ]
+        texts=["hello", "world"],
     )
 
     assert num_tokens == 2
 
+
 def test_max_chunks():
     model = BaichuanTextEmbeddingModel()
 
     result = model.invoke(
-        model='baichuan-text-embedding',
+        model="baichuan-text-embedding",
         credentials={
-            'api_key': os.environ.get('BAICHUAN_API_KEY'),
+            "api_key": os.environ.get("BAICHUAN_API_KEY"),
         },
         texts=[
             "hello",
@@ -92,8 +80,8 @@ def test_max_chunks():
             "world",
             "hello",
             "world",
-        ]
+        ],
     )
 
     assert isinstance(result, TextEmbeddingResult)
-    assert len(result.embeddings) == 22
+    assert len(result.embeddings) == 22

+ 25 - 41
api/tests/integration_tests/model_runtime/bedrock/test_llm.py

@@ -13,77 +13,63 @@ def test_validate_credentials():
     model = BedrockLargeLanguageModel()
 
     with pytest.raises(CredentialsValidateFailedError):
-        model.validate_credentials(
-            model='meta.llama2-13b-chat-v1',
-            credentials={
-                'anthropic_api_key': 'invalid_key'
-            }
-        )
+        model.validate_credentials(model="meta.llama2-13b-chat-v1", credentials={"anthropic_api_key": "invalid_key"})
 
     model.validate_credentials(
-        model='meta.llama2-13b-chat-v1',
+        model="meta.llama2-13b-chat-v1",
         credentials={
             "aws_region": os.getenv("AWS_REGION"),
             "aws_access_key": os.getenv("AWS_ACCESS_KEY"),
-            "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
-        }
+            "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
+        },
     )
 
+
 def test_invoke_model():
     model = BedrockLargeLanguageModel()
 
     response = model.invoke(
-        model='meta.llama2-13b-chat-v1',
+        model="meta.llama2-13b-chat-v1",
         credentials={
             "aws_region": os.getenv("AWS_REGION"),
             "aws_access_key": os.getenv("AWS_ACCESS_KEY"),
-            "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
+            "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
         },
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
+            UserPromptMessage(content="Hello World!"),
         ],
-        model_parameters={
-            'temperature': 0.0,
-            'top_p': 1.0,
-            'max_tokens_to_sample': 10
-        },
-        stop=['How'],
+        model_parameters={"temperature": 0.0, "top_p": 1.0, "max_tokens_to_sample": 10},
+        stop=["How"],
         stream=False,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, LLMResult)
     assert len(response.message.content) > 0
 
+
 def test_invoke_stream_model():
     model = BedrockLargeLanguageModel()
 
     response = model.invoke(
-        model='meta.llama2-13b-chat-v1',
+        model="meta.llama2-13b-chat-v1",
         credentials={
             "aws_region": os.getenv("AWS_REGION"),
             "aws_access_key": os.getenv("AWS_ACCESS_KEY"),
-            "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
+            "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
         },
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
+            UserPromptMessage(content="Hello World!"),
         ],
-        model_parameters={
-            'temperature': 0.0,
-            'max_tokens_to_sample': 100
-        },
+        model_parameters={"temperature": 0.0, "max_tokens_to_sample": 100},
         stream=True,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, Generator)
@@ -100,20 +86,18 @@ def test_get_num_tokens():
     model = BedrockLargeLanguageModel()
 
     num_tokens = model.get_num_tokens(
-        model='meta.llama2-13b-chat-v1',
-        credentials = {
+        model="meta.llama2-13b-chat-v1",
+        credentials={
             "aws_region": os.getenv("AWS_REGION"),
             "aws_access_key": os.getenv("AWS_ACCESS_KEY"),
-            "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
+            "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
         },
         messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ]
+            UserPromptMessage(content="Hello World!"),
+        ],
     )
 
     assert num_tokens == 18

+ 2 - 4
api/tests/integration_tests/model_runtime/bedrock/test_provider.py

@@ -10,14 +10,12 @@ def test_validate_provider_credentials():
     provider = BedrockProvider()
 
     with pytest.raises(CredentialsValidateFailedError):
-        provider.validate_provider_credentials(
-            credentials={}
-        )
+        provider.validate_provider_credentials(credentials={})
 
     provider.validate_provider_credentials(
         credentials={
             "aws_region": os.getenv("AWS_REGION"),
             "aws_access_key": os.getenv("AWS_ACCESS_KEY"),
-            "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
+            "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
         }
     )

+ 77 - 138
api/tests/integration_tests/model_runtime/chatglm/test_llm.py

@@ -23,79 +23,64 @@ def test_predefined_models():
     assert len(model_schemas) >= 1
     assert isinstance(model_schemas[0], AIModelEntity)
 
-@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
+
+@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
 def test_validate_credentials_for_chat_model(setup_openai_mock):
     model = ChatGLMLargeLanguageModel()
 
     with pytest.raises(CredentialsValidateFailedError):
-        model.validate_credentials(
-            model='chatglm2-6b',
-            credentials={
-                'api_base': 'invalid_key'
-            }
-        )
-
-    model.validate_credentials(
-        model='chatglm2-6b',
-        credentials={
-            'api_base': os.environ.get('CHATGLM_API_BASE')
-        }
-    )
+        model.validate_credentials(model="chatglm2-6b", credentials={"api_base": "invalid_key"})
+
+    model.validate_credentials(model="chatglm2-6b", credentials={"api_base": os.environ.get("CHATGLM_API_BASE")})
+
 
-@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
+@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
 def test_invoke_model(setup_openai_mock):
     model = ChatGLMLargeLanguageModel()
 
     response = model.invoke(
-        model='chatglm2-6b',
-        credentials={
-            'api_base': os.environ.get('CHATGLM_API_BASE')
-        },
+        model="chatglm2-6b",
+        credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
+            UserPromptMessage(content="Hello World!"),
         ],
         model_parameters={
-            'temperature': 0.7,
-            'top_p': 1.0,
+            "temperature": 0.7,
+            "top_p": 1.0,
         },
-        stop=['you'],
+        stop=["you"],
         user="abc-123",
-        stream=False
+        stream=False,
     )
 
     assert isinstance(response, LLMResult)
     assert len(response.message.content) > 0
     assert response.usage.total_tokens > 0
 
-@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
+
+@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
 def test_invoke_stream_model(setup_openai_mock):
     model = ChatGLMLargeLanguageModel()
 
     response = model.invoke(
-        model='chatglm2-6b',
-        credentials={
-            'api_base': os.environ.get('CHATGLM_API_BASE')
-        },
+        model="chatglm2-6b",
+        credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
+            UserPromptMessage(content="Hello World!"),
         ],
         model_parameters={
-            'temperature': 0.7,
-            'top_p': 1.0,
+            "temperature": 0.7,
+            "top_p": 1.0,
         },
-        stop=['you'],
+        stop=["you"],
         stream=True,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, Generator)
@@ -105,56 +90,45 @@ def test_invoke_stream_model(setup_openai_mock):
         assert isinstance(chunk.delta.message, AssistantPromptMessage)
         assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
 
-@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
+
+@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
 def test_invoke_stream_model_with_functions(setup_openai_mock):
     model = ChatGLMLargeLanguageModel()
 
     response = model.invoke(
-        model='chatglm3-6b',
-        credentials={
-            'api_base': os.environ.get('CHATGLM_API_BASE')
-        },
+        model="chatglm3-6b",
+        credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
         prompt_messages=[
             SystemPromptMessage(
-                content='你是一个天气机器人,你不知道今天的天气怎么样,你需要通过调用一个函数来获取天气信息。'
+                content="你是一个天气机器人,你不知道今天的天气怎么样,你需要通过调用一个函数来获取天气信息。"
             ),
-            UserPromptMessage(
-                content='波士顿天气如何?'
-            )
+            UserPromptMessage(content="波士顿天气如何?"),
         ],
         model_parameters={
-            'temperature': 0,
-            'top_p': 1.0,
+            "temperature": 0,
+            "top_p": 1.0,
         },
-        stop=['you'],
-        user='abc-123',
+        stop=["you"],
+        user="abc-123",
         stream=True,
         tools=[
             PromptMessageTool(
-                name='get_current_weather',
-                description='Get the current weather in a given location',
+                name="get_current_weather",
+                description="Get the current weather in a given location",
                 parameters={
                     "type": "object",
                     "properties": {
-                        "location": {
-                        "type": "string",
-                            "description": "The city and state e.g. San Francisco, CA"
-                        },
-                        "unit": {
-                            "type": "string",
-                            "enum": ["celsius", "fahrenheit"]
-                        }
+                        "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
+                        "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
                     },
-                    "required": [
-                        "location"
-                    ]
-                }
+                    "required": ["location"],
+                },
             )
-        ]
+        ],
     )
 
     assert isinstance(response, Generator)
-    
+
     call: LLMResultChunk = None
     chunks = []
 
@@ -170,122 +144,87 @@ def test_invoke_stream_model_with_functions(setup_openai_mock):
             break
 
     assert call is not None
-    assert call.delta.message.tool_calls[0].function.name == 'get_current_weather'
+    assert call.delta.message.tool_calls[0].function.name == "get_current_weather"
 
-@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
+
+@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
 def test_invoke_model_with_functions(setup_openai_mock):
     model = ChatGLMLargeLanguageModel()
 
     response = model.invoke(
-        model='chatglm3-6b',
-        credentials={
-            'api_base': os.environ.get('CHATGLM_API_BASE')
-        },
-        prompt_messages=[
-            UserPromptMessage(
-                content='What is the weather like in San Francisco?'
-            )
-        ],
+        model="chatglm3-6b",
+        credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
+        prompt_messages=[UserPromptMessage(content="What is the weather like in San Francisco?")],
         model_parameters={
-            'temperature': 0.7,
-            'top_p': 1.0,
+            "temperature": 0.7,
+            "top_p": 1.0,
         },
-        stop=['you'],
-        user='abc-123',
+        stop=["you"],
+        user="abc-123",
         stream=False,
         tools=[
             PromptMessageTool(
-                name='get_current_weather',
-                description='Get the current weather in a given location',
+                name="get_current_weather",
+                description="Get the current weather in a given location",
                 parameters={
                     "type": "object",
                     "properties": {
-                        "location": {
-                        "type": "string",
-                            "description": "The city and state e.g. San Francisco, CA"
-                        },
-                        "unit": {
-                            "type": "string",
-                            "enum": [
-                                "c",
-                                "f"
-                            ]
-                        }
+                        "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
+                        "unit": {"type": "string", "enum": ["c", "f"]},
                     },
-                    "required": [
-                        "location"
-                    ]
-                }
+                    "required": ["location"],
+                },
             )
-        ]
+        ],
     )
 
     assert isinstance(response, LLMResult)
     assert len(response.message.content) > 0
     assert response.usage.total_tokens > 0
-    assert response.message.tool_calls[0].function.name == 'get_current_weather'
+    assert response.message.tool_calls[0].function.name == "get_current_weather"
 
 
 def test_get_num_tokens():
     model = ChatGLMLargeLanguageModel()
 
     num_tokens = model.get_num_tokens(
-        model='chatglm2-6b',
-        credentials={
-            'api_base': os.environ.get('CHATGLM_API_BASE')
-        },
+        model="chatglm2-6b",
+        credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
+            UserPromptMessage(content="Hello World!"),
         ],
         tools=[
             PromptMessageTool(
-                name='get_current_weather',
-                description='Get the current weather in a given location',
+                name="get_current_weather",
+                description="Get the current weather in a given location",
                 parameters={
                     "type": "object",
                     "properties": {
-                        "location": {
-                        "type": "string",
-                            "description": "The city and state e.g. San Francisco, CA"
-                        },
-                        "unit": {
-                            "type": "string",
-                            "enum": [
-                                "c",
-                                "f"
-                            ]
-                        }
+                        "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
+                        "unit": {"type": "string", "enum": ["c", "f"]},
                     },
-                    "required": [
-                        "location"
-                    ]
-                }
+                    "required": ["location"],
+                },
             )
-        ]
+        ],
     )
 
     assert isinstance(num_tokens, int)
     assert num_tokens == 77
 
     num_tokens = model.get_num_tokens(
-        model='chatglm2-6b',
-        credentials={
-            'api_base': os.environ.get('CHATGLM_API_BASE')
-        },
+        model="chatglm2-6b",
+        credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
+            UserPromptMessage(content="Hello World!"),
         ],
     )
 
     assert isinstance(num_tokens, int)
-    assert num_tokens == 21
+    assert num_tokens == 21

+ 3 - 11
api/tests/integration_tests/model_runtime/chatglm/test_provider.py

@@ -7,19 +7,11 @@ from core.model_runtime.model_providers.chatglm.chatglm import ChatGLMProvider
 from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
 
 
-@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
+@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
 def test_validate_provider_credentials(setup_openai_mock):
     provider = ChatGLMProvider()
 
     with pytest.raises(CredentialsValidateFailedError):
-        provider.validate_provider_credentials(
-            credentials={
-                'api_base': 'hahahaha'
-            }
-        )
+        provider.validate_provider_credentials(credentials={"api_base": "hahahaha"})
 
-    provider.validate_provider_credentials(
-        credentials={
-            'api_base': os.environ.get('CHATGLM_API_BASE')
-        }
-    )
+    provider.validate_provider_credentials(credentials={"api_base": os.environ.get("CHATGLM_API_BASE")})

+ 54 - 129
api/tests/integration_tests/model_runtime/cohere/test_llm.py

@@ -13,87 +13,49 @@ def test_validate_credentials_for_chat_model():
     model = CohereLargeLanguageModel()
 
     with pytest.raises(CredentialsValidateFailedError):
-        model.validate_credentials(
-            model='command-light-chat',
-            credentials={
-                'api_key': 'invalid_key'
-            }
-        )
-
-    model.validate_credentials(
-        model='command-light-chat',
-        credentials={
-            'api_key': os.environ.get('COHERE_API_KEY')
-        }
-    )
+        model.validate_credentials(model="command-light-chat", credentials={"api_key": "invalid_key"})
+
+    model.validate_credentials(model="command-light-chat", credentials={"api_key": os.environ.get("COHERE_API_KEY")})
 
 
 def test_validate_credentials_for_completion_model():
     model = CohereLargeLanguageModel()
 
     with pytest.raises(CredentialsValidateFailedError):
-        model.validate_credentials(
-            model='command-light',
-            credentials={
-                'api_key': 'invalid_key'
-            }
-        )
-
-    model.validate_credentials(
-        model='command-light',
-        credentials={
-            'api_key': os.environ.get('COHERE_API_KEY')
-        }
-    )
+        model.validate_credentials(model="command-light", credentials={"api_key": "invalid_key"})
+
+    model.validate_credentials(model="command-light", credentials={"api_key": os.environ.get("COHERE_API_KEY")})
 
 
 def test_invoke_completion_model():
     model = CohereLargeLanguageModel()
 
-    credentials = {
-        'api_key': os.environ.get('COHERE_API_KEY')
-    }
+    credentials = {"api_key": os.environ.get("COHERE_API_KEY")}
 
     result = model.invoke(
-        model='command-light',
+        model="command-light",
         credentials=credentials,
-        prompt_messages=[
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ],
-        model_parameters={
-            'temperature': 0.0,
-            'max_tokens': 1
-        },
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
+        model_parameters={"temperature": 0.0, "max_tokens": 1},
         stream=False,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(result, LLMResult)
     assert len(result.message.content) > 0
-    assert model._num_tokens_from_string('command-light', credentials, result.message.content) == 1
+    assert model._num_tokens_from_string("command-light", credentials, result.message.content) == 1
 
 
 def test_invoke_stream_completion_model():
     model = CohereLargeLanguageModel()
 
     result = model.invoke(
-        model='command-light',
-        credentials={
-            'api_key': os.environ.get('COHERE_API_KEY')
-        },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ],
-        model_parameters={
-            'temperature': 0.0,
-            'max_tokens': 100
-        },
+        model="command-light",
+        credentials={"api_key": os.environ.get("COHERE_API_KEY")},
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
+        model_parameters={"temperature": 0.0, "max_tokens": 100},
         stream=True,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(result, Generator)
@@ -109,28 +71,24 @@ def test_invoke_chat_model():
     model = CohereLargeLanguageModel()
 
     result = model.invoke(
-        model='command-light-chat',
-        credentials={
-            'api_key': os.environ.get('COHERE_API_KEY')
-        },
+        model="command-light-chat",
+        credentials={"api_key": os.environ.get("COHERE_API_KEY")},
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
+            UserPromptMessage(content="Hello World!"),
         ],
         model_parameters={
-            'temperature': 0.0,
-            'p': 0.99,
-            'presence_penalty': 0.0,
-            'frequency_penalty': 0.0,
-            'max_tokens': 10
+            "temperature": 0.0,
+            "p": 0.99,
+            "presence_penalty": 0.0,
+            "frequency_penalty": 0.0,
+            "max_tokens": 10,
         },
-        stop=['How'],
+        stop=["How"],
         stream=False,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(result, LLMResult)
@@ -141,24 +99,17 @@ def test_invoke_stream_chat_model():
     model = CohereLargeLanguageModel()
 
     result = model.invoke(
-        model='command-light-chat',
-        credentials={
-            'api_key': os.environ.get('COHERE_API_KEY')
-        },
+        model="command-light-chat",
+        credentials={"api_key": os.environ.get("COHERE_API_KEY")},
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
+            UserPromptMessage(content="Hello World!"),
         ],
-        model_parameters={
-            'temperature': 0.0,
-            'max_tokens': 100
-        },
+        model_parameters={"temperature": 0.0, "max_tokens": 100},
         stream=True,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(result, Generator)
@@ -177,32 +128,22 @@ def test_get_num_tokens():
     model = CohereLargeLanguageModel()
 
     num_tokens = model.get_num_tokens(
-        model='command-light',
-        credentials={
-            'api_key': os.environ.get('COHERE_API_KEY')
-        },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ]
+        model="command-light",
+        credentials={"api_key": os.environ.get("COHERE_API_KEY")},
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
     )
 
     assert num_tokens == 3
 
     num_tokens = model.get_num_tokens(
-        model='command-light-chat',
-        credentials={
-            'api_key': os.environ.get('COHERE_API_KEY')
-        },
+        model="command-light-chat",
+        credentials={"api_key": os.environ.get("COHERE_API_KEY")},
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ]
+            UserPromptMessage(content="Hello World!"),
+        ],
     )
 
     assert num_tokens == 15
@@ -213,25 +154,17 @@ def test_fine_tuned_model():
 
     # test invoke
     result = model.invoke(
-        model='85ec47be-6139-4f75-a4be-0f0ec1ef115c-ft',
-        credentials={
-            'api_key': os.environ.get('COHERE_API_KEY'),
-            'mode': 'completion'
-        },
+        model="85ec47be-6139-4f75-a4be-0f0ec1ef115c-ft",
+        credentials={"api_key": os.environ.get("COHERE_API_KEY"), "mode": "completion"},
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
+            UserPromptMessage(content="Hello World!"),
         ],
-        model_parameters={
-            'temperature': 0.0,
-            'max_tokens': 100
-        },
+        model_parameters={"temperature": 0.0, "max_tokens": 100},
         stream=False,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(result, LLMResult)
@@ -242,25 +175,17 @@ def test_fine_tuned_chat_model():
 
     # test invoke
     result = model.invoke(
-        model='94f2d55a-4c79-4c00-bde4-23962e74b170-ft',
-        credentials={
-            'api_key': os.environ.get('COHERE_API_KEY'),
-            'mode': 'chat'
-        },
+        model="94f2d55a-4c79-4c00-bde4-23962e74b170-ft",
+        credentials={"api_key": os.environ.get("COHERE_API_KEY"), "mode": "chat"},
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
+            UserPromptMessage(content="Hello World!"),
         ],
-        model_parameters={
-            'temperature': 0.0,
-            'max_tokens': 100
-        },
+        model_parameters={"temperature": 0.0, "max_tokens": 100},
         stream=False,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(result, LLMResult)

+ 2 - 8
api/tests/integration_tests/model_runtime/cohere/test_provider.py

@@ -10,12 +10,6 @@ def test_validate_provider_credentials():
     provider = CohereProvider()
 
     with pytest.raises(CredentialsValidateFailedError):
-        provider.validate_provider_credentials(
-            credentials={}
-        )
+        provider.validate_provider_credentials(credentials={})
 
-    provider.validate_provider_credentials(
-        credentials={
-            'api_key': os.environ.get('COHERE_API_KEY')
-        }
-    )
+    provider.validate_provider_credentials(credentials={"api_key": os.environ.get("COHERE_API_KEY")})

+ 7 - 19
api/tests/integration_tests/model_runtime/cohere/test_rerank.py

@@ -11,29 +11,17 @@ def test_validate_credentials():
     model = CohereRerankModel()
 
     with pytest.raises(CredentialsValidateFailedError):
-        model.validate_credentials(
-            model='rerank-english-v2.0',
-            credentials={
-                'api_key': 'invalid_key'
-            }
-        )
-
-    model.validate_credentials(
-        model='rerank-english-v2.0',
-        credentials={
-            'api_key': os.environ.get('COHERE_API_KEY')
-        }
-    )
+        model.validate_credentials(model="rerank-english-v2.0", credentials={"api_key": "invalid_key"})
+
+    model.validate_credentials(model="rerank-english-v2.0", credentials={"api_key": os.environ.get("COHERE_API_KEY")})
 
 
 def test_invoke_model():
     model = CohereRerankModel()
 
     result = model.invoke(
-        model='rerank-english-v2.0',
-        credentials={
-            'api_key': os.environ.get('COHERE_API_KEY')
-        },
+        model="rerank-english-v2.0",
+        credentials={"api_key": os.environ.get("COHERE_API_KEY")},
         query="What is the capital of the United States?",
         docs=[
             "Carson City is the capital city of the American state of Nevada. At the 2010 United States "
@@ -41,9 +29,9 @@ def test_invoke_model():
             "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) "
             "is the capital of the United States. It is a federal district. The President of the USA and many major "
             "national government offices are in the territory. This makes it the political center of the United "
-            "States of America."
+            "States of America.",
         ],
-        score_threshold=0.8
+        score_threshold=0.8,
     )
 
     assert isinstance(result, RerankResult)

+ 9 - 29
api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py

@@ -11,18 +11,10 @@ def test_validate_credentials():
     model = CohereTextEmbeddingModel()
 
     with pytest.raises(CredentialsValidateFailedError):
-        model.validate_credentials(
-            model='embed-multilingual-v3.0',
-            credentials={
-                'api_key': 'invalid_key'
-            }
-        )
+        model.validate_credentials(model="embed-multilingual-v3.0", credentials={"api_key": "invalid_key"})
 
     model.validate_credentials(
-        model='embed-multilingual-v3.0',
-        credentials={
-            'api_key': os.environ.get('COHERE_API_KEY')
-        }
+        model="embed-multilingual-v3.0", credentials={"api_key": os.environ.get("COHERE_API_KEY")}
     )
 
 
@@ -30,17 +22,10 @@ def test_invoke_model():
     model = CohereTextEmbeddingModel()
 
     result = model.invoke(
-        model='embed-multilingual-v3.0',
-        credentials={
-            'api_key': os.environ.get('COHERE_API_KEY')
-        },
-        texts=[
-            "hello",
-            "world",
-            " ".join(["long_text"] * 100),
-            " ".join(["another_long_text"] * 100)
-        ],
-        user="abc-123"
+        model="embed-multilingual-v3.0",
+        credentials={"api_key": os.environ.get("COHERE_API_KEY")},
+        texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)],
+        user="abc-123",
     )
 
     assert isinstance(result, TextEmbeddingResult)
@@ -52,14 +37,9 @@ def test_get_num_tokens():
     model = CohereTextEmbeddingModel()
 
     num_tokens = model.get_num_tokens(
-        model='embed-multilingual-v3.0',
-        credentials={
-            'api_key': os.environ.get('COHERE_API_KEY')
-        },
-        texts=[
-            "hello",
-            "world"
-        ]
+        model="embed-multilingual-v3.0",
+        credentials={"api_key": os.environ.get("COHERE_API_KEY")},
+        texts=["hello", "world"],
     )
 
     assert num_tokens == 3

Разлика између датотеке није приказан због своје велике величине
+ 38 - 71
api/tests/integration_tests/model_runtime/google/test_llm.py


+ 3 - 9
api/tests/integration_tests/model_runtime/google/test_provider.py

@@ -7,17 +7,11 @@ from core.model_runtime.model_providers.google.google import GoogleProvider
 from tests.integration_tests.model_runtime.__mock.google import setup_google_mock
 
 
-@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True)
+@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True)
 def test_validate_provider_credentials(setup_google_mock):
     provider = GoogleProvider()
 
     with pytest.raises(CredentialsValidateFailedError):
-        provider.validate_provider_credentials(
-            credentials={}
-        )
+        provider.validate_provider_credentials(credentials={})
 
-    provider.validate_provider_credentials(
-        credentials={
-            'google_api_key': os.environ.get('GOOGLE_API_KEY')
-        }
-    )
+    provider.validate_provider_credentials(credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")})

+ 117 - 143
api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py

@@ -10,87 +10,75 @@ from core.model_runtime.model_providers.huggingface_hub.llm.llm import Huggingfa
 from tests.integration_tests.model_runtime.__mock.huggingface import setup_huggingface_mock
 
 
-@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
+@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
 def test_hosted_inference_api_validate_credentials(setup_huggingface_mock):
     model = HuggingfaceHubLargeLanguageModel()
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
-            model='HuggingFaceH4/zephyr-7b-beta',
-            credentials={
-                'huggingfacehub_api_type': 'hosted_inference_api',
-                'huggingfacehub_api_token': 'invalid_key'
-            }
+            model="HuggingFaceH4/zephyr-7b-beta",
+            credentials={"huggingfacehub_api_type": "hosted_inference_api", "huggingfacehub_api_token": "invalid_key"},
         )
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
-            model='fake-model',
-            credentials={
-                'huggingfacehub_api_type': 'hosted_inference_api',
-                'huggingfacehub_api_token': 'invalid_key'
-            }
+            model="fake-model",
+            credentials={"huggingfacehub_api_type": "hosted_inference_api", "huggingfacehub_api_token": "invalid_key"},
         )
 
     model.validate_credentials(
-        model='HuggingFaceH4/zephyr-7b-beta',
+        model="HuggingFaceH4/zephyr-7b-beta",
         credentials={
-            'huggingfacehub_api_type': 'hosted_inference_api',
-            'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY')
-        }
+            "huggingfacehub_api_type": "hosted_inference_api",
+            "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
+        },
     )
 
-@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
+
+@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
 def test_hosted_inference_api_invoke_model(setup_huggingface_mock):
     model = HuggingfaceHubLargeLanguageModel()
 
     response = model.invoke(
-        model='HuggingFaceH4/zephyr-7b-beta',
+        model="HuggingFaceH4/zephyr-7b-beta",
         credentials={
-            'huggingfacehub_api_type': 'hosted_inference_api',
-            'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY')
+            "huggingfacehub_api_type": "hosted_inference_api",
+            "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
         },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Who are you?'
-            )
-        ],
+        prompt_messages=[UserPromptMessage(content="Who are you?")],
         model_parameters={
-            'temperature': 1.0,
-            'top_k': 2,
-            'top_p': 0.5,
+            "temperature": 1.0,
+            "top_k": 2,
+            "top_p": 0.5,
         },
-        stop=['How'],
+        stop=["How"],
         stream=False,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, LLMResult)
     assert len(response.message.content) > 0
 
-@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
+
+@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
 def test_hosted_inference_api_invoke_stream_model(setup_huggingface_mock):
     model = HuggingfaceHubLargeLanguageModel()
 
     response = model.invoke(
-        model='HuggingFaceH4/zephyr-7b-beta',
+        model="HuggingFaceH4/zephyr-7b-beta",
         credentials={
-            'huggingfacehub_api_type': 'hosted_inference_api',
-            'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY')
+            "huggingfacehub_api_type": "hosted_inference_api",
+            "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
         },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Who are you?'
-            )
-        ],
+        prompt_messages=[UserPromptMessage(content="Who are you?")],
         model_parameters={
-            'temperature': 1.0,
-            'top_k': 2,
-            'top_p': 0.5,
+            "temperature": 1.0,
+            "top_k": 2,
+            "top_p": 0.5,
         },
-        stop=['How'],
+        stop=["How"],
         stream=True,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, Generator)
@@ -101,86 +89,81 @@ def test_hosted_inference_api_invoke_stream_model(setup_huggingface_mock):
         assert isinstance(chunk.delta.message, AssistantPromptMessage)
         assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
 
-@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
+
+@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
 def test_inference_endpoints_text_generation_validate_credentials(setup_huggingface_mock):
     model = HuggingfaceHubLargeLanguageModel()
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
-            model='openchat/openchat_3.5',
+            model="openchat/openchat_3.5",
             credentials={
-                'huggingfacehub_api_type': 'inference_endpoints',
-                'huggingfacehub_api_token': 'invalid_key',
-                'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
-                'task_type': 'text-generation'
-            }
+                "huggingfacehub_api_type": "inference_endpoints",
+                "huggingfacehub_api_token": "invalid_key",
+                "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
+                "task_type": "text-generation",
+            },
         )
 
     model.validate_credentials(
-        model='openchat/openchat_3.5',
+        model="openchat/openchat_3.5",
         credentials={
-            'huggingfacehub_api_type': 'inference_endpoints',
-            'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
-            'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
-            'task_type': 'text-generation'
-        }
+            "huggingfacehub_api_type": "inference_endpoints",
+            "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
+            "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
+            "task_type": "text-generation",
+        },
     )
 
-@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
+
+@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
 def test_inference_endpoints_text_generation_invoke_model(setup_huggingface_mock):
     model = HuggingfaceHubLargeLanguageModel()
 
     response = model.invoke(
-        model='openchat/openchat_3.5',
+        model="openchat/openchat_3.5",
         credentials={
-            'huggingfacehub_api_type': 'inference_endpoints',
-            'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
-            'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
-            'task_type': 'text-generation'
+            "huggingfacehub_api_type": "inference_endpoints",
+            "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
+            "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
+            "task_type": "text-generation",
         },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Who are you?'
-            )
-        ],
+        prompt_messages=[UserPromptMessage(content="Who are you?")],
         model_parameters={
-            'temperature': 1.0,
-            'top_k': 2,
-            'top_p': 0.5,
+            "temperature": 1.0,
+            "top_k": 2,
+            "top_p": 0.5,
         },
-        stop=['How'],
+        stop=["How"],
         stream=False,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, LLMResult)
     assert len(response.message.content) > 0
 
-@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
+
+@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
 def test_inference_endpoints_text_generation_invoke_stream_model(setup_huggingface_mock):
     model = HuggingfaceHubLargeLanguageModel()
 
     response = model.invoke(
-        model='openchat/openchat_3.5',
+        model="openchat/openchat_3.5",
         credentials={
-            'huggingfacehub_api_type': 'inference_endpoints',
-            'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
-            'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
-            'task_type': 'text-generation'
+            "huggingfacehub_api_type": "inference_endpoints",
+            "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
+            "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
+            "task_type": "text-generation",
         },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Who are you?'
-            )
-        ],
+        prompt_messages=[UserPromptMessage(content="Who are you?")],
         model_parameters={
-            'temperature': 1.0,
-            'top_k': 2,
-            'top_p': 0.5,
+            "temperature": 1.0,
+            "top_k": 2,
+            "top_p": 0.5,
         },
-        stop=['How'],
+        stop=["How"],
         stream=True,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, Generator)
@@ -191,86 +174,81 @@ def test_inference_endpoints_text_generation_invoke_stream_model(setup_huggingfa
         assert isinstance(chunk.delta.message, AssistantPromptMessage)
         assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
 
-@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
+
+@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
 def test_inference_endpoints_text2text_generation_validate_credentials(setup_huggingface_mock):
     model = HuggingfaceHubLargeLanguageModel()
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
-            model='google/mt5-base',
+            model="google/mt5-base",
             credentials={
-                'huggingfacehub_api_type': 'inference_endpoints',
-                'huggingfacehub_api_token': 'invalid_key',
-                'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
-                'task_type': 'text2text-generation'
-            }
+                "huggingfacehub_api_type": "inference_endpoints",
+                "huggingfacehub_api_token": "invalid_key",
+                "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
+                "task_type": "text2text-generation",
+            },
         )
 
     model.validate_credentials(
-        model='google/mt5-base',
+        model="google/mt5-base",
         credentials={
-            'huggingfacehub_api_type': 'inference_endpoints',
-            'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
-            'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
-            'task_type': 'text2text-generation'
-        }
+            "huggingfacehub_api_type": "inference_endpoints",
+            "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
+            "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
+            "task_type": "text2text-generation",
+        },
     )
 
-@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
+
+@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
 def test_inference_endpoints_text2text_generation_invoke_model(setup_huggingface_mock):
     model = HuggingfaceHubLargeLanguageModel()
 
     response = model.invoke(
-        model='google/mt5-base',
+        model="google/mt5-base",
         credentials={
-            'huggingfacehub_api_type': 'inference_endpoints',
-            'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
-            'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
-            'task_type': 'text2text-generation'
+            "huggingfacehub_api_type": "inference_endpoints",
+            "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
+            "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
+            "task_type": "text2text-generation",
         },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Who are you?'
-            )
-        ],
+        prompt_messages=[UserPromptMessage(content="Who are you?")],
         model_parameters={
-            'temperature': 1.0,
-            'top_k': 2,
-            'top_p': 0.5,
+            "temperature": 1.0,
+            "top_k": 2,
+            "top_p": 0.5,
         },
-        stop=['How'],
+        stop=["How"],
         stream=False,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, LLMResult)
     assert len(response.message.content) > 0
 
-@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
+
+@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
 def test_inference_endpoints_text2text_generation_invoke_stream_model(setup_huggingface_mock):
     model = HuggingfaceHubLargeLanguageModel()
 
     response = model.invoke(
-        model='google/mt5-base',
+        model="google/mt5-base",
         credentials={
-            'huggingfacehub_api_type': 'inference_endpoints',
-            'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
-            'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
-            'task_type': 'text2text-generation'
+            "huggingfacehub_api_type": "inference_endpoints",
+            "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
+            "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
+            "task_type": "text2text-generation",
         },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Who are you?'
-            )
-        ],
+        prompt_messages=[UserPromptMessage(content="Who are you?")],
         model_parameters={
-            'temperature': 1.0,
-            'top_k': 2,
-            'top_p': 0.5,
+            "temperature": 1.0,
+            "top_k": 2,
+            "top_p": 0.5,
         },
-        stop=['How'],
+        stop=["How"],
         stream=True,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, Generator)
@@ -286,18 +264,14 @@ def test_get_num_tokens():
     model = HuggingfaceHubLargeLanguageModel()
 
     num_tokens = model.get_num_tokens(
-        model='google/mt5-base',
+        model="google/mt5-base",
         credentials={
-            'huggingfacehub_api_type': 'inference_endpoints',
-            'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
-            'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
-            'task_type': 'text2text-generation'
+            "huggingfacehub_api_type": "inference_endpoints",
+            "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
+            "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
+            "task_type": "text2text-generation",
         },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ]
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
     )
 
     assert num_tokens == 7

+ 40 - 49
api/tests/integration_tests/model_runtime/huggingface_hub/test_text_embedding.py

@@ -14,19 +14,19 @@ def test_hosted_inference_api_validate_credentials():
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
-            model='facebook/bart-base',
+            model="facebook/bart-base",
             credentials={
-                'huggingfacehub_api_type': 'hosted_inference_api',
-                'huggingfacehub_api_token': 'invalid_key',
-            }
+                "huggingfacehub_api_type": "hosted_inference_api",
+                "huggingfacehub_api_token": "invalid_key",
+            },
         )
 
     model.validate_credentials(
-        model='facebook/bart-base',
+        model="facebook/bart-base",
         credentials={
-            'huggingfacehub_api_type': 'hosted_inference_api',
-            'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
-        }
+            "huggingfacehub_api_type": "hosted_inference_api",
+            "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
+        },
     )
 
 
@@ -34,15 +34,12 @@ def test_hosted_inference_api_invoke_model():
     model = HuggingfaceHubTextEmbeddingModel()
 
     result = model.invoke(
-        model='facebook/bart-base',
+        model="facebook/bart-base",
         credentials={
-            'huggingfacehub_api_type': 'hosted_inference_api',
-            'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
+            "huggingfacehub_api_type": "hosted_inference_api",
+            "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
         },
-        texts=[
-            "hello",
-            "world"
-        ]
+        texts=["hello", "world"],
     )
 
     assert isinstance(result, TextEmbeddingResult)
@@ -55,25 +52,25 @@ def test_inference_endpoints_validate_credentials():
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
-            model='all-MiniLM-L6-v2',
+            model="all-MiniLM-L6-v2",
             credentials={
-                'huggingfacehub_api_type': 'inference_endpoints',
-                'huggingfacehub_api_token': 'invalid_key',
-                'huggingface_namespace': 'Dify-AI',
-                'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'),
-                'task_type': 'feature-extraction'
-            }
+                "huggingfacehub_api_type": "inference_endpoints",
+                "huggingfacehub_api_token": "invalid_key",
+                "huggingface_namespace": "Dify-AI",
+                "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"),
+                "task_type": "feature-extraction",
+            },
         )
 
     model.validate_credentials(
-        model='all-MiniLM-L6-v2',
+        model="all-MiniLM-L6-v2",
         credentials={
-            'huggingfacehub_api_type': 'inference_endpoints',
-            'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
-            'huggingface_namespace': 'Dify-AI',
-            'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'),
-            'task_type': 'feature-extraction'
-        }
+            "huggingfacehub_api_type": "inference_endpoints",
+            "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
+            "huggingface_namespace": "Dify-AI",
+            "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"),
+            "task_type": "feature-extraction",
+        },
     )
 
 
@@ -81,18 +78,15 @@ def test_inference_endpoints_invoke_model():
     model = HuggingfaceHubTextEmbeddingModel()
 
     result = model.invoke(
-        model='all-MiniLM-L6-v2',
+        model="all-MiniLM-L6-v2",
         credentials={
-            'huggingfacehub_api_type': 'inference_endpoints',
-            'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
-            'huggingface_namespace': 'Dify-AI',
-            'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'),
-            'task_type': 'feature-extraction'
+            "huggingfacehub_api_type": "inference_endpoints",
+            "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
+            "huggingface_namespace": "Dify-AI",
+            "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"),
+            "task_type": "feature-extraction",
         },
-        texts=[
-            "hello",
-            "world"
-        ]
+        texts=["hello", "world"],
     )
 
     assert isinstance(result, TextEmbeddingResult)
@@ -104,18 +98,15 @@ def test_get_num_tokens():
     model = HuggingfaceHubTextEmbeddingModel()
 
     num_tokens = model.get_num_tokens(
-        model='all-MiniLM-L6-v2',
+        model="all-MiniLM-L6-v2",
         credentials={
-            'huggingfacehub_api_type': 'inference_endpoints',
-            'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
-            'huggingface_namespace': 'Dify-AI',
-            'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'),
-            'task_type': 'feature-extraction'
+            "huggingfacehub_api_type": "inference_endpoints",
+            "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
+            "huggingface_namespace": "Dify-AI",
+            "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"),
+            "task_type": "feature-extraction",
         },
-        texts=[
-            "hello",
-            "world"
-        ]
+        texts=["hello", "world"],
     )
 
     assert num_tokens == 2

+ 18 - 20
api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py

@@ -10,61 +10,59 @@ from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embe
 )
 from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass
 
-MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
+MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
 
 
 @pytest.fixture
 def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch):
     if MOCK:
-        monkeypatch.setattr(TeiHelper, 'get_tei_extra_parameter', MockTEIClass.get_tei_extra_parameter)
-        monkeypatch.setattr(TeiHelper, 'invoke_tokenize', MockTEIClass.invoke_tokenize)
-        monkeypatch.setattr(TeiHelper, 'invoke_embeddings', MockTEIClass.invoke_embeddings)
-        monkeypatch.setattr(TeiHelper, 'invoke_rerank', MockTEIClass.invoke_rerank)
+        monkeypatch.setattr(TeiHelper, "get_tei_extra_parameter", MockTEIClass.get_tei_extra_parameter)
+        monkeypatch.setattr(TeiHelper, "invoke_tokenize", MockTEIClass.invoke_tokenize)
+        monkeypatch.setattr(TeiHelper, "invoke_embeddings", MockTEIClass.invoke_embeddings)
+        monkeypatch.setattr(TeiHelper, "invoke_rerank", MockTEIClass.invoke_rerank)
     yield
 
     if MOCK:
         monkeypatch.undo()
 
 
-@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
+@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True)
 def test_validate_credentials(setup_tei_mock):
     model = HuggingfaceTeiTextEmbeddingModel()
     # model name is only used in mock
-    model_name = 'embedding'
+    model_name = "embedding"
 
     if MOCK:
         # TEI Provider will check model type by API endpoint, at real server, the model type is correct.
         # So we dont need to check model type here. Only check in mock
         with pytest.raises(CredentialsValidateFailedError):
             model.validate_credentials(
-                model='reranker',
+                model="reranker",
                 credentials={
-                    'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""),
-                }
+                    "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
+                },
             )
 
     model.validate_credentials(
         model=model_name,
         credentials={
-            'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""),
-        }
+            "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
+        },
     )
 
-@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
+
+@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True)
 def test_invoke_model(setup_tei_mock):
     model = HuggingfaceTeiTextEmbeddingModel()
-    model_name = 'embedding'
+    model_name = "embedding"
 
     result = model.invoke(
         model=model_name,
         credentials={
-            'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""),
+            "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
         },
-        texts=[
-            "hello",
-            "world"
-        ],
-        user="abc-123"
+        texts=["hello", "world"],
+        user="abc-123",
     )
 
     assert isinstance(result, TextEmbeddingResult)

+ 20 - 18
api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py

@@ -11,63 +11,65 @@ from core.model_runtime.model_providers.huggingface_tei.rerank.rerank import (
 from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import TeiHelper
 from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass
 
-MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
+MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
 
 
 @pytest.fixture
 def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch):
     if MOCK:
-        monkeypatch.setattr(TeiHelper, 'get_tei_extra_parameter', MockTEIClass.get_tei_extra_parameter)
-        monkeypatch.setattr(TeiHelper, 'invoke_tokenize', MockTEIClass.invoke_tokenize)
-        monkeypatch.setattr(TeiHelper, 'invoke_embeddings', MockTEIClass.invoke_embeddings)
-        monkeypatch.setattr(TeiHelper, 'invoke_rerank', MockTEIClass.invoke_rerank)
+        monkeypatch.setattr(TeiHelper, "get_tei_extra_parameter", MockTEIClass.get_tei_extra_parameter)
+        monkeypatch.setattr(TeiHelper, "invoke_tokenize", MockTEIClass.invoke_tokenize)
+        monkeypatch.setattr(TeiHelper, "invoke_embeddings", MockTEIClass.invoke_embeddings)
+        monkeypatch.setattr(TeiHelper, "invoke_rerank", MockTEIClass.invoke_rerank)
     yield
 
     if MOCK:
         monkeypatch.undo()
 
-@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
+
+@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True)
 def test_validate_credentials(setup_tei_mock):
     model = HuggingfaceTeiRerankModel()
     # model name is only used in mock
-    model_name = 'reranker'
+    model_name = "reranker"
 
     if MOCK:
         # TEI Provider will check model type by API endpoint, at real server, the model type is correct.
         # So we dont need to check model type here. Only check in mock
         with pytest.raises(CredentialsValidateFailedError):
             model.validate_credentials(
-                model='embedding',
+                model="embedding",
                 credentials={
-                    'server_url': os.environ.get('TEI_RERANK_SERVER_URL'),
-                }
+                    "server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
+                },
             )
 
     model.validate_credentials(
         model=model_name,
         credentials={
-            'server_url': os.environ.get('TEI_RERANK_SERVER_URL'),
-        }
+            "server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
+        },
     )
 
-@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
+
+@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True)
 def test_invoke_model(setup_tei_mock):
     model = HuggingfaceTeiRerankModel()
     # model name is only used in mock
-    model_name = 'reranker'
+    model_name = "reranker"
 
     result = model.invoke(
         model=model_name,
         credentials={
-            'server_url': os.environ.get('TEI_RERANK_SERVER_URL'),
+            "server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
         },
         query="Who is Kasumi?",
         docs=[
-            "Kasumi is a girl's name of Japanese origin meaning \"mist\".",
+            'Kasumi is a girl\'s name of Japanese origin meaning "mist".',
             "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ",
-            "and she leads a team named PopiParty."
+            "and she leads a team named PopiParty.",
         ],
-        score_threshold=0.8
+        score_threshold=0.8,
     )
 
     assert isinstance(result, RerankResult)

+ 24 - 45
api/tests/integration_tests/model_runtime/hunyuan/test_llm.py

@@ -14,19 +14,15 @@ def test_validate_credentials():
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
-            model='hunyuan-standard',
-            credentials={
-                'secret_id': 'invalid_key',
-                'secret_key': 'invalid_key'
-            }
+            model="hunyuan-standard", credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"}
         )
 
     model.validate_credentials(
-        model='hunyuan-standard',
+        model="hunyuan-standard",
         credentials={
-            'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
-            'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
-        }
+            "secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
+            "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
+        },
     )
 
 
@@ -34,23 +30,16 @@ def test_invoke_model():
     model = HunyuanLargeLanguageModel()
 
     response = model.invoke(
-        model='hunyuan-standard',
+        model="hunyuan-standard",
         credentials={
-            'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
-            'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
-        },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Hi'
-            )
-        ],
-        model_parameters={
-            'temperature': 0.5,
-            'max_tokens': 10
+            "secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
+            "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
         },
-        stop=['How'],
+        prompt_messages=[UserPromptMessage(content="Hi")],
+        model_parameters={"temperature": 0.5, "max_tokens": 10},
+        stop=["How"],
         stream=False,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, LLMResult)
@@ -61,23 +50,15 @@ def test_invoke_stream_model():
     model = HunyuanLargeLanguageModel()
 
     response = model.invoke(
-        model='hunyuan-standard',
+        model="hunyuan-standard",
         credentials={
-            'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
-            'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
-        },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Hi'
-            )
-        ],
-        model_parameters={
-            'temperature': 0.5,
-            'max_tokens': 100,
-            'seed': 1234
+            "secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
+            "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
         },
+        prompt_messages=[UserPromptMessage(content="Hi")],
+        model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234},
         stream=True,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, Generator)
@@ -93,19 +74,17 @@ def test_get_num_tokens():
     model = HunyuanLargeLanguageModel()
 
     num_tokens = model.get_num_tokens(
-        model='hunyuan-standard',
+        model="hunyuan-standard",
         credentials={
-            'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
-            'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
+            "secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
+            "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
         },
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ]
+            UserPromptMessage(content="Hello World!"),
+        ],
     )
 
     assert num_tokens == 14

+ 3 - 8
api/tests/integration_tests/model_runtime/hunyuan/test_provider.py

@@ -10,16 +10,11 @@ def test_validate_provider_credentials():
     provider = HunyuanProvider()
 
     with pytest.raises(CredentialsValidateFailedError):
-        provider.validate_provider_credentials(
-            credentials={
-                'secret_id': 'invalid_key',
-                'secret_key': 'invalid_key'
-            }
-        )
+        provider.validate_provider_credentials(credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"})
 
     provider.validate_provider_credentials(
         credentials={
-            'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
-            'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
+            "secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
+            "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
         }
     )

+ 21 - 29
api/tests/integration_tests/model_runtime/hunyuan/test_text_embedding.py

@@ -12,19 +12,15 @@ def test_validate_credentials():
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
-            model='hunyuan-embedding',
-            credentials={
-                'secret_id': 'invalid_key',
-                'secret_key': 'invalid_key'
-            }
+            model="hunyuan-embedding", credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"}
         )
 
     model.validate_credentials(
-        model='hunyuan-embedding',
+        model="hunyuan-embedding",
         credentials={
-            'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
-            'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
-        }
+            "secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
+            "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
+        },
     )
 
 
@@ -32,47 +28,43 @@ def test_invoke_model():
     model = HunyuanTextEmbeddingModel()
 
     result = model.invoke(
-        model='hunyuan-embedding',
+        model="hunyuan-embedding",
         credentials={
-            'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
-            'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
+            "secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
+            "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
         },
-        texts=[
-            "hello",
-            "world"
-        ],
-        user="abc-123"
+        texts=["hello", "world"],
+        user="abc-123",
     )
 
     assert isinstance(result, TextEmbeddingResult)
     assert len(result.embeddings) == 2
     assert result.usage.total_tokens == 6
 
+
 def test_get_num_tokens():
     model = HunyuanTextEmbeddingModel()
 
     num_tokens = model.get_num_tokens(
-        model='hunyuan-embedding',
+        model="hunyuan-embedding",
         credentials={
-            'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
-            'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
+            "secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
+            "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
         },
-        texts=[
-            "hello",
-            "world"
-        ]
+        texts=["hello", "world"],
     )
 
     assert num_tokens == 2
 
+
 def test_max_chunks():
     model = HunyuanTextEmbeddingModel()
 
     result = model.invoke(
-        model='hunyuan-embedding',
+        model="hunyuan-embedding",
         credentials={
-            'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
-            'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
+            "secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
+            "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
         },
         texts=[
             "hello",
@@ -97,8 +89,8 @@ def test_max_chunks():
             "world",
             "hello",
             "world",
-        ]
+        ],
     )
 
     assert isinstance(result, TextEmbeddingResult)
-    assert len(result.embeddings) == 22
+    assert len(result.embeddings) == 22

+ 2 - 10
api/tests/integration_tests/model_runtime/jina/test_provider.py

@@ -10,14 +10,6 @@ def test_validate_provider_credentials():
     provider = JinaProvider()
 
     with pytest.raises(CredentialsValidateFailedError):
-        provider.validate_provider_credentials(
-            credentials={
-                'api_key': 'hahahaha'
-            }
-        )
+        provider.validate_provider_credentials(credentials={"api_key": "hahahaha"})
 
-    provider.validate_provider_credentials(
-        credentials={
-            'api_key': os.environ.get('JINA_API_KEY')
-        }
-    )
+    provider.validate_provider_credentials(credentials={"api_key": os.environ.get("JINA_API_KEY")})

+ 9 - 23
api/tests/integration_tests/model_runtime/jina/test_text_embedding.py

@@ -11,18 +11,10 @@ def test_validate_credentials():
     model = JinaTextEmbeddingModel()
 
     with pytest.raises(CredentialsValidateFailedError):
-        model.validate_credentials(
-            model='jina-embeddings-v2-base-en',
-            credentials={
-                'api_key': 'invalid_key'
-            }
-        )
+        model.validate_credentials(model="jina-embeddings-v2-base-en", credentials={"api_key": "invalid_key"})
 
     model.validate_credentials(
-        model='jina-embeddings-v2-base-en',
-        credentials={
-            'api_key': os.environ.get('JINA_API_KEY')
-        }
+        model="jina-embeddings-v2-base-en", credentials={"api_key": os.environ.get("JINA_API_KEY")}
     )
 
 
@@ -30,15 +22,12 @@ def test_invoke_model():
     model = JinaTextEmbeddingModel()
 
     result = model.invoke(
-        model='jina-embeddings-v2-base-en',
+        model="jina-embeddings-v2-base-en",
         credentials={
-            'api_key': os.environ.get('JINA_API_KEY'),
+            "api_key": os.environ.get("JINA_API_KEY"),
         },
-        texts=[
-            "hello",
-            "world"
-        ],
-        user="abc-123"
+        texts=["hello", "world"],
+        user="abc-123",
     )
 
     assert isinstance(result, TextEmbeddingResult)
@@ -50,14 +39,11 @@ def test_get_num_tokens():
     model = JinaTextEmbeddingModel()
 
     num_tokens = model.get_num_tokens(
-        model='jina-embeddings-v2-base-en',
+        model="jina-embeddings-v2-base-en",
         credentials={
-            'api_key': os.environ.get('JINA_API_KEY'),
+            "api_key": os.environ.get("JINA_API_KEY"),
         },
-        texts=[
-            "hello",
-            "world"
-        ]
+        texts=["hello", "world"],
     )
 
     assert num_tokens == 6

+ 3 - 3
api/tests/integration_tests/model_runtime/localai/test_embedding.py

@@ -1,4 +1,4 @@
 """
-    LocalAI Embedding Interface is temporarily unavailable due to 
-    we could not find a way to test it for now.
-"""
+LocalAI Embedding Interface is temporarily unavailable due to
+we could not find a way to test it for now.
+"""

+ 55 - 99
api/tests/integration_tests/model_runtime/localai/test_llm.py

@@ -21,99 +21,78 @@ def test_validate_credentials_for_chat_model():
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
-            model='chinese-llama-2-7b',
+            model="chinese-llama-2-7b",
             credentials={
-                'server_url': 'hahahaha',
-                'completion_type': 'completion',
-            }
+                "server_url": "hahahaha",
+                "completion_type": "completion",
+            },
         )
 
     model.validate_credentials(
-        model='chinese-llama-2-7b',
+        model="chinese-llama-2-7b",
         credentials={
-            'server_url': os.environ.get('LOCALAI_SERVER_URL'),
-            'completion_type': 'completion',
-        }
+            "server_url": os.environ.get("LOCALAI_SERVER_URL"),
+            "completion_type": "completion",
+        },
     )
 
+
 def test_invoke_completion_model():
     model = LocalAILanguageModel()
 
     response = model.invoke(
-        model='chinese-llama-2-7b',
+        model="chinese-llama-2-7b",
         credentials={
-            'server_url': os.environ.get('LOCALAI_SERVER_URL'),
-            'completion_type': 'completion',
-        },
-        prompt_messages=[
-            UserPromptMessage(
-                content='ping'
-            )
-        ],
-        model_parameters={
-            'temperature': 0.7,
-            'top_p': 1.0,
-            'max_tokens': 10
+            "server_url": os.environ.get("LOCALAI_SERVER_URL"),
+            "completion_type": "completion",
         },
+        prompt_messages=[UserPromptMessage(content="ping")],
+        model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
         stop=[],
         user="abc-123",
-        stream=False
+        stream=False,
     )
 
     assert isinstance(response, LLMResult)
     assert len(response.message.content) > 0
     assert response.usage.total_tokens > 0
 
+
 def test_invoke_chat_model():
     model = LocalAILanguageModel()
 
     response = model.invoke(
-        model='chinese-llama-2-7b',
+        model="chinese-llama-2-7b",
         credentials={
-            'server_url': os.environ.get('LOCALAI_SERVER_URL'),
-            'completion_type': 'chat_completion',
-        },
-        prompt_messages=[
-            UserPromptMessage(
-                content='ping'
-            )
-        ],
-        model_parameters={
-            'temperature': 0.7,
-            'top_p': 1.0,
-            'max_tokens': 10
+            "server_url": os.environ.get("LOCALAI_SERVER_URL"),
+            "completion_type": "chat_completion",
         },
+        prompt_messages=[UserPromptMessage(content="ping")],
+        model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
         stop=[],
         user="abc-123",
-        stream=False
+        stream=False,
     )
 
     assert isinstance(response, LLMResult)
     assert len(response.message.content) > 0
     assert response.usage.total_tokens > 0
 
+
 def test_invoke_stream_completion_model():
     model = LocalAILanguageModel()
 
     response = model.invoke(
-        model='chinese-llama-2-7b',
+        model="chinese-llama-2-7b",
         credentials={
-            'server_url': os.environ.get('LOCALAI_SERVER_URL'),
-            'completion_type': 'completion',
+            "server_url": os.environ.get("LOCALAI_SERVER_URL"),
+            "completion_type": "completion",
         },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ],
-        model_parameters={
-            'temperature': 0.7,
-            'top_p': 1.0,
-            'max_tokens': 10
-        },
-        stop=['you'],
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
+        model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
+        stop=["you"],
         stream=True,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, Generator)
@@ -123,28 +102,21 @@ def test_invoke_stream_completion_model():
         assert isinstance(chunk.delta.message, AssistantPromptMessage)
         assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
 
+
 def test_invoke_stream_chat_model():
     model = LocalAILanguageModel()
 
     response = model.invoke(
-        model='chinese-llama-2-7b',
+        model="chinese-llama-2-7b",
         credentials={
-            'server_url': os.environ.get('LOCALAI_SERVER_URL'),
-            'completion_type': 'chat_completion',
-        },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ],
-        model_parameters={
-            'temperature': 0.7,
-            'top_p': 1.0,
-            'max_tokens': 10
+            "server_url": os.environ.get("LOCALAI_SERVER_URL"),
+            "completion_type": "chat_completion",
         },
-        stop=['you'],
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
+        model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
+        stop=["you"],
         stream=True,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, Generator)
@@ -154,64 +126,48 @@ def test_invoke_stream_chat_model():
         assert isinstance(chunk.delta.message, AssistantPromptMessage)
         assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
 
+
 def test_get_num_tokens():
     model = LocalAILanguageModel()
 
     num_tokens = model.get_num_tokens(
-        model='????',
+        model="????",
         credentials={
-            'server_url': os.environ.get('LOCALAI_SERVER_URL'),
-            'completion_type': 'chat_completion',
+            "server_url": os.environ.get("LOCALAI_SERVER_URL"),
+            "completion_type": "chat_completion",
         },
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
+            UserPromptMessage(content="Hello World!"),
         ],
         tools=[
             PromptMessageTool(
-                name='get_current_weather',
-                description='Get the current weather in a given location',
+                name="get_current_weather",
+                description="Get the current weather in a given location",
                 parameters={
                     "type": "object",
                     "properties": {
-                        "location": {
-                        "type": "string",
-                            "description": "The city and state e.g. San Francisco, CA"
-                        },
-                        "unit": {
-                            "type": "string",
-                            "enum": [
-                                "c",
-                                "f"
-                            ]
-                        }
+                        "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
+                        "unit": {"type": "string", "enum": ["c", "f"]},
                     },
-                    "required": [
-                        "location"
-                    ]
-                }
+                    "required": ["location"],
+                },
             )
-        ]
+        ],
     )
 
     assert isinstance(num_tokens, int)
     assert num_tokens == 77
 
     num_tokens = model.get_num_tokens(
-        model='????',
+        model="????",
         credentials={
-            'server_url': os.environ.get('LOCALAI_SERVER_URL'),
-            'completion_type': 'chat_completion',
+            "server_url": os.environ.get("LOCALAI_SERVER_URL"),
+            "completion_type": "chat_completion",
         },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ],
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
     )
 
     assert isinstance(num_tokens, int)

+ 25 - 31
api/tests/integration_tests/model_runtime/localai/test_rerank.py

@@ -12,30 +12,29 @@ def test_validate_credentials_for_chat_model():
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
-            model='bge-reranker-v2-m3',
+            model="bge-reranker-v2-m3",
             credentials={
-                'server_url': 'hahahaha',
-                'completion_type': 'completion',
-            }
+                "server_url": "hahahaha",
+                "completion_type": "completion",
+            },
         )
 
     model.validate_credentials(
-        model='bge-reranker-base',
+        model="bge-reranker-base",
         credentials={
-            'server_url': os.environ.get('LOCALAI_SERVER_URL'),
-            'completion_type': 'completion',
-        }
+            "server_url": os.environ.get("LOCALAI_SERVER_URL"),
+            "completion_type": "completion",
+        },
     )
 
+
 def test_invoke_rerank_model():
     model = LocalaiRerankModel()
 
     response = model.invoke(
-        model='bge-reranker-base',
-        credentials={
-            'server_url': os.environ.get('LOCALAI_SERVER_URL')
-        },
-        query='Organic skincare products for sensitive skin',
+        model="bge-reranker-base",
+        credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")},
+        query="Organic skincare products for sensitive skin",
         docs=[
             "Eco-friendly kitchenware for modern homes",
             "Biodegradable cleaning supplies for eco-conscious consumers",
@@ -45,43 +44,38 @@ def test_invoke_rerank_model():
             "Sustainable gardening tools and compost solutions",
             "Sensitive skin-friendly facial cleansers and toners",
             "Organic food wraps and storage solutions",
-            "Yoga mats made from recycled materials"
+            "Yoga mats made from recycled materials",
         ],
         top_n=3,
         score_threshold=0.75,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, RerankResult)
     assert len(response.docs) == 3
 
+
 def test__invoke():
     model = LocalaiRerankModel()
 
     # Test case 1: Empty docs
     result = model._invoke(
-        model='bge-reranker-base',
-        credentials={
-            'server_url': 'https://example.com',
-            'api_key': '1234567890'
-        },
-        query='Organic skincare products for sensitive skin',
+        model="bge-reranker-base",
+        credentials={"server_url": "https://example.com", "api_key": "1234567890"},
+        query="Organic skincare products for sensitive skin",
         docs=[],
         top_n=3,
         score_threshold=0.75,
-        user="abc-123"
+        user="abc-123",
     )
     assert isinstance(result, RerankResult)
     assert len(result.docs) == 0
 
     # Test case 2: Valid invocation
     result = model._invoke(
-        model='bge-reranker-base',
-        credentials={
-            'server_url': 'https://example.com',
-            'api_key': '1234567890'
-        },
-        query='Organic skincare products for sensitive skin',
+        model="bge-reranker-base",
+        credentials={"server_url": "https://example.com", "api_key": "1234567890"},
+        query="Organic skincare products for sensitive skin",
         docs=[
             "Eco-friendly kitchenware for modern homes",
             "Biodegradable cleaning supplies for eco-conscious consumers",
@@ -91,12 +85,12 @@ def test__invoke():
             "Sustainable gardening tools and compost solutions",
             "Sensitive skin-friendly facial cleansers and toners",
             "Organic food wraps and storage solutions",
-            "Yoga mats made from recycled materials"
+            "Yoga mats made from recycled materials",
         ],
         top_n=3,
         score_threshold=0.75,
-        user="abc-123"
+        user="abc-123",
     )
     assert isinstance(result, RerankResult)
     assert len(result.docs) == 3
-    assert all(isinstance(doc, RerankDocument) for doc in result.docs)
+    assert all(isinstance(doc, RerankDocument) for doc in result.docs)

+ 9 - 21
api/tests/integration_tests/model_runtime/localai/test_speech2text.py

@@ -10,19 +10,9 @@ def test_validate_credentials():
     model = LocalAISpeech2text()
 
     with pytest.raises(CredentialsValidateFailedError):
-        model.validate_credentials(
-            model='whisper-1',
-            credentials={
-                'server_url': 'invalid_url'
-            }
-        )
+        model.validate_credentials(model="whisper-1", credentials={"server_url": "invalid_url"})
 
-    model.validate_credentials(
-        model='whisper-1',
-        credentials={
-            'server_url': os.environ.get('LOCALAI_SERVER_URL')
-        }
-    )
+    model.validate_credentials(model="whisper-1", credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")})
 
 
 def test_invoke_model():
@@ -32,23 +22,21 @@ def test_invoke_model():
     current_dir = os.path.dirname(os.path.abspath(__file__))
 
     # Get assets directory
-    assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
+    assets_dir = os.path.join(os.path.dirname(current_dir), "assets")
 
     # Construct the path to the audio file
-    audio_file_path = os.path.join(assets_dir, 'audio.mp3')
+    audio_file_path = os.path.join(assets_dir, "audio.mp3")
 
     # Open the file and get the file object
-    with open(audio_file_path, 'rb') as audio_file:
+    with open(audio_file_path, "rb") as audio_file:
         file = audio_file
 
         result = model.invoke(
-            model='whisper-1',
-            credentials={
-                'server_url': os.environ.get('LOCALAI_SERVER_URL')
-            },
+            model="whisper-1",
+            credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")},
             file=file,
-            user="abc-123"
+            user="abc-123",
         )
 
         assert isinstance(result, str)
-        assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10'
+        assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10"

+ 17 - 24
api/tests/integration_tests/model_runtime/minimax/test_embedding.py

@@ -12,54 +12,47 @@ def test_validate_credentials():
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
-            model='embo-01',
-            credentials={
-                'minimax_api_key': 'invalid_key',
-                'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
-            }
+            model="embo-01",
+            credentials={"minimax_api_key": "invalid_key", "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID")},
         )
 
     model.validate_credentials(
-        model='embo-01',
+        model="embo-01",
         credentials={
-            'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
-            'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
-        }
+            "minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
+            "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
+        },
     )
 
+
 def test_invoke_model():
     model = MinimaxTextEmbeddingModel()
 
     result = model.invoke(
-        model='embo-01',
+        model="embo-01",
         credentials={
-            'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
-            'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
+            "minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
+            "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
         },
-        texts=[
-            "hello",
-            "world"
-        ],
-        user="abc-123"
+        texts=["hello", "world"],
+        user="abc-123",
     )
 
     assert isinstance(result, TextEmbeddingResult)
     assert len(result.embeddings) == 2
     assert result.usage.total_tokens == 16
 
+
 def test_get_num_tokens():
     model = MinimaxTextEmbeddingModel()
 
     num_tokens = model.get_num_tokens(
-        model='embo-01',
+        model="embo-01",
         credentials={
-            'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
-            'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
+            "minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
+            "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
         },
-        texts=[
-            "hello",
-            "world"
-        ]
+        texts=["hello", "world"],
     )
 
     assert num_tokens == 2

+ 46 - 61
api/tests/integration_tests/model_runtime/minimax/test_llm.py

@@ -17,79 +17,70 @@ def test_predefined_models():
     assert len(model_schemas) >= 1
     assert isinstance(model_schemas[0], AIModelEntity)
 
+
 def test_validate_credentials_for_chat_model():
     sleep(3)
     model = MinimaxLargeLanguageModel()
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
-            model='abab5.5-chat',
-            credentials={
-                'minimax_api_key': 'invalid_key',
-                'minimax_group_id': 'invalid_key'
-            }
+            model="abab5.5-chat", credentials={"minimax_api_key": "invalid_key", "minimax_group_id": "invalid_key"}
         )
 
     model.validate_credentials(
-        model='abab5.5-chat',
+        model="abab5.5-chat",
         credentials={
-            'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
-            'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
-        }
+            "minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
+            "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
+        },
     )
 
+
 def test_invoke_model():
     sleep(3)
     model = MinimaxLargeLanguageModel()
 
     response = model.invoke(
-        model='abab5-chat',
+        model="abab5-chat",
         credentials={
-            'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
-            'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
+            "minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
+            "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
         },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ],
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
         model_parameters={
-            'temperature': 0.7,
-            'top_p': 1.0,
-            'top_k': 1,
+            "temperature": 0.7,
+            "top_p": 1.0,
+            "top_k": 1,
         },
-        stop=['you'],
+        stop=["you"],
         user="abc-123",
-        stream=False
+        stream=False,
     )
 
     assert isinstance(response, LLMResult)
     assert len(response.message.content) > 0
     assert response.usage.total_tokens > 0
 
+
 def test_invoke_stream_model():
     sleep(3)
     model = MinimaxLargeLanguageModel()
 
     response = model.invoke(
-        model='abab5.5-chat',
+        model="abab5.5-chat",
         credentials={
-            'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
-            'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
+            "minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
+            "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
         },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ],
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
         model_parameters={
-            'temperature': 0.7,
-            'top_p': 1.0,
-            'top_k': 1,
+            "temperature": 0.7,
+            "top_p": 1.0,
+            "top_k": 1,
         },
-        stop=['you'],
+        stop=["you"],
         stream=True,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, Generator)
@@ -99,34 +90,31 @@ def test_invoke_stream_model():
         assert isinstance(chunk.delta.message, AssistantPromptMessage)
         assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
 
+
 def test_invoke_with_search():
     sleep(3)
     model = MinimaxLargeLanguageModel()
 
     response = model.invoke(
-        model='abab5.5-chat',
+        model="abab5.5-chat",
         credentials={
-            'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
-            'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
+            "minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
+            "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
         },
-        prompt_messages=[
-            UserPromptMessage(
-                content='北京今天的天气怎么样'
-            )
-        ],
+        prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
         model_parameters={
-            'temperature': 0.7,
-            'top_p': 1.0,
-            'top_k': 1,
-            'plugin_web_search': True,
+            "temperature": 0.7,
+            "top_p": 1.0,
+            "top_k": 1,
+            "plugin_web_search": True,
         },
-        stop=['you'],
+        stop=["you"],
         stream=True,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, Generator)
-    total_message = ''
+    total_message = ""
     for chunk in response:
         assert isinstance(chunk, LLMResultChunk)
         assert isinstance(chunk.delta, LLMResultChunkDelta)
@@ -134,25 +122,22 @@ def test_invoke_with_search():
         total_message += chunk.delta.message.content
         assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True
 
-    assert '参考资料' in total_message
+    assert "参考资料" in total_message
+
 
 def test_get_num_tokens():
     sleep(3)
     model = MinimaxLargeLanguageModel()
 
     response = model.get_num_tokens(
-        model='abab5.5-chat',
+        model="abab5.5-chat",
         credentials={
-            'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
-            'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
+            "minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
+            "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
         },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ],
-        tools=[]
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
+        tools=[],
     )
 
     assert isinstance(response, int)
-    assert response == 30
+    assert response == 30

+ 4 - 4
api/tests/integration_tests/model_runtime/minimax/test_provider.py

@@ -12,14 +12,14 @@ def test_validate_provider_credentials():
     with pytest.raises(CredentialsValidateFailedError):
         provider.validate_provider_credentials(
             credentials={
-                'minimax_api_key': 'hahahaha',
-                'minimax_group_id': '123',
+                "minimax_api_key": "hahahaha",
+                "minimax_group_id": "123",
             }
         )
 
     provider.validate_provider_credentials(
         credentials={
-            'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
-            'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID'),
+            "minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
+            "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
         }
     )

+ 23 - 47
api/tests/integration_tests/model_runtime/novita/test_llm.py

@@ -19,19 +19,12 @@ def test_validate_credentials():
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
-            model='meta-llama/llama-3-8b-instruct',
-            credentials={
-                'api_key': 'invalid_key',
-                'mode': 'chat'
-            }
+            model="meta-llama/llama-3-8b-instruct", credentials={"api_key": "invalid_key", "mode": "chat"}
         )
 
     model.validate_credentials(
-        model='meta-llama/llama-3-8b-instruct',
-        credentials={
-            'api_key': os.environ.get('NOVITA_API_KEY'),
-            'mode': 'chat'
-        }
+        model="meta-llama/llama-3-8b-instruct",
+        credentials={"api_key": os.environ.get("NOVITA_API_KEY"), "mode": "chat"},
     )
 
 
@@ -39,27 +32,22 @@ def test_invoke_model():
     model = NovitaLargeLanguageModel()
 
     response = model.invoke(
-        model='meta-llama/llama-3-8b-instruct',
-        credentials={
-            'api_key': os.environ.get('NOVITA_API_KEY'),
-            'mode': 'completion'
-        },
+        model="meta-llama/llama-3-8b-instruct",
+        credentials={"api_key": os.environ.get("NOVITA_API_KEY"), "mode": "completion"},
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Who are you?'
-            )
+            UserPromptMessage(content="Who are you?"),
         ],
         model_parameters={
-            'temperature': 1.0,
-            'top_p': 0.5,
-            'max_tokens': 10,
+            "temperature": 1.0,
+            "top_p": 0.5,
+            "max_tokens": 10,
         },
-        stop=['How'],
+        stop=["How"],
         stream=False,
-        user="novita"
+        user="novita",
     )
 
     assert isinstance(response, LLMResult)
@@ -70,27 +58,17 @@ def test_invoke_stream_model():
     model = NovitaLargeLanguageModel()
 
     response = model.invoke(
-        model='meta-llama/llama-3-8b-instruct',
-        credentials={
-            'api_key': os.environ.get('NOVITA_API_KEY'),
-            'mode': 'chat'
-        },
+        model="meta-llama/llama-3-8b-instruct",
+        credentials={"api_key": os.environ.get("NOVITA_API_KEY"), "mode": "chat"},
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Who are you?'
-            )
+            UserPromptMessage(content="Who are you?"),
         ],
-        model_parameters={
-            'temperature': 1.0,
-            'top_k': 2,
-            'top_p': 0.5,
-            'max_tokens': 100
-        },
+        model_parameters={"temperature": 1.0, "top_k": 2, "top_p": 0.5, "max_tokens": 100},
         stream=True,
-        user="novita"
+        user="novita",
     )
 
     assert isinstance(response, Generator)
@@ -105,18 +83,16 @@ def test_get_num_tokens():
     model = NovitaLargeLanguageModel()
 
     num_tokens = model.get_num_tokens(
-        model='meta-llama/llama-3-8b-instruct',
+        model="meta-llama/llama-3-8b-instruct",
         credentials={
-            'api_key': os.environ.get('NOVITA_API_KEY'),
+            "api_key": os.environ.get("NOVITA_API_KEY"),
         },
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ]
+            UserPromptMessage(content="Hello World!"),
+        ],
     )
 
     assert isinstance(num_tokens, int)

+ 2 - 4
api/tests/integration_tests/model_runtime/novita/test_provider.py

@@ -10,12 +10,10 @@ def test_validate_provider_credentials():
     provider = NovitaProvider()
 
     with pytest.raises(CredentialsValidateFailedError):
-        provider.validate_provider_credentials(
-            credentials={}
-        )
+        provider.validate_provider_credentials(credentials={})
 
     provider.validate_provider_credentials(
         credentials={
-            'api_key': os.environ.get('NOVITA_API_KEY'),
+            "api_key": os.environ.get("NOVITA_API_KEY"),
         }
     )

Разлика између датотеке није приказан због своје велике величине
+ 56 - 88
api/tests/integration_tests/model_runtime/ollama/test_llm.py


+ 21 - 27
api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py

@@ -12,21 +12,21 @@ def test_validate_credentials():
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
-            model='mistral:text',
+            model="mistral:text",
             credentials={
-                'base_url': 'http://localhost:21434',
-                'mode': 'chat',
-                'context_size': 4096,
-            }
+                "base_url": "http://localhost:21434",
+                "mode": "chat",
+                "context_size": 4096,
+            },
         )
 
     model.validate_credentials(
-        model='mistral:text',
+        model="mistral:text",
         credentials={
-            'base_url': os.environ.get('OLLAMA_BASE_URL'),
-            'mode': 'chat',
-            'context_size': 4096,
-        }
+            "base_url": os.environ.get("OLLAMA_BASE_URL"),
+            "mode": "chat",
+            "context_size": 4096,
+        },
     )
 
 
@@ -34,17 +34,14 @@ def test_invoke_model():
     model = OllamaEmbeddingModel()
 
     result = model.invoke(
-        model='mistral:text',
+        model="mistral:text",
         credentials={
-            'base_url': os.environ.get('OLLAMA_BASE_URL'),
-            'mode': 'chat',
-            'context_size': 4096,
+            "base_url": os.environ.get("OLLAMA_BASE_URL"),
+            "mode": "chat",
+            "context_size": 4096,
         },
-        texts=[
-            "hello",
-            "world"
-        ],
-        user="abc-123"
+        texts=["hello", "world"],
+        user="abc-123",
     )
 
     assert isinstance(result, TextEmbeddingResult)
@@ -56,16 +53,13 @@ def test_get_num_tokens():
     model = OllamaEmbeddingModel()
 
     num_tokens = model.get_num_tokens(
-        model='mistral:text',
+        model="mistral:text",
         credentials={
-            'base_url': os.environ.get('OLLAMA_BASE_URL'),
-            'mode': 'chat',
-            'context_size': 4096,
+            "base_url": os.environ.get("OLLAMA_BASE_URL"),
+            "mode": "chat",
+            "context_size": 4096,
         },
-        texts=[
-            "hello",
-            "world"
-        ]
+        texts=["hello", "world"],
     )
 
     assert num_tokens == 2

Разлика између датотеке није приказан због своје велике величине
+ 43 - 78
api/tests/integration_tests/model_runtime/openai/test_llm.py


+ 11 - 22
api/tests/integration_tests/model_runtime/openai/test_moderation.py

@@ -7,48 +7,37 @@ from core.model_runtime.model_providers.openai.moderation.moderation import Open
 from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
 
 
-@pytest.mark.parametrize('setup_openai_mock', [['moderation']], indirect=True)
+@pytest.mark.parametrize("setup_openai_mock", [["moderation"]], indirect=True)
 def test_validate_credentials(setup_openai_mock):
     model = OpenAIModerationModel()
 
     with pytest.raises(CredentialsValidateFailedError):
-        model.validate_credentials(
-            model='text-moderation-stable',
-            credentials={
-                'openai_api_key': 'invalid_key'
-            }
-        )
+        model.validate_credentials(model="text-moderation-stable", credentials={"openai_api_key": "invalid_key"})
 
     model.validate_credentials(
-        model='text-moderation-stable',
-        credentials={
-            'openai_api_key': os.environ.get('OPENAI_API_KEY')
-        }
+        model="text-moderation-stable", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}
     )
 
-@pytest.mark.parametrize('setup_openai_mock', [['moderation']], indirect=True)
+
+@pytest.mark.parametrize("setup_openai_mock", [["moderation"]], indirect=True)
 def test_invoke_model(setup_openai_mock):
     model = OpenAIModerationModel()
 
     result = model.invoke(
-        model='text-moderation-stable',
-        credentials={
-            'openai_api_key': os.environ.get('OPENAI_API_KEY')
-        },
+        model="text-moderation-stable",
+        credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
         text="hello",
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(result, bool)
     assert result is False
 
     result = model.invoke(
-        model='text-moderation-stable',
-        credentials={
-            'openai_api_key': os.environ.get('OPENAI_API_KEY')
-        },
+        model="text-moderation-stable",
+        credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
         text="i will kill you",
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(result, bool)

+ 3 - 9
api/tests/integration_tests/model_runtime/openai/test_provider.py

@@ -7,17 +7,11 @@ from core.model_runtime.model_providers.openai.openai import OpenAIProvider
 from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
 
 
-@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
+@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
 def test_validate_provider_credentials(setup_openai_mock):
     provider = OpenAIProvider()
 
     with pytest.raises(CredentialsValidateFailedError):
-        provider.validate_provider_credentials(
-            credentials={}
-        )
+        provider.validate_provider_credentials(credentials={})
 
-    provider.validate_provider_credentials(
-        credentials={
-            'openai_api_key': os.environ.get('OPENAI_API_KEY')
-        }
-    )
+    provider.validate_provider_credentials(credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")})

+ 12 - 23
api/tests/integration_tests/model_runtime/openai/test_speech2text.py

@@ -7,26 +7,17 @@ from core.model_runtime.model_providers.openai.speech2text.speech2text import Op
 from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
 
 
-@pytest.mark.parametrize('setup_openai_mock', [['speech2text']], indirect=True)
+@pytest.mark.parametrize("setup_openai_mock", [["speech2text"]], indirect=True)
 def test_validate_credentials(setup_openai_mock):
     model = OpenAISpeech2TextModel()
 
     with pytest.raises(CredentialsValidateFailedError):
-        model.validate_credentials(
-            model='whisper-1',
-            credentials={
-                'openai_api_key': 'invalid_key'
-            }
-        )
+        model.validate_credentials(model="whisper-1", credentials={"openai_api_key": "invalid_key"})
+
+    model.validate_credentials(model="whisper-1", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")})
 
-    model.validate_credentials(
-        model='whisper-1',
-        credentials={
-            'openai_api_key': os.environ.get('OPENAI_API_KEY')
-        }
-    )
 
-@pytest.mark.parametrize('setup_openai_mock', [['speech2text']], indirect=True)
+@pytest.mark.parametrize("setup_openai_mock", [["speech2text"]], indirect=True)
 def test_invoke_model(setup_openai_mock):
     model = OpenAISpeech2TextModel()
 
@@ -34,23 +25,21 @@ def test_invoke_model(setup_openai_mock):
     current_dir = os.path.dirname(os.path.abspath(__file__))
 
     # Get assets directory
-    assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
+    assets_dir = os.path.join(os.path.dirname(current_dir), "assets")
 
     # Construct the path to the audio file
-    audio_file_path = os.path.join(assets_dir, 'audio.mp3')
+    audio_file_path = os.path.join(assets_dir, "audio.mp3")
 
     # Open the file and get the file object
-    with open(audio_file_path, 'rb') as audio_file:
+    with open(audio_file_path, "rb") as audio_file:
         file = audio_file
 
         result = model.invoke(
-            model='whisper-1',
-            credentials={
-                'openai_api_key': os.environ.get('OPENAI_API_KEY')
-            },
+            model="whisper-1",
+            credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
             file=file,
-            user="abc-123"
+            user="abc-123",
         )
 
         assert isinstance(result, str)
-        assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10'
+        assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10"

+ 12 - 33
api/tests/integration_tests/model_runtime/openai/test_text_embedding.py

@@ -8,42 +8,27 @@ from core.model_runtime.model_providers.openai.text_embedding.text_embedding imp
 from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
 
 
-@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
+@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
 def test_validate_credentials(setup_openai_mock):
     model = OpenAITextEmbeddingModel()
 
     with pytest.raises(CredentialsValidateFailedError):
-        model.validate_credentials(
-            model='text-embedding-ada-002',
-            credentials={
-                'openai_api_key': 'invalid_key'
-            }
-        )
+        model.validate_credentials(model="text-embedding-ada-002", credentials={"openai_api_key": "invalid_key"})
 
     model.validate_credentials(
-        model='text-embedding-ada-002',
-        credentials={
-            'openai_api_key': os.environ.get('OPENAI_API_KEY')
-        }
+        model="text-embedding-ada-002", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}
     )
 
-@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
+
+@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
 def test_invoke_model(setup_openai_mock):
     model = OpenAITextEmbeddingModel()
 
     result = model.invoke(
-        model='text-embedding-ada-002',
-        credentials={
-            'openai_api_key': os.environ.get('OPENAI_API_KEY'),
-            'openai_api_base': 'https://api.openai.com'
-        },
-        texts=[
-            "hello",
-            "world",
-            " ".join(["long_text"] * 100),
-            " ".join(["another_long_text"] * 100)
-        ],
-        user="abc-123"
+        model="text-embedding-ada-002",
+        credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY"), "openai_api_base": "https://api.openai.com"},
+        texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)],
+        user="abc-123",
     )
 
     assert isinstance(result, TextEmbeddingResult)
@@ -55,15 +40,9 @@ def test_get_num_tokens():
     model = OpenAITextEmbeddingModel()
 
     num_tokens = model.get_num_tokens(
-        model='text-embedding-ada-002',
-        credentials={
-            'openai_api_key': os.environ.get('OPENAI_API_KEY'),
-            'openai_api_base': 'https://api.openai.com'
-        },
-        texts=[
-            "hello",
-            "world"
-        ]
+        model="text-embedding-ada-002",
+        credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY"), "openai_api_base": "https://api.openai.com"},
+        texts=["hello", "world"],
     )
 
     assert num_tokens == 2

+ 60 - 89
api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py

@@ -23,21 +23,17 @@ def test_validate_credentials():
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
-            model='mistralai/Mixtral-8x7B-Instruct-v0.1',
-            credentials={
-                'api_key': 'invalid_key',
-                'endpoint_url': 'https://api.together.xyz/v1/',
-                'mode': 'chat'
-            }
+            model="mistralai/Mixtral-8x7B-Instruct-v0.1",
+            credentials={"api_key": "invalid_key", "endpoint_url": "https://api.together.xyz/v1/", "mode": "chat"},
         )
 
     model.validate_credentials(
-        model='mistralai/Mixtral-8x7B-Instruct-v0.1',
+        model="mistralai/Mixtral-8x7B-Instruct-v0.1",
         credentials={
-            'api_key': os.environ.get('TOGETHER_API_KEY'),
-            'endpoint_url': 'https://api.together.xyz/v1/',
-            'mode': 'chat'
-        }
+            "api_key": os.environ.get("TOGETHER_API_KEY"),
+            "endpoint_url": "https://api.together.xyz/v1/",
+            "mode": "chat",
+        },
     )
 
 
@@ -45,28 +41,26 @@ def test_invoke_model():
     model = OAIAPICompatLargeLanguageModel()
 
     response = model.invoke(
-        model='mistralai/Mixtral-8x7B-Instruct-v0.1',
+        model="mistralai/Mixtral-8x7B-Instruct-v0.1",
         credentials={
-            'api_key': os.environ.get('TOGETHER_API_KEY'),
-            'endpoint_url': 'https://api.together.xyz/v1/',
-            'mode': 'completion'
+            "api_key": os.environ.get("TOGETHER_API_KEY"),
+            "endpoint_url": "https://api.together.xyz/v1/",
+            "mode": "completion",
         },
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Who are you?'
-            )
+            UserPromptMessage(content="Who are you?"),
         ],
         model_parameters={
-            'temperature': 1.0,
-            'top_k': 2,
-            'top_p': 0.5,
+            "temperature": 1.0,
+            "top_k": 2,
+            "top_p": 0.5,
         },
-        stop=['How'],
+        stop=["How"],
         stream=False,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, LLMResult)
@@ -77,29 +71,27 @@ def test_invoke_stream_model():
     model = OAIAPICompatLargeLanguageModel()
 
     response = model.invoke(
-        model='mistralai/Mixtral-8x7B-Instruct-v0.1',
+        model="mistralai/Mixtral-8x7B-Instruct-v0.1",
         credentials={
-            'api_key': os.environ.get('TOGETHER_API_KEY'),
-            'endpoint_url': 'https://api.together.xyz/v1/',
-            'mode': 'chat',
-            'stream_mode_delimiter': '\\n\\n'
+            "api_key": os.environ.get("TOGETHER_API_KEY"),
+            "endpoint_url": "https://api.together.xyz/v1/",
+            "mode": "chat",
+            "stream_mode_delimiter": "\\n\\n",
         },
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Who are you?'
-            )
+            UserPromptMessage(content="Who are you?"),
         ],
         model_parameters={
-            'temperature': 1.0,
-            'top_k': 2,
-            'top_p': 0.5,
+            "temperature": 1.0,
+            "top_k": 2,
+            "top_p": 0.5,
         },
-        stop=['How'],
+        stop=["How"],
         stream=True,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, Generator)
@@ -114,28 +106,26 @@ def test_invoke_stream_model_without_delimiter():
     model = OAIAPICompatLargeLanguageModel()
 
     response = model.invoke(
-        model='mistralai/Mixtral-8x7B-Instruct-v0.1',
+        model="mistralai/Mixtral-8x7B-Instruct-v0.1",
         credentials={
-            'api_key': os.environ.get('TOGETHER_API_KEY'),
-            'endpoint_url': 'https://api.together.xyz/v1/',
-            'mode': 'chat'
+            "api_key": os.environ.get("TOGETHER_API_KEY"),
+            "endpoint_url": "https://api.together.xyz/v1/",
+            "mode": "chat",
         },
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Who are you?'
-            )
+            UserPromptMessage(content="Who are you?"),
         ],
         model_parameters={
-            'temperature': 1.0,
-            'top_k': 2,
-            'top_p': 0.5,
+            "temperature": 1.0,
+            "top_k": 2,
+            "top_p": 0.5,
         },
-        stop=['How'],
+        stop=["How"],
         stream=True,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, Generator)
@@ -151,51 +141,37 @@ def test_invoke_chat_model_with_tools():
     model = OAIAPICompatLargeLanguageModel()
 
     result = model.invoke(
-        model='gpt-3.5-turbo',
+        model="gpt-3.5-turbo",
         credentials={
-            'api_key': os.environ.get('OPENAI_API_KEY'),
-            'endpoint_url': 'https://api.openai.com/v1/',
-            'mode': 'chat'
+            "api_key": os.environ.get("OPENAI_API_KEY"),
+            "endpoint_url": "https://api.openai.com/v1/",
+            "mode": "chat",
         },
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
             UserPromptMessage(
                 content="what's the weather today in London?",
-            )
+            ),
         ],
         tools=[
             PromptMessageTool(
-                name='get_weather',
-                description='Determine weather in my location',
+                name="get_weather",
+                description="Determine weather in my location",
                 parameters={
                     "type": "object",
                     "properties": {
-                        "location": {
-                            "type": "string",
-                            "description": "The city and state e.g. San Francisco, CA"
-                        },
-                        "unit": {
-                            "type": "string",
-                            "enum": [
-                                "celsius",
-                                "fahrenheit"
-                            ]
-                        }
+                        "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
+                        "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
                     },
-                    "required": [
-                        "location"
-                    ]
-                }
+                    "required": ["location"],
+                },
             ),
         ],
-        model_parameters={
-            'temperature': 0.0,
-            'max_tokens': 1024
-        },
+        model_parameters={"temperature": 0.0, "max_tokens": 1024},
         stream=False,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(result, LLMResult)
@@ -207,19 +183,14 @@ def test_get_num_tokens():
     model = OAIAPICompatLargeLanguageModel()
 
     num_tokens = model.get_num_tokens(
-        model='mistralai/Mixtral-8x7B-Instruct-v0.1',
-        credentials={
-            'api_key': os.environ.get('OPENAI_API_KEY'),
-            'endpoint_url': 'https://api.openai.com/v1/'
-        },
+        model="mistralai/Mixtral-8x7B-Instruct-v0.1",
+        credentials={"api_key": os.environ.get("OPENAI_API_KEY"), "endpoint_url": "https://api.openai.com/v1/"},
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ]
+            UserPromptMessage(content="Hello World!"),
+        ],
     )
 
     assert isinstance(num_tokens, int)

+ 4 - 13
api/tests/integration_tests/model_runtime/openai_api_compatible/test_speech2text.py

@@ -14,18 +14,12 @@ def test_validate_credentials():
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
             model="whisper-1",
-            credentials={
-                "api_key": "invalid_key",
-                "endpoint_url": "https://api.openai.com/v1/"
-            },
+            credentials={"api_key": "invalid_key", "endpoint_url": "https://api.openai.com/v1/"},
         )
 
     model.validate_credentials(
         model="whisper-1",
-        credentials={
-            "api_key": os.environ.get("OPENAI_API_KEY"),
-            "endpoint_url": "https://api.openai.com/v1/"
-        },
+        credentials={"api_key": os.environ.get("OPENAI_API_KEY"), "endpoint_url": "https://api.openai.com/v1/"},
     )
 
 
@@ -47,13 +41,10 @@ def test_invoke_model():
 
         result = model.invoke(
             model="whisper-1",
-            credentials={
-                "api_key": os.environ.get("OPENAI_API_KEY"),
-                "endpoint_url": "https://api.openai.com/v1/"
-            },
+            credentials={"api_key": os.environ.get("OPENAI_API_KEY"), "endpoint_url": "https://api.openai.com/v1/"},
             file=file,
             user="abc-123",
         )
 
         assert isinstance(result, str)
-        assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10'
+        assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10"

+ 20 - 32
api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py

@@ -12,27 +12,23 @@ from core.model_runtime.model_providers.openai_api_compatible.text_embedding.tex
 Using OpenAI's API as testing endpoint
 """
 
+
 def test_validate_credentials():
     model = OAICompatEmbeddingModel()
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
-            model='text-embedding-ada-002',
-            credentials={
-                'api_key': 'invalid_key',
-                'endpoint_url': 'https://api.openai.com/v1/',
-                'context_size': 8184
-                
-            }
+            model="text-embedding-ada-002",
+            credentials={"api_key": "invalid_key", "endpoint_url": "https://api.openai.com/v1/", "context_size": 8184},
         )
 
     model.validate_credentials(
-        model='text-embedding-ada-002',
+        model="text-embedding-ada-002",
         credentials={
-            'api_key': os.environ.get('OPENAI_API_KEY'),
-            'endpoint_url': 'https://api.openai.com/v1/',
-            'context_size': 8184
-        }
+            "api_key": os.environ.get("OPENAI_API_KEY"),
+            "endpoint_url": "https://api.openai.com/v1/",
+            "context_size": 8184,
+        },
     )
 
 
@@ -40,19 +36,14 @@ def test_invoke_model():
     model = OAICompatEmbeddingModel()
 
     result = model.invoke(
-        model='text-embedding-ada-002',
+        model="text-embedding-ada-002",
         credentials={
-            'api_key': os.environ.get('OPENAI_API_KEY'),
-            'endpoint_url': 'https://api.openai.com/v1/',
-            'context_size': 8184
+            "api_key": os.environ.get("OPENAI_API_KEY"),
+            "endpoint_url": "https://api.openai.com/v1/",
+            "context_size": 8184,
         },
-        texts=[
-            "hello",
-            "world",
-            " ".join(["long_text"] * 100),
-            " ".join(["another_long_text"] * 100)
-        ],
-        user="abc-123"
+        texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)],
+        user="abc-123",
     )
 
     assert isinstance(result, TextEmbeddingResult)
@@ -64,16 +55,13 @@ def test_get_num_tokens():
     model = OAICompatEmbeddingModel()
 
     num_tokens = model.get_num_tokens(
-        model='text-embedding-ada-002',
+        model="text-embedding-ada-002",
         credentials={
-            'api_key': os.environ.get('OPENAI_API_KEY'),
-            'endpoint_url': 'https://api.openai.com/v1/embeddings',
-            'context_size': 8184
+            "api_key": os.environ.get("OPENAI_API_KEY"),
+            "endpoint_url": "https://api.openai.com/v1/embeddings",
+            "context_size": 8184,
         },
-        texts=[
-            "hello",
-            "world"
-        ]
+        texts=["hello", "world"],
     )
 
-    assert num_tokens == 2
+    assert num_tokens == 2

+ 14 - 19
api/tests/integration_tests/model_runtime/openllm/test_embedding.py

@@ -12,17 +12,17 @@ def test_validate_credentials():
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
-            model='NOT IMPORTANT',
+            model="NOT IMPORTANT",
             credentials={
-                'server_url': 'ww' + os.environ.get('OPENLLM_SERVER_URL'),
-            }
+                "server_url": "ww" + os.environ.get("OPENLLM_SERVER_URL"),
+            },
         )
 
     model.validate_credentials(
-        model='NOT IMPORTANT',
+        model="NOT IMPORTANT",
         credentials={
-            'server_url': os.environ.get('OPENLLM_SERVER_URL'),
-        }
+            "server_url": os.environ.get("OPENLLM_SERVER_URL"),
+        },
     )
 
 
@@ -30,33 +30,28 @@ def test_invoke_model():
     model = OpenLLMTextEmbeddingModel()
 
     result = model.invoke(
-        model='NOT IMPORTANT',
+        model="NOT IMPORTANT",
         credentials={
-            'server_url': os.environ.get('OPENLLM_SERVER_URL'),
+            "server_url": os.environ.get("OPENLLM_SERVER_URL"),
         },
-        texts=[
-            "hello",
-            "world"
-        ],
-        user="abc-123"
+        texts=["hello", "world"],
+        user="abc-123",
     )
 
     assert isinstance(result, TextEmbeddingResult)
     assert len(result.embeddings) == 2
     assert result.usage.total_tokens > 0
 
+
 def test_get_num_tokens():
     model = OpenLLMTextEmbeddingModel()
 
     num_tokens = model.get_num_tokens(
-        model='NOT IMPORTANT',
+        model="NOT IMPORTANT",
         credentials={
-            'server_url': os.environ.get('OPENLLM_SERVER_URL'),
+            "server_url": os.environ.get("OPENLLM_SERVER_URL"),
         },
-        texts=[
-            "hello",
-            "world"
-        ]
+        texts=["hello", "world"],
     )
 
     assert num_tokens == 2

+ 30 - 39
api/tests/integration_tests/model_runtime/openllm/test_llm.py

@@ -14,67 +14,61 @@ def test_validate_credentials_for_chat_model():
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
-            model='NOT IMPORTANT',
+            model="NOT IMPORTANT",
             credentials={
-                'server_url': 'invalid_key',
-            }
+                "server_url": "invalid_key",
+            },
         )
 
     model.validate_credentials(
-        model='NOT IMPORTANT',
+        model="NOT IMPORTANT",
         credentials={
-            'server_url': os.environ.get('OPENLLM_SERVER_URL'),
-        }
+            "server_url": os.environ.get("OPENLLM_SERVER_URL"),
+        },
     )
 
+
 def test_invoke_model():
     model = OpenLLMLargeLanguageModel()
 
     response = model.invoke(
-        model='NOT IMPORTANT',
+        model="NOT IMPORTANT",
         credentials={
-            'server_url': os.environ.get('OPENLLM_SERVER_URL'),
+            "server_url": os.environ.get("OPENLLM_SERVER_URL"),
         },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ],
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
         model_parameters={
-            'temperature': 0.7,
-            'top_p': 1.0,
-            'top_k': 1,
+            "temperature": 0.7,
+            "top_p": 1.0,
+            "top_k": 1,
         },
-        stop=['you'],
+        stop=["you"],
         user="abc-123",
-        stream=False
+        stream=False,
     )
 
     assert isinstance(response, LLMResult)
     assert len(response.message.content) > 0
     assert response.usage.total_tokens > 0
 
+
 def test_invoke_stream_model():
     model = OpenLLMLargeLanguageModel()
 
     response = model.invoke(
-        model='NOT IMPORTANT',
+        model="NOT IMPORTANT",
         credentials={
-            'server_url': os.environ.get('OPENLLM_SERVER_URL'),
+            "server_url": os.environ.get("OPENLLM_SERVER_URL"),
         },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ],
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
         model_parameters={
-            'temperature': 0.7,
-            'top_p': 1.0,
-            'top_k': 1,
+            "temperature": 0.7,
+            "top_p": 1.0,
+            "top_k": 1,
         },
-        stop=['you'],
+        stop=["you"],
         stream=True,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, Generator)
@@ -84,21 +78,18 @@ def test_invoke_stream_model():
         assert isinstance(chunk.delta.message, AssistantPromptMessage)
         assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
 
+
 def test_get_num_tokens():
     model = OpenLLMLargeLanguageModel()
 
     response = model.get_num_tokens(
-        model='NOT IMPORTANT',
+        model="NOT IMPORTANT",
         credentials={
-            'server_url': os.environ.get('OPENLLM_SERVER_URL'),
+            "server_url": os.environ.get("OPENLLM_SERVER_URL"),
         },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ],
-        tools=[]
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
+        tools=[],
     )
 
     assert isinstance(response, int)
-    assert response == 3
+    assert response == 3

+ 26 - 45
api/tests/integration_tests/model_runtime/openrouter/test_llm.py

@@ -19,19 +19,12 @@ def test_validate_credentials():
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
-            model='mistralai/mixtral-8x7b-instruct',
-            credentials={
-                'api_key': 'invalid_key',
-                'mode': 'chat'
-            }
+            model="mistralai/mixtral-8x7b-instruct", credentials={"api_key": "invalid_key", "mode": "chat"}
         )
 
     model.validate_credentials(
-        model='mistralai/mixtral-8x7b-instruct',
-        credentials={
-            'api_key': os.environ.get('TOGETHER_API_KEY'),
-            'mode': 'chat'
-        }
+        model="mistralai/mixtral-8x7b-instruct",
+        credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "chat"},
     )
 
 
@@ -39,27 +32,22 @@ def test_invoke_model():
     model = OpenRouterLargeLanguageModel()
 
     response = model.invoke(
-        model='mistralai/mixtral-8x7b-instruct',
-        credentials={
-            'api_key': os.environ.get('TOGETHER_API_KEY'),
-            'mode': 'completion'
-        },
+        model="mistralai/mixtral-8x7b-instruct",
+        credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "completion"},
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Who are you?'
-            )
+            UserPromptMessage(content="Who are you?"),
         ],
         model_parameters={
-            'temperature': 1.0,
-            'top_k': 2,
-            'top_p': 0.5,
+            "temperature": 1.0,
+            "top_k": 2,
+            "top_p": 0.5,
         },
-        stop=['How'],
+        stop=["How"],
         stream=False,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, LLMResult)
@@ -70,27 +58,22 @@ def test_invoke_stream_model():
     model = OpenRouterLargeLanguageModel()
 
     response = model.invoke(
-        model='mistralai/mixtral-8x7b-instruct',
-        credentials={
-            'api_key': os.environ.get('TOGETHER_API_KEY'),
-            'mode': 'chat'
-        },
+        model="mistralai/mixtral-8x7b-instruct",
+        credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "chat"},
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Who are you?'
-            )
+            UserPromptMessage(content="Who are you?"),
         ],
         model_parameters={
-            'temperature': 1.0,
-            'top_k': 2,
-            'top_p': 0.5,
+            "temperature": 1.0,
+            "top_k": 2,
+            "top_p": 0.5,
         },
-        stop=['How'],
+        stop=["How"],
         stream=True,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, Generator)
@@ -105,18 +88,16 @@ def test_get_num_tokens():
     model = OpenRouterLargeLanguageModel()
 
     num_tokens = model.get_num_tokens(
-        model='mistralai/mixtral-8x7b-instruct',
+        model="mistralai/mixtral-8x7b-instruct",
         credentials={
-            'api_key': os.environ.get('TOGETHER_API_KEY'),
+            "api_key": os.environ.get("TOGETHER_API_KEY"),
         },
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ]
+            UserPromptMessage(content="Hello World!"),
+        ],
     )
 
     assert isinstance(num_tokens, int)

+ 34 - 40
api/tests/integration_tests/model_runtime/replicate/test_llm.py

@@ -14,19 +14,19 @@ def test_validate_credentials():
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
-            model='meta/llama-2-13b-chat',
+            model="meta/llama-2-13b-chat",
             credentials={
-                'replicate_api_token': 'invalid_key',
-                'model_version': 'f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d'
-            }
+                "replicate_api_token": "invalid_key",
+                "model_version": "f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d",
+            },
         )
 
     model.validate_credentials(
-        model='meta/llama-2-13b-chat',
+        model="meta/llama-2-13b-chat",
         credentials={
-            'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
-            'model_version': 'f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d'
-        }
+            "replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
+            "model_version": "f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d",
+        },
     )
 
 
@@ -34,27 +34,25 @@ def test_invoke_model():
     model = ReplicateLargeLanguageModel()
 
     response = model.invoke(
-        model='meta/llama-2-13b-chat',
+        model="meta/llama-2-13b-chat",
         credentials={
-            'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
-            'model_version': 'f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d'
+            "replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
+            "model_version": "f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d",
         },
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Who are you?'
-            )
+            UserPromptMessage(content="Who are you?"),
         ],
         model_parameters={
-            'temperature': 1.0,
-            'top_k': 2,
-            'top_p': 0.5,
+            "temperature": 1.0,
+            "top_k": 2,
+            "top_p": 0.5,
         },
-        stop=['How'],
+        stop=["How"],
         stream=False,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, LLMResult)
@@ -65,27 +63,25 @@ def test_invoke_stream_model():
     model = ReplicateLargeLanguageModel()
 
     response = model.invoke(
-        model='mistralai/mixtral-8x7b-instruct-v0.1',
+        model="mistralai/mixtral-8x7b-instruct-v0.1",
         credentials={
-            'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
-            'model_version': '2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e'
+            "replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
+            "model_version": "2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e",
         },
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Who are you?'
-            )
+            UserPromptMessage(content="Who are you?"),
         ],
         model_parameters={
-            'temperature': 1.0,
-            'top_k': 2,
-            'top_p': 0.5,
+            "temperature": 1.0,
+            "top_k": 2,
+            "top_p": 0.5,
         },
-        stop=['How'],
+        stop=["How"],
         stream=True,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, Generator)
@@ -100,19 +96,17 @@ def test_get_num_tokens():
     model = ReplicateLargeLanguageModel()
 
     num_tokens = model.get_num_tokens(
-        model='',
+        model="",
         credentials={
-            'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
-            'model_version': '2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e'
+            "replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
+            "model_version": "2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e",
         },
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ]
+            UserPromptMessage(content="Hello World!"),
+        ],
     )
 
     assert num_tokens == 14

+ 40 - 55
api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py

@@ -12,19 +12,19 @@ def test_validate_credentials_one():
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
-            model='replicate/all-mpnet-base-v2',
+            model="replicate/all-mpnet-base-v2",
             credentials={
-                'replicate_api_token': 'invalid_key',
-                'model_version': 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305'
-            }
+                "replicate_api_token": "invalid_key",
+                "model_version": "b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305",
+            },
         )
 
     model.validate_credentials(
-        model='replicate/all-mpnet-base-v2',
+        model="replicate/all-mpnet-base-v2",
         credentials={
-            'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
-            'model_version': 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305'
-        }
+            "replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
+            "model_version": "b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305",
+        },
     )
 
 
@@ -33,19 +33,19 @@ def test_validate_credentials_two():
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
-            model='nateraw/bge-large-en-v1.5',
+            model="nateraw/bge-large-en-v1.5",
             credentials={
-                'replicate_api_token': 'invalid_key',
-                'model_version': '9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1'
-            }
+                "replicate_api_token": "invalid_key",
+                "model_version": "9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1",
+            },
         )
 
     model.validate_credentials(
-        model='nateraw/bge-large-en-v1.5',
+        model="nateraw/bge-large-en-v1.5",
         credentials={
-            'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
-            'model_version': '9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1'
-        }
+            "replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
+            "model_version": "9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1",
+        },
     )
 
 
@@ -53,16 +53,13 @@ def test_invoke_model_one():
     model = ReplicateEmbeddingModel()
 
     result = model.invoke(
-        model='nateraw/bge-large-en-v1.5',
+        model="nateraw/bge-large-en-v1.5",
         credentials={
-            'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
-            'model_version': '9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1'
+            "replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
+            "model_version": "9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1",
         },
-        texts=[
-            "hello",
-            "world"
-        ],
-        user="abc-123"
+        texts=["hello", "world"],
+        user="abc-123",
     )
 
     assert isinstance(result, TextEmbeddingResult)
@@ -74,16 +71,13 @@ def test_invoke_model_two():
     model = ReplicateEmbeddingModel()
 
     result = model.invoke(
-        model='andreasjansson/clip-features',
+        model="andreasjansson/clip-features",
         credentials={
-            'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
-            'model_version': '75b33f253f7714a281ad3e9b28f63e3232d583716ef6718f2e46641077ea040a'
+            "replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
+            "model_version": "75b33f253f7714a281ad3e9b28f63e3232d583716ef6718f2e46641077ea040a",
         },
-        texts=[
-            "hello",
-            "world"
-        ],
-        user="abc-123"
+        texts=["hello", "world"],
+        user="abc-123",
     )
 
     assert isinstance(result, TextEmbeddingResult)
@@ -95,16 +89,13 @@ def test_invoke_model_three():
     model = ReplicateEmbeddingModel()
 
     result = model.invoke(
-        model='replicate/all-mpnet-base-v2',
+        model="replicate/all-mpnet-base-v2",
         credentials={
-            'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
-            'model_version': 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305'
+            "replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
+            "model_version": "b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305",
         },
-        texts=[
-            "hello",
-            "world"
-        ],
-        user="abc-123"
+        texts=["hello", "world"],
+        user="abc-123",
     )
 
     assert isinstance(result, TextEmbeddingResult)
@@ -116,16 +107,13 @@ def test_invoke_model_four():
     model = ReplicateEmbeddingModel()
 
     result = model.invoke(
-        model='nateraw/jina-embeddings-v2-base-en',
+        model="nateraw/jina-embeddings-v2-base-en",
         credentials={
-            'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
-            'model_version': 'f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e'
+            "replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
+            "model_version": "f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e",
         },
-        texts=[
-            "hello",
-            "world"
-        ],
-        user="abc-123"
+        texts=["hello", "world"],
+        user="abc-123",
     )
 
     assert isinstance(result, TextEmbeddingResult)
@@ -137,15 +125,12 @@ def test_get_num_tokens():
     model = ReplicateEmbeddingModel()
 
     num_tokens = model.get_num_tokens(
-        model='nateraw/jina-embeddings-v2-base-en',
+        model="nateraw/jina-embeddings-v2-base-en",
         credentials={
-            'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
-            'model_version': 'f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e'
+            "replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
+            "model_version": "f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e",
         },
-        texts=[
-            "hello",
-            "world"
-        ]
+        texts=["hello", "world"],
     )
 
     assert num_tokens == 2

+ 2 - 6
api/tests/integration_tests/model_runtime/sagemaker/test_provider.py

@@ -10,10 +10,6 @@ def test_validate_provider_credentials():
     provider = SageMakerProvider()
 
     with pytest.raises(CredentialsValidateFailedError):
-        provider.validate_provider_credentials(
-            credentials={}
-        )
+        provider.validate_provider_credentials(credentials={})
 
-    provider.validate_provider_credentials(
-        credentials={}
-    )
+    provider.validate_provider_credentials(credentials={})

+ 6 - 6
api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py

@@ -12,11 +12,11 @@ def test_validate_credentials():
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
-            model='bge-m3-rerank-v2',
+            model="bge-m3-rerank-v2",
             credentials={
                 "aws_region": os.getenv("AWS_REGION"),
                 "aws_access_key": os.getenv("AWS_ACCESS_KEY"),
-                "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
+                "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
             },
             query="What is the capital of the United States?",
             docs=[
@@ -25,7 +25,7 @@ def test_validate_credentials():
                 "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
                 "are a political division controlled by the United States. Its capital is Saipan.",
             ],
-            score_threshold=0.8
+            score_threshold=0.8,
         )
 
 
@@ -33,11 +33,11 @@ def test_invoke_model():
     model = SageMakerRerankModel()
 
     result = model.invoke(
-        model='bge-m3-rerank-v2',
+        model="bge-m3-rerank-v2",
         credentials={
             "aws_region": os.getenv("AWS_REGION"),
             "aws_access_key": os.getenv("AWS_ACCESS_KEY"),
-            "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
+            "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
         },
         query="What is the capital of the United States?",
         docs=[
@@ -46,7 +46,7 @@ def test_invoke_model():
             "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
             "are a political division controlled by the United States. Its capital is Saipan.",
         ],
-        score_threshold=0.8
+        score_threshold=0.8,
     )
 
     assert isinstance(result, RerankResult)

+ 5 - 27
api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py

@@ -11,45 +11,23 @@ def test_validate_credentials():
     model = SageMakerEmbeddingModel()
 
     with pytest.raises(CredentialsValidateFailedError):
-        model.validate_credentials(
-            model='bge-m3',
-            credentials={
-            }
-        )
+        model.validate_credentials(model="bge-m3", credentials={})
 
-    model.validate_credentials(
-        model='bge-m3-embedding',
-        credentials={
-        }
-    )
+    model.validate_credentials(model="bge-m3-embedding", credentials={})
 
 
 def test_invoke_model():
     model = SageMakerEmbeddingModel()
 
-    result = model.invoke(
-        model='bge-m3-embedding',
-        credentials={
-        },
-        texts=[
-            "hello",
-            "world"
-        ],
-        user="abc-123"
-    )
+    result = model.invoke(model="bge-m3-embedding", credentials={}, texts=["hello", "world"], user="abc-123")
 
     assert isinstance(result, TextEmbeddingResult)
     assert len(result.embeddings) == 2
 
+
 def test_get_num_tokens():
     model = SageMakerEmbeddingModel()
 
-    num_tokens = model.get_num_tokens(
-        model='bge-m3-embedding',
-        credentials={
-        },
-        texts=[
-        ]
-    )
+    num_tokens = model.get_num_tokens(model="bge-m3-embedding", credentials={}, texts=[])
 
     assert num_tokens == 0

+ 19 - 52
api/tests/integration_tests/model_runtime/siliconflow/test_llm.py

@@ -13,41 +13,22 @@ def test_validate_credentials():
     model = SiliconflowLargeLanguageModel()
 
     with pytest.raises(CredentialsValidateFailedError):
-        model.validate_credentials(
-            model='deepseek-ai/DeepSeek-V2-Chat',
-            credentials={
-                'api_key': 'invalid_key'
-            }
-        )
-
-    model.validate_credentials(
-        model='deepseek-ai/DeepSeek-V2-Chat',
-        credentials={
-            'api_key': os.environ.get('API_KEY')
-        }
-    )
+        model.validate_credentials(model="deepseek-ai/DeepSeek-V2-Chat", credentials={"api_key": "invalid_key"})
+
+    model.validate_credentials(model="deepseek-ai/DeepSeek-V2-Chat", credentials={"api_key": os.environ.get("API_KEY")})
 
 
 def test_invoke_model():
     model = SiliconflowLargeLanguageModel()
 
     response = model.invoke(
-        model='deepseek-ai/DeepSeek-V2-Chat',
-        credentials={
-            'api_key': os.environ.get('API_KEY')
-        },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Who are you?'
-            )
-        ],
-        model_parameters={
-            'temperature': 0.5,
-            'max_tokens': 10
-        },
-        stop=['How'],
+        model="deepseek-ai/DeepSeek-V2-Chat",
+        credentials={"api_key": os.environ.get("API_KEY")},
+        prompt_messages=[UserPromptMessage(content="Who are you?")],
+        model_parameters={"temperature": 0.5, "max_tokens": 10},
+        stop=["How"],
         stream=False,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, LLMResult)
@@ -58,22 +39,12 @@ def test_invoke_stream_model():
     model = SiliconflowLargeLanguageModel()
 
     response = model.invoke(
-        model='deepseek-ai/DeepSeek-V2-Chat',
-        credentials={
-            'api_key': os.environ.get('API_KEY')
-        },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ],
-        model_parameters={
-            'temperature': 0.5,
-            'max_tokens': 100,
-            'seed': 1234
-        },
+        model="deepseek-ai/DeepSeek-V2-Chat",
+        credentials={"api_key": os.environ.get("API_KEY")},
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
+        model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234},
         stream=True,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, Generator)
@@ -89,18 +60,14 @@ def test_get_num_tokens():
     model = SiliconflowLargeLanguageModel()
 
     num_tokens = model.get_num_tokens(
-        model='deepseek-ai/DeepSeek-V2-Chat',
-        credentials={
-            'api_key': os.environ.get('API_KEY')
-        },
+        model="deepseek-ai/DeepSeek-V2-Chat",
+        credentials={"api_key": os.environ.get("API_KEY")},
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ]
+            UserPromptMessage(content="Hello World!"),
+        ],
     )
 
     assert num_tokens == 12

+ 2 - 8
api/tests/integration_tests/model_runtime/siliconflow/test_provider.py

@@ -10,12 +10,6 @@ def test_validate_provider_credentials():
     provider = SiliconflowProvider()
 
     with pytest.raises(CredentialsValidateFailedError):
-        provider.validate_provider_credentials(
-            credentials={}
-        )
+        provider.validate_provider_credentials(credentials={})
 
-    provider.validate_provider_credentials(
-        credentials={
-            'api_key': os.environ.get('API_KEY')
-        }
-    )
+    provider.validate_provider_credentials(credentials={"api_key": os.environ.get("API_KEY")})

+ 5 - 7
api/tests/integration_tests/model_runtime/siliconflow/test_rerank.py

@@ -13,9 +13,7 @@ def test_validate_credentials():
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
             model="BAAI/bge-reranker-v2-m3",
-            credentials={
-                "api_key": "invalid_key"
-            },
+            credentials={"api_key": "invalid_key"},
         )
 
     model.validate_credentials(
@@ -30,17 +28,17 @@ def test_invoke_model():
     model = SiliconflowRerankModel()
 
     result = model.invoke(
-        model='BAAI/bge-reranker-v2-m3',
+        model="BAAI/bge-reranker-v2-m3",
         credentials={
             "api_key": os.environ.get("API_KEY"),
         },
         query="Who is Kasumi?",
         docs=[
-            "Kasumi is a girl's name of Japanese origin meaning \"mist\".",
+            'Kasumi is a girl\'s name of Japanese origin meaning "mist".',
             "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ",
-            "and she leads a team named PopiParty."
+            "and she leads a team named PopiParty.",
         ],
-        score_threshold=0.8
+        score_threshold=0.8,
     )
 
     assert isinstance(result, RerankResult)

+ 4 - 12
api/tests/integration_tests/model_runtime/siliconflow/test_speech2text.py

@@ -12,16 +12,12 @@ def test_validate_credentials():
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
             model="iic/SenseVoiceSmall",
-            credentials={
-                "api_key": "invalid_key"
-            },
+            credentials={"api_key": "invalid_key"},
         )
 
     model.validate_credentials(
         model="iic/SenseVoiceSmall",
-        credentials={
-            "api_key": os.environ.get("API_KEY")
-        },
+        credentials={"api_key": os.environ.get("API_KEY")},
     )
 
 
@@ -42,12 +38,8 @@ def test_invoke_model():
         file = audio_file
 
         result = model.invoke(
-            model="iic/SenseVoiceSmall",
-            credentials={
-                "api_key": os.environ.get("API_KEY")
-            },
-            file=file
+            model="iic/SenseVoiceSmall", credentials={"api_key": os.environ.get("API_KEY")}, file=file
         )
 
         assert isinstance(result, str)
-        assert result == '1,2,3,4,5,6,7,8,9,10.'
+        assert result == "1,2,3,4,5,6,7,8,9,10."

+ 1 - 3
api/tests/integration_tests/model_runtime/siliconflow/test_text_embedding.py

@@ -15,9 +15,7 @@ def test_validate_credentials():
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
             model="BAAI/bge-large-zh-v1.5",
-            credentials={
-                "api_key": "invalid_key"
-            },
+            credentials={"api_key": "invalid_key"},
         )
 
     model.validate_credentials(

+ 28 - 49
api/tests/integration_tests/model_runtime/spark/test_llm.py

@@ -13,20 +13,15 @@ def test_validate_credentials():
     model = SparkLargeLanguageModel()
 
     with pytest.raises(CredentialsValidateFailedError):
-        model.validate_credentials(
-            model='spark-1.5',
-            credentials={
-                'app_id': 'invalid_key'
-            }
-        )
+        model.validate_credentials(model="spark-1.5", credentials={"app_id": "invalid_key"})
 
     model.validate_credentials(
-        model='spark-1.5',
+        model="spark-1.5",
         credentials={
-            'app_id': os.environ.get('SPARK_APP_ID'),
-            'api_secret': os.environ.get('SPARK_API_SECRET'),
-            'api_key': os.environ.get('SPARK_API_KEY')
-        }
+            "app_id": os.environ.get("SPARK_APP_ID"),
+            "api_secret": os.environ.get("SPARK_API_SECRET"),
+            "api_key": os.environ.get("SPARK_API_KEY"),
+        },
     )
 
 
@@ -34,24 +29,17 @@ def test_invoke_model():
     model = SparkLargeLanguageModel()
 
     response = model.invoke(
-        model='spark-1.5',
+        model="spark-1.5",
         credentials={
-            'app_id': os.environ.get('SPARK_APP_ID'),
-            'api_secret': os.environ.get('SPARK_API_SECRET'),
-            'api_key': os.environ.get('SPARK_API_KEY')
-        },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Who are you?'
-            )
-        ],
-        model_parameters={
-            'temperature': 0.5,
-            'max_tokens': 10
+            "app_id": os.environ.get("SPARK_APP_ID"),
+            "api_secret": os.environ.get("SPARK_API_SECRET"),
+            "api_key": os.environ.get("SPARK_API_KEY"),
         },
-        stop=['How'],
+        prompt_messages=[UserPromptMessage(content="Who are you?")],
+        model_parameters={"temperature": 0.5, "max_tokens": 10},
+        stop=["How"],
         stream=False,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, LLMResult)
@@ -62,23 +50,16 @@ def test_invoke_stream_model():
     model = SparkLargeLanguageModel()
 
     response = model.invoke(
-        model='spark-1.5',
+        model="spark-1.5",
         credentials={
-            'app_id': os.environ.get('SPARK_APP_ID'),
-            'api_secret': os.environ.get('SPARK_API_SECRET'),
-            'api_key': os.environ.get('SPARK_API_KEY')
-        },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ],
-        model_parameters={
-            'temperature': 0.5,
-            'max_tokens': 100
+            "app_id": os.environ.get("SPARK_APP_ID"),
+            "api_secret": os.environ.get("SPARK_API_SECRET"),
+            "api_key": os.environ.get("SPARK_API_KEY"),
         },
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
+        model_parameters={"temperature": 0.5, "max_tokens": 100},
         stream=True,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, Generator)
@@ -94,20 +75,18 @@ def test_get_num_tokens():
     model = SparkLargeLanguageModel()
 
     num_tokens = model.get_num_tokens(
-        model='spark-1.5',
+        model="spark-1.5",
         credentials={
-            'app_id': os.environ.get('SPARK_APP_ID'),
-            'api_secret': os.environ.get('SPARK_API_SECRET'),
-            'api_key': os.environ.get('SPARK_API_KEY')
+            "app_id": os.environ.get("SPARK_APP_ID"),
+            "api_secret": os.environ.get("SPARK_API_SECRET"),
+            "api_key": os.environ.get("SPARK_API_KEY"),
         },
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ]
+            UserPromptMessage(content="Hello World!"),
+        ],
     )
 
     assert num_tokens == 14

+ 4 - 6
api/tests/integration_tests/model_runtime/spark/test_provider.py

@@ -10,14 +10,12 @@ def test_validate_provider_credentials():
     provider = SparkProvider()
 
     with pytest.raises(CredentialsValidateFailedError):
-        provider.validate_provider_credentials(
-            credentials={}
-        )
+        provider.validate_provider_credentials(credentials={})
 
     provider.validate_provider_credentials(
         credentials={
-            'app_id': os.environ.get('SPARK_APP_ID'),
-            'api_secret': os.environ.get('SPARK_API_SECRET'),
-            'api_key': os.environ.get('SPARK_API_KEY')
+            "app_id": os.environ.get("SPARK_APP_ID"),
+            "api_secret": os.environ.get("SPARK_API_SECRET"),
+            "api_key": os.environ.get("SPARK_API_KEY"),
         }
     )

+ 36 - 87
api/tests/integration_tests/model_runtime/stepfun/test_llm.py

@@ -21,40 +21,22 @@ def test_validate_credentials():
     model = StepfunLargeLanguageModel()
 
     with pytest.raises(CredentialsValidateFailedError):
-        model.validate_credentials(
-            model='step-1-8k',
-            credentials={
-                'api_key': 'invalid_key'
-            }
-        )
-
-    model.validate_credentials(
-        model='step-1-8k',
-        credentials={
-            'api_key': os.environ.get('STEPFUN_API_KEY')
-        }
-    )
+        model.validate_credentials(model="step-1-8k", credentials={"api_key": "invalid_key"})
+
+    model.validate_credentials(model="step-1-8k", credentials={"api_key": os.environ.get("STEPFUN_API_KEY")})
+
 
 def test_invoke_model():
     model = StepfunLargeLanguageModel()
 
     response = model.invoke(
-        model='step-1-8k',
-        credentials={
-            'api_key': os.environ.get('STEPFUN_API_KEY')
-        },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ],
-        model_parameters={
-            'temperature': 0.9,
-            'top_p': 0.7
-        },
-        stop=['Hi'],
+        model="step-1-8k",
+        credentials={"api_key": os.environ.get("STEPFUN_API_KEY")},
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
+        model_parameters={"temperature": 0.9, "top_p": 0.7},
+        stop=["Hi"],
         stream=False,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, LLMResult)
@@ -65,24 +47,17 @@ def test_invoke_stream_model():
     model = StepfunLargeLanguageModel()
 
     response = model.invoke(
-        model='step-1-8k',
-        credentials={
-            'api_key': os.environ.get('STEPFUN_API_KEY')
-        },
+        model="step-1-8k",
+        credentials={"api_key": os.environ.get("STEPFUN_API_KEY")},
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
+            UserPromptMessage(content="Hello World!"),
         ],
-        model_parameters={
-            'temperature': 0.9,
-            'top_p': 0.7
-        },
+        model_parameters={"temperature": 0.9, "top_p": 0.7},
         stream=True,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, Generator)
@@ -98,10 +73,7 @@ def test_get_customizable_model_schema():
     model = StepfunLargeLanguageModel()
 
     schema = model.get_customizable_model_schema(
-        model='step-1-8k',
-        credentials={
-            'api_key': os.environ.get('STEPFUN_API_KEY')
-        }
+        model="step-1-8k", credentials={"api_key": os.environ.get("STEPFUN_API_KEY")}
     )
     assert isinstance(schema, AIModelEntity)
 
@@ -110,67 +82,44 @@ def test_invoke_chat_model_with_tools():
     model = StepfunLargeLanguageModel()
 
     result = model.invoke(
-        model='step-1-8k',
-        credentials={
-            'api_key': os.environ.get('STEPFUN_API_KEY')
-        },
+        model="step-1-8k",
+        credentials={"api_key": os.environ.get("STEPFUN_API_KEY")},
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
             UserPromptMessage(
                 content="what's the weather today in Shanghai?",
-            )
+            ),
         ],
-        model_parameters={
-            'temperature': 0.9,
-            'max_tokens': 100
-        },
+        model_parameters={"temperature": 0.9, "max_tokens": 100},
         tools=[
             PromptMessageTool(
-                name='get_weather',
-                description='Determine weather in my location',
+                name="get_weather",
+                description="Determine weather in my location",
                 parameters={
                     "type": "object",
                     "properties": {
-                      "location": {
-                        "type": "string",
-                        "description": "The city and state e.g. San Francisco, CA"
-                      },
-                      "unit": {
-                        "type": "string",
-                        "enum": [
-                          "c",
-                          "f"
-                        ]
-                      }
+                        "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
+                        "unit": {"type": "string", "enum": ["c", "f"]},
                     },
-                    "required": [
-                      "location"
-                    ]
-                  }
+                    "required": ["location"],
+                },
             ),
             PromptMessageTool(
-                name='get_stock_price',
-                description='Get the current stock price',
+                name="get_stock_price",
+                description="Get the current stock price",
                 parameters={
                     "type": "object",
-                    "properties": {
-                      "symbol": {
-                        "type": "string",
-                        "description": "The stock symbol"
-                      }
-                    },
-                    "required": [
-                      "symbol"
-                    ]
-                  }
-            )
+                    "properties": {"symbol": {"type": "string", "description": "The stock symbol"}},
+                    "required": ["symbol"],
+                },
+            ),
         ],
         stream=False,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(result, LLMResult)
     assert isinstance(result.message, AssistantPromptMessage)
-    assert len(result.message.tool_calls) > 0
+    assert len(result.message.tool_calls) > 0

+ 8 - 21
api/tests/integration_tests/model_runtime/test_model_provider_factory.py

@@ -24,13 +24,8 @@ def test_get_models():
     providers = factory.get_models(
         model_type=ModelType.LLM,
         provider_configs=[
-            ProviderConfig(
-                provider='openai',
-                credentials={
-                    'openai_api_key': os.environ.get('OPENAI_API_KEY')
-                }
-            )
-        ]
+            ProviderConfig(provider="openai", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")})
+        ],
     )
 
     logger.debug(providers)
@@ -44,29 +39,21 @@ def test_get_models():
             assert provider_model.model_type == ModelType.LLM
 
     providers = factory.get_models(
-        provider='openai',
+        provider="openai",
         provider_configs=[
-            ProviderConfig(
-                provider='openai',
-                credentials={
-                    'openai_api_key': os.environ.get('OPENAI_API_KEY')
-                }
-            )
-        ]
+            ProviderConfig(provider="openai", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")})
+        ],
     )
 
     assert len(providers) == 1
     assert isinstance(providers[0], SimpleProviderEntity)
-    assert providers[0].provider == 'openai'
+    assert providers[0].provider == "openai"
 
 
 def test_provider_credentials_validate():
     factory = ModelProviderFactory()
     factory.provider_credentials_validate(
-        provider='openai',
-        credentials={
-            'openai_api_key': os.environ.get('OPENAI_API_KEY')
-        }
+        provider="openai", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}
     )
 
 
@@ -79,4 +66,4 @@ def test__get_model_provider_map():
         logger.debug(model_provider.provider_instance)
 
     assert len(model_providers) >= 1
-    assert isinstance(model_providers['openai'], ModelProviderExtension)
+    assert isinstance(model_providers["openai"], ModelProviderExtension)

+ 29 - 45
api/tests/integration_tests/model_runtime/togetherai/test_llm.py

@@ -19,76 +19,61 @@ def test_validate_credentials():
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
-            model='mistralai/Mixtral-8x7B-Instruct-v0.1',
-            credentials={
-                'api_key': 'invalid_key',
-                'mode': 'chat'
-            }
+            model="mistralai/Mixtral-8x7B-Instruct-v0.1", credentials={"api_key": "invalid_key", "mode": "chat"}
         )
 
     model.validate_credentials(
-        model='mistralai/Mixtral-8x7B-Instruct-v0.1',
-        credentials={
-            'api_key': os.environ.get('TOGETHER_API_KEY'),
-            'mode': 'chat'
-        }
+        model="mistralai/Mixtral-8x7B-Instruct-v0.1",
+        credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "chat"},
     )
 
+
 def test_invoke_model():
     model = TogetherAILargeLanguageModel()
 
     response = model.invoke(
-        model='mistralai/Mixtral-8x7B-Instruct-v0.1',
-        credentials={
-            'api_key': os.environ.get('TOGETHER_API_KEY'),
-            'mode': 'completion'
-        },
+        model="mistralai/Mixtral-8x7B-Instruct-v0.1",
+        credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "completion"},
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Who are you?'
-            )
+            UserPromptMessage(content="Who are you?"),
         ],
         model_parameters={
-            'temperature': 1.0,
-            'top_k': 2,
-            'top_p': 0.5,
+            "temperature": 1.0,
+            "top_k": 2,
+            "top_p": 0.5,
         },
-        stop=['How'],
+        stop=["How"],
         stream=False,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, LLMResult)
     assert len(response.message.content) > 0
 
+
 def test_invoke_stream_model():
     model = TogetherAILargeLanguageModel()
 
     response = model.invoke(
-        model='mistralai/Mixtral-8x7B-Instruct-v0.1',
-        credentials={
-            'api_key': os.environ.get('TOGETHER_API_KEY'),
-            'mode': 'chat'
-        },
+        model="mistralai/Mixtral-8x7B-Instruct-v0.1",
+        credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "chat"},
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Who are you?'
-            )
+            UserPromptMessage(content="Who are you?"),
         ],
         model_parameters={
-            'temperature': 1.0,
-            'top_k': 2,
-            'top_p': 0.5,
+            "temperature": 1.0,
+            "top_k": 2,
+            "top_p": 0.5,
         },
-        stop=['How'],
+        stop=["How"],
         stream=True,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, Generator)
@@ -98,22 +83,21 @@ def test_invoke_stream_model():
         assert isinstance(chunk.delta, LLMResultChunkDelta)
         assert isinstance(chunk.delta.message, AssistantPromptMessage)
 
+
 def test_get_num_tokens():
     model = TogetherAILargeLanguageModel()
 
     num_tokens = model.get_num_tokens(
-        model='mistralai/Mixtral-8x7B-Instruct-v0.1',
+        model="mistralai/Mixtral-8x7B-Instruct-v0.1",
         credentials={
-            'api_key': os.environ.get('TOGETHER_API_KEY'),
+            "api_key": os.environ.get("TOGETHER_API_KEY"),
         },
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ]
+            UserPromptMessage(content="Hello World!"),
+        ],
     )
 
     assert isinstance(num_tokens, int)

+ 18 - 49
api/tests/integration_tests/model_runtime/tongyi/test_llm.py

@@ -13,18 +13,10 @@ def test_validate_credentials():
     model = TongyiLargeLanguageModel()
 
     with pytest.raises(CredentialsValidateFailedError):
-        model.validate_credentials(
-            model='qwen-turbo',
-            credentials={
-                'dashscope_api_key': 'invalid_key'
-            }
-        )
+        model.validate_credentials(model="qwen-turbo", credentials={"dashscope_api_key": "invalid_key"})
 
     model.validate_credentials(
-        model='qwen-turbo',
-        credentials={
-            'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
-        }
+        model="qwen-turbo", credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")}
     )
 
 
@@ -32,22 +24,13 @@ def test_invoke_model():
     model = TongyiLargeLanguageModel()
 
     response = model.invoke(
-        model='qwen-turbo',
-        credentials={
-            'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
-        },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Who are you?'
-            )
-        ],
-        model_parameters={
-            'temperature': 0.5,
-            'max_tokens': 10
-        },
-        stop=['How'],
+        model="qwen-turbo",
+        credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")},
+        prompt_messages=[UserPromptMessage(content="Who are you?")],
+        model_parameters={"temperature": 0.5, "max_tokens": 10},
+        stop=["How"],
         stream=False,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, LLMResult)
@@ -58,22 +41,12 @@ def test_invoke_stream_model():
     model = TongyiLargeLanguageModel()
 
     response = model.invoke(
-        model='qwen-turbo',
-        credentials={
-            'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
-        },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ],
-        model_parameters={
-            'temperature': 0.5,
-            'max_tokens': 100,
-            'seed': 1234
-        },
+        model="qwen-turbo",
+        credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")},
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
+        model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234},
         stream=True,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, Generator)
@@ -89,18 +62,14 @@ def test_get_num_tokens():
     model = TongyiLargeLanguageModel()
 
     num_tokens = model.get_num_tokens(
-        model='qwen-turbo',
-        credentials={
-            'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
-        },
+        model="qwen-turbo",
+        credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")},
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ]
+            UserPromptMessage(content="Hello World!"),
+        ],
     )
 
     assert num_tokens == 12

+ 2 - 6
api/tests/integration_tests/model_runtime/tongyi/test_provider.py

@@ -10,12 +10,8 @@ def test_validate_provider_credentials():
     provider = TongyiProvider()
 
     with pytest.raises(CredentialsValidateFailedError):
-        provider.validate_provider_credentials(
-            credentials={}
-        )
+        provider.validate_provider_credentials(credentials={})
 
     provider.validate_provider_credentials(
-        credentials={
-            'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
-        }
+        credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")}
     )

+ 7 - 11
api/tests/integration_tests/model_runtime/tongyi/test_response_format.py

@@ -39,21 +39,17 @@ def invoke_model_with_json_response(model_name="qwen-max-0403"):
 
     response = model.invoke(
         model=model_name,
-        credentials={
-            'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
-        },
+        credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")},
         prompt_messages=[
-            UserPromptMessage(
-                content='output json data with format `{"data": "test", "code": 200, "msg": "success"}'
-            )
+            UserPromptMessage(content='output json data with format `{"data": "test", "code": 200, "msg": "success"}')
         ],
         model_parameters={
-            'temperature': 0.5,
-            'max_tokens': 50,
-            'response_format': 'JSON',
+            "temperature": 0.5,
+            "max_tokens": 50,
+            "response_format": "JSON",
         },
         stream=True,
-        user="abc-123"
+        user="abc-123",
     )
     print("=====================================")
     print(response)
@@ -81,4 +77,4 @@ def is_json(s):
         json.loads(s)
     except ValueError:
         return False
-    return True
+    return True

+ 59 - 118
api/tests/integration_tests/model_runtime/upstage/test_llm.py

@@ -26,151 +26,113 @@ def test_predefined_models():
     assert len(model_schemas) >= 1
     assert isinstance(model_schemas[0], AIModelEntity)
 
-@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
+
+@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
 def test_validate_credentials_for_chat_model(setup_openai_mock):
     model = UpstageLargeLanguageModel()
 
     with pytest.raises(CredentialsValidateFailedError):
         # model name to gpt-3.5-turbo because of mocking
-        model.validate_credentials(
-            model='gpt-3.5-turbo',
-            credentials={
-                'upstage_api_key': 'invalid_key'
-            }
-        )
+        model.validate_credentials(model="gpt-3.5-turbo", credentials={"upstage_api_key": "invalid_key"})
 
     model.validate_credentials(
-        model='solar-1-mini-chat',
-        credentials={
-            'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
-        }
+        model="solar-1-mini-chat", credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}
     )
 
-@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
+
+@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
 def test_invoke_chat_model(setup_openai_mock):
     model = UpstageLargeLanguageModel()
 
     result = model.invoke(
-        model='solar-1-mini-chat',
-        credentials={
-            'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
-        },
+        model="solar-1-mini-chat",
+        credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")},
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
+            UserPromptMessage(content="Hello World!"),
         ],
         model_parameters={
-            'temperature': 0.0,
-            'top_p': 1.0,
-            'presence_penalty': 0.0,
-            'frequency_penalty': 0.0,
-            'max_tokens': 10
+            "temperature": 0.0,
+            "top_p": 1.0,
+            "presence_penalty": 0.0,
+            "frequency_penalty": 0.0,
+            "max_tokens": 10,
         },
-        stop=['How'],
+        stop=["How"],
         stream=False,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(result, LLMResult)
     assert len(result.message.content) > 0
 
-@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
+
+@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
 def test_invoke_chat_model_with_tools(setup_openai_mock):
     model = UpstageLargeLanguageModel()
 
     result = model.invoke(
-        model='solar-1-mini-chat',
-        credentials={
-            'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
-        },
+        model="solar-1-mini-chat",
+        credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")},
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
             UserPromptMessage(
                 content="what's the weather today in London?",
-            )
+            ),
         ],
-        model_parameters={
-            'temperature': 0.0,
-            'max_tokens': 100
-        },
+        model_parameters={"temperature": 0.0, "max_tokens": 100},
         tools=[
             PromptMessageTool(
-                name='get_weather',
-                description='Determine weather in my location',
+                name="get_weather",
+                description="Determine weather in my location",
                 parameters={
                     "type": "object",
                     "properties": {
-                      "location": {
-                        "type": "string",
-                        "description": "The city and state e.g. San Francisco, CA"
-                      },
-                      "unit": {
-                        "type": "string",
-                        "enum": [
-                          "c",
-                          "f"
-                        ]
-                      }
+                        "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
+                        "unit": {"type": "string", "enum": ["c", "f"]},
                     },
-                    "required": [
-                      "location"
-                    ]
-                  }
+                    "required": ["location"],
+                },
             ),
             PromptMessageTool(
-                name='get_stock_price',
-                description='Get the current stock price',
+                name="get_stock_price",
+                description="Get the current stock price",
                 parameters={
                     "type": "object",
-                    "properties": {
-                      "symbol": {
-                        "type": "string",
-                        "description": "The stock symbol"
-                      }
-                    },
-                    "required": [
-                      "symbol"
-                    ]
-                  }
-            )
+                    "properties": {"symbol": {"type": "string", "description": "The stock symbol"}},
+                    "required": ["symbol"],
+                },
+            ),
         ],
         stream=False,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(result, LLMResult)
     assert isinstance(result.message, AssistantPromptMessage)
     assert len(result.message.tool_calls) > 0
 
-@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
+
+@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
 def test_invoke_stream_chat_model(setup_openai_mock):
     model = UpstageLargeLanguageModel()
 
     result = model.invoke(
-        model='solar-1-mini-chat',
-        credentials={
-            'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
-        },
+        model="solar-1-mini-chat",
+        credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")},
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
+            UserPromptMessage(content="Hello World!"),
         ],
-        model_parameters={
-            'temperature': 0.0,
-            'max_tokens': 100
-        },
+        model_parameters={"temperature": 0.0, "max_tokens": 100},
         stream=True,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(result, Generator)
@@ -189,57 +151,36 @@ def test_get_num_tokens():
     model = UpstageLargeLanguageModel()
 
     num_tokens = model.get_num_tokens(
-        model='solar-1-mini-chat',
-        credentials={
-            'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
-        },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ]
+        model="solar-1-mini-chat",
+        credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")},
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
     )
 
     assert num_tokens == 13
 
     num_tokens = model.get_num_tokens(
-        model='solar-1-mini-chat',
-        credentials={
-            'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
-        },
+        model="solar-1-mini-chat",
+        credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")},
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
+            UserPromptMessage(content="Hello World!"),
         ],
         tools=[
             PromptMessageTool(
-                name='get_weather',
-                description='Determine weather in my location',
+                name="get_weather",
+                description="Determine weather in my location",
                 parameters={
                     "type": "object",
                     "properties": {
-                      "location": {
-                        "type": "string",
-                        "description": "The city and state e.g. San Francisco, CA"
-                      },
-                      "unit": {
-                        "type": "string",
-                        "enum": [
-                          "c",
-                          "f"
-                        ]
-                      }
+                        "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
+                        "unit": {"type": "string", "enum": ["c", "f"]},
                     },
-                    "required": [
-                      "location"
-                    ]
-                }
+                    "required": ["location"],
+                },
             ),
-        ]
+        ],
     )
 
     assert num_tokens == 106

+ 3 - 9
api/tests/integration_tests/model_runtime/upstage/test_provider.py

@@ -7,17 +7,11 @@ from core.model_runtime.model_providers.upstage.upstage import UpstageProvider
 from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
 
 
-@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
+@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
 def test_validate_provider_credentials(setup_openai_mock):
     provider = UpstageProvider()
 
     with pytest.raises(CredentialsValidateFailedError):
-        provider.validate_provider_credentials(
-            credentials={}
-        )
+        provider.validate_provider_credentials(credentials={})
 
-    provider.validate_provider_credentials(
-        credentials={
-            'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
-        }
-    )
+    provider.validate_provider_credentials(credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")})

+ 12 - 25
api/tests/integration_tests/model_runtime/upstage/test_text_embedding.py

@@ -8,41 +8,31 @@ from core.model_runtime.model_providers.upstage.text_embedding.text_embedding im
 from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
 
 
-@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
+@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
 def test_validate_credentials(setup_openai_mock):
     model = UpstageTextEmbeddingModel()
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
-            model='solar-embedding-1-large-passage',
-            credentials={
-                'upstage_api_key': 'invalid_key'
-            }
+            model="solar-embedding-1-large-passage", credentials={"upstage_api_key": "invalid_key"}
         )
 
     model.validate_credentials(
-        model='solar-embedding-1-large-passage',
-        credentials={
-            'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
-        }
+        model="solar-embedding-1-large-passage", credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}
     )
 
-@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
+
+@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
 def test_invoke_model(setup_openai_mock):
     model = UpstageTextEmbeddingModel()
 
     result = model.invoke(
-        model='solar-embedding-1-large-passage',
+        model="solar-embedding-1-large-passage",
         credentials={
-            'upstage_api_key': os.environ.get('UPSTAGE_API_KEY'),
+            "upstage_api_key": os.environ.get("UPSTAGE_API_KEY"),
         },
-        texts=[
-            "hello",
-            "world",
-            " ".join(["long_text"] * 100),
-            " ".join(["another_long_text"] * 100)
-        ],
-        user="abc-123"
+        texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)],
+        user="abc-123",
     )
 
     assert isinstance(result, TextEmbeddingResult)
@@ -54,14 +44,11 @@ def test_get_num_tokens():
     model = UpstageTextEmbeddingModel()
 
     num_tokens = model.get_num_tokens(
-        model='solar-embedding-1-large-passage',
+        model="solar-embedding-1-large-passage",
         credentials={
-            'upstage_api_key': os.environ.get('UPSTAGE_API_KEY'),
+            "upstage_api_key": os.environ.get("UPSTAGE_API_KEY"),
         },
-        texts=[
-            "hello",
-            "world"
-        ]
+        texts=["hello", "world"],
     )
 
     assert num_tokens == 5

+ 32 - 38
api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py

@@ -14,26 +14,26 @@ def test_validate_credentials():
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
-            model='NOT IMPORTANT',
+            model="NOT IMPORTANT",
             credentials={
-                'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
-                'volc_region': 'cn-beijing',
-                'volc_access_key_id': 'INVALID',
-                'volc_secret_access_key': 'INVALID',
-                'endpoint_id': 'INVALID',
-                'base_model_name': 'Doubao-embedding',
-            }
+                "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
+                "volc_region": "cn-beijing",
+                "volc_access_key_id": "INVALID",
+                "volc_secret_access_key": "INVALID",
+                "endpoint_id": "INVALID",
+                "base_model_name": "Doubao-embedding",
+            },
         )
 
     model.validate_credentials(
-        model='NOT IMPORTANT',
+        model="NOT IMPORTANT",
         credentials={
-            'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
-            'volc_region': 'cn-beijing',
-            'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
-            'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
-            'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
-            'base_model_name': 'Doubao-embedding',
+            "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
+            "volc_region": "cn-beijing",
+            "volc_access_key_id": os.environ.get("VOLC_API_KEY"),
+            "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"),
+            "endpoint_id": os.environ.get("VOLC_EMBEDDING_ENDPOINT_ID"),
+            "base_model_name": "Doubao-embedding",
         },
     )
 
@@ -42,20 +42,17 @@ def test_invoke_model():
     model = VolcengineMaaSTextEmbeddingModel()
 
     result = model.invoke(
-        model='NOT IMPORTANT',
+        model="NOT IMPORTANT",
         credentials={
-            'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
-            'volc_region': 'cn-beijing',
-            'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
-            'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
-            'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
-            'base_model_name': 'Doubao-embedding',
+            "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
+            "volc_region": "cn-beijing",
+            "volc_access_key_id": os.environ.get("VOLC_API_KEY"),
+            "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"),
+            "endpoint_id": os.environ.get("VOLC_EMBEDDING_ENDPOINT_ID"),
+            "base_model_name": "Doubao-embedding",
         },
-        texts=[
-            "hello",
-            "world"
-        ],
-        user="abc-123"
+        texts=["hello", "world"],
+        user="abc-123",
     )
 
     assert isinstance(result, TextEmbeddingResult)
@@ -67,19 +64,16 @@ def test_get_num_tokens():
     model = VolcengineMaaSTextEmbeddingModel()
 
     num_tokens = model.get_num_tokens(
-        model='NOT IMPORTANT',
+        model="NOT IMPORTANT",
         credentials={
-            'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
-            'volc_region': 'cn-beijing',
-            'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
-            'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
-            'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
-            'base_model_name': 'Doubao-embedding',
+            "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
+            "volc_region": "cn-beijing",
+            "volc_access_key_id": os.environ.get("VOLC_API_KEY"),
+            "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"),
+            "endpoint_id": os.environ.get("VOLC_EMBEDDING_ENDPOINT_ID"),
+            "base_model_name": "Doubao-embedding",
         },
-        texts=[
-            "hello",
-            "world"
-        ]
+        texts=["hello", "world"],
     )
 
     assert num_tokens == 2

+ 50 - 63
api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py

@@ -14,25 +14,25 @@ def test_validate_credentials_for_chat_model():
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
-            model='NOT IMPORTANT',
+            model="NOT IMPORTANT",
             credentials={
-                'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
-                'volc_region': 'cn-beijing',
-                'volc_access_key_id': 'INVALID',
-                'volc_secret_access_key': 'INVALID',
-                'endpoint_id': 'INVALID',
-            }
+                "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
+                "volc_region": "cn-beijing",
+                "volc_access_key_id": "INVALID",
+                "volc_secret_access_key": "INVALID",
+                "endpoint_id": "INVALID",
+            },
         )
 
     model.validate_credentials(
-        model='NOT IMPORTANT',
+        model="NOT IMPORTANT",
         credentials={
-            'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
-            'volc_region': 'cn-beijing',
-            'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
-            'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
-            'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'),
-        }
+            "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
+            "volc_region": "cn-beijing",
+            "volc_access_key_id": os.environ.get("VOLC_API_KEY"),
+            "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"),
+            "endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"),
+        },
     )
 
 
@@ -40,28 +40,24 @@ def test_invoke_model():
     model = VolcengineMaaSLargeLanguageModel()
 
     response = model.invoke(
-        model='NOT IMPORTANT',
+        model="NOT IMPORTANT",
         credentials={
-            'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
-            'volc_region': 'cn-beijing',
-            'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
-            'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
-            'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'),
-            'base_model_name': 'Skylark2-pro-4k',
+            "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
+            "volc_region": "cn-beijing",
+            "volc_access_key_id": os.environ.get("VOLC_API_KEY"),
+            "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"),
+            "endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"),
+            "base_model_name": "Skylark2-pro-4k",
         },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ],
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
         model_parameters={
-            'temperature': 0.7,
-            'top_p': 1.0,
-            'top_k': 1,
+            "temperature": 0.7,
+            "top_p": 1.0,
+            "top_k": 1,
         },
-        stop=['you'],
+        stop=["you"],
         user="abc-123",
-        stream=False
+        stream=False,
     )
 
     assert isinstance(response, LLMResult)
@@ -73,28 +69,24 @@ def test_invoke_stream_model():
     model = VolcengineMaaSLargeLanguageModel()
 
     response = model.invoke(
-        model='NOT IMPORTANT',
+        model="NOT IMPORTANT",
         credentials={
-            'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
-            'volc_region': 'cn-beijing',
-            'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
-            'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
-            'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'),
-            'base_model_name': 'Skylark2-pro-4k',
+            "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
+            "volc_region": "cn-beijing",
+            "volc_access_key_id": os.environ.get("VOLC_API_KEY"),
+            "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"),
+            "endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"),
+            "base_model_name": "Skylark2-pro-4k",
         },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ],
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
         model_parameters={
-            'temperature': 0.7,
-            'top_p': 1.0,
-            'top_k': 1,
+            "temperature": 0.7,
+            "top_p": 1.0,
+            "top_k": 1,
         },
-        stop=['you'],
+        stop=["you"],
         stream=True,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, Generator)
@@ -102,29 +94,24 @@ def test_invoke_stream_model():
         assert isinstance(chunk, LLMResultChunk)
         assert isinstance(chunk.delta, LLMResultChunkDelta)
         assert isinstance(chunk.delta.message, AssistantPromptMessage)
-        assert len(
-            chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
+        assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
 
 
 def test_get_num_tokens():
     model = VolcengineMaaSLargeLanguageModel()
 
     response = model.get_num_tokens(
-        model='NOT IMPORTANT',
+        model="NOT IMPORTANT",
         credentials={
-            'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
-            'volc_region': 'cn-beijing',
-            'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
-            'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
-            'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'),
-            'base_model_name': 'Skylark2-pro-4k',
+            "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
+            "volc_region": "cn-beijing",
+            "volc_access_key_id": os.environ.get("VOLC_API_KEY"),
+            "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"),
+            "endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"),
+            "base_model_name": "Skylark2-pro-4k",
         },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ],
-        tools=[]
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
+        tools=[],
     )
 
     assert isinstance(response, int)

+ 16 - 28
api/tests/integration_tests/model_runtime/wenxin/test_embedding.py

@@ -10,13 +10,10 @@ def test_invoke_embedding_v1():
     model = WenxinTextEmbeddingModel()
 
     response = model.invoke(
-        model='embedding-v1',
-        credentials={
-            'api_key': os.environ.get('WENXIN_API_KEY'),
-            'secret_key': os.environ.get('WENXIN_SECRET_KEY')
-        },
-        texts=['hello', '你好', 'xxxxx'],
-        user="abc-123"
+        model="embedding-v1",
+        credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
+        texts=["hello", "你好", "xxxxx"],
+        user="abc-123",
     )
 
     assert isinstance(response, TextEmbeddingResult)
@@ -29,13 +26,10 @@ def test_invoke_embedding_bge_large_en():
     model = WenxinTextEmbeddingModel()
 
     response = model.invoke(
-        model='bge-large-en',
-        credentials={
-            'api_key': os.environ.get('WENXIN_API_KEY'),
-            'secret_key': os.environ.get('WENXIN_SECRET_KEY')
-        },
-        texts=['hello', '你好', 'xxxxx'],
-        user="abc-123"
+        model="bge-large-en",
+        credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
+        texts=["hello", "你好", "xxxxx"],
+        user="abc-123",
     )
 
     assert isinstance(response, TextEmbeddingResult)
@@ -48,13 +42,10 @@ def test_invoke_embedding_bge_large_zh():
     model = WenxinTextEmbeddingModel()
 
     response = model.invoke(
-        model='bge-large-zh',
-        credentials={
-            'api_key': os.environ.get('WENXIN_API_KEY'),
-            'secret_key': os.environ.get('WENXIN_SECRET_KEY')
-        },
-        texts=['hello', '你好', 'xxxxx'],
-        user="abc-123"
+        model="bge-large-zh",
+        credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
+        texts=["hello", "你好", "xxxxx"],
+        user="abc-123",
     )
 
     assert isinstance(response, TextEmbeddingResult)
@@ -67,13 +58,10 @@ def test_invoke_embedding_tao_8k():
     model = WenxinTextEmbeddingModel()
 
     response = model.invoke(
-        model='tao-8k',
-        credentials={
-            'api_key': os.environ.get('WENXIN_API_KEY'),
-            'secret_key': os.environ.get('WENXIN_SECRET_KEY')
-        },
-        texts=['hello', '你好', 'xxxxx'],
-        user="abc-123"
+        model="tao-8k",
+        credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
+        texts=["hello", "你好", "xxxxx"],
+        user="abc-123",
     )
 
     assert isinstance(response, TextEmbeddingResult)

+ 69 - 126
api/tests/integration_tests/model_runtime/wenxin/test_llm.py

@@ -17,161 +17,125 @@ def test_predefined_models():
     assert len(model_schemas) >= 1
     assert isinstance(model_schemas[0], AIModelEntity)
 
+
 def test_validate_credentials_for_chat_model():
     sleep(3)
     model = ErnieBotLargeLanguageModel()
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
-            model='ernie-bot',
-            credentials={
-                'api_key': 'invalid_key',
-                'secret_key': 'invalid_key'
-            }
+            model="ernie-bot", credentials={"api_key": "invalid_key", "secret_key": "invalid_key"}
         )
 
     model.validate_credentials(
-        model='ernie-bot',
-        credentials={
-            'api_key': os.environ.get('WENXIN_API_KEY'),
-            'secret_key': os.environ.get('WENXIN_SECRET_KEY')
-        }
+        model="ernie-bot",
+        credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
     )
 
+
 def test_invoke_model_ernie_bot():
     sleep(3)
     model = ErnieBotLargeLanguageModel()
 
     response = model.invoke(
-        model='ernie-bot',
-        credentials={
-            'api_key': os.environ.get('WENXIN_API_KEY'),
-            'secret_key': os.environ.get('WENXIN_SECRET_KEY')
-        },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ],
+        model="ernie-bot",
+        credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
         model_parameters={
-            'temperature': 0.7,
-            'top_p': 1.0,
+            "temperature": 0.7,
+            "top_p": 1.0,
         },
-        stop=['you'],
+        stop=["you"],
         user="abc-123",
-        stream=False
+        stream=False,
     )
 
     assert isinstance(response, LLMResult)
     assert len(response.message.content) > 0
     assert response.usage.total_tokens > 0
 
+
 def test_invoke_model_ernie_bot_turbo():
     sleep(3)
     model = ErnieBotLargeLanguageModel()
 
     response = model.invoke(
-        model='ernie-bot-turbo',
-        credentials={
-            'api_key': os.environ.get('WENXIN_API_KEY'),
-            'secret_key': os.environ.get('WENXIN_SECRET_KEY')
-        },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ],
+        model="ernie-bot-turbo",
+        credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
         model_parameters={
-            'temperature': 0.7,
-            'top_p': 1.0,
+            "temperature": 0.7,
+            "top_p": 1.0,
         },
-        stop=['you'],
+        stop=["you"],
         user="abc-123",
-        stream=False
+        stream=False,
     )
 
     assert isinstance(response, LLMResult)
     assert len(response.message.content) > 0
     assert response.usage.total_tokens > 0
 
+
 def test_invoke_model_ernie_8k():
     sleep(3)
     model = ErnieBotLargeLanguageModel()
 
     response = model.invoke(
-        model='ernie-bot-8k',
-        credentials={
-            'api_key': os.environ.get('WENXIN_API_KEY'),
-            'secret_key': os.environ.get('WENXIN_SECRET_KEY')
-        },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ],
+        model="ernie-bot-8k",
+        credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
         model_parameters={
-            'temperature': 0.7,
-            'top_p': 1.0,
+            "temperature": 0.7,
+            "top_p": 1.0,
         },
-        stop=['you'],
+        stop=["you"],
         user="abc-123",
-        stream=False
+        stream=False,
     )
 
     assert isinstance(response, LLMResult)
     assert len(response.message.content) > 0
     assert response.usage.total_tokens > 0
 
+
 def test_invoke_model_ernie_bot_4():
     sleep(3)
     model = ErnieBotLargeLanguageModel()
 
     response = model.invoke(
-        model='ernie-bot-4',
-        credentials={
-            'api_key': os.environ.get('WENXIN_API_KEY'),
-            'secret_key': os.environ.get('WENXIN_SECRET_KEY')
-        },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ],
+        model="ernie-bot-4",
+        credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
         model_parameters={
-            'temperature': 0.7,
-            'top_p': 1.0,
+            "temperature": 0.7,
+            "top_p": 1.0,
         },
-        stop=['you'],
+        stop=["you"],
         user="abc-123",
-        stream=False
+        stream=False,
     )
 
     assert isinstance(response, LLMResult)
     assert len(response.message.content) > 0
     assert response.usage.total_tokens > 0
 
+
 def test_invoke_stream_model():
     sleep(3)
     model = ErnieBotLargeLanguageModel()
 
     response = model.invoke(
-        model='ernie-3.5-8k',
-        credentials={
-            'api_key': os.environ.get('WENXIN_API_KEY'),
-            'secret_key': os.environ.get('WENXIN_SECRET_KEY')
-        },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ],
+        model="ernie-3.5-8k",
+        credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
         model_parameters={
-            'temperature': 0.7,
-            'top_p': 1.0,
+            "temperature": 0.7,
+            "top_p": 1.0,
         },
-        stop=['you'],
+        stop=["you"],
         stream=True,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, Generator)
@@ -181,63 +145,48 @@ def test_invoke_stream_model():
         assert isinstance(chunk.delta.message, AssistantPromptMessage)
         assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
 
+
 def test_invoke_model_with_system():
     sleep(3)
     model = ErnieBotLargeLanguageModel()
 
     response = model.invoke(
-        model='ernie-bot',
-        credentials={
-            'api_key': os.environ.get('WENXIN_API_KEY'),
-            'secret_key': os.environ.get('WENXIN_SECRET_KEY')
-        },
-        prompt_messages=[
-            SystemPromptMessage(
-                content='你是Kasumi'
-            ),
-            UserPromptMessage(
-                content='你是谁?'
-            )
-        ],
+        model="ernie-bot",
+        credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
+        prompt_messages=[SystemPromptMessage(content="你是Kasumi"), UserPromptMessage(content="你是谁?")],
         model_parameters={
-            'temperature': 0.7,
-            'top_p': 1.0,
+            "temperature": 0.7,
+            "top_p": 1.0,
         },
-        stop=['you'],
+        stop=["you"],
         stream=False,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, LLMResult)
-    assert 'kasumi' in response.message.content.lower()
+    assert "kasumi" in response.message.content.lower()
+
 
 def test_invoke_with_search():
     sleep(3)
     model = ErnieBotLargeLanguageModel()
 
     response = model.invoke(
-        model='ernie-bot',
-        credentials={
-            'api_key': os.environ.get('WENXIN_API_KEY'),
-            'secret_key': os.environ.get('WENXIN_SECRET_KEY')
-        },
-        prompt_messages=[
-            UserPromptMessage(
-                content='北京今天的天气怎么样'
-            )
-        ],
+        model="ernie-bot",
+        credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
+        prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
         model_parameters={
-            'temperature': 0.7,
-            'top_p': 1.0,
-            'disable_search': True,
+            "temperature": 0.7,
+            "top_p": 1.0,
+            "disable_search": True,
         },
         stop=[],
         stream=True,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, Generator)
-    total_message = ''
+    total_message = ""
     for chunk in response:
         assert isinstance(chunk, LLMResultChunk)
         assert isinstance(chunk.delta, LLMResultChunkDelta)
@@ -247,25 +196,19 @@ def test_invoke_with_search():
         assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True
 
     # there should be 对不起、我不能、不支持……
-    assert ('不' in total_message or '抱歉' in total_message or '无法' in total_message)
+    assert "不" in total_message or "抱歉" in total_message or "无法" in total_message
+
 
 def test_get_num_tokens():
     sleep(3)
     model = ErnieBotLargeLanguageModel()
 
     response = model.get_num_tokens(
-        model='ernie-bot',
-        credentials={
-            'api_key': os.environ.get('WENXIN_API_KEY'),
-            'secret_key': os.environ.get('WENXIN_SECRET_KEY')
-        },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ],
-        tools=[]
+        model="ernie-bot",
+        credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
+        tools=[],
     )
 
     assert isinstance(response, int)
-    assert response == 10
+    assert response == 10

+ 2 - 10
api/tests/integration_tests/model_runtime/wenxin/test_provider.py

@@ -10,16 +10,8 @@ def test_validate_provider_credentials():
     provider = WenxinProvider()
 
     with pytest.raises(CredentialsValidateFailedError):
-        provider.validate_provider_credentials(
-            credentials={
-                'api_key': 'hahahaha',
-                'secret_key': 'hahahaha'
-            }
-        )
+        provider.validate_provider_credentials(credentials={"api_key": "hahahaha", "secret_key": "hahahaha"})
 
     provider.validate_provider_credentials(
-        credentials={
-            'api_key': os.environ.get('WENXIN_API_KEY'),
-            'secret_key': os.environ.get('WENXIN_SECRET_KEY')
-        }
+        credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}
     )

+ 21 - 25
api/tests/integration_tests/model_runtime/xinference/test_embeddings.py

@@ -8,61 +8,57 @@ from core.model_runtime.model_providers.xinference.text_embedding.text_embedding
 from tests.integration_tests.model_runtime.__mock.xinference import MOCK, setup_xinference_mock
 
 
-@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True)
+@pytest.mark.parametrize("setup_xinference_mock", [["none"]], indirect=True)
 def test_validate_credentials(setup_xinference_mock):
     model = XinferenceTextEmbeddingModel()
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
-            model='bge-base-en',
+            model="bge-base-en",
             credentials={
-                'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
-                'model_uid': 'www ' + os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID')
-            }
+                "server_url": os.environ.get("XINFERENCE_SERVER_URL"),
+                "model_uid": "www " + os.environ.get("XINFERENCE_EMBEDDINGS_MODEL_UID"),
+            },
         )
 
     model.validate_credentials(
-        model='bge-base-en',
+        model="bge-base-en",
         credentials={
-            'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
-            'model_uid': os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID')
-        }
+            "server_url": os.environ.get("XINFERENCE_SERVER_URL"),
+            "model_uid": os.environ.get("XINFERENCE_EMBEDDINGS_MODEL_UID"),
+        },
     )
 
-@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True)
+
+@pytest.mark.parametrize("setup_xinference_mock", [["none"]], indirect=True)
 def test_invoke_model(setup_xinference_mock):
     model = XinferenceTextEmbeddingModel()
 
     result = model.invoke(
-        model='bge-base-en',
+        model="bge-base-en",
         credentials={
-            'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
-            'model_uid': os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID')
+            "server_url": os.environ.get("XINFERENCE_SERVER_URL"),
+            "model_uid": os.environ.get("XINFERENCE_EMBEDDINGS_MODEL_UID"),
         },
-        texts=[
-            "hello",
-            "world"
-        ],
-        user="abc-123"
+        texts=["hello", "world"],
+        user="abc-123",
     )
 
     assert isinstance(result, TextEmbeddingResult)
     assert len(result.embeddings) == 2
     assert result.usage.total_tokens > 0
 
+
 def test_get_num_tokens():
     model = XinferenceTextEmbeddingModel()
 
     num_tokens = model.get_num_tokens(
-        model='bge-base-en',
+        model="bge-base-en",
         credentials={
-            'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
-            'model_uid': os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID')
+            "server_url": os.environ.get("XINFERENCE_SERVER_URL"),
+            "model_uid": os.environ.get("XINFERENCE_EMBEDDINGS_MODEL_UID"),
         },
-        texts=[
-            "hello",
-            "world"
-        ]
+        texts=["hello", "world"],
     )
 
     assert num_tokens == 2

+ 85 - 116
api/tests/integration_tests/model_runtime/xinference/test_llm.py

@@ -20,92 +20,84 @@ from tests.integration_tests.model_runtime.__mock.openai import setup_openai_moc
 from tests.integration_tests.model_runtime.__mock.xinference import setup_xinference_mock
 
 
-@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True)
+@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["chat", "none"]], indirect=True)
 def test_validate_credentials_for_chat_model(setup_openai_mock, setup_xinference_mock):
     model = XinferenceAILargeLanguageModel()
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
-            model='ChatGLM3',
+            model="ChatGLM3",
             credentials={
-                'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
-                'model_uid': 'www ' + os.environ.get('XINFERENCE_CHAT_MODEL_UID')
-            }
+                "server_url": os.environ.get("XINFERENCE_SERVER_URL"),
+                "model_uid": "www " + os.environ.get("XINFERENCE_CHAT_MODEL_UID"),
+            },
         )
 
     with pytest.raises(CredentialsValidateFailedError):
-        model.validate_credentials(
-            model='aaaaa',
-            credentials={
-                'server_url': '',
-                'model_uid': ''
-            }
-        )
+        model.validate_credentials(model="aaaaa", credentials={"server_url": "", "model_uid": ""})
 
     model.validate_credentials(
-        model='ChatGLM3',
+        model="ChatGLM3",
         credentials={
-            'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
-            'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID')
-        }
+            "server_url": os.environ.get("XINFERENCE_SERVER_URL"),
+            "model_uid": os.environ.get("XINFERENCE_CHAT_MODEL_UID"),
+        },
     )
 
-@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True)
+
+@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["chat", "none"]], indirect=True)
 def test_invoke_chat_model(setup_openai_mock, setup_xinference_mock):
     model = XinferenceAILargeLanguageModel()
 
     response = model.invoke(
-        model='ChatGLM3',
+        model="ChatGLM3",
         credentials={
-            'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
-            'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID')
+            "server_url": os.environ.get("XINFERENCE_SERVER_URL"),
+            "model_uid": os.environ.get("XINFERENCE_CHAT_MODEL_UID"),
         },
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
+            UserPromptMessage(content="Hello World!"),
         ],
         model_parameters={
-            'temperature': 0.7,
-            'top_p': 1.0,
+            "temperature": 0.7,
+            "top_p": 1.0,
         },
-        stop=['you'],
+        stop=["you"],
         user="abc-123",
-        stream=False
+        stream=False,
     )
 
     assert isinstance(response, LLMResult)
     assert len(response.message.content) > 0
     assert response.usage.total_tokens > 0
 
-@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True)
+
+@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["chat", "none"]], indirect=True)
 def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock):
     model = XinferenceAILargeLanguageModel()
 
     response = model.invoke(
-        model='ChatGLM3',
+        model="ChatGLM3",
         credentials={
-            'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
-            'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID')
+            "server_url": os.environ.get("XINFERENCE_SERVER_URL"),
+            "model_uid": os.environ.get("XINFERENCE_CHAT_MODEL_UID"),
         },
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
+            UserPromptMessage(content="Hello World!"),
         ],
         model_parameters={
-            'temperature': 0.7,
-            'top_p': 1.0,
+            "temperature": 0.7,
+            "top_p": 1.0,
         },
-        stop=['you'],
+        stop=["you"],
         stream=True,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, Generator)
@@ -114,6 +106,8 @@ def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock):
         assert isinstance(chunk.delta, LLMResultChunkDelta)
         assert isinstance(chunk.delta.message, AssistantPromptMessage)
         assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
+
+
 """
     Funtion calling of xinference does not support stream mode currently
 """
@@ -168,7 +162,7 @@ def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock):
 #     )
 
 #     assert isinstance(response, Generator)
-    
+
 #     call: LLMResultChunk = None
 #     chunks = []
 
@@ -241,86 +235,75 @@ def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock):
 #     assert response.usage.total_tokens > 0
 #     assert response.message.tool_calls[0].function.name == 'get_current_weather'
 
-@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True)
+
+@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["completion", "none"]], indirect=True)
 def test_validate_credentials_for_generation_model(setup_openai_mock, setup_xinference_mock):
     model = XinferenceAILargeLanguageModel()
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
-            model='alapaca',
+            model="alapaca",
             credentials={
-                'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
-                'model_uid': 'www ' + os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
-            }
+                "server_url": os.environ.get("XINFERENCE_SERVER_URL"),
+                "model_uid": "www " + os.environ.get("XINFERENCE_GENERATION_MODEL_UID"),
+            },
         )
 
     with pytest.raises(CredentialsValidateFailedError):
-        model.validate_credentials(
-            model='alapaca',
-            credentials={
-                'server_url': '',
-                'model_uid': ''
-            }
-        )
+        model.validate_credentials(model="alapaca", credentials={"server_url": "", "model_uid": ""})
 
     model.validate_credentials(
-        model='alapaca',
+        model="alapaca",
         credentials={
-            'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
-            'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
-        }
+            "server_url": os.environ.get("XINFERENCE_SERVER_URL"),
+            "model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"),
+        },
     )
 
-@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True)
+
+@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["completion", "none"]], indirect=True)
 def test_invoke_generation_model(setup_openai_mock, setup_xinference_mock):
     model = XinferenceAILargeLanguageModel()
 
     response = model.invoke(
-        model='alapaca',
+        model="alapaca",
         credentials={
-            'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
-            'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
+            "server_url": os.environ.get("XINFERENCE_SERVER_URL"),
+            "model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"),
         },
-        prompt_messages=[
-            UserPromptMessage(
-                content='the United States is'
-            )
-        ],
+        prompt_messages=[UserPromptMessage(content="the United States is")],
         model_parameters={
-            'temperature': 0.7,
-            'top_p': 1.0,
+            "temperature": 0.7,
+            "top_p": 1.0,
         },
-        stop=['you'],
+        stop=["you"],
         user="abc-123",
-        stream=False
+        stream=False,
     )
 
     assert isinstance(response, LLMResult)
     assert len(response.message.content) > 0
     assert response.usage.total_tokens > 0
 
-@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True)
+
+@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["completion", "none"]], indirect=True)
 def test_invoke_stream_generation_model(setup_openai_mock, setup_xinference_mock):
     model = XinferenceAILargeLanguageModel()
 
     response = model.invoke(
-        model='alapaca',
+        model="alapaca",
         credentials={
-            'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
-            'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
+            "server_url": os.environ.get("XINFERENCE_SERVER_URL"),
+            "model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"),
         },
-        prompt_messages=[
-            UserPromptMessage(
-                content='the United States is'
-            )
-        ],
+        prompt_messages=[UserPromptMessage(content="the United States is")],
         model_parameters={
-            'temperature': 0.7,
-            'top_p': 1.0,
+            "temperature": 0.7,
+            "top_p": 1.0,
         },
-        stop=['you'],
+        stop=["you"],
         stream=True,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, Generator)
@@ -330,68 +313,54 @@ def test_invoke_stream_generation_model(setup_openai_mock, setup_xinference_mock
         assert isinstance(chunk.delta.message, AssistantPromptMessage)
         assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
 
+
 def test_get_num_tokens():
     model = XinferenceAILargeLanguageModel()
 
     num_tokens = model.get_num_tokens(
-        model='ChatGLM3',
+        model="ChatGLM3",
         credentials={
-            'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
-            'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
+            "server_url": os.environ.get("XINFERENCE_SERVER_URL"),
+            "model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"),
         },
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
+            UserPromptMessage(content="Hello World!"),
         ],
         tools=[
             PromptMessageTool(
-                name='get_current_weather',
-                description='Get the current weather in a given location',
+                name="get_current_weather",
+                description="Get the current weather in a given location",
                 parameters={
                     "type": "object",
                     "properties": {
-                        "location": {
-                        "type": "string",
-                            "description": "The city and state e.g. San Francisco, CA"
-                        },
-                        "unit": {
-                            "type": "string",
-                            "enum": [
-                                "c",
-                                "f"
-                            ]
-                        }
+                        "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
+                        "unit": {"type": "string", "enum": ["c", "f"]},
                     },
-                    "required": [
-                        "location"
-                    ]
-                }
+                    "required": ["location"],
+                },
             )
-        ]
+        ],
     )
 
     assert isinstance(num_tokens, int)
     assert num_tokens == 77
 
     num_tokens = model.get_num_tokens(
-        model='ChatGLM3',
+        model="ChatGLM3",
         credentials={
-            'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
-            'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
+            "server_url": os.environ.get("XINFERENCE_SERVER_URL"),
+            "model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"),
         },
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
+            UserPromptMessage(content="Hello World!"),
         ],
     )
 
     assert isinstance(num_tokens, int)
-    assert num_tokens == 21
+    assert num_tokens == 21

+ 15 - 17
api/tests/integration_tests/model_runtime/xinference/test_rerank.py

@@ -8,44 +8,42 @@ from core.model_runtime.model_providers.xinference.rerank.rerank import Xinferen
 from tests.integration_tests.model_runtime.__mock.xinference import MOCK, setup_xinference_mock
 
 
-@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True)
+@pytest.mark.parametrize("setup_xinference_mock", [["none"]], indirect=True)
 def test_validate_credentials(setup_xinference_mock):
     model = XinferenceRerankModel()
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
-            model='bge-reranker-base',
-            credentials={
-                'server_url': 'awdawdaw',
-                'model_uid': os.environ.get('XINFERENCE_RERANK_MODEL_UID')
-            }
+            model="bge-reranker-base",
+            credentials={"server_url": "awdawdaw", "model_uid": os.environ.get("XINFERENCE_RERANK_MODEL_UID")},
         )
 
     model.validate_credentials(
-        model='bge-reranker-base',
+        model="bge-reranker-base",
         credentials={
-            'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
-            'model_uid': os.environ.get('XINFERENCE_RERANK_MODEL_UID')
-        }
+            "server_url": os.environ.get("XINFERENCE_SERVER_URL"),
+            "model_uid": os.environ.get("XINFERENCE_RERANK_MODEL_UID"),
+        },
     )
 
-@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True)
+
+@pytest.mark.parametrize("setup_xinference_mock", [["none"]], indirect=True)
 def test_invoke_model(setup_xinference_mock):
     model = XinferenceRerankModel()
 
     result = model.invoke(
-        model='bge-reranker-base',
+        model="bge-reranker-base",
         credentials={
-            'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
-            'model_uid': os.environ.get('XINFERENCE_RERANK_MODEL_UID')
+            "server_url": os.environ.get("XINFERENCE_SERVER_URL"),
+            "model_uid": os.environ.get("XINFERENCE_RERANK_MODEL_UID"),
         },
         query="Who is Kasumi?",
         docs=[
-            "Kasumi is a girl's name of Japanese origin meaning \"mist\".",
+            'Kasumi is a girl\'s name of Japanese origin meaning "mist".',
             "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ",
-            "and she leads a team named PopiParty."
+            "and she leads a team named PopiParty.",
         ],
-        score_threshold=0.8
+        score_threshold=0.8,
     )
 
     assert isinstance(result, RerankResult)

+ 19 - 52
api/tests/integration_tests/model_runtime/zhinao/test_llm.py

@@ -13,41 +13,22 @@ def test_validate_credentials():
     model = ZhinaoLargeLanguageModel()
 
     with pytest.raises(CredentialsValidateFailedError):
-        model.validate_credentials(
-            model='360gpt2-pro',
-            credentials={
-                'api_key': 'invalid_key'
-            }
-        )
-
-    model.validate_credentials(
-        model='360gpt2-pro',
-        credentials={
-            'api_key': os.environ.get('ZHINAO_API_KEY')
-        }
-    )
+        model.validate_credentials(model="360gpt2-pro", credentials={"api_key": "invalid_key"})
+
+    model.validate_credentials(model="360gpt2-pro", credentials={"api_key": os.environ.get("ZHINAO_API_KEY")})
 
 
 def test_invoke_model():
     model = ZhinaoLargeLanguageModel()
 
     response = model.invoke(
-        model='360gpt2-pro',
-        credentials={
-            'api_key': os.environ.get('ZHINAO_API_KEY')
-        },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Who are you?'
-            )
-        ],
-        model_parameters={
-            'temperature': 0.5,
-            'max_tokens': 10
-        },
-        stop=['How'],
+        model="360gpt2-pro",
+        credentials={"api_key": os.environ.get("ZHINAO_API_KEY")},
+        prompt_messages=[UserPromptMessage(content="Who are you?")],
+        model_parameters={"temperature": 0.5, "max_tokens": 10},
+        stop=["How"],
         stream=False,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, LLMResult)
@@ -58,22 +39,12 @@ def test_invoke_stream_model():
     model = ZhinaoLargeLanguageModel()
 
     response = model.invoke(
-        model='360gpt2-pro',
-        credentials={
-            'api_key': os.environ.get('ZHINAO_API_KEY')
-        },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ],
-        model_parameters={
-            'temperature': 0.5,
-            'max_tokens': 100,
-            'seed': 1234
-        },
+        model="360gpt2-pro",
+        credentials={"api_key": os.environ.get("ZHINAO_API_KEY")},
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
+        model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234},
         stream=True,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, Generator)
@@ -89,18 +60,14 @@ def test_get_num_tokens():
     model = ZhinaoLargeLanguageModel()
 
     num_tokens = model.get_num_tokens(
-        model='360gpt2-pro',
-        credentials={
-            'api_key': os.environ.get('ZHINAO_API_KEY')
-        },
+        model="360gpt2-pro",
+        credentials={"api_key": os.environ.get("ZHINAO_API_KEY")},
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ]
+            UserPromptMessage(content="Hello World!"),
+        ],
     )
 
     assert num_tokens == 21

+ 2 - 8
api/tests/integration_tests/model_runtime/zhinao/test_provider.py

@@ -10,12 +10,6 @@ def test_validate_provider_credentials():
     provider = ZhinaoProvider()
 
     with pytest.raises(CredentialsValidateFailedError):
-        provider.validate_provider_credentials(
-            credentials={}
-        )
+        provider.validate_provider_credentials(credentials={})
 
-    provider.validate_provider_credentials(
-        credentials={
-            'api_key': os.environ.get('ZHINAO_API_KEY')
-        }
-    )
+    provider.validate_provider_credentials(credentials={"api_key": os.environ.get("ZHINAO_API_KEY")})

+ 31 - 77
api/tests/integration_tests/model_runtime/zhipuai/test_llm.py

@@ -18,41 +18,22 @@ def test_validate_credentials():
     model = ZhipuAILargeLanguageModel()
 
     with pytest.raises(CredentialsValidateFailedError):
-        model.validate_credentials(
-            model='chatglm_turbo',
-            credentials={
-                'api_key': 'invalid_key'
-            }
-        )
-
-    model.validate_credentials(
-        model='chatglm_turbo',
-        credentials={
-            'api_key': os.environ.get('ZHIPUAI_API_KEY')
-        }
-    )
+        model.validate_credentials(model="chatglm_turbo", credentials={"api_key": "invalid_key"})
+
+    model.validate_credentials(model="chatglm_turbo", credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")})
 
 
 def test_invoke_model():
     model = ZhipuAILargeLanguageModel()
 
     response = model.invoke(
-        model='chatglm_turbo',
-        credentials={
-            'api_key': os.environ.get('ZHIPUAI_API_KEY')
-        },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Who are you?'
-            )
-        ],
-        model_parameters={
-            'temperature': 0.9,
-            'top_p': 0.7
-        },
-        stop=['How'],
+        model="chatglm_turbo",
+        credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")},
+        prompt_messages=[UserPromptMessage(content="Who are you?")],
+        model_parameters={"temperature": 0.9, "top_p": 0.7},
+        stop=["How"],
         stream=False,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, LLMResult)
@@ -63,21 +44,12 @@ def test_invoke_stream_model():
     model = ZhipuAILargeLanguageModel()
 
     response = model.invoke(
-        model='chatglm_turbo',
-        credentials={
-            'api_key': os.environ.get('ZHIPUAI_API_KEY')
-        },
-        prompt_messages=[
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ],
-        model_parameters={
-            'temperature': 0.9,
-            'top_p': 0.7
-        },
+        model="chatglm_turbo",
+        credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")},
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
+        model_parameters={"temperature": 0.9, "top_p": 0.7},
         stream=True,
-        user="abc-123"
+        user="abc-123",
     )
 
     assert isinstance(response, Generator)
@@ -93,63 +65,45 @@ def test_get_num_tokens():
     model = ZhipuAILargeLanguageModel()
 
     num_tokens = model.get_num_tokens(
-        model='chatglm_turbo',
-        credentials={
-            'api_key': os.environ.get('ZHIPUAI_API_KEY')
-        },
+        model="chatglm_turbo",
+        credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")},
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ]
+            UserPromptMessage(content="Hello World!"),
+        ],
     )
 
     assert num_tokens == 14
 
+
 def test_get_tools_num_tokens():
     model = ZhipuAILargeLanguageModel()
 
     num_tokens = model.get_num_tokens(
-        model='tools',
-        credentials={
-            'api_key': os.environ.get('ZHIPUAI_API_KEY')
-        },
+        model="tools",
+        credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")},
         tools=[
             PromptMessageTool(
-                name='get_current_weather',
-                description='Get the current weather in a given location',
+                name="get_current_weather",
+                description="Get the current weather in a given location",
                 parameters={
                     "type": "object",
                     "properties": {
-                        "location": {
-                        "type": "string",
-                            "description": "The city and state e.g. San Francisco, CA"
-                        },
-                        "unit": {
-                            "type": "string",
-                            "enum": [
-                                "c",
-                                "f"
-                            ]
-                        }
+                        "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
+                        "unit": {"type": "string", "enum": ["c", "f"]},
                     },
-                    "required": [
-                        "location"
-                    ]
-                }
+                    "required": ["location"],
+                },
             )
         ],
         prompt_messages=[
             SystemPromptMessage(
-                content='You are a helpful AI assistant.',
+                content="You are a helpful AI assistant.",
             ),
-            UserPromptMessage(
-                content='Hello World!'
-            )
-        ]
+            UserPromptMessage(content="Hello World!"),
+        ],
     )
 
     assert num_tokens == 88

+ 2 - 8
api/tests/integration_tests/model_runtime/zhipuai/test_provider.py

@@ -10,12 +10,6 @@ def test_validate_provider_credentials():
     provider = ZhipuaiProvider()
 
     with pytest.raises(CredentialsValidateFailedError):
-        provider.validate_provider_credentials(
-            credentials={}
-        )
+        provider.validate_provider_credentials(credentials={})
 
-    provider.validate_provider_credentials(
-        credentials={
-            'api_key': os.environ.get('ZHIPUAI_API_KEY')
-        }
-    )
+    provider.validate_provider_credentials(credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")})

+ 8 - 30
api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py

@@ -11,34 +11,19 @@ def test_validate_credentials():
     model = ZhipuAITextEmbeddingModel()
 
     with pytest.raises(CredentialsValidateFailedError):
-        model.validate_credentials(
-            model='text_embedding',
-            credentials={
-                'api_key': 'invalid_key'
-            }
-        )
-
-    model.validate_credentials(
-        model='text_embedding',
-        credentials={
-            'api_key': os.environ.get('ZHIPUAI_API_KEY')
-        }
-    )
+        model.validate_credentials(model="text_embedding", credentials={"api_key": "invalid_key"})
+
+    model.validate_credentials(model="text_embedding", credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")})
 
 
 def test_invoke_model():
     model = ZhipuAITextEmbeddingModel()
 
     result = model.invoke(
-        model='text_embedding',
-        credentials={
-            'api_key': os.environ.get('ZHIPUAI_API_KEY')
-        },
-        texts=[
-            "hello",
-            "world"
-        ],
-        user="abc-123"
+        model="text_embedding",
+        credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")},
+        texts=["hello", "world"],
+        user="abc-123",
     )
 
     assert isinstance(result, TextEmbeddingResult)
@@ -50,14 +35,7 @@ def test_get_num_tokens():
     model = ZhipuAITextEmbeddingModel()
 
     num_tokens = model.get_num_tokens(
-        model='text_embedding',
-        credentials={
-            'api_key': os.environ.get('ZHIPUAI_API_KEY')
-        },
-        texts=[
-            "hello",
-            "world"
-        ]
+        model="text_embedding", credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}, texts=["hello", "world"]
     )
 
     assert num_tokens == 2

+ 6 - 9
api/tests/integration_tests/tools/__mock/http.py

@@ -7,20 +7,17 @@ from _pytest.monkeypatch import MonkeyPatch
 
 
 class MockedHttp:
-    def httpx_request(method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD'],
-                      url: str, **kwargs) -> httpx.Response:
+    def httpx_request(
+        method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs
+    ) -> httpx.Response:
         """
         Mocked httpx.request
         """
         request = httpx.Request(
-            method,
-            url,
-            params=kwargs.get('params'),
-            headers=kwargs.get('headers'),
-            cookies=kwargs.get('cookies')
+            method, url, params=kwargs.get("params"), headers=kwargs.get("headers"), cookies=kwargs.get("cookies")
         )
-        data = kwargs.get('data', None)
-        resp = json.dumps(data).encode('utf-8') if data else b'OK'
+        data = kwargs.get("data", None)
+        resp = json.dumps(data).encode("utf-8") if data else b"OK"
         response = httpx.Response(
             status_code=200,
             request=request,

+ 4 - 2
api/tests/integration_tests/tools/__mock_server/openapi_todo.py

@@ -10,6 +10,7 @@ todos_data = {
     "user1": ["Go for a run", "Read a book"],
 }
 
+
 class TodosResource(Resource):
     def get(self, username):
         todos = todos_data.get(username, [])
@@ -32,7 +33,8 @@ class TodosResource(Resource):
 
         return {"error": "Invalid todo index"}, 400
 
-api.add_resource(TodosResource, '/todos/<string:username>')
 
-if __name__ == '__main__':
+api.add_resource(TodosResource, "/todos/<string:username>")
+
+if __name__ == "__main__":
     app.run(port=5003, debug=True)

+ 27 - 24
api/tests/integration_tests/tools/api_tool/test_api_tool.py

@@ -3,37 +3,40 @@ from core.tools.tool.tool import Tool
 from tests.integration_tests.tools.__mock.http import setup_http_mock
 
 tool_bundle = {
-    'server_url': 'http://www.example.com/{path_param}',
-    'method': 'post',
-    'author': '',
-    'openapi': {'parameters': [{'in': 'path', 'name': 'path_param'},
-                               {'in': 'query', 'name': 'query_param'},
-                               {'in': 'cookie', 'name': 'cookie_param'},
-                               {'in': 'header', 'name': 'header_param'},
-                               ],
-                'requestBody': {
-                    'content': {'application/json': {'schema': {'properties': {'body_param': {'type': 'string'}}}}}}
-                },
-    'parameters': []
+    "server_url": "http://www.example.com/{path_param}",
+    "method": "post",
+    "author": "",
+    "openapi": {
+        "parameters": [
+            {"in": "path", "name": "path_param"},
+            {"in": "query", "name": "query_param"},
+            {"in": "cookie", "name": "cookie_param"},
+            {"in": "header", "name": "header_param"},
+        ],
+        "requestBody": {
+            "content": {"application/json": {"schema": {"properties": {"body_param": {"type": "string"}}}}}
+        },
+    },
+    "parameters": [],
 }
 parameters = {
-    'path_param': 'p_param',
-    'query_param': 'q_param',
-    'cookie_param': 'c_param',
-    'header_param': 'h_param',
-    'body_param': 'b_param',
+    "path_param": "p_param",
+    "query_param": "q_param",
+    "cookie_param": "c_param",
+    "header_param": "h_param",
+    "body_param": "b_param",
 }
 
 
 def test_api_tool(setup_http_mock):
-    tool = ApiTool(api_bundle=tool_bundle, runtime=Tool.Runtime(credentials={'auth_type': 'none'}))
+    tool = ApiTool(api_bundle=tool_bundle, runtime=Tool.Runtime(credentials={"auth_type": "none"}))
     headers = tool.assembling_request(parameters)
     response = tool.do_http_request(tool.api_bundle.server_url, tool.api_bundle.method, headers, parameters)
 
     assert response.status_code == 200
-    assert '/p_param' == response.request.url.path
-    assert b'query_param=q_param' == response.request.url.query
-    assert 'h_param' == response.request.headers.get('header_param')
-    assert 'application/json' == response.request.headers.get('content-type')
-    assert 'cookie_param=c_param' == response.request.headers.get('cookie')
-    assert 'b_param' in response.content.decode()
+    assert "/p_param" == response.request.url.path
+    assert b"query_param=q_param" == response.request.url.query
+    assert "h_param" == response.request.headers.get("header_param")
+    assert "application/json" == response.request.headers.get("content-type")
+    assert "cookie_param=c_param" == response.request.headers.get("cookie")
+    assert "b_param" in response.content.decode()

+ 5 - 4
api/tests/integration_tests/tools/test_all_provider.py

@@ -7,16 +7,17 @@ provider_names = [provider.identity.name for provider in provider_generator]
 ToolManager.clear_builtin_providers_cache()
 provider_generator = ToolManager.list_builtin_providers()
 
-@pytest.mark.parametrize('name', provider_names)
+
+@pytest.mark.parametrize("name", provider_names)
 def test_tool_providers(benchmark, name):
     """
     Test that all tool providers can be loaded
     """
-    
+
     def test(generator):
         try:
             return next(generator)
         except StopIteration:
             return None
-    
-    benchmark.pedantic(test, args=(provider_generator,), iterations=1, rounds=1)
+
+    benchmark.pedantic(test, args=(provider_generator,), iterations=1, rounds=1)

Неке датотеке нису приказане због велике количине промена