|
@@ -0,0 +1,284 @@
|
|
|
+import logging
|
|
|
+from collections.abc import Generator
|
|
|
+
|
|
|
+from core.model_runtime.entities.common_entities import I18nObject
|
|
|
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
|
|
+from core.model_runtime.entities.message_entities import (
|
|
|
+ AssistantPromptMessage,
|
|
|
+ PromptMessage,
|
|
|
+ PromptMessageTool,
|
|
|
+ UserPromptMessage,
|
|
|
+)
|
|
|
+from core.model_runtime.entities.model_entities import (
|
|
|
+ AIModelEntity,
|
|
|
+ FetchFrom,
|
|
|
+ ModelPropertyKey,
|
|
|
+ ModelType,
|
|
|
+ ParameterRule,
|
|
|
+ ParameterType,
|
|
|
+)
|
|
|
+from core.model_runtime.errors.invoke import (
|
|
|
+ InvokeAuthorizationError,
|
|
|
+ InvokeBadRequestError,
|
|
|
+ InvokeConnectionError,
|
|
|
+ InvokeError,
|
|
|
+ InvokeRateLimitError,
|
|
|
+ InvokeServerUnavailableError,
|
|
|
+)
|
|
|
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|
|
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
|
|
+from core.model_runtime.model_providers.volcengine_maas.client import MaaSClient
|
|
|
+from core.model_runtime.model_providers.volcengine_maas.errors import (
|
|
|
+ AuthErrors,
|
|
|
+ BadRequestErrors,
|
|
|
+ ConnectionErrors,
|
|
|
+ RateLimitErrors,
|
|
|
+ ServerUnavailableErrors,
|
|
|
+)
|
|
|
+from core.model_runtime.model_providers.volcengine_maas.llm.models import ModelConfigs
|
|
|
+from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
|
|
|
+
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+
|
|
|
+class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|
|
+ def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
|
|
+ model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
|
|
+ stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
|
|
+ -> LLMResult | Generator:
|
|
|
+ return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
|
|
+
|
|
|
+ def validate_credentials(self, model: str, credentials: dict) -> None:
|
|
|
+ """
|
|
|
+ Validate credentials
|
|
|
+ """
|
|
|
+ # ping
|
|
|
+ client = MaaSClient.from_credential(credentials)
|
|
|
+ try:
|
|
|
+ client.chat(
|
|
|
+ {
|
|
|
+ 'max_new_tokens': 16,
|
|
|
+ 'temperature': 0.7,
|
|
|
+ 'top_p': 0.9,
|
|
|
+ 'top_k': 15,
|
|
|
+ },
|
|
|
+ [UserPromptMessage(content='ping\nAnswer: ')],
|
|
|
+ )
|
|
|
+ except MaasException as e:
|
|
|
+ raise CredentialsValidateFailedError(e.message)
|
|
|
+
|
|
|
+ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
|
|
+ tools: list[PromptMessageTool] | None = None) -> int:
|
|
|
+ if len(prompt_messages) == 0:
|
|
|
+ return 0
|
|
|
+ return self._num_tokens_from_messages(prompt_messages)
|
|
|
+
|
|
|
+ def _num_tokens_from_messages(self, messages: list[PromptMessage]) -> int:
|
|
|
+ """
|
|
|
+ Calculate num tokens.
|
|
|
+
|
|
|
+ :param messages: messages
|
|
|
+ """
|
|
|
+ num_tokens = 0
|
|
|
+ messages_dict = [
|
|
|
+ MaaSClient.convert_prompt_message_to_maas_message(m) for m in messages]
|
|
|
+ for message in messages_dict:
|
|
|
+ for key, value in message.items():
|
|
|
+ num_tokens += self._get_num_tokens_by_gpt2(str(key))
|
|
|
+ num_tokens += self._get_num_tokens_by_gpt2(str(value))
|
|
|
+
|
|
|
+ return num_tokens
|
|
|
+
|
|
|
+ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
|
|
+ model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
|
|
+ stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
|
|
+ -> LLMResult | Generator:
|
|
|
+
|
|
|
+ client = MaaSClient.from_credential(credentials)
|
|
|
+
|
|
|
+ req_params = ModelConfigs.get(
|
|
|
+ credentials['base_model_name'], {}).get('req_params', {}).copy()
|
|
|
+ if credentials.get('context_size'):
|
|
|
+ req_params['max_prompt_tokens'] = credentials.get('context_size')
|
|
|
+ if credentials.get('max_tokens'):
|
|
|
+ req_params['max_new_tokens'] = credentials.get('max_tokens')
|
|
|
+ if model_parameters.get('max_tokens'):
|
|
|
+ req_params['max_new_tokens'] = model_parameters.get('max_tokens')
|
|
|
+ if model_parameters.get('temperature'):
|
|
|
+ req_params['temperature'] = model_parameters.get('temperature')
|
|
|
+ if model_parameters.get('top_p'):
|
|
|
+ req_params['top_p'] = model_parameters.get('top_p')
|
|
|
+ if model_parameters.get('top_k'):
|
|
|
+ req_params['top_k'] = model_parameters.get('top_k')
|
|
|
+ if model_parameters.get('presence_penalty'):
|
|
|
+ req_params['presence_penalty'] = model_parameters.get(
|
|
|
+ 'presence_penalty')
|
|
|
+ if model_parameters.get('frequency_penalty'):
|
|
|
+ req_params['frequency_penalty'] = model_parameters.get(
|
|
|
+ 'frequency_penalty')
|
|
|
+ if stop:
|
|
|
+ req_params['stop'] = stop
|
|
|
+
|
|
|
+ resp = MaaSClient.wrap_exception(
|
|
|
+ lambda: client.chat(req_params, prompt_messages, stream))
|
|
|
+ if not stream:
|
|
|
+ return self._handle_chat_response(model, credentials, prompt_messages, resp)
|
|
|
+ return self._handle_stream_chat_response(model, credentials, prompt_messages, resp)
|
|
|
+
|
|
|
+ def _handle_stream_chat_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: Generator) -> Generator:
|
|
|
+ for index, r in enumerate(resp):
|
|
|
+ choices = r['choices']
|
|
|
+ if not choices:
|
|
|
+ continue
|
|
|
+ choice = choices[0]
|
|
|
+ message = choice['message']
|
|
|
+ usage = None
|
|
|
+ if r.get('usage'):
|
|
|
+ usage = self._calc_usage(model, credentials, r['usage'])
|
|
|
+ yield LLMResultChunk(
|
|
|
+ model=model,
|
|
|
+ prompt_messages=prompt_messages,
|
|
|
+ delta=LLMResultChunkDelta(
|
|
|
+ index=index,
|
|
|
+ message=AssistantPromptMessage(
|
|
|
+ content=message['content'] if message['content'] else '',
|
|
|
+ tool_calls=[]
|
|
|
+ ),
|
|
|
+ usage=usage,
|
|
|
+ finish_reason=choice.get('finish_reason'),
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ def _handle_chat_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: dict) -> LLMResult:
|
|
|
+ choices = resp['choices']
|
|
|
+ if not choices:
|
|
|
+ return
|
|
|
+ choice = choices[0]
|
|
|
+ message = choice['message']
|
|
|
+
|
|
|
+ return LLMResult(
|
|
|
+ model=model,
|
|
|
+ prompt_messages=prompt_messages,
|
|
|
+ message=AssistantPromptMessage(
|
|
|
+ content=message['content'] if message['content'] else '',
|
|
|
+ tool_calls=[],
|
|
|
+ ),
|
|
|
+ usage=self._calc_usage(model, credentials, resp['usage']),
|
|
|
+ )
|
|
|
+
|
|
|
+ def _calc_usage(self, model: str, credentials: dict, usage: dict) -> LLMUsage:
|
|
|
+ return self._calc_response_usage(model=model, credentials=credentials,
|
|
|
+ prompt_tokens=usage['prompt_tokens'],
|
|
|
+ completion_tokens=usage['completion_tokens']
|
|
|
+ )
|
|
|
+
|
|
|
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
|
|
+ """
|
|
|
+ used to define customizable model schema
|
|
|
+ """
|
|
|
+ max_tokens = ModelConfigs.get(
|
|
|
+ credentials['base_model_name'], {}).get('req_params', {}).get('max_new_tokens')
|
|
|
+ if credentials.get('max_tokens'):
|
|
|
+ max_tokens = int(credentials.get('max_tokens'))
|
|
|
+ rules = [
|
|
|
+ ParameterRule(
|
|
|
+ name='temperature',
|
|
|
+ type=ParameterType.FLOAT,
|
|
|
+ use_template='temperature',
|
|
|
+ label=I18nObject(
|
|
|
+ zh_Hans='温度',
|
|
|
+ en_US='Temperature'
|
|
|
+ )
|
|
|
+ ),
|
|
|
+ ParameterRule(
|
|
|
+ name='top_p',
|
|
|
+ type=ParameterType.FLOAT,
|
|
|
+ use_template='top_p',
|
|
|
+ label=I18nObject(
|
|
|
+ zh_Hans='Top P',
|
|
|
+ en_US='Top P'
|
|
|
+ )
|
|
|
+ ),
|
|
|
+ ParameterRule(
|
|
|
+ name='top_k',
|
|
|
+ type=ParameterType.INT,
|
|
|
+ min=1,
|
|
|
+ default=1,
|
|
|
+ label=I18nObject(
|
|
|
+ zh_Hans='Top K',
|
|
|
+ en_US='Top K'
|
|
|
+ )
|
|
|
+ ),
|
|
|
+ ParameterRule(
|
|
|
+ name='presence_penalty',
|
|
|
+ type=ParameterType.FLOAT,
|
|
|
+ use_template='presence_penalty',
|
|
|
+ label={
|
|
|
+ 'en_US': 'Presence Penalty',
|
|
|
+ 'zh_Hans': '存在惩罚',
|
|
|
+ },
|
|
|
+ min=-2.0,
|
|
|
+ max=2.0,
|
|
|
+ ),
|
|
|
+ ParameterRule(
|
|
|
+ name='frequency_penalty',
|
|
|
+ type=ParameterType.FLOAT,
|
|
|
+ use_template='frequency_penalty',
|
|
|
+ label={
|
|
|
+ 'en_US': 'Frequency Penalty',
|
|
|
+ 'zh_Hans': '频率惩罚',
|
|
|
+ },
|
|
|
+ min=-2.0,
|
|
|
+ max=2.0,
|
|
|
+ ),
|
|
|
+ ParameterRule(
|
|
|
+ name='max_tokens',
|
|
|
+ type=ParameterType.INT,
|
|
|
+ use_template='max_tokens',
|
|
|
+ min=1,
|
|
|
+ max=max_tokens,
|
|
|
+ default=512,
|
|
|
+ label=I18nObject(
|
|
|
+ zh_Hans='最大生成长度',
|
|
|
+ en_US='Max Tokens'
|
|
|
+ )
|
|
|
+ ),
|
|
|
+ ]
|
|
|
+
|
|
|
+ model_properties = ModelConfigs.get(
|
|
|
+ credentials['base_model_name'], {}).get('model_properties', {}).copy()
|
|
|
+ if credentials.get('mode'):
|
|
|
+ model_properties[ModelPropertyKey.MODE] = credentials.get('mode')
|
|
|
+ if credentials.get('context_size'):
|
|
|
+ model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
|
|
|
+ credentials.get('context_size', 4096))
|
|
|
+ entity = AIModelEntity(
|
|
|
+ model=model,
|
|
|
+ label=I18nObject(
|
|
|
+ en_US=model
|
|
|
+ ),
|
|
|
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
|
|
+ model_type=ModelType.LLM,
|
|
|
+ model_properties=model_properties,
|
|
|
+ parameter_rules=rules
|
|
|
+ )
|
|
|
+
|
|
|
+ return entity
|
|
|
+
|
|
|
+ @property
|
|
|
+ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
|
|
+ """
|
|
|
+ Map model invoke error to unified error
|
|
|
+ The key is the error type thrown to the caller
|
|
|
+ The value is the error type thrown by the model,
|
|
|
+ which needs to be converted into a unified error type for the caller.
|
|
|
+
|
|
|
+ :return: Invoke error mapping
|
|
|
+ """
|
|
|
+ return {
|
|
|
+ InvokeConnectionError: ConnectionErrors.values(),
|
|
|
+ InvokeServerUnavailableError: ServerUnavailableErrors.values(),
|
|
|
+ InvokeRateLimitError: RateLimitErrors.values(),
|
|
|
+ InvokeAuthorizationError: AuthErrors.values(),
|
|
|
+ InvokeBadRequestError: BadRequestErrors.values(),
|
|
|
+ }
|