Pārlūkot izejas kodu

feat: add `supported_model_types` field and filter in provider list (#1581)

takatost 1 gadu atpakaļ
vecāks
revīzija
c9368925a3

+ 5 - 1
api/controllers/console/workspace/model_providers.py

@@ -21,8 +21,12 @@ class ModelProviderListApi(Resource):
     def get(self):
         tenant_id = current_user.current_tenant_id
 
+        parser = reqparse.RequestParser()
+        parser.add_argument('model_type', type=str, required=False, nullable=True, location='args')
+        args = parser.parse_args()
+
         provider_service = ProviderService()
-        provider_list = provider_service.get_provider_list(tenant_id)
+        provider_list = provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get('model_type'))
 
         return provider_list
 

+ 3 - 0
api/core/model_providers/rules/anthropic.json

@@ -12,6 +12,9 @@
         "quota_limit": 0
     },
     "model_flexibility": "fixed",
+    "supported_model_types": [
+        "text-generation"
+    ],
     "price_config": {
         "claude-instant-1": {
             "prompt": "1.63",

+ 4 - 0
api/core/model_providers/rules/azure_openai.json

@@ -4,6 +4,10 @@
     ],
     "system_config": null,
     "model_flexibility": "configurable",
+    "supported_model_types": [
+        "text-generation",
+        "embeddings"
+    ],
     "price_config":{
         "gpt-4": {
             "prompt": "0.03",

+ 3 - 0
api/core/model_providers/rules/baichuan.json

@@ -4,6 +4,9 @@
     ],
     "system_config": null,
     "model_flexibility": "fixed",
+    "supported_model_types": [
+        "text-generation"
+    ],
     "price_config": {
         "baichuan2-53b": {
             "prompt": "0.01",

+ 4 - 1
api/core/model_providers/rules/chatglm.json

@@ -3,5 +3,8 @@
         "custom"
     ],
     "system_config": null,
-    "model_flexibility": "fixed"
+    "model_flexibility": "fixed",
+    "supported_model_types": [
+        "text-generation"
+    ]
 }

+ 4 - 1
api/core/model_providers/rules/cohere.json

@@ -3,5 +3,8 @@
         "custom"
     ],
     "system_config": null,
-    "model_flexibility": "fixed"
+    "model_flexibility": "fixed",
+    "supported_model_types": [
+        "reranking"
+    ]
 }

+ 5 - 1
api/core/model_providers/rules/huggingface_hub.json

@@ -3,5 +3,9 @@
         "custom"
     ],
     "system_config": null,
-    "model_flexibility": "configurable"
+    "model_flexibility": "configurable",
+    "supported_model_types": [
+        "text-generation",
+        "embeddings"
+    ]
 }

+ 5 - 1
api/core/model_providers/rules/localai.json

@@ -3,5 +3,9 @@
         "custom"
     ],
     "system_config": null,
-    "model_flexibility": "configurable"
+    "model_flexibility": "configurable",
+    "supported_model_types": [
+        "text-generation",
+        "embeddings"
+    ]
 }

+ 4 - 0
api/core/model_providers/rules/minimax.json

@@ -10,6 +10,10 @@
         "quota_unit": "tokens"
     },
     "model_flexibility": "fixed",
+    "supported_model_types": [
+        "text-generation",
+        "embeddings"
+    ],
     "price_config": {
         "abab5.5-chat": {
             "prompt": "0.015",

+ 6 - 0
api/core/model_providers/rules/openai.json

@@ -11,6 +11,12 @@
         "quota_limit": 200
     },
     "model_flexibility": "fixed",
+    "supported_model_types": [
+        "text-generation",
+        "embeddings",
+        "speech2text",
+        "moderation"
+    ],
     "price_config": {
         "gpt-4": {
             "prompt": "0.03",

+ 5 - 1
api/core/model_providers/rules/openllm.json

@@ -3,5 +3,9 @@
         "custom"
     ],
     "system_config": null,
-    "model_flexibility": "configurable"
+    "model_flexibility": "configurable",
+    "supported_model_types": [
+        "text-generation",
+        "embeddings"
+    ]
 }

+ 5 - 1
api/core/model_providers/rules/replicate.json

@@ -3,5 +3,9 @@
         "custom"
     ],
     "system_config": null,
-    "model_flexibility": "configurable"
+    "model_flexibility": "configurable",
+    "supported_model_types": [
+        "text-generation",
+        "embeddings"
+    ]
 }

+ 3 - 0
api/core/model_providers/rules/spark.json

@@ -10,6 +10,9 @@
         "quota_unit": "tokens"
     },
     "model_flexibility": "fixed",
+    "supported_model_types": [
+        "text-generation"
+    ],
     "price_config": {
         "spark": {
             "prompt": "0.18",

+ 3 - 0
api/core/model_providers/rules/tongyi.json

@@ -4,6 +4,9 @@
     ],
     "system_config": null,
     "model_flexibility": "fixed",
+    "supported_model_types": [
+        "text-generation"
+    ],
     "price_config": {
         "qwen-turbo": {
             "prompt": "0.012",

+ 3 - 0
api/core/model_providers/rules/wenxin.json

@@ -4,6 +4,9 @@
     ],
     "system_config": null,
     "model_flexibility": "fixed",
+    "supported_model_types": [
+        "text-generation"
+    ],
     "price_config": {
         "ernie-bot-4": {
             "prompt": "0",

+ 5 - 1
api/core/model_providers/rules/xinference.json

@@ -3,5 +3,9 @@
         "custom"
     ],
     "system_config": null,
-    "model_flexibility": "configurable"
+    "model_flexibility": "configurable",
+    "supported_model_types": [
+        "text-generation",
+        "embeddings"
+    ]
 }

+ 4 - 0
api/core/model_providers/rules/zhipuai.json

@@ -10,6 +10,10 @@
         "quota_unit": "tokens"
     },
     "model_flexibility": "fixed",
+    "supported_model_types": [
+        "text-generation",
+        "embeddings"
+    ],
     "price_config": {
         "chatglm_turbo": {
             "prompt": "0.005",

+ 7 - 2
api/services/provider_service.py

@@ -17,11 +17,12 @@ from models.provider import Provider, ProviderModel, TenantPreferredModelProvide
 
 class ProviderService:
 
-    def get_provider_list(self, tenant_id: str):
+    def get_provider_list(self, tenant_id: str, model_type: Optional[str] = None) -> list:
         """
         get provider list of tenant.
 
-        :param tenant_id:
+        :param tenant_id: workspace id
+        :param model_type: filter by model type
         :return:
         """
         # get rules for all providers
@@ -79,6 +80,9 @@ class ProviderService:
         providers_list = {}
 
         for model_provider_name, model_provider_rule in model_provider_rules.items():
+            if model_type and model_type not in model_provider_rule.get('supported_model_types', []):
+                continue
+
             # get preferred provider type
             preferred_model_provider = provider_name_to_preferred_provider_type_dict.get(model_provider_name)
             preferred_provider_type = ModelProviderFactory.get_preferred_type_by_preferred_model_provider(
@@ -90,6 +94,7 @@ class ProviderService:
             provider_config_dict = {
                 "preferred_provider_type": preferred_provider_type,
                 "model_flexibility": model_provider_rule['model_flexibility'],
+                "supported_model_types": model_provider_rule.get("supported_model_types", []),
             }
 
             provider_parameter_dict = {}