|
@@ -0,0 +1,267 @@
|
|
|
+from collections.abc import Generator
|
|
|
+
|
|
|
+from httpx import Response, post
|
|
|
+from yarl import URL
|
|
|
+
|
|
|
+from core.model_runtime.entities.common_entities import I18nObject
|
|
|
+from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
|
|
+from core.model_runtime.entities.message_entities import (
|
|
|
+ AssistantPromptMessage,
|
|
|
+ PromptMessage,
|
|
|
+ PromptMessageTool,
|
|
|
+ SystemPromptMessage,
|
|
|
+ UserPromptMessage,
|
|
|
+)
|
|
|
+from core.model_runtime.entities.model_entities import (
|
|
|
+ AIModelEntity,
|
|
|
+ FetchFrom,
|
|
|
+ ModelPropertyKey,
|
|
|
+ ModelType,
|
|
|
+ ParameterRule,
|
|
|
+ ParameterType,
|
|
|
+)
|
|
|
+from core.model_runtime.errors.invoke import (
|
|
|
+ InvokeAuthorizationError,
|
|
|
+ InvokeBadRequestError,
|
|
|
+ InvokeConnectionError,
|
|
|
+ InvokeError,
|
|
|
+ InvokeRateLimitError,
|
|
|
+ InvokeServerUnavailableError,
|
|
|
+)
|
|
|
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|
|
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
|
|
+
|
|
|
+
|
|
|
+class TritonInferenceAILargeLanguageModel(LargeLanguageModel):
|
|
|
+ def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
|
|
+ model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
|
|
+ stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
|
|
+ -> LLMResult | Generator:
|
|
|
+ """
|
|
|
+ invoke LLM
|
|
|
+
|
|
|
+ see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke`
|
|
|
+ """
|
|
|
+ return self._generate(
|
|
|
+ model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters,
|
|
|
+ tools=tools, stop=stop, stream=stream, user=user,
|
|
|
+ )
|
|
|
+
|
|
|
+ def validate_credentials(self, model: str, credentials: dict) -> None:
|
|
|
+ """
|
|
|
+ validate credentials
|
|
|
+ """
|
|
|
+ if 'server_url' not in credentials:
|
|
|
+ raise CredentialsValidateFailedError('server_url is required in credentials')
|
|
|
+
|
|
|
+ try:
|
|
|
+ self._invoke(model=model, credentials=credentials, prompt_messages=[
|
|
|
+ UserPromptMessage(content='ping')
|
|
|
+ ], model_parameters={}, stream=False)
|
|
|
+ except InvokeError as ex:
|
|
|
+ raise CredentialsValidateFailedError(f'An error occurred during connection: {str(ex)}')
|
|
|
+
|
|
|
+ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
|
|
+ tools: list[PromptMessageTool] | None = None) -> int:
|
|
|
+ """
|
|
|
+ get number of tokens
|
|
|
+
|
|
|
+ cause TritonInference LLM is a customized model, we could net detect which tokenizer to use
|
|
|
+ so we just take the GPT2 tokenizer as default
|
|
|
+ """
|
|
|
+ return self._get_num_tokens_by_gpt2(self._convert_prompt_message_to_text(prompt_messages))
|
|
|
+
|
|
|
+ def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str:
|
|
|
+ """
|
|
|
+ convert prompt message to text
|
|
|
+ """
|
|
|
+ text = ''
|
|
|
+ for item in message:
|
|
|
+ if isinstance(item, UserPromptMessage):
|
|
|
+ text += f'User: {item.content}'
|
|
|
+ elif isinstance(item, SystemPromptMessage):
|
|
|
+ text += f'System: {item.content}'
|
|
|
+ elif isinstance(item, AssistantPromptMessage):
|
|
|
+ text += f'Assistant: {item.content}'
|
|
|
+ else:
|
|
|
+ raise NotImplementedError(f'PromptMessage type {type(item)} is not supported')
|
|
|
+ return text
|
|
|
+
|
|
|
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
|
|
+ """
|
|
|
+ used to define customizable model schema
|
|
|
+ """
|
|
|
+ rules = [
|
|
|
+ ParameterRule(
|
|
|
+ name='temperature',
|
|
|
+ type=ParameterType.FLOAT,
|
|
|
+ use_template='temperature',
|
|
|
+ label=I18nObject(
|
|
|
+ zh_Hans='温度',
|
|
|
+ en_US='Temperature'
|
|
|
+ ),
|
|
|
+ ),
|
|
|
+ ParameterRule(
|
|
|
+ name='top_p',
|
|
|
+ type=ParameterType.FLOAT,
|
|
|
+ use_template='top_p',
|
|
|
+ label=I18nObject(
|
|
|
+ zh_Hans='Top P',
|
|
|
+ en_US='Top P'
|
|
|
+ )
|
|
|
+ ),
|
|
|
+ ParameterRule(
|
|
|
+ name='max_tokens',
|
|
|
+ type=ParameterType.INT,
|
|
|
+ use_template='max_tokens',
|
|
|
+ min=1,
|
|
|
+ max=int(credentials.get('context_length', 2048)),
|
|
|
+ default=min(512, int(credentials.get('context_length', 2048))),
|
|
|
+ label=I18nObject(
|
|
|
+ zh_Hans='最大生成长度',
|
|
|
+ en_US='Max Tokens'
|
|
|
+ )
|
|
|
+ )
|
|
|
+ ]
|
|
|
+
|
|
|
+ completion_type = None
|
|
|
+
|
|
|
+ if 'completion_type' in credentials:
|
|
|
+ if credentials['completion_type'] == 'chat':
|
|
|
+ completion_type = LLMMode.CHAT.value
|
|
|
+ elif credentials['completion_type'] == 'completion':
|
|
|
+ completion_type = LLMMode.COMPLETION.value
|
|
|
+ else:
|
|
|
+ raise ValueError(f'completion_type {credentials["completion_type"]} is not supported')
|
|
|
+
|
|
|
+ entity = AIModelEntity(
|
|
|
+ model=model,
|
|
|
+ label=I18nObject(
|
|
|
+ en_US=model
|
|
|
+ ),
|
|
|
+ parameter_rules=rules,
|
|
|
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
|
|
+ model_type=ModelType.LLM,
|
|
|
+ model_properties={
|
|
|
+ ModelPropertyKey.MODE: completion_type,
|
|
|
+ ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_length', 2048)),
|
|
|
+ },
|
|
|
+ )
|
|
|
+
|
|
|
+ return entity
|
|
|
+
|
|
|
+ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
|
|
+ model_parameters: dict,
|
|
|
+ tools: list[PromptMessageTool] | None = None,
|
|
|
+ stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
|
|
+ -> LLMResult | Generator:
|
|
|
+ """
|
|
|
+ generate text from LLM
|
|
|
+ """
|
|
|
+ if 'server_url' not in credentials:
|
|
|
+ raise CredentialsValidateFailedError('server_url is required in credentials')
|
|
|
+
|
|
|
+ if 'stream' in credentials and not bool(credentials['stream']) and stream:
|
|
|
+ raise ValueError(f'stream is not supported by model {model}')
|
|
|
+
|
|
|
+ try:
|
|
|
+ parameters = {}
|
|
|
+ if 'temperature' in model_parameters:
|
|
|
+ parameters['temperature'] = model_parameters['temperature']
|
|
|
+ if 'top_p' in model_parameters:
|
|
|
+ parameters['top_p'] = model_parameters['top_p']
|
|
|
+ if 'top_k' in model_parameters:
|
|
|
+ parameters['top_k'] = model_parameters['top_k']
|
|
|
+ if 'presence_penalty' in model_parameters:
|
|
|
+ parameters['presence_penalty'] = model_parameters['presence_penalty']
|
|
|
+ if 'frequency_penalty' in model_parameters:
|
|
|
+ parameters['frequency_penalty'] = model_parameters['frequency_penalty']
|
|
|
+
|
|
|
+ response = post(str(URL(credentials['server_url']) / 'v2' / 'models' / model / 'generate'), json={
|
|
|
+ 'text_input': self._convert_prompt_message_to_text(prompt_messages),
|
|
|
+ 'max_tokens': model_parameters.get('max_tokens', 512),
|
|
|
+ 'parameters': {
|
|
|
+ 'stream': False,
|
|
|
+ **parameters
|
|
|
+ },
|
|
|
+ }, timeout=(10, 120))
|
|
|
+ response.raise_for_status()
|
|
|
+ if response.status_code != 200:
|
|
|
+ raise InvokeBadRequestError(f'Invoke failed with status code {response.status_code}, {response.text}')
|
|
|
+
|
|
|
+ if stream:
|
|
|
+ return self._handle_chat_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
|
|
|
+ tools=tools, resp=response)
|
|
|
+ return self._handle_chat_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
|
|
|
+ tools=tools, resp=response)
|
|
|
+ except Exception as ex:
|
|
|
+ raise InvokeConnectionError(f'An error occurred during connection: {str(ex)}')
|
|
|
+
|
|
|
+ def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
|
|
+ tools: list[PromptMessageTool],
|
|
|
+ resp: Response) -> LLMResult:
|
|
|
+ """
|
|
|
+ handle normal chat generate response
|
|
|
+ """
|
|
|
+ text = resp.json()['text_output']
|
|
|
+
|
|
|
+ usage = LLMUsage.empty_usage()
|
|
|
+ usage.prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
|
|
+ usage.completion_tokens = self._get_num_tokens_by_gpt2(text)
|
|
|
+
|
|
|
+ return LLMResult(
|
|
|
+ model=model,
|
|
|
+ prompt_messages=prompt_messages,
|
|
|
+ message=AssistantPromptMessage(
|
|
|
+ content=text
|
|
|
+ ),
|
|
|
+ usage=usage
|
|
|
+ )
|
|
|
+
|
|
|
+ def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
|
|
+ tools: list[PromptMessageTool],
|
|
|
+ resp: Response) -> Generator:
|
|
|
+ """
|
|
|
+ handle normal chat generate response
|
|
|
+ """
|
|
|
+ text = resp.json()['text_output']
|
|
|
+
|
|
|
+ usage = LLMUsage.empty_usage()
|
|
|
+ usage.prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
|
|
+ usage.completion_tokens = self._get_num_tokens_by_gpt2(text)
|
|
|
+
|
|
|
+ yield LLMResultChunk(
|
|
|
+ model=model,
|
|
|
+ prompt_messages=prompt_messages,
|
|
|
+ delta=LLMResultChunkDelta(
|
|
|
+ index=0,
|
|
|
+ message=AssistantPromptMessage(
|
|
|
+ content=text
|
|
|
+ ),
|
|
|
+ usage=usage
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+ @property
|
|
|
+ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
|
|
+ """
|
|
|
+ Map model invoke error to unified error
|
|
|
+ The key is the error type thrown to the caller
|
|
|
+ The value is the error type thrown by the model,
|
|
|
+ which needs to be converted into a unified error type for the caller.
|
|
|
+
|
|
|
+ :return: Invoke error mapping
|
|
|
+ """
|
|
|
+ return {
|
|
|
+ InvokeConnectionError: [
|
|
|
+ ],
|
|
|
+ InvokeServerUnavailableError: [
|
|
|
+ ],
|
|
|
+ InvokeRateLimitError: [
|
|
|
+ ],
|
|
|
+ InvokeAuthorizationError: [
|
|
|
+ ],
|
|
|
+ InvokeBadRequestError: [
|
|
|
+ ValueError
|
|
|
+ ]
|
|
|
+ }
|