Selaa lähdekoodia

feat: backend model load balancing support (#4927)

takatost 10 kuukautta sitten
vanhempi
commit
d1dbbc1e33
47 muutettua tiedostoa jossa 2190 lisäystä ja 255 poistoa
  1. 38 17
      api/config.py
  2. 1 1
      api/controllers/console/__init__.py
  3. 6 1
      api/controllers/console/feature.py
  4. 19 18
      api/controllers/console/version.py
  5. 106 0
      api/controllers/console/workspace/load_balancing_config.py
  6. 110 8
      api/controllers/console/workspace/models.py
  7. 13 12
      api/core/app/apps/base_app_runner.py
  8. 10 8
      api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
  9. 0 0
      api/core/application_manager.py
  10. 10 2
      api/core/entities/model_entities.py
  11. 254 27
      api/core/entities/provider_configuration.py
  12. 19 0
      api/core/entities/provider_entities.py
  13. 1 1
      api/core/extension/extensible.py
  14. 1 0
      api/core/helper/model_provider_cache.py
  15. 0 0
      api/core/helper/module_import_helper.py
  16. 0 0
      api/core/helper/position_helper.py
  17. 4 15
      api/core/indexing_runner.py
  18. 2 11
      api/core/memory/token_buffer_memory.py
  19. 283 9
      api/core/model_manager.py
  20. 1 1
      api/core/model_runtime/model_providers/__base/ai_model.py
  21. 1 1
      api/core/model_runtime/model_providers/__base/model_provider.py
  22. 2 2
      api/core/model_runtime/model_providers/model_provider_factory.py
  23. 7 7
      api/core/prompt/prompt_transform.py
  24. 165 2
      api/core/provider_manager.py
  25. 2 7
      api/core/rag/docstore/dataset_docstore.py
  26. 2 7
      api/core/rag/splitter/fixed_text_splitter.py
  27. 1 1
      api/core/tools/provider/builtin/_positions.py
  28. 4 4
      api/core/tools/provider/builtin_tool_provider.py
  29. 10 10
      api/core/tools/tool_manager.py
  30. 4 13
      api/core/tools/utils/model_invocation_utils.py
  31. 6 6
      api/core/workflow/nodes/question_classifier/question_classifier_node.py
  32. 126 0
      api/migrations/versions/4e99a8df00ff_add_load_balancing.py
  33. 48 5
      api/models/provider.py
  34. 4 14
      api/services/dataset_service.py
  35. 3 10
      api/services/entities/model_provider_entities.py
  36. 25 12
      api/services/feature_service.py
  37. 565 0
      api/services/model_load_balancing_service.py
  38. 56 8
      api/services/model_provider_service.py
  39. 1 1
      api/services/workflow_service.py
  40. 1 7
      api/tasks/batch_create_segment_to_index_task.py
  41. 1 1
      api/tests/integration_tests/utils/test_module_import_helper.py
  42. 5 3
      api/tests/integration_tests/workflow/nodes/test_llm.py
  43. 2 1
      api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
  44. 10 1
      api/tests/unit_tests/core/prompt/test_prompt_transform.py
  45. 77 0
      api/tests/unit_tests/core/test_model_manager.py
  46. 183 0
      api/tests/unit_tests/core/test_provider_manager.py
  47. 1 1
      api/tests/unit_tests/utils/position_helper/test_position_helper.py

+ 38 - 17
api/config.py

@@ -70,6 +70,7 @@ DEFAULTS = {
     'INVITE_EXPIRY_HOURS': 72,
     'BILLING_ENABLED': 'False',
     'CAN_REPLACE_LOGO': 'False',
+    'MODEL_LB_ENABLED': 'False',
     'ETL_TYPE': 'dify',
     'KEYWORD_STORE': 'jieba',
     'BATCH_UPLOAD_LIMIT': 20,
@@ -123,6 +124,7 @@ class Config:
         self.LOG_FILE = get_env('LOG_FILE')
         self.LOG_FORMAT = get_env('LOG_FORMAT')
         self.LOG_DATEFORMAT = get_env('LOG_DATEFORMAT')
+        self.API_COMPRESSION_ENABLED = get_bool_env('API_COMPRESSION_ENABLED')
 
         # The backend URL prefix of the console API.
         # used to concatenate the login authorization callback or notion integration callback.
@@ -210,27 +212,41 @@ class Config:
             if self.CELERY_BACKEND == 'database' else self.CELERY_BROKER_URL
         self.BROKER_USE_SSL = self.CELERY_BROKER_URL.startswith('rediss://')
 
+        # ------------------------
+        # Code Execution Sandbox Configurations.
+        # ------------------------
+        self.CODE_EXECUTION_ENDPOINT = get_env('CODE_EXECUTION_ENDPOINT')
+        self.CODE_EXECUTION_API_KEY = get_env('CODE_EXECUTION_API_KEY')
+
         # ------------------------
         # File Storage Configurations.
         # ------------------------
         self.STORAGE_TYPE = get_env('STORAGE_TYPE')
         self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH')
+
+        # S3 Storage settings
         self.S3_ENDPOINT = get_env('S3_ENDPOINT')
         self.S3_BUCKET_NAME = get_env('S3_BUCKET_NAME')
         self.S3_ACCESS_KEY = get_env('S3_ACCESS_KEY')
         self.S3_SECRET_KEY = get_env('S3_SECRET_KEY')
         self.S3_REGION = get_env('S3_REGION')
         self.S3_ADDRESS_STYLE = get_env('S3_ADDRESS_STYLE')
+
+        # Azure Blob Storage settings
         self.AZURE_BLOB_ACCOUNT_NAME = get_env('AZURE_BLOB_ACCOUNT_NAME')
         self.AZURE_BLOB_ACCOUNT_KEY = get_env('AZURE_BLOB_ACCOUNT_KEY')
         self.AZURE_BLOB_CONTAINER_NAME = get_env('AZURE_BLOB_CONTAINER_NAME')
         self.AZURE_BLOB_ACCOUNT_URL = get_env('AZURE_BLOB_ACCOUNT_URL')
+
+        # Aliyun Storage settings
         self.ALIYUN_OSS_BUCKET_NAME = get_env('ALIYUN_OSS_BUCKET_NAME')
         self.ALIYUN_OSS_ACCESS_KEY = get_env('ALIYUN_OSS_ACCESS_KEY')
         self.ALIYUN_OSS_SECRET_KEY = get_env('ALIYUN_OSS_SECRET_KEY')
         self.ALIYUN_OSS_ENDPOINT = get_env('ALIYUN_OSS_ENDPOINT')
         self.ALIYUN_OSS_REGION = get_env('ALIYUN_OSS_REGION')
         self.ALIYUN_OSS_AUTH_VERSION = get_env('ALIYUN_OSS_AUTH_VERSION')
+
+        # Google Cloud Storage settings
         self.GOOGLE_STORAGE_BUCKET_NAME = get_env('GOOGLE_STORAGE_BUCKET_NAME')
         self.GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64 = get_env('GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64')
 
@@ -240,6 +256,7 @@ class Config:
         # ------------------------
         self.VECTOR_STORE = get_env('VECTOR_STORE')
         self.KEYWORD_STORE = get_env('KEYWORD_STORE')
+
         # qdrant settings
         self.QDRANT_URL = get_env('QDRANT_URL')
         self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
@@ -323,6 +340,19 @@ class Config:
         self.UPLOAD_FILE_SIZE_LIMIT = int(get_env('UPLOAD_FILE_SIZE_LIMIT'))
         self.UPLOAD_FILE_BATCH_LIMIT = int(get_env('UPLOAD_FILE_BATCH_LIMIT'))
         self.UPLOAD_IMAGE_FILE_SIZE_LIMIT = int(get_env('UPLOAD_IMAGE_FILE_SIZE_LIMIT'))
+        self.BATCH_UPLOAD_LIMIT = get_env('BATCH_UPLOAD_LIMIT')
+
+        # RAG ETL Configurations.
+        self.ETL_TYPE = get_env('ETL_TYPE')
+        self.UNSTRUCTURED_API_URL = get_env('UNSTRUCTURED_API_URL')
+        self.UNSTRUCTURED_API_KEY = get_env('UNSTRUCTURED_API_KEY')
+        self.KEYWORD_DATA_SOURCE_TYPE = get_env('KEYWORD_DATA_SOURCE_TYPE')
+
+        # Indexing Configurations.
+        self.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH = get_env('INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH')
+
+        # Tool Configurations.
+        self.TOOL_ICON_CACHE_MAX_AGE = get_env('TOOL_ICON_CACHE_MAX_AGE')
 
         self.WORKFLOW_MAX_EXECUTION_STEPS = int(get_env('WORKFLOW_MAX_EXECUTION_STEPS'))
         self.WORKFLOW_MAX_EXECUTION_TIME = int(get_env('WORKFLOW_MAX_EXECUTION_TIME'))
@@ -378,24 +408,15 @@ class Config:
         self.HOSTED_FETCH_APP_TEMPLATES_MODE = get_env('HOSTED_FETCH_APP_TEMPLATES_MODE')
         self.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = get_env('HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN')
 
-        self.ETL_TYPE = get_env('ETL_TYPE')
-        self.UNSTRUCTURED_API_URL = get_env('UNSTRUCTURED_API_URL')
-        self.UNSTRUCTURED_API_KEY = get_env('UNSTRUCTURED_API_KEY')
-        self.BILLING_ENABLED = get_bool_env('BILLING_ENABLED')
-        self.CAN_REPLACE_LOGO = get_bool_env('CAN_REPLACE_LOGO')
-
-        self.BATCH_UPLOAD_LIMIT = get_env('BATCH_UPLOAD_LIMIT')
-
-        self.CODE_EXECUTION_ENDPOINT = get_env('CODE_EXECUTION_ENDPOINT')
-        self.CODE_EXECUTION_API_KEY = get_env('CODE_EXECUTION_API_KEY')
+        # Model Load Balancing Configurations.
+        self.MODEL_LB_ENABLED = get_bool_env('MODEL_LB_ENABLED')
 
-        self.API_COMPRESSION_ENABLED = get_bool_env('API_COMPRESSION_ENABLED')
-        self.TOOL_ICON_CACHE_MAX_AGE = get_env('TOOL_ICON_CACHE_MAX_AGE')
-
-        self.KEYWORD_DATA_SOURCE_TYPE = get_env('KEYWORD_DATA_SOURCE_TYPE')
-        self.ENTERPRISE_ENABLED = get_bool_env('ENTERPRISE_ENABLED')
+        # Platform Billing Configurations.
+        self.BILLING_ENABLED = get_bool_env('BILLING_ENABLED')
 
         # ------------------------
-        # Indexing Configurations.
+        # Enterprise feature Configurations.
+        # **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
         # ------------------------
-        self.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH = get_env('INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH')
+        self.ENTERPRISE_ENABLED = get_bool_env('ENTERPRISE_ENABLED')
+        self.CAN_REPLACE_LOGO = get_bool_env('CAN_REPLACE_LOGO')

+ 1 - 1
api/controllers/console/__init__.py

@@ -54,4 +54,4 @@ from .explore import (
 from .tag import tags
 
 # Import workspace controllers
-from .workspace import account, members, model_providers, models, tool_providers, workspace
+from .workspace import account, load_balancing_config, members, model_providers, models, tool_providers, workspace

+ 6 - 1
api/controllers/console/feature.py

@@ -1,14 +1,19 @@
 from flask_login import current_user
 from flask_restful import Resource
 
+from libs.login import login_required
 from services.feature_service import FeatureService
 
 from . import api
-from .wraps import cloud_utm_record
+from .setup import setup_required
+from .wraps import account_initialization_required, cloud_utm_record
 
 
 class FeatureApi(Resource):
 
+    @setup_required
+    @login_required
+    @account_initialization_required
     @cloud_utm_record
     def get(self):
         return FeatureService.get_features(current_user.current_tenant_id).dict()

+ 19 - 18
api/controllers/console/version.py

@@ -17,13 +17,19 @@ class VersionApi(Resource):
         args = parser.parse_args()
         check_update_url = current_app.config['CHECK_UPDATE_URL']
 
-        if not check_update_url:
-            return {
-                'version': '0.0.0',
-                'release_date': '',
-                'release_notes': '',
-                'can_auto_update': False
+        result = {
+            'version': current_app.config['CURRENT_VERSION'],
+            'release_date': '',
+            'release_notes': '',
+            'can_auto_update': False,
+            'features': {
+                'can_replace_logo': current_app.config['CAN_REPLACE_LOGO'],
+                'model_load_balancing_enabled': current_app.config['MODEL_LB_ENABLED']
             }
+        }
+
+        if not check_update_url:
+            return result
 
         try:
             response = requests.get(check_update_url, {
@@ -31,20 +37,15 @@ class VersionApi(Resource):
             })
         except Exception as error:
             logging.warning("Check update version error: {}.".format(str(error)))
-            return {
-                'version': args.get('current_version'),
-                'release_date': '',
-                'release_notes': '',
-                'can_auto_update': False
-            }
+            result['version'] = args.get('current_version')
+            return result
 
         content = json.loads(response.content)
-        return {
-            'version': content['version'],
-            'release_date': content['releaseDate'],
-            'release_notes': content['releaseNotes'],
-            'can_auto_update': content['canAutoUpdate']
-        }
+        result['version'] = content['version']
+        result['release_date'] = content['releaseDate']
+        result['release_notes'] = content['releaseNotes']
+        result['can_auto_update'] = content['canAutoUpdate']
+        return result
 
 
 api.add_resource(VersionApi, '/version')

+ 106 - 0
api/controllers/console/workspace/load_balancing_config.py

@@ -0,0 +1,106 @@
+from flask_restful import Resource, reqparse
+from werkzeug.exceptions import Forbidden
+
+from controllers.console import api
+from controllers.console.setup import setup_required
+from controllers.console.wraps import account_initialization_required
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from libs.login import current_user, login_required
+from models.account import TenantAccountRole
+from services.model_load_balancing_service import ModelLoadBalancingService
+
+
+class LoadBalancingCredentialsValidateApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def post(self, provider: str):
+        if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role):
+            raise Forbidden()
+
+        tenant_id = current_user.current_tenant_id
+
+        parser = reqparse.RequestParser()
+        parser.add_argument('model', type=str, required=True, nullable=False, location='json')
+        parser.add_argument('model_type', type=str, required=True, nullable=False,
+                            choices=[mt.value for mt in ModelType], location='json')
+        parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
+        args = parser.parse_args()
+
+        # validate model load balancing credentials
+        model_load_balancing_service = ModelLoadBalancingService()
+
+        result = True
+        error = None
+
+        try:
+            model_load_balancing_service.validate_load_balancing_credentials(
+                tenant_id=tenant_id,
+                provider=provider,
+                model=args['model'],
+                model_type=args['model_type'],
+                credentials=args['credentials']
+            )
+        except CredentialsValidateFailedError as ex:
+            result = False
+            error = str(ex)
+
+        response = {'result': 'success' if result else 'error'}
+
+        if not result:
+            response['error'] = error
+
+        return response
+
+
+class LoadBalancingConfigCredentialsValidateApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def post(self, provider: str, config_id: str):
+        if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role):
+            raise Forbidden()
+
+        tenant_id = current_user.current_tenant_id
+
+        parser = reqparse.RequestParser()
+        parser.add_argument('model', type=str, required=True, nullable=False, location='json')
+        parser.add_argument('model_type', type=str, required=True, nullable=False,
+                            choices=[mt.value for mt in ModelType], location='json')
+        parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
+        args = parser.parse_args()
+
+        # validate model load balancing config credentials
+        model_load_balancing_service = ModelLoadBalancingService()
+
+        result = True
+        error = None
+
+        try:
+            model_load_balancing_service.validate_load_balancing_credentials(
+                tenant_id=tenant_id,
+                provider=provider,
+                model=args['model'],
+                model_type=args['model_type'],
+                credentials=args['credentials'],
+                config_id=config_id,
+            )
+        except CredentialsValidateFailedError as ex:
+            result = False
+            error = str(ex)
+
+        response = {'result': 'success' if result else 'error'}
+
+        if not result:
+            response['error'] = error
+
+        return response
+
+
+# Load Balancing Config
+api.add_resource(LoadBalancingCredentialsValidateApi,
+                 '/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/credentials-validate')
+
+api.add_resource(LoadBalancingConfigCredentialsValidateApi,
+                 '/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate')

+ 110 - 8
api/controllers/console/workspace/models.py

@@ -12,6 +12,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
 from core.model_runtime.utils.encoders import jsonable_encoder
 from libs.login import login_required
 from models.account import TenantAccountRole
+from services.model_load_balancing_service import ModelLoadBalancingService
 from services.model_provider_service import ModelProviderService
 
 
@@ -104,21 +105,56 @@ class ModelProviderModelApi(Resource):
         parser.add_argument('model', type=str, required=True, nullable=False, location='json')
         parser.add_argument('model_type', type=str, required=True, nullable=False,
                             choices=[mt.value for mt in ModelType], location='json')
-        parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
+        parser.add_argument('credentials', type=dict, required=False, nullable=True, location='json')
+        parser.add_argument('load_balancing', type=dict, required=False, nullable=True, location='json')
+        parser.add_argument('config_from', type=str, required=False, nullable=True, location='json')
         args = parser.parse_args()
 
-        model_provider_service = ModelProviderService()
+        model_load_balancing_service = ModelLoadBalancingService()
 
-        try:
-            model_provider_service.save_model_credentials(
+        if ('load_balancing' in args and args['load_balancing'] and
+                'enabled' in args['load_balancing'] and args['load_balancing']['enabled']):
+            if 'configs' not in args['load_balancing']:
+                raise ValueError('invalid load balancing configs')
+
+            # save load balancing configs
+            model_load_balancing_service.update_load_balancing_configs(
                 tenant_id=tenant_id,
                 provider=provider,
                 model=args['model'],
                 model_type=args['model_type'],
-                credentials=args['credentials']
+                configs=args['load_balancing']['configs']
             )
-        except CredentialsValidateFailedError as ex:
-            raise ValueError(str(ex))
+
+            # enable load balancing
+            model_load_balancing_service.enable_model_load_balancing(
+                tenant_id=tenant_id,
+                provider=provider,
+                model=args['model'],
+                model_type=args['model_type']
+            )
+        else:
+            # disable load balancing
+            model_load_balancing_service.disable_model_load_balancing(
+                tenant_id=tenant_id,
+                provider=provider,
+                model=args['model'],
+                model_type=args['model_type']
+            )
+
+            if args.get('config_from', '') != 'predefined-model':
+                model_provider_service = ModelProviderService()
+
+                try:
+                    model_provider_service.save_model_credentials(
+                        tenant_id=tenant_id,
+                        provider=provider,
+                        model=args['model'],
+                        model_type=args['model_type'],
+                        credentials=args['credentials']
+                    )
+                except CredentialsValidateFailedError as ex:
+                    raise ValueError(str(ex))
 
         return {'result': 'success'}, 200
 
@@ -170,11 +206,73 @@ class ModelProviderModelCredentialApi(Resource):
             model=args['model']
         )
 
+        model_load_balancing_service = ModelLoadBalancingService()
+        is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
+            tenant_id=tenant_id,
+            provider=provider,
+            model=args['model'],
+            model_type=args['model_type']
+        )
+
         return {
-            "credentials": credentials
+            "credentials": credentials,
+            "load_balancing": {
+                "enabled": is_load_balancing_enabled,
+                "configs": load_balancing_configs
+            }
         }
 
 
+class ModelProviderModelEnableApi(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def patch(self, provider: str):
+        tenant_id = current_user.current_tenant_id
+
+        parser = reqparse.RequestParser()
+        parser.add_argument('model', type=str, required=True, nullable=False, location='json')
+        parser.add_argument('model_type', type=str, required=True, nullable=False,
+                            choices=[mt.value for mt in ModelType], location='json')
+        args = parser.parse_args()
+
+        model_provider_service = ModelProviderService()
+        model_provider_service.enable_model(
+            tenant_id=tenant_id,
+            provider=provider,
+            model=args['model'],
+            model_type=args['model_type']
+        )
+
+        return {'result': 'success'}
+
+
+class ModelProviderModelDisableApi(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def patch(self, provider: str):
+        tenant_id = current_user.current_tenant_id
+
+        parser = reqparse.RequestParser()
+        parser.add_argument('model', type=str, required=True, nullable=False, location='json')
+        parser.add_argument('model_type', type=str, required=True, nullable=False,
+                            choices=[mt.value for mt in ModelType], location='json')
+        args = parser.parse_args()
+
+        model_provider_service = ModelProviderService()
+        model_provider_service.disable_model(
+            tenant_id=tenant_id,
+            provider=provider,
+            model=args['model'],
+            model_type=args['model_type']
+        )
+
+        return {'result': 'success'}
+
+
 class ModelProviderModelValidateApi(Resource):
 
     @setup_required
@@ -259,6 +357,10 @@ class ModelProviderAvailableModelApi(Resource):
 
 
 api.add_resource(ModelProviderModelApi, '/workspaces/current/model-providers/<string:provider>/models')
+api.add_resource(ModelProviderModelEnableApi, '/workspaces/current/model-providers/<string:provider>/models/enable',
+                 endpoint='model-provider-model-enable')
+api.add_resource(ModelProviderModelDisableApi, '/workspaces/current/model-providers/<string:provider>/models/disable',
+                 endpoint='model-provider-model-disable')
 api.add_resource(ModelProviderModelCredentialApi,
                  '/workspaces/current/model-providers/<string:provider>/models/credentials')
 api.add_resource(ModelProviderModelValidateApi,

+ 13 - 12
api/core/app/apps/base_app_runner.py

@@ -1,6 +1,6 @@
 import time
 from collections.abc import Generator
-from typing import Optional, Union, cast
+from typing import Optional, Union
 
 from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
 from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@@ -16,11 +16,11 @@ from core.app.features.hosting_moderation.hosting_moderation import HostingModer
 from core.external_data_tool.external_data_fetch import ExternalDataFetch
 from core.file.file_obj import FileVar
 from core.memory.token_buffer_memory import TokenBufferMemory
+from core.model_manager import ModelInstance
 from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
 from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage
 from core.model_runtime.entities.model_entities import ModelPropertyKey
 from core.model_runtime.errors.invoke import InvokeBadRequestError
-from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.moderation.input_moderation import InputModeration
 from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
 from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
@@ -45,8 +45,11 @@ class AppRunner:
         :param query: query
         :return:
         """
-        model_type_instance = model_config.provider_model_bundle.model_type_instance
-        model_type_instance = cast(LargeLanguageModel, model_type_instance)
+        # Invoke model
+        model_instance = ModelInstance(
+            provider_model_bundle=model_config.provider_model_bundle,
+            model=model_config.model
+        )
 
         model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
 
@@ -73,9 +76,7 @@ class AppRunner:
             query=query
         )
 
-        prompt_tokens = model_type_instance.get_num_tokens(
-            model_config.model,
-            model_config.credentials,
+        prompt_tokens = model_instance.get_llm_num_tokens(
             prompt_messages
         )
 
@@ -89,8 +90,10 @@ class AppRunner:
     def recalc_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity,
                               prompt_messages: list[PromptMessage]):
         # recalc max_tokens if sum(prompt_token +  max_tokens) over model token limit
-        model_type_instance = model_config.provider_model_bundle.model_type_instance
-        model_type_instance = cast(LargeLanguageModel, model_type_instance)
+        model_instance = ModelInstance(
+            provider_model_bundle=model_config.provider_model_bundle,
+            model=model_config.model
+        )
 
         model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
 
@@ -107,9 +110,7 @@ class AppRunner:
         if max_tokens is None:
             max_tokens = 0
 
-        prompt_tokens = model_type_instance.get_num_tokens(
-            model_config.model,
-            model_config.credentials,
+        prompt_tokens = model_instance.get_llm_num_tokens(
             prompt_messages
         )
 

+ 10 - 8
api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py

@@ -37,6 +37,7 @@ from core.app.entities.task_entities import (
 )
 from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
 from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
+from core.model_manager import ModelInstance
 from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
 from core.model_runtime.entities.message_entities import (
     AssistantPromptMessage,
@@ -317,29 +318,30 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
         """
         model_config = self._model_config
         model = model_config.model
-        model_type_instance = model_config.provider_model_bundle.model_type_instance
-        model_type_instance = cast(LargeLanguageModel, model_type_instance)
+
+        model_instance = ModelInstance(
+            provider_model_bundle=model_config.provider_model_bundle,
+            model=model_config.model
+        )
 
         # calculate num tokens
         prompt_tokens = 0
         if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
-            prompt_tokens = model_type_instance.get_num_tokens(
-                model,
-                model_config.credentials,
+            prompt_tokens = model_instance.get_llm_num_tokens(
                 self._task_state.llm_result.prompt_messages
             )
 
         completion_tokens = 0
         if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
-            completion_tokens = model_type_instance.get_num_tokens(
-                model,
-                model_config.credentials,
+            completion_tokens = model_instance.get_llm_num_tokens(
                 [self._task_state.llm_result.message]
             )
 
         credentials = model_config.credentials
 
         # transform usage
+        model_type_instance = model_config.provider_model_bundle.model_type_instance
+        model_type_instance = cast(LargeLanguageModel, model_type_instance)
         self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
             model,
             credentials,

+ 0 - 0
api/core/application_manager.py


+ 10 - 2
api/core/entities/model_entities.py

@@ -16,6 +16,7 @@ class ModelStatus(Enum):
     NO_CONFIGURE = "no-configure"
     QUOTA_EXCEEDED = "quota-exceeded"
     NO_PERMISSION = "no-permission"
+    DISABLED = "disabled"
 
 
 class SimpleModelProviderEntity(BaseModel):
@@ -43,12 +44,19 @@ class SimpleModelProviderEntity(BaseModel):
         )
 
 
-class ModelWithProviderEntity(ProviderModel):
+class ProviderModelWithStatusEntity(ProviderModel):
+    """
+    Model class for model response.
+    """
+    status: ModelStatus
+    load_balancing_enabled: bool = False
+
+
+class ModelWithProviderEntity(ProviderModelWithStatusEntity):
     """
     Model with provider entity.
     """
     provider: SimpleModelProviderEntity
-    status: ModelStatus
 
 
 class DefaultModelProviderEntity(BaseModel):

+ 254 - 27
api/core/entities/provider_configuration.py

@@ -1,6 +1,7 @@
 import datetime
 import json
 import logging
+from collections import defaultdict
 from collections.abc import Iterator
 from json import JSONDecodeError
 from typing import Optional
@@ -8,7 +9,12 @@ from typing import Optional
 from pydantic import BaseModel
 
 from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
-from core.entities.provider_entities import CustomConfiguration, SystemConfiguration, SystemConfigurationStatus
+from core.entities.provider_entities import (
+    CustomConfiguration,
+    ModelSettings,
+    SystemConfiguration,
+    SystemConfigurationStatus,
+)
 from core.helper import encrypter
 from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
 from core.model_runtime.entities.model_entities import FetchFrom, ModelType
@@ -22,7 +28,14 @@ from core.model_runtime.model_providers import model_provider_factory
 from core.model_runtime.model_providers.__base.ai_model import AIModel
 from core.model_runtime.model_providers.__base.model_provider import ModelProvider
 from extensions.ext_database import db
-from models.provider import Provider, ProviderModel, ProviderType, TenantPreferredModelProvider
+from models.provider import (
+    LoadBalancingModelConfig,
+    Provider,
+    ProviderModel,
+    ProviderModelSetting,
+    ProviderType,
+    TenantPreferredModelProvider,
+)
 
 logger = logging.getLogger(__name__)
 
@@ -39,6 +52,7 @@ class ProviderConfiguration(BaseModel):
     using_provider_type: ProviderType
     system_configuration: SystemConfiguration
     custom_configuration: CustomConfiguration
+    model_settings: list[ModelSettings]
 
     def __init__(self, **data):
         super().__init__(**data)
@@ -62,6 +76,14 @@ class ProviderConfiguration(BaseModel):
         :param model: model name
         :return:
         """
+        if self.model_settings:
+            # check if model is disabled by admin
+            for model_setting in self.model_settings:
+                if (model_setting.model_type == model_type
+                        and model_setting.model == model):
+                    if not model_setting.enabled:
+                        raise ValueError(f'Model {model} is disabled.')
+
         if self.using_provider_type == ProviderType.SYSTEM:
             restrict_models = []
             for quota_configuration in self.system_configuration.quota_configurations:
@@ -80,15 +102,17 @@ class ProviderConfiguration(BaseModel):
 
             return copy_credentials
         else:
+            credentials = None
             if self.custom_configuration.models:
                 for model_configuration in self.custom_configuration.models:
                     if model_configuration.model_type == model_type and model_configuration.model == model:
-                        return model_configuration.credentials
+                        credentials = model_configuration.credentials
+                        break
 
             if self.custom_configuration.provider:
-                return self.custom_configuration.provider.credentials
-            else:
-                return None
+                credentials = self.custom_configuration.provider.credentials
+
+            return credentials
 
     def get_system_configuration_status(self) -> SystemConfigurationStatus:
         """
@@ -130,7 +154,7 @@ class ProviderConfiguration(BaseModel):
             return credentials
 
         # Obfuscate credentials
-        return self._obfuscated_credentials(
+        return self.obfuscated_credentials(
             credentials=credentials,
             credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
             if self.provider.provider_credential_schema else []
@@ -151,7 +175,7 @@ class ProviderConfiguration(BaseModel):
         ).first()
 
         # Get provider credential secret variables
-        provider_credential_secret_variables = self._extract_secret_variables(
+        provider_credential_secret_variables = self.extract_secret_variables(
             self.provider.provider_credential_schema.credential_form_schemas
             if self.provider.provider_credential_schema else []
         )
@@ -274,7 +298,7 @@ class ProviderConfiguration(BaseModel):
                     return credentials
 
                 # Obfuscate credentials
-                return self._obfuscated_credentials(
+                return self.obfuscated_credentials(
                     credentials=credentials,
                     credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
                     if self.provider.model_credential_schema else []
@@ -302,7 +326,7 @@ class ProviderConfiguration(BaseModel):
         ).first()
 
         # Get provider credential secret variables
-        provider_credential_secret_variables = self._extract_secret_variables(
+        provider_credential_secret_variables = self.extract_secret_variables(
             self.provider.model_credential_schema.credential_form_schemas
             if self.provider.model_credential_schema else []
         )
@@ -402,6 +426,160 @@ class ProviderConfiguration(BaseModel):
 
             provider_model_credentials_cache.delete()
 
+    def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
+        """
+        Enable model.
+        :param model_type: model type
+        :param model: model name
+        :return:
+        """
+        model_setting = db.session.query(ProviderModelSetting) \
+            .filter(
+            ProviderModelSetting.tenant_id == self.tenant_id,
+            ProviderModelSetting.provider_name == self.provider.provider,
+            ProviderModelSetting.model_type == model_type.to_origin_model_type(),
+            ProviderModelSetting.model_name == model
+        ).first()
+
+        if model_setting:
+            model_setting.enabled = True
+            model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
+            db.session.commit()
+        else:
+            model_setting = ProviderModelSetting(
+                tenant_id=self.tenant_id,
+                provider_name=self.provider.provider,
+                model_type=model_type.to_origin_model_type(),
+                model_name=model,
+                enabled=True
+            )
+            db.session.add(model_setting)
+            db.session.commit()
+
+        return model_setting
+
+    def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
+        """
+        Disable model.
+        :param model_type: model type
+        :param model: model name
+        :return:
+        """
+        model_setting = db.session.query(ProviderModelSetting) \
+            .filter(
+            ProviderModelSetting.tenant_id == self.tenant_id,
+            ProviderModelSetting.provider_name == self.provider.provider,
+            ProviderModelSetting.model_type == model_type.to_origin_model_type(),
+            ProviderModelSetting.model_name == model
+        ).first()
+
+        if model_setting:
+            model_setting.enabled = False
+            model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
+            db.session.commit()
+        else:
+            model_setting = ProviderModelSetting(
+                tenant_id=self.tenant_id,
+                provider_name=self.provider.provider,
+                model_type=model_type.to_origin_model_type(),
+                model_name=model,
+                enabled=False
+            )
+            db.session.add(model_setting)
+            db.session.commit()
+
+        return model_setting
+
+    def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]:
+        """
+        Get provider model setting.
+        :param model_type: model type
+        :param model: model name
+        :return:
+        """
+        return db.session.query(ProviderModelSetting) \
+            .filter(
+            ProviderModelSetting.tenant_id == self.tenant_id,
+            ProviderModelSetting.provider_name == self.provider.provider,
+            ProviderModelSetting.model_type == model_type.to_origin_model_type(),
+            ProviderModelSetting.model_name == model
+        ).first()
+
+    def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
+        """
+        Enable model load balancing.
+        :param model_type: model type
+        :param model: model name
+        :return:
+        """
+        load_balancing_config_count = db.session.query(LoadBalancingModelConfig) \
+            .filter(
+            LoadBalancingModelConfig.tenant_id == self.tenant_id,
+            LoadBalancingModelConfig.provider_name == self.provider.provider,
+            LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
+            LoadBalancingModelConfig.model_name == model
+        ).count()
+
+        if load_balancing_config_count <= 1:
+            raise ValueError('Model load balancing configuration must be more than 1.')
+
+        model_setting = db.session.query(ProviderModelSetting) \
+            .filter(
+            ProviderModelSetting.tenant_id == self.tenant_id,
+            ProviderModelSetting.provider_name == self.provider.provider,
+            ProviderModelSetting.model_type == model_type.to_origin_model_type(),
+            ProviderModelSetting.model_name == model
+        ).first()
+
+        if model_setting:
+            model_setting.load_balancing_enabled = True
+            model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
+            db.session.commit()
+        else:
+            model_setting = ProviderModelSetting(
+                tenant_id=self.tenant_id,
+                provider_name=self.provider.provider,
+                model_type=model_type.to_origin_model_type(),
+                model_name=model,
+                load_balancing_enabled=True
+            )
+            db.session.add(model_setting)
+            db.session.commit()
+
+        return model_setting
+
+    def disable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
+        """
+        Disable model load balancing.
+        :param model_type: model type
+        :param model: model name
+        :return:
+        """
+        model_setting = db.session.query(ProviderModelSetting) \
+            .filter(
+            ProviderModelSetting.tenant_id == self.tenant_id,
+            ProviderModelSetting.provider_name == self.provider.provider,
+            ProviderModelSetting.model_type == model_type.to_origin_model_type(),
+            ProviderModelSetting.model_name == model
+        ).first()
+
+        if model_setting:
+            model_setting.load_balancing_enabled = False
+            model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
+            db.session.commit()
+        else:
+            model_setting = ProviderModelSetting(
+                tenant_id=self.tenant_id,
+                provider_name=self.provider.provider,
+                model_type=model_type.to_origin_model_type(),
+                model_name=model,
+                load_balancing_enabled=False
+            )
+            db.session.add(model_setting)
+            db.session.commit()
+
+        return model_setting
+
     def get_provider_instance(self) -> ModelProvider:
         """
         Get provider instance.
@@ -453,7 +631,7 @@ class ProviderConfiguration(BaseModel):
 
         db.session.commit()
 
-    def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
+    def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
         """
         Extract secret input form variables.
 
@@ -467,7 +645,7 @@ class ProviderConfiguration(BaseModel):
 
         return secret_input_form_variables
 
-    def _obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
+    def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
         """
         Obfuscated credentials.
 
@@ -476,7 +654,7 @@ class ProviderConfiguration(BaseModel):
         :return:
         """
         # Get provider credential secret variables
-        credential_secret_variables = self._extract_secret_variables(
+        credential_secret_variables = self.extract_secret_variables(
             credential_form_schemas
         )
 
@@ -522,15 +700,22 @@ class ProviderConfiguration(BaseModel):
         else:
             model_types = provider_instance.get_provider_schema().supported_model_types
 
+        # Group model settings by model type and model
+        model_setting_map = defaultdict(dict)
+        for model_setting in self.model_settings:
+            model_setting_map[model_setting.model_type][model_setting.model] = model_setting
+
         if self.using_provider_type == ProviderType.SYSTEM:
             provider_models = self._get_system_provider_models(
                 model_types=model_types,
-                provider_instance=provider_instance
+                provider_instance=provider_instance,
+                model_setting_map=model_setting_map
             )
         else:
             provider_models = self._get_custom_provider_models(
                 model_types=model_types,
-                provider_instance=provider_instance
+                provider_instance=provider_instance,
+                model_setting_map=model_setting_map
             )
 
         if only_active:
@@ -541,18 +726,27 @@ class ProviderConfiguration(BaseModel):
 
     def _get_system_provider_models(self,
                                     model_types: list[ModelType],
-                                    provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
+                                    provider_instance: ModelProvider,
+                                    model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \
+            -> list[ModelWithProviderEntity]:
         """
         Get system provider models.
 
         :param model_types: model types
         :param provider_instance: provider instance
+        :param model_setting_map: model setting map
         :return:
         """
         provider_models = []
         for model_type in model_types:
-            provider_models.extend(
-                [
+            for m in provider_instance.models(model_type):
+                status = ModelStatus.ACTIVE
+                if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
+                    model_setting = model_setting_map[m.model_type][m.model]
+                    if model_setting.enabled is False:
+                        status = ModelStatus.DISABLED
+
+                provider_models.append(
                     ModelWithProviderEntity(
                         model=m.model,
                         label=m.label,
@@ -562,11 +756,9 @@ class ProviderConfiguration(BaseModel):
                         model_properties=m.model_properties,
                         deprecated=m.deprecated,
                         provider=SimpleModelProviderEntity(self.provider),
-                        status=ModelStatus.ACTIVE
+                        status=status
                     )
-                    for m in provider_instance.models(model_type)
-                ]
-            )
+                )
 
         if self.provider.provider not in original_provider_configurate_methods:
             original_provider_configurate_methods[self.provider.provider] = []
@@ -586,7 +778,8 @@ class ProviderConfiguration(BaseModel):
                 break
 
             if should_use_custom_model:
-                if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
+                if original_provider_configurate_methods[self.provider.provider] == [
+                    ConfigurateMethod.CUSTOMIZABLE_MODEL]:
                     # only customizable model
                     for restrict_model in restrict_models:
                         copy_credentials = self.system_configuration.credentials.copy()
@@ -611,6 +804,13 @@ class ProviderConfiguration(BaseModel):
                         if custom_model_schema.model_type not in model_types:
                             continue
 
+                        status = ModelStatus.ACTIVE
+                        if (custom_model_schema.model_type in model_setting_map
+                                and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]):
+                            model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
+                            if model_setting.enabled is False:
+                                status = ModelStatus.DISABLED
+
                         provider_models.append(
                             ModelWithProviderEntity(
                                 model=custom_model_schema.model,
@@ -621,7 +821,7 @@ class ProviderConfiguration(BaseModel):
                                 model_properties=custom_model_schema.model_properties,
                                 deprecated=custom_model_schema.deprecated,
                                 provider=SimpleModelProviderEntity(self.provider),
-                                status=ModelStatus.ACTIVE
+                                status=status
                             )
                         )
 
@@ -632,16 +832,20 @@ class ProviderConfiguration(BaseModel):
                     m.status = ModelStatus.NO_PERMISSION
                 elif not quota_configuration.is_valid:
                     m.status = ModelStatus.QUOTA_EXCEEDED
+
         return provider_models
 
     def _get_custom_provider_models(self,
                                     model_types: list[ModelType],
-                                    provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
+                                    provider_instance: ModelProvider,
+                                    model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \
+            -> list[ModelWithProviderEntity]:
         """
         Get custom provider models.
 
         :param model_types: model types
         :param provider_instance: provider instance
+        :param model_setting_map: model setting map
         :return:
         """
         provider_models = []
@@ -656,6 +860,16 @@ class ProviderConfiguration(BaseModel):
 
             models = provider_instance.models(model_type)
             for m in models:
+                status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
+                load_balancing_enabled = False
+                if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
+                    model_setting = model_setting_map[m.model_type][m.model]
+                    if model_setting.enabled is False:
+                        status = ModelStatus.DISABLED
+
+                    if len(model_setting.load_balancing_configs) > 1:
+                        load_balancing_enabled = True
+
                 provider_models.append(
                     ModelWithProviderEntity(
                         model=m.model,
@@ -666,7 +880,8 @@ class ProviderConfiguration(BaseModel):
                         model_properties=m.model_properties,
                         deprecated=m.deprecated,
                         provider=SimpleModelProviderEntity(self.provider),
-                        status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
+                        status=status,
+                        load_balancing_enabled=load_balancing_enabled
                     )
                 )
 
@@ -690,6 +905,17 @@ class ProviderConfiguration(BaseModel):
             if not custom_model_schema:
                 continue
 
+            status = ModelStatus.ACTIVE
+            load_balancing_enabled = False
+            if (custom_model_schema.model_type in model_setting_map
+                    and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]):
+                model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
+                if model_setting.enabled is False:
+                    status = ModelStatus.DISABLED
+
+                if len(model_setting.load_balancing_configs) > 1:
+                    load_balancing_enabled = True
+
             provider_models.append(
                 ModelWithProviderEntity(
                     model=custom_model_schema.model,
@@ -700,7 +926,8 @@ class ProviderConfiguration(BaseModel):
                     model_properties=custom_model_schema.model_properties,
                     deprecated=custom_model_schema.deprecated,
                     provider=SimpleModelProviderEntity(self.provider),
-                    status=ModelStatus.ACTIVE
+                    status=status,
+                    load_balancing_enabled=load_balancing_enabled
                 )
             )
 

+ 19 - 0
api/core/entities/provider_entities.py

@@ -72,3 +72,22 @@ class CustomConfiguration(BaseModel):
     """
     provider: Optional[CustomProviderConfiguration] = None
     models: list[CustomModelConfiguration] = []
+
+
+class ModelLoadBalancingConfiguration(BaseModel):
+    """
+    Class for model load balancing configuration.
+    """
+    id: str
+    name: str
+    credentials: dict
+
+
+class ModelSettings(BaseModel):
+    """
+    Model class for model settings.
+    """
+    model: str
+    model_type: ModelType
+    enabled: bool = True
+    load_balancing_configs: list[ModelLoadBalancingConfiguration] = []

+ 1 - 1
api/core/extension/extensible.py

@@ -7,7 +7,7 @@ from typing import Any, Optional
 
 from pydantic import BaseModel
 
-from core.utils.position_helper import sort_to_dict_by_position_map
+from core.helper.position_helper import sort_to_dict_by_position_map
 
 
 class ExtensionModule(enum.Enum):

+ 1 - 0
api/core/helper/model_provider_cache.py

@@ -9,6 +9,7 @@ from extensions.ext_redis import redis_client
 class ProviderCredentialsCacheType(Enum):
     PROVIDER = "provider"
     MODEL = "provider_model"
+    LOAD_BALANCING_MODEL = "load_balancing_provider_model"
 
 
 class ProviderCredentialsCache:

+ 0 - 0
api/core/utils/module_import_helper.py → api/core/helper/module_import_helper.py


+ 0 - 0
api/core/utils/position_helper.py → api/core/helper/position_helper.py


+ 4 - 15
api/core/indexing_runner.py

@@ -286,11 +286,7 @@ class IndexingRunner:
                 if len(preview_texts) < 5:
                     preview_texts.append(document.page_content)
                 if indexing_technique == 'high_quality' or embedding_model_instance:
-                    embedding_model_type_instance = embedding_model_instance.model_type_instance
-                    embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
-                    tokens += embedding_model_type_instance.get_num_tokens(
-                        model=embedding_model_instance.model,
-                        credentials=embedding_model_instance.credentials,
+                    tokens += embedding_model_instance.get_text_embedding_num_tokens(
                         texts=[self.filter_string(document.page_content)]
                     )
 
@@ -658,10 +654,6 @@ class IndexingRunner:
         tokens = 0
         chunk_size = 10
 
-        embedding_model_type_instance = None
-        if embedding_model_instance:
-            embedding_model_type_instance = embedding_model_instance.model_type_instance
-            embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
         # create keyword index
         create_keyword_thread = threading.Thread(target=self._process_keyword_index,
                                                  args=(current_app._get_current_object(),
@@ -674,8 +666,7 @@ class IndexingRunner:
                     chunk_documents = documents[i:i + chunk_size]
                     futures.append(executor.submit(self._process_chunk, current_app._get_current_object(), index_processor,
                                                    chunk_documents, dataset,
-                                                   dataset_document, embedding_model_instance,
-                                                   embedding_model_type_instance))
+                                                   dataset_document, embedding_model_instance))
 
                 for future in futures:
                     tokens += future.result()
@@ -716,7 +707,7 @@ class IndexingRunner:
                 db.session.commit()
 
     def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, dataset_document,
-                       embedding_model_instance, embedding_model_type_instance):
+                       embedding_model_instance):
         with flask_app.app_context():
             # check document is paused
             self._check_document_paused_status(dataset_document.id)
@@ -724,9 +715,7 @@ class IndexingRunner:
             tokens = 0
             if dataset.indexing_technique == 'high_quality' or embedding_model_type_instance:
                 tokens += sum(
-                    embedding_model_type_instance.get_num_tokens(
-                        embedding_model_instance.model,
-                        embedding_model_instance.credentials,
+                    embedding_model_instance.get_text_embedding_num_tokens(
                         [document.page_content]
                     )
                     for document in chunk_documents

+ 2 - 11
api/core/memory/token_buffer_memory.py

@@ -9,8 +9,6 @@ from core.model_runtime.entities.message_entities import (
     TextPromptMessageContent,
     UserPromptMessage,
 )
-from core.model_runtime.entities.model_entities import ModelType
-from core.model_runtime.model_providers import model_provider_factory
 from extensions.ext_database import db
 from models.model import AppMode, Conversation, Message
 
@@ -78,12 +76,7 @@ class TokenBufferMemory:
             return []
 
         # prune the chat message if it exceeds the max token limit
-        provider_instance = model_provider_factory.get_provider_instance(self.model_instance.provider)
-        model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
-
-        curr_message_tokens = model_type_instance.get_num_tokens(
-            self.model_instance.model,
-            self.model_instance.credentials,
+        curr_message_tokens = self.model_instance.get_llm_num_tokens(
             prompt_messages
         )
 
@@ -91,9 +84,7 @@ class TokenBufferMemory:
             pruned_memory = []
             while curr_message_tokens > max_token_limit and prompt_messages:
                 pruned_memory.append(prompt_messages.pop(0))
-                curr_message_tokens = model_type_instance.get_num_tokens(
-                    self.model_instance.model,
-                    self.model_instance.credentials,
+                curr_message_tokens = self.model_instance.get_llm_num_tokens(
                     prompt_messages
                 )
 

+ 283 - 9
api/core/model_manager.py

@@ -1,7 +1,10 @@
+import logging
+import os
 from collections.abc import Generator
 from typing import IO, Optional, Union, cast
 
-from core.entities.provider_configuration import ProviderModelBundle
+from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
+from core.entities.provider_entities import ModelLoadBalancingConfiguration
 from core.errors.error import ProviderTokenNotInitError
 from core.model_runtime.callbacks.base_callback import Callback
 from core.model_runtime.entities.llm_entities import LLMResult
@@ -9,6 +12,7 @@ from core.model_runtime.entities.message_entities import PromptMessage, PromptMe
 from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.entities.rerank_entities import RerankResult
 from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
+from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
 from core.model_runtime.model_providers.__base.rerank_model import RerankModel
@@ -16,6 +20,10 @@ from core.model_runtime.model_providers.__base.speech2text_model import Speech2T
 from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
 from core.model_runtime.model_providers.__base.tts_model import TTSModel
 from core.provider_manager import ProviderManager
+from extensions.ext_redis import redis_client
+from models.provider import ProviderType
+
+logger = logging.getLogger(__name__)
 
 
 class ModelInstance:
@@ -29,6 +37,12 @@ class ModelInstance:
         self.provider = provider_model_bundle.configuration.provider.provider
         self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
         self.model_type_instance = self.provider_model_bundle.model_type_instance
+        self.load_balancing_manager = self._get_load_balancing_manager(
+            configuration=provider_model_bundle.configuration,
+            model_type=provider_model_bundle.model_type_instance.model_type,
+            model=model,
+            credentials=self.credentials
+        )
 
     def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBundle, model: str) -> dict:
         """
@@ -37,8 +51,10 @@ class ModelInstance:
         :param model: model name
         :return:
         """
-        credentials = provider_model_bundle.configuration.get_current_credentials(
-            model_type=provider_model_bundle.model_type_instance.model_type,
+        configuration = provider_model_bundle.configuration
+        model_type = provider_model_bundle.model_type_instance.model_type
+        credentials = configuration.get_current_credentials(
+            model_type=model_type,
             model=model
         )
 
@@ -47,6 +63,43 @@ class ModelInstance:
 
         return credentials
 
+    def _get_load_balancing_manager(self, configuration: ProviderConfiguration,
+                                    model_type: ModelType,
+                                    model: str,
+                                    credentials: dict) -> Optional["LBModelManager"]:
+        """
+        Get load balancing model credentials
+        :param configuration: provider configuration
+        :param model_type: model type
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        if configuration.model_settings and configuration.using_provider_type == ProviderType.CUSTOM:
+            current_model_setting = None
+            # check if model is disabled by admin
+            for model_setting in configuration.model_settings:
+                if (model_setting.model_type == model_type
+                        and model_setting.model == model):
+                    current_model_setting = model_setting
+                    break
+
+            # check if load balancing is enabled
+            if current_model_setting and current_model_setting.load_balancing_configs:
+                # use load balancing proxy to choose credentials
+                lb_model_manager = LBModelManager(
+                    tenant_id=configuration.tenant_id,
+                    provider=configuration.provider.provider,
+                    model_type=model_type,
+                    model=model,
+                    load_balancing_configs=current_model_setting.load_balancing_configs,
+                    managed_credentials=credentials if configuration.custom_configuration.provider else None
+                )
+
+                return lb_model_manager
+
+        return None
+
     def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
                    tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
                    stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
@@ -67,7 +120,8 @@ class ModelInstance:
             raise Exception("Model type instance is not LargeLanguageModel")
 
         self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
-        return self.model_type_instance.invoke(
+        return self._round_robin_invoke(
+            function=self.model_type_instance.invoke,
             model=self.model,
             credentials=self.credentials,
             prompt_messages=prompt_messages,
@@ -79,6 +133,27 @@ class ModelInstance:
             callbacks=callbacks
         )
 
+    def get_llm_num_tokens(self, prompt_messages: list[PromptMessage],
+                           tools: Optional[list[PromptMessageTool]] = None) -> int:
+        """
+        Get number of tokens for llm
+
+        :param prompt_messages: prompt messages
+        :param tools: tools for tool calling
+        :return:
+        """
+        if not isinstance(self.model_type_instance, LargeLanguageModel):
+            raise Exception("Model type instance is not LargeLanguageModel")
+
+        self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
+        return self._round_robin_invoke(
+            function=self.model_type_instance.get_num_tokens,
+            model=self.model,
+            credentials=self.credentials,
+            prompt_messages=prompt_messages,
+            tools=tools
+        )
+
     def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \
             -> TextEmbeddingResult:
         """
@@ -92,13 +167,32 @@ class ModelInstance:
             raise Exception("Model type instance is not TextEmbeddingModel")
 
         self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
-        return self.model_type_instance.invoke(
+        return self._round_robin_invoke(
+            function=self.model_type_instance.invoke,
             model=self.model,
             credentials=self.credentials,
             texts=texts,
             user=user
         )
 
+    def get_text_embedding_num_tokens(self, texts: list[str]) -> int:
+        """
+        Get number of tokens for text embedding
+
+        :param texts: texts to embed
+        :return:
+        """
+        if not isinstance(self.model_type_instance, TextEmbeddingModel):
+            raise Exception("Model type instance is not TextEmbeddingModel")
+
+        self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
+        return self._round_robin_invoke(
+            function=self.model_type_instance.get_num_tokens,
+            model=self.model,
+            credentials=self.credentials,
+            texts=texts
+        )
+
     def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None,
                       top_n: Optional[int] = None,
                       user: Optional[str] = None) \
@@ -117,7 +211,8 @@ class ModelInstance:
             raise Exception("Model type instance is not RerankModel")
 
         self.model_type_instance = cast(RerankModel, self.model_type_instance)
-        return self.model_type_instance.invoke(
+        return self._round_robin_invoke(
+            function=self.model_type_instance.invoke,
             model=self.model,
             credentials=self.credentials,
             query=query,
@@ -140,7 +235,8 @@ class ModelInstance:
             raise Exception("Model type instance is not ModerationModel")
 
         self.model_type_instance = cast(ModerationModel, self.model_type_instance)
-        return self.model_type_instance.invoke(
+        return self._round_robin_invoke(
+            function=self.model_type_instance.invoke,
             model=self.model,
             credentials=self.credentials,
             text=text,
@@ -160,7 +256,8 @@ class ModelInstance:
             raise Exception("Model type instance is not Speech2TextModel")
 
         self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
-        return self.model_type_instance.invoke(
+        return self._round_robin_invoke(
+            function=self.model_type_instance.invoke,
             model=self.model,
             credentials=self.credentials,
             file=file,
@@ -183,7 +280,8 @@ class ModelInstance:
             raise Exception("Model type instance is not TTSModel")
 
         self.model_type_instance = cast(TTSModel, self.model_type_instance)
-        return self.model_type_instance.invoke(
+        return self._round_robin_invoke(
+            function=self.model_type_instance.invoke,
             model=self.model,
             credentials=self.credentials,
             content_text=content_text,
@@ -193,6 +291,43 @@ class ModelInstance:
             streaming=streaming
         )
 
+    def _round_robin_invoke(self, function: callable, *args, **kwargs):
+        """
+        Round-robin invoke
+        :param function: function to invoke
+        :param args: function args
+        :param kwargs: function kwargs
+        :return:
+        """
+        if not self.load_balancing_manager:
+            return function(*args, **kwargs)
+
+        last_exception = None
+        while True:
+            lb_config = self.load_balancing_manager.fetch_next()
+            if not lb_config:
+                if not last_exception:
+                    raise ProviderTokenNotInitError("Model credentials is not initialized.")
+                else:
+                    raise last_exception
+
+            try:
+                if 'credentials' in kwargs:
+                    del kwargs['credentials']
+                return function(*args, **kwargs, credentials=lb_config.credentials)
+            except InvokeRateLimitError as e:
+                # expire in 60 seconds
+                self.load_balancing_manager.cooldown(lb_config, expire=60)
+                last_exception = e
+                continue
+            except (InvokeAuthorizationError, InvokeConnectionError) as e:
+                # expire in 10 seconds
+                self.load_balancing_manager.cooldown(lb_config, expire=10)
+                last_exception = e
+                continue
+            except Exception as e:
+                raise e
+
     def get_tts_voices(self, language: str) -> list:
         """
         Invoke large language tts model voices
@@ -226,6 +361,7 @@ class ModelManager:
         """
         if not provider:
             return self.get_default_model_instance(tenant_id, model_type)
+
         provider_model_bundle = self._provider_manager.get_provider_model_bundle(
             tenant_id=tenant_id,
             provider=provider,
@@ -255,3 +391,141 @@ class ModelManager:
             model_type=model_type,
             model=default_model_entity.model
         )
+
+
+class LBModelManager:
+    def __init__(self, tenant_id: str,
+                 provider: str,
+                 model_type: ModelType,
+                 model: str,
+                 load_balancing_configs: list[ModelLoadBalancingConfiguration],
+                 managed_credentials: Optional[dict] = None) -> None:
+        """
+        Load balancing model manager
+        :param load_balancing_configs: all load balancing configurations
+        :param managed_credentials: credentials if load balancing configuration name is __inherit__
+        """
+        self._tenant_id = tenant_id
+        self._provider = provider
+        self._model_type = model_type
+        self._model = model
+        self._load_balancing_configs = load_balancing_configs
+
+        for load_balancing_config in self._load_balancing_configs:
+            if load_balancing_config.name == "__inherit__":
+                if not managed_credentials:
+                    # remove __inherit__ if managed credentials is not provided
+                    self._load_balancing_configs.remove(load_balancing_config)
+                else:
+                    load_balancing_config.credentials = managed_credentials
+
+    def fetch_next(self) -> Optional[ModelLoadBalancingConfiguration]:
+        """
+        Get next model load balancing config
+        Strategy: Round Robin
+        :return:
+        """
+        cache_key = "model_lb_index:{}:{}:{}:{}".format(
+            self._tenant_id,
+            self._provider,
+            self._model_type.value,
+            self._model
+        )
+
+        cooldown_load_balancing_configs = []
+        max_index = len(self._load_balancing_configs)
+
+        while True:
+            current_index = redis_client.incr(cache_key)
+            if current_index >= 10000000:
+                current_index = 1
+                redis_client.set(cache_key, current_index)
+
+            redis_client.expire(cache_key, 3600)
+            if current_index > max_index:
+                current_index = current_index % max_index
+
+            real_index = current_index - 1
+            if real_index > max_index:
+                real_index = 0
+
+            config = self._load_balancing_configs[real_index]
+
+            if self.in_cooldown(config):
+                cooldown_load_balancing_configs.append(config)
+                if len(cooldown_load_balancing_configs) >= len(self._load_balancing_configs):
+                    # all configs are in cooldown
+                    return None
+
+                continue
+
+            if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
+                logger.info(f"Model LB\nid: {config.id}\nname:{config.name}\n"
+                            f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n"
+                            f"model_type: {self._model_type.value}\nmodel: {self._model}")
+
+            return config
+
+        return None
+
+    def cooldown(self, config: ModelLoadBalancingConfiguration, expire: int = 60) -> None:
+        """
+        Cooldown model load balancing config
+        :param config: model load balancing config
+        :param expire: cooldown time
+        :return:
+        """
+        cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
+            self._tenant_id,
+            self._provider,
+            self._model_type.value,
+            self._model,
+            config.id
+        )
+
+        redis_client.setex(cooldown_cache_key, expire, 'true')
+
+    def in_cooldown(self, config: ModelLoadBalancingConfiguration) -> bool:
+        """
+        Check if model load balancing config is in cooldown
+        :param config: model load balancing config
+        :return:
+        """
+        cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
+            self._tenant_id,
+            self._provider,
+            self._model_type.value,
+            self._model,
+            config.id
+        )
+
+        return redis_client.exists(cooldown_cache_key)
+
+    @classmethod
+    def get_config_in_cooldown_and_ttl(cls, tenant_id: str,
+                                       provider: str,
+                                       model_type: ModelType,
+                                       model: str,
+                                       config_id: str) -> tuple[bool, int]:
+        """
+        Get model load balancing config is in cooldown and ttl
+        :param tenant_id: workspace id
+        :param provider: provider name
+        :param model_type: model type
+        :param model: model name
+        :param config_id: model load balancing config id
+        :return:
+        """
+        cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
+            tenant_id,
+            provider,
+            model_type.value,
+            model,
+            config_id
+        )
+
+        ttl = redis_client.ttl(cooldown_cache_key)
+        if ttl == -2:
+            return False, 0
+
+        return True, ttl

+ 1 - 1
api/core/model_runtime/model_providers/__base/ai_model.py

@@ -3,6 +3,7 @@ import os
 from abc import ABC, abstractmethod
 from typing import Optional
 
+from core.helper.position_helper import get_position_map, sort_by_position_map
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
 from core.model_runtime.entities.model_entities import (
@@ -17,7 +18,6 @@ from core.model_runtime.entities.model_entities import (
 from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
 from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
 from core.tools.utils.yaml_utils import load_yaml_file
-from core.utils.position_helper import get_position_map, sort_by_position_map
 
 
 class AIModel(ABC):

+ 1 - 1
api/core/model_runtime/model_providers/__base/model_provider.py

@@ -1,11 +1,11 @@
 import os
 from abc import ABC, abstractmethod
 
+from core.helper.module_import_helper import get_subclasses_from_module, import_module_from_source
 from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
 from core.model_runtime.entities.provider_entities import ProviderEntity
 from core.model_runtime.model_providers.__base.ai_model import AIModel
 from core.tools.utils.yaml_utils import load_yaml_file
-from core.utils.module_import_helper import get_subclasses_from_module, import_module_from_source
 
 
 class ModelProvider(ABC):

+ 2 - 2
api/core/model_runtime/model_providers/model_provider_factory.py

@@ -4,13 +4,13 @@ from typing import Optional
 
 from pydantic import BaseModel
 
+from core.helper.module_import_helper import load_single_subclass_from_source
+from core.helper.position_helper import get_position_map, sort_to_dict_by_position_map
 from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
 from core.model_runtime.model_providers.__base.model_provider import ModelProvider
 from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
 from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator
-from core.utils.module_import_helper import load_single_subclass_from_source
-from core.utils.position_helper import get_position_map, sort_to_dict_by_position_map
 
 logger = logging.getLogger(__name__)
 

+ 7 - 7
api/core/prompt/prompt_transform.py

@@ -1,10 +1,10 @@
-from typing import Optional, cast
+from typing import Optional
 
 from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
 from core.memory.token_buffer_memory import TokenBufferMemory
+from core.model_manager import ModelInstance
 from core.model_runtime.entities.message_entities import PromptMessage
 from core.model_runtime.entities.model_entities import ModelPropertyKey
-from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.prompt.entities.advanced_prompt_entities import MemoryConfig
 
 
@@ -25,12 +25,12 @@ class PromptTransform:
 
         model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
         if model_context_tokens:
-            model_type_instance = model_config.provider_model_bundle.model_type_instance
-            model_type_instance = cast(LargeLanguageModel, model_type_instance)
+            model_instance = ModelInstance(
+                provider_model_bundle=model_config.provider_model_bundle,
+                model=model_config.model
+            )
 
-            curr_message_tokens = model_type_instance.get_num_tokens(
-                model_config.model,
-                model_config.credentials,
+            curr_message_tokens = model_instance.get_llm_num_tokens(
                 prompt_messages
             )
 

+ 165 - 2
api/core/provider_manager.py

@@ -11,6 +11,8 @@ from core.entities.provider_entities import (
     CustomConfiguration,
     CustomModelConfiguration,
     CustomProviderConfiguration,
+    ModelLoadBalancingConfiguration,
+    ModelSettings,
     QuotaConfiguration,
     SystemConfiguration,
 )
@@ -26,13 +28,16 @@ from core.model_runtime.model_providers import model_provider_factory
 from extensions import ext_hosting_provider
 from extensions.ext_database import db
 from models.provider import (
+    LoadBalancingModelConfig,
     Provider,
     ProviderModel,
+    ProviderModelSetting,
     ProviderQuotaType,
     ProviderType,
     TenantDefaultModel,
     TenantPreferredModelProvider,
 )
+from services.feature_service import FeatureService
 
 
 class ProviderManager:
@@ -98,6 +103,13 @@ class ProviderManager:
         # Get All preferred provider types of the workspace
         provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id)
 
+        # Get All provider model settings
+        provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id)
+
+        # Get All load balancing configs
+        provider_name_to_provider_load_balancing_model_configs_dict \
+            = self._get_all_provider_load_balancing_configs(tenant_id)
+
         provider_configurations = ProviderConfigurations(
             tenant_id=tenant_id
         )
@@ -147,13 +159,28 @@ class ProviderManager:
                     if system_configuration.enabled and has_valid_quota:
                         using_provider_type = ProviderType.SYSTEM
 
+            # Get provider load balancing configs
+            provider_model_settings = provider_name_to_provider_model_settings_dict.get(provider_name)
+
+            # Get provider load balancing configs
+            provider_load_balancing_configs \
+                = provider_name_to_provider_load_balancing_model_configs_dict.get(provider_name)
+
+            # Convert to model settings
+            model_settings = self._to_model_settings(
+                provider_entity=provider_entity,
+                provider_model_settings=provider_model_settings,
+                load_balancing_model_configs=provider_load_balancing_configs
+            )
+
             provider_configuration = ProviderConfiguration(
                 tenant_id=tenant_id,
                 provider=provider_entity,
                 preferred_provider_type=preferred_provider_type,
                 using_provider_type=using_provider_type,
                 system_configuration=system_configuration,
-                custom_configuration=custom_configuration
+                custom_configuration=custom_configuration,
+                model_settings=model_settings
             )
 
             provider_configurations[provider_name] = provider_configuration
@@ -338,7 +365,7 @@ class ProviderManager:
         """
         Get All preferred provider types of the workspace.
 
-        :param tenant_id:
+        :param tenant_id: workspace id
         :return:
         """
         preferred_provider_types = db.session.query(TenantPreferredModelProvider) \
@@ -353,6 +380,48 @@ class ProviderManager:
 
         return provider_name_to_preferred_provider_type_records_dict
 
+    def _get_all_provider_model_settings(self, tenant_id: str) -> dict[str, list[ProviderModelSetting]]:
+        """
+        Get All provider model settings of the workspace.
+
+        :param tenant_id: workspace id
+        :return:
+        """
+        provider_model_settings = db.session.query(ProviderModelSetting) \
+            .filter(
+            ProviderModelSetting.tenant_id == tenant_id
+        ).all()
+
+        provider_name_to_provider_model_settings_dict = defaultdict(list)
+        for provider_model_setting in provider_model_settings:
+            (provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name]
+             .append(provider_model_setting))
+
+        return provider_name_to_provider_model_settings_dict
+
+    def _get_all_provider_load_balancing_configs(self, tenant_id: str) -> dict[str, list[LoadBalancingModelConfig]]:
+        """
+        Get All provider load balancing configs of the workspace.
+
+        :param tenant_id: workspace id
+        :return:
+        """
+        model_load_balancing_enabled = FeatureService.get_features(tenant_id).model_load_balancing_enabled
+        if not model_load_balancing_enabled:
+            return dict()
+
+        provider_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \
+            .filter(
+            LoadBalancingModelConfig.tenant_id == tenant_id
+        ).all()
+
+        provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list)
+        for provider_load_balancing_config in provider_load_balancing_configs:
+            (provider_name_to_provider_load_balancing_model_configs_dict[provider_load_balancing_config.provider_name]
+             .append(provider_load_balancing_config))
+
+        return provider_name_to_provider_load_balancing_model_configs_dict
+
     def _init_trial_provider_records(self, tenant_id: str,
                                      provider_name_to_provider_records_dict: dict[str, list]) -> dict[str, list]:
         """
@@ -726,3 +795,97 @@ class ProviderManager:
                 secret_input_form_variables.append(credential_form_schema.variable)
 
         return secret_input_form_variables
+
+    def _to_model_settings(self, provider_entity: ProviderEntity,
+                           provider_model_settings: Optional[list[ProviderModelSetting]] = None,
+                           load_balancing_model_configs: Optional[list[LoadBalancingModelConfig]] = None) \
+            -> list[ModelSettings]:
+        """
+        Convert to model settings.
+
+        :param provider_model_settings: provider model settings include enabled, load balancing enabled
+        :param load_balancing_model_configs: load balancing model configs
+        :return:
+        """
+        # Get provider model credential secret variables
+        model_credential_secret_variables = self._extract_secret_variables(
+            provider_entity.model_credential_schema.credential_form_schemas
+            if provider_entity.model_credential_schema else []
+        )
+
+        model_settings = []
+        if not provider_model_settings:
+            return model_settings
+
+        for provider_model_setting in provider_model_settings:
+            load_balancing_configs = []
+            if provider_model_setting.load_balancing_enabled and load_balancing_model_configs:
+                for load_balancing_model_config in load_balancing_model_configs:
+                    if (load_balancing_model_config.model_name == provider_model_setting.model_name
+                            and load_balancing_model_config.model_type == provider_model_setting.model_type):
+                        if not load_balancing_model_config.enabled:
+                            continue
+
+                        if not load_balancing_model_config.encrypted_config:
+                            if load_balancing_model_config.name == "__inherit__":
+                                load_balancing_configs.append(ModelLoadBalancingConfiguration(
+                                    id=load_balancing_model_config.id,
+                                    name=load_balancing_model_config.name,
+                                    credentials={}
+                                ))
+                            continue
+
+                        provider_model_credentials_cache = ProviderCredentialsCache(
+                            tenant_id=load_balancing_model_config.tenant_id,
+                            identity_id=load_balancing_model_config.id,
+                            cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL
+                        )
+
+                        # Get cached provider model credentials
+                        cached_provider_model_credentials = provider_model_credentials_cache.get()
+
+                        if not cached_provider_model_credentials:
+                            try:
+                                provider_model_credentials = json.loads(load_balancing_model_config.encrypted_config)
+                            except JSONDecodeError:
+                                continue
+
+                            # Get decoding rsa key and cipher for decrypting credentials
+                            if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
+                                self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(
+                                    load_balancing_model_config.tenant_id)
+
+                            for variable in model_credential_secret_variables:
+                                if variable in provider_model_credentials:
+                                    try:
+                                        provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
+                                            provider_model_credentials.get(variable),
+                                            self.decoding_rsa_key,
+                                            self.decoding_cipher_rsa
+                                        )
+                                    except ValueError:
+                                        pass
+
+                            # cache provider model credentials
+                            provider_model_credentials_cache.set(
+                                credentials=provider_model_credentials
+                            )
+                        else:
+                            provider_model_credentials = cached_provider_model_credentials
+
+                        load_balancing_configs.append(ModelLoadBalancingConfiguration(
+                            id=load_balancing_model_config.id,
+                            name=load_balancing_model_config.name,
+                            credentials=provider_model_credentials
+                        ))
+
+            model_settings.append(
+                ModelSettings(
+                    model=provider_model_setting.model_name,
+                    model_type=ModelType.value_of(provider_model_setting.model_type),
+                    enabled=provider_model_setting.enabled,
+                    load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else []
+                )
+            )
+
+        return model_settings

+ 2 - 7
api/core/rag/docstore/dataset_docstore.py

@@ -1,11 +1,10 @@
 from collections.abc import Sequence
-from typing import Any, Optional, cast
+from typing import Any, Optional
 
 from sqlalchemy import func
 
 from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelType
-from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
 from core.rag.models.document import Document
 from extensions.ext_database import db
 from models.dataset import Dataset, DocumentSegment
@@ -95,11 +94,7 @@ class DatasetDocumentStore:
 
             # calc embedding use tokens
             if embedding_model:
-                model_type_instance = embedding_model.model_type_instance
-                model_type_instance = cast(TextEmbeddingModel, model_type_instance)
-                tokens = model_type_instance.get_num_tokens(
-                    model=embedding_model.model,
-                    credentials=embedding_model.credentials,
+                tokens = embedding_model.get_text_embedding_num_tokens(
                     texts=[doc.page_content]
                 )
             else:

+ 2 - 7
api/core/rag/splitter/fixed_text_splitter.py

@@ -1,10 +1,9 @@
 """Functionality for splitting text."""
 from __future__ import annotations
 
-from typing import Any, Optional, cast
+from typing import Any, Optional
 
 from core.model_manager import ModelInstance
-from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
 from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
 from core.rag.splitter.text_splitter import (
     TS,
@@ -35,11 +34,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
                 return 0
 
             if embedding_model_instance:
-                embedding_model_type_instance = embedding_model_instance.model_type_instance
-                embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
-                return embedding_model_type_instance.get_num_tokens(
-                    model=embedding_model_instance.model,
-                    credentials=embedding_model_instance.credentials,
+                return embedding_model_instance.get_text_embedding_num_tokens(
                     texts=[text]
                 )
             else:

+ 1 - 1
api/core/tools/provider/builtin/_positions.py

@@ -1,7 +1,7 @@
 import os.path
 
+from core.helper.position_helper import get_position_map, sort_by_position_map
 from core.tools.entities.api_entities import UserToolProvider
-from core.utils.position_helper import get_position_map, sort_by_position_map
 
 
 class BuiltinToolProviderSort:

+ 4 - 4
api/core/tools/provider/builtin_tool_provider.py

@@ -2,6 +2,7 @@ from abc import abstractmethod
 from os import listdir, path
 from typing import Any
 
+from core.helper.module_import_helper import load_single_subclass_from_source
 from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType
 from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
 from core.tools.errors import (
@@ -14,7 +15,6 @@ from core.tools.tool.builtin_tool import BuiltinTool
 from core.tools.tool.tool import Tool
 from core.tools.utils.tool_parameter_converter import ToolParameterConverter
 from core.tools.utils.yaml_utils import load_yaml_file
-from core.utils.module_import_helper import load_single_subclass_from_source
 
 
 class BuiltinToolProviderController(ToolProviderController):
@@ -82,7 +82,7 @@ class BuiltinToolProviderController(ToolProviderController):
             return {}
         
         return self.credentials_schema.copy()
-    
+
     def get_tools(self) -> list[Tool]:
         """
             returns a list of tools that the provider can provide
@@ -127,7 +127,7 @@ class BuiltinToolProviderController(ToolProviderController):
             :return: type of the provider
         """
         return ToolProviderType.BUILT_IN
-    
+
     @property
     def tool_labels(self) -> list[str]:
         """
@@ -137,7 +137,7 @@ class BuiltinToolProviderController(ToolProviderController):
         """
         label_enums = self._get_tool_labels()
         return [default_tool_label_dict[label].name for label in label_enums]
-    
+
     def _get_tool_labels(self) -> list[ToolLabelEnum]:
         """
             returns the labels of the provider

+ 10 - 10
api/core/tools/tool_manager.py

@@ -10,6 +10,7 @@ from flask import current_app
 
 from core.agent.entities import AgentToolEntity
 from core.app.entities.app_invoke_entities import InvokeFrom
+from core.helper.module_import_helper import load_single_subclass_from_source
 from core.model_runtime.utils.encoders import jsonable_encoder
 from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral
 from core.tools.entities.common_entities import I18nObject
@@ -31,7 +32,6 @@ from core.tools.utils.configuration import (
     ToolParameterConfigurationManager,
 )
 from core.tools.utils.tool_parameter_converter import ToolParameterConverter
-from core.utils.module_import_helper import load_single_subclass_from_source
 from core.workflow.nodes.tool.entities import ToolEntity
 from extensions.ext_database import db
 from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
@@ -102,10 +102,10 @@ class ToolManager:
             raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
 
     @classmethod
-    def get_tool_runtime(cls, provider_type: str, 
+    def get_tool_runtime(cls, provider_type: str,
                          provider_id: str,
-                         tool_name: str, 
-                         tenant_id: str, 
+                         tool_name: str,
+                         tenant_id: str,
                          invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
                          tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \
         -> Union[BuiltinTool, ApiTool]:
@@ -222,7 +222,7 @@ class ToolManager:
             get the agent tool runtime
         """
         tool_entity = cls.get_tool_runtime(
-            provider_type=agent_tool.provider_type, 
+            provider_type=agent_tool.provider_type,
             provider_id=agent_tool.provider_id,
             tool_name=agent_tool.tool_name,
             tenant_id=tenant_id,
@@ -235,7 +235,7 @@ class ToolManager:
             # check file types
             if parameter.type == ToolParameter.ToolParameterType.FILE:
                 raise ValueError(f"file type parameter {parameter.name} not supported in agent")
-            
+
             if parameter.form == ToolParameter.ToolParameterForm.FORM:
                 # save tool parameter to tool entity memory
                 value = cls._init_runtime_parameter(parameter, agent_tool.tool_parameters)
@@ -403,7 +403,7 @@ class ToolManager:
 
             # get builtin providers
             builtin_providers = cls.list_builtin_providers()
-            
+
             # get db builtin providers
             db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \
                 filter(BuiltinToolProvider.tenant_id == tenant_id).all()
@@ -428,7 +428,7 @@ class ToolManager:
         if 'api' in filters:
             db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \
                 filter(ApiToolProvider.tenant_id == tenant_id).all()
-            
+
             api_provider_controllers = [{
                 'provider': provider,
                 'controller': ToolTransformService.api_provider_to_controller(provider)
@@ -450,7 +450,7 @@ class ToolManager:
             # get workflow providers
             workflow_providers: list[WorkflowToolProvider] = db.session.query(WorkflowToolProvider). \
                 filter(WorkflowToolProvider.tenant_id == tenant_id).all()
-            
+
             workflow_provider_controllers = []
             for provider in workflow_providers:
                 try:
@@ -460,7 +460,7 @@ class ToolManager:
                 except Exception as e:
                     # app has been deleted
                     pass
-            
+
             labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers)
 
             for provider_controller in workflow_provider_controllers:

+ 4 - 13
api/core/tools/utils/model_invocation_utils.py

@@ -73,10 +73,8 @@ class ModelInvocationUtils:
         if not model_instance:
             raise InvokeModelError('Model not found')
         
-        llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
-
         # get tokens
-        tokens = llm_model.get_num_tokens(model_instance.model, model_instance.credentials, prompt_messages)
+        tokens = model_instance.get_llm_num_tokens(prompt_messages)
 
         return tokens
 
@@ -108,13 +106,8 @@ class ModelInvocationUtils:
             tenant_id=tenant_id, model_type=ModelType.LLM,
         )
 
-        llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
-
-        # get model credentials
-        model_credentials = model_instance.credentials
-
         # get prompt tokens
-        prompt_tokens = llm_model.get_num_tokens(model_instance.model, model_credentials, prompt_messages)
+        prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
 
         model_parameters = {
             'temperature': 0.8,
@@ -144,9 +137,7 @@ class ModelInvocationUtils:
         db.session.commit()
 
         try:
-            response: LLMResult = llm_model.invoke(
-                model=model_instance.model,
-                credentials=model_credentials,
+            response: LLMResult = model_instance.invoke_llm(
                 prompt_messages=prompt_messages,
                 model_parameters=model_parameters,
                 tools=[], stop=[], stream=False, user=user_id, callbacks=[]
@@ -176,4 +167,4 @@ class ModelInvocationUtils:
 
         db.session.commit()
 
-        return response
+        return response

+ 6 - 6
api/core/workflow/nodes/question_classifier/question_classifier_node.py

@@ -4,9 +4,9 @@ from typing import Optional, Union, cast
 
 from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
 from core.memory.token_buffer_memory import TokenBufferMemory
+from core.model_manager import ModelInstance
 from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole
 from core.model_runtime.entities.model_entities import ModelPropertyKey
-from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.model_runtime.utils.encoders import jsonable_encoder
 from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
 from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
@@ -200,12 +200,12 @@ class QuestionClassifierNode(LLMNode):
 
         model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
         if model_context_tokens:
-            model_type_instance = model_config.provider_model_bundle.model_type_instance
-            model_type_instance = cast(LargeLanguageModel, model_type_instance)
+            model_instance = ModelInstance(
+                provider_model_bundle=model_config.provider_model_bundle,
+                model=model_config.model
+            )
 
-            curr_message_tokens = model_type_instance.get_num_tokens(
-                model_config.model,
-                model_config.credentials,
+            curr_message_tokens = model_instance.get_llm_num_tokens(
                 prompt_messages
             )
 

+ 126 - 0
api/migrations/versions/4e99a8df00ff_add_load_balancing.py

@@ -0,0 +1,126 @@
+"""add load balancing
+
+Revision ID: 4e99a8df00ff
+Revises: 47cc7df8c4f3
+Create Date: 2024-05-10 12:08:09.812736
+
+"""
+import sqlalchemy as sa
+from alembic import op
+
+import models as models
+
+# revision identifiers, used by Alembic.
+revision = '4e99a8df00ff'
+down_revision = '64a70a7aab8b'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    op.create_table('load_balancing_model_configs',
+    sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+    sa.Column('tenant_id', models.StringUUID(), nullable=False),
+    sa.Column('provider_name', sa.String(length=255), nullable=False),
+    sa.Column('model_name', sa.String(length=255), nullable=False),
+    sa.Column('model_type', sa.String(length=40), nullable=False),
+    sa.Column('name', sa.String(length=255), nullable=False),
+    sa.Column('encrypted_config', sa.Text(), nullable=True),
+    sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False),
+    sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
+    sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
+    sa.PrimaryKeyConstraint('id', name='load_balancing_model_config_pkey')
+    )
+    with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op:
+        batch_op.create_index('load_balancing_model_config_tenant_provider_model_idx', ['tenant_id', 'provider_name', 'model_type'], unique=False)
+
+    op.create_table('provider_model_settings',
+    sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+    sa.Column('tenant_id', models.StringUUID(), nullable=False),
+    sa.Column('provider_name', sa.String(length=255), nullable=False),
+    sa.Column('model_name', sa.String(length=255), nullable=False),
+    sa.Column('model_type', sa.String(length=40), nullable=False),
+    sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False),
+    sa.Column('load_balancing_enabled', sa.Boolean(), server_default=sa.text('false'), nullable=False),
+    sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
+    sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
+    sa.PrimaryKeyConstraint('id', name='provider_model_setting_pkey')
+    )
+    with op.batch_alter_table('provider_model_settings', schema=None) as batch_op:
+        batch_op.create_index('provider_model_setting_tenant_provider_model_idx', ['tenant_id', 'provider_name', 'model_type'], unique=False)
+
+    with op.batch_alter_table('provider_models', schema=None) as batch_op:
+        batch_op.alter_column('provider_name',
+               existing_type=sa.VARCHAR(length=40),
+               type_=sa.String(length=255),
+               existing_nullable=False)
+
+    with op.batch_alter_table('provider_orders', schema=None) as batch_op:
+        batch_op.alter_column('provider_name',
+               existing_type=sa.VARCHAR(length=40),
+               type_=sa.String(length=255),
+               existing_nullable=False)
+
+    with op.batch_alter_table('providers', schema=None) as batch_op:
+        batch_op.alter_column('provider_name',
+               existing_type=sa.VARCHAR(length=40),
+               type_=sa.String(length=255),
+               existing_nullable=False)
+
+    with op.batch_alter_table('tenant_default_models', schema=None) as batch_op:
+        batch_op.alter_column('provider_name',
+               existing_type=sa.VARCHAR(length=40),
+               type_=sa.String(length=255),
+               existing_nullable=False)
+
+    with op.batch_alter_table('tenant_preferred_model_providers', schema=None) as batch_op:
+        batch_op.alter_column('provider_name',
+               existing_type=sa.VARCHAR(length=40),
+               type_=sa.String(length=255),
+               existing_nullable=False)
+
+    # ### end Alembic commands ###
+
+
+def downgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('tenant_preferred_model_providers', schema=None) as batch_op:
+        batch_op.alter_column('provider_name',
+               existing_type=sa.String(length=255),
+               type_=sa.VARCHAR(length=40),
+               existing_nullable=False)
+
+    with op.batch_alter_table('tenant_default_models', schema=None) as batch_op:
+        batch_op.alter_column('provider_name',
+               existing_type=sa.String(length=255),
+               type_=sa.VARCHAR(length=40),
+               existing_nullable=False)
+
+    with op.batch_alter_table('providers', schema=None) as batch_op:
+        batch_op.alter_column('provider_name',
+               existing_type=sa.String(length=255),
+               type_=sa.VARCHAR(length=40),
+               existing_nullable=False)
+
+    with op.batch_alter_table('provider_orders', schema=None) as batch_op:
+        batch_op.alter_column('provider_name',
+               existing_type=sa.String(length=255),
+               type_=sa.VARCHAR(length=40),
+               existing_nullable=False)
+
+    with op.batch_alter_table('provider_models', schema=None) as batch_op:
+        batch_op.alter_column('provider_name',
+               existing_type=sa.String(length=255),
+               type_=sa.VARCHAR(length=40),
+               existing_nullable=False)
+
+    with op.batch_alter_table('provider_model_settings', schema=None) as batch_op:
+        batch_op.drop_index('provider_model_setting_tenant_provider_model_idx')
+
+    op.drop_table('provider_model_settings')
+    with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op:
+        batch_op.drop_index('load_balancing_model_config_tenant_provider_model_idx')
+
+    op.drop_table('load_balancing_model_configs')
+    # ### end Alembic commands ###

+ 48 - 5
api/models/provider.py

@@ -47,7 +47,7 @@ class Provider(db.Model):
 
     id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
     tenant_id = db.Column(StringUUID, nullable=False)
-    provider_name = db.Column(db.String(40), nullable=False)
+    provider_name = db.Column(db.String(255), nullable=False)
     provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying"))
     encrypted_config = db.Column(db.Text, nullable=True)
     is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
@@ -94,7 +94,7 @@ class ProviderModel(db.Model):
 
     id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
     tenant_id = db.Column(StringUUID, nullable=False)
-    provider_name = db.Column(db.String(40), nullable=False)
+    provider_name = db.Column(db.String(255), nullable=False)
     model_name = db.Column(db.String(255), nullable=False)
     model_type = db.Column(db.String(40), nullable=False)
     encrypted_config = db.Column(db.Text, nullable=True)
@@ -112,7 +112,7 @@ class TenantDefaultModel(db.Model):
 
     id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
     tenant_id = db.Column(StringUUID, nullable=False)
-    provider_name = db.Column(db.String(40), nullable=False)
+    provider_name = db.Column(db.String(255), nullable=False)
     model_name = db.Column(db.String(255), nullable=False)
     model_type = db.Column(db.String(40), nullable=False)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@@ -128,7 +128,7 @@ class TenantPreferredModelProvider(db.Model):
 
     id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
     tenant_id = db.Column(StringUUID, nullable=False)
-    provider_name = db.Column(db.String(40), nullable=False)
+    provider_name = db.Column(db.String(255), nullable=False)
     preferred_provider_type = db.Column(db.String(40), nullable=False)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@@ -143,7 +143,7 @@ class ProviderOrder(db.Model):
 
     id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
     tenant_id = db.Column(StringUUID, nullable=False)
-    provider_name = db.Column(db.String(40), nullable=False)
+    provider_name = db.Column(db.String(255), nullable=False)
     account_id = db.Column(StringUUID, nullable=False)
     payment_product_id = db.Column(db.String(191), nullable=False)
     payment_id = db.Column(db.String(191))
@@ -157,3 +157,46 @@ class ProviderOrder(db.Model):
     refunded_at = db.Column(db.DateTime)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
+
+
+class ProviderModelSetting(db.Model):
+    """
+    Provider model settings for record the model enabled status and load balancing status.
+    """
+    __tablename__ = 'provider_model_settings'
+    __table_args__ = (
+        db.PrimaryKeyConstraint('id', name='provider_model_setting_pkey'),
+        db.Index('provider_model_setting_tenant_provider_model_idx', 'tenant_id', 'provider_name', 'model_type'),
+    )
+
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    tenant_id = db.Column(StringUUID, nullable=False)
+    provider_name = db.Column(db.String(255), nullable=False)
+    model_name = db.Column(db.String(255), nullable=False)
+    model_type = db.Column(db.String(40), nullable=False)
+    enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('true'))
+    load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
+    created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
+    updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
+
+
+class LoadBalancingModelConfig(db.Model):
+    """
+    Configurations for load balancing models.
+    """
+    __tablename__ = 'load_balancing_model_configs'
+    __table_args__ = (
+        db.PrimaryKeyConstraint('id', name='load_balancing_model_config_pkey'),
+        db.Index('load_balancing_model_config_tenant_provider_model_idx', 'tenant_id', 'provider_name', 'model_type'),
+    )
+
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    tenant_id = db.Column(StringUUID, nullable=False)
+    provider_name = db.Column(db.String(255), nullable=False)
+    model_name = db.Column(db.String(255), nullable=False)
+    model_type = db.Column(db.String(40), nullable=False)
+    name = db.Column(db.String(255), nullable=False)
+    encrypted_config = db.Column(db.Text, nullable=True)
+    enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('true'))
+    created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
+    updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))

+ 4 - 14
api/services/dataset_service.py

@@ -4,7 +4,7 @@ import logging
 import random
 import time
 import uuid
-from typing import Optional, cast
+from typing import Optional
 
 from flask import current_app
 from flask_login import current_user
@@ -13,7 +13,6 @@ from sqlalchemy import func
 from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
 from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelType
-from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
 from core.rag.datasource.keyword.keyword_factory import Keyword
 from core.rag.models.document import Document as RAGDocument
 from events.dataset_event import dataset_was_deleted
@@ -1144,10 +1143,7 @@ class SegmentService:
                 model=dataset.embedding_model
             )
             # calc embedding use tokens
-            model_type_instance = cast(TextEmbeddingModel, embedding_model.model_type_instance)
-            tokens = model_type_instance.get_num_tokens(
-                model=embedding_model.model,
-                credentials=embedding_model.credentials,
+            tokens = embedding_model.get_text_embedding_num_tokens(
                 texts=[content]
             )
         lock_name = 'add_segment_lock_document_id_{}'.format(document.id)
@@ -1215,10 +1211,7 @@ class SegmentService:
                 tokens = 0
                 if dataset.indexing_technique == 'high_quality' and embedding_model:
                     # calc embedding use tokens
-                    model_type_instance = cast(TextEmbeddingModel, embedding_model.model_type_instance)
-                    tokens = model_type_instance.get_num_tokens(
-                        model=embedding_model.model,
-                        credentials=embedding_model.credentials,
+                    tokens = embedding_model.get_text_embedding_num_tokens(
                         texts=[content]
                     )
                 segment_document = DocumentSegment(
@@ -1321,10 +1314,7 @@ class SegmentService:
                     )
 
                     # calc embedding use tokens
-                    model_type_instance = cast(TextEmbeddingModel, embedding_model.model_type_instance)
-                    tokens = model_type_instance.get_num_tokens(
-                        model=embedding_model.model,
-                        credentials=embedding_model.credentials,
+                    tokens = embedding_model.get_text_embedding_num_tokens(
                         texts=[content]
                     )
                 segment.content = content

+ 3 - 10
api/services/entities/model_provider_entities.py

@@ -4,10 +4,10 @@ from typing import Optional
 from flask import current_app
 from pydantic import BaseModel
 
-from core.entities.model_entities import ModelStatus, ModelWithProviderEntity
+from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity
 from core.entities.provider_entities import QuotaConfiguration
 from core.model_runtime.entities.common_entities import I18nObject
-from core.model_runtime.entities.model_entities import ModelType, ProviderModel
+from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.entities.provider_entities import (
     ConfigurateMethod,
     ModelCredentialSchema,
@@ -79,13 +79,6 @@ class ProviderResponse(BaseModel):
             )
 
 
-class ModelResponse(ProviderModel):
-    """
-    Model class for model response.
-    """
-    status: ModelStatus
-
-
 class ProviderWithModelsResponse(BaseModel):
     """
     Model class for provider with models response.
@@ -95,7 +88,7 @@ class ProviderWithModelsResponse(BaseModel):
     icon_small: Optional[I18nObject] = None
     icon_large: Optional[I18nObject] = None
     status: CustomConfigurationStatus
-    models: list[ModelResponse]
+    models: list[ProviderModelWithStatusEntity]
 
     def __init__(self, **data) -> None:
         super().__init__(**data)

+ 25 - 12
api/services/feature_service.py

@@ -29,6 +29,7 @@ class FeatureModel(BaseModel):
     documents_upload_quota: LimitationModel = LimitationModel(size=0, limit=50)
     docs_processing: str = 'standard'
     can_replace_logo: bool = False
+    model_load_balancing_enabled: bool = False
 
 
 class SystemFeatureModel(BaseModel):
@@ -63,6 +64,7 @@ class FeatureService:
     @classmethod
     def _fulfill_params_from_env(cls, features: FeatureModel):
         features.can_replace_logo = current_app.config['CAN_REPLACE_LOGO']
+        features.model_load_balancing_enabled = current_app.config['MODEL_LB_ENABLED']
 
     @classmethod
     def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
@@ -72,23 +74,34 @@ class FeatureService:
         features.billing.subscription.plan = billing_info['subscription']['plan']
         features.billing.subscription.interval = billing_info['subscription']['interval']
 
-        features.members.size = billing_info['members']['size']
-        features.members.limit = billing_info['members']['limit']
+        if 'members' in billing_info:
+            features.members.size = billing_info['members']['size']
+            features.members.limit = billing_info['members']['limit']
 
-        features.apps.size = billing_info['apps']['size']
-        features.apps.limit = billing_info['apps']['limit']
+        if 'apps' in billing_info:
+            features.apps.size = billing_info['apps']['size']
+            features.apps.limit = billing_info['apps']['limit']
 
-        features.vector_space.size = billing_info['vector_space']['size']
-        features.vector_space.limit = billing_info['vector_space']['limit']
+        if 'vector_space' in billing_info:
+            features.vector_space.size = billing_info['vector_space']['size']
+            features.vector_space.limit = billing_info['vector_space']['limit']
 
-        features.documents_upload_quota.size = billing_info['documents_upload_quota']['size']
-        features.documents_upload_quota.limit = billing_info['documents_upload_quota']['limit']
+        if 'documents_upload_quota' in billing_info:
+            features.documents_upload_quota.size = billing_info['documents_upload_quota']['size']
+            features.documents_upload_quota.limit = billing_info['documents_upload_quota']['limit']
 
-        features.annotation_quota_limit.size = billing_info['annotation_quota_limit']['size']
-        features.annotation_quota_limit.limit = billing_info['annotation_quota_limit']['limit']
+        if 'annotation_quota_limit' in billing_info:
+            features.annotation_quota_limit.size = billing_info['annotation_quota_limit']['size']
+            features.annotation_quota_limit.limit = billing_info['annotation_quota_limit']['limit']
 
-        features.docs_processing = billing_info['docs_processing']
-        features.can_replace_logo = billing_info['can_replace_logo']
+        if 'docs_processing' in billing_info:
+            features.docs_processing = billing_info['docs_processing']
+
+        if 'can_replace_logo' in billing_info:
+            features.can_replace_logo = billing_info['can_replace_logo']
+
+        if 'model_load_balancing_enabled' in billing_info:
+            features.model_load_balancing_enabled = billing_info['model_load_balancing_enabled']
 
     @classmethod
     def _fulfill_params_from_enterprise(cls, features):

+ 565 - 0
api/services/model_load_balancing_service.py

@@ -0,0 +1,565 @@
+import datetime
+import json
+import logging
+from json import JSONDecodeError
+from typing import Optional
+
+from core.entities.provider_configuration import ProviderConfiguration
+from core.helper import encrypter
+from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
+from core.model_manager import LBModelManager
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.entities.provider_entities import (
+    ModelCredentialSchema,
+    ProviderCredentialSchema,
+)
+from core.model_runtime.model_providers import model_provider_factory
+from core.provider_manager import ProviderManager
+from extensions.ext_database import db
+from models.provider import LoadBalancingModelConfig
+
+logger = logging.getLogger(__name__)
+
+
+class ModelLoadBalancingService:
+
+    def __init__(self) -> None:
+        self.provider_manager = ProviderManager()
+
+    def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
+        """
+        enable model load balancing.
+
+        :param tenant_id: workspace id
+        :param provider: provider name
+        :param model: model name
+        :param model_type: model type
+        :return:
+        """
+        # Get all provider configurations of the current workspace
+        provider_configurations = self.provider_manager.get_configurations(tenant_id)
+
+        # Get provider configuration
+        provider_configuration = provider_configurations.get(provider)
+        if not provider_configuration:
+            raise ValueError(f"Provider {provider} does not exist.")
+
+        # Enable model load balancing
+        provider_configuration.enable_model_load_balancing(
+            model=model,
+            model_type=ModelType.value_of(model_type)
+        )
+
+    def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
+        """
+        disable model load balancing.
+
+        :param tenant_id: workspace id
+        :param provider: provider name
+        :param model: model name
+        :param model_type: model type
+        :return:
+        """
+        # Get all provider configurations of the current workspace
+        provider_configurations = self.provider_manager.get_configurations(tenant_id)
+
+        # Get provider configuration
+        provider_configuration = provider_configurations.get(provider)
+        if not provider_configuration:
+            raise ValueError(f"Provider {provider} does not exist.")
+
+        # disable model load balancing
+        provider_configuration.disable_model_load_balancing(
+            model=model,
+            model_type=ModelType.value_of(model_type)
+        )
+
+    def get_load_balancing_configs(self, tenant_id: str, provider: str, model: str, model_type: str) \
+            -> tuple[bool, list[dict]]:
+        """
+        Get load balancing configurations.
+        :param tenant_id: workspace id
+        :param provider: provider name
+        :param model: model name
+        :param model_type: model type
+        :return:
+        """
+        # Get all provider configurations of the current workspace
+        provider_configurations = self.provider_manager.get_configurations(tenant_id)
+
+        # Get provider configuration
+        provider_configuration = provider_configurations.get(provider)
+        if not provider_configuration:
+            raise ValueError(f"Provider {provider} does not exist.")
+
+        # Convert model type to ModelType
+        model_type = ModelType.value_of(model_type)
+
+        # Get provider model setting
+        provider_model_setting = provider_configuration.get_provider_model_setting(
+            model_type=model_type,
+            model=model,
+        )
+
+        is_load_balancing_enabled = False
+        if provider_model_setting and provider_model_setting.load_balancing_enabled:
+            is_load_balancing_enabled = True
+
+        # Get load balancing configurations
+        load_balancing_configs = db.session.query(LoadBalancingModelConfig) \
+            .filter(
+            LoadBalancingModelConfig.tenant_id == tenant_id,
+            LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
+            LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
+            LoadBalancingModelConfig.model_name == model
+        ).order_by(LoadBalancingModelConfig.created_at).all()
+
+        if provider_configuration.custom_configuration.provider:
+            # check if the inherit configuration exists,
+            # inherit is represented for the provider or model custom credentials
+            inherit_config_exists = False
+            for load_balancing_config in load_balancing_configs:
+                if load_balancing_config.name == '__inherit__':
+                    inherit_config_exists = True
+                    break
+
+            if not inherit_config_exists:
+                # Initialize the inherit configuration
+                inherit_config = self._init_inherit_config(tenant_id, provider, model, model_type)
+
+                # prepend the inherit configuration
+                load_balancing_configs.insert(0, inherit_config)
+            else:
+                # move the inherit configuration to the first
+                for i, load_balancing_config in enumerate(load_balancing_configs):
+                    if load_balancing_config.name == '__inherit__':
+                        inherit_config = load_balancing_configs.pop(i)
+                        load_balancing_configs.insert(0, inherit_config)
+
+        # Get credential form schemas from model credential schema or provider credential schema
+        credential_schemas = self._get_credential_schema(provider_configuration)
+
+        # Get decoding rsa key and cipher for decrypting credentials
+        decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
+
+        # fetch status and ttl for each config
+        datas = []
+        for load_balancing_config in load_balancing_configs:
+            in_cooldown, ttl = LBModelManager.get_config_in_cooldown_and_ttl(
+                tenant_id=tenant_id,
+                provider=provider,
+                model=model,
+                model_type=model_type,
+                config_id=load_balancing_config.id
+            )
+
+            try:
+                if load_balancing_config.encrypted_config:
+                    credentials = json.loads(load_balancing_config.encrypted_config)
+                else:
+                    credentials = {}
+            except JSONDecodeError:
+                credentials = {}
+
+            # Get provider credential secret variables
+            credential_secret_variables = provider_configuration.extract_secret_variables(
+                credential_schemas.credential_form_schemas
+            )
+
+            # decrypt credentials
+            for variable in credential_secret_variables:
+                if variable in credentials:
+                    try:
+                        credentials[variable] = encrypter.decrypt_token_with_decoding(
+                            credentials.get(variable),
+                            decoding_rsa_key,
+                            decoding_cipher_rsa
+                        )
+                    except ValueError:
+                        pass
+
+            # Obfuscate credentials
+            credentials = provider_configuration.obfuscated_credentials(
+                credentials=credentials,
+                credential_form_schemas=credential_schemas.credential_form_schemas
+            )
+
+            datas.append({
+                'id': load_balancing_config.id,
+                'name': load_balancing_config.name,
+                'credentials': credentials,
+                'enabled': load_balancing_config.enabled,
+                'in_cooldown': in_cooldown,
+                'ttl': ttl
+            })
+
+        return is_load_balancing_enabled, datas
+
+    def get_load_balancing_config(self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str) \
+            -> Optional[dict]:
+        """
+        Get load balancing configuration.
+        :param tenant_id: workspace id
+        :param provider: provider name
+        :param model: model name
+        :param model_type: model type
+        :param config_id: load balancing config id
+        :return:
+        """
+        # Get all provider configurations of the current workspace
+        provider_configurations = self.provider_manager.get_configurations(tenant_id)
+
+        # Get provider configuration
+        provider_configuration = provider_configurations.get(provider)
+        if not provider_configuration:
+            raise ValueError(f"Provider {provider} does not exist.")
+
+        # Convert model type to ModelType
+        model_type = ModelType.value_of(model_type)
+
+        # Get load balancing configurations
+        load_balancing_model_config = db.session.query(LoadBalancingModelConfig) \
+            .filter(
+            LoadBalancingModelConfig.tenant_id == tenant_id,
+            LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
+            LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
+            LoadBalancingModelConfig.model_name == model,
+            LoadBalancingModelConfig.id == config_id
+        ).first()
+
+        if not load_balancing_model_config:
+            return None
+
+        try:
+            if load_balancing_model_config.encrypted_config:
+                credentials = json.loads(load_balancing_model_config.encrypted_config)
+            else:
+                credentials = {}
+        except JSONDecodeError:
+            credentials = {}
+
+        # Get credential form schemas from model credential schema or provider credential schema
+        credential_schemas = self._get_credential_schema(provider_configuration)
+
+        # Obfuscate credentials
+        credentials = provider_configuration.obfuscated_credentials(
+            credentials=credentials,
+            credential_form_schemas=credential_schemas.credential_form_schemas
+        )
+
+        return {
+            'id': load_balancing_model_config.id,
+            'name': load_balancing_model_config.name,
+            'credentials': credentials,
+            'enabled': load_balancing_model_config.enabled
+        }
+
+    def _init_inherit_config(self, tenant_id: str, provider: str, model: str, model_type: ModelType) \
+            -> LoadBalancingModelConfig:
+        """
+        Initialize the inherit configuration.
+        :param tenant_id: workspace id
+        :param provider: provider name
+        :param model: model name
+        :param model_type: model type
+        :return:
+        """
+        # Initialize the inherit configuration
+        inherit_config = LoadBalancingModelConfig(
+            tenant_id=tenant_id,
+            provider_name=provider,
+            model_type=model_type.to_origin_model_type(),
+            model_name=model,
+            name='__inherit__'
+        )
+        db.session.add(inherit_config)
+        db.session.commit()
+
+        return inherit_config
+
+    def update_load_balancing_configs(self, tenant_id: str,
+                                      provider: str,
+                                      model: str,
+                                      model_type: str,
+                                      configs: list[dict]) -> None:
+        """
+        Update load balancing configurations.
+        :param tenant_id: workspace id
+        :param provider: provider name
+        :param model: model name
+        :param model_type: model type
+        :param configs: load balancing configs
+        :return:
+        """
+        # Get all provider configurations of the current workspace
+        provider_configurations = self.provider_manager.get_configurations(tenant_id)
+
+        # Get provider configuration
+        provider_configuration = provider_configurations.get(provider)
+        if not provider_configuration:
+            raise ValueError(f"Provider {provider} does not exist.")
+
+        # Convert model type to ModelType
+        model_type = ModelType.value_of(model_type)
+
+        if not isinstance(configs, list):
+            raise ValueError('Invalid load balancing configs')
+
+        current_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \
+            .filter(
+            LoadBalancingModelConfig.tenant_id == tenant_id,
+            LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
+            LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
+            LoadBalancingModelConfig.model_name == model
+        ).all()
+
+        # id as key, config as value
+        current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs}
+        updated_config_ids = set()
+
+        for config in configs:
+            if not isinstance(config, dict):
+                raise ValueError('Invalid load balancing config')
+
+            config_id = config.get('id')
+            name = config.get('name')
+            credentials = config.get('credentials')
+            enabled = config.get('enabled')
+
+            if not name:
+                raise ValueError('Invalid load balancing config name')
+
+            if enabled is None:
+                raise ValueError('Invalid load balancing config enabled')
+
+            # is config exists
+            if config_id:
+                config_id = str(config_id)
+
+                if config_id not in current_load_balancing_configs_dict:
+                    raise ValueError('Invalid load balancing config id: {}'.format(config_id))
+
+                updated_config_ids.add(config_id)
+
+                load_balancing_config = current_load_balancing_configs_dict[config_id]
+
+                # check duplicate name
+                for current_load_balancing_config in current_load_balancing_configs:
+                    if current_load_balancing_config.id != config_id and current_load_balancing_config.name == name:
+                        raise ValueError('Load balancing config name {} already exists'.format(name))
+
+                if credentials:
+                    if not isinstance(credentials, dict):
+                        raise ValueError('Invalid load balancing config credentials')
+
+                    # validate custom provider config
+                    credentials = self._custom_credentials_validate(
+                        tenant_id=tenant_id,
+                        provider_configuration=provider_configuration,
+                        model_type=model_type,
+                        model=model,
+                        credentials=credentials,
+                        load_balancing_model_config=load_balancing_config,
+                        validate=False
+                    )
+
+                    # update load balancing config
+                    load_balancing_config.encrypted_config = json.dumps(credentials)
+
+                load_balancing_config.name = name
+                load_balancing_config.enabled = enabled
+                load_balancing_config.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
+                db.session.commit()
+
+                self._clear_credentials_cache(tenant_id, config_id)
+            else:
+                # create load balancing config
+                if name == '__inherit__':
+                    raise ValueError('Invalid load balancing config name')
+
+                # check duplicate name
+                for current_load_balancing_config in current_load_balancing_configs:
+                    if current_load_balancing_config.name == name:
+                        raise ValueError('Load balancing config name {} already exists'.format(name))
+
+                if not credentials:
+                    raise ValueError('Invalid load balancing config credentials')
+
+                if not isinstance(credentials, dict):
+                    raise ValueError('Invalid load balancing config credentials')
+
+                # validate custom provider config
+                credentials = self._custom_credentials_validate(
+                    tenant_id=tenant_id,
+                    provider_configuration=provider_configuration,
+                    model_type=model_type,
+                    model=model,
+                    credentials=credentials,
+                    validate=False
+                )
+
+                # create load balancing config
+                load_balancing_model_config = LoadBalancingModelConfig(
+                    tenant_id=tenant_id,
+                    provider_name=provider_configuration.provider.provider,
+                    model_type=model_type.to_origin_model_type(),
+                    model_name=model,
+                    name=name,
+                    encrypted_config=json.dumps(credentials)
+                )
+
+                db.session.add(load_balancing_model_config)
+                db.session.commit()
+
+        # get deleted config ids
+        deleted_config_ids = set(current_load_balancing_configs_dict.keys()) - updated_config_ids
+        for config_id in deleted_config_ids:
+            db.session.delete(current_load_balancing_configs_dict[config_id])
+            db.session.commit()
+
+            self._clear_credentials_cache(tenant_id, config_id)
+
+    def validate_load_balancing_credentials(self, tenant_id: str,
+                                            provider: str,
+                                            model: str,
+                                            model_type: str,
+                                            credentials: dict,
+                                            config_id: Optional[str] = None) -> None:
+        """
+        Validate load balancing credentials.
+        :param tenant_id: workspace id
+        :param provider: provider name
+        :param model_type: model type
+        :param model: model name
+        :param credentials: credentials
+        :param config_id: load balancing config id
+        :return:
+        """
+        # Get all provider configurations of the current workspace
+        provider_configurations = self.provider_manager.get_configurations(tenant_id)
+
+        # Get provider configuration
+        provider_configuration = provider_configurations.get(provider)
+        if not provider_configuration:
+            raise ValueError(f"Provider {provider} does not exist.")
+
+        # Convert model type to ModelType
+        model_type = ModelType.value_of(model_type)
+
+        load_balancing_model_config = None
+        if config_id:
+            # Get load balancing config
+            load_balancing_model_config = db.session.query(LoadBalancingModelConfig) \
+                .filter(
+                LoadBalancingModelConfig.tenant_id == tenant_id,
+                LoadBalancingModelConfig.provider_name == provider,
+                LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
+                LoadBalancingModelConfig.model_name == model,
+                LoadBalancingModelConfig.id == config_id
+            ).first()
+
+            if not load_balancing_model_config:
+                raise ValueError(f"Load balancing config {config_id} does not exist.")
+
+        # Validate custom provider config
+        self._custom_credentials_validate(
+            tenant_id=tenant_id,
+            provider_configuration=provider_configuration,
+            model_type=model_type,
+            model=model,
+            credentials=credentials,
+            load_balancing_model_config=load_balancing_model_config
+        )
+
+    def _custom_credentials_validate(self, tenant_id: str,
+                                     provider_configuration: ProviderConfiguration,
+                                     model_type: ModelType,
+                                     model: str,
+                                     credentials: dict,
+                                     load_balancing_model_config: Optional[LoadBalancingModelConfig] = None,
+                                     validate: bool = True) -> dict:
+        """
+        Validate custom credentials.
+        :param tenant_id: workspace id
+        :param provider_configuration: provider configuration
+        :param model_type: model type
+        :param model: model name
+        :param credentials: credentials
+        :param load_balancing_model_config: load balancing model config
+        :param validate: validate credentials
+        :return:
+        """
+        # Get credential form schemas from model credential schema or provider credential schema
+        credential_schemas = self._get_credential_schema(provider_configuration)
+
+        # Get provider credential secret variables
+        provider_credential_secret_variables = provider_configuration.extract_secret_variables(
+            credential_schemas.credential_form_schemas
+        )
+
+        if load_balancing_model_config:
+            try:
+                # fix origin data
+                if load_balancing_model_config.encrypted_config:
+                    original_credentials = json.loads(load_balancing_model_config.encrypted_config)
+                else:
+                    original_credentials = {}
+            except JSONDecodeError:
+                original_credentials = {}
+
+            # encrypt credentials
+            for key, value in credentials.items():
+                if key in provider_credential_secret_variables:
+                    # if send [__HIDDEN__] in secret input, it will be same as original value
+                    if value == '[__HIDDEN__]' and key in original_credentials:
+                        credentials[key] = encrypter.decrypt_token(tenant_id, original_credentials[key])
+
+        if validate:
+            if isinstance(credential_schemas, ModelCredentialSchema):
+                credentials = model_provider_factory.model_credentials_validate(
+                    provider=provider_configuration.provider.provider,
+                    model_type=model_type,
+                    model=model,
+                    credentials=credentials
+                )
+            else:
+                credentials = model_provider_factory.provider_credentials_validate(
+                    provider=provider_configuration.provider.provider,
+                    credentials=credentials
+                )
+
+        for key, value in credentials.items():
+            if key in provider_credential_secret_variables:
+                credentials[key] = encrypter.encrypt_token(tenant_id, value)
+
+        return credentials
+
+    def _get_credential_schema(self, provider_configuration: ProviderConfiguration) \
+            -> ModelCredentialSchema | ProviderCredentialSchema:
+        """
+        Get form schemas.
+        :param provider_configuration: provider configuration
+        :return:
+        """
+        # Get credential form schemas from model credential schema or provider credential schema
+        if provider_configuration.provider.model_credential_schema:
+            credential_schema = provider_configuration.provider.model_credential_schema
+        else:
+            credential_schema = provider_configuration.provider.provider_credential_schema
+
+        return credential_schema
+
+    def _clear_credentials_cache(self, tenant_id: str, config_id: str) -> None:
+        """
+        Clear credentials cache.
+        :param tenant_id: workspace id
+        :param config_id: load balancing config id
+        :return:
+        """
+        provider_model_credentials_cache = ProviderCredentialsCache(
+            tenant_id=tenant_id,
+            identity_id=config_id,
+            cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL
+        )
+
+        provider_model_credentials_cache.delete()

+ 56 - 8
api/services/model_provider_service.py

@@ -6,7 +6,7 @@ from typing import Optional, cast
 import requests
 from flask import current_app
 
-from core.entities.model_entities import ModelStatus
+from core.entities.model_entities import ModelStatus, ProviderModelWithStatusEntity
 from core.model_runtime.entities.model_entities import ModelType, ParameterRule
 from core.model_runtime.model_providers import model_provider_factory
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
@@ -16,7 +16,6 @@ from services.entities.model_provider_entities import (
     CustomConfigurationResponse,
     CustomConfigurationStatus,
     DefaultModelResponse,
-    ModelResponse,
     ModelWithProviderEntityResponse,
     ProviderResponse,
     ProviderWithModelsResponse,
@@ -303,6 +302,9 @@ class ModelProviderService:
             if model.deprecated:
                 continue
 
+            if model.status != ModelStatus.ACTIVE:
+                continue
+
             provider_models[model.provider.provider].append(model)
 
         # convert to ProviderWithModelsResponse list
@@ -313,24 +315,22 @@ class ModelProviderService:
 
             first_model = models[0]
 
-            has_active_models = any([model.status == ModelStatus.ACTIVE for model in models])
-
             providers_with_models.append(
                 ProviderWithModelsResponse(
                     provider=provider,
                     label=first_model.provider.label,
                     icon_small=first_model.provider.icon_small,
                     icon_large=first_model.provider.icon_large,
-                    status=CustomConfigurationStatus.ACTIVE
-                    if has_active_models else CustomConfigurationStatus.NO_CONFIGURE,
-                    models=[ModelResponse(
+                    status=CustomConfigurationStatus.ACTIVE,
+                    models=[ProviderModelWithStatusEntity(
                         model=model.model,
                         label=model.label,
                         model_type=model.model_type,
                         features=model.features,
                         fetch_from=model.fetch_from,
                         model_properties=model.model_properties,
-                        status=model.status
+                        status=model.status,
+                        load_balancing_enabled=model.load_balancing_enabled
                     ) for model in models]
                 )
             )
@@ -486,6 +486,54 @@ class ModelProviderService:
         # Switch preferred provider type
         provider_configuration.switch_preferred_provider_type(preferred_provider_type_enum)
 
+    def enable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
+        """
+        enable model.
+
+        :param tenant_id: workspace id
+        :param provider: provider name
+        :param model: model name
+        :param model_type: model type
+        :return:
+        """
+        # Get all provider configurations of the current workspace
+        provider_configurations = self.provider_manager.get_configurations(tenant_id)
+
+        # Get provider configuration
+        provider_configuration = provider_configurations.get(provider)
+        if not provider_configuration:
+            raise ValueError(f"Provider {provider} does not exist.")
+
+        # Enable model
+        provider_configuration.enable_model(
+            model=model,
+            model_type=ModelType.value_of(model_type)
+        )
+
+    def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
+        """
+        disable model.
+
+        :param tenant_id: workspace id
+        :param provider: provider name
+        :param model: model name
+        :param model_type: model type
+        :return:
+        """
+        # Get all provider configurations of the current workspace
+        provider_configurations = self.provider_manager.get_configurations(tenant_id)
+
+        # Get provider configuration
+        provider_configuration = provider_configurations.get(provider)
+        if not provider_configuration:
+            raise ValueError(f"Provider {provider} does not exist.")
+
+        # Enable model
+        provider_configuration.disable_model(
+            model=model,
+            model_type=ModelType.value_of(model_type)
+        )
+
     def free_quota_submit(self, tenant_id: str, provider: str):
         api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
         api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")

+ 1 - 1
api/services/workflow_service.py

@@ -68,7 +68,7 @@ class WorkflowService:
                             account: Account) -> Workflow:
         """
         Sync draft workflow
-        @throws WorkflowHashNotEqualError
+        :raises WorkflowHashNotEqualError
         """
         # fetch draft workflow by app_model
         workflow = self.get_draft_workflow(app_model=app_model)

+ 1 - 7
api/tasks/batch_create_segment_to_index_task.py

@@ -2,7 +2,6 @@ import datetime
 import logging
 import time
 import uuid
-from typing import cast
 
 import click
 from celery import shared_task
@@ -11,7 +10,6 @@ from sqlalchemy import func
 from core.indexing_runner import IndexingRunner
 from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelType
-from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from libs import helper
@@ -59,16 +57,12 @@ def batch_create_segment_to_index_task(job_id: str, content: list, dataset_id: s
                 model=dataset.embedding_model
             )
 
-        model_type_instance = embedding_model.model_type_instance
-        model_type_instance = cast(TextEmbeddingModel, model_type_instance)
         for segment in content:
             content = segment['content']
             doc_id = str(uuid.uuid4())
             segment_hash = helper.generate_text_hash(content)
             # calc embedding use tokens
-            tokens = model_type_instance.get_num_tokens(
-                model=embedding_model.model,
-                credentials=embedding_model.credentials,
+            tokens = embedding_model.get_text_embedding_num_tokens(
                 texts=[content]
             ) if embedding_model else 0
             max_position = db.session.query(func.max(DocumentSegment.position)).filter(

+ 1 - 1
api/tests/integration_tests/utils/test_module_import_helper.py

@@ -1,6 +1,6 @@
 import os
 
-from core.utils.module_import_helper import import_module_from_source, load_single_subclass_from_source
+from core.helper.module_import_helper import import_module_from_source, load_single_subclass_from_source
 from tests.integration_tests.utils.parent_class import ParentClass
 
 

+ 5 - 3
api/tests/integration_tests/workflow/nodes/test_llm.py

@@ -92,7 +92,8 @@ def test_execute_llm(setup_openai_mock):
                 provider=CustomProviderConfiguration(
                     credentials=credentials
                 )
-            )
+            ),
+            model_settings=[]
         ),
         provider_instance=provider_instance,
         model_type_instance=model_type_instance
@@ -206,10 +207,11 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
                 provider=CustomProviderConfiguration(
                     credentials=credentials
                 )
-            )
+            ),
+            model_settings=[]
         ),
         provider_instance=provider_instance,
-        model_type_instance=model_type_instance
+        model_type_instance=model_type_instance,
     )
 
     model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model='gpt-3.5-turbo')

+ 2 - 1
api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py

@@ -42,7 +42,8 @@ def get_mocked_fetch_model_config(
                 provider=CustomProviderConfiguration(
                     credentials=credentials
                 )
-            )
+            ),
+            model_settings=[]
         ),
         provider_instance=provider_instance,
         model_type_instance=model_type_instance

+ 10 - 1
api/tests/unit_tests/core/prompt/test_prompt_transform.py

@@ -1,9 +1,10 @@
 from unittest.mock import MagicMock
 
 from core.app.app_config.entities import ModelConfigEntity
-from core.entities.provider_configuration import ProviderModelBundle
+from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
 from core.model_runtime.entities.message_entities import UserPromptMessage
 from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule
+from core.model_runtime.entities.provider_entities import ProviderEntity
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.prompt.prompt_transform import PromptTransform
 
@@ -22,8 +23,16 @@ def test__calculate_rest_token():
     large_language_model_mock = MagicMock(spec=LargeLanguageModel)
     large_language_model_mock.get_num_tokens.return_value = 6
 
+    provider_mock = MagicMock(spec=ProviderEntity)
+    provider_mock.provider = 'openai'
+
+    provider_configuration_mock = MagicMock(spec=ProviderConfiguration)
+    provider_configuration_mock.provider = provider_mock
+    provider_configuration_mock.model_settings = None
+
     provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle)
     provider_model_bundle_mock.model_type_instance = large_language_model_mock
+    provider_model_bundle_mock.configuration = provider_configuration_mock
 
     model_config_mock = MagicMock(spec=ModelConfigEntity)
     model_config_mock.model = 'gpt-4'

+ 77 - 0
api/tests/unit_tests/core/test_model_manager.py

@@ -0,0 +1,77 @@
+from unittest.mock import MagicMock
+
+import pytest
+
+from core.entities.provider_entities import ModelLoadBalancingConfiguration
+from core.model_manager import LBModelManager
+from core.model_runtime.entities.model_entities import ModelType
+
+
+@pytest.fixture
+def lb_model_manager():
+    load_balancing_configs = [
+        ModelLoadBalancingConfiguration(
+            id='id1',
+            name='__inherit__',
+            credentials={}
+        ),
+        ModelLoadBalancingConfiguration(
+            id='id2',
+            name='first',
+            credentials={"openai_api_key": "fake_key"}
+        ),
+        ModelLoadBalancingConfiguration(
+            id='id3',
+            name='second',
+            credentials={"openai_api_key": "fake_key"}
+        )
+    ]
+
+    lb_model_manager = LBModelManager(
+        tenant_id='tenant_id',
+        provider='openai',
+        model_type=ModelType.LLM,
+        model='gpt-4',
+        load_balancing_configs=load_balancing_configs,
+        managed_credentials={"openai_api_key": "fake_key"}
+    )
+
+    lb_model_manager.cooldown = MagicMock(return_value=None)
+
+    def is_cooldown(config: ModelLoadBalancingConfiguration):
+        if config.id == 'id1':
+            return True
+
+        return False
+
+    lb_model_manager.in_cooldown = MagicMock(side_effect=is_cooldown)
+
+    return lb_model_manager
+
+
+def test_lb_model_manager_fetch_next(mocker, lb_model_manager):
+    assert len(lb_model_manager._load_balancing_configs) == 3
+
+    config1 = lb_model_manager._load_balancing_configs[0]
+    config2 = lb_model_manager._load_balancing_configs[1]
+    config3 = lb_model_manager._load_balancing_configs[2]
+
+    assert lb_model_manager.in_cooldown(config1) is True
+    assert lb_model_manager.in_cooldown(config2) is False
+    assert lb_model_manager.in_cooldown(config3) is False
+
+    start_index = 0
+    def incr(key):
+        nonlocal start_index
+        start_index += 1
+        return start_index
+
+    mocker.patch('redis.Redis.incr', side_effect=incr)
+    mocker.patch('redis.Redis.set', return_value=None)
+    mocker.patch('redis.Redis.expire', return_value=None)
+
+    config = lb_model_manager.fetch_next()
+    assert config == config2
+
+    config = lb_model_manager.fetch_next()
+    assert config == config3

+ 183 - 0
api/tests/unit_tests/core/test_provider_manager.py

@@ -0,0 +1,183 @@
+from core.entities.provider_entities import ModelSettings
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.model_providers import model_provider_factory
+from core.provider_manager import ProviderManager
+from models.provider import LoadBalancingModelConfig, ProviderModelSetting
+
+
+def test__to_model_settings(mocker):
+    # Get all provider entities
+    provider_entities = model_provider_factory.get_providers()
+
+    provider_entity = None
+    for provider in provider_entities:
+        if provider.provider == 'openai':
+            provider_entity = provider
+
+    # Mocking the inputs
+    provider_model_settings = [ProviderModelSetting(
+        id='id',
+        tenant_id='tenant_id',
+        provider_name='openai',
+        model_name='gpt-4',
+        model_type='text-generation',
+        enabled=True,
+        load_balancing_enabled=True
+    )]
+    load_balancing_model_configs = [
+        LoadBalancingModelConfig(
+            id='id1',
+            tenant_id='tenant_id',
+            provider_name='openai',
+            model_name='gpt-4',
+            model_type='text-generation',
+            name='__inherit__',
+            encrypted_config=None,
+            enabled=True
+        ),
+        LoadBalancingModelConfig(
+            id='id2',
+            tenant_id='tenant_id',
+            provider_name='openai',
+            model_name='gpt-4',
+            model_type='text-generation',
+            name='first',
+            encrypted_config='{"openai_api_key": "fake_key"}',
+            enabled=True
+        )
+    ]
+
+    mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"})
+
+    provider_manager = ProviderManager()
+
+    # Running the method
+    result = provider_manager._to_model_settings(
+        provider_entity,
+        provider_model_settings,
+        load_balancing_model_configs
+    )
+
+    # Asserting that the result is as expected
+    assert len(result) == 1
+    assert isinstance(result[0], ModelSettings)
+    assert result[0].model == 'gpt-4'
+    assert result[0].model_type == ModelType.LLM
+    assert result[0].enabled is True
+    assert len(result[0].load_balancing_configs) == 2
+    assert result[0].load_balancing_configs[0].name == '__inherit__'
+    assert result[0].load_balancing_configs[1].name == 'first'
+
+
+def test__to_model_settings_only_one_lb(mocker):
+    # Get all provider entities
+    provider_entities = model_provider_factory.get_providers()
+
+    provider_entity = None
+    for provider in provider_entities:
+        if provider.provider == 'openai':
+            provider_entity = provider
+
+    # Mocking the inputs
+    provider_model_settings = [ProviderModelSetting(
+        id='id',
+        tenant_id='tenant_id',
+        provider_name='openai',
+        model_name='gpt-4',
+        model_type='text-generation',
+        enabled=True,
+        load_balancing_enabled=True
+    )]
+    load_balancing_model_configs = [
+        LoadBalancingModelConfig(
+            id='id1',
+            tenant_id='tenant_id',
+            provider_name='openai',
+            model_name='gpt-4',
+            model_type='text-generation',
+            name='__inherit__',
+            encrypted_config=None,
+            enabled=True
+        )
+    ]
+
+    mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"})
+
+    provider_manager = ProviderManager()
+
+    # Running the method
+    result = provider_manager._to_model_settings(
+        provider_entity,
+        provider_model_settings,
+        load_balancing_model_configs
+    )
+
+    # Asserting that the result is as expected
+    assert len(result) == 1
+    assert isinstance(result[0], ModelSettings)
+    assert result[0].model == 'gpt-4'
+    assert result[0].model_type == ModelType.LLM
+    assert result[0].enabled is True
+    assert len(result[0].load_balancing_configs) == 0
+
+
+def test__to_model_settings_lb_disabled(mocker):
+    # Get all provider entities
+    provider_entities = model_provider_factory.get_providers()
+
+    provider_entity = None
+    for provider in provider_entities:
+        if provider.provider == 'openai':
+            provider_entity = provider
+
+    # Mocking the inputs
+    provider_model_settings = [ProviderModelSetting(
+        id='id',
+        tenant_id='tenant_id',
+        provider_name='openai',
+        model_name='gpt-4',
+        model_type='text-generation',
+        enabled=True,
+        load_balancing_enabled=False
+    )]
+    load_balancing_model_configs = [
+        LoadBalancingModelConfig(
+            id='id1',
+            tenant_id='tenant_id',
+            provider_name='openai',
+            model_name='gpt-4',
+            model_type='text-generation',
+            name='__inherit__',
+            encrypted_config=None,
+            enabled=True
+        ),
+        LoadBalancingModelConfig(
+            id='id2',
+            tenant_id='tenant_id',
+            provider_name='openai',
+            model_name='gpt-4',
+            model_type='text-generation',
+            name='first',
+            encrypted_config='{"openai_api_key": "fake_key"}',
+            enabled=True
+        )
+    ]
+
+    mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"})
+
+    provider_manager = ProviderManager()
+
+    # Running the method
+    result = provider_manager._to_model_settings(
+        provider_entity,
+        provider_model_settings,
+        load_balancing_model_configs
+    )
+
+    # Asserting that the result is as expected
+    assert len(result) == 1
+    assert isinstance(result[0], ModelSettings)
+    assert result[0].model == 'gpt-4'
+    assert result[0].model_type == ModelType.LLM
+    assert result[0].enabled is True
+    assert len(result[0].load_balancing_configs) == 0

+ 1 - 1
api/tests/unit_tests/utils/position_helper/test_position_helper.py

@@ -2,7 +2,7 @@ from textwrap import dedent
 
 import pytest
 
-from core.utils.position_helper import get_position_map
+from core.helper.position_helper import get_position_map
 
 
 @pytest.fixture