Przeglądaj źródła

chore: optimize ark model parameters (#7378)

sino 8 miesięcy temu
rodzic
commit
a0a67873aa

+ 21 - 52
api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py

@@ -35,7 +35,10 @@ from core.model_runtime.model_providers.volcengine_maas.errors import (
     RateLimitErrors,
     ServerUnavailableErrors,
 )
-from core.model_runtime.model_providers.volcengine_maas.llm.models import ModelConfigs
+from core.model_runtime.model_providers.volcengine_maas.llm.models import (
+    get_model_config,
+    get_v2_req_params,
+)
 from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
 
 logger = logging.getLogger(__name__)
@@ -95,37 +98,12 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
             -> LLMResult | Generator:
 
         client = MaaSClient.from_credential(credentials)
-
-        req_params = ModelConfigs.get(
-            credentials['base_model_name'], {}).get('req_params', {}).copy()
-        if credentials.get('context_size'):
-            req_params['max_prompt_tokens'] = credentials.get('context_size')
-        if credentials.get('max_tokens'):
-            req_params['max_new_tokens'] = credentials.get('max_tokens')
-        if model_parameters.get('max_tokens'):
-            req_params['max_new_tokens'] = model_parameters.get('max_tokens')
-        if model_parameters.get('temperature'):
-            req_params['temperature'] = model_parameters.get('temperature')
-        if model_parameters.get('top_p'):
-            req_params['top_p'] = model_parameters.get('top_p')
-        if model_parameters.get('top_k'):
-            req_params['top_k'] = model_parameters.get('top_k')
-        if model_parameters.get('presence_penalty'):
-            req_params['presence_penalty'] = model_parameters.get(
-                'presence_penalty')
-        if model_parameters.get('frequency_penalty'):
-            req_params['frequency_penalty'] = model_parameters.get(
-                'frequency_penalty')
-        if stop:
-            req_params['stop'] = stop
-
+        req_params = get_v2_req_params(credentials, model_parameters, stop)
         extra_model_kwargs = {}
-        
         if tools:
             extra_model_kwargs['tools'] = [
                 MaaSClient.transform_tool_prompt_to_maas_config(tool) for tool in tools
             ]
-
         resp = MaaSClient.wrap_exception(
             lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs))
         if not stream:
@@ -197,10 +175,8 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
         """
             used to define customizable model schema
         """
-        max_tokens = ModelConfigs.get(
-            credentials['base_model_name'], {}).get('req_params', {}).get('max_new_tokens')
-        if credentials.get('max_tokens'):
-            max_tokens = int(credentials.get('max_tokens'))
+        model_config = get_model_config(credentials)
+    
         rules = [
             ParameterRule(
                 name='temperature',
@@ -234,10 +210,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
                 name='presence_penalty',
                 type=ParameterType.FLOAT,
                 use_template='presence_penalty',
-                label={
-                    'en_US': 'Presence Penalty',
-                    'zh_Hans': '存在惩罚',
-                },
+                label=I18nObject(
+                    en_US='Presence Penalty',
+                    zh_Hans= '存在惩罚',
+                ),
                 min=-2.0,
                 max=2.0,
             ),
@@ -245,10 +221,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
                 name='frequency_penalty',
                 type=ParameterType.FLOAT,
                 use_template='frequency_penalty',
-                label={
-                    'en_US': 'Frequency Penalty',
-                    'zh_Hans': '频率惩罚',
-                },
+                label=I18nObject(
+                    en_US= 'Frequency Penalty',
+                    zh_Hans= '频率惩罚',
+                ),
                 min=-2.0,
                 max=2.0,
             ),
@@ -257,7 +233,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
                 type=ParameterType.INT,
                 use_template='max_tokens',
                 min=1,
-                max=max_tokens,
+                max=model_config.properties.max_tokens,
                 default=512,
                 label=I18nObject(
                     zh_Hans='最大生成长度',
@@ -266,17 +242,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
             ),
         ]
 
-        model_properties = ModelConfigs.get(
-            credentials['base_model_name'], {}).get('model_properties', {}).copy()
-        if credentials.get('mode'):
-            model_properties[ModelPropertyKey.MODE] = credentials.get('mode')
-        if credentials.get('context_size'):
-            model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
-                credentials.get('context_size', 4096))
-
-        model_features = ModelConfigs.get(
-            credentials['base_model_name'], {}).get('features', [])
-
+        model_properties = {}
+        model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size
+        model_properties[ModelPropertyKey.MODE] = model_config.properties.mode.value
+       
         entity = AIModelEntity(
             model=model,
             label=I18nObject(
@@ -286,7 +255,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
             model_type=ModelType.LLM,
             model_properties=model_properties,
             parameter_rules=rules,
-            features=model_features,
+            features=model_config.features,
         )
 
         return entity

+ 120 - 178
api/core/model_runtime/model_providers/volcengine_maas/llm/models.py

@@ -1,181 +1,123 @@
+from pydantic import BaseModel
+
+from core.model_runtime.entities.llm_entities import LLMMode
 from core.model_runtime.entities.model_entities import ModelFeature
 
-ModelConfigs = {
-    'Doubao-pro-4k': {
-        'req_params': {
-            'max_prompt_tokens': 4096,
-            'max_new_tokens': 4096,
-        },
-        'model_properties': {
-            'context_size': 4096,
-            'mode': 'chat',
-        },
-        'features': [
-            ModelFeature.TOOL_CALL
-        ],
-    },
-    'Doubao-lite-4k': {
-        'req_params': {
-            'max_prompt_tokens': 4096,
-            'max_new_tokens': 4096,
-        },
-        'model_properties': {
-            'context_size': 4096,
-            'mode': 'chat',
-        },
-        'features': [
-            ModelFeature.TOOL_CALL
-        ],
-    },
-    'Doubao-pro-32k': {
-        'req_params': {
-            'max_prompt_tokens': 32768,
-            'max_new_tokens': 32768,
-        },
-        'model_properties': {
-            'context_size': 32768,
-            'mode': 'chat',
-        },
-        'features': [
-            ModelFeature.TOOL_CALL
-        ],
-    },
-    'Doubao-lite-32k': {
-        'req_params': {
-            'max_prompt_tokens': 32768,
-            'max_new_tokens': 32768,
-        },
-        'model_properties': {
-            'context_size': 32768,
-            'mode': 'chat',
-        },
-        'features': [
-            ModelFeature.TOOL_CALL
-        ],
-    },
-    'Doubao-pro-128k': {
-        'req_params': {
-            'max_prompt_tokens': 131072,
-            'max_new_tokens': 131072,
-        },
-        'model_properties': {
-            'context_size': 131072,
-            'mode': 'chat',
-        },
-        'features': [
-            ModelFeature.TOOL_CALL
-        ],
-    },
-    'Doubao-lite-128k': {
-        'req_params': {
-            'max_prompt_tokens': 131072,
-            'max_new_tokens': 131072,
-        },
-        'model_properties': {
-            'context_size': 131072,
-            'mode': 'chat',
-        },
-        'features': [
-            ModelFeature.TOOL_CALL
-        ],
-    },
-    'Skylark2-pro-4k': {
-        'req_params': {
-            'max_prompt_tokens': 4096,
-            'max_new_tokens': 4000,
-        },
-        'model_properties': {
-            'context_size': 4096,
-            'mode': 'chat',
-        },
-        'features': [],
-    },
-    'Llama3-8B': {
-        'req_params': {
-            'max_prompt_tokens': 8192,
-            'max_new_tokens': 8192,
-        },
-        'model_properties': {
-            'context_size': 8192,
-            'mode': 'chat',
-        },
-        'features': [],
-    },
-    'Llama3-70B': {
-        'req_params': {
-            'max_prompt_tokens': 8192,
-            'max_new_tokens': 8192,
-        },
-        'model_properties': {
-            'context_size': 8192,
-            'mode': 'chat',
-        },
-        'features': [],
-    },
-    'Moonshot-v1-8k': {
-        'req_params': {
-            'max_prompt_tokens': 8192,
-            'max_new_tokens': 4096,
-        },
-        'model_properties': {
-            'context_size': 8192,
-            'mode': 'chat',
-        },
-        'features': [],
-    },
-    'Moonshot-v1-32k': {
-        'req_params': {
-            'max_prompt_tokens': 32768,
-            'max_new_tokens': 16384,
-        },
-        'model_properties': {
-            'context_size': 32768,
-            'mode': 'chat',
-        },
-        'features': [],
-    },
-    'Moonshot-v1-128k': {
-        'req_params': {
-            'max_prompt_tokens': 131072,
-            'max_new_tokens': 65536,
-        },
-        'model_properties': {
-            'context_size': 131072,
-            'mode': 'chat',
-        },
-        'features': [],
-    },
-    'GLM3-130B': {
-        'req_params': {
-            'max_prompt_tokens': 8192,
-            'max_new_tokens': 4096,
-        },
-        'model_properties': {
-            'context_size': 8192,
-            'mode': 'chat',
-        },
-        'features': [],
-    },
-    'GLM3-130B-Fin': {
-        'req_params': {
-            'max_prompt_tokens': 8192,
-            'max_new_tokens': 4096,
-        },
-        'model_properties': {
-            'context_size': 8192,
-            'mode': 'chat',
-        },
-        'features': [],
-    },
-    'Mistral-7B': {
-        'req_params': {
-            'max_prompt_tokens': 8192,
-            'max_new_tokens': 2048,
-        },
-        'model_properties': {
-            'context_size': 8192,
-            'mode': 'chat',
-        },
-        'features': [],
-    }
+
+class ModelProperties(BaseModel):
+    context_size: int 
+    max_tokens: int 
+    mode: LLMMode
+
+class ModelConfig(BaseModel):
+    properties: ModelProperties
+    features: list[ModelFeature]
+
+
+configs: dict[str, ModelConfig] = {
+    'Doubao-pro-4k': ModelConfig(
+        properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
+        features=[ModelFeature.TOOL_CALL]
+    ),
+    'Doubao-lite-4k': ModelConfig(
+        properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
+        features=[ModelFeature.TOOL_CALL]
+    ),
+    'Doubao-pro-32k': ModelConfig(
+        properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT),
+        features=[ModelFeature.TOOL_CALL]
+    ),
+    'Doubao-lite-32k': ModelConfig(
+        properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT),
+        features=[ModelFeature.TOOL_CALL]
+    ),
+    'Doubao-pro-128k': ModelConfig(
+        properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT),
+        features=[ModelFeature.TOOL_CALL]
+    ),
+    'Doubao-lite-128k': ModelConfig(
+        properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT),
+        features=[ModelFeature.TOOL_CALL]
+    ),
+    'Skylark2-pro-4k': ModelConfig(
+        properties=ModelProperties(context_size=4096, max_tokens=4000, mode=LLMMode.CHAT),
+        features=[]
+    ),
+    'Llama3-8B': ModelConfig(
+        properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT),
+        features=[]
+    ),
+    'Llama3-70B': ModelConfig(
+        properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT),
+        features=[]
+    ),
+    'Moonshot-v1-8k': ModelConfig(
+        properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
+        features=[]
+    ),
+    'Moonshot-v1-32k': ModelConfig(
+        properties=ModelProperties(context_size=32768, max_tokens=16384, mode=LLMMode.CHAT),
+        features=[]
+    ),
+    'Moonshot-v1-128k': ModelConfig(
+        properties=ModelProperties(context_size=131072, max_tokens=65536, mode=LLMMode.CHAT),
+        features=[]
+    ),
+    'GLM3-130B': ModelConfig(
+        properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
+        features=[]
+    ),
+    'GLM3-130B-Fin': ModelConfig(
+        properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
+        features=[]
+    ),
+    'Mistral-7B': ModelConfig(
+        properties=ModelProperties(context_size=8192, max_tokens=2048, mode=LLMMode.CHAT),
+        features=[]
+    )
 }
+
+def get_model_config(credentials: dict)->ModelConfig:
+    base_model = credentials.get('base_model_name', '')
+    model_configs = configs.get(base_model)
+    if not model_configs:
+        return ModelConfig(
+                properties=ModelProperties(
+                context_size=int(credentials.get('context_size', 0)),
+                max_tokens=int(credentials.get('max_tokens', 0)),
+                mode= LLMMode.value_of(credentials.get('mode', 'chat')),
+            ),
+            features=[]
+        )
+    return model_configs
+
+
+def get_v2_req_params(credentials: dict, model_parameters: dict, 
+                      stop: list[str] | None=None):
+    req_params = {}
+    # predefined properties
+    model_configs = get_model_config(credentials)
+    if model_configs:
+        req_params['max_prompt_tokens'] = model_configs.properties.context_size
+        req_params['max_new_tokens'] = model_configs.properties.max_tokens
+
+    # model parameters
+    if model_parameters.get('max_tokens'):
+        req_params['max_new_tokens'] = model_parameters.get('max_tokens')
+    if model_parameters.get('temperature'):
+        req_params['temperature'] = model_parameters.get('temperature')
+    if model_parameters.get('top_p'):
+        req_params['top_p'] = model_parameters.get('top_p')
+    if model_parameters.get('top_k'):
+        req_params['top_k'] = model_parameters.get('top_k')
+    if model_parameters.get('presence_penalty'):
+        req_params['presence_penalty'] = model_parameters.get(
+            'presence_penalty')
+    if model_parameters.get('frequency_penalty'):
+        req_params['frequency_penalty'] = model_parameters.get(
+            'frequency_penalty')
+            
+    if stop:
+        req_params['stop'] = stop
+
+    return req_params

+ 25 - 7
api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py

@@ -1,9 +1,27 @@
+from pydantic import BaseModel
+
+
+class ModelProperties(BaseModel):
+    context_size: int 
+    max_chunks: int 
+
+class ModelConfig(BaseModel):
+    properties: ModelProperties
+
 ModelConfigs = {
-    'Doubao-embedding': {
-        'req_params': {},
-        'model_properties': {
-            'context_size': 4096,
-            'max_chunks': 1,
-        }
-    },
+    'Doubao-embedding': ModelConfig(
+        properties=ModelProperties(context_size=4096, max_chunks=1)
+    ),
 }
+
+def get_model_config(credentials: dict)->ModelConfig:
+    base_model = credentials.get('base_model_name', '')
+    model_configs = ModelConfigs.get(base_model)
+    if not model_configs:
+        return ModelConfig(
+                properties=ModelProperties(
+                context_size=int(credentials.get('context_size', 0)),
+                max_chunks=int(credentials.get('max_chunks', 0)),
+            )
+        )
+    return model_configs

+ 5 - 9
api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py

@@ -30,7 +30,7 @@ from core.model_runtime.model_providers.volcengine_maas.errors import (
     RateLimitErrors,
     ServerUnavailableErrors,
 )
-from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import ModelConfigs
+from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import get_model_config
 from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
 
 
@@ -115,14 +115,10 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
         """
             generate custom model entities from credentials
         """
-        model_properties = ModelConfigs.get(
-            credentials['base_model_name'], {}).get('model_properties', {}).copy()
-        if credentials.get('context_size'):
-            model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
-                credentials.get('context_size', 4096))
-        if credentials.get('max_chunks'):
-            model_properties[ModelPropertyKey.MAX_CHUNKS] = int(
-                credentials.get('max_chunks', 4096))
+        model_config = get_model_config(credentials)
+        model_properties = {}
+        model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size
+        model_properties[ModelPropertyKey.MAX_CHUNKS] = model_config.properties.max_chunks
         entity = AIModelEntity(
             model=model,
             label=I18nObject(en_US=model),