소스 검색

feat: add proxy configuration for Cohere model (#4152)

Moonlit 11 달 전
부모
커밋
2fdd64c1b5

+ 18 - 0
api/core/model_runtime/model_providers/cohere/cohere.yaml

@@ -32,6 +32,15 @@ provider_credential_schema:
         zh_Hans: 在此输入您的 API Key
         en_US: Enter your API Key
       show_on: [ ]
+    - variable: base_url
+      label:
+        zh_Hans: API Base
+        en_US: API Base
+      type: text-input
+      required: false
+      placeholder:
+        zh_Hans: 在此输入您的 API Base,如 https://api.cohere.ai/v1
+        en_US: Enter your API Base, e.g. https://api.cohere.ai/v1
 model_credential_schema:
   model:
     label:
@@ -70,3 +79,12 @@ model_credential_schema:
       placeholder:
         zh_Hans: 在此输入您的 API Key
         en_US: Enter your API Key
+    - variable: base_url
+      label:
+        zh_Hans: API Base
+        en_US: API Base
+      type: text-input
+      required: false
+      placeholder:
+        zh_Hans: 在此输入您的 API Base,如 https://api.cohere.ai/v1
+        en_US: Enter your API Base, e.g. https://api.cohere.ai/v1

+ 5 - 4
api/core/model_runtime/model_providers/cohere/llm/llm.py

@@ -173,7 +173,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
         :return: full response or stream response chunk generator result
         """
         # initialize client
-        client = cohere.Client(credentials.get('api_key'))
+        client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
 
         if stop:
             model_parameters['end_sequences'] = stop
@@ -233,7 +233,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
 
         return response
 
-    def _handle_generate_stream_response(self, model: str, credentials: dict, response: Iterator[GenerateStreamedResponse],
+    def _handle_generate_stream_response(self, model: str, credentials: dict,
+                                         response: Iterator[GenerateStreamedResponse],
                                          prompt_messages: list[PromptMessage]) -> Generator:
         """
         Handle llm stream response
@@ -317,7 +318,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
         :return: full response or stream response chunk generator result
         """
         # initialize client
-        client = cohere.Client(credentials.get('api_key'))
+        client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
 
         if stop:
             model_parameters['stop_sequences'] = stop
@@ -636,7 +637,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
         :return: number of tokens
         """
         # initialize client
-        client = cohere.Client(credentials.get('api_key'))
+        client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
 
         response = client.tokenize(
             text=text,

+ 1 - 1
api/core/model_runtime/model_providers/cohere/rerank/rerank.py

@@ -44,7 +44,7 @@ class CohereRerankModel(RerankModel):
             )
 
         # initialize client
-        client = cohere.Client(credentials.get('api_key'))
+        client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
         response = client.rerank(
             query=query,
             documents=docs,

+ 2 - 2
api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py

@@ -141,7 +141,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
             return []
 
         # initialize client
-        client = cohere.Client(credentials.get('api_key'))
+        client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
 
         response = client.tokenize(
             text=text,
@@ -180,7 +180,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
         :return: embeddings and used tokens
         """
         # initialize client
-        client = cohere.Client(credentials.get('api_key'))
+        client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
 
         # call embedding model
         response = client.embed(