|
@@ -1,9 +1,23 @@
|
|
|
from collections.abc import Generator
|
|
|
+from decimal import Decimal
|
|
|
from typing import Optional, Union
|
|
|
|
|
|
-from core.model_runtime.entities.llm_entities import LLMResult
|
|
|
-from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
|
|
-from core.model_runtime.entities.model_entities import AIModelEntity
|
|
|
+from core.model_runtime.entities.common_entities import I18nObject
|
|
|
+from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
|
|
|
+from core.model_runtime.entities.message_entities import (
|
|
|
+ PromptMessage,
|
|
|
+ PromptMessageTool,
|
|
|
+)
|
|
|
+from core.model_runtime.entities.model_entities import (
|
|
|
+ AIModelEntity,
|
|
|
+ DefaultParameterName,
|
|
|
+ FetchFrom,
|
|
|
+ ModelPropertyKey,
|
|
|
+ ModelType,
|
|
|
+ ParameterRule,
|
|
|
+ ParameterType,
|
|
|
+ PriceConfig,
|
|
|
+)
|
|
|
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
|
|
|
|
|
|
|
@@ -36,8 +50,98 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
|
|
|
|
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
|
|
cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
|
|
|
-
|
|
|
- return super().get_customizable_model_schema(model, cred_with_endpoint)
|
|
|
+ REPETITION_PENALTY = "repetition_penalty"
|
|
|
+ TOP_K = "top_k"
|
|
|
+ features = []
|
|
|
+
|
|
|
+ entity = AIModelEntity(
|
|
|
+ model=model,
|
|
|
+ label=I18nObject(en_US=model),
|
|
|
+ model_type=ModelType.LLM,
|
|
|
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
|
|
+ features=features,
|
|
|
+ model_properties={
|
|
|
+ ModelPropertyKey.CONTEXT_SIZE: int(cred_with_endpoint.get('context_size', "4096")),
|
|
|
+ ModelPropertyKey.MODE: cred_with_endpoint.get('mode'),
|
|
|
+ },
|
|
|
+ parameter_rules=[
|
|
|
+ ParameterRule(
|
|
|
+ name=DefaultParameterName.TEMPERATURE.value,
|
|
|
+ label=I18nObject(en_US="Temperature"),
|
|
|
+ type=ParameterType.FLOAT,
|
|
|
+ default=float(cred_with_endpoint.get('temperature', 0.7)),
|
|
|
+ min=0,
|
|
|
+ max=2,
|
|
|
+ precision=2
|
|
|
+ ),
|
|
|
+ ParameterRule(
|
|
|
+ name=DefaultParameterName.TOP_P.value,
|
|
|
+ label=I18nObject(en_US="Top P"),
|
|
|
+ type=ParameterType.FLOAT,
|
|
|
+ default=float(cred_with_endpoint.get('top_p', 1)),
|
|
|
+ min=0,
|
|
|
+ max=1,
|
|
|
+ precision=2
|
|
|
+ ),
|
|
|
+ ParameterRule(
|
|
|
+ name=TOP_K,
|
|
|
+ label=I18nObject(en_US="Top K"),
|
|
|
+ type=ParameterType.INT,
|
|
|
+ default=int(cred_with_endpoint.get('top_k', 50)),
|
|
|
+ min=-2147483647,
|
|
|
+ max=2147483647,
|
|
|
+ precision=0
|
|
|
+ ),
|
|
|
+ ParameterRule(
|
|
|
+ name=REPETITION_PENALTY,
|
|
|
+ label=I18nObject(en_US="Repetition Penalty"),
|
|
|
+ type=ParameterType.FLOAT,
|
|
|
+ default=float(cred_with_endpoint.get('repetition_penalty', 1)),
|
|
|
+ min=-3.4,
|
|
|
+ max=3.4,
|
|
|
+ precision=1
|
|
|
+ ),
|
|
|
+ ParameterRule(
|
|
|
+ name=DefaultParameterName.MAX_TOKENS.value,
|
|
|
+ label=I18nObject(en_US="Max Tokens"),
|
|
|
+ type=ParameterType.INT,
|
|
|
+ default=512,
|
|
|
+ min=1,
|
|
|
+ max=int(cred_with_endpoint.get('max_tokens_to_sample', 4096)),
|
|
|
+ ),
|
|
|
+ ParameterRule(
|
|
|
+ name=DefaultParameterName.FREQUENCY_PENALTY.value,
|
|
|
+ label=I18nObject(en_US="Frequency Penalty"),
|
|
|
+ type=ParameterType.FLOAT,
|
|
|
+ default=float(credentials.get('frequency_penalty', 0)),
|
|
|
+ min=-2,
|
|
|
+ max=2
|
|
|
+ ),
|
|
|
+ ParameterRule(
|
|
|
+ name=DefaultParameterName.PRESENCE_PENALTY.value,
|
|
|
+ label=I18nObject(en_US="Presence Penalty"),
|
|
|
+ type=ParameterType.FLOAT,
|
|
|
+ default=float(credentials.get('presence_penalty', 0)),
|
|
|
+ min=-2,
|
|
|
+ max=2
|
|
|
+ )
|
|
|
+ ],
|
|
|
+ pricing=PriceConfig(
|
|
|
+ input=Decimal(cred_with_endpoint.get('input_price', 0)),
|
|
|
+ output=Decimal(cred_with_endpoint.get('output_price', 0)),
|
|
|
+ unit=Decimal(cred_with_endpoint.get('unit', 0)),
|
|
|
+ currency=cred_with_endpoint.get('currency', "USD")
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ if cred_with_endpoint['mode'] == 'chat':
|
|
|
+ entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value
|
|
|
+ elif cred_with_endpoint['mode'] == 'completion':
|
|
|
+ entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Unknown completion type {cred_with_endpoint['completion_type']}")
|
|
|
+
|
|
|
+ return entity
|
|
|
|
|
|
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
|
|
tools: Optional[list[PromptMessageTool]] = None) -> int:
|