Pārlūkot izejas kodu

fix: setting default model to gpt-3.5-turbo-1106 and remove default m… (#2274)

Yeuoly 1 gadu atpakaļ
vecāks
revīzija
34634bddf1
2 mainītis faili ar 49 papildinājumiem un 21 dzēšanām
  1. 25 12
      api/controllers/console/app/app.py
  2. 24 9
      api/core/provider_manager.py

+ 25 - 12
api/controllers/console/app/app.py

@@ -107,20 +107,33 @@ class AppListApi(Resource):
             # validate config
             model_config_dict = args['model_config']
 
-            # get model provider
-            model_manager = ModelManager()
-            model_instance = model_manager.get_default_model_instance(
-                tenant_id=current_user.current_tenant_id,
-                model_type=ModelType.LLM
+            # Get provider configurations
+            provider_manager = ProviderManager()
+            provider_configurations = provider_manager.get_configurations(current_user.current_tenant_id)
+
+            # get available models from provider_configurations
+            available_models = provider_configurations.get_models(
+                model_type=ModelType.LLM,
+                only_active=True
             )
 
-            if not model_instance:
-                raise ProviderNotInitializeError(
-                    f"No Default System Reasoning Model available. Please configure "
-                    f"in the Settings -> Model Provider.")
-            else:
-                model_config_dict["model"]["provider"] = model_instance.provider
-                model_config_dict["model"]["name"] = model_instance.model
+            # check if model is available
+            available_models_names = [f'{model.provider.provider}.{model.model}' for model in available_models]
+            provider_model = f"{model_config_dict['model']['provider']}.{model_config_dict['model']['name']}"
+            if provider_model not in available_models_names:
+                model_manager = ModelManager()
+                model_instance = model_manager.get_default_model_instance(
+                    tenant_id=current_user.current_tenant_id,
+                    model_type=ModelType.LLM
+                )
+
+                if not model_instance:
+                    raise ProviderNotInitializeError(
+                        f"No Default System Reasoning Model available. Please configure "
+                        f"in the Settings -> Model Provider.")
+                else:
+                    model_config_dict["model"]["provider"] = model_instance.provider
+                    model_config_dict["model"]["name"] = model_instance.model
 
             model_configuration = AppModelConfigService.validate_configuration(
                 tenant_id=current_user.current_tenant_id,

+ 24 - 9
api/core/provider_manager.py

@@ -218,15 +218,30 @@ class ProviderManager:
             )
 
             if available_models:
-                available_model = available_models[0]
-                default_model = TenantDefaultModel(
-                    tenant_id=tenant_id,
-                    model_type=model_type.to_origin_model_type(),
-                    provider_name=available_model.provider.provider,
-                    model_name=available_model.model
-                )
-                db.session.add(default_model)
-                db.session.commit()
+                found = False
+                for available_model in available_models:
+                    if available_model.model == "gpt-3.5-turbo-1106":
+                        default_model = TenantDefaultModel(
+                            tenant_id=tenant_id,
+                            model_type=model_type.to_origin_model_type(),
+                            provider_name=available_model.provider.provider,
+                            model_name=available_model.model
+                        )
+                        db.session.add(default_model)
+                        db.session.commit()
+                        found = True
+                        break
+
+                if not found:
+                    available_model = available_models[0]
+                    default_model = TenantDefaultModel(
+                        tenant_id=tenant_id,
+                        model_type=model_type.to_origin_model_type(),
+                        provider_name=available_model.provider.provider,
+                        model_name=available_model.model
+                    )
+                    db.session.add(default_model)
+                    db.session.commit()
 
         if not default_model:
             return None