瀏覽代碼

feat: support doubao llm and embeding models (#4431)

sino 11 月之前
父節點
當前提交
6e9066ebf4

+ 61 - 1
api/core/model_runtime/model_providers/volcengine_maas/llm/models.py

@@ -1,4 +1,64 @@
 ModelConfigs = {
+    'Doubao-pro-4k': {
+        'req_params': {
+            'max_prompt_tokens': 4096,
+            'max_new_tokens': 4096,
+        },
+        'model_properties': {
+            'context_size': 4096,
+            'mode': 'chat',
+        }
+    },
+    'Doubao-lite-4k': {
+        'req_params': {
+            'max_prompt_tokens': 4096,
+            'max_new_tokens': 4096,
+        },
+        'model_properties': {
+            'context_size': 4096,
+            'mode': 'chat',
+        }
+    },
+    'Doubao-pro-32k': {
+        'req_params': {
+            'max_prompt_tokens': 32768,
+            'max_new_tokens': 32768,
+        },
+        'model_properties': {
+            'context_size': 32768,
+            'mode': 'chat',
+        }
+    },
+    'Doubao-lite-32k': {
+        'req_params': {
+            'max_prompt_tokens': 32768,
+            'max_new_tokens': 32768,
+        },
+        'model_properties': {
+            'context_size': 32768,
+            'mode': 'chat',
+        }
+    },
+    'Doubao-pro-128k': {
+        'req_params': {
+            'max_prompt_tokens': 131072,
+            'max_new_tokens': 131072,
+        },
+        'model_properties': {
+            'context_size': 131072,
+            'mode': 'chat',
+        }
+    },
+    'Doubao-lite-128k': {
+        'req_params': {
+            'max_prompt_tokens': 131072,
+            'max_new_tokens': 131072,
+        },
+        'model_properties': {
+            'context_size': 131072,
+            'mode': 'chat',
+        }
+    },
     'Skylark2-pro-4k': {
         'req_params': {
             'max_prompt_tokens': 4096,
@@ -8,5 +68,5 @@ ModelConfigs = {
             'context_size': 4096,
             'mode': 'chat',
         }
-    }
+    },
 }

+ 9 - 0
api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py

@@ -0,0 +1,9 @@
+ModelConfigs = {
+    'Doubao-embedding': {
+        'req_params': {},
+        'model_properties': {
+            'context_size': 4096,
+            'max_chunks': 1,
+        }
+    },
+}

+ 40 - 2
api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py

@@ -1,7 +1,16 @@
 import time
+from decimal import Decimal
 from typing import Optional
 
-from core.model_runtime.entities.model_entities import PriceType
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.model_entities import (
+    AIModelEntity,
+    FetchFrom,
+    ModelPropertyKey,
+    ModelType,
+    PriceConfig,
+    PriceType,
+)
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 from core.model_runtime.errors.invoke import (
     InvokeAuthorizationError,
@@ -21,6 +30,7 @@ from core.model_runtime.model_providers.volcengine_maas.errors import (
     RateLimitErrors,
     ServerUnavailableErrors,
 )
+from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import ModelConfigs
 from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
 
 
@@ -45,7 +55,7 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
         resp = MaaSClient.wrap_exception(lambda: client.embeddings(texts))
 
         usage = self._calc_response_usage(
-            model=model, credentials=credentials, tokens=resp['total_tokens'])
+            model=model, credentials=credentials, tokens=resp['usage']['total_tokens'])
 
         result = TextEmbeddingResult(
             model=model,
@@ -101,6 +111,34 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
             InvokeBadRequestError: BadRequestErrors.values(),
         }
 
+    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
+        """
+            generate custom model entities from credentials
+        """
+        model_properties = ModelConfigs.get(
+            credentials['base_model_name'], {}).get('model_properties', {}).copy()
+        if credentials.get('context_size'):
+            model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
+                credentials.get('context_size', 4096))
+        if credentials.get('max_chunks'):
+            model_properties[ModelPropertyKey.MAX_CHUNKS] = int(
+                credentials.get('max_chunks', 4096))
+        entity = AIModelEntity(
+            model=model,
+            label=I18nObject(en_US=model),
+            model_type=ModelType.TEXT_EMBEDDING,
+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            model_properties=model_properties,
+            parameter_rules=[],
+            pricing=PriceConfig(
+                input=Decimal(credentials.get('input_price', 0)),
+                unit=Decimal(credentials.get('unit', 0)),
+                currency=credentials.get('currency', "USD")
+            )
+        )
+
+        return entity
+
     def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
         """
         Calculate response usage

+ 42 - 5
api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml

@@ -76,21 +76,60 @@ model_credential_schema:
         en_US: Enter your Endpoint ID
         zh_Hans: 输入您的 Endpoint ID
     - variable: base_model_name
-      show_on:
-        - variable: __model_type
-          value: llm
       label:
         en_US: Base Model
         zh_Hans: 基础模型
       type: select
       required: true
       options:
+        - label:
+            en_US: Doubao-pro-4k
+          value: Doubao-pro-4k
+          show_on:
+            - variable: __model_type
+              value: llm
+        - label:
+            en_US: Doubao-lite-4k
+          value: Doubao-lite-4k
+          show_on:
+            - variable: __model_type
+              value: llm
+        - label:
+            en_US: Doubao-pro-32k
+          value: Doubao-pro-32k
+          show_on:
+            - variable: __model_type
+              value: llm
+        - label:
+            en_US: Doubao-lite-32k
+          value: Doubao-lite-32k
+          show_on:
+            - variable: __model_type
+              value: llm
+        - label:
+            en_US: Doubao-pro-128k
+          value: Doubao-pro-128k
+          show_on:
+            - variable: __model_type
+              value: llm
+        - label:
+            en_US: Doubao-lite-128k
+          value: Doubao-lite-128k
+          show_on:
+            - variable: __model_type
+              value: llm
         - label:
             en_US: Skylark2-pro-4k
           value: Skylark2-pro-4k
           show_on:
             - variable: __model_type
               value: llm
+        - label:
+            en_US: Doubao-embedding
+          value: Doubao-embedding
+          show_on:
+            - variable: __model_type
+              value: text-embedding
         - label:
             en_US: Custom
             zh_Hans: 自定义
@@ -122,8 +161,6 @@ model_credential_schema:
     - variable: context_size
       required: true
       show_on:
-        - variable: __model_type
-          value: llm
         - variable: base_model_name
           value: Custom
       label:

+ 4 - 0
api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py

@@ -21,6 +21,7 @@ def test_validate_credentials():
                 'volc_access_key_id': 'INVALID',
                 'volc_secret_access_key': 'INVALID',
                 'endpoint_id': 'INVALID',
+                'base_model_name': 'Doubao-embedding',
             }
         )
 
@@ -32,6 +33,7 @@ def test_validate_credentials():
             'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
             'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
             'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
+            'base_model_name': 'Doubao-embedding',
         },
     )
 
@@ -47,6 +49,7 @@ def test_invoke_model():
             'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
             'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
             'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
+            'base_model_name': 'Doubao-embedding',
         },
         texts=[
             "hello",
@@ -71,6 +74,7 @@ def test_get_num_tokens():
             'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
             'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
             'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
+            'base_model_name': 'Doubao-embedding',
         },
         texts=[
             "hello",