소스 검색

refactor: optimize provider configuration queries with provider name … (#15491)

Yeuoly 1 개월 전
부모
커밋
a6bc642721
1개의 변경된 파일107개의 추가작업 그리고 93개의 파일을 삭제
  1. 107 93
      api/core/entities/provider_configuration.py

+ 107 - 93
api/core/entities/provider_configuration.py

@@ -7,7 +7,6 @@ from json import JSONDecodeError
 from typing import Optional
 
 from pydantic import BaseModel, ConfigDict, Field
-from sqlalchemy import or_
 
 from constants import HIDDEN_VALUE
 from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
@@ -180,37 +179,35 @@ class ProviderConfiguration(BaseModel):
             else [],
         )
 
-    def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]:
+    def _get_custom_provider_credentials(self) -> Provider | None:
         """
-        Validate custom credentials.
-        :param credentials: provider credentials
-        :return:
+        Get custom provider credentials.
         """
         # get provider
         model_provider_id = ModelProviderID(self.provider.provider)
+        provider_names = [self.provider.provider]
         if model_provider_id.is_langgenius():
-            provider_record = (
-                db.session.query(Provider)
-                .filter(
-                    Provider.tenant_id == self.tenant_id,
-                    Provider.provider_type == ProviderType.CUSTOM.value,
-                    or_(
-                        Provider.provider_name == model_provider_id.provider_name,
-                        Provider.provider_name == self.provider.provider,
-                    ),
-                )
-                .first()
-            )
-        else:
-            provider_record = (
-                db.session.query(Provider)
-                .filter(
-                    Provider.tenant_id == self.tenant_id,
-                    Provider.provider_type == ProviderType.CUSTOM.value,
-                    Provider.provider_name == self.provider.provider,
-                )
-                .first()
+            provider_names.append(model_provider_id.provider_name)
+
+        provider_record = (
+            db.session.query(Provider)
+            .filter(
+                Provider.tenant_id == self.tenant_id,
+                Provider.provider_type == ProviderType.CUSTOM.value,
+                Provider.provider_name.in_(provider_names),
             )
+            .first()
+        )
+
+        return provider_record
+
+    def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]:
+        """
+        Validate custom credentials.
+        :param credentials: provider credentials
+        :return:
+        """
+        provider_record = self._get_custom_provider_credentials()
 
         # Get provider credential secret variables
         provider_credential_secret_variables = self.extract_secret_variables(
@@ -291,18 +288,7 @@ class ProviderConfiguration(BaseModel):
         :return:
         """
         # get provider
-        provider_record = (
-            db.session.query(Provider)
-            .filter(
-                Provider.tenant_id == self.tenant_id,
-                or_(
-                    Provider.provider_name == ModelProviderID(self.provider.provider).plugin_name,
-                    Provider.provider_name == self.provider.provider,
-                ),
-                Provider.provider_type == ProviderType.CUSTOM.value,
-            )
-            .first()
-        )
+        provider_record = self._get_custom_provider_credentials()
 
         # delete provider
         if provider_record:
@@ -349,29 +335,47 @@ class ProviderConfiguration(BaseModel):
 
         return None
 
-    def custom_model_credentials_validate(
-        self, model_type: ModelType, model: str, credentials: dict
-    ) -> tuple[ProviderModel | None, dict]:
+    def _get_custom_model_credentials(
+        self,
+        model_type: ModelType,
+        model: str,
+    ) -> ProviderModel | None:
         """
-        Validate custom model credentials.
-
-        :param model_type: model type
-        :param model: model name
-        :param credentials: model credentials
-        :return:
+        Get custom model credentials.
         """
         # get provider model
+        model_provider_id = ModelProviderID(self.provider.provider)
+        provider_names = [self.provider.provider]
+        if model_provider_id.is_langgenius():
+            provider_names.append(model_provider_id.provider_name)
+
         provider_model_record = (
             db.session.query(ProviderModel)
             .filter(
                 ProviderModel.tenant_id == self.tenant_id,
-                ProviderModel.provider_name == self.provider.provider,
+                ProviderModel.provider_name.in_(provider_names),
                 ProviderModel.model_name == model,
                 ProviderModel.model_type == model_type.to_origin_model_type(),
             )
             .first()
         )
 
+        return provider_model_record
+
+    def custom_model_credentials_validate(
+        self, model_type: ModelType, model: str, credentials: dict
+    ) -> tuple[ProviderModel | None, dict]:
+        """
+        Validate custom model credentials.
+
+        :param model_type: model type
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        # get provider model
+        provider_model_record = self._get_custom_model_credentials(model_type, model)
+
         # Get provider credential secret variables
         provider_credential_secret_variables = self.extract_secret_variables(
             self.provider.model_credential_schema.credential_form_schemas
@@ -451,16 +455,7 @@ class ProviderConfiguration(BaseModel):
         :return:
         """
         # get provider model
-        provider_model_record = (
-            db.session.query(ProviderModel)
-            .filter(
-                ProviderModel.tenant_id == self.tenant_id,
-                ProviderModel.provider_name == self.provider.provider,
-                ProviderModel.model_name == model,
-                ProviderModel.model_type == model_type.to_origin_model_type(),
-            )
-            .first()
-        )
+        provider_model_record = self._get_custom_model_credentials(model_type, model)
 
         # delete provider model
         if provider_model_record:
@@ -475,24 +470,35 @@ class ProviderConfiguration(BaseModel):
 
             provider_model_credentials_cache.delete()
 
-    def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
+    def _get_provider_model_setting(self, model_type: ModelType, model: str) -> ProviderModelSetting | None:
         """
-        Enable model.
-        :param model_type: model type
-        :param model: model name
-        :return:
+        Get provider model setting.
         """
-        model_setting = (
+        model_provider_id = ModelProviderID(self.provider.provider)
+        provider_names = [self.provider.provider]
+        if model_provider_id.is_langgenius():
+            provider_names.append(model_provider_id.provider_name)
+
+        return (
             db.session.query(ProviderModelSetting)
             .filter(
                 ProviderModelSetting.tenant_id == self.tenant_id,
-                ProviderModelSetting.provider_name == self.provider.provider,
+                ProviderModelSetting.provider_name.in_(provider_names),
                 ProviderModelSetting.model_type == model_type.to_origin_model_type(),
                 ProviderModelSetting.model_name == model,
             )
             .first()
         )
 
+    def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
+        """
+        Enable model.
+        :param model_type: model type
+        :param model: model name
+        :return:
+        """
+        model_setting = self._get_provider_model_setting(model_type, model)
+
         if model_setting:
             model_setting.enabled = True
             model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
@@ -516,16 +522,7 @@ class ProviderConfiguration(BaseModel):
         :param model: model name
         :return:
         """
-        model_setting = (
-            db.session.query(ProviderModelSetting)
-            .filter(
-                ProviderModelSetting.tenant_id == self.tenant_id,
-                ProviderModelSetting.provider_name == self.provider.provider,
-                ProviderModelSetting.model_type == model_type.to_origin_model_type(),
-                ProviderModelSetting.model_name == model,
-            )
-            .first()
-        )
+        model_setting = self._get_provider_model_setting(model_type, model)
 
         if model_setting:
             model_setting.enabled = False
@@ -550,13 +547,24 @@ class ProviderConfiguration(BaseModel):
         :param model: model name
         :return:
         """
+        return self._get_provider_model_setting(model_type, model)
+
+    def _get_load_balancing_config(self, model_type: ModelType, model: str) -> Optional[LoadBalancingModelConfig]:
+        """
+        Get load balancing config.
+        """
+        model_provider_id = ModelProviderID(self.provider.provider)
+        provider_names = [self.provider.provider]
+        if model_provider_id.is_langgenius():
+            provider_names.append(model_provider_id.provider_name)
+
         return (
-            db.session.query(ProviderModelSetting)
+            db.session.query(LoadBalancingModelConfig)
             .filter(
-                ProviderModelSetting.tenant_id == self.tenant_id,
-                ProviderModelSetting.provider_name == self.provider.provider,
-                ProviderModelSetting.model_type == model_type.to_origin_model_type(),
-                ProviderModelSetting.model_name == model,
+                LoadBalancingModelConfig.tenant_id == self.tenant_id,
+                LoadBalancingModelConfig.provider_name.in_(provider_names),
+                LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
+                LoadBalancingModelConfig.model_name == model,
             )
             .first()
         )
@@ -568,11 +576,16 @@ class ProviderConfiguration(BaseModel):
         :param model: model name
         :return:
         """
+        model_provider_id = ModelProviderID(self.provider.provider)
+        provider_names = [self.provider.provider]
+        if model_provider_id.is_langgenius():
+            provider_names.append(model_provider_id.provider_name)
+
         load_balancing_config_count = (
             db.session.query(LoadBalancingModelConfig)
             .filter(
                 LoadBalancingModelConfig.tenant_id == self.tenant_id,
-                LoadBalancingModelConfig.provider_name == self.provider.provider,
+                LoadBalancingModelConfig.provider_name.in_(provider_names),
                 LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
                 LoadBalancingModelConfig.model_name == model,
             )
@@ -582,16 +595,7 @@ class ProviderConfiguration(BaseModel):
         if load_balancing_config_count <= 1:
             raise ValueError("Model load balancing configuration must be more than 1.")
 
-        model_setting = (
-            db.session.query(ProviderModelSetting)
-            .filter(
-                ProviderModelSetting.tenant_id == self.tenant_id,
-                ProviderModelSetting.provider_name == self.provider.provider,
-                ProviderModelSetting.model_type == model_type.to_origin_model_type(),
-                ProviderModelSetting.model_name == model,
-            )
-            .first()
-        )
+        model_setting = self._get_provider_model_setting(model_type, model)
 
         if model_setting:
             model_setting.load_balancing_enabled = True
@@ -616,11 +620,16 @@ class ProviderConfiguration(BaseModel):
         :param model: model name
         :return:
         """
+        model_provider_id = ModelProviderID(self.provider.provider)
+        provider_names = [self.provider.provider]
+        if model_provider_id.is_langgenius():
+            provider_names.append(model_provider_id.provider_name)
+
         model_setting = (
             db.session.query(ProviderModelSetting)
             .filter(
                 ProviderModelSetting.tenant_id == self.tenant_id,
-                ProviderModelSetting.provider_name == self.provider.provider,
+                ProviderModelSetting.provider_name.in_(provider_names),
                 ProviderModelSetting.model_type == model_type.to_origin_model_type(),
                 ProviderModelSetting.model_name == model,
             )
@@ -677,11 +686,16 @@ class ProviderConfiguration(BaseModel):
             return
 
         # get preferred provider
+        model_provider_id = ModelProviderID(self.provider.provider)
+        provider_names = [self.provider.provider]
+        if model_provider_id.is_langgenius():
+            provider_names.append(model_provider_id.provider_name)
+
         preferred_model_provider = (
             db.session.query(TenantPreferredModelProvider)
             .filter(
                 TenantPreferredModelProvider.tenant_id == self.tenant_id,
-                TenantPreferredModelProvider.provider_name == self.provider.provider,
+                TenantPreferredModelProvider.provider_name.in_(provider_names),
             )
             .first()
         )