|
@@ -0,0 +1,247 @@
|
|
|
+import json
|
|
|
+from collections.abc import Generator
|
|
|
+from typing import Optional, Union
|
|
|
+
|
|
|
+import requests
|
|
|
+from yarl import URL
|
|
|
+
|
|
|
+from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
|
|
|
+from core.model_runtime.entities.message_entities import (
|
|
|
+ PromptMessage,
|
|
|
+ PromptMessageContentType,
|
|
|
+ PromptMessageFunction,
|
|
|
+ PromptMessageTool,
|
|
|
+ UserPromptMessage,
|
|
|
+)
|
|
|
+from core.model_runtime.errors.invoke import InvokeError
|
|
|
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|
|
+from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
|
|
+from core.model_runtime.utils import helper
|
|
|
+
|
|
|
+
|
|
|
+class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
|
|
+ MODEL_SUFFIX_MAP = {
|
|
|
+ 'fuyu-8b': 'vlm/adept/fuyu-8b',
|
|
|
+ 'mistralai/mixtral-8x7b-instruct-v0.1': '',
|
|
|
+ 'google/gemma-7b': '',
|
|
|
+ 'meta/llama2-70b': ''
|
|
|
+ }
|
|
|
+
|
|
|
+ def _invoke(self, model: str, credentials: dict,
|
|
|
+ prompt_messages: list[PromptMessage], model_parameters: dict,
|
|
|
+ tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
|
|
+ stream: bool = True, user: Optional[str] = None) \
|
|
|
+ -> Union[LLMResult, Generator]:
|
|
|
+
|
|
|
+ self._add_custom_parameters(credentials, model)
|
|
|
+ prompt_messages = self._transform_prompt_messages(prompt_messages)
|
|
|
+ stop = []
|
|
|
+ user = None
|
|
|
+
|
|
|
+ return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
|
|
+
|
|
|
+ def _transform_prompt_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
|
|
+ """
|
|
|
+ Handle Image transform
|
|
|
+ """
|
|
|
+ for i, p in enumerate(prompt_messages):
|
|
|
+ if isinstance(p, UserPromptMessage) and isinstance(p.content, list):
|
|
|
+ content = p.content
|
|
|
+ content_text = ''
|
|
|
+ for prompt_content in content:
|
|
|
+ if prompt_content.type == PromptMessageContentType.TEXT:
|
|
|
+ content_text += prompt_content.data
|
|
|
+ else:
|
|
|
+ content_text += f' <img src="{prompt_content.data}" />'
|
|
|
+
|
|
|
+ prompt_message = UserPromptMessage(
|
|
|
+ content=content_text
|
|
|
+ )
|
|
|
+ prompt_messages[i] = prompt_message
|
|
|
+ return prompt_messages
|
|
|
+
|
|
|
+ def validate_credentials(self, model: str, credentials: dict) -> None:
|
|
|
+ self._add_custom_parameters(credentials, model)
|
|
|
+ self._validate_credentials(model, credentials)
|
|
|
+
|
|
|
+ def _add_custom_parameters(self, credentials: dict, model: str) -> None:
|
|
|
+ credentials['mode'] = 'chat'
|
|
|
+
|
|
|
+ if self.MODEL_SUFFIX_MAP[model]:
|
|
|
+ credentials['server_url'] = f'https://ai.api.nvidia.com/v1/{self.MODEL_SUFFIX_MAP[model]}'
|
|
|
+ credentials.pop('endpoint_url')
|
|
|
+ else:
|
|
|
+ credentials['endpoint_url'] = 'https://integrate.api.nvidia.com/v1'
|
|
|
+
|
|
|
+ credentials['stream_mode_delimiter'] = '\n'
|
|
|
+
|
|
|
+ def _validate_credentials(self, model: str, credentials: dict) -> None:
|
|
|
+ """
|
|
|
+ Validate model credentials using requests to ensure compatibility with all providers following OpenAI's API standard.
|
|
|
+
|
|
|
+ :param model: model name
|
|
|
+ :param credentials: model credentials
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ headers = {
|
|
|
+ 'Content-Type': 'application/json'
|
|
|
+ }
|
|
|
+
|
|
|
+ api_key = credentials.get('api_key')
|
|
|
+ if api_key:
|
|
|
+ headers["Authorization"] = f"Bearer {api_key}"
|
|
|
+
|
|
|
+ endpoint_url = credentials['endpoint_url'] if 'endpoint_url' in credentials else None
|
|
|
+ if endpoint_url and not endpoint_url.endswith('/'):
|
|
|
+ endpoint_url += '/'
|
|
|
+ server_url = credentials['server_url'] if 'server_url' in credentials else None
|
|
|
+
|
|
|
+ # prepare the payload for a simple ping to the model
|
|
|
+ data = {
|
|
|
+ 'model': model,
|
|
|
+ 'max_tokens': 5
|
|
|
+ }
|
|
|
+
|
|
|
+ completion_type = LLMMode.value_of(credentials['mode'])
|
|
|
+
|
|
|
+ if completion_type is LLMMode.CHAT:
|
|
|
+ data['messages'] = [
|
|
|
+ {
|
|
|
+ "role": "user",
|
|
|
+ "content": "ping"
|
|
|
+ },
|
|
|
+ ]
|
|
|
+ if 'endpoint_url' in credentials:
|
|
|
+ endpoint_url = str(URL(endpoint_url) / 'chat' / 'completions')
|
|
|
+ elif 'server_url' in credentials:
|
|
|
+ endpoint_url = server_url
|
|
|
+ elif completion_type is LLMMode.COMPLETION:
|
|
|
+ data['prompt'] = 'ping'
|
|
|
+ if 'endpoint_url' in credentials:
|
|
|
+ endpoint_url = str(URL(endpoint_url) / 'completions')
|
|
|
+ elif 'server_url' in credentials:
|
|
|
+ endpoint_url = server_url
|
|
|
+ else:
|
|
|
+ raise ValueError("Unsupported completion type for model configuration.")
|
|
|
+
|
|
|
+ # send a post request to validate the credentials
|
|
|
+ response = requests.post(
|
|
|
+ endpoint_url,
|
|
|
+ headers=headers,
|
|
|
+ json=data,
|
|
|
+ timeout=(10, 60)
|
|
|
+ )
|
|
|
+
|
|
|
+ if response.status_code != 200:
|
|
|
+ raise CredentialsValidateFailedError(
|
|
|
+ f'Credentials validation failed with status code {response.status_code}')
|
|
|
+
|
|
|
+ try:
|
|
|
+ json_result = response.json()
|
|
|
+ except json.JSONDecodeError as e:
|
|
|
+ raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error')
|
|
|
+ except CredentialsValidateFailedError:
|
|
|
+ raise
|
|
|
+ except Exception as ex:
|
|
|
+ raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}')
|
|
|
+
|
|
|
+ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
|
|
|
+ tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
|
|
+ stream: bool = True, \
|
|
|
+ user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
|
|
+ """
|
|
|
+ Invoke llm completion model
|
|
|
+
|
|
|
+ :param model: model name
|
|
|
+ :param credentials: credentials
|
|
|
+ :param prompt_messages: prompt messages
|
|
|
+ :param model_parameters: model parameters
|
|
|
+ :param stop: stop words
|
|
|
+ :param stream: is stream response
|
|
|
+ :param user: unique user id
|
|
|
+ :return: full response or stream response chunk generator result
|
|
|
+ """
|
|
|
+ headers = {
|
|
|
+ 'Content-Type': 'application/json',
|
|
|
+ 'Accept-Charset': 'utf-8',
|
|
|
+ }
|
|
|
+
|
|
|
+ api_key = credentials.get('api_key')
|
|
|
+ if api_key:
|
|
|
+ headers['Authorization'] = f'Bearer {api_key}'
|
|
|
+
|
|
|
+ if stream:
|
|
|
+ headers['Accept'] = 'text/event-stream'
|
|
|
+
|
|
|
+ endpoint_url = credentials['endpoint_url'] if 'endpoint_url' in credentials else None
|
|
|
+ if endpoint_url and not endpoint_url.endswith('/'):
|
|
|
+ endpoint_url += '/'
|
|
|
+ server_url = credentials['server_url'] if 'server_url' in credentials else None
|
|
|
+
|
|
|
+ data = {
|
|
|
+ "model": model,
|
|
|
+ "stream": stream,
|
|
|
+ **model_parameters
|
|
|
+ }
|
|
|
+
|
|
|
+ completion_type = LLMMode.value_of(credentials['mode'])
|
|
|
+
|
|
|
+ if completion_type is LLMMode.CHAT:
|
|
|
+ if 'endpoint_url' in credentials:
|
|
|
+ endpoint_url = str(URL(endpoint_url) / 'chat' / 'completions')
|
|
|
+ elif 'server_url' in credentials:
|
|
|
+ endpoint_url = server_url
|
|
|
+ data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
|
|
|
+ elif completion_type is LLMMode.COMPLETION:
|
|
|
+ data['prompt'] = 'ping'
|
|
|
+ if 'endpoint_url' in credentials:
|
|
|
+ endpoint_url = str(URL(endpoint_url) / 'completions')
|
|
|
+ elif 'server_url' in credentials:
|
|
|
+ endpoint_url = server_url
|
|
|
+ else:
|
|
|
+ raise ValueError("Unsupported completion type for model configuration.")
|
|
|
+
|
|
|
+
|
|
|
+ # annotate tools with names, descriptions, etc.
|
|
|
+ function_calling_type = credentials.get('function_calling_type', 'no_call')
|
|
|
+ formatted_tools = []
|
|
|
+ if tools:
|
|
|
+ if function_calling_type == 'function_call':
|
|
|
+ data['functions'] = [{
|
|
|
+ "name": tool.name,
|
|
|
+ "description": tool.description,
|
|
|
+ "parameters": tool.parameters
|
|
|
+ } for tool in tools]
|
|
|
+ elif function_calling_type == 'tool_call':
|
|
|
+ data["tool_choice"] = "auto"
|
|
|
+
|
|
|
+ for tool in tools:
|
|
|
+ formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool)))
|
|
|
+
|
|
|
+ data["tools"] = formatted_tools
|
|
|
+
|
|
|
+ if stop:
|
|
|
+ data["stop"] = stop
|
|
|
+
|
|
|
+ if user:
|
|
|
+ data["user"] = user
|
|
|
+
|
|
|
+ response = requests.post(
|
|
|
+ endpoint_url,
|
|
|
+ headers=headers,
|
|
|
+ json=data,
|
|
|
+ timeout=(10, 60),
|
|
|
+ stream=stream
|
|
|
+ )
|
|
|
+
|
|
|
+ if response.encoding is None or response.encoding == 'ISO-8859-1':
|
|
|
+ response.encoding = 'utf-8'
|
|
|
+
|
|
|
+ if not response.ok:
|
|
|
+ raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}")
|
|
|
+
|
|
|
+ if stream:
|
|
|
+ return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
|
|
+
|
|
|
+ return self._handle_generate_response(model, credentials, response, prompt_messages)
|