|
@@ -4,12 +4,22 @@ from urllib.parse import urlparse
|
|
|
|
|
|
import tiktoken
|
|
|
|
|
|
-from core.model_runtime.entities.llm_entities import LLMResult
|
|
|
+from core.model_runtime.entities.common_entities import I18nObject
|
|
|
+from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
|
|
|
from core.model_runtime.entities.message_entities import (
|
|
|
PromptMessage,
|
|
|
PromptMessageTool,
|
|
|
SystemPromptMessage,
|
|
|
)
|
|
|
+from core.model_runtime.entities.model_entities import (
|
|
|
+ AIModelEntity,
|
|
|
+ FetchFrom,
|
|
|
+ ModelFeature,
|
|
|
+ ModelPropertyKey,
|
|
|
+ ModelType,
|
|
|
+ ParameterRule,
|
|
|
+ ParameterType,
|
|
|
+)
|
|
|
from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguageModel
|
|
|
|
|
|
|
|
@@ -125,3 +135,58 @@ class YiLargeLanguageModel(OpenAILargeLanguageModel):
|
|
|
else:
|
|
|
parsed_url = urlparse(credentials["endpoint_url"])
|
|
|
credentials["openai_api_base"] = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
|
|
+
|
|
|
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
|
|
+ return AIModelEntity(
|
|
|
+ model=model,
|
|
|
+ label=I18nObject(en_US=model, zh_Hans=model),
|
|
|
+ model_type=ModelType.LLM,
|
|
|
+ features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL]
|
|
|
+ if credentials.get("function_calling_type") == "tool_call"
|
|
|
+ else [],
|
|
|
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
|
|
+ model_properties={
|
|
|
+ ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 8000)),
|
|
|
+ ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
|
|
+ },
|
|
|
+ parameter_rules=[
|
|
|
+ ParameterRule(
|
|
|
+ name="temperature",
|
|
|
+ use_template="temperature",
|
|
|
+ label=I18nObject(en_US="Temperature", zh_Hans="温度"),
|
|
|
+ type=ParameterType.FLOAT,
|
|
|
+ ),
|
|
|
+ ParameterRule(
|
|
|
+ name="max_tokens",
|
|
|
+ use_template="max_tokens",
|
|
|
+ default=512,
|
|
|
+ min=1,
|
|
|
+ max=int(credentials.get("max_tokens", 8192)),
|
|
|
+ label=I18nObject(
|
|
|
+ en_US="Max Tokens", zh_Hans="指定生成结果长度的上限。如果生成结果截断,可以调大该参数"
|
|
|
+ ),
|
|
|
+ type=ParameterType.INT,
|
|
|
+ ),
|
|
|
+ ParameterRule(
|
|
|
+ name="top_p",
|
|
|
+ use_template="top_p",
|
|
|
+ label=I18nObject(
|
|
|
+ en_US="Top P",
|
|
|
+ zh_Hans="控制生成结果的随机性。数值越小,随机性越弱;数值越大,随机性越强。",
|
|
|
+ ),
|
|
|
+ type=ParameterType.FLOAT,
|
|
|
+ ),
|
|
|
+ ParameterRule(
|
|
|
+ name="top_k",
|
|
|
+ use_template="top_k",
|
|
|
+ label=I18nObject(en_US="Top K", zh_Hans="取样数量"),
|
|
|
+ type=ParameterType.FLOAT,
|
|
|
+ ),
|
|
|
+ ParameterRule(
|
|
|
+ name="frequency_penalty",
|
|
|
+ use_template="frequency_penalty",
|
|
|
+ label=I18nObject(en_US="Frequency Penalty", zh_Hans="重复惩罚"),
|
|
|
+ type=ParameterType.FLOAT,
|
|
|
+ ),
|
|
|
+ ],
|
|
|
+ )
|