Ver Fonte

add example api url endpoint in placeholder (#1887)

Co-authored-by: takatost <takatost@gmail.com>
Chenhe Gu há 1 ano atrás
pai
commit
77f9e8ce0f

+ 96 - 52
api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py

@@ -1,5 +1,6 @@
 import logging
 from decimal import Decimal
+from urllib.parse import urljoin
 
 import requests
 import json
@@ -9,9 +10,12 @@ from typing import Optional, Generator, Union, List, cast
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.utils import helper
 
-from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessage, AssistantPromptMessage, PromptMessageContent, \
-    PromptMessageContentType, PromptMessageFunction, PromptMessageTool, UserPromptMessage, SystemPromptMessage, ToolPromptMessage
-from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType, PriceConfig, ParameterRule, DefaultParameterName, \
+from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessage, \
+    AssistantPromptMessage, PromptMessageContent, \
+    PromptMessageContentType, PromptMessageFunction, PromptMessageTool, UserPromptMessage, SystemPromptMessage, \
+    ToolPromptMessage
+from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType, PriceConfig, ParameterRule, \
+    DefaultParameterName, \
     ParameterType, ModelPropertyKey, FetchFrom, AIModelEntity
 from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
 from core.model_runtime.errors.invoke import InvokeError
@@ -70,7 +74,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
         :return:
         """
         return self._num_tokens_from_messages(model, prompt_messages, tools)
-        
+
     def validate_credentials(self, model: str, credentials: dict) -> None:
         """
         Validate model credentials using requests to ensure compatibility with all providers following OpenAI's API standard.
@@ -89,6 +93,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
                 headers["Authorization"] = f"Bearer {api_key}"
 
             endpoint_url = credentials['endpoint_url']
+            if not endpoint_url.endswith('/'):
+                endpoint_url += '/'
 
             # prepare the payload for a simple ping to the model
             data = {
@@ -105,11 +111,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
                         "content": "ping"
                     },
                 ]
+                endpoint_url = urljoin(endpoint_url, 'chat/completions')
             elif completion_type is LLMMode.COMPLETION:
                 data['prompt'] = 'ping'
+                endpoint_url = urljoin(endpoint_url, 'completions')
             else:
                 raise ValueError("Unsupported completion type for model configuration.")
-        
+
             # send a post request to validate the credentials
             response = requests.post(
                 endpoint_url,
@@ -119,8 +127,24 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
             )
 
             if response.status_code != 200:
-                raise CredentialsValidateFailedError(f'Credentials validation failed with status code {response.status_code}: {response.text}')
-
+                raise CredentialsValidateFailedError(
+                    f'Credentials validation failed with status code {response.status_code}')
+
+            try:
+                json_result = response.json()
+            except json.JSONDecodeError as e:
+                raise CredentialsValidateFailedError(f'Credentials validation failed: JSON decode error')
+
+            if (completion_type is LLMMode.CHAT
+                    and ('object' not in json_result or json_result['object'] != 'chat.completion')):
+                raise CredentialsValidateFailedError(
+                    f'Credentials validation failed: invalid response object, must be \'chat.completion\'')
+            elif (completion_type is LLMMode.COMPLETION
+                  and ('object' not in json_result or json_result['object'] != 'text_completion')):
+                raise CredentialsValidateFailedError(
+                    f'Credentials validation failed: invalid response object, must be \'text_completion\'')
+        except CredentialsValidateFailedError:
+            raise
         except Exception as ex:
             raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}')
 
@@ -134,8 +158,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
             model_type=ModelType.LLM,
             fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
             model_properties={
-                ModelPropertyKey.CONTEXT_SIZE: credentials.get('context_size'),
-                ModelPropertyKey.MODE: 'chat'
+                ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
+                ModelPropertyKey.MODE: credentials.get('mode'),
             },
             parameter_rules=[
                 ParameterRule(
@@ -197,11 +221,11 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
 
         return entity
 
-
     # validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard.
-    def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, 
-                  tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, stream: bool = True, \
-                        user: Optional[str] = None) -> Union[LLMResult, Generator]:
+    def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
+                  tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
+                  stream: bool = True, \
+                  user: Optional[str] = None) -> Union[LLMResult, Generator]:
         """
         Invoke llm completion model
 
@@ -223,7 +247,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
             headers["Authorization"] = f"Bearer {api_key}"
 
         endpoint_url = credentials["endpoint_url"]
-    
+        if not endpoint_url.endswith('/'):
+            endpoint_url += '/'
+
         data = {
             "model": model,
             "stream": stream,
@@ -233,8 +259,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
         completion_type = LLMMode.value_of(credentials['mode'])
 
         if completion_type is LLMMode.CHAT:
+            endpoint_url = urljoin(endpoint_url, 'chat/completions')
             data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
         elif completion_type == LLMMode.COMPLETION:
+            endpoint_url = urljoin(endpoint_url, 'completions')
             data['prompt'] = prompt_messages[0].content
         else:
             raise ValueError("Unsupported completion type for model configuration.")
@@ -245,8 +273,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
             data["tool_choice"] = "auto"
 
             for tool in tools:
-                formatted_tools.append( helper.dump_model(PromptMessageFunction(function=tool)))
-            
+                formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool)))
+
             data["tools"] = formatted_tools
 
         if stop:
@@ -254,7 +282,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
 
         if user:
             data["user"] = user
-    
+
         response = requests.post(
             endpoint_url,
             headers=headers,
@@ -275,8 +303,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
 
         return self._handle_generate_response(model, credentials, response, prompt_messages)
 
-    def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response, 
-                                        prompt_messages: list[PromptMessage]) -> Generator:
+    def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response,
+                                         prompt_messages: list[PromptMessage]) -> Generator:
         """
         Handle llm stream response
 
@@ -313,51 +341,64 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
             if chunk:
                 decoded_chunk = chunk.decode('utf-8').strip().lstrip('data: ').lstrip()
 
+                chunk_json = None
                 try:
                     chunk_json = json.loads(decoded_chunk)
                 # stream ended
                 except json.JSONDecodeError as e:
                     yield create_final_llm_result_chunk(
-                        index=chunk_index + 1, 
+                        index=chunk_index + 1,
                         message=AssistantPromptMessage(content=""),
                         finish_reason="Non-JSON encountered."
                     )
 
-                if len(chunk_json['choices']) == 0:
+                if not chunk_json or len(chunk_json['choices']) == 0:
                     continue
 
-                delta = chunk_json['choices'][0]['delta']
-                chunk_index = chunk_json['choices'][0]['index']
+                choice = chunk_json['choices'][0]
+                chunk_index = choice['index'] if 'index' in choice else chunk_index
 
-                if delta.get('finish_reason') is None and (delta.get('content') is None or delta.get('content') == ''):
-                    continue
-                
-                assistant_message_tool_calls = delta.get('tool_calls', None)
-                # assistant_message_function_call = delta.delta.function_call
-
-                # extract tool calls from response
-                if assistant_message_tool_calls:
-                    tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
-                # function_call = self._extract_response_function_call(assistant_message_function_call)
-                # tool_calls = [function_call] if function_call else []
-
-                # transform assistant message to prompt message
-                assistant_prompt_message = AssistantPromptMessage(
-                    content=delta.get('content', ''),
-                    tool_calls=tool_calls if assistant_message_tool_calls else []
-                )
+                if 'delta' in choice:
+                    delta = choice['delta']
+                    if delta.get('content') is None or delta.get('content') == '':
+                        continue
+
+                    assistant_message_tool_calls = delta.get('tool_calls', None)
+                    # assistant_message_function_call = delta.delta.function_call
+
+                    # extract tool calls from response
+                    if assistant_message_tool_calls:
+                        tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
+                    # function_call = self._extract_response_function_call(assistant_message_function_call)
+                    # tool_calls = [function_call] if function_call else []
+
+                    # transform assistant message to prompt message
+                    assistant_prompt_message = AssistantPromptMessage(
+                        content=delta.get('content', ''),
+                        tool_calls=tool_calls if assistant_message_tool_calls else []
+                    )
 
-                full_assistant_content += delta.get('content', '')
+                    full_assistant_content += delta.get('content', '')
+                elif 'text' in choice:
+                    if choice.get('text') is None or choice.get('text') == '':
+                        continue
+
+                    # transform assistant message to prompt message
+                    assistant_prompt_message = AssistantPromptMessage(
+                        content=choice.get('text', '')
+                    )
+
+                    full_assistant_content += choice.get('text', '')
+                else:
+                    continue
 
                 # check payload indicator for completion
                 if chunk_json['choices'][0].get('finish_reason') is not None:
-                   
                     yield create_final_llm_result_chunk(
                         index=chunk_index,
                         message=assistant_prompt_message,
                         finish_reason=chunk_json['choices'][0]['finish_reason']
                     )
-
                 else:
                     yield LLMResultChunk(
                         model=model,
@@ -373,10 +414,12 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
                     message=AssistantPromptMessage(content=""),
                     finish_reason="End of stream."
                 )
-            
-    def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response, 
-                                        prompt_messages: list[PromptMessage]) -> LLMResult:
-        
+
+            chunk_index += 1
+
+    def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response,
+                                  prompt_messages: list[PromptMessage]) -> LLMResult:
+
         response_json = response.json()
 
         completion_type = LLMMode.value_of(credentials['mode'])
@@ -455,7 +498,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
             message = cast(AssistantPromptMessage, message)
             message_dict = {"role": "assistant", "content": message.content}
             if message.tool_calls:
-                message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call in
+                message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call
+                                              in
                                               message.tool_calls]
                 # function_call = message.tool_calls[0]
                 # message_dict["function_call"] = {
@@ -484,7 +528,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
             message_dict["name"] = message.name
 
         return message_dict
-    
+
     def _num_tokens_from_string(self, model: str, text: str,
                                 tools: Optional[list[PromptMessageTool]] = None) -> int:
         """
@@ -507,10 +551,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
         """
         Approximate num tokens with GPT2 tokenizer.
         """
-       
+
         tokens_per_message = 3
         tokens_per_name = 1
-       
+
         num_tokens = 0
         messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
         for message in messages_dict:
@@ -599,7 +643,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
                     num_tokens += self._get_num_tokens_by_gpt2(required_field)
 
         return num_tokens
-    
+
     def _extract_response_tool_calls(self,
                                      response_tool_calls: list[dict]) \
             -> list[AssistantPromptMessage.ToolCall]:

+ 2 - 2
api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml

@@ -33,8 +33,8 @@ model_credential_schema:
     type: text-input
     required: true
     placeholder:
-      zh_Hans: 在此输入您的 API endpoint URL
-      en_US: Enter your API endpoint URL  
+      zh_Hans: Base URL, eg. https://api.openai.com/v1
+      en_US: Base URL, eg. https://api.openai.com/v1
   - variable: mode
     show_on:
       - variable: __model_type

+ 23 - 5
api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py

@@ -1,6 +1,7 @@
 import time
 from decimal import Decimal
 from typing import Optional
+from urllib.parse import urljoin
 import requests
 import json
 
@@ -42,8 +43,11 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
         if api_key:
             headers["Authorization"] = f"Bearer {api_key}"
 
+        endpoint_url = credentials.get('endpoint_url')
+        if not endpoint_url.endswith('/'):
+            endpoint_url += '/'
 
-        endpoint_url = credentials['endpoint_url']
+        endpoint_url = urljoin(endpoint_url, 'embeddings')
 
         extra_model_kwargs = {}
         if user:
@@ -144,8 +148,11 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
             if api_key:
                 headers["Authorization"] = f"Bearer {api_key}"
 
+            endpoint_url = credentials.get('endpoint_url')
+            if not endpoint_url.endswith('/'):
+                endpoint_url += '/'
 
-            endpoint_url = credentials['endpoint_url']
+            endpoint_url = urljoin(endpoint_url, 'embeddings')
 
             payload = {
                 'input': 'ping',
@@ -160,8 +167,19 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
             )
 
             if response.status_code != 200:
-                raise CredentialsValidateFailedError(f"Invalid response status: {response.status_code}")
-
+                raise CredentialsValidateFailedError(
+                    f'Credentials validation failed with status code {response.status_code}')
+
+            try:
+                json_result = response.json()
+            except json.JSONDecodeError as e:
+                raise CredentialsValidateFailedError(f'Credentials validation failed: JSON decode error')
+
+            if 'model' not in json_result:
+                raise CredentialsValidateFailedError(
+                    f'Credentials validation failed: invalid response')
+        except CredentialsValidateFailedError:
+            raise
         except Exception as ex:
             raise CredentialsValidateFailedError(str(ex))
 
@@ -175,7 +193,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
             model_type=ModelType.TEXT_EMBEDDING,
             fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
             model_properties={
-                ModelPropertyKey.CONTEXT_SIZE: credentials.get('context_size'),
+                ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
                 ModelPropertyKey.MAX_CHUNKS: 1,
             },
             parameter_rules=[],