Переглянути джерело

add together ai model setting (#3895)

Jyong 1 рік тому
батько
коміт
0ec8b57825

+ 109 - 5
api/core/model_runtime/model_providers/togetherai/llm/llm.py

@@ -1,9 +1,23 @@
 from collections.abc import Generator
+from decimal import Decimal
 from typing import Optional, Union
 
-from core.model_runtime.entities.llm_entities import LLMResult
-from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
-from core.model_runtime.entities.model_entities import AIModelEntity
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
+from core.model_runtime.entities.message_entities import (
+    PromptMessage,
+    PromptMessageTool,
+)
+from core.model_runtime.entities.model_entities import (
+    AIModelEntity,
+    DefaultParameterName,
+    FetchFrom,
+    ModelPropertyKey,
+    ModelType,
+    ParameterRule,
+    ParameterType,
+    PriceConfig,
+)
 from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
 
 
@@ -36,8 +50,98 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
 
     def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
         cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
-
-        return super().get_customizable_model_schema(model, cred_with_endpoint)
+        REPETITION_PENALTY = "repetition_penalty"
+        TOP_K = "top_k"
+        features = []
+
+        entity = AIModelEntity(
+            model=model,
+            label=I18nObject(en_US=model),
+            model_type=ModelType.LLM,
+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            features=features,
+            model_properties={
+                ModelPropertyKey.CONTEXT_SIZE: int(cred_with_endpoint.get('context_size', "4096")),
+                ModelPropertyKey.MODE: cred_with_endpoint.get('mode'),
+            },
+            parameter_rules=[
+                ParameterRule(
+                    name=DefaultParameterName.TEMPERATURE.value,
+                    label=I18nObject(en_US="Temperature"),
+                    type=ParameterType.FLOAT,
+                    default=float(cred_with_endpoint.get('temperature', 0.7)),
+                    min=0,
+                    max=2,
+                    precision=2
+                ),
+                ParameterRule(
+                    name=DefaultParameterName.TOP_P.value,
+                    label=I18nObject(en_US="Top P"),
+                    type=ParameterType.FLOAT,
+                    default=float(cred_with_endpoint.get('top_p', 1)),
+                    min=0,
+                    max=1,
+                    precision=2
+                ),
+                ParameterRule(
+                    name=TOP_K,
+                    label=I18nObject(en_US="Top K"),
+                    type=ParameterType.INT,
+                    default=int(cred_with_endpoint.get('top_k', 50)),
+                    min=-2147483647,
+                    max=2147483647,
+                    precision=0
+                ),
+                ParameterRule(
+                    name=REPETITION_PENALTY,
+                    label=I18nObject(en_US="Repetition Penalty"),
+                    type=ParameterType.FLOAT,
+                    default=float(cred_with_endpoint.get('repetition_penalty', 1)),
+                    min=-3.4,
+                    max=3.4,
+                    precision=1
+                ),
+                ParameterRule(
+                    name=DefaultParameterName.MAX_TOKENS.value,
+                    label=I18nObject(en_US="Max Tokens"),
+                    type=ParameterType.INT,
+                    default=512,
+                    min=1,
+                    max=int(cred_with_endpoint.get('max_tokens_to_sample', 4096)),
+                ),
+                ParameterRule(
+                    name=DefaultParameterName.FREQUENCY_PENALTY.value,
+                    label=I18nObject(en_US="Frequency Penalty"),
+                    type=ParameterType.FLOAT,
+                    default=float(credentials.get('frequency_penalty', 0)),
+                    min=-2,
+                    max=2
+                ),
+                ParameterRule(
+                    name=DefaultParameterName.PRESENCE_PENALTY.value,
+                    label=I18nObject(en_US="Presence Penalty"),
+                    type=ParameterType.FLOAT,
+                    default=float(credentials.get('presence_penalty', 0)),
+                    min=-2,
+                    max=2
+                )
+            ],
+            pricing=PriceConfig(
+                input=Decimal(cred_with_endpoint.get('input_price', 0)),
+                output=Decimal(cred_with_endpoint.get('output_price', 0)),
+                unit=Decimal(cred_with_endpoint.get('unit', 0)),
+                currency=cred_with_endpoint.get('currency', "USD")
+            ),
+        )
+
+        if cred_with_endpoint['mode'] == 'chat':
+            entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value
+        elif cred_with_endpoint['mode'] == 'completion':
+            entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value
+        else:
+            raise ValueError(f"Unknown completion type {cred_with_endpoint['completion_type']}")
+
+        return entity
 
     def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
                        tools: Optional[list[PromptMessageTool]] = None) -> int: