Forráskód Böngészése

feat: server multi models support (#799)

takatost 1 éve
szülő
commit
5fa2161b05
100 módosított fájl, 3368 hozzáadás és 1699 törlés
  1. 2 1
      .github/workflows/check_no_chinese_comments.py
  2. 26 0
      api/.env.example
  3. 17 2
      api/app.py
  4. 24 13
      api/commands.py
  5. 39 13
      api/config.py
  6. 4 1
      api/controllers/console/__init__.py
  7. 21 5
      api/controllers/console/app/app.py
  8. 1 1
      api/controllers/console/app/audio.py
  9. 9 3
      api/controllers/console/app/completion.py
  10. 1 1
      api/controllers/console/app/generator.py
  11. 1 1
      api/controllers/console/app/message.py
  12. 2 2
      api/controllers/console/app/model_config.py
  13. 1 1
      api/controllers/console/datasets/data_source.py
  14. 29 3
      api/controllers/console/datasets/datasets.py
  15. 43 5
      api/controllers/console/datasets/datasets_document.py
  16. 3 1
      api/controllers/console/datasets/hit_testing.py
  17. 1 1
      api/controllers/console/explore/audio.py
  18. 1 1
      api/controllers/console/explore/completion.py
  19. 1 1
      api/controllers/console/explore/message.py
  20. 1 4
      api/controllers/console/explore/parameter.py
  21. 1 1
      api/controllers/console/universal_chat/audio.py
  22. 3 7
      api/controllers/console/universal_chat/chat.py
  23. 1 1
      api/controllers/console/universal_chat/message.py
  24. 1 4
      api/controllers/console/universal_chat/parameter.py
  25. 0 0
      api/controllers/console/webhook/__init__.py
  26. 53 0
      api/controllers/console/webhook/stripe.py
  27. 203 193
      api/controllers/console/workspace/model_providers.py
  28. 108 0
      api/controllers/console/workspace/models.py
  29. 130 0
      api/controllers/console/workspace/providers.py
  30. 1 1
      api/controllers/console/workspace/workspace.py
  31. 1 4
      api/controllers/service_api/app/app.py
  32. 1 1
      api/controllers/service_api/app/audio.py
  33. 1 1
      api/controllers/service_api/app/completion.py
  34. 1 1
      api/controllers/service_api/dataset/document.py
  35. 1 4
      api/controllers/web/app.py
  36. 1 1
      api/controllers/web/audio.py
  37. 1 1
      api/controllers/web/completion.py
  38. 1 1
      api/controllers/web/message.py
  39. 0 36
      api/core/__init__.py
  40. 9 13
      api/core/agent/agent/calc_token_mixin.py
  41. 9 1
      api/core/agent/agent/multi_dataset_router_agent.py
  42. 3 2
      api/core/agent/agent/openai_function_call.py
  43. 11 3
      api/core/agent/agent/openai_function_call_summarize_mixin.py
  44. 3 2
      api/core/agent/agent/openai_multi_function_call.py
  45. 162 0
      api/core/agent/agent/structed_multi_dataset_router_agent.py
  46. 8 2
      api/core/agent/agent/structured_chat.py
  47. 25 11
      api/core/agent/agent_executor.py
  48. 5 4
      api/core/callback_handler/agent_loop_gather_callback_handler.py
  49. 11 7
      api/core/callback_handler/llm_callback_handler.py
  50. 0 2
      api/core/callback_handler/main_chain_gather_callback_handler.py
  51. 86 111
      api/core/completion.py
  52. 0 109
      api/core/constant/llm_constant.py
  53. 22 37
      api/core/conversation_message_task.py
  54. 6 8
      api/core/docstore/dataset_docstore.py
  55. 33 25
      api/core/embedding/cached_embedding.py
  56. 61 83
      api/core/generator/llm_generator.py
  57. 0 0
      api/core/helper/__init__.py
  58. 20 0
      api/core/helper/encrypter.py
  59. 4 10
      api/core/index/index.py
  60. 46 43
      api/core/indexing_runner.py
  61. 0 148
      api/core/llm/llm_builder.py
  62. 0 15
      api/core/llm/moderation.py
  63. 0 138
      api/core/llm/provider/anthropic_provider.py
  64. 0 145
      api/core/llm/provider/azure_provider.py
  65. 0 132
      api/core/llm/provider/base.py
  66. 0 2
      api/core/llm/provider/errors.py
  67. 0 22
      api/core/llm/provider/huggingface_provider.py
  68. 0 53
      api/core/llm/provider/llm_provider_service.py
  69. 0 55
      api/core/llm/provider/openai_provider.py
  70. 0 62
      api/core/llm/streamable_chat_anthropic.py
  71. 0 41
      api/core/llm/token_calculator.py
  72. 0 26
      api/core/llm/whisper.py
  73. 0 27
      api/core/llm/wrappers/anthropic_wrapper.py
  74. 0 31
      api/core/llm/wrappers/openai_wrapper.py
  75. 12 12
      api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py
  76. 0 0
      api/core/model_providers/error.py
  77. 293 0
      api/core/model_providers/model_factory.py
  78. 228 0
      api/core/model_providers/model_provider_factory.py
  79. 0 0
      api/core/model_providers/models/__init__.py
  80. 22 0
      api/core/model_providers/models/base.py
  81. 0 0
      api/core/model_providers/models/embedding/__init__.py
  82. 78 0
      api/core/model_providers/models/embedding/azure_openai_embedding.py
  83. 40 0
      api/core/model_providers/models/embedding/base.py
  84. 35 0
      api/core/model_providers/models/embedding/minimax_embedding.py
  85. 72 0
      api/core/model_providers/models/embedding/openai_embedding.py
  86. 36 0
      api/core/model_providers/models/embedding/replicate_embedding.py
  87. 0 0
      api/core/model_providers/models/entity/__init__.py
  88. 53 0
      api/core/model_providers/models/entity/message.py
  89. 59 0
      api/core/model_providers/models/entity/model_params.py
  90. 10 0
      api/core/model_providers/models/entity/provider.py
  91. 0 0
      api/core/model_providers/models/llm/__init__.py
  92. 107 0
      api/core/model_providers/models/llm/anthropic_model.py
  93. 177 0
      api/core/model_providers/models/llm/azure_openai_model.py
  94. 269 0
      api/core/model_providers/models/llm/base.py
  95. 70 0
      api/core/model_providers/models/llm/chatglm_model.py
  96. 82 0
      api/core/model_providers/models/llm/huggingface_hub_model.py
  97. 70 0
      api/core/model_providers/models/llm/minimax_model.py
  98. 219 0
      api/core/model_providers/models/llm/openai_model.py
  99. 103 0
      api/core/model_providers/models/llm/replicate_model.py
  100. 73 0
      api/core/model_providers/models/llm/spark_model.py

+ 2 - 1
.github/workflows/check_no_chinese_comments.py

@@ -19,7 +19,8 @@ def check_file_for_chinese_comments(file_path):
 
 def main():
     has_chinese = False
-    excluded_files = ["model_template.py", 'stopwords.py', 'commands.py', 'indexing_runner.py', 'web_reader_tool.py']
+    excluded_files = ["model_template.py", 'stopwords.py', 'commands.py',
+                      'indexing_runner.py', 'web_reader_tool.py', 'spark_provider.py']
 
     for root, _, files in os.walk("."):
         for file in files:

+ 26 - 0
api/.env.example

@@ -102,3 +102,29 @@ NOTION_INTEGRATION_TYPE=public
 NOTION_CLIENT_SECRET=you-client-secret
 NOTION_CLIENT_ID=you-client-id
 NOTION_INTERNAL_SECRET=you-internal-secret
+
+# Hosted Model Credentials
+HOSTED_OPENAI_ENABLED=false
+HOSTED_OPENAI_API_KEY=
+HOSTED_OPENAI_API_BASE=
+HOSTED_OPENAI_API_ORGANIZATION=
+HOSTED_OPENAI_QUOTA_LIMIT=200
+HOSTED_OPENAI_PAID_ENABLED=false
+HOSTED_OPENAI_PAID_STRIPE_PRICE_ID=
+HOSTED_OPENAI_PAID_INCREASE_QUOTA=1
+
+HOSTED_AZURE_OPENAI_ENABLED=false
+HOSTED_AZURE_OPENAI_API_KEY=
+HOSTED_AZURE_OPENAI_API_BASE=
+HOSTED_AZURE_OPENAI_QUOTA_LIMIT=200
+
+HOSTED_ANTHROPIC_ENABLED=false
+HOSTED_ANTHROPIC_API_BASE=
+HOSTED_ANTHROPIC_API_KEY=
+HOSTED_ANTHROPIC_QUOTA_LIMIT=1000000
+HOSTED_ANTHROPIC_PAID_ENABLED=false
+HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID=
+HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1
+
+STRIPE_API_KEY=
+STRIPE_WEBHOOK_SECRET=

+ 17 - 2
api/app.py

@@ -16,8 +16,9 @@ from flask import Flask, request, Response, session
 import flask_login
 from flask_cors import CORS
 
+from core.model_providers.providers import hosted
 from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
-    ext_database, ext_storage, ext_mail
+    ext_database, ext_storage, ext_mail, ext_stripe
 from extensions.ext_database import db
 from extensions.ext_login import login_manager
 
@@ -71,7 +72,7 @@ def create_app(test_config=None) -> Flask:
     register_blueprints(app)
     register_commands(app)
 
-    core.init_app(app)
+    hosted.init_app(app)
 
     return app
 
@@ -88,6 +89,7 @@ def initialize_extensions(app):
     ext_login.init_app(app)
     ext_mail.init_app(app)
     ext_sentry.init_app(app)
+    ext_stripe.init_app(app)
 
 
 def _create_tenant_for_account(account):
@@ -246,5 +248,18 @@ def threads():
     }
 
 
+@app.route('/db-pool-stat')
+def pool_stat():
+    engine = db.engine
+    return {
+        'pool_size': engine.pool.size(),
+        'checked_in_connections': engine.pool.checkedin(),
+        'checked_out_connections': engine.pool.checkedout(),
+        'overflow_connections': engine.pool.overflow(),
+        'connection_timeout': engine.pool.timeout(),
+        'recycle_time': db.engine.pool._recycle
+    }
+
+
 if __name__ == '__main__':
     app.run(host='0.0.0.0', port=5001)

+ 24 - 13
api/commands.py

@@ -1,5 +1,5 @@
 import datetime
-import logging
+import math
 import random
 import string
 import time
@@ -9,18 +9,18 @@ from flask import current_app
 from werkzeug.exceptions import NotFound
 
 from core.index.index import IndexBuilder
+from core.model_providers.providers.hosted import hosted_model_providers
 from libs.password import password_pattern, valid_password, hash_password
 from libs.helper import email as email_validate
 from extensions.ext_database import db
 from libs.rsa import generate_key_pair
 from models.account import InvitationCode, Tenant
-from models.dataset import Dataset, DatasetQuery, Document, DocumentSegment
+from models.dataset import Dataset, DatasetQuery, Document
 from models.model import Account
 import secrets
 import base64
 
-from models.provider import Provider, ProviderName
-from services.provider_service import ProviderService
+from models.provider import Provider, ProviderType, ProviderQuotaType
 
 
 @click.command('reset-password', help='Reset the account password.')
@@ -251,26 +251,37 @@ def clean_unused_dataset_indexes():
 
 @click.command('sync-anthropic-hosted-providers', help='Sync anthropic hosted providers.')
 def sync_anthropic_hosted_providers():
+    if not hosted_model_providers.anthropic:
+        click.echo(click.style('Anthropic hosted provider is not configured.', fg='red'))
+        return
+
     click.echo(click.style('Start sync anthropic hosted providers.', fg='green'))
     count = 0
 
     page = 1
     while True:
         try:
-            tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc()).paginate(page=page, per_page=50)
+            providers = db.session.query(Provider).filter(
+                Provider.provider_name == 'anthropic',
+                Provider.provider_type == ProviderType.SYSTEM.value,
+                Provider.quota_type == ProviderQuotaType.TRIAL.value,
+            ).order_by(Provider.created_at.desc()).paginate(page=page, per_page=100)
         except NotFound:
             break
 
         page += 1
-        for tenant in tenants:
+        for provider in providers:
             try:
-                click.echo('Syncing tenant anthropic hosted provider: {}'.format(tenant.id))
-                ProviderService.create_system_provider(
-                    tenant,
-                    ProviderName.ANTHROPIC.value,
-                    current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
-                    True
-                )
+                click.echo('Syncing tenant anthropic hosted provider: {}'.format(provider.tenant_id))
+                original_quota_limit = provider.quota_limit
+                new_quota_limit = hosted_model_providers.anthropic.quota_limit
+                division = math.ceil(new_quota_limit / 1000)
+
+                provider.quota_limit = new_quota_limit if original_quota_limit == 1000 \
+                    else original_quota_limit * division
+                provider.quota_used = division * provider.quota_used
+                db.session.commit()
+
                 count += 1
             except Exception as e:
                 click.echo(click.style(

+ 39 - 13
api/config.py

@@ -41,6 +41,7 @@ DEFAULTS = {
     'SESSION_USE_SIGNER': 'True',
     'DEPLOY_ENV': 'PRODUCTION',
     'SQLALCHEMY_POOL_SIZE': 30,
+    'SQLALCHEMY_POOL_RECYCLE': 3600,
     'SQLALCHEMY_ECHO': 'False',
     'SENTRY_TRACES_SAMPLE_RATE': 1.0,
     'SENTRY_PROFILES_SAMPLE_RATE': 1.0,
@@ -50,9 +51,16 @@ DEFAULTS = {
     'PDF_PREVIEW': 'True',
     'LOG_LEVEL': 'INFO',
     'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False',
-    'DEFAULT_LLM_PROVIDER': 'openai',
-    'OPENAI_HOSTED_QUOTA_LIMIT': 200,
-    'ANTHROPIC_HOSTED_QUOTA_LIMIT': 1000,
+    'HOSTED_OPENAI_QUOTA_LIMIT': 200,
+    'HOSTED_OPENAI_ENABLED': 'False',
+    'HOSTED_OPENAI_PAID_ENABLED': 'False',
+    'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1,
+    'HOSTED_AZURE_OPENAI_ENABLED': 'False',
+    'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
+    'HOSTED_ANTHROPIC_QUOTA_LIMIT': 1000000,
+    'HOSTED_ANTHROPIC_ENABLED': 'False',
+    'HOSTED_ANTHROPIC_PAID_ENABLED': 'False',
+    'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1,
     'TENANT_DOCUMENT_COUNT': 100,
     'CLEAN_DAY_SETTING': 30
 }
@@ -182,7 +190,10 @@ class Config:
         }
 
         self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/{db_credentials['DB_DATABASE']}"
-        self.SQLALCHEMY_ENGINE_OPTIONS = {'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE'))}
+        self.SQLALCHEMY_ENGINE_OPTIONS = {
+            'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE')),
+            'pool_recycle': int(get_env('SQLALCHEMY_POOL_RECYCLE'))
+        }
 
         self.SQLALCHEMY_ECHO = get_bool_env('SQLALCHEMY_ECHO')
 
@@ -194,20 +205,35 @@ class Config:
         self.BROKER_USE_SSL = self.CELERY_BROKER_URL.startswith('rediss://')
 
         # hosted provider credentials
-        self.OPENAI_API_KEY = get_env('OPENAI_API_KEY')
-        self.ANTHROPIC_API_KEY = get_env('ANTHROPIC_API_KEY')
-
-        self.OPENAI_HOSTED_QUOTA_LIMIT = get_env('OPENAI_HOSTED_QUOTA_LIMIT')
-        self.ANTHROPIC_HOSTED_QUOTA_LIMIT = get_env('ANTHROPIC_HOSTED_QUOTA_LIMIT')
+        self.HOSTED_OPENAI_ENABLED = get_bool_env('HOSTED_OPENAI_ENABLED')
+        self.HOSTED_OPENAI_API_KEY = get_env('HOSTED_OPENAI_API_KEY')
+        self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
+        self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
+        self.HOSTED_OPENAI_QUOTA_LIMIT = get_env('HOSTED_OPENAI_QUOTA_LIMIT')
+        self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
+        self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID')
+        self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA'))
+
+        self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
+        self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
+        self.HOSTED_AZURE_OPENAI_API_BASE = get_env('HOSTED_AZURE_OPENAI_API_BASE')
+        self.HOSTED_AZURE_OPENAI_QUOTA_LIMIT = get_env('HOSTED_AZURE_OPENAI_QUOTA_LIMIT')
+
+        self.HOSTED_ANTHROPIC_ENABLED = get_bool_env('HOSTED_ANTHROPIC_ENABLED')
+        self.HOSTED_ANTHROPIC_API_BASE = get_env('HOSTED_ANTHROPIC_API_BASE')
+        self.HOSTED_ANTHROPIC_API_KEY = get_env('HOSTED_ANTHROPIC_API_KEY')
+        self.HOSTED_ANTHROPIC_QUOTA_LIMIT = get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT')
+        self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED')
+        self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID')
+        self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA')
+
+        self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
+        self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
 
         # By default it is False
         # You could disable it for compatibility with certain OpenAPI providers
         self.DISABLE_PROVIDER_CONFIG_VALIDATION = get_bool_env('DISABLE_PROVIDER_CONFIG_VALIDATION')
 
-        # For temp use only
-        # set default LLM provider, default is 'openai', support `azure_openai`
-        self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER')
-
         # notion import setting
         self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
         self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')

+ 4 - 1
api/controllers/console/__init__.py

@@ -18,10 +18,13 @@ from .auth import login, oauth, data_source_oauth, activate
 from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source
 
 # Import workspace controllers
-from .workspace import workspace, members, model_providers, account, tool_providers
+from .workspace import workspace, members, providers, model_providers, account, tool_providers, models
 
 # Import explore controllers
 from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio
 
 # Import universal chat controllers
 from .universal_chat import chat, conversation, message, parameter, audio
+
+# Import webhook controllers
+from .webhook import stripe

+ 21 - 5
api/controllers/console/app/app.py

@@ -2,16 +2,17 @@
 import json
 from datetime import datetime
 
-import flask
 from flask_login import login_required, current_user
 from flask_restful import Resource, reqparse, fields, marshal_with, abort, inputs
-from werkzeug.exceptions import Unauthorized, Forbidden
+from werkzeug.exceptions import Forbidden
 
 from constants.model_template import model_templates, demo_model_templates
 from controllers.console import api
-from controllers.console.app.error import AppNotFoundError
+from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
+from core.model_providers.model_factory import ModelFactory
+from core.model_providers.models.entity.model_params import ModelType
 from events.app_event import app_was_created, app_was_deleted
 from libs.helper import TimestampField
 from extensions.ext_database import db
@@ -126,9 +127,9 @@ class AppListApi(Resource):
         if args['model_config'] is not None:
             # validate config
             model_configuration = AppModelConfigService.validate_configuration(
+                tenant_id=current_user.current_tenant_id,
                 account=current_user,
-                config=args['model_config'],
-                mode=args['mode']
+                config=args['model_config']
             )
 
             app = App(
@@ -164,6 +165,21 @@ class AppListApi(Resource):
             app = App(**model_config_template['app'])
             app_model_config = AppModelConfig(**model_config_template['model_config'])
 
+            default_model = ModelFactory.get_default_model(
+                tenant_id=current_user.current_tenant_id,
+                model_type=ModelType.TEXT_GENERATION
+            )
+
+            if default_model:
+                model_dict = app_model_config.model_dict
+                model_dict['provider'] = default_model.provider_name
+                model_dict['name'] = default_model.model_name
+                app_model_config.model = json.dumps(model_dict)
+            else:
+                raise ProviderNotInitializeError(
+                    f"No Text Generation Model available. Please configure a valid provider "
+                    f"in the Settings -> Model Provider.")
+
         app.name = args['name']
         app.mode = args['mode']
         app.icon = args['icon']

+ 1 - 1
api/controllers/console/app/audio.py

@@ -14,7 +14,7 @@ from controllers.console.app.error import AppUnavailableError, \
     UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
-from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
+from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
     LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
 from flask_restful import Resource
 from services.audio_service import AudioService

+ 9 - 3
api/controllers/console/app/completion.py

@@ -17,7 +17,7 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
 from core.conversation_message_task import PubHandler
-from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
+from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
     LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
 from libs.helper import uuid_value
 from flask_restful import Resource, reqparse
@@ -41,8 +41,11 @@ class CompletionMessageApi(Resource):
         parser.add_argument('inputs', type=dict, required=True, location='json')
         parser.add_argument('query', type=str, location='json')
         parser.add_argument('model_config', type=dict, required=True, location='json')
+        parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         args = parser.parse_args()
 
+        streaming = args['response_mode'] != 'blocking'
+
         account = flask_login.current_user
 
         try:
@@ -51,7 +54,7 @@ class CompletionMessageApi(Resource):
                 user=account,
                 args=args,
                 from_source='console',
-                streaming=True,
+                streaming=streaming,
                 is_model_config_override=True
             )
 
@@ -111,8 +114,11 @@ class ChatMessageApi(Resource):
         parser.add_argument('query', type=str, required=True, location='json')
         parser.add_argument('model_config', type=dict, required=True, location='json')
         parser.add_argument('conversation_id', type=uuid_value, location='json')
+        parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         args = parser.parse_args()
 
+        streaming = args['response_mode'] != 'blocking'
+
         account = flask_login.current_user
 
         try:
@@ -121,7 +127,7 @@ class ChatMessageApi(Resource):
                 user=account,
                 args=args,
                 from_source='console',
-                streaming=True,
+                streaming=streaming,
                 is_model_config_override=True
             )
 

+ 1 - 1
api/controllers/console/app/generator.py

@@ -7,7 +7,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
 from core.generator.llm_generator import LLMGenerator
-from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \
+from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \
     LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
 
 

+ 1 - 1
api/controllers/console/app/message.py

@@ -14,7 +14,7 @@ from controllers.console.app.error import CompletionRequestError, ProviderNotIni
     AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
-from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
+from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
     ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
 from libs.helper import uuid_value, TimestampField
 from libs.infinite_scroll_pagination import InfiniteScrollPagination

+ 2 - 2
api/controllers/console/app/model_config.py

@@ -28,9 +28,9 @@ class ModelConfigResource(Resource):
 
         # validate config
         model_configuration = AppModelConfigService.validate_configuration(
+            tenant_id=current_user.current_tenant_id,
             account=current_user,
-            config=request.json,
-            mode=app_model.mode
+            config=request.json
         )
 
         new_app_model_config = AppModelConfig(

+ 1 - 1
api/controllers/console/datasets/data_source.py

@@ -255,7 +255,7 @@ class DataSourceNotionApi(Resource):
         # validate args
         DocumentService.estimate_args_validate(args)
         indexing_runner = IndexingRunner()
-        response = indexing_runner.notion_indexing_estimate(args['notion_info_list'], args['process_rule'])
+        response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, args['notion_info_list'], args['process_rule'])
         return response, 200
 
 

+ 29 - 3
api/controllers/console/datasets/datasets.py

@@ -5,10 +5,13 @@ from flask_restful import Resource, reqparse, fields, marshal, marshal_with
 from werkzeug.exceptions import NotFound, Forbidden
 import services
 from controllers.console import api
+from controllers.console.app.error import ProviderNotInitializeError
 from controllers.console.datasets.error import DatasetNameDuplicateError
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
 from core.indexing_runner import IndexingRunner
+from core.model_providers.error import LLMBadRequestError
+from core.model_providers.model_factory import ModelFactory
 from libs.helper import TimestampField
 from extensions.ext_database import db
 from models.dataset import DocumentSegment, Document
@@ -97,6 +100,15 @@ class DatasetListApi(Resource):
         if current_user.current_tenant.current_role not in ['admin', 'owner']:
             raise Forbidden()
 
+        try:
+            ModelFactory.get_embedding_model(
+                tenant_id=current_user.current_tenant_id
+            )
+        except LLMBadRequestError:
+            raise ProviderNotInitializeError(
+                f"No Embedding Model available. Please configure a valid provider "
+                f"in the Settings -> Model Provider.")
+
         try:
             dataset = DatasetService.create_empty_dataset(
                 tenant_id=current_user.current_tenant_id,
@@ -235,12 +247,26 @@ class DatasetIndexingEstimateApi(Resource):
                 raise NotFound("File not found.")
 
             indexing_runner = IndexingRunner()
-            response = indexing_runner.file_indexing_estimate(file_details, args['process_rule'], args['doc_form'])
+
+            try:
+                response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
+                                                                  args['process_rule'], args['doc_form'])
+            except LLMBadRequestError:
+                raise ProviderNotInitializeError(
+                    f"No Embedding Model available. Please configure a valid provider "
+                    f"in the Settings -> Model Provider.")
         elif args['info_list']['data_source_type'] == 'notion_import':
 
             indexing_runner = IndexingRunner()
-            response = indexing_runner.notion_indexing_estimate(args['info_list']['notion_info_list'],
-                                                                args['process_rule'], args['doc_form'])
+
+            try:
+                response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
+                                                                    args['info_list']['notion_info_list'],
+                                                                    args['process_rule'], args['doc_form'])
+            except LLMBadRequestError:
+                raise ProviderNotInitializeError(
+                    f"No Embedding Model available. Please configure a valid provider "
+                    f"in the Settings -> Model Provider.")
         else:
             raise ValueError('Data source type not support')
         return response, 200

+ 43 - 5
api/controllers/console/datasets/datasets_document.py

@@ -18,7 +18,9 @@ from controllers.console.datasets.error import DocumentAlreadyFinishedError, Inv
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
 from core.indexing_runner import IndexingRunner
-from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
+    LLMBadRequestError
+from core.model_providers.model_factory import ModelFactory
 from extensions.ext_redis import redis_client
 from libs.helper import TimestampField
 from extensions.ext_database import db
@@ -280,6 +282,15 @@ class DatasetDocumentListApi(Resource):
         # validate args
         DocumentService.document_create_args_validate(args)
 
+        try:
+            ModelFactory.get_embedding_model(
+                tenant_id=current_user.current_tenant_id
+            )
+        except LLMBadRequestError:
+            raise ProviderNotInitializeError(
+                f"No Embedding Model available. Please configure a valid provider "
+                f"in the Settings -> Model Provider.")
+
         try:
             documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
         except ProviderTokenNotInitError as ex:
@@ -319,6 +330,15 @@ class DatasetInitApi(Resource):
         parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
         args = parser.parse_args()
 
+        try:
+            ModelFactory.get_embedding_model(
+                tenant_id=current_user.current_tenant_id
+            )
+        except LLMBadRequestError:
+            raise ProviderNotInitializeError(
+                f"No Embedding Model available. Please configure a valid provider "
+                f"in the Settings -> Model Provider.")
+
         # validate args
         DocumentService.document_create_args_validate(args)
 
@@ -384,7 +404,13 @@ class DocumentIndexingEstimateApi(DocumentResource):
 
                 indexing_runner = IndexingRunner()
 
-                response = indexing_runner.file_indexing_estimate([file], data_process_rule_dict)
+                try:
+                    response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file],
+                                                                      data_process_rule_dict)
+                except LLMBadRequestError:
+                    raise ProviderNotInitializeError(
+                        f"No Embedding Model available. Please configure a valid provider "
+                        f"in the Settings -> Model Provider.")
 
         return response
 
@@ -445,12 +471,24 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
                 raise NotFound("File not found.")
 
             indexing_runner = IndexingRunner()
-            response = indexing_runner.file_indexing_estimate(file_details, data_process_rule_dict)
+            try:
+                response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
+                                                                  data_process_rule_dict)
+            except LLMBadRequestError:
+                raise ProviderNotInitializeError(
+                    f"No Embedding Model available. Please configure a valid provider "
+                    f"in the Settings -> Model Provider.")
         elif dataset.data_source_type:
 
             indexing_runner = IndexingRunner()
-            response = indexing_runner.notion_indexing_estimate(info_list,
-                                                                data_process_rule_dict)
+            try:
+                response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
+                                                                    info_list,
+                                                                    data_process_rule_dict)
+            except LLMBadRequestError:
+                raise ProviderNotInitializeError(
+                    f"No Embedding Model available. Please configure a valid provider "
+                    f"in the Settings -> Model Provider.")
         else:
             raise ValueError('Data source type not support')
         return response

+ 3 - 1
api/controllers/console/datasets/hit_testing.py

@@ -11,7 +11,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
 from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
-from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
 from libs.helper import TimestampField
 from services.dataset_service import DatasetService
 from services.hit_testing_service import HitTestingService
@@ -102,6 +102,8 @@ class HitTestingApi(Resource):
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
+        except ValueError as e:
+            raise ValueError(str(e))
         except Exception as e:
             logging.exception("Hit testing failed.")
             raise InternalServerError(str(e))

+ 1 - 1
api/controllers/console/explore/audio.py

@@ -11,7 +11,7 @@ from controllers.console.app.error import AppUnavailableError, ProviderNotInitia
     NoAudioUploadedError, AudioTooLargeError, \
     UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
 from controllers.console.explore.wraps import InstalledAppResource
-from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
+from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
     LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
 from services.audio_service import AudioService
 from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \

+ 1 - 1
api/controllers/console/explore/completion.py

@@ -15,7 +15,7 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail
 from controllers.console.explore.error import NotCompletionAppError, NotChatAppError
 from controllers.console.explore.wraps import InstalledAppResource
 from core.conversation_message_task import PubHandler
-from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
+from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
     LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
 from libs.helper import uuid_value
 from services.completion_service import CompletionService

+ 1 - 1
api/controllers/console/explore/message.py

@@ -15,7 +15,7 @@ from controllers.console.app.error import AppMoreLikeThisDisabledError, Provider
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
 from controllers.console.explore.error import NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError
 from controllers.console.explore.wraps import InstalledAppResource
-from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
+from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
     ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
 from libs.helper import uuid_value, TimestampField
 from services.completion_service import CompletionService

+ 1 - 4
api/controllers/console/explore/parameter.py

@@ -4,8 +4,6 @@ from flask_restful import marshal_with, fields
 from controllers.console import api
 from controllers.console.explore.wraps import InstalledAppResource
 
-from core.llm.llm_builder import LLMBuilder
-from models.provider import ProviderName
 from models.model import InstalledApp
 
 
@@ -35,13 +33,12 @@ class AppParameterApi(InstalledAppResource):
         """Retrieve app parameters."""
         app_model = installed_app.app
         app_model_config = app_model.app_model_config
-        provider_name = LLMBuilder.get_default_provider(installed_app.tenant_id, 'whisper-1')
 
         return {
             'opening_statement': app_model_config.opening_statement,
             'suggested_questions': app_model_config.suggested_questions_list,
             'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
-            'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
+            'speech_to_text': app_model_config.speech_to_text_dict,
             'more_like_this': app_model_config.more_like_this_dict,
             'user_input_form': app_model_config.user_input_form_list
         }

+ 1 - 1
api/controllers/console/universal_chat/audio.py

@@ -11,7 +11,7 @@ from controllers.console.app.error import AppUnavailableError, ProviderNotInitia
     NoAudioUploadedError, AudioTooLargeError, \
     UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
 from controllers.console.universal_chat.wraps import UniversalChatResource
-from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
+from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
     LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
 from services.audio_service import AudioService
 from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \

+ 3 - 7
api/controllers/console/universal_chat/chat.py

@@ -12,9 +12,8 @@ from controllers.console import api
 from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, ProviderNotInitializeError, \
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
 from controllers.console.universal_chat.wraps import UniversalChatResource
-from core.constant import llm_constant
 from core.conversation_message_task import PubHandler
-from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
+from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
     LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError
 from libs.helper import uuid_value
 from services.completion_service import CompletionService
@@ -27,6 +26,7 @@ class UniversalChatApi(UniversalChatResource):
         parser = reqparse.RequestParser()
         parser.add_argument('query', type=str, required=True, location='json')
         parser.add_argument('conversation_id', type=uuid_value, location='json')
+        parser.add_argument('provider', type=str, required=True, location='json')
         parser.add_argument('model', type=str, required=True, location='json')
         parser.add_argument('tools', type=list, required=True, location='json')
         args = parser.parse_args()
@@ -36,11 +36,7 @@ class UniversalChatApi(UniversalChatResource):
         # update app model config
         args['model_config'] = app_model_config.to_dict()
         args['model_config']['model']['name'] = args['model']
-
-        if not llm_constant.models[args['model']]:
-            raise ValueError("Model not exists.")
-
-        args['model_config']['model']['provider'] = llm_constant.models[args['model']]
+        args['model_config']['model']['provider'] = args['provider']
         args['model_config']['agent_mode']['tools'] = args['tools']
 
         if not args['model_config']['agent_mode']['tools']:

+ 1 - 1
api/controllers/console/universal_chat/message.py

@@ -12,7 +12,7 @@ from controllers.console.app.error import ProviderNotInitializeError, \
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
 from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
 from controllers.console.universal_chat.wraps import UniversalChatResource
-from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
+from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
     ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
 from libs.helper import uuid_value, TimestampField
 from services.errors.conversation import ConversationNotExistsError

+ 1 - 4
api/controllers/console/universal_chat/parameter.py

@@ -4,8 +4,6 @@ from flask_restful import marshal_with, fields
 from controllers.console import api
 from controllers.console.universal_chat.wraps import UniversalChatResource
 
-from core.llm.llm_builder import LLMBuilder
-from models.provider import ProviderName
 from models.model import App
 
 
@@ -23,13 +21,12 @@ class UniversalChatParameterApi(UniversalChatResource):
         """Retrieve app parameters."""
         app_model = universal_app
         app_model_config = app_model.app_model_config
-        provider_name = LLMBuilder.get_default_provider(universal_app.tenant_id, 'whisper-1')
 
         return {
             'opening_statement': app_model_config.opening_statement,
             'suggested_questions': app_model_config.suggested_questions_list,
             'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
-            'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
+            'speech_to_text': app_model_config.speech_to_text_dict,
         }
 
 

+ 0 - 0
api/tests/test_controllers/__init__.py → api/controllers/console/webhook/__init__.py


+ 53 - 0
api/controllers/console/webhook/stripe.py

@@ -0,0 +1,53 @@
+import logging
+
+import stripe
+from flask import request, current_app
+from flask_restful import Resource
+
+from controllers.console import api
+from controllers.console.setup import setup_required
+from controllers.console.wraps import only_edition_cloud
+from services.provider_checkout_service import ProviderCheckoutService
+
+
+class StripeWebhookApi(Resource):
+    @setup_required
+    @only_edition_cloud
+    def post(self):
+        payload = request.data
+        sig_header = request.headers.get('STRIPE_SIGNATURE')
+        webhook_secret = current_app.config.get('STRIPE_WEBHOOK_SECRET')
+
+        try:
+            event = stripe.Webhook.construct_event(
+                payload, sig_header, webhook_secret
+            )
+        except ValueError as e:
+            # Invalid payload
+            return 'Invalid payload', 400
+        except stripe.error.SignatureVerificationError as e:
+            # Invalid signature
+            return 'Invalid signature', 400
+
+        # Handle the checkout.session.completed event
+        if event['type'] == 'checkout.session.completed':
+            logging.debug(event['data']['object']['id'])
+            logging.debug(event['data']['object']['amount_subtotal'])
+            logging.debug(event['data']['object']['currency'])
+            logging.debug(event['data']['object']['payment_intent'])
+            logging.debug(event['data']['object']['payment_status'])
+            logging.debug(event['data']['object']['metadata'])
+
+            # Fulfill the purchase...
+            provider_checkout_service = ProviderCheckoutService()
+
+            try:
+                provider_checkout_service.fulfill_provider_order(event)
+            except Exception as e:
+                logging.debug(str(e))
+                return 'success', 200
+
+        return 'success', 200
+
+
+api.add_resource(StripeWebhookApi, '/webhook/stripe')

+ 203 - 193
api/controllers/console/workspace/model_providers.py

@@ -1,24 +1,18 @@
-# -*- coding:utf-8 -*-
-import base64
-import json
-import logging
-
-from flask import current_app
 from flask_login import login_required, current_user
-from flask_restful import Resource, reqparse, abort
+from flask_restful import Resource, reqparse
 from werkzeug.exceptions import Forbidden
 
 from controllers.console import api
+from controllers.console.app.error import ProviderNotInitializeError
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
-from core.llm.provider.errors import ValidateFailedError
-from extensions.ext_database import db
-from libs import rsa
-from models.provider import Provider, ProviderType, ProviderName
+from core.model_providers.error import LLMBadRequestError
+from core.model_providers.providers.base import CredentialsValidateFailedError
+from services.provider_checkout_service import ProviderCheckoutService
 from services.provider_service import ProviderService
 
 
-class ProviderListApi(Resource):
+class ModelProviderListApi(Resource):
 
     @setup_required
     @login_required
@@ -26,156 +20,115 @@ class ProviderListApi(Resource):
     def get(self):
         tenant_id = current_user.current_tenant_id
 
-        """
-        If the type is AZURE_OPENAI, decode and return the four fields of azure_api_type, azure_api_version:, 
-        azure_api_base, azure_api_key as an object, where azure_api_key displays the first 6 bits in plaintext, and the 
-        rest is replaced by * and the last two bits are displayed in plaintext
-        
-        If the type is other, decode and return the Token field directly, the field displays the first 6 bits in 
-        plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
-        """
-
-        ProviderService.init_supported_provider(current_user.current_tenant)
-        providers = Provider.query.filter_by(tenant_id=tenant_id).all()
-
-        provider_list = [
-            {
-                'provider_name': p.provider_name,
-                'provider_type': p.provider_type,
-                'is_valid': p.is_valid,
-                'last_used': p.last_used,
-                'is_enabled': p.is_enabled,
-                **({
-                       'quota_type': p.quota_type,
-                       'quota_limit': p.quota_limit,
-                       'quota_used': p.quota_used
-                   } if p.provider_type == ProviderType.SYSTEM.value else {}),
-                'token': ProviderService.get_obfuscated_api_key(current_user.current_tenant,
-                                                                ProviderName(p.provider_name), only_custom=True)
-                if p.provider_type == ProviderType.CUSTOM.value else None
-            }
-            for p in providers
-        ]
+        provider_service = ProviderService()
+        provider_list = provider_service.get_provider_list(tenant_id)
 
         return provider_list
 
 
-class ProviderTokenApi(Resource):
+class ModelProviderValidateApi(Resource):
 
     @setup_required
     @login_required
     @account_initialization_required
-    def post(self, provider):
-        if provider not in [p.value for p in ProviderName]:
-            abort(404)
+    def post(self, provider_name: str):
+
+        parser = reqparse.RequestParser()
+        parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
+        args = parser.parse_args()
+
+        provider_service = ProviderService()
+
+        result = True
+        error = None
+
+        try:
+            provider_service.custom_provider_config_validate(
+                provider_name=provider_name,
+                config=args['config']
+            )
+        except CredentialsValidateFailedError as ex:
+            result = False
+            error = str(ex)
+
+        response = {'result': 'success' if result else 'error'}
+
+        if not result:
+            response['error'] = error
+
+        return response
+
+
+class ModelProviderUpdateApi(Resource):
 
-        # The role of the current user in the ta table must be admin or owner
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def post(self, provider_name: str):
         if current_user.current_tenant.current_role not in ['admin', 'owner']:
-            logging.log(logging.ERROR,
-                        f'User {current_user.id} is not authorized to update provider token, current_role is {current_user.current_tenant.current_role}')
             raise Forbidden()
 
         parser = reqparse.RequestParser()
-
-        parser.add_argument('token', type=ProviderService.get_token_type(
-            tenant=current_user.current_tenant,
-            provider_name=ProviderName(provider)
-        ), required=True, nullable=False, location='json')
-
+        parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
         args = parser.parse_args()
 
-        if args['token']:
-            try:
-                ProviderService.validate_provider_configs(
-                    tenant=current_user.current_tenant,
-                    provider_name=ProviderName(provider),
-                    configs=args['token']
-                )
-                token_is_valid = True
-            except ValidateFailedError as ex:
-                raise ValueError(str(ex))
-
-            base64_encrypted_token = ProviderService.get_encrypted_token(
-                tenant=current_user.current_tenant,
-                provider_name=ProviderName(provider),
-                configs=args['token']
+        provider_service = ProviderService()
+
+        try:
+            provider_service.save_custom_provider_config(
+                tenant_id=current_user.current_tenant_id,
+                provider_name=provider_name,
+                config=args['config']
             )
-        else:
-            base64_encrypted_token = None
-            token_is_valid = False
-
-        tenant = current_user.current_tenant
-
-        provider_model = db.session.query(Provider).filter(
-                Provider.tenant_id == tenant.id,
-                Provider.provider_name == provider,
-                Provider.provider_type == ProviderType.CUSTOM.value
-            ).first()
-
-        # Only allow updating token for CUSTOM provider type
-        if provider_model:
-            provider_model.encrypted_config = base64_encrypted_token
-            provider_model.is_valid = token_is_valid
-        else:
-            provider_model = Provider(tenant_id=tenant.id, provider_name=provider,
-                                      provider_type=ProviderType.CUSTOM.value,
-                                      encrypted_config=base64_encrypted_token,
-                                      is_valid=token_is_valid)
-            db.session.add(provider_model)
-
-        if provider in [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value] and provider_model.is_valid:
-            other_providers = db.session.query(Provider).filter(
-                Provider.tenant_id == tenant.id,
-                Provider.provider_name.in_([ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value]),
-                Provider.provider_name != provider,
-                Provider.provider_type == ProviderType.CUSTOM.value
-            ).all()
-
-            for other_provider in other_providers:
-                other_provider.is_valid = False
-
-        db.session.commit()
-
-        if provider in [ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,
-                        ProviderName.HUGGINGFACEHUB.value]:
-            return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}, 201
+        except CredentialsValidateFailedError as ex:
+            raise ValueError(str(ex))
 
         return {'result': 'success'}, 201
 
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def delete(self, provider_name: str):
+        if current_user.current_tenant.current_role not in ['admin', 'owner']:
+            raise Forbidden()
 
-class ProviderTokenValidateApi(Resource):
+        provider_service = ProviderService()
+        provider_service.delete_custom_provider(
+            tenant_id=current_user.current_tenant_id,
+            provider_name=provider_name
+        )
+
+        return {'result': 'success'}, 204
+
+
+class ModelProviderModelValidateApi(Resource):
 
     @setup_required
     @login_required
     @account_initialization_required
-    def post(self, provider):
-        if provider not in [p.value for p in ProviderName]:
-            abort(404)
-
+    def post(self, provider_name: str):
         parser = reqparse.RequestParser()
-        parser.add_argument('token', type=ProviderService.get_token_type(
-            tenant=current_user.current_tenant,
-            provider_name=ProviderName(provider)
-        ), required=True, nullable=False, location='json')
+        parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
+        parser.add_argument('model_type', type=str, required=True, nullable=False,
+                            choices=['text-generation', 'embeddings', 'speech2text'], location='json')
+        parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
         args = parser.parse_args()
 
-        # todo: remove this when the provider is supported
-        if provider in [ProviderName.COHERE.value,
-                        ProviderName.HUGGINGFACEHUB.value]:
-            return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}
+        provider_service = ProviderService()
 
         result = True
         error = None
 
         try:
-            ProviderService.validate_provider_configs(
-                tenant=current_user.current_tenant,
-                provider_name=ProviderName(provider),
-                configs=args['token']
+            provider_service.custom_provider_model_config_validate(
+                provider_name=provider_name,
+                model_name=args['model_name'],
+                model_type=args['model_type'],
+                config=args['config']
             )
-        except ValidateFailedError as e:
+        except CredentialsValidateFailedError as ex:
             result = False
-            error = str(e)
+            error = str(ex)
 
         response = {'result': 'success' if result else 'error'}
 
@@ -185,91 +138,148 @@ class ProviderTokenValidateApi(Resource):
         return response
 
 
-class ProviderSystemApi(Resource):
+class ModelProviderModelUpdateApi(Resource):
 
     @setup_required
     @login_required
     @account_initialization_required
-    def put(self, provider):
-        if provider not in [p.value for p in ProviderName]:
-            abort(404)
+    def post(self, provider_name: str):
+        if current_user.current_tenant.current_role not in ['admin', 'owner']:
+            raise Forbidden()
 
         parser = reqparse.RequestParser()
-        parser.add_argument('is_enabled', type=bool, required=True, location='json')
+        parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
+        parser.add_argument('model_type', type=str, required=True, nullable=False,
+                            choices=['text-generation', 'embeddings', 'speech2text'], location='json')
+        parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
         args = parser.parse_args()
 
-        tenant = current_user.current_tenant_id
-
-        provider_model = Provider.query.filter_by(tenant_id=tenant.id, provider_name=provider).first()
-
-        if provider_model and provider_model.provider_type == ProviderType.SYSTEM.value:
-            provider_model.is_valid = args['is_enabled']
-            db.session.commit()
-        elif not provider_model:
-            if provider == ProviderName.OPENAI.value:
-                quota_limit = current_app.config['OPENAI_HOSTED_QUOTA_LIMIT']
-            elif provider == ProviderName.ANTHROPIC.value:
-                quota_limit = current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT']
-            else:
-                quota_limit = 0
-
-            ProviderService.create_system_provider(
-                tenant,
-                provider,
-                quota_limit,
-                args['is_enabled']
+        provider_service = ProviderService()
+
+        try:
+            provider_service.add_or_save_custom_provider_model_config(
+                tenant_id=current_user.current_tenant_id,
+                provider_name=provider_name,
+                model_name=args['model_name'],
+                model_type=args['model_type'],
+                config=args['config']
             )
-        else:
-            abort(403)
+        except CredentialsValidateFailedError as ex:
+            raise ValueError(str(ex))
 
-        return {'result': 'success'}
+        return {'result': 'success'}, 200
 
     @setup_required
     @login_required
     @account_initialization_required
-    def get(self, provider):
-        if provider not in [p.value for p in ProviderName]:
-            abort(404)
+    def delete(self, provider_name: str):
+        if current_user.current_tenant.current_role not in ['admin', 'owner']:
+            raise Forbidden()
+
+        parser = reqparse.RequestParser()
+        parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
+        parser.add_argument('model_type', type=str, required=True, nullable=False,
+                            choices=['text-generation', 'embeddings', 'speech2text'], location='args')
+        args = parser.parse_args()
 
-        # The role of the current user in the ta table must be admin or owner
+        provider_service = ProviderService()
+        provider_service.delete_custom_provider_model(
+            tenant_id=current_user.current_tenant_id,
+            provider_name=provider_name,
+            model_name=args['model_name'],
+            model_type=args['model_type']
+        )
+
+        return {'result': 'success'}, 204
+
+
+class PreferredProviderTypeUpdateApi(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def post(self, provider_name: str):
         if current_user.current_tenant.current_role not in ['admin', 'owner']:
             raise Forbidden()
 
-        provider_model = db.session.query(Provider).filter(Provider.tenant_id == current_user.current_tenant_id,
-                                                           Provider.provider_name == provider,
-                                                           Provider.provider_type == ProviderType.SYSTEM.value).first()
-
-        system_model = None
-        if provider_model:
-            system_model = {
-                'result': 'success',
-                'provider': {
-                    'provider_name': provider_model.provider_name,
-                    'provider_type': provider_model.provider_type,
-                    'is_valid': provider_model.is_valid,
-                    'last_used': provider_model.last_used,
-                    'is_enabled': provider_model.is_enabled,
-                    'quota_type': provider_model.quota_type,
-                    'quota_limit': provider_model.quota_limit,
-                    'quota_used': provider_model.quota_used
-                }
-            }
-        else:
-            abort(404)
+        parser = reqparse.RequestParser()
+        parser.add_argument('preferred_provider_type', type=str, required=True, nullable=False,
+                            choices=['system', 'custom'], location='json')
+        args = parser.parse_args()
+
+        provider_service = ProviderService()
+        provider_service.switch_preferred_provider(
+            tenant_id=current_user.current_tenant_id,
+            provider_name=provider_name,
+            preferred_provider_type=args['preferred_provider_type']
+        )
+
+        return {'result': 'success'}
+
+
+class ModelProviderModelParameterRuleApi(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self, provider_name: str):
+        parser = reqparse.RequestParser()
+        parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
+        args = parser.parse_args()
 
-        return system_model
+        provider_service = ProviderService()
 
+        try:
+            parameter_rules = provider_service.get_model_parameter_rules(
+                tenant_id=current_user.current_tenant_id,
+                model_provider_name=provider_name,
+                model_name=args['model_name'],
+                model_type='text-generation'
+            )
+        except LLMBadRequestError:
+            raise ProviderNotInitializeError(
+                f"Current Text Generation Model is invalid. Please switch to the available model.")
+
+        rules = {
+            k: {
+                'enabled': v.enabled,
+                'min': v.min,
+                'max': v.max,
+                'default': v.default
+            }
+            for k, v in vars(parameter_rules).items()
+        }
 
-api.add_resource(ProviderTokenApi, '/providers/<provider>/token',
-                 endpoint='current_providers_token')  # Deprecated
-api.add_resource(ProviderTokenValidateApi, '/providers/<provider>/token-validate',
-                 endpoint='current_providers_token_validate')  # Deprecated
+        return rules
 
-api.add_resource(ProviderTokenApi, '/workspaces/current/providers/<provider>/token',
-                 endpoint='workspaces_current_providers_token')  # PUT for updating provider token
-api.add_resource(ProviderTokenValidateApi, '/workspaces/current/providers/<provider>/token-validate',
-                 endpoint='workspaces_current_providers_token_validate')  # POST for validating provider token
 
-api.add_resource(ProviderListApi, '/workspaces/current/providers')  # GET for getting providers list
-api.add_resource(ProviderSystemApi, '/workspaces/current/providers/<provider>/system',
-                 endpoint='workspaces_current_providers_system')  # GET for getting provider quota, PUT for updating provider status
+class ModelProviderPaymentCheckoutUrlApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self, provider_name: str):
+        provider_service = ProviderCheckoutService()
+        provider_checkout = provider_service.create_checkout(
+            tenant_id=current_user.current_tenant_id,
+            provider_name=provider_name,
+            account=current_user
+        )
+
+        return {
+            'url': provider_checkout.get_checkout_url()
+        }
+
+
+api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers')
+api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider_name>/validate')
+api.add_resource(ModelProviderUpdateApi, '/workspaces/current/model-providers/<string:provider_name>')
+api.add_resource(ModelProviderModelValidateApi,
+                 '/workspaces/current/model-providers/<string:provider_name>/models/validate')
+api.add_resource(ModelProviderModelUpdateApi,
+                 '/workspaces/current/model-providers/<string:provider_name>/models')
+api.add_resource(PreferredProviderTypeUpdateApi,
+                 '/workspaces/current/model-providers/<string:provider_name>/preferred-provider-type')
+api.add_resource(ModelProviderModelParameterRuleApi,
+                 '/workspaces/current/model-providers/<string:provider_name>/models/parameter-rules')
+api.add_resource(ModelProviderPaymentCheckoutUrlApi,
+                 '/workspaces/current/model-providers/<string:provider_name>/checkout-url')

+ 108 - 0
api/controllers/console/workspace/models.py

@@ -0,0 +1,108 @@
+from flask_login import login_required, current_user
+from flask_restful import Resource, reqparse
+
+from controllers.console import api
+from controllers.console.setup import setup_required
+from controllers.console.wraps import account_initialization_required
+from core.model_providers.model_provider_factory import ModelProviderFactory
+from core.model_providers.models.entity.model_params import ModelType
+from models.provider import ProviderType
+from services.provider_service import ProviderService
+
+
+class DefaultModelApi(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self):
+        parser = reqparse.RequestParser()
+        parser.add_argument('model_type', type=str, required=True, nullable=False,
+                            choices=['text-generation', 'embeddings', 'speech2text'], location='args')
+        args = parser.parse_args()
+
+        tenant_id = current_user.current_tenant_id
+
+        provider_service = ProviderService()
+        default_model = provider_service.get_default_model_of_model_type(
+            tenant_id=tenant_id,
+            model_type=args['model_type']
+        )
+
+        if not default_model:
+            return None
+
+        model_provider = ModelProviderFactory.get_preferred_model_provider(
+            tenant_id,
+            default_model.provider_name
+        )
+
+        if not model_provider:
+            return {
+                'model_name': default_model.model_name,
+                'model_type': default_model.model_type,
+                'model_provider': {
+                    'provider_name': default_model.provider_name
+                }
+            }
+
+        provider = model_provider.provider
+        rst = {
+            'model_name': default_model.model_name,
+            'model_type': default_model.model_type,
+            'model_provider': {
+                'provider_name': provider.provider_name,
+                'provider_type': provider.provider_type
+            }
+        }
+
+        model_provider_rules = ModelProviderFactory.get_provider_rule(default_model.provider_name)
+        if provider.provider_type == ProviderType.SYSTEM.value:
+            rst['model_provider']['quota_type'] = provider.quota_type
+            rst['model_provider']['quota_unit'] = model_provider_rules['system_config']['quota_unit']
+            rst['model_provider']['quota_limit'] = provider.quota_limit
+            rst['model_provider']['quota_used'] = provider.quota_used
+
+        return rst
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def post(self):
+        parser = reqparse.RequestParser()
+        parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
+        parser.add_argument('model_type', type=str, required=True, nullable=False,
+                            choices=['text-generation', 'embeddings', 'speech2text'], location='json')
+        parser.add_argument('provider_name', type=str, required=True, nullable=False, location='json')
+        args = parser.parse_args()
+
+        provider_service = ProviderService()
+        provider_service.update_default_model_of_model_type(
+            tenant_id=current_user.current_tenant_id,
+            model_type=args['model_type'],
+            provider_name=args['provider_name'],
+            model_name=args['model_name']
+        )
+
+        return {'result': 'success'}
+
+
+class ValidModelApi(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self, model_type):
+        ModelType.value_of(model_type)
+
+        provider_service = ProviderService()
+        valid_models = provider_service.get_valid_model_list(
+            tenant_id=current_user.current_tenant_id,
+            model_type=model_type
+        )
+
+        return valid_models
+
+
+api.add_resource(DefaultModelApi, '/workspaces/current/default-model')
+api.add_resource(ValidModelApi, '/workspaces/current/models/model-type/<string:model_type>')

+ 130 - 0
api/controllers/console/workspace/providers.py

@@ -0,0 +1,130 @@
+# -*- coding:utf-8 -*-
+from flask_login import login_required, current_user
+from flask_restful import Resource, reqparse
+from werkzeug.exceptions import Forbidden
+
+from controllers.console import api
+from controllers.console.setup import setup_required
+from controllers.console.wraps import account_initialization_required
+from core.model_providers.providers.base import CredentialsValidateFailedError
+from models.provider import ProviderType
+from services.provider_service import ProviderService
+
+
+class ProviderListApi(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self):
+        tenant_id = current_user.current_tenant_id
+
+        """
+        If the type is AZURE_OPENAI, decode and return the four fields of azure_api_type, azure_api_version:, 
+        azure_api_base, azure_api_key as an object, where azure_api_key displays the first 6 bits in plaintext, and the 
+        rest is replaced by * and the last two bits are displayed in plaintext
+        
+        If the type is other, decode and return the Token field directly, the field displays the first 6 bits in 
+        plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
+        """
+
+        provider_service = ProviderService()
+        provider_info_list = provider_service.get_provider_list(tenant_id)
+
+        provider_list = [
+            {
+                'provider_name': p['provider_name'],
+                'provider_type': p['provider_type'],
+                'is_valid': p['is_valid'],
+                'last_used': p['last_used'],
+                'is_enabled': p['is_valid'],
+                **({
+                       'quota_type': p['quota_type'],
+                       'quota_limit': p['quota_limit'],
+                       'quota_used': p['quota_used']
+                   } if p['provider_type'] == ProviderType.SYSTEM.value else {}),
+                'token': (p['config'] if p['provider_name'] != 'openai' else p['config']['openai_api_key'])
+                        if p['config'] else None
+            }
+            for name, provider_info in provider_info_list.items()
+            for p in provider_info['providers']
+        ]
+
+        return provider_list
+
+
+class ProviderTokenApi(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def post(self, provider):
+        # The role of the current user in the ta table must be admin or owner
+        if current_user.current_tenant.current_role not in ['admin', 'owner']:
+            raise Forbidden()
+
+        parser = reqparse.RequestParser()
+        parser.add_argument('token', required=True, nullable=False, location='json')
+        args = parser.parse_args()
+
+        if provider == 'openai':
+            args['token'] = {
+                'openai_api_key': args['token']
+            }
+
+        provider_service = ProviderService()
+        try:
+            provider_service.save_custom_provider_config(
+                tenant_id=current_user.current_tenant_id,
+                provider_name=provider,
+                config=args['token']
+            )
+        except CredentialsValidateFailedError as ex:
+            raise ValueError(str(ex))
+
+        return {'result': 'success'}, 201
+
+
+class ProviderTokenValidateApi(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def post(self, provider):
+        parser = reqparse.RequestParser()
+        parser.add_argument('token', required=True, nullable=False, location='json')
+        args = parser.parse_args()
+
+        provider_service = ProviderService()
+
+        if provider == 'openai':
+            args['token'] = {
+                'openai_api_key': args['token']
+            }
+
+        result = True
+        error = None
+
+        try:
+            provider_service.custom_provider_config_validate(
+                provider_name=provider,
+                config=args['token']
+            )
+        except CredentialsValidateFailedError as ex:
+            result = False
+            error = str(ex)
+
+        response = {'result': 'success' if result else 'error'}
+
+        if not result:
+            response['error'] = error
+
+        return response
+
+
+api.add_resource(ProviderTokenApi, '/workspaces/current/providers/<provider>/token',
+                 endpoint='workspaces_current_providers_token')  # PUT for updating provider token
+api.add_resource(ProviderTokenValidateApi, '/workspaces/current/providers/<provider>/token-validate',
+                 endpoint='workspaces_current_providers_token_validate')  # POST for validating provider token
+
+api.add_resource(ProviderListApi, '/workspaces/current/providers')  # GET for getting providers list

+ 1 - 1
api/controllers/console/workspace/workspace.py

@@ -30,7 +30,7 @@ tenant_fields = {
     'created_at': TimestampField,
     'role': fields.String,
     'providers': fields.List(fields.Nested(provider_fields)),
-    'in_trail': fields.Boolean,
+    'in_trial': fields.Boolean,
     'trial_end_reason': fields.String,
 }
 

+ 1 - 4
api/controllers/service_api/app/app.py

@@ -4,8 +4,6 @@ from flask_restful import fields, marshal_with
 from controllers.service_api import api
 from controllers.service_api.wraps import AppApiResource
 
-from core.llm.llm_builder import LLMBuilder
-from models.provider import ProviderName
 from models.model import App
 
 
@@ -35,13 +33,12 @@ class AppParameterApi(AppApiResource):
     def get(self, app_model: App, end_user):
         """Retrieve app parameters."""
         app_model_config = app_model.app_model_config
-        provider_name = LLMBuilder.get_default_provider(app_model.tenant_id, 'whisper-1')
 
         return {
             'opening_statement': app_model_config.opening_statement,
             'suggested_questions': app_model_config.suggested_questions_list,
             'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
-            'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
+            'speech_to_text': app_model_config.speech_to_text_dict,
             'more_like_this': app_model_config.more_like_this_dict,
             'user_input_form': app_model_config.user_input_form_list
         }

+ 1 - 1
api/controllers/service_api/app/audio.py

@@ -9,7 +9,7 @@ from controllers.service_api.app.error import AppUnavailableError, ProviderNotIn
     ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, UnsupportedAudioTypeError, \
     ProviderNotSupportSpeechToTextError
 from controllers.service_api.wraps import AppApiResource
-from core.llm.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
+from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
     LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
 from models.model import App, AppModelConfig
 from services.audio_service import AudioService

+ 1 - 1
api/controllers/service_api/app/completion.py

@@ -14,7 +14,7 @@ from controllers.service_api.app.error import AppUnavailableError, ProviderNotIn
     ProviderModelCurrentlyNotSupportError
 from controllers.service_api.wraps import AppApiResource
 from core.conversation_message_task import PubHandler
-from core.llm.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
+from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
     LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
 from libs.helper import uuid_value
 from services.completion_service import CompletionService

+ 1 - 1
api/controllers/service_api/dataset/document.py

@@ -11,7 +11,7 @@ from controllers.service_api.app.error import ProviderNotInitializeError
 from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \
     DatasetNotInitedError
 from controllers.service_api.wraps import DatasetApiResource
-from core.llm.error import ProviderTokenNotInitError
+from core.model_providers.error import ProviderTokenNotInitError
 from extensions.ext_database import db
 from extensions.ext_storage import storage
 from models.model import UploadFile

+ 1 - 4
api/controllers/web/app.py

@@ -4,8 +4,6 @@ from flask_restful import marshal_with, fields
 from controllers.web import api
 from controllers.web.wraps import WebApiResource
 
-from core.llm.llm_builder import LLMBuilder
-from models.provider import ProviderName
 from models.model import App
 
 
@@ -34,13 +32,12 @@ class AppParameterApi(WebApiResource):
     def get(self, app_model: App, end_user):
         """Retrieve app parameters."""
         app_model_config = app_model.app_model_config
-        provider_name = LLMBuilder.get_default_provider(app_model.tenant_id, 'whisper-1')
 
         return {
             'opening_statement': app_model_config.opening_statement,
             'suggested_questions': app_model_config.suggested_questions_list,
             'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
-            'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
+            'speech_to_text': app_model_config.speech_to_text_dict,
             'more_like_this': app_model_config.more_like_this_dict,
             'user_input_form': app_model_config.user_input_form_list
         }

+ 1 - 1
api/controllers/web/audio.py

@@ -10,7 +10,7 @@ from controllers.web.error import AppUnavailableError, ProviderNotInitializeErro
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, \
     UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
 from controllers.web.wraps import WebApiResource
-from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
+from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
     LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
 from services.audio_service import AudioService
 from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \

+ 1 - 1
api/controllers/web/completion.py

@@ -14,7 +14,7 @@ from controllers.web.error import AppUnavailableError, ConversationCompletedErro
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
 from controllers.web.wraps import WebApiResource
 from core.conversation_message_task import PubHandler
-from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
+from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
     LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
 from libs.helper import uuid_value
 from services.completion_service import CompletionService

+ 1 - 1
api/controllers/web/message.py

@@ -14,7 +14,7 @@ from controllers.web.error import NotChatAppError, CompletionRequestError, Provi
     AppMoreLikeThisDisabledError, NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
 from controllers.web.wraps import WebApiResource
-from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
+from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
     ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
 from libs.helper import uuid_value, TimestampField
 from services.completion_service import CompletionService

+ 0 - 36
api/core/__init__.py

@@ -1,36 +0,0 @@
-import os
-from typing import Optional
-
-import langchain
-from flask import Flask
-from pydantic import BaseModel
-
-from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
-from core.prompt.prompt_template import OneLineFormatter
-
-
-class HostedOpenAICredential(BaseModel):
-    api_key: str
-
-
-class HostedAnthropicCredential(BaseModel):
-    api_key: str
-
-
-class HostedLLMCredentials(BaseModel):
-    openai: Optional[HostedOpenAICredential] = None
-    anthropic: Optional[HostedAnthropicCredential] = None
-
-
-hosted_llm_credentials = HostedLLMCredentials()
-
-
-def init_app(app: Flask):
-    if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
-        langchain.verbose = True
-
-    if app.config.get("OPENAI_API_KEY"):
-        hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY"))
-
-    if app.config.get("ANTHROPIC_API_KEY"):
-        hosted_llm_credentials.anthropic = HostedAnthropicCredential(api_key=app.config.get("ANTHROPIC_API_KEY"))

+ 9 - 13
api/core/agent/agent/calc_token_mixin.py

@@ -1,20 +1,17 @@
-from typing import cast, List
+from typing import List
 
-from langchain import OpenAI
-from langchain.base_language import BaseLanguageModel
-from langchain.chat_models.openai import ChatOpenAI
 from langchain.schema import BaseMessage
 
-from core.constant import llm_constant
+from core.model_providers.models.entity.message import to_prompt_messages
+from core.model_providers.models.llm.base import BaseLLM
 
 
 class CalcTokenMixin:
 
-    def get_num_tokens_from_messages(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int:
-        llm = cast(ChatOpenAI, llm)
-        return llm.get_num_tokens_from_messages(messages)
+    def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
+        return model_instance.get_num_tokens(to_prompt_messages(messages))
 
-    def get_message_rest_tokens(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int:
+    def get_message_rest_tokens(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
         """
         Got the rest tokens available for the model after excluding messages tokens and completion max tokens
 
@@ -22,10 +19,9 @@ class CalcTokenMixin:
         :param messages:
         :return:
         """
-        llm = cast(ChatOpenAI, llm)
-        llm_max_tokens = llm_constant.max_context_token_length[llm.model_name]
-        completion_max_tokens = llm.max_tokens
-        used_tokens = self.get_num_tokens_from_messages(llm, messages, **kwargs)
+        llm_max_tokens = model_instance.model_rules.max_tokens.max
+        completion_max_tokens = model_instance.model_kwargs.max_tokens
+        used_tokens = self.get_num_tokens_from_messages(model_instance, messages, **kwargs)
         rest_tokens = llm_max_tokens - completion_max_tokens - used_tokens
 
         return rest_tokens

+ 9 - 1
api/core/agent/agent/multi_dataset_router_agent.py

@@ -4,9 +4,11 @@ from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
 from langchain.callbacks.base import BaseCallbackManager
 from langchain.callbacks.manager import Callbacks
 from langchain.prompts.chat import BaseMessagePromptTemplate
-from langchain.schema import AgentAction, AgentFinish, BaseLanguageModel, SystemMessage
+from langchain.schema import AgentAction, AgentFinish, SystemMessage
+from langchain.schema.language_model import BaseLanguageModel
 from langchain.tools import BaseTool
 
+from core.model_providers.models.llm.base import BaseLLM
 from core.tool.dataset_retriever_tool import DatasetRetrieverTool
 
 
@@ -14,6 +16,12 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
     """
     An Multi Dataset Retrieve Agent driven by Router.
     """
+    model_instance: BaseLLM
+
+    class Config:
+        """Configuration for this pydantic object."""
+
+        arbitrary_types_allowed = True
 
     def should_use_agent(self, query: str):
         """

+ 3 - 2
api/core/agent/agent/openai_function_call.py

@@ -6,7 +6,8 @@ from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
 from langchain.callbacks.base import BaseCallbackManager
 from langchain.callbacks.manager import Callbacks
 from langchain.prompts.chat import BaseMessagePromptTemplate
-from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel
+from langchain.schema import AgentAction, AgentFinish, SystemMessage
+from langchain.schema.language_model import BaseLanguageModel
 from langchain.tools import BaseTool
 
 from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
@@ -84,7 +85,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
 
         # summarize messages if rest_tokens < 0
         try:
-            messages = self.summarize_messages_if_needed(self.llm, messages, functions=self.functions)
+            messages = self.summarize_messages_if_needed(messages, functions=self.functions)
         except ExceededLLMTokensLimitError as e:
             return AgentFinish(return_values={"output": str(e)}, log=str(e))
 

+ 11 - 3
api/core/agent/agent/openai_function_call_summarize_mixin.py

@@ -3,20 +3,28 @@ from typing import cast, List
 from langchain.chat_models import ChatOpenAI
 from langchain.chat_models.openai import _convert_message_to_dict
 from langchain.memory.summary import SummarizerMixin
-from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage, BaseLanguageModel
+from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage
+from langchain.schema.language_model import BaseLanguageModel
 from pydantic import BaseModel
 
 from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
+from core.model_providers.models.llm.base import BaseLLM
 
 
 class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin):
     moving_summary_buffer: str = ""
     moving_summary_index: int = 0
     summary_llm: BaseLanguageModel
+    model_instance: BaseLLM
 
-    def summarize_messages_if_needed(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
+    class Config:
+        """Configuration for this pydantic object."""
+
+        arbitrary_types_allowed = True
+
+    def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
         # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
-        rest_tokens = self.get_message_rest_tokens(llm, messages, **kwargs)
+        rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
         rest_tokens = rest_tokens - 20  # to deal with the inaccuracy of rest_tokens
         if rest_tokens >= 0:
             return messages

+ 3 - 2
api/core/agent/agent/openai_multi_function_call.py

@@ -6,7 +6,8 @@ from langchain.agents.openai_functions_multi_agent.base import OpenAIMultiFuncti
 from langchain.callbacks.base import BaseCallbackManager
 from langchain.callbacks.manager import Callbacks
 from langchain.prompts.chat import BaseMessagePromptTemplate
-from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel
+from langchain.schema import AgentAction, AgentFinish, SystemMessage
+from langchain.schema.language_model import BaseLanguageModel
 from langchain.tools import BaseTool
 
 from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
@@ -84,7 +85,7 @@ class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, Ope
 
         # summarize messages if rest_tokens < 0
         try:
-            messages = self.summarize_messages_if_needed(self.llm, messages, functions=self.functions)
+            messages = self.summarize_messages_if_needed(messages, functions=self.functions)
         except ExceededLLMTokensLimitError as e:
             return AgentFinish(return_values={"output": str(e)}, log=str(e))
 

+ 162 - 0
api/core/agent/agent/structed_multi_dataset_router_agent.py

@@ -0,0 +1,162 @@
+import re
+from typing import List, Tuple, Any, Union, Sequence, Optional, cast
+
+from langchain import BasePromptTemplate
+from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
+from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
+from langchain.base_language import BaseLanguageModel
+from langchain.callbacks.base import BaseCallbackManager
+from langchain.callbacks.manager import Callbacks
+from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
+from langchain.schema import AgentAction, AgentFinish, OutputParserException
+from langchain.tools import BaseTool
+from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
+
+from core.model_providers.models.llm.base import BaseLLM
+from core.tool.dataset_retriever_tool import DatasetRetrieverTool
+
+FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
+The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
+Valid "action" values: "Final Answer" or {tool_names}
+
+Provide only ONE action per $JSON_BLOB, as shown:
+
+```
+{{{{
+  "action": $TOOL_NAME,
+  "action_input": $INPUT
+}}}}
+```
+
+Follow this format:
+
+Question: input question to answer
+Thought: consider previous and subsequent steps
+Action:
+```
+$JSON_BLOB
+```
+Observation: action result
+... (repeat Thought/Action/Observation N times)
+Thought: I know what to respond
+Action:
+```
+{{{{
+  "action": "Final Answer",
+  "action_input": "Final response to human"
+}}}}
+```"""
+
+
+class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
+    model_instance: BaseLLM
+    dataset_tools: Sequence[BaseTool]
+
+    class Config:
+        """Configuration for this pydantic object."""
+
+        arbitrary_types_allowed = True
+
+    def should_use_agent(self, query: str):
+        """
+        return should use agent
+        Using the ReACT mode to determine whether an agent is needed is costly,
+        so it's better to just use an Agent for reasoning, which is cheaper.
+
+        :param query:
+        :return:
+        """
+        return True
+
+    def plan(
+        self,
+        intermediate_steps: List[Tuple[AgentAction, str]],
+        callbacks: Callbacks = None,
+        **kwargs: Any,
+    ) -> Union[AgentAction, AgentFinish]:
+        """Given input, decided what to do.
+
+        Args:
+            intermediate_steps: Steps the LLM has taken to date,
+                along with observations
+            callbacks: Callbacks to run.
+            **kwargs: User inputs.
+
+        Returns:
+            Action specifying what tool to use.
+        """
+        if len(self.dataset_tools) == 0:
+            return AgentFinish(return_values={"output": ''}, log='')
+        elif len(self.dataset_tools) == 1:
+            tool = next(iter(self.dataset_tools))
+            tool = cast(DatasetRetrieverTool, tool)
+            rst = tool.run(tool_input={'dataset_id': tool.dataset_id, 'query': kwargs['input']})
+            return AgentFinish(return_values={"output": rst}, log=rst)
+
+        full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
+        full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
+
+        try:
+            return self.output_parser.parse(full_output)
+        except OutputParserException:
+            return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
+                                          "I don't know how to respond to that."}, "")
+    @classmethod
+    def create_prompt(
+            cls,
+            tools: Sequence[BaseTool],
+            prefix: str = PREFIX,
+            suffix: str = SUFFIX,
+            human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
+            format_instructions: str = FORMAT_INSTRUCTIONS,
+            input_variables: Optional[List[str]] = None,
+            memory_prompts: Optional[List[BasePromptTemplate]] = None,
+    ) -> BasePromptTemplate:
+        tool_strings = []
+        for tool in tools:
+            args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
+            tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
+        formatted_tools = "\n".join(tool_strings)
+        unique_tool_names = set(tool.name for tool in tools)
+        tool_names = ", ".join('"' + name + '"' for name in unique_tool_names)
+        format_instructions = format_instructions.format(tool_names=tool_names)
+        template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
+        if input_variables is None:
+            input_variables = ["input", "agent_scratchpad"]
+        _memory_prompts = memory_prompts or []
+        messages = [
+            SystemMessagePromptTemplate.from_template(template),
+            *_memory_prompts,
+            HumanMessagePromptTemplate.from_template(human_message_template),
+        ]
+        return ChatPromptTemplate(input_variables=input_variables, messages=messages)
+
+    @classmethod
+    def from_llm_and_tools(
+            cls,
+            llm: BaseLanguageModel,
+            tools: Sequence[BaseTool],
+            callback_manager: Optional[BaseCallbackManager] = None,
+            output_parser: Optional[AgentOutputParser] = None,
+            prefix: str = PREFIX,
+            suffix: str = SUFFIX,
+            human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
+            format_instructions: str = FORMAT_INSTRUCTIONS,
+            input_variables: Optional[List[str]] = None,
+            memory_prompts: Optional[List[BasePromptTemplate]] = None,
+            **kwargs: Any,
+    ) -> Agent:
+        return super().from_llm_and_tools(
+            llm=llm,
+            tools=tools,
+            callback_manager=callback_manager,
+            output_parser=output_parser,
+            prefix=prefix,
+            suffix=suffix,
+            human_message_template=human_message_template,
+            format_instructions=format_instructions,
+            input_variables=input_variables,
+            memory_prompts=memory_prompts,
+            dataset_tools=tools,
+            **kwargs,
+        )

+ 8 - 2
api/core/agent/agent/structured_chat.py

@@ -14,7 +14,7 @@ from langchain.tools import BaseTool
 from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
 
 from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
-
+from core.model_providers.models.llm.base import BaseLLM
 
 FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
 The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
@@ -53,6 +53,12 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
     moving_summary_buffer: str = ""
     moving_summary_index: int = 0
     summary_llm: BaseLanguageModel
+    model_instance: BaseLLM
+
+    class Config:
+        """Configuration for this pydantic object."""
+
+        arbitrary_types_allowed = True
 
     def should_use_agent(self, query: str):
         """
@@ -89,7 +95,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
         if prompts:
             messages = prompts[0].to_messages()
 
-        rest_tokens = self.get_message_rest_tokens(self.llm_chain.llm, messages)
+        rest_tokens = self.get_message_rest_tokens(self.model_instance, messages)
         if rest_tokens < 0:
             full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
 

+ 25 - 11
api/core/agent/agent_executor.py

@@ -3,7 +3,6 @@ import logging
 from typing import Union, Optional
 
 from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent
-from langchain.base_language import BaseLanguageModel
 from langchain.callbacks.manager import Callbacks
 from langchain.memory.chat_memory import BaseChatMemory
 from langchain.tools import BaseTool
@@ -13,14 +12,17 @@ from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
 from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
 from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
 from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
+from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
 from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
 from langchain.agents import AgentExecutor as LCAgentExecutor
 
+from core.model_providers.models.llm.base import BaseLLM
 from core.tool.dataset_retriever_tool import DatasetRetrieverTool
 
 
 class PlanningStrategy(str, enum.Enum):
     ROUTER = 'router'
+    REACT_ROUTER = 'react_router'
     REACT = 'react'
     FUNCTION_CALL = 'function_call'
     MULTI_FUNCTION_CALL = 'multi_function_call'
@@ -28,10 +30,9 @@ class PlanningStrategy(str, enum.Enum):
 
 class AgentConfiguration(BaseModel):
     strategy: PlanningStrategy
-    llm: BaseLanguageModel
+    model_instance: BaseLLM
     tools: list[BaseTool]
-    summary_llm: BaseLanguageModel
-    dataset_llm: BaseLanguageModel
+    summary_model_instance: BaseLLM
     memory: Optional[BaseChatMemory] = None
     callbacks: Callbacks = None
     max_iterations: int = 6
@@ -60,36 +61,49 @@ class AgentExecutor:
     def _init_agent(self) -> Union[BaseSingleActionAgent | BaseMultiActionAgent]:
         if self.configuration.strategy == PlanningStrategy.REACT:
             agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
-                llm=self.configuration.llm,
+                model_instance=self.configuration.model_instance,
+                llm=self.configuration.model_instance.client,
                 tools=self.configuration.tools,
                 output_parser=StructuredChatOutputParser(),
-                summary_llm=self.configuration.summary_llm,
+                summary_llm=self.configuration.summary_model_instance.client,
                 verbose=True
             )
         elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
             agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
-                llm=self.configuration.llm,
+                model_instance=self.configuration.model_instance,
+                llm=self.configuration.model_instance.client,
                 tools=self.configuration.tools,
                 extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,  # used for read chat histories memory
-                summary_llm=self.configuration.summary_llm,
+                summary_llm=self.configuration.summary_model_instance.client,
                 verbose=True
             )
         elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
             agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools(
-                llm=self.configuration.llm,
+                model_instance=self.configuration.model_instance,
+                llm=self.configuration.model_instance.client,
                 tools=self.configuration.tools,
                 extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,  # used for read chat histories memory
-                summary_llm=self.configuration.summary_llm,
+                summary_llm=self.configuration.summary_model_instance.client,
                 verbose=True
             )
         elif self.configuration.strategy == PlanningStrategy.ROUTER:
             self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
             agent = MultiDatasetRouterAgent.from_llm_and_tools(
-                llm=self.configuration.dataset_llm,
+                model_instance=self.configuration.model_instance,
+                llm=self.configuration.model_instance.client,
                 tools=self.configuration.tools,
                 extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
                 verbose=True
             )
+        elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER:
+            self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
+            agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
+                model_instance=self.configuration.model_instance,
+                llm=self.configuration.model_instance.client,
+                tools=self.configuration.tools,
+                output_parser=StructuredChatOutputParser(),
+                verbose=True
+            )
         else:
             raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}")
 

+ 5 - 4
api/core/callback_handler/agent_loop_gather_callback_handler.py

@@ -10,15 +10,16 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
 
 from core.callback_handler.entity.agent_loop import AgentLoop
 from core.conversation_message_task import ConversationMessageTask
+from core.model_providers.models.llm.base import BaseLLM
 
 
 class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
     """Callback Handler that prints to std out."""
     raise_error: bool = True
 
-    def __init__(self, model_name, conversation_message_task: ConversationMessageTask) -> None:
+    def __init__(self, model_instant: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
         """Initialize callback handler."""
-        self.model_name = model_name
+        self.model_instant = model_instant
         self.conversation_message_task = conversation_message_task
         self._agent_loops = []
         self._current_loop = None
@@ -152,7 +153,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
             self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
 
             self.conversation_message_task.on_agent_end(
-                self._message_agent_thought, self.model_name, self._current_loop
+                self._message_agent_thought, self.model_instant, self._current_loop
             )
 
             self._agent_loops.append(self._current_loop)
@@ -183,7 +184,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
             )
 
             self.conversation_message_task.on_agent_end(
-                self._message_agent_thought, self.model_name, self._current_loop
+                self._message_agent_thought, self.model_instant, self._current_loop
             )
 
             self._agent_loops.append(self._current_loop)

+ 11 - 7
api/core/callback_handler/llm_callback_handler.py

@@ -3,18 +3,20 @@ import time
 from typing import Any, Dict, List, Union
 
 from langchain.callbacks.base import BaseCallbackHandler
-from langchain.schema import LLMResult, BaseMessage, BaseLanguageModel
+from langchain.schema import LLMResult, BaseMessage
 
 from core.callback_handler.entity.llm_message import LLMMessage
 from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
+from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage
+from core.model_providers.models.llm.base import BaseLLM
 
 
 class LLMCallbackHandler(BaseCallbackHandler):
     raise_error: bool = True
 
-    def __init__(self, llm: BaseLanguageModel,
+    def __init__(self, model_instance: BaseLLM,
                  conversation_message_task: ConversationMessageTask):
-        self.llm = llm
+        self.model_instance = model_instance
         self.llm_message = LLMMessage()
         self.start_at = None
         self.conversation_message_task = conversation_message_task
@@ -46,7 +48,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
             })
 
         self.llm_message.prompt = real_prompts
-        self.llm_message.prompt_tokens = self.llm.get_num_tokens_from_messages(messages[0])
+        self.llm_message.prompt_tokens = self.model_instance.get_num_tokens(to_prompt_messages(messages[0]))
 
     def on_llm_start(
         self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
@@ -58,7 +60,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
             "text": prompts[0]
         }]
 
-        self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0])
+        self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
 
     def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
         end_at = time.perf_counter()
@@ -68,7 +70,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
             self.conversation_message_task.append_message_text(response.generations[0][0].text)
             self.llm_message.completion = response.generations[0][0].text
 
-        self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
+        self.llm_message.completion_tokens = self.model_instance.get_num_tokens([PromptMessage(content=self.llm_message.completion)])
 
         self.conversation_message_task.save_message(self.llm_message)
 
@@ -89,7 +91,9 @@ class LLMCallbackHandler(BaseCallbackHandler):
             if self.conversation_message_task.streaming:
                 end_at = time.perf_counter()
                 self.llm_message.latency = end_at - self.start_at
-                self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
+                self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
+                    [PromptMessage(content=self.llm_message.completion)]
+                )
                 self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
         else:
             logging.error(error)

+ 0 - 2
api/core/callback_handler/main_chain_gather_callback_handler.py

@@ -5,9 +5,7 @@ from typing import Any, Dict, Union
 
 from langchain.callbacks.base import BaseCallbackHandler
 
-from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
 from core.callback_handler.entity.chain_result import ChainResult
-from core.constant import llm_constant
 from core.conversation_message_task import ConversationMessageTask
 
 

+ 86 - 111
api/core/completion.py

@@ -2,27 +2,19 @@ import logging
 import re
 from typing import Optional, List, Union, Tuple
 
-from langchain.base_language import BaseLanguageModel
-from langchain.callbacks.base import BaseCallbackHandler
-from langchain.chat_models.base import BaseChatModel
-from langchain.llms import BaseLLM
-from langchain.schema import BaseMessage, HumanMessage
+from langchain.schema import BaseMessage
 from requests.exceptions import ChunkedEncodingError
 
 from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
 from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
-from core.constant import llm_constant
 from core.callback_handler.llm_callback_handler import LLMCallbackHandler
-from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
-    DifyStdOutCallbackHandler
 from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
-from core.llm.error import LLMBadRequestError
-from core.llm.fake import FakeLLM
-from core.llm.llm_builder import LLMBuilder
-from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
-from core.llm.streamable_open_ai import StreamableOpenAI
+from core.model_providers.error import LLMBadRequestError
 from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
     ReadOnlyConversationTokenDBBufferSharedMemory
+from core.model_providers.model_factory import ModelFactory
+from core.model_providers.models.entity.message import PromptMessage, to_prompt_messages
+from core.model_providers.models.llm.base import BaseLLM
 from core.orchestrator_rule_parser import OrchestratorRuleParser
 from core.prompt.prompt_builder import PromptBuilder
 from core.prompt.prompt_template import JinjaPromptTemplate
@@ -51,12 +43,10 @@ class Completion:
 
             inputs = conversation.inputs
 
-        rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
-            mode=app.mode,
+        final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
             tenant_id=app.tenant_id,
-            app_model_config=app_model_config,
-            query=query,
-            inputs=inputs
+            model_config=app_model_config.model_dict,
+            streaming=streaming
         )
 
         conversation_message_task = ConversationMessageTask(
@@ -68,10 +58,17 @@ class Completion:
             is_override=is_override,
             inputs=inputs,
             query=query,
-            streaming=streaming
+            streaming=streaming,
+            model_instance=final_model_instance
         )
 
-        chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
+        rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
+            mode=app.mode,
+            model_instance=final_model_instance,
+            app_model_config=app_model_config,
+            query=query,
+            inputs=inputs
+        )
 
         # init orchestrator rule parser
         orchestrator_rule_parser = OrchestratorRuleParser(
@@ -80,6 +77,7 @@ class Completion:
         )
 
         # parse sensitive_word_avoidance_chain
+        chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
         sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback])
         if sensitive_word_avoidance_chain:
             query = sensitive_word_avoidance_chain.run(query)
@@ -102,15 +100,14 @@ class Completion:
         # run the final llm
         try:
             cls.run_final_llm(
-                tenant_id=app.tenant_id,
+                model_instance=final_model_instance,
                 mode=app.mode,
                 app_model_config=app_model_config,
                 query=query,
                 inputs=inputs,
                 agent_execute_result=agent_execute_result,
                 conversation_message_task=conversation_message_task,
-                memory=memory,
-                streaming=streaming
+                memory=memory
             )
         except ConversationTaskStoppedException:
             return
@@ -121,31 +118,20 @@ class Completion:
             return
 
     @classmethod
-    def run_final_llm(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
+    def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
                       agent_execute_result: Optional[AgentExecuteResult],
                       conversation_message_task: ConversationMessageTask,
-                      memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], streaming: bool):
+                      memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]):
         # When no extra pre prompt is specified,
         # the output of the agent can be used directly as the main output content without calling LLM again
+        fake_response = None
         if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
                 and agent_execute_result.strategy != PlanningStrategy.ROUTER:
-            final_llm = FakeLLM(response=agent_execute_result.output,
-                                origin_llm=agent_execute_result.configuration.llm,
-                                streaming=streaming)
-            final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
-            response = final_llm.generate([[HumanMessage(content=query)]])
-            return response
-
-        final_llm = LLMBuilder.to_llm_from_model(
-            tenant_id=tenant_id,
-            model=app_model_config.model_dict,
-            streaming=streaming
-        )
+            fake_response = agent_execute_result.output
 
         # get llm prompt
-        prompt, stop_words = cls.get_main_llm_prompt(
+        prompt_messages, stop_words = cls.get_main_llm_prompt(
             mode=mode,
-            llm=final_llm,
             model=app_model_config.model_dict,
             pre_prompt=app_model_config.pre_prompt,
             query=query,
@@ -154,25 +140,26 @@ class Completion:
             memory=memory
         )
 
-        final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
-
         cls.recale_llm_max_tokens(
-            final_llm=final_llm,
-            model=app_model_config.model_dict,
-            prompt=prompt,
-            mode=mode
+            model_instance=model_instance,
+            prompt_messages=prompt_messages,
         )
 
-        response = final_llm.generate([prompt], stop_words)
+        response = model_instance.run(
+            messages=prompt_messages,
+            stop=stop_words,
+            callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
+            fake_response=fake_response
+        )
 
         return response
 
     @classmethod
-    def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, model: dict,
+    def get_main_llm_prompt(cls, mode: str, model: dict,
                             pre_prompt: str, query: str, inputs: dict,
                             agent_execute_result: Optional[AgentExecuteResult],
                             memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
-            Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
+            Tuple[List[PromptMessage], Optional[List[str]]]:
         if mode == 'completion':
             prompt_template = JinjaPromptTemplate.from_template(
                 template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
@@ -200,11 +187,7 @@ And answer according to the language of the user's question.
                 **prompt_inputs
             )
 
-            if isinstance(llm, BaseChatModel):
-                # use chat llm as completion model
-                return [HumanMessage(content=prompt_content)], None
-            else:
-                return prompt_content, None
+            return [PromptMessage(content=prompt_content)], None
         else:
             messages: List[BaseMessage] = []
 
@@ -249,12 +232,14 @@ And answer according to the language of the user's question.
                     inputs=human_inputs
                 )
 
-                curr_message_tokens = memory.llm.get_num_tokens_from_messages([tmp_human_message])
-                model_name = model['name']
-                max_tokens = model.get("completion_params").get('max_tokens')
-                rest_tokens = llm_constant.max_context_token_length[model_name] \
-                              - max_tokens - curr_message_tokens
-                rest_tokens = max(rest_tokens, 0)
+                if memory.model_instance.model_rules.max_tokens.max:
+                    curr_message_tokens = memory.model_instance.get_num_tokens(to_prompt_messages([tmp_human_message]))
+                    max_tokens = model.get("completion_params").get('max_tokens')
+                    rest_tokens = memory.model_instance.model_rules.max_tokens.max - max_tokens - curr_message_tokens
+                    rest_tokens = max(rest_tokens, 0)
+                else:
+                    rest_tokens = 2000
+
                 histories = cls.get_history_messages_from_memory(memory, rest_tokens)
                 human_message_prompt += "\n\n" if human_message_prompt else ""
                 human_message_prompt += "Here is the chat histories between human and assistant, " \
@@ -274,17 +259,7 @@ And answer according to the language of the user's question.
             for message in messages:
                 message.content = re.sub(r'<\|.*?\|>', '', message.content)
 
-            return messages, ['\nHuman:', '</histories>']
-
-    @classmethod
-    def get_llm_callbacks(cls, llm: BaseLanguageModel,
-                          streaming: bool,
-                          conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]:
-        llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
-        if streaming:
-            return [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
-        else:
-            return [llm_callback_handler, DifyStdOutCallbackHandler()]
+            return to_prompt_messages(messages), ['\nHuman:', '</histories>']
 
     @classmethod
     def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
@@ -300,15 +275,15 @@ And answer according to the language of the user's question.
                                      conversation: Conversation,
                                      **kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory:
         # only for calc token in memory
-        memory_llm = LLMBuilder.to_llm_from_model(
+        memory_model_instance = ModelFactory.get_text_generation_model_from_model_config(
             tenant_id=tenant_id,
-            model=app_model_config.model_dict
+            model_config=app_model_config.model_dict
         )
 
         # use llm config from conversation
         memory = ReadOnlyConversationTokenDBBufferSharedMemory(
             conversation=conversation,
-            llm=memory_llm,
+            model_instance=memory_model_instance,
             max_token_limit=kwargs.get("max_token_limit", 2048),
             memory_key=kwargs.get("memory_key", "chat_history"),
             return_messages=kwargs.get("return_messages", True),
@@ -320,21 +295,20 @@ And answer according to the language of the user's question.
         return memory
 
     @classmethod
-    def get_validate_rest_tokens(cls, mode: str, tenant_id: str, app_model_config: AppModelConfig,
+    def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_config: AppModelConfig,
                                  query: str, inputs: dict) -> int:
-        llm = LLMBuilder.to_llm_from_model(
-            tenant_id=tenant_id,
-            model=app_model_config.model_dict
-        )
+        model_limited_tokens = model_instance.model_rules.max_tokens.max
+        max_tokens = model_instance.get_model_kwargs().max_tokens
 
-        model_name = app_model_config.model_dict.get("name")
-        model_limited_tokens = llm_constant.max_context_token_length[model_name]
-        max_tokens = app_model_config.model_dict.get("completion_params").get('max_tokens')
+        if model_limited_tokens is None:
+            return -1
+
+        if max_tokens is None:
+            max_tokens = 0
 
         # get prompt without memory and context
-        prompt, _ = cls.get_main_llm_prompt(
+        prompt_messages, _ = cls.get_main_llm_prompt(
             mode=mode,
-            llm=llm,
             model=app_model_config.model_dict,
             pre_prompt=app_model_config.pre_prompt,
             query=query,
@@ -343,9 +317,7 @@ And answer according to the language of the user's question.
             memory=None
         )
 
-        prompt_tokens = llm.get_num_tokens(prompt) if isinstance(prompt, str) \
-            else llm.get_num_tokens_from_messages(prompt)
-
+        prompt_tokens = model_instance.get_num_tokens(prompt_messages)
         rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
         if rest_tokens < 0:
             raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
@@ -354,36 +326,40 @@ And answer according to the language of the user's question.
         return rest_tokens
 
     @classmethod
-    def recale_llm_max_tokens(cls, final_llm: BaseLanguageModel, model: dict,
-                              prompt: Union[str, List[BaseMessage]], mode: str):
+    def recale_llm_max_tokens(cls, model_instance: BaseLLM, prompt_messages: List[PromptMessage]):
         # recalc max_tokens if sum(prompt_token +  max_tokens) over model token limit
-        model_name = model.get("name")
-        model_limited_tokens = llm_constant.max_context_token_length[model_name]
-        max_tokens = model.get("completion_params").get('max_tokens')
+        model_limited_tokens = model_instance.model_rules.max_tokens.max
+        max_tokens = model_instance.get_model_kwargs().max_tokens
 
-        if mode == 'completion' and isinstance(final_llm, BaseLLM):
-            prompt_tokens = final_llm.get_num_tokens(prompt)
-        else:
-            prompt_tokens = final_llm.get_num_tokens_from_messages(prompt)
+        if model_limited_tokens is None:
+            return
+
+        if max_tokens is None:
+            max_tokens = 0
+
+        prompt_tokens = model_instance.get_num_tokens(prompt_messages)
 
         if prompt_tokens + max_tokens > model_limited_tokens:
             max_tokens = max(model_limited_tokens - prompt_tokens, 16)
-            final_llm.max_tokens = max_tokens
+
+            # update model instance max tokens
+            model_kwargs = model_instance.get_model_kwargs()
+            model_kwargs.max_tokens = max_tokens
+            model_instance.set_model_kwargs(model_kwargs)
 
     @classmethod
     def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
                                 app_model_config: AppModelConfig, user: Account, streaming: bool):
 
-        llm = LLMBuilder.to_llm_from_model(
+        final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
             tenant_id=app.tenant_id,
-            model=app_model_config.model_dict,
+            model_config=app_model_config.model_dict,
             streaming=streaming
         )
 
         # get llm prompt
-        original_prompt, _ = cls.get_main_llm_prompt(
+        old_prompt_messages, _ = cls.get_main_llm_prompt(
             mode="completion",
-            llm=llm,
             model=app_model_config.model_dict,
             pre_prompt=pre_prompt,
             query=message.query,
@@ -395,10 +371,9 @@ And answer according to the language of the user's question.
         original_completion = message.answer.strip()
 
         prompt = MORE_LIKE_THIS_GENERATE_PROMPT
-        prompt = prompt.format(prompt=original_prompt, original_completion=original_completion)
+        prompt = prompt.format(prompt=old_prompt_messages[0].content, original_completion=original_completion)
 
-        if isinstance(llm, BaseChatModel):
-            prompt = [HumanMessage(content=prompt)]
+        prompt_messages = [PromptMessage(content=prompt)]
 
         conversation_message_task = ConversationMessageTask(
             task_id=task_id,
@@ -408,16 +383,16 @@ And answer according to the language of the user's question.
             inputs=message.inputs,
             query=message.query,
             is_override=True if message.override_model_configs else False,
-            streaming=streaming
+            streaming=streaming,
+            model_instance=final_model_instance
         )
 
-        llm.callbacks = cls.get_llm_callbacks(llm, streaming, conversation_message_task)
-
         cls.recale_llm_max_tokens(
-            final_llm=llm,
-            model=app_model_config.model_dict,
-            prompt=prompt,
-            mode='completion'
+            model_instance=final_model_instance,
+            prompt_messages=prompt_messages
         )
 
-        llm.generate([prompt])
+        final_model_instance.run(
+            messages=prompt_messages,
+            callbacks=[LLMCallbackHandler(final_model_instance, conversation_message_task)]
+        )

+ 0 - 109
api/core/constant/llm_constant.py

@@ -1,109 +0,0 @@
-from _decimal import Decimal
-
-models = {
-    'claude-instant-1': 'anthropic',  # 100,000 tokens
-    'claude-2': 'anthropic',  # 100,000 tokens
-    'gpt-4': 'openai',  # 8,192 tokens
-    'gpt-4-32k': 'openai',  # 32,768 tokens
-    'gpt-3.5-turbo': 'openai',  # 4,096 tokens
-    'gpt-3.5-turbo-16k': 'openai',  # 16384 tokens
-    'text-davinci-003': 'openai',  # 4,097 tokens
-    'text-davinci-002': 'openai',  # 4,097 tokens
-    'text-curie-001': 'openai',  # 2,049 tokens
-    'text-babbage-001': 'openai',  # 2,049 tokens
-    'text-ada-001': 'openai',  # 2,049 tokens
-    'text-embedding-ada-002': 'openai',  # 8191 tokens, 1536 dimensions
-    'whisper-1': 'openai'
-}
-
-max_context_token_length = {
-    'claude-instant-1': 100000,
-    'claude-2': 100000,
-    'gpt-4': 8192,
-    'gpt-4-32k': 32768,
-    'gpt-3.5-turbo': 4096,
-    'gpt-3.5-turbo-16k': 16384,
-    'text-davinci-003': 4097,
-    'text-davinci-002': 4097,
-    'text-curie-001': 2049,
-    'text-babbage-001': 2049,
-    'text-ada-001': 2049,
-    'text-embedding-ada-002': 8191,
-}
-
-models_by_mode = {
-    'chat': [
-        'claude-instant-1',  # 100,000 tokens
-        'claude-2',  # 100,000 tokens
-        'gpt-4',  # 8,192 tokens
-        'gpt-4-32k',  # 32,768 tokens
-        'gpt-3.5-turbo',  # 4,096 tokens
-        'gpt-3.5-turbo-16k',  # 16,384 tokens
-    ],
-    'completion': [
-        'claude-instant-1',  # 100,000 tokens
-        'claude-2',  # 100,000 tokens
-        'gpt-4',  # 8,192 tokens
-        'gpt-4-32k',  # 32,768 tokens
-        'gpt-3.5-turbo',  # 4,096 tokens
-        'gpt-3.5-turbo-16k',  # 16,384 tokens
-        'text-davinci-003',  # 4,097 tokens
-        'text-davinci-002'  # 4,097 tokens
-        'text-curie-001',  # 2,049 tokens
-        'text-babbage-001',  # 2,049 tokens
-        'text-ada-001'  # 2,049 tokens
-    ],
-    'embedding': [
-        'text-embedding-ada-002'  # 8191 tokens, 1536 dimensions
-    ]
-}
-
-model_currency = 'USD'
-
-model_prices = {
-    'claude-instant-1': {
-        'prompt': Decimal('0.00163'),
-        'completion': Decimal('0.00551'),
-    },
-    'claude-2': {
-        'prompt': Decimal('0.01102'),
-        'completion': Decimal('0.03268'),
-    },
-    'gpt-4': {
-        'prompt': Decimal('0.03'),
-        'completion': Decimal('0.06'),
-    },
-    'gpt-4-32k': {
-        'prompt': Decimal('0.06'),
-        'completion': Decimal('0.12')
-    },
-    'gpt-3.5-turbo': {
-        'prompt': Decimal('0.0015'),
-        'completion': Decimal('0.002')
-    },
-    'gpt-3.5-turbo-16k': {
-        'prompt': Decimal('0.003'),
-        'completion': Decimal('0.004')
-    },
-    'text-davinci-003': {
-        'prompt': Decimal('0.02'),
-        'completion': Decimal('0.02')
-    },
-    'text-curie-001': {
-        'prompt': Decimal('0.002'),
-        'completion': Decimal('0.002')
-    },
-    'text-babbage-001': {
-        'prompt': Decimal('0.0005'),
-        'completion': Decimal('0.0005')
-    },
-    'text-ada-001': {
-        'prompt': Decimal('0.0004'),
-        'completion': Decimal('0.0004')
-    },
-    'text-embedding-ada-002': {
-        'usage': Decimal('0.0001'),
-    }
-}
-
-agent_model_name = 'text-davinci-003'

+ 22 - 37
api/core/conversation_message_task.py

@@ -6,9 +6,9 @@ from core.callback_handler.entity.agent_loop import AgentLoop
 from core.callback_handler.entity.dataset_query import DatasetQueryObj
 from core.callback_handler.entity.llm_message import LLMMessage
 from core.callback_handler.entity.chain_result import ChainResult
-from core.constant import llm_constant
-from core.llm.llm_builder import LLMBuilder
-from core.llm.provider.llm_provider_service import LLMProviderService
+from core.model_providers.model_factory import ModelFactory
+from core.model_providers.models.entity.message import to_prompt_messages, MessageType
+from core.model_providers.models.llm.base import BaseLLM
 from core.prompt.prompt_builder import PromptBuilder
 from core.prompt.prompt_template import JinjaPromptTemplate
 from events.message_event import message_was_created
@@ -16,12 +16,11 @@ from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from models.dataset import DatasetQuery
 from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, MessageChain
-from models.provider import ProviderType, Provider
 
 
 class ConversationMessageTask:
     def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
-                 inputs: dict, query: str, streaming: bool,
+                 inputs: dict, query: str, streaming: bool, model_instance: BaseLLM,
                  conversation: Optional[Conversation] = None, is_override: bool = False):
         self.task_id = task_id
 
@@ -38,9 +37,12 @@ class ConversationMessageTask:
         self.conversation = conversation
         self.is_new_conversation = False
 
+        self.model_instance = model_instance
+
         self.message = None
 
         self.model_dict = self.app_model_config.model_dict
+        self.provider_name = self.model_dict.get('provider')
         self.model_name = self.model_dict.get('name')
         self.mode = app.mode
 
@@ -56,9 +58,6 @@ class ConversationMessageTask:
         )
 
     def init(self):
-        provider_name = LLMBuilder.get_default_provider(self.app.tenant_id, self.model_name)
-        self.model_dict['provider'] = provider_name
-
         override_model_configs = None
         if self.is_override:
             override_model_configs = {
@@ -89,15 +88,19 @@ class ConversationMessageTask:
             if self.app_model_config.pre_prompt:
                 system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs)
                 system_instruction = system_message.content
-                llm = LLMBuilder.to_llm(self.tenant_id, self.model_name)
-                system_instruction_tokens = llm.get_num_tokens_from_messages([system_message])
+                model_instance = ModelFactory.get_text_generation_model(
+                    tenant_id=self.tenant_id,
+                    model_provider_name=self.provider_name,
+                    model_name=self.model_name
+                )
+                system_instruction_tokens = model_instance.get_num_tokens(to_prompt_messages([system_message]))
 
         if not self.conversation:
             self.is_new_conversation = True
             self.conversation = Conversation(
                 app_id=self.app_model_config.app_id,
                 app_model_config_id=self.app_model_config.id,
-                model_provider=self.model_dict.get('provider'),
+                model_provider=self.provider_name,
                 model_id=self.model_name,
                 override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
                 mode=self.mode,
@@ -117,7 +120,7 @@ class ConversationMessageTask:
 
         self.message = Message(
             app_id=self.app_model_config.app_id,
-            model_provider=self.model_dict.get('provider'),
+            model_provider=self.provider_name,
             model_id=self.model_name,
             override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
             conversation_id=self.conversation.id,
@@ -131,7 +134,7 @@ class ConversationMessageTask:
             answer_unit_price=0,
             provider_response_latency=0,
             total_price=0,
-            currency=llm_constant.model_currency,
+            currency=self.model_instance.get_currency(),
             from_source=('console' if isinstance(self.user, Account) else 'api'),
             from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
             from_account_id=(self.user.id if isinstance(self.user, Account) else None),
@@ -145,12 +148,10 @@ class ConversationMessageTask:
         self._pub_handler.pub_text(text)
 
     def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
-        model_name = self.app_model_config.model_dict.get('name')
-
         message_tokens = llm_message.prompt_tokens
         answer_tokens = llm_message.completion_tokens
-        message_unit_price = llm_constant.model_prices[model_name]['prompt']
-        answer_unit_price = llm_constant.model_prices[model_name]['completion']
+        message_unit_price = self.model_instance.get_token_price(1, MessageType.HUMAN)
+        answer_unit_price = self.model_instance.get_token_price(1, MessageType.ASSISTANT)
 
         total_price = self.calc_total_price(message_tokens, message_unit_price, answer_tokens, answer_unit_price)
 
@@ -163,8 +164,6 @@ class ConversationMessageTask:
         self.message.provider_response_latency = llm_message.latency
         self.message.total_price = total_price
 
-        self.update_provider_quota()
-
         db.session.commit()
 
         message_was_created.send(
@@ -176,20 +175,6 @@ class ConversationMessageTask:
         if not by_stopped:
             self.end()
 
-    def update_provider_quota(self):
-        llm_provider_service = LLMProviderService(
-            tenant_id=self.app.tenant_id,
-            provider_name=self.message.model_provider,
-        )
-
-        provider = llm_provider_service.get_provider_db_record()
-        if provider and provider.provider_type == ProviderType.SYSTEM.value:
-            db.session.query(Provider).filter(
-                Provider.tenant_id == self.app.tenant_id,
-                Provider.provider_name == provider.provider_name,
-                Provider.quota_limit > Provider.quota_used
-            ).update({'quota_used': Provider.quota_used + 1})
-
     def init_chain(self, chain_result: ChainResult):
         message_chain = MessageChain(
             message_id=self.message.id,
@@ -229,10 +214,10 @@ class ConversationMessageTask:
 
         return message_agent_thought
 
-    def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_name: str,
+    def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM,
                      agent_loop: AgentLoop):
-        agent_message_unit_price = llm_constant.model_prices[agent_model_name]['prompt']
-        agent_answer_unit_price = llm_constant.model_prices[agent_model_name]['completion']
+        agent_message_unit_price = agent_model_instant.get_token_price(1, MessageType.HUMAN)
+        agent_answer_unit_price = agent_model_instant.get_token_price(1, MessageType.ASSISTANT)
 
         loop_message_tokens = agent_loop.prompt_tokens
         loop_answer_tokens = agent_loop.completion_tokens
@@ -253,7 +238,7 @@ class ConversationMessageTask:
         message_agent_thought.latency = agent_loop.latency
         message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
         message_agent_thought.total_price = loop_total_price
-        message_agent_thought.currency = llm_constant.model_currency
+        message_agent_thought.currency = agent_model_instant.get_currency()
         db.session.flush()
 
     def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):

+ 6 - 8
api/core/docstore/dataset_docstore.py

@@ -3,7 +3,7 @@ from typing import Any, Dict, Optional, Sequence
 from langchain.schema import Document
 from sqlalchemy import func
 
-from core.llm.token_calculator import TokenCalculator
+from core.model_providers.model_factory import ModelFactory
 from extensions.ext_database import db
 from models.dataset import Dataset, DocumentSegment
 
@@ -13,12 +13,10 @@ class DatesetDocumentStore:
         self,
         dataset: Dataset,
         user_id: str,
-        embedding_model_name: str,
         document_id: Optional[str] = None,
     ):
         self._dataset = dataset
         self._user_id = user_id
-        self._embedding_model_name = embedding_model_name
         self._document_id = document_id
 
     @classmethod
@@ -39,10 +37,6 @@ class DatesetDocumentStore:
     def user_id(self) -> Any:
         return self._user_id
 
-    @property
-    def embedding_model_name(self) -> Any:
-        return self._embedding_model_name
-
     @property
     def docs(self) -> Dict[str, Document]:
         document_segments = db.session.query(DocumentSegment).filter(
@@ -74,6 +68,10 @@ class DatesetDocumentStore:
         if max_position is None:
             max_position = 0
 
+        embedding_model = ModelFactory.get_embedding_model(
+            tenant_id=self._dataset.tenant_id
+        )
+
         for doc in docs:
             if not isinstance(doc, Document):
                 raise ValueError("doc must be a Document")
@@ -88,7 +86,7 @@ class DatesetDocumentStore:
                 )
 
             # calc embedding use tokens
-            tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.page_content)
+            tokens = embedding_model.get_num_tokens(doc.page_content)
 
             if not segment_document:
                 max_position += 1

+ 33 - 25
api/core/embedding/cached_embedding.py

@@ -4,14 +4,14 @@ from typing import List
 from langchain.embeddings.base import Embeddings
 from sqlalchemy.exc import IntegrityError
 
-from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
+from core.model_providers.models.embedding.base import BaseEmbedding
 from extensions.ext_database import db
 from libs import helper
 from models.dataset import Embedding
 
 
 class CacheEmbedding(Embeddings):
-    def __init__(self, embeddings: Embeddings):
+    def __init__(self, embeddings: BaseEmbedding):
         self._embeddings = embeddings
 
     def embed_documents(self, texts: List[str]) -> List[List[float]]:
@@ -21,48 +21,54 @@ class CacheEmbedding(Embeddings):
         embedding_queue_texts = []
         for text in texts:
             hash = helper.generate_text_hash(text)
-            embedding = db.session.query(Embedding).filter_by(hash=hash).first()
+            embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first()
             if embedding:
                 text_embeddings.append(embedding.get_embedding())
             else:
                 embedding_queue_texts.append(text)
 
-        embedding_results = self._embeddings.embed_documents(embedding_queue_texts)
+        if embedding_queue_texts:
+            try:
+                embedding_results = self._embeddings.client.embed_documents(embedding_queue_texts)
+            except Exception as ex:
+                raise self._embeddings.handle_exceptions(ex)
 
-        i = 0
-        for text in embedding_queue_texts:
-            hash = helper.generate_text_hash(text)
+            i = 0
+            for text in embedding_queue_texts:
+                hash = helper.generate_text_hash(text)
 
-            try:
-                embedding = Embedding(hash=hash)
-                embedding.set_embedding(embedding_results[i])
-                db.session.add(embedding)
-                db.session.commit()
-            except IntegrityError:
-                db.session.rollback()
-                continue
-            except:
-                logging.exception('Failed to add embedding to db')
-                continue
-            finally:
-                i += 1
+                try:
+                    embedding = Embedding(model_name=self._embeddings.name, hash=hash)
+                    embedding.set_embedding(embedding_results[i])
+                    db.session.add(embedding)
+                    db.session.commit()
+                except IntegrityError:
+                    db.session.rollback()
+                    continue
+                except:
+                    logging.exception('Failed to add embedding to db')
+                    continue
+                finally:
+                    i += 1
 
-        text_embeddings.extend(embedding_results)
+            text_embeddings.extend(embedding_results)
         return text_embeddings
 
-    @handle_openai_exceptions
     def embed_query(self, text: str) -> List[float]:
         """Embed query text."""
         # use doc embedding cache or store if not exists
         hash = helper.generate_text_hash(text)
-        embedding = db.session.query(Embedding).filter_by(hash=hash).first()
+        embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first()
         if embedding:
             return embedding.get_embedding()
 
-        embedding_results = self._embeddings.embed_query(text)
+        try:
+            embedding_results = self._embeddings.client.embed_query(text)
+        except Exception as ex:
+            raise self._embeddings.handle_exceptions(ex)
 
         try:
-            embedding = Embedding(hash=hash)
+            embedding = Embedding(model_name=self._embeddings.name, hash=hash)
             embedding.set_embedding(embedding_results)
             db.session.add(embedding)
             db.session.commit()
@@ -72,3 +78,5 @@ class CacheEmbedding(Embeddings):
             logging.exception('Failed to add embedding to db')
 
         return embedding_results
+
+

+ 61 - 83
api/core/generator/llm_generator.py

@@ -1,13 +1,10 @@
 import logging
 
-from langchain import PromptTemplate
-from langchain.chat_models.base import BaseChatModel
-from langchain.schema import HumanMessage, OutputParserException, BaseMessage, SystemMessage
-
-from core.constant import llm_constant
-from core.llm.llm_builder import LLMBuilder
-from core.llm.streamable_open_ai import StreamableOpenAI
-from core.llm.token_calculator import TokenCalculator
+from langchain.schema import OutputParserException
+
+from core.model_providers.model_factory import ModelFactory
+from core.model_providers.models.entity.message import PromptMessage, MessageType
+from core.model_providers.models.entity.model_params import ModelKwargs
 from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
 
 from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
@@ -15,9 +12,6 @@ from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTempla
 from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \
     GENERATOR_QA_PROMPT
 
-# gpt-3.5-turbo works not well
-generate_base_model = 'text-davinci-003'
-
 
 class LLMGenerator:
     @classmethod
@@ -28,29 +22,35 @@ class LLMGenerator:
             query = query[:300] + "...[TRUNCATED]..." + query[-300:]
 
         prompt = prompt.format(query=query)
-        llm: StreamableOpenAI = LLMBuilder.to_llm(
+
+        model_instance = ModelFactory.get_text_generation_model(
             tenant_id=tenant_id,
-            model_name='gpt-3.5-turbo',
-            max_tokens=50,
-            timeout=600
+            model_kwargs=ModelKwargs(
+                max_tokens=50
+            )
         )
 
-        if isinstance(llm, BaseChatModel):
-            prompt = [HumanMessage(content=prompt)]
-
-        response = llm.generate([prompt])
-        answer = response.generations[0][0].text
+        prompts = [PromptMessage(content=prompt)]
+        response = model_instance.run(prompts)
+        answer = response.content
         return answer.strip()
 
     @classmethod
     def generate_conversation_summary(cls, tenant_id: str, messages):
         max_tokens = 200
-        model = 'gpt-3.5-turbo'
+
+        model_instance = ModelFactory.get_text_generation_model(
+            tenant_id=tenant_id,
+            model_kwargs=ModelKwargs(
+                max_tokens=max_tokens
+            )
+        )
 
         prompt = CONVERSATION_SUMMARY_PROMPT
         prompt_with_empty_context = prompt.format(context='')
-        prompt_tokens = TokenCalculator.get_num_tokens(model, prompt_with_empty_context)
-        rest_tokens = llm_constant.max_context_token_length[model] - prompt_tokens - max_tokens - 1
+        prompt_tokens = model_instance.get_num_tokens([PromptMessage(content=prompt_with_empty_context)])
+        max_context_token_length = model_instance.model_rules.max_tokens.max
+        rest_tokens = max_context_token_length - prompt_tokens - max_tokens - 1
 
         context = ''
         for message in messages:
@@ -68,25 +68,16 @@ class LLMGenerator:
                 answer = message.answer
 
             message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer
-            if rest_tokens - TokenCalculator.get_num_tokens(model, context + message_qa_text) > 0:
+            if rest_tokens - model_instance.get_num_tokens([PromptMessage(content=context + message_qa_text)]) > 0:
                 context += message_qa_text
 
         if not context:
             return '[message too long, no summary]'
 
         prompt = prompt.format(context=context)
-
-        llm: StreamableOpenAI = LLMBuilder.to_llm(
-            tenant_id=tenant_id,
-            model_name=model,
-            max_tokens=max_tokens
-        )
-
-        if isinstance(llm, BaseChatModel):
-            prompt = [HumanMessage(content=prompt)]
-
-        response = llm.generate([prompt])
-        answer = response.generations[0][0].text
+        prompts = [PromptMessage(content=prompt)]
+        response = model_instance.run(prompts)
+        answer = response.content
         return answer.strip()
 
     @classmethod
@@ -94,16 +85,13 @@ class LLMGenerator:
         prompt = INTRODUCTION_GENERATE_PROMPT
         prompt = prompt.format(prompt=pre_prompt)
 
-        llm: StreamableOpenAI = LLMBuilder.to_llm(
-            tenant_id=tenant_id,
-            model_name=generate_base_model,
+        model_instance = ModelFactory.get_text_generation_model(
+            tenant_id=tenant_id
         )
 
-        if isinstance(llm, BaseChatModel):
-            prompt = [HumanMessage(content=prompt)]
-
-        response = llm.generate([prompt])
-        answer = response.generations[0][0].text
+        prompts = [PromptMessage(content=prompt)]
+        response = model_instance.run(prompts)
+        answer = response.content
         return answer.strip()
 
     @classmethod
@@ -119,23 +107,19 @@ class LLMGenerator:
 
         _input = prompt.format_prompt(histories=histories)
 
-        llm: StreamableOpenAI = LLMBuilder.to_llm(
+        model_instance = ModelFactory.get_text_generation_model(
             tenant_id=tenant_id,
-            model_name='gpt-3.5-turbo',
-            temperature=0,
-            max_tokens=256
+            model_kwargs=ModelKwargs(
+                max_tokens=256,
+                temperature=0
+            )
         )
 
-        if isinstance(llm, BaseChatModel):
-            query = [HumanMessage(content=_input.to_string())]
-        else:
-            query = _input.to_string()
+        prompts = [PromptMessage(content=_input.to_string())]
 
         try:
-            output = llm(query)
-            if isinstance(output, BaseMessage):
-                output = output.content
-            questions = output_parser.parse(output)
+            output = model_instance.run(prompts)
+            questions = output_parser.parse(output.content)
         except Exception:
             logging.exception("Error generating suggested questions after answer")
             questions = []
@@ -160,21 +144,19 @@ class LLMGenerator:
 
         _input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve)
 
-        llm: StreamableOpenAI = LLMBuilder.to_llm(
+        model_instance = ModelFactory.get_text_generation_model(
             tenant_id=tenant_id,
-            model_name=generate_base_model,
-            temperature=0,
-            max_tokens=512
+            model_kwargs=ModelKwargs(
+                max_tokens=512,
+                temperature=0
+            )
         )
 
-        if isinstance(llm, BaseChatModel):
-            query = [HumanMessage(content=_input.to_string())]
-        else:
-            query = _input.to_string()
+        prompts = [PromptMessage(content=_input.to_string())]
 
         try:
-            output = llm(query)
-            rule_config = output_parser.parse(output)
+            output = model_instance.run(prompts)
+            rule_config = output_parser.parse(output.content)
         except OutputParserException:
             raise ValueError('Please give a valid input for intended audience or hoping to solve problems.')
         except Exception:
@@ -188,25 +170,21 @@ class LLMGenerator:
         return rule_config
 
     @classmethod
-    async def generate_qa_document(cls, llm: StreamableOpenAI, query):
+    def generate_qa_document(cls, tenant_id: str, query):
         prompt = GENERATOR_QA_PROMPT
 
+        model_instance = ModelFactory.get_text_generation_model(
+            tenant_id=tenant_id,
+            model_kwargs=ModelKwargs(
+                max_tokens=2000
+            )
+        )
 
-        if isinstance(llm, BaseChatModel):
-            prompt = [SystemMessage(content=prompt), HumanMessage(content=query)]
-
-        response = llm.generate([prompt])
-        answer = response.generations[0][0].text
-        return answer.strip()
-
-    @classmethod
-    def generate_qa_document_sync(cls, llm: StreamableOpenAI, query):
-        prompt = GENERATOR_QA_PROMPT
-
-
-        if isinstance(llm, BaseChatModel):
-            prompt = [SystemMessage(content=prompt), HumanMessage(content=query)]
+        prompts = [
+            PromptMessage(content=prompt, type=MessageType.SYSTEM),
+            PromptMessage(content=query)
+        ]
 
-        response = llm.generate([prompt])
-        answer = response.generations[0][0].text
+        response = model_instance.run(prompts)
+        answer = response.content
         return answer.strip()

+ 0 - 0
api/tests/test_helpers/__init__.py → api/core/helper/__init__.py


+ 20 - 0
api/core/helper/encrypter.py

@@ -0,0 +1,20 @@
+import base64
+
+from extensions.ext_database import db
+from libs import rsa
+
+from models.account import Tenant
+
+
+def obfuscated_token(token: str):
+    return token[:6] + '*' * (len(token) - 8) + token[-2:]
+
+
+def encrypt_token(tenant_id: str, token: str):
+    tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).first()
+    encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
+    return base64.b64encode(encrypted_token).decode()
+
+
+def decrypt_token(tenant_id: str, token: str):
+    return rsa.decrypt(base64.b64decode(token), tenant_id)

+ 4 - 10
api/core/index/index.py

@@ -1,10 +1,9 @@
 from flask import current_app
-from langchain.embeddings import OpenAIEmbeddings
 
 from core.embedding.cached_embedding import CacheEmbedding
 from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
 from core.index.vector_index.vector_index import VectorIndex
-from core.llm.llm_builder import LLMBuilder
+from core.model_providers.model_factory import ModelFactory
 from models.dataset import Dataset
 
 
@@ -15,16 +14,11 @@ class IndexBuilder:
             if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality':
                 return None
 
-            model_credentials = LLMBuilder.get_model_credentials(
-                tenant_id=dataset.tenant_id,
-                model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'),
-                model_name='text-embedding-ada-002'
+            embedding_model = ModelFactory.get_embedding_model(
+                tenant_id=dataset.tenant_id
             )
 
-            embeddings = CacheEmbedding(OpenAIEmbeddings(
-                max_retries=1,
-                **model_credentials
-            ))
+            embeddings = CacheEmbedding(embedding_model)
 
             return VectorIndex(
                 dataset=dataset,

+ 46 - 43
api/core/indexing_runner.py

@@ -1,4 +1,3 @@
-import concurrent
 import datetime
 import json
 import logging
@@ -6,7 +5,6 @@ import re
 import threading
 import time
 import uuid
-from concurrent.futures import ThreadPoolExecutor
 from typing import Optional, List, cast
 
 from flask_login import current_user
@@ -18,11 +16,10 @@ from core.data_loader.loader.notion import NotionLoader
 from core.docstore.dataset_docstore import DatesetDocumentStore
 from core.generator.llm_generator import LLMGenerator
 from core.index.index import IndexBuilder
-from core.llm.error import ProviderTokenNotInitError
-from core.llm.llm_builder import LLMBuilder
-from core.llm.streamable_open_ai import StreamableOpenAI
+from core.model_providers.error import ProviderTokenNotInitError
+from core.model_providers.model_factory import ModelFactory
+from core.model_providers.models.entity.message import MessageType
 from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
-from core.llm.token_calculator import TokenCalculator
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from extensions.ext_storage import storage
@@ -35,9 +32,8 @@ from models.source import DataSourceBinding
 
 class IndexingRunner:
 
-    def __init__(self, embedding_model_name: str = "text-embedding-ada-002"):
+    def __init__(self):
         self.storage = storage
-        self.embedding_model_name = embedding_model_name
 
     def run(self, dataset_documents: List[DatasetDocument]):
         """Run the indexing process."""
@@ -227,11 +223,15 @@ class IndexingRunner:
             dataset_document.stopped_at = datetime.datetime.utcnow()
             db.session.commit()
 
-    def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict,
+    def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict,
                                doc_form: str = None) -> dict:
         """
         Estimate the indexing for the document.
         """
+        embedding_model = ModelFactory.get_embedding_model(
+            tenant_id=tenant_id
+        )
+
         tokens = 0
         preview_texts = []
         total_segments = 0
@@ -253,44 +253,49 @@ class IndexingRunner:
                 splitter=splitter,
                 processing_rule=processing_rule
             )
+
             total_segments += len(documents)
+
             for document in documents:
                 if len(preview_texts) < 5:
                     preview_texts.append(document.page_content)
 
-                tokens += TokenCalculator.get_num_tokens(self.embedding_model_name,
-                                                         self.filter_string(document.page_content))
+                tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content))
+
+        text_generation_model = ModelFactory.get_text_generation_model(
+            tenant_id=tenant_id
+        )
+
         if doc_form and doc_form == 'qa_model':
             if len(preview_texts) > 0:
                 # qa model document
-                llm: StreamableOpenAI = LLMBuilder.to_llm(
-                    tenant_id=current_user.current_tenant_id,
-                    model_name='gpt-3.5-turbo',
-                    max_tokens=2000
-                )
-                response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0])
+                response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0])
                 document_qa_list = self.format_split_text(response)
                 return {
                     "total_segments": total_segments * 20,
                     "tokens": total_segments * 2000,
                     "total_price": '{:f}'.format(
-                        TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')),
-                    "currency": TokenCalculator.get_currency(self.embedding_model_name),
+                        text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)),
+                    "currency": embedding_model.get_currency(),
                     "qa_preview": document_qa_list,
                     "preview": preview_texts
                 }
         return {
             "total_segments": total_segments,
             "tokens": tokens,
-            "total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)),
-            "currency": TokenCalculator.get_currency(self.embedding_model_name),
+            "total_price": '{:f}'.format(embedding_model.get_token_price(tokens)),
+            "currency": embedding_model.get_currency(),
             "preview": preview_texts
         }
 
-    def notion_indexing_estimate(self, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict:
+    def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict:
         """
         Estimate the indexing for the document.
         """
+        embedding_model = ModelFactory.get_embedding_model(
+            tenant_id=tenant_id
+        )
+
         # load data from notion
         tokens = 0
         preview_texts = []
@@ -336,31 +341,31 @@ class IndexingRunner:
                     if len(preview_texts) < 5:
                         preview_texts.append(document.page_content)
 
-                    tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
+                    tokens += embedding_model.get_num_tokens(document.page_content)
+
+        text_generation_model = ModelFactory.get_text_generation_model(
+            tenant_id=tenant_id
+        )
+
         if doc_form and doc_form == 'qa_model':
             if len(preview_texts) > 0:
                 # qa model document
-                llm: StreamableOpenAI = LLMBuilder.to_llm(
-                    tenant_id=current_user.current_tenant_id,
-                    model_name='gpt-3.5-turbo',
-                    max_tokens=2000
-                )
-                response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0])
+                response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0])
                 document_qa_list = self.format_split_text(response)
                 return {
                     "total_segments": total_segments * 20,
                     "tokens": total_segments * 2000,
                     "total_price": '{:f}'.format(
-                        TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')),
-                    "currency": TokenCalculator.get_currency(self.embedding_model_name),
+                        text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)),
+                    "currency": embedding_model.get_currency(),
                     "qa_preview": document_qa_list,
                     "preview": preview_texts
                 }
         return {
             "total_segments": total_segments,
             "tokens": tokens,
-            "total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)),
-            "currency": TokenCalculator.get_currency(self.embedding_model_name),
+            "total_price": '{:f}'.format(embedding_model.get_token_price(tokens)),
+            "currency": embedding_model.get_currency(),
             "preview": preview_texts
         }
 
@@ -459,7 +464,6 @@ class IndexingRunner:
         doc_store = DatesetDocumentStore(
             dataset=dataset,
             user_id=dataset_document.created_by,
-            embedding_model_name=self.embedding_model_name,
             document_id=dataset_document.id
         )
 
@@ -513,17 +517,12 @@ class IndexingRunner:
             all_documents.extend(split_documents)
         # processing qa document
         if document_form == 'qa_model':
-            llm: StreamableOpenAI = LLMBuilder.to_llm(
-                tenant_id=tenant_id,
-                model_name='gpt-3.5-turbo',
-                max_tokens=2000
-            )
             for i in range(0, len(all_documents), 10):
                 threads = []
                 sub_documents = all_documents[i:i + 10]
                 for doc in sub_documents:
                     document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={
-                        'llm': llm, 'document_node': doc, 'all_qa_documents': all_qa_documents})
+                        'tenant_id': tenant_id, 'document_node': doc, 'all_qa_documents': all_qa_documents})
                     threads.append(document_format_thread)
                     document_format_thread.start()
                 for thread in threads:
@@ -531,13 +530,13 @@ class IndexingRunner:
             return all_qa_documents
         return all_documents
 
-    def format_qa_document(self, llm: StreamableOpenAI, document_node, all_qa_documents):
+    def format_qa_document(self, tenant_id: str, document_node, all_qa_documents):
         format_documents = []
         if document_node.page_content is None or not document_node.page_content.strip():
             return
         try:
             # qa model document
-            response = LLMGenerator.generate_qa_document_sync(llm, document_node.page_content)
+            response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content)
             document_qa_list = self.format_split_text(response)
             qa_documents = []
             for result in document_qa_list:
@@ -638,6 +637,10 @@ class IndexingRunner:
         vector_index = IndexBuilder.get_index(dataset, 'high_quality')
         keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
 
+        embedding_model = ModelFactory.get_embedding_model(
+            tenant_id=dataset.tenant_id
+        )
+
         # chunk nodes by chunk size
         indexing_start_at = time.perf_counter()
         tokens = 0
@@ -648,7 +651,7 @@ class IndexingRunner:
             chunk_documents = documents[i:i + chunk_size]
 
             tokens += sum(
-                TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
+                embedding_model.get_num_tokens(document.page_content)
                 for document in chunk_documents
             )
 

+ 0 - 148
api/core/llm/llm_builder.py

@@ -1,148 +0,0 @@
-from typing import Union, Optional, List
-
-from langchain.callbacks.base import BaseCallbackHandler
-
-from core.constant import llm_constant
-from core.llm.error import ProviderTokenNotInitError
-from core.llm.provider.base import BaseProvider
-from core.llm.provider.llm_provider_service import LLMProviderService
-from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI
-from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI
-from core.llm.streamable_chat_anthropic import StreamableChatAnthropic
-from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
-from core.llm.streamable_open_ai import StreamableOpenAI
-from models.provider import ProviderType, ProviderName
-
-
-class LLMBuilder:
-    """
-    This class handles the following logic:
-    1. For providers with the name 'OpenAI', the OPENAI_API_KEY value is stored directly in encrypted_config.
-    2. For providers with the name 'Azure OpenAI', encrypted_config stores the serialized values of four fields, as shown below:
-       OPENAI_API_TYPE=azure
-       OPENAI_API_VERSION=2022-12-01
-       OPENAI_API_BASE=https://your-resource-name.openai.azure.com
-       OPENAI_API_KEY=<your Azure OpenAI API key>
-    3. For providers with the name 'Anthropic', the ANTHROPIC_API_KEY value is stored directly in encrypted_config.
-    4. For providers with the name 'Cohere', the COHERE_API_KEY value is stored directly in encrypted_config.
-    5. For providers with the name 'HUGGINGFACEHUB', the HUGGINGFACEHUB_API_KEY value is stored directly in encrypted_config.
-    6. Providers with the provider_type 'CUSTOM' can be created through the admin interface, while 'System' providers cannot be created through the admin interface.
-    7. If both CUSTOM and System providers exist in the records, the CUSTOM provider is preferred by default, but this preference can be changed via an input parameter.
-    8. For providers with the provider_type 'System', the quota_used must not exceed quota_limit. If the quota is exceeded, the provider cannot be used. Currently, only the TRIAL quota_type is supported, which is permanently non-resetting.
-    """
-
-    @classmethod
-    def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
-        provider = cls.get_default_provider(tenant_id, model_name)
-
-        model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
-
-        llm_cls = None
-        mode = cls.get_mode_by_model(model_name)
-        if mode == 'chat':
-            if provider == ProviderName.OPENAI.value:
-                llm_cls = StreamableChatOpenAI
-            elif provider == ProviderName.AZURE_OPENAI.value:
-                llm_cls = StreamableAzureChatOpenAI
-            elif provider == ProviderName.ANTHROPIC.value:
-                llm_cls = StreamableChatAnthropic
-        elif mode == 'completion':
-            if provider == ProviderName.OPENAI.value:
-                llm_cls = StreamableOpenAI
-            elif provider == ProviderName.AZURE_OPENAI.value:
-                llm_cls = StreamableAzureOpenAI
-
-        if not llm_cls:
-            raise ValueError(f"model name {model_name} is not supported.")
-
-        model_kwargs = {
-            'model_name': model_name,
-            'temperature': kwargs.get('temperature', 0),
-            'max_tokens': kwargs.get('max_tokens', 256),
-            'top_p': kwargs.get('top_p', 1),
-            'frequency_penalty': kwargs.get('frequency_penalty', 0),
-            'presence_penalty': kwargs.get('presence_penalty', 0),
-            'callbacks': kwargs.get('callbacks', None),
-            'streaming': kwargs.get('streaming', False),
-        }
-
-        model_kwargs.update(model_credentials)
-        model_kwargs = llm_cls.get_kwargs_from_model_params(model_kwargs)
-
-        return llm_cls(**model_kwargs)
-
-    @classmethod
-    def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
-                          callbacks: Optional[List[BaseCallbackHandler]] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
-        model_name = model.get("name")
-        completion_params = model.get("completion_params", {})
-
-        return cls.to_llm(
-            tenant_id=tenant_id,
-            model_name=model_name,
-            temperature=completion_params.get('temperature', 0),
-            max_tokens=completion_params.get('max_tokens', 256),
-            top_p=completion_params.get('top_p', 0),
-            frequency_penalty=completion_params.get('frequency_penalty', 0.1),
-            presence_penalty=completion_params.get('presence_penalty', 0.1),
-            streaming=streaming,
-            callbacks=callbacks
-        )
-
-    @classmethod
-    def get_mode_by_model(cls, model_name: str) -> str:
-        if not model_name:
-            raise ValueError(f"empty model name is not supported.")
-
-        if model_name in llm_constant.models_by_mode['chat']:
-            return "chat"
-        elif model_name in llm_constant.models_by_mode['completion']:
-            return "completion"
-        else:
-            raise ValueError(f"model name {model_name} is not supported.")
-
-    @classmethod
-    def get_model_credentials(cls, tenant_id: str, model_provider: str, model_name: str) -> dict:
-        """
-        Returns the API credentials for the given tenant_id and model_name, based on the model's provider.
-        Raises an exception if the model_name is not found or if the provider is not found.
-        """
-        if not model_name:
-            raise Exception('model name not found')
-        #
-        # if model_name not in llm_constant.models:
-        #     raise Exception('model {} not found'.format(model_name))
-
-        # model_provider = llm_constant.models[model_name]
-
-        provider_service = LLMProviderService(tenant_id=tenant_id, provider_name=model_provider)
-        return provider_service.get_credentials(model_name)
-
-    @classmethod
-    def get_default_provider(cls, tenant_id: str, model_name: str) -> str:
-        provider_name = llm_constant.models[model_name]
-
-        if provider_name == 'openai':
-            # get the default provider (openai / azure_openai) for the tenant
-            openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.OPENAI.value)
-            azure_openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.AZURE_OPENAI.value)
-
-            provider = None
-            if openai_provider and openai_provider.provider_type == ProviderType.CUSTOM.value:
-                provider = openai_provider
-            elif azure_openai_provider and azure_openai_provider.provider_type == ProviderType.CUSTOM.value:
-                provider = azure_openai_provider
-            elif openai_provider and openai_provider.provider_type == ProviderType.SYSTEM.value:
-                provider = openai_provider
-            elif azure_openai_provider and azure_openai_provider.provider_type == ProviderType.SYSTEM.value:
-                provider = azure_openai_provider
-
-            if not provider:
-                raise ProviderTokenNotInitError(
-                    f"No valid {provider_name} model provider credentials found. "
-                    f"Please go to Settings -> Model Provider to complete your provider credentials."
-                )
-
-            provider_name = provider.provider_name
-
-        return provider_name

+ 0 - 15
api/core/llm/moderation.py

@@ -1,15 +0,0 @@
-import openai
-from models.provider import ProviderName
-
-
-class Moderation:
-
-    def __init__(self, provider: str, api_key: str):
-        self.provider = provider
-        self.api_key = api_key
-
-        if self.provider == ProviderName.OPENAI.value:
-            self.client = openai.Moderation
-
-    def moderate(self, text):
-        return self.client.create(input=text, api_key=self.api_key)

+ 0 - 138
api/core/llm/provider/anthropic_provider.py

@@ -1,138 +0,0 @@
-import json
-import logging
-from typing import Optional, Union
-
-import anthropic
-from langchain.chat_models import ChatAnthropic
-from langchain.schema import HumanMessage
-
-from core import hosted_llm_credentials
-from core.llm.error import ProviderTokenNotInitError
-from core.llm.provider.base import BaseProvider
-from core.llm.provider.errors import ValidateFailedError
-from models.provider import ProviderName, ProviderType
-
-
-class AnthropicProvider(BaseProvider):
-    def get_models(self, model_id: Optional[str] = None) -> list[dict]:
-        return [
-            {
-                'id': 'claude-instant-1',
-                'name': 'claude-instant-1',
-            },
-            {
-                'id': 'claude-2',
-                'name': 'claude-2',
-            },
-        ]
-
-    def get_credentials(self, model_id: Optional[str] = None) -> dict:
-        return self.get_provider_api_key(model_id=model_id)
-
-    def get_provider_name(self):
-        return ProviderName.ANTHROPIC
-
-    def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
-        """
-        Returns the provider configs.
-        """
-        try:
-            config = self.get_provider_api_key(only_custom=only_custom)
-        except:
-            config = {
-                'anthropic_api_key': ''
-            }
-
-        if obfuscated:
-            if not config.get('anthropic_api_key'):
-                config = {
-                    'anthropic_api_key': ''
-                }
-
-            config['anthropic_api_key'] = self.obfuscated_token(config.get('anthropic_api_key'))
-            return config
-
-        return config
-
-    def get_encrypted_token(self, config: Union[dict | str]):
-        """
-        Returns the encrypted token.
-        """
-        return json.dumps({
-            'anthropic_api_key': self.encrypt_token(config['anthropic_api_key'])
-        })
-
-    def get_decrypted_token(self, token: str):
-        """
-        Returns the decrypted token.
-        """
-        config = json.loads(token)
-        config['anthropic_api_key'] = self.decrypt_token(config['anthropic_api_key'])
-        return config
-
-    def get_token_type(self):
-        return dict
-
-    def config_validate(self, config: Union[dict | str]):
-        """
-        Validates the given config.
-        """
-        # check OpenAI / Azure OpenAI credential is valid
-        openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.OPENAI.value)
-        azure_openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.AZURE_OPENAI.value)
-
-        provider = None
-        if openai_provider:
-            provider = openai_provider
-        elif azure_openai_provider:
-            provider = azure_openai_provider
-
-        if not provider:
-            raise ValidateFailedError(f"OpenAI or Azure OpenAI provider must be configured first.")
-
-        if provider.provider_type == ProviderType.SYSTEM.value:
-            quota_used = provider.quota_used if provider.quota_used is not None else 0
-            quota_limit = provider.quota_limit if provider.quota_limit is not None else 0
-            if quota_used >= quota_limit:
-                raise ValidateFailedError(f"Your quota for Dify Hosted OpenAI has been exhausted, "
-                                          f"please configure OpenAI or Azure OpenAI provider first.")
-
-        try:
-            if not isinstance(config, dict):
-                raise ValueError('Config must be a object.')
-
-            if 'anthropic_api_key' not in config:
-                raise ValueError('anthropic_api_key must be provided.')
-
-            chat_llm = ChatAnthropic(
-                model='claude-instant-1',
-                anthropic_api_key=config['anthropic_api_key'],
-                max_tokens_to_sample=10,
-                temperature=0,
-                default_request_timeout=60
-            )
-
-            messages = [
-                HumanMessage(
-                    content="ping"
-                )
-            ]
-
-            chat_llm(messages)
-        except anthropic.APIConnectionError as ex:
-            raise ValidateFailedError(f"Anthropic: Connection error, cause: {ex.__cause__}")
-        except (anthropic.APIStatusError, anthropic.RateLimitError) as ex:
-            raise ValidateFailedError(f"Anthropic: Error code: {ex.status_code} - "
-                                      f"{ex.body['error']['type']}: {ex.body['error']['message']}")
-        except Exception as ex:
-            logging.exception('Anthropic config validation failed')
-            raise ex
-
-    def get_hosted_credentials(self) -> Union[str | dict]:
-        if not hosted_llm_credentials.anthropic or not hosted_llm_credentials.anthropic.api_key:
-            raise ProviderTokenNotInitError(
-                f"No valid {self.get_provider_name().value} model provider credentials found. "
-                f"Please go to Settings -> Model Provider to complete your provider credentials."
-            )
-
-        return {'anthropic_api_key': hosted_llm_credentials.anthropic.api_key}

+ 0 - 145
api/core/llm/provider/azure_provider.py

@@ -1,145 +0,0 @@
-import json
-import logging
-from typing import Optional, Union
-
-import openai
-import requests
-
-from core.llm.provider.base import BaseProvider
-from core.llm.provider.errors import ValidateFailedError
-from models.provider import ProviderName
-
-
-AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
-
-
-class AzureProvider(BaseProvider):
-    def get_models(self, model_id: Optional[str] = None, credentials: Optional[dict] = None) -> list[dict]:
-        return []
-
-    def check_embedding_model(self, credentials: Optional[dict] = None):
-        credentials = self.get_credentials('text-embedding-ada-002') if not credentials else credentials
-        try:
-            result = openai.Embedding.create(input=['test'],
-                                             engine='text-embedding-ada-002',
-                                             timeout=60,
-                                             api_key=str(credentials.get('openai_api_key')),
-                                             api_base=str(credentials.get('openai_api_base')),
-                                             api_type='azure',
-                                             api_version=str(credentials.get('openai_api_version')))["data"][0][
-                "embedding"]
-        except openai.error.AuthenticationError as e:
-            raise AzureAuthenticationError(str(e))
-        except openai.error.APIConnectionError as e:
-            raise AzureRequestFailedError(
-                'Failed to request Azure OpenAI, please check your API Base Endpoint, The format is `https://xxx.openai.azure.com/`')
-        except openai.error.InvalidRequestError as e:
-            if e.http_status == 404:
-                raise AzureRequestFailedError("Please check your 'gpt-3.5-turbo' or 'text-embedding-ada-002' "
-                                              "deployment name is exists in Azure AI")
-            else:
-                raise AzureRequestFailedError(
-                    'Failed to request Azure OpenAI. cause: {}'.format(str(e)))
-        except openai.error.OpenAIError as e:
-            raise AzureRequestFailedError(
-                'Failed to request Azure OpenAI. cause: {}'.format(str(e)))
-
-        if not isinstance(result, list):
-            raise AzureRequestFailedError('Failed to request Azure OpenAI.')
-
-    def get_credentials(self, model_id: Optional[str] = None) -> dict:
-        """
-        Returns the API credentials for Azure OpenAI as a dictionary.
-        """
-        config = self.get_provider_api_key(model_id=model_id)
-        config['openai_api_type'] = 'azure'
-        config['openai_api_version'] = AZURE_OPENAI_API_VERSION
-        if model_id == 'text-embedding-ada-002':
-            config['deployment'] = model_id.replace('.', '') if model_id else None
-            config['chunk_size'] = 16
-        else:
-            config['deployment_name'] = model_id.replace('.', '') if model_id else None
-        return config
-
-    def get_provider_name(self):
-        return ProviderName.AZURE_OPENAI
-
-    def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
-        """
-        Returns the provider configs.
-        """
-        try:
-            config = self.get_provider_api_key(only_custom=only_custom)
-        except:
-            config = {
-                'openai_api_type': 'azure',
-                'openai_api_version': AZURE_OPENAI_API_VERSION,
-                'openai_api_base': '',
-                'openai_api_key': ''
-            }
-
-        if obfuscated:
-            if not config.get('openai_api_key'):
-                config = {
-                    'openai_api_type': 'azure',
-                    'openai_api_version': AZURE_OPENAI_API_VERSION,
-                    'openai_api_base': '',
-                    'openai_api_key': ''
-                }
-
-            config['openai_api_key'] = self.obfuscated_token(config.get('openai_api_key'))
-            return config
-
-        return config
-
-    def get_token_type(self):
-        return dict
-
-    def config_validate(self, config: Union[dict | str]):
-        """
-        Validates the given config.
-        """
-        try:
-            if not isinstance(config, dict):
-                raise ValueError('Config must be a object.')
-
-            if 'openai_api_version' not in config:
-                config['openai_api_version'] = AZURE_OPENAI_API_VERSION
-
-            self.check_embedding_model(credentials=config)
-        except ValidateFailedError as e:
-            raise e
-        except AzureAuthenticationError:
-            raise ValidateFailedError('Validation failed, please check your API Key.')
-        except AzureRequestFailedError as ex:
-            raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
-        except Exception as ex:
-            logging.exception('Azure OpenAI Credentials validation failed')
-            raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
-
-    def get_encrypted_token(self, config: Union[dict | str]):
-        """
-        Returns the encrypted token.
-        """
-        return json.dumps({
-            'openai_api_type': 'azure',
-            'openai_api_version': AZURE_OPENAI_API_VERSION,
-            'openai_api_base': config['openai_api_base'],
-            'openai_api_key': self.encrypt_token(config['openai_api_key'])
-        })
-
-    def get_decrypted_token(self, token: str):
-        """
-        Returns the decrypted token.
-        """
-        config = json.loads(token)
-        config['openai_api_key'] = self.decrypt_token(config['openai_api_key'])
-        return config
-
-
-class AzureAuthenticationError(Exception):
-    pass
-
-
-class AzureRequestFailedError(Exception):
-    pass

+ 0 - 132
api/core/llm/provider/base.py

@@ -1,132 +0,0 @@
-import base64
-from abc import ABC, abstractmethod
-from typing import Optional, Union
-
-from core.constant import llm_constant
-from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError
-from extensions.ext_database import db
-from libs import rsa
-from models.account import Tenant
-from models.provider import Provider, ProviderType, ProviderName
-
-
-class BaseProvider(ABC):
-    def __init__(self, tenant_id: str):
-        self.tenant_id = tenant_id
-
-    def get_provider_api_key(self, model_id: Optional[str] = None, only_custom: bool = False) -> Union[str | dict]:
-        """
-        Returns the decrypted API key for the given tenant_id and provider_name.
-        If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
-        If the provider is not found or not valid, raises a ProviderTokenNotInitError.
-        """
-        provider = self.get_provider(only_custom)
-        if not provider:
-            raise ProviderTokenNotInitError(
-                f"No valid {llm_constant.models[model_id]} model provider credentials found. "
-                f"Please go to Settings -> Model Provider to complete your provider credentials."
-            )
-
-        if provider.provider_type == ProviderType.SYSTEM.value:
-            quota_used = provider.quota_used if provider.quota_used is not None else 0
-            quota_limit = provider.quota_limit if provider.quota_limit is not None else 0
-
-            if model_id and model_id == 'gpt-4':
-                raise ModelCurrentlyNotSupportError()
-
-            if quota_used >= quota_limit:
-                raise QuotaExceededError()
-
-            return self.get_hosted_credentials()
-        else:
-            return self.get_decrypted_token(provider.encrypted_config)
-
-    def get_provider(self, only_custom: bool = False) -> Optional[Provider]:
-        """
-        Returns the Provider instance for the given tenant_id and provider_name.
-        If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
-        """
-        return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, only_custom)
-
-    @classmethod
-    def get_valid_provider(cls, tenant_id: str, provider_name: str = None, only_custom: bool = False) -> Optional[
-        Provider]:
-        """
-        Returns the Provider instance for the given tenant_id and provider_name.
-        If both CUSTOM and System providers exist.
-        """
-        query = db.session.query(Provider).filter(
-            Provider.tenant_id == tenant_id
-        )
-
-        if provider_name:
-            query = query.filter(Provider.provider_name == provider_name)
-
-        if only_custom:
-            query = query.filter(Provider.provider_type == ProviderType.CUSTOM.value)
-
-        providers = query.order_by(Provider.provider_type.asc()).all()
-
-        for provider in providers:
-            if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config:
-                return provider
-            elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid:
-                return provider
-
-        return None
-
-    def get_hosted_credentials(self) -> Union[str | dict]:
-        raise ProviderTokenNotInitError(
-            f"No valid {self.get_provider_name().value} model provider credentials found. "
-            f"Please go to Settings -> Model Provider to complete your provider credentials."
-        )
-
-    def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
-        """
-        Returns the provider configs.
-        """
-        try:
-            config = self.get_provider_api_key(only_custom=only_custom)
-        except:
-            config = ''
-
-        if obfuscated:
-            return self.obfuscated_token(config)
-
-        return config
-
-    def obfuscated_token(self, token: str):
-        return token[:6] + '*' * (len(token) - 8) + token[-2:]
-
-    def get_token_type(self):
-        return str
-
-    def get_encrypted_token(self, config: Union[dict | str]):
-        return self.encrypt_token(config)
-
-    def get_decrypted_token(self, token: str):
-        return self.decrypt_token(token)
-
-    def encrypt_token(self, token):
-        tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
-        encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
-        return base64.b64encode(encrypted_token).decode()
-
-    def decrypt_token(self, token):
-        return rsa.decrypt(base64.b64decode(token), self.tenant_id)
-
-    @abstractmethod
-    def get_provider_name(self):
-        raise NotImplementedError
-
-    @abstractmethod
-    def get_credentials(self, model_id: Optional[str] = None) -> dict:
-        raise NotImplementedError
-
-    @abstractmethod
-    def get_models(self, model_id: Optional[str] = None) -> list[dict]:
-        raise NotImplementedError
-
-    @abstractmethod
-    def config_validate(self, config: str):
-        raise NotImplementedError

+ 0 - 2
api/core/llm/provider/errors.py

@@ -1,2 +0,0 @@
-class ValidateFailedError(Exception):
-    description = "Provider Validate failed"

+ 0 - 22
api/core/llm/provider/huggingface_provider.py

@@ -1,22 +0,0 @@
-from typing import Optional
-
-from core.llm.provider.base import BaseProvider
-from models.provider import ProviderName
-
-
-class HuggingfaceProvider(BaseProvider):
-    def get_models(self, model_id: Optional[str] = None) -> list[dict]:
-        credentials = self.get_credentials(model_id)
-        # todo
-        return []
-
-    def get_credentials(self, model_id: Optional[str] = None) -> dict:
-        """
-        Returns the API credentials for Huggingface as a dictionary, for the given tenant_id.
-        """
-        return {
-            'huggingface_api_key': self.get_provider_api_key(model_id=model_id)
-        }
-
-    def get_provider_name(self):
-        return ProviderName.HUGGINGFACEHUB

+ 0 - 53
api/core/llm/provider/llm_provider_service.py

@@ -1,53 +0,0 @@
-from typing import Optional, Union
-
-from core.llm.provider.anthropic_provider import AnthropicProvider
-from core.llm.provider.azure_provider import AzureProvider
-from core.llm.provider.base import BaseProvider
-from core.llm.provider.huggingface_provider import HuggingfaceProvider
-from core.llm.provider.openai_provider import OpenAIProvider
-from models.provider import Provider
-
-
-class LLMProviderService:
-
-    def __init__(self, tenant_id: str, provider_name: str):
-        self.provider = self.init_provider(tenant_id, provider_name)
-
-    def init_provider(self, tenant_id: str, provider_name: str) -> BaseProvider:
-        if provider_name == 'openai':
-            return OpenAIProvider(tenant_id)
-        elif provider_name == 'azure_openai':
-            return AzureProvider(tenant_id)
-        elif provider_name == 'anthropic':
-            return AnthropicProvider(tenant_id)
-        elif provider_name == 'huggingface':
-            return HuggingfaceProvider(tenant_id)
-        else:
-            raise Exception('provider {} not found'.format(provider_name))
-
-    def get_models(self, model_id: Optional[str] = None) -> list[dict]:
-        return self.provider.get_models(model_id)
-
-    def get_credentials(self, model_id: Optional[str] = None) -> dict:
-        return self.provider.get_credentials(model_id)
-
-    def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
-        return self.provider.get_provider_configs(obfuscated=obfuscated, only_custom=only_custom)
-
-    def get_provider_db_record(self) -> Optional[Provider]:
-        return self.provider.get_provider()
-
-    def config_validate(self, config: Union[dict | str]):
-        """
-        Validates the given config.
-
-        :param config:
-        :raises: ValidateFailedError
-        """
-        return self.provider.config_validate(config)
-
-    def get_token_type(self):
-        return self.provider.get_token_type()
-
-    def get_encrypted_token(self, config: Union[dict | str]):
-        return self.provider.get_encrypted_token(config)

+ 0 - 55
api/core/llm/provider/openai_provider.py

@@ -1,55 +0,0 @@
-import logging
-from typing import Optional, Union
-
-import openai
-from openai.error import AuthenticationError, OpenAIError
-
-from core import hosted_llm_credentials
-from core.llm.error import ProviderTokenNotInitError
-from core.llm.moderation import Moderation
-from core.llm.provider.base import BaseProvider
-from core.llm.provider.errors import ValidateFailedError
-from models.provider import ProviderName
-
-
-class OpenAIProvider(BaseProvider):
-    def get_models(self, model_id: Optional[str] = None) -> list[dict]:
-        credentials = self.get_credentials(model_id)
-        response = openai.Model.list(**credentials)
-
-        return [{
-            'id': model['id'],
-            'name': model['id'],
-        } for model in response['data']]
-
-    def get_credentials(self, model_id: Optional[str] = None) -> dict:
-        """
-        Returns the credentials for the given tenant_id and provider_name.
-        """
-        return {
-            'openai_api_key': self.get_provider_api_key(model_id=model_id)
-        }
-
-    def get_provider_name(self):
-        return ProviderName.OPENAI
-
-    def config_validate(self, config: Union[dict | str]):
-        """
-        Validates the given config.
-        """
-        try:
-            Moderation(self.get_provider_name().value, config).moderate('test')
-        except (AuthenticationError, OpenAIError) as ex:
-            raise ValidateFailedError(str(ex))
-        except Exception as ex:
-            logging.exception('OpenAI config validation failed')
-            raise ex
-
-    def get_hosted_credentials(self) -> Union[str | dict]:
-        if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key:
-            raise ProviderTokenNotInitError(
-                f"No valid {self.get_provider_name().value} model provider credentials found. "
-                f"Please go to Settings -> Model Provider to complete your provider credentials."
-            )
-
-        return hosted_llm_credentials.openai.api_key

+ 0 - 62
api/core/llm/streamable_chat_anthropic.py

@@ -1,62 +0,0 @@
-from typing import List, Optional, Any, Dict
-
-from httpx import Timeout
-from langchain.callbacks.manager import Callbacks
-from langchain.chat_models import ChatAnthropic
-from langchain.schema import BaseMessage, LLMResult, SystemMessage, AIMessage, HumanMessage, ChatMessage
-from pydantic import root_validator
-
-from core.llm.wrappers.anthropic_wrapper import handle_anthropic_exceptions
-
-
-class StreamableChatAnthropic(ChatAnthropic):
-    """
-    Wrapper around Anthropic's large language model.
-    """
-
-    default_request_timeout: Optional[float] = Timeout(timeout=300.0, connect=5.0)
-
-    @root_validator()
-    def prepare_params(cls, values: Dict) -> Dict:
-        values['model_name'] = values.get('model')
-        values['max_tokens'] = values.get('max_tokens_to_sample')
-        return values
-
-    @handle_anthropic_exceptions
-    def generate(
-            self,
-            messages: List[List[BaseMessage]],
-            stop: Optional[List[str]] = None,
-            callbacks: Callbacks = None,
-            *,
-            tags: Optional[List[str]] = None,
-            metadata: Optional[Dict[str, Any]] = None,
-            **kwargs: Any,
-    ) -> LLMResult:
-        return super().generate(messages, stop, callbacks, tags=tags, metadata=metadata, **kwargs)
-
-    @classmethod
-    def get_kwargs_from_model_params(cls, params: dict):
-        params['model'] = params.get('model_name')
-        del params['model_name']
-
-        params['max_tokens_to_sample'] = params.get('max_tokens')
-        del params['max_tokens']
-
-        del params['frequency_penalty']
-        del params['presence_penalty']
-
-        return params
-
-    def _convert_one_message_to_text(self, message: BaseMessage) -> str:
-        if isinstance(message, ChatMessage):
-            message_text = f"\n\n{message.role.capitalize()}: {message.content}"
-        elif isinstance(message, HumanMessage):
-            message_text = f"{self.HUMAN_PROMPT} {message.content}"
-        elif isinstance(message, AIMessage):
-            message_text = f"{self.AI_PROMPT} {message.content}"
-        elif isinstance(message, SystemMessage):
-            message_text = f"<admin>{message.content}</admin>"
-        else:
-            raise ValueError(f"Got unknown type {message}")
-        return message_text

+ 0 - 41
api/core/llm/token_calculator.py

@@ -1,41 +0,0 @@
-import decimal
-from typing import Optional
-
-import tiktoken
-
-from core.constant import llm_constant
-
-
-class TokenCalculator:
-    @classmethod
-    def get_num_tokens(cls, model_name: str, text: str):
-        if len(text) == 0:
-            return 0
-
-        enc = tiktoken.encoding_for_model(model_name)
-
-        tokenized_text = enc.encode(text)
-
-        # calculate the number of tokens in the encoded text
-        return len(tokenized_text)
-
-    @classmethod
-    def get_token_price(cls, model_name: str, tokens: int, text_type: Optional[str] = None) -> decimal.Decimal:
-        if model_name in llm_constant.models_by_mode['embedding']:
-            unit_price = llm_constant.model_prices[model_name]['usage']
-        elif text_type == 'prompt':
-            unit_price = llm_constant.model_prices[model_name]['prompt']
-        elif text_type == 'completion':
-            unit_price = llm_constant.model_prices[model_name]['completion']
-        else:
-            raise Exception('Invalid text type')
-
-        tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
-                                                                  rounding=decimal.ROUND_HALF_UP)
-
-        total_price = tokens_per_1k * unit_price
-        return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
-
-    @classmethod
-    def get_currency(cls, model_name: str):
-        return llm_constant.model_currency

+ 0 - 26
api/core/llm/whisper.py

@@ -1,26 +0,0 @@
-import openai
-
-from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
-from models.provider import ProviderName
-from core.llm.provider.base import BaseProvider
-
-
-class Whisper:
-
-    def __init__(self, provider: BaseProvider):
-        self.provider = provider
-
-        if self.provider.get_provider_name() == ProviderName.OPENAI:
-            self.client = openai.Audio
-            self.credentials = provider.get_credentials()
-
-    @handle_openai_exceptions
-    def transcribe(self, file):
-        return self.client.transcribe(
-            model='whisper-1', 
-            file=file,
-            api_key=self.credentials.get('openai_api_key'),
-            api_base=self.credentials.get('openai_api_base'),
-            api_type=self.credentials.get('openai_api_type'),
-            api_version=self.credentials.get('openai_api_version'),
-        )

+ 0 - 27
api/core/llm/wrappers/anthropic_wrapper.py

@@ -1,27 +0,0 @@
-import logging
-from functools import wraps
-
-import anthropic
-
-from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \
-    LLMBadRequestError
-
-
-def handle_anthropic_exceptions(func):
-    @wraps(func)
-    def wrapper(*args, **kwargs):
-        try:
-            return func(*args, **kwargs)
-        except anthropic.APIConnectionError as e:
-            logging.exception("Failed to connect to Anthropic API.")
-            raise LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {e.__cause__}")
-        except anthropic.RateLimitError:
-            raise LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.")
-        except anthropic.AuthenticationError as e:
-            raise LLMAuthorizationError(f"Anthropic: {e.message}")
-        except anthropic.BadRequestError as e:
-            raise LLMBadRequestError(f"Anthropic: {e.message}")
-        except anthropic.APIStatusError as e:
-            raise LLMAPIUnavailableError(f"Anthropic: code: {e.status_code}, cause: {e.message}")
-
-    return wrapper

+ 0 - 31
api/core/llm/wrappers/openai_wrapper.py

@@ -1,31 +0,0 @@
-import logging
-from functools import wraps
-
-import openai
-
-from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \
-    LLMBadRequestError
-
-
-def handle_openai_exceptions(func):
-    @wraps(func)
-    def wrapper(*args, **kwargs):
-        try:
-            return func(*args, **kwargs)
-        except openai.error.InvalidRequestError as e:
-            logging.exception("Invalid request to OpenAI API.")
-            raise LLMBadRequestError(str(e))
-        except openai.error.APIConnectionError as e:
-            logging.exception("Failed to connect to OpenAI API.")
-            raise LLMAPIConnectionError(e.__class__.__name__ + ":" + str(e))
-        except (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout) as e:
-            logging.exception("OpenAI service unavailable.")
-            raise LLMAPIUnavailableError(e.__class__.__name__ + ":" + str(e))
-        except openai.error.RateLimitError as e:
-            raise LLMRateLimitError(str(e))
-        except openai.error.AuthenticationError as e:
-            raise LLMAuthorizationError(str(e))
-        except openai.error.OpenAIError as e:
-            raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e))
-
-    return wrapper

+ 12 - 12
api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py

@@ -1,10 +1,10 @@
-from typing import Any, List, Dict, Union
+from typing import Any, List, Dict
 
 from langchain.memory.chat_memory import BaseChatMemory
-from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage, BaseLanguageModel
+from langchain.schema import get_buffer_string, BaseMessage
 
-from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
-from core.llm.streamable_open_ai import StreamableOpenAI
+from core.model_providers.models.entity.message import PromptMessage, MessageType, to_lc_messages
+from core.model_providers.models.llm.base import BaseLLM
 from extensions.ext_database import db
 from models.model import Conversation, Message
 
@@ -13,7 +13,7 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
     conversation: Conversation
     human_prefix: str = "Human"
     ai_prefix: str = "Assistant"
-    llm: BaseLanguageModel
+    model_instance: BaseLLM
     memory_key: str = "chat_history"
     max_token_limit: int = 2000
     message_limit: int = 10
@@ -29,23 +29,23 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
 
         messages = list(reversed(messages))
 
-        chat_messages: List[BaseMessage] = []
+        chat_messages: List[PromptMessage] = []
         for message in messages:
-            chat_messages.append(HumanMessage(content=message.query))
-            chat_messages.append(AIMessage(content=message.answer))
+            chat_messages.append(PromptMessage(content=message.query, type=MessageType.HUMAN))
+            chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT))
 
         if not chat_messages:
-            return chat_messages
+            return []
 
         # prune the chat message if it exceeds the max token limit
-        curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
+        curr_buffer_length = self.model_instance.get_num_tokens(chat_messages)
         if curr_buffer_length > self.max_token_limit:
             pruned_memory = []
             while curr_buffer_length > self.max_token_limit and chat_messages:
                 pruned_memory.append(chat_messages.pop(0))
-                curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
+                curr_buffer_length = self.model_instance.get_num_tokens(chat_messages)
 
-        return chat_messages
+        return to_lc_messages(chat_messages)
 
     @property
     def memory_variables(self) -> List[str]:

+ 0 - 0
api/core/llm/error.py → api/core/model_providers/error.py


+ 293 - 0
api/core/model_providers/model_factory.py

@@ -0,0 +1,293 @@
+from typing import Optional
+
+from langchain.callbacks.base import Callbacks
+
+from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
+from core.model_providers.model_provider_factory import ModelProviderFactory, DEFAULT_MODELS
+from core.model_providers.models.base import BaseProviderModel
+from core.model_providers.models.embedding.base import BaseEmbedding
+from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
+from core.model_providers.models.llm.base import BaseLLM
+from core.model_providers.models.speech2text.base import BaseSpeech2Text
+from extensions.ext_database import db
+from models.provider import TenantDefaultModel
+
+
+class ModelFactory:
+
+    @classmethod
+    def get_text_generation_model_from_model_config(cls, tenant_id: str,
+                                                    model_config: dict,
+                                                    streaming: bool = False,
+                                                    callbacks: Callbacks = None) -> Optional[BaseLLM]:
+        provider_name = model_config.get("provider")
+        model_name = model_config.get("name")
+        completion_params = model_config.get("completion_params", {})
+
+        return cls.get_text_generation_model(
+            tenant_id=tenant_id,
+            model_provider_name=provider_name,
+            model_name=model_name,
+            model_kwargs=ModelKwargs(
+                temperature=completion_params.get('temperature', 0),
+                max_tokens=completion_params.get('max_tokens', 256),
+                top_p=completion_params.get('top_p', 0),
+                frequency_penalty=completion_params.get('frequency_penalty', 0.1),
+                presence_penalty=completion_params.get('presence_penalty', 0.1)
+            ),
+            streaming=streaming,
+            callbacks=callbacks
+        )
+
+    @classmethod
+    def get_text_generation_model(cls,
+                                  tenant_id: str,
+                                  model_provider_name: Optional[str] = None,
+                                  model_name: Optional[str] = None,
+                                  model_kwargs: Optional[ModelKwargs] = None,
+                                  streaming: bool = False,
+                                  callbacks: Callbacks = None) -> Optional[BaseLLM]:
+        """
+        get text generation model.
+
+        :param tenant_id: a string representing the ID of the tenant.
+        :param model_provider_name:
+        :param model_name:
+        :param model_kwargs:
+        :param streaming:
+        :param callbacks:
+        :return:
+        """
+        is_default_model = False
+        if model_provider_name is None and model_name is None:
+            default_model = cls.get_default_model(tenant_id, ModelType.TEXT_GENERATION)
+
+            if not default_model:
+                raise LLMBadRequestError(f"Default model is not available. "
+                                         f"Please configure a Default System Reasoning Model "
+                                         f"in the Settings -> Model Provider.")
+
+            model_provider_name = default_model.provider_name
+            model_name = default_model.model_name
+            is_default_model = True
+
+        # get model provider
+        model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
+
+        if not model_provider:
+            raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
+
+        # init text generation model
+        model_class = model_provider.get_model_class(model_type=ModelType.TEXT_GENERATION)
+
+        try:
+            model_instance = model_class(
+                model_provider=model_provider,
+                name=model_name,
+                model_kwargs=model_kwargs,
+                streaming=streaming,
+                callbacks=callbacks
+            )
+        except LLMBadRequestError as e:
+            if is_default_model:
+                raise LLMBadRequestError(f"Default model {model_name} is not available. "
+                                         f"Please check your model provider credentials.")
+            else:
+                raise e
+
+        if is_default_model:
+            model_instance.deduct_quota = False
+
+        return model_instance
+
+    @classmethod
+    def get_embedding_model(cls,
+                            tenant_id: str,
+                            model_provider_name: Optional[str] = None,
+                            model_name: Optional[str] = None) -> Optional[BaseEmbedding]:
+        """
+        get embedding model.
+
+        :param tenant_id: a string representing the ID of the tenant.
+        :param model_provider_name:
+        :param model_name:
+        :return:
+        """
+        if model_provider_name is None and model_name is None:
+            default_model = cls.get_default_model(tenant_id, ModelType.EMBEDDINGS)
+
+            if not default_model:
+                raise LLMBadRequestError(f"Default model is not available. "
+                                         f"Please configure a Default Embedding Model "
+                                         f"in the Settings -> Model Provider.")
+
+            model_provider_name = default_model.provider_name
+            model_name = default_model.model_name
+
+        # get model provider
+        model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
+
+        if not model_provider:
+            raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
+
+        # init embedding model
+        model_class = model_provider.get_model_class(model_type=ModelType.EMBEDDINGS)
+        return model_class(
+            model_provider=model_provider,
+            name=model_name
+        )
+
+    @classmethod
+    def get_speech2text_model(cls,
+                              tenant_id: str,
+                              model_provider_name: Optional[str] = None,
+                              model_name: Optional[str] = None) -> Optional[BaseSpeech2Text]:
+        """
+        get speech to text model.
+
+        :param tenant_id: a string representing the ID of the tenant.
+        :param model_provider_name:
+        :param model_name:
+        :return:
+        """
+        if model_provider_name is None and model_name is None:
+            default_model = cls.get_default_model(tenant_id, ModelType.SPEECH_TO_TEXT)
+
+            if not default_model:
+                raise LLMBadRequestError(f"Default model is not available. "
+                                         f"Please configure a Default Speech-to-Text Model "
+                                         f"in the Settings -> Model Provider.")
+
+            model_provider_name = default_model.provider_name
+            model_name = default_model.model_name
+
+        # get model provider
+        model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
+
+        if not model_provider:
+            raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
+
+        # init speech to text model
+        model_class = model_provider.get_model_class(model_type=ModelType.SPEECH_TO_TEXT)
+        return model_class(
+            model_provider=model_provider,
+            name=model_name
+        )
+
+    @classmethod
+    def get_moderation_model(cls,
+                             tenant_id: str,
+                             model_provider_name: str,
+                             model_name: str) -> Optional[BaseProviderModel]:
+        """
+        get moderation model.
+
+        :param tenant_id: a string representing the ID of the tenant.
+        :param model_provider_name:
+        :param model_name:
+        :return:
+        """
+        # get model provider
+        model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
+
+        if not model_provider:
+            raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
+
+        # init moderation model
+        model_class = model_provider.get_model_class(model_type=ModelType.MODERATION)
+        return model_class(
+            model_provider=model_provider,
+            name=model_name
+        )
+
+    @classmethod
+    def get_default_model(cls, tenant_id: str, model_type: ModelType) -> TenantDefaultModel:
+        """
+        get default model of model type.
+
+        :param tenant_id:
+        :param model_type:
+        :return:
+        """
+        # get default model
+        default_model = db.session.query(TenantDefaultModel) \
+            .filter(
+            TenantDefaultModel.tenant_id == tenant_id,
+            TenantDefaultModel.model_type == model_type.value
+        ).first()
+
+        if not default_model:
+            model_provider_rules = ModelProviderFactory.get_provider_rules()
+            for model_provider_name, model_provider_rule in model_provider_rules.items():
+                model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
+                if not model_provider:
+                    continue
+
+                model_list = model_provider.get_supported_model_list(model_type)
+                if model_list:
+                    model_info = model_list[0]
+                    default_model = TenantDefaultModel(
+                        tenant_id=tenant_id,
+                        model_type=model_type.value,
+                        provider_name=model_provider_name,
+                        model_name=model_info['id']
+                    )
+                    db.session.add(default_model)
+                    db.session.commit()
+                    break
+
+        return default_model
+
+    @classmethod
+    def update_default_model(cls,
+                             tenant_id: str,
+                             model_type: ModelType,
+                             provider_name: str,
+                             model_name: str) -> TenantDefaultModel:
+        """
+        update default model of model type.
+
+        :param tenant_id:
+        :param model_type:
+        :param provider_name:
+        :param model_name:
+        :return:
+        """
+        model_provider_name = ModelProviderFactory.get_provider_names()
+        if provider_name not in model_provider_name:
+            raise ValueError(f'Invalid provider name: {provider_name}')
+
+        model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, provider_name)
+
+        if not model_provider:
+            raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
+
+        model_list = model_provider.get_supported_model_list(model_type)
+        model_ids = [model['id'] for model in model_list]
+        if model_name not in model_ids:
+            raise ValueError(f'Invalid model name: {model_name}')
+
+        # get default model
+        default_model = db.session.query(TenantDefaultModel) \
+            .filter(
+            TenantDefaultModel.tenant_id == tenant_id,
+            TenantDefaultModel.model_type == model_type.value
+        ).first()
+
+        if default_model:
+            # update default model
+            default_model.provider_name = provider_name
+            default_model.model_name = model_name
+            db.session.commit()
+        else:
+            # create default model
+            default_model = TenantDefaultModel(
+                tenant_id=tenant_id,
+                model_type=model_type.value,
+                provider_name=provider_name,
+                model_name=model_name,
+            )
+            db.session.add(default_model)
+            db.session.commit()
+
+        return default_model

+ 228 - 0
api/core/model_providers/model_provider_factory.py

@@ -0,0 +1,228 @@
+from typing import Type
+
+from sqlalchemy.exc import IntegrityError
+
+from core.model_providers.models.entity.model_params import ModelType
+from core.model_providers.providers.base import BaseModelProvider
+from core.model_providers.rules import provider_rules
+from extensions.ext_database import db
+from models.provider import TenantPreferredModelProvider, ProviderType, Provider, ProviderQuotaType
+
+DEFAULT_MODELS = {
+    ModelType.TEXT_GENERATION.value: {
+        'provider_name': 'openai',
+        'model_name': 'gpt-3.5-turbo',
+    },
+    ModelType.EMBEDDINGS.value: {
+        'provider_name': 'openai',
+        'model_name': 'text-embedding-ada-002',
+    },
+    ModelType.SPEECH_TO_TEXT.value: {
+        'provider_name': 'openai',
+        'model_name': 'whisper-1',
+    }
+}
+
+
+class ModelProviderFactory:
+    @classmethod
+    def get_model_provider_class(cls, provider_name: str) -> Type[BaseModelProvider]:
+        if provider_name == 'openai':
+            from core.model_providers.providers.openai_provider import OpenAIProvider
+            return OpenAIProvider
+        elif provider_name == 'anthropic':
+            from core.model_providers.providers.anthropic_provider import AnthropicProvider
+            return AnthropicProvider
+        elif provider_name == 'minimax':
+            from core.model_providers.providers.minimax_provider import MinimaxProvider
+            return MinimaxProvider
+        elif provider_name == 'spark':
+            from core.model_providers.providers.spark_provider import SparkProvider
+            return SparkProvider
+        elif provider_name == 'tongyi':
+            from core.model_providers.providers.tongyi_provider import TongyiProvider
+            return TongyiProvider
+        elif provider_name == 'wenxin':
+            from core.model_providers.providers.wenxin_provider import WenxinProvider
+            return WenxinProvider
+        elif provider_name == 'chatglm':
+            from core.model_providers.providers.chatglm_provider import ChatGLMProvider
+            return ChatGLMProvider
+        elif provider_name == 'azure_openai':
+            from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider
+            return AzureOpenAIProvider
+        elif provider_name == 'replicate':
+            from core.model_providers.providers.replicate_provider import ReplicateProvider
+            return ReplicateProvider
+        elif provider_name == 'huggingface_hub':
+            from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
+            return HuggingfaceHubProvider
+        else:
+            raise NotImplementedError
+
+    @classmethod
+    def get_provider_names(cls):
+        """
+        Returns a list of provider names.
+        """
+        return list(provider_rules.keys())
+
+    @classmethod
+    def get_provider_rules(cls):
+        """
+        Returns a list of provider rules.
+
+        :return:
+        """
+        return provider_rules
+
+    @classmethod
+    def get_provider_rule(cls, provider_name: str):
+        """
+        Returns provider rule.
+        """
+        return provider_rules[provider_name]
+
+    @classmethod
+    def get_preferred_model_provider(cls, tenant_id: str, model_provider_name: str):
+        """
+        get preferred model provider.
+
+        :param tenant_id: a string representing the ID of the tenant.
+        :param model_provider_name:
+        :return:
+        """
+        # get preferred provider
+        preferred_provider = cls._get_preferred_provider(tenant_id, model_provider_name)
+        if not preferred_provider or not preferred_provider.is_valid:
+            return None
+
+        # init model provider
+        model_provider_class = ModelProviderFactory.get_model_provider_class(model_provider_name)
+        return model_provider_class(provider=preferred_provider)
+
+    @classmethod
+    def get_preferred_type_by_preferred_model_provider(cls,
+                                                       tenant_id: str,
+                                                       model_provider_name: str,
+                                                       preferred_model_provider: TenantPreferredModelProvider):
+        """
+        get preferred provider type by preferred model provider.
+
+        :param model_provider_name:
+        :param preferred_model_provider:
+        :return:
+        """
+        if not preferred_model_provider:
+            model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name)
+            support_provider_types = model_provider_rules['support_provider_types']
+
+            if ProviderType.CUSTOM.value in support_provider_types:
+                custom_provider = db.session.query(Provider) \
+                    .filter(
+                        Provider.tenant_id == tenant_id,
+                        Provider.provider_name == model_provider_name,
+                        Provider.provider_type == ProviderType.CUSTOM.value,
+                        Provider.is_valid == True
+                    ).first()
+
+                if custom_provider:
+                    return ProviderType.CUSTOM.value
+
+            model_provider = cls.get_model_provider_class(model_provider_name)
+
+            if ProviderType.SYSTEM.value in support_provider_types \
+                    and model_provider.is_provider_type_system_supported():
+                return ProviderType.SYSTEM.value
+            elif ProviderType.CUSTOM.value in support_provider_types:
+                return ProviderType.CUSTOM.value
+        else:
+            return preferred_model_provider.preferred_provider_type
+
+    @classmethod
+    def _get_preferred_provider(cls, tenant_id: str, model_provider_name: str):
+        """
+        get preferred provider of tenant.
+
+        :param tenant_id:
+        :param model_provider_name:
+        :return:
+        """
+        # get preferred provider type
+        preferred_provider_type = cls._get_preferred_provider_type(tenant_id, model_provider_name)
+
+        # get providers by preferred provider type
+        providers = db.session.query(Provider) \
+            .filter(
+                Provider.tenant_id == tenant_id,
+                Provider.provider_name == model_provider_name,
+                Provider.provider_type == preferred_provider_type
+            ).all()
+
+        no_system_provider = False
+        if preferred_provider_type == ProviderType.SYSTEM.value:
+            quota_type_to_provider_dict = {}
+            for provider in providers:
+                quota_type_to_provider_dict[provider.quota_type] = provider
+
+            model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name)
+            for quota_type_enum in ProviderQuotaType:
+                quota_type = quota_type_enum.value
+                if quota_type in model_provider_rules['system_config']['supported_quota_types'] \
+                        and quota_type in quota_type_to_provider_dict.keys():
+                    provider = quota_type_to_provider_dict[quota_type]
+                    if provider.is_valid and provider.quota_limit > provider.quota_used:
+                        return provider
+
+            no_system_provider = True
+
+        if no_system_provider:
+            providers = db.session.query(Provider) \
+                .filter(
+                Provider.tenant_id == tenant_id,
+                Provider.provider_name == model_provider_name,
+                Provider.provider_type == ProviderType.CUSTOM.value
+            ).all()
+
+        if preferred_provider_type == ProviderType.CUSTOM.value or no_system_provider:
+            if providers:
+                return providers[0]
+            else:
+                try:
+                    provider = Provider(
+                        tenant_id=tenant_id,
+                        provider_name=model_provider_name,
+                        provider_type=ProviderType.CUSTOM.value,
+                        is_valid=False
+                    )
+                    db.session.add(provider)
+                    db.session.commit()
+                except IntegrityError:
+                    db.session.rollback()
+                    provider = db.session.query(Provider) \
+                        .filter(
+                            Provider.tenant_id == tenant_id,
+                            Provider.provider_name == model_provider_name,
+                            Provider.provider_type == ProviderType.CUSTOM.value
+                        ).first()
+
+                return provider
+
+        return None
+
+    @classmethod
+    def _get_preferred_provider_type(cls, tenant_id: str, model_provider_name: str):
+        """
+        get preferred provider type of tenant.
+
+        :param tenant_id:
+        :param model_provider_name:
+        :return:
+        """
+        preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
+            .filter(
+            TenantPreferredModelProvider.tenant_id == tenant_id,
+            TenantPreferredModelProvider.provider_name == model_provider_name
+        ).first()
+
+        return cls.get_preferred_type_by_preferred_model_provider(tenant_id, model_provider_name, preferred_model_provider)

+ 0 - 0
api/tests/test_libs/__init__.py → api/core/model_providers/models/__init__.py


+ 22 - 0
api/core/model_providers/models/base.py

@@ -0,0 +1,22 @@
+from abc import ABC
+from typing import Any
+
+from core.model_providers.providers.base import BaseModelProvider
+
+
+class BaseProviderModel(ABC):
+    _client: Any
+    _model_provider: BaseModelProvider
+
+    def __init__(self, model_provider: BaseModelProvider, client: Any):
+        self._model_provider = model_provider
+        self._client = client
+
+    @property
+    def client(self):
+        return self._client
+
+    @property
+    def model_provider(self):
+        return self._model_provider
+

+ 0 - 0
api/tests/test_models/__init__.py → api/core/model_providers/models/embedding/__init__.py


+ 78 - 0
api/core/model_providers/models/embedding/azure_openai_embedding.py

@@ -0,0 +1,78 @@
+import decimal
+import logging
+
+import openai
+import tiktoken
+from langchain.embeddings import OpenAIEmbeddings
+
+from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMRateLimitError, \
+    LLMAPIUnavailableError, LLMAPIConnectionError
+from core.model_providers.models.embedding.base import BaseEmbedding
+from core.model_providers.providers.base import BaseModelProvider
+
+AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
+
+
+class AzureOpenAIEmbedding(BaseEmbedding):
+    def __init__(self, model_provider: BaseModelProvider, name: str):
+        self.credentials = model_provider.get_model_credentials(
+            model_name=name,
+            model_type=self.type
+        )
+
+        client = OpenAIEmbeddings(
+            deployment=name,
+            openai_api_type='azure',
+            openai_api_version=AZURE_OPENAI_API_VERSION,
+            chunk_size=16,
+            max_retries=1,
+            **self.credentials
+        )
+
+        super().__init__(model_provider, client, name)
+
+    def get_num_tokens(self, text: str) -> int:
+        """
+        get num tokens of text.
+
+        :param text:
+        :return:
+        """
+        if len(text) == 0:
+            return 0
+
+        enc = tiktoken.encoding_for_model(self.credentials.get('base_model_name'))
+
+        tokenized_text = enc.encode(text)
+
+        # calculate the number of tokens in the encoded text
+        return len(tokenized_text)
+
+    def get_token_price(self, tokens: int):
+        tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
+                                                                  rounding=decimal.ROUND_HALF_UP)
+
+        total_price = tokens_per_1k * decimal.Decimal('0.0001')
+        return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
+
+    def get_currency(self):
+        return 'USD'
+
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        if isinstance(ex, openai.error.InvalidRequestError):
+            logging.warning("Invalid request to Azure OpenAI API.")
+            return LLMBadRequestError(str(ex))
+        elif isinstance(ex, openai.error.APIConnectionError):
+            logging.warning("Failed to connect to Azure OpenAI API.")
+            return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
+        elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
+            logging.warning("Azure OpenAI service unavailable.")
+            return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
+        elif isinstance(ex, openai.error.RateLimitError):
+            return LLMRateLimitError('Azure ' + str(ex))
+        elif isinstance(ex, openai.error.AuthenticationError):
+            raise LLMAuthorizationError('Azure ' + str(ex))
+        elif isinstance(ex, openai.error.OpenAIError):
+            return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex))
+        else:
+            return ex

+ 40 - 0
api/core/model_providers/models/embedding/base.py

@@ -0,0 +1,40 @@
+from abc import abstractmethod
+from typing import Any
+
+import tiktoken
+from langchain.schema.language_model import _get_token_ids_default_method
+
+from core.model_providers.models.base import BaseProviderModel
+from core.model_providers.models.entity.model_params import ModelType
+from core.model_providers.providers.base import BaseModelProvider
+
+
+class BaseEmbedding(BaseProviderModel):
+    name: str
+    type: ModelType = ModelType.EMBEDDINGS
+
+    def __init__(self, model_provider: BaseModelProvider, client: Any, name: str):
+        super().__init__(model_provider, client)
+        self.name = name
+
+    def get_num_tokens(self, text: str) -> int:
+        """
+        get num tokens of text.
+
+        :param text:
+        :return:
+        """
+        if len(text) == 0:
+            return 0
+
+        return len(_get_token_ids_default_method(text))
+
+    def get_token_price(self, tokens: int):
+        return 0
+
+    def get_currency(self):
+        return 'USD'
+
+    @abstractmethod
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        raise NotImplementedError

+ 35 - 0
api/core/model_providers/models/embedding/minimax_embedding.py

@@ -0,0 +1,35 @@
+import decimal
+import logging
+
+from langchain.embeddings import MiniMaxEmbeddings
+
+from core.model_providers.error import LLMBadRequestError
+from core.model_providers.models.embedding.base import BaseEmbedding
+from core.model_providers.providers.base import BaseModelProvider
+
+
+class MinimaxEmbedding(BaseEmbedding):
+    def __init__(self, model_provider: BaseModelProvider, name: str):
+        credentials = model_provider.get_model_credentials(
+            model_name=name,
+            model_type=self.type
+        )
+
+        client = MiniMaxEmbeddings(
+            model=name,
+            **credentials
+        )
+
+        super().__init__(model_provider, client, name)
+
+    def get_token_price(self, tokens: int):
+        return decimal.Decimal('0')
+
+    def get_currency(self):
+        return 'RMB'
+
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        if isinstance(ex, ValueError):
+            return LLMBadRequestError(f"Minimax: {str(ex)}")
+        else:
+            return ex

+ 72 - 0
api/core/model_providers/models/embedding/openai_embedding.py

@@ -0,0 +1,72 @@
+import decimal
+import logging
+
+import openai
+import tiktoken
+from langchain.embeddings import OpenAIEmbeddings
+
+from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
+    LLMRateLimitError, LLMAuthorizationError
+from core.model_providers.models.embedding.base import BaseEmbedding
+from core.model_providers.providers.base import BaseModelProvider
+
+
+class OpenAIEmbedding(BaseEmbedding):
+    def __init__(self, model_provider: BaseModelProvider, name: str):
+        credentials = model_provider.get_model_credentials(
+            model_name=name,
+            model_type=self.type
+        )
+
+        client = OpenAIEmbeddings(
+            max_retries=1,
+            **credentials
+        )
+
+        super().__init__(model_provider, client, name)
+
+    def get_num_tokens(self, text: str) -> int:
+        """
+        get num tokens of text.
+
+        :param text:
+        :return:
+        """
+        if len(text) == 0:
+            return 0
+
+        enc = tiktoken.encoding_for_model(self.name)
+
+        tokenized_text = enc.encode(text)
+
+        # calculate the number of tokens in the encoded text
+        return len(tokenized_text)
+
+    def get_token_price(self, tokens: int):
+        tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
+                                                                  rounding=decimal.ROUND_HALF_UP)
+
+        total_price = tokens_per_1k * decimal.Decimal('0.0001')
+        return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
+
+    def get_currency(self):
+        return 'USD'
+
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        if isinstance(ex, openai.error.InvalidRequestError):
+            logging.warning("Invalid request to OpenAI API.")
+            return LLMBadRequestError(str(ex))
+        elif isinstance(ex, openai.error.APIConnectionError):
+            logging.warning("Failed to connect to OpenAI API.")
+            return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
+        elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
+            logging.warning("OpenAI service unavailable.")
+            return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
+        elif isinstance(ex, openai.error.RateLimitError):
+            return LLMRateLimitError(str(ex))
+        elif isinstance(ex, openai.error.AuthenticationError):
+            raise LLMAuthorizationError(str(ex))
+        elif isinstance(ex, openai.error.OpenAIError):
+            return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
+        else:
+            return ex

+ 36 - 0
api/core/model_providers/models/embedding/replicate_embedding.py

@@ -0,0 +1,36 @@
+import decimal
+
+from replicate.exceptions import ModelError, ReplicateError
+
+from core.model_providers.error import LLMBadRequestError
+from core.model_providers.providers.base import BaseModelProvider
+from core.third_party.langchain.embeddings.replicate_embedding import ReplicateEmbeddings
+from core.model_providers.models.embedding.base import BaseEmbedding
+
+
+class ReplicateEmbedding(BaseEmbedding):
+    def __init__(self, model_provider: BaseModelProvider, name: str):
+        credentials = model_provider.get_model_credentials(
+            model_name=name,
+            model_type=self.type
+        )
+
+        client = ReplicateEmbeddings(
+            model=name + ':' + credentials.get('model_version'),
+            replicate_api_token=credentials.get('replicate_api_token')
+        )
+
+        super().__init__(model_provider, client, name)
+
+    def get_token_price(self, tokens: int):
+        # replicate only pay for prediction seconds
+        return decimal.Decimal('0')
+
+    def get_currency(self):
+        return 'USD'
+
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        if isinstance(ex, (ModelError, ReplicateError)):
+            return LLMBadRequestError(f"Replicate: {str(ex)}")
+        else:
+            return ex

+ 0 - 0
api/tests/test_services/__init__.py → api/core/model_providers/models/entity/__init__.py


+ 53 - 0
api/core/model_providers/models/entity/message.py

@@ -0,0 +1,53 @@
+import enum
+
+from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage
+from pydantic import BaseModel
+
+
+class LLMRunResult(BaseModel):
+    content: str
+    prompt_tokens: int
+    completion_tokens: int
+
+
+class MessageType(enum.Enum):
+    HUMAN = 'human'
+    ASSISTANT = 'assistant'
+    SYSTEM = 'system'
+
+
+class PromptMessage(BaseModel):
+    type: MessageType = MessageType.HUMAN
+    content: str = ''
+
+
+def to_lc_messages(messages: list[PromptMessage]):
+    lc_messages = []
+    for message in messages:
+        if message.type == MessageType.HUMAN:
+            lc_messages.append(HumanMessage(content=message.content))
+        elif message.type == MessageType.ASSISTANT:
+            lc_messages.append(AIMessage(content=message.content))
+        elif message.type == MessageType.SYSTEM:
+            lc_messages.append(SystemMessage(content=message.content))
+
+    return lc_messages
+
+
+def to_prompt_messages(messages: list[BaseMessage]):
+    prompt_messages = []
+    for message in messages:
+        if isinstance(message, HumanMessage):
+            prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
+        elif isinstance(message, AIMessage):
+            prompt_messages.append(PromptMessage(content=message.content, type=MessageType.ASSISTANT))
+        elif isinstance(message, SystemMessage):
+            prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM))
+    return prompt_messages
+
+
+def str_to_prompt_messages(texts: list[str]):
+    prompt_messages = []
+    for text in texts:
+        prompt_messages.append(PromptMessage(content=text))
+    return prompt_messages

+ 59 - 0
api/core/model_providers/models/entity/model_params.py

@@ -0,0 +1,59 @@
+import enum
+from typing import Optional, TypeVar, Generic
+
+from langchain.load.serializable import Serializable
+from pydantic import BaseModel
+
+
+class ModelMode(enum.Enum):
+    COMPLETION = 'completion'
+    CHAT = 'chat'
+
+
+class ModelType(enum.Enum):
+    TEXT_GENERATION = 'text-generation'
+    EMBEDDINGS = 'embeddings'
+    SPEECH_TO_TEXT = 'speech2text'
+    IMAGE = 'image'
+    VIDEO = 'video'
+    MODERATION = 'moderation'
+
+    @staticmethod
+    def value_of(value):
+        for member in ModelType:
+            if member.value == value:
+                return member
+        raise ValueError(f"No matching enum found for value '{value}'")
+
+
+class ModelKwargs(BaseModel):
+    max_tokens: Optional[int]
+    temperature: Optional[float]
+    top_p: Optional[float]
+    presence_penalty: Optional[float]
+    frequency_penalty: Optional[float]
+
+
+class KwargRuleType(enum.Enum):
+    STRING = 'string'
+    INTEGER = 'integer'
+    FLOAT = 'float'
+
+
+T = TypeVar('T')
+
+
+class KwargRule(Generic[T], BaseModel):
+    enabled: bool = True
+    min: Optional[T] = None
+    max: Optional[T] = None
+    default: Optional[T] = None
+    alias: Optional[str] = None
+
+
+class ModelKwargsRules(BaseModel):
+    max_tokens: KwargRule = KwargRule[int](enabled=False)
+    temperature: KwargRule = KwargRule[float](enabled=False)
+    top_p: KwargRule = KwargRule[float](enabled=False)
+    presence_penalty: KwargRule = KwargRule[float](enabled=False)
+    frequency_penalty: KwargRule = KwargRule[float](enabled=False)

+ 10 - 0
api/core/model_providers/models/entity/provider.py

@@ -0,0 +1,10 @@
+from enum import Enum
+
+
+class ProviderQuotaUnit(Enum):
+    TIMES = 'times'
+    TOKENS = 'tokens'
+
+
+class ModelFeature(Enum):
+    AGENT_THOUGHT = 'agent_thought'

+ 0 - 0
api/core/model_providers/models/llm/__init__.py


+ 107 - 0
api/core/model_providers/models/llm/anthropic_model.py

@@ -0,0 +1,107 @@
+import decimal
+import logging
+from functools import wraps
+from typing import List, Optional, Any
+
+import anthropic
+from langchain.callbacks.manager import Callbacks
+from langchain.chat_models import ChatAnthropic
+from langchain.schema import LLMResult
+
+from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
+    LLMRateLimitError, LLMAuthorizationError
+from core.model_providers.models.llm.base import BaseLLM
+from core.model_providers.models.entity.message import PromptMessage, MessageType
+from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
+
+
+class AnthropicModel(BaseLLM):
+    model_mode: ModelMode = ModelMode.CHAT
+
+    def _init_client(self) -> Any:
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
+        return ChatAnthropic(
+            model=self.name,
+            streaming=self.streaming,
+            callbacks=self.callbacks,
+            default_request_timeout=60,
+            **self.credentials,
+            **provider_model_kwargs
+        )
+
+    def _run(self, messages: List[PromptMessage],
+             stop: Optional[List[str]] = None,
+             callbacks: Callbacks = None,
+             **kwargs) -> LLMResult:
+        """
+        run predict by prompt messages and stop words.
+
+        :param messages:
+        :param stop:
+        :param callbacks:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        return self._client.generate([prompts], stop, callbacks)
+
+    def get_num_tokens(self, messages: List[PromptMessage]) -> int:
+        """
+        get num tokens of prompt messages.
+
+        :param messages:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
+
+    def get_token_price(self, tokens: int, message_type: MessageType):
+        model_unit_prices = {
+            'claude-instant-1': {
+                'prompt': decimal.Decimal('1.63'),
+                'completion': decimal.Decimal('5.51'),
+            },
+            'claude-2': {
+                'prompt': decimal.Decimal('11.02'),
+                'completion': decimal.Decimal('32.68'),
+            },
+        }
+
+        if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
+            unit_price = model_unit_prices[self.name]['prompt']
+        else:
+            unit_price = model_unit_prices[self.name]['completion']
+
+        tokens_per_1m = (decimal.Decimal(tokens) / 1000000).quantize(decimal.Decimal('0.000001'),
+                                                                     rounding=decimal.ROUND_HALF_UP)
+
+        total_price = tokens_per_1m * unit_price
+        return total_price.quantize(decimal.Decimal('0.00000001'), rounding=decimal.ROUND_HALF_UP)
+
+    def get_currency(self):
+        return 'USD'
+
+    def _set_model_kwargs(self, model_kwargs: ModelKwargs):
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
+        for k, v in provider_model_kwargs.items():
+            if hasattr(self.client, k):
+                setattr(self.client, k, v)
+
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        if isinstance(ex, anthropic.APIConnectionError):
+            logging.warning("Failed to connect to Anthropic API.")
+            return LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {ex.__cause__}")
+        elif isinstance(ex, anthropic.RateLimitError):
+            return LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.")
+        elif isinstance(ex, anthropic.AuthenticationError):
+            return LLMAuthorizationError(f"Anthropic: {ex.message}")
+        elif isinstance(ex, anthropic.BadRequestError):
+            return LLMBadRequestError(f"Anthropic: {ex.message}")
+        elif isinstance(ex, anthropic.APIStatusError):
+            return LLMAPIUnavailableError(f"Anthropic: code: {ex.status_code}, cause: {ex.message}")
+        else:
+            return ex
+
+    @classmethod
+    def support_streaming(cls):
+        return True
+

+ 177 - 0
api/core/model_providers/models/llm/azure_openai_model.py

@@ -0,0 +1,177 @@
+import decimal
+import logging
+from functools import wraps
+from typing import List, Optional, Any
+
+import openai
+from langchain.callbacks.manager import Callbacks
+from langchain.schema import LLMResult
+
+from core.model_providers.providers.base import BaseModelProvider
+from core.third_party.langchain.llms.azure_chat_open_ai import EnhanceAzureChatOpenAI
+from core.third_party.langchain.llms.azure_open_ai import EnhanceAzureOpenAI
+from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
+    LLMRateLimitError, LLMAuthorizationError
+from core.model_providers.models.llm.base import BaseLLM
+from core.model_providers.models.entity.message import PromptMessage, MessageType
+from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
+
+AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
+
+
+class AzureOpenAIModel(BaseLLM):
+    def __init__(self, model_provider: BaseModelProvider,
+                 name: str,
+                 model_kwargs: ModelKwargs,
+                 streaming: bool = False,
+                 callbacks: Callbacks = None):
+        if name == 'text-davinci-003':
+            self.model_mode = ModelMode.COMPLETION
+        else:
+            self.model_mode = ModelMode.CHAT
+
+        super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
+
+    def _init_client(self) -> Any:
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
+        if self.name == 'text-davinci-003':
+            client = EnhanceAzureOpenAI(
+                deployment_name=self.name,
+                streaming=self.streaming,
+                request_timeout=60,
+                openai_api_type='azure',
+                openai_api_version=AZURE_OPENAI_API_VERSION,
+                openai_api_key=self.credentials.get('openai_api_key'),
+                openai_api_base=self.credentials.get('openai_api_base'),
+                callbacks=self.callbacks,
+                **provider_model_kwargs
+            )
+        else:
+            extra_model_kwargs = {
+                'top_p': provider_model_kwargs.get('top_p'),
+                'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
+                'presence_penalty': provider_model_kwargs.get('presence_penalty'),
+            }
+
+            client = EnhanceAzureChatOpenAI(
+                deployment_name=self.name,
+                temperature=provider_model_kwargs.get('temperature'),
+                max_tokens=provider_model_kwargs.get('max_tokens'),
+                model_kwargs=extra_model_kwargs,
+                streaming=self.streaming,
+                request_timeout=60,
+                openai_api_type='azure',
+                openai_api_version=AZURE_OPENAI_API_VERSION,
+                openai_api_key=self.credentials.get('openai_api_key'),
+                openai_api_base=self.credentials.get('openai_api_base'),
+                callbacks=self.callbacks,
+            )
+
+        return client
+
+    def _run(self, messages: List[PromptMessage],
+             stop: Optional[List[str]] = None,
+             callbacks: Callbacks = None,
+             **kwargs) -> LLMResult:
+        """
+        run predict by prompt messages and stop words.
+
+        :param messages:
+        :param stop:
+        :param callbacks:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        return self._client.generate([prompts], stop, callbacks)
+
+    def get_num_tokens(self, messages: List[PromptMessage]) -> int:
+        """
+        get num tokens of prompt messages.
+
+        :param messages:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        if isinstance(prompts, str):
+            return self._client.get_num_tokens(prompts)
+        else:
+            return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
+
+    def get_token_price(self, tokens: int, message_type: MessageType):
+        model_unit_prices = {
+            'gpt-4': {
+                'prompt': decimal.Decimal('0.03'),
+                'completion': decimal.Decimal('0.06'),
+            },
+            'gpt-4-32k': {
+                'prompt': decimal.Decimal('0.06'),
+                'completion': decimal.Decimal('0.12')
+            },
+            'gpt-35-turbo': {
+                'prompt': decimal.Decimal('0.0015'),
+                'completion': decimal.Decimal('0.002')
+            },
+            'gpt-35-turbo-16k': {
+                'prompt': decimal.Decimal('0.003'),
+                'completion': decimal.Decimal('0.004')
+            },
+            'text-davinci-003': {
+                'prompt': decimal.Decimal('0.02'),
+                'completion': decimal.Decimal('0.02')
+            },
+        }
+
+        base_model_name = self.credentials.get("base_model_name")
+        if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
+            unit_price = model_unit_prices[base_model_name]['prompt']
+        else:
+            unit_price = model_unit_prices[base_model_name]['completion']
+
+        tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
+                                                                  rounding=decimal.ROUND_HALF_UP)
+
+        total_price = tokens_per_1k * unit_price
+        return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
+
+    def get_currency(self):
+        return 'USD'
+
+    def _set_model_kwargs(self, model_kwargs: ModelKwargs):
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
+        if self.name == 'text-davinci-003':
+            for k, v in provider_model_kwargs.items():
+                if hasattr(self.client, k):
+                    setattr(self.client, k, v)
+        else:
+            extra_model_kwargs = {
+                'top_p': provider_model_kwargs.get('top_p'),
+                'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
+                'presence_penalty': provider_model_kwargs.get('presence_penalty'),
+            }
+
+            self.client.temperature = provider_model_kwargs.get('temperature')
+            self.client.max_tokens = provider_model_kwargs.get('max_tokens')
+            self.client.model_kwargs = extra_model_kwargs
+
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        if isinstance(ex, openai.error.InvalidRequestError):
+            logging.warning("Invalid request to Azure OpenAI API.")
+            return LLMBadRequestError(str(ex))
+        elif isinstance(ex, openai.error.APIConnectionError):
+            logging.warning("Failed to connect to Azure OpenAI API.")
+            return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
+        elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
+            logging.warning("Azure OpenAI service unavailable.")
+            return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
+        elif isinstance(ex, openai.error.RateLimitError):
+            return LLMRateLimitError('Azure ' + str(ex))
+        elif isinstance(ex, openai.error.AuthenticationError):
+            raise LLMAuthorizationError('Azure ' + str(ex))
+        elif isinstance(ex, openai.error.OpenAIError):
+            return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex))
+        else:
+            return ex
+
+    @classmethod
+    def support_streaming(cls):
+        return True

+ 269 - 0
api/core/model_providers/models/llm/base.py

@@ -0,0 +1,269 @@
+from abc import abstractmethod
+from typing import List, Optional, Any, Union
+
+from langchain.callbacks.manager import Callbacks
+from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
+
+from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
+from core.model_providers.models.base import BaseProviderModel
+from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult
+from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
+from core.model_providers.providers.base import BaseModelProvider
+from core.third_party.langchain.llms.fake import FakeLLM
+
+
+class BaseLLM(BaseProviderModel):
+    model_mode: ModelMode = ModelMode.COMPLETION
+    name: str
+    model_kwargs: ModelKwargs
+    credentials: dict
+    streaming: bool = False
+    type: ModelType = ModelType.TEXT_GENERATION
+    deduct_quota: bool = True
+
+    def __init__(self, model_provider: BaseModelProvider,
+                 name: str,
+                 model_kwargs: ModelKwargs,
+                 streaming: bool = False,
+                 callbacks: Callbacks = None):
+        self.name = name
+        self.model_rules = model_provider.get_model_parameter_rules(name, self.type)
+        self.model_kwargs = model_kwargs if model_kwargs else ModelKwargs(
+            max_tokens=None,
+            temperature=None,
+            top_p=None,
+            presence_penalty=None,
+            frequency_penalty=None
+        )
+        self.credentials = model_provider.get_model_credentials(
+            model_name=name,
+            model_type=self.type
+        )
+        self.streaming = streaming
+
+        if streaming:
+            default_callback = DifyStreamingStdOutCallbackHandler()
+        else:
+            default_callback = DifyStdOutCallbackHandler()
+
+        if not callbacks:
+            callbacks = [default_callback]
+        else:
+            callbacks.append(default_callback)
+
+        self.callbacks = callbacks
+
+        client = self._init_client()
+        super().__init__(model_provider, client)
+
+    @abstractmethod
+    def _init_client(self) -> Any:
+        raise NotImplementedError
+
+    def run(self, messages: List[PromptMessage],
+            stop: Optional[List[str]] = None,
+            callbacks: Callbacks = None,
+            **kwargs) -> LLMRunResult:
+        """
+        run predict by prompt messages and stop words.
+
+        :param messages:
+        :param stop:
+        :param callbacks:
+        :return:
+        """
+        if self.deduct_quota:
+            self.model_provider.check_quota_over_limit()
+
+        if not callbacks:
+            callbacks = self.callbacks
+        else:
+            callbacks.extend(self.callbacks)
+
+        if 'fake_response' in kwargs and kwargs['fake_response']:
+            prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
+            fake_llm = FakeLLM(
+                response=kwargs['fake_response'],
+                num_token_func=self.get_num_tokens,
+                streaming=self.streaming,
+                callbacks=callbacks
+            )
+            result = fake_llm.generate([prompts])
+        else:
+            try:
+                result = self._run(
+                    messages=messages,
+                    stop=stop,
+                    callbacks=callbacks if not (self.streaming and not self.support_streaming()) else None,
+                    **kwargs
+                )
+            except Exception as ex:
+                raise self.handle_exceptions(ex)
+
+        if isinstance(result.generations[0][0], ChatGeneration):
+            completion_content = result.generations[0][0].message.content
+        else:
+            completion_content = result.generations[0][0].text
+
+        if self.streaming and not self.support_streaming():
+            # use FakeLLM to simulate streaming when current model not support streaming but streaming is True
+            prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
+            fake_llm = FakeLLM(
+                response=completion_content,
+                num_token_func=self.get_num_tokens,
+                streaming=self.streaming,
+                callbacks=callbacks
+            )
+            fake_llm.generate([prompts])
+
+        if result.llm_output and result.llm_output['token_usage']:
+            prompt_tokens = result.llm_output['token_usage']['prompt_tokens']
+            completion_tokens = result.llm_output['token_usage']['completion_tokens']
+            total_tokens = result.llm_output['token_usage']['total_tokens']
+        else:
+            prompt_tokens = self.get_num_tokens(messages)
+            completion_tokens = self.get_num_tokens([PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
+            total_tokens = prompt_tokens + completion_tokens
+
+        if self.deduct_quota:
+            self.model_provider.deduct_quota(total_tokens)
+
+        return LLMRunResult(
+            content=completion_content,
+            prompt_tokens=prompt_tokens,
+            completion_tokens=completion_tokens
+        )
+
+    @abstractmethod
+    def _run(self, messages: List[PromptMessage],
+             stop: Optional[List[str]] = None,
+             callbacks: Callbacks = None,
+             **kwargs) -> LLMResult:
+        """
+        run predict by prompt messages and stop words.
+
+        :param messages:
+        :param stop:
+        :param callbacks:
+        :return:
+        """
+        raise NotImplementedError
+
+    @abstractmethod
+    def get_num_tokens(self, messages: List[PromptMessage]) -> int:
+        """
+        get num tokens of prompt messages.
+
+        :param messages:
+        :return:
+        """
+        raise NotImplementedError
+
+    @abstractmethod
+    def get_token_price(self, tokens: int, message_type: MessageType):
+        """
+        get token price.
+
+        :param tokens:
+        :param message_type:
+        :return:
+        """
+        raise NotImplementedError
+
+    @abstractmethod
+    def get_currency(self):
+        """
+        get token currency.
+
+        :return:
+        """
+        raise NotImplementedError
+
+    def get_model_kwargs(self):
+        return self.model_kwargs
+
+    def set_model_kwargs(self, model_kwargs: ModelKwargs):
+        self.model_kwargs = model_kwargs
+        self._set_model_kwargs(model_kwargs)
+
+    @abstractmethod
+    def _set_model_kwargs(self, model_kwargs: ModelKwargs):
+        raise NotImplementedError
+
+    @abstractmethod
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        """
+        Handle llm run exceptions.
+
+        :param ex:
+        :return:
+        """
+        raise NotImplementedError
+
+    def add_callbacks(self, callbacks: Callbacks):
+        """
+        Add callbacks to client.
+
+        :param callbacks:
+        :return:
+        """
+        if not self.client.callbacks:
+            self.client.callbacks = callbacks
+        else:
+            self.client.callbacks.extend(callbacks)
+
+    @classmethod
+    def support_streaming(cls):
+        return False
+
+    def _get_prompt_from_messages(self, messages: List[PromptMessage],
+                                  model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
+        if len(messages) == 0:
+            raise ValueError("prompt must not be empty.")
+
+        if not model_mode:
+            model_mode = self.model_mode
+
+        if model_mode == ModelMode.COMPLETION:
+            return messages[0].content
+        else:
+            chat_messages = []
+            for message in messages:
+                if message.type == MessageType.HUMAN:
+                    chat_messages.append(HumanMessage(content=message.content))
+                elif message.type == MessageType.ASSISTANT:
+                    chat_messages.append(AIMessage(content=message.content))
+                elif message.type == MessageType.SYSTEM:
+                    chat_messages.append(SystemMessage(content=message.content))
+
+            return chat_messages
+
+    def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict:
+        """
+        convert model kwargs to provider model kwargs.
+
+        :param model_rules:
+        :param model_kwargs:
+        :return:
+        """
+        model_kwargs_input = {}
+        for key, value in model_kwargs.dict().items():
+            rule = getattr(model_rules, key)
+            if not rule.enabled:
+                continue
+
+            if rule.alias:
+                key = rule.alias
+
+            if rule.default is not None and value is None:
+                value = rule.default
+
+            if rule.min is not None:
+                value = max(value, rule.min)
+
+            if rule.max is not None:
+                value = min(value, rule.max)
+
+            model_kwargs_input[key] = value
+
+        return model_kwargs_input

+ 70 - 0
api/core/model_providers/models/llm/chatglm_model.py

@@ -0,0 +1,70 @@
+import decimal
+from typing import List, Optional, Any
+
+from langchain.callbacks.manager import Callbacks
+from langchain.llms import ChatGLM
+from langchain.schema import LLMResult
+
+from core.model_providers.error import LLMBadRequestError
+from core.model_providers.models.llm.base import BaseLLM
+from core.model_providers.models.entity.message import PromptMessage, MessageType
+from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
+
+
+class ChatGLMModel(BaseLLM):
+    model_mode: ModelMode = ModelMode.COMPLETION
+
+    def _init_client(self) -> Any:
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
+        return ChatGLM(
+            callbacks=self.callbacks,
+            endpoint_url=self.credentials.get('api_base'),
+            **provider_model_kwargs
+        )
+
+    def _run(self, messages: List[PromptMessage],
+             stop: Optional[List[str]] = None,
+             callbacks: Callbacks = None,
+             **kwargs) -> LLMResult:
+        """
+        run predict by prompt messages and stop words.
+
+        :param messages:
+        :param stop:
+        :param callbacks:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        return self._client.generate([prompts], stop, callbacks)
+
+    def get_num_tokens(self, messages: List[PromptMessage]) -> int:
+        """
+        get num tokens of prompt messages.
+
+        :param messages:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        return max(self._client.get_num_tokens(prompts), 0)
+
+    def get_token_price(self, tokens: int, message_type: MessageType):
+        return decimal.Decimal('0')
+
+    def get_currency(self):
+        return 'RMB'
+
+    def _set_model_kwargs(self, model_kwargs: ModelKwargs):
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
+        for k, v in provider_model_kwargs.items():
+            if hasattr(self.client, k):
+                setattr(self.client, k, v)
+
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        if isinstance(ex, ValueError):
+            return LLMBadRequestError(f"ChatGLM: {str(ex)}")
+        else:
+            return ex
+
+    @classmethod
+    def support_streaming(cls):
+        return False

+ 82 - 0
api/core/model_providers/models/llm/huggingface_hub_model.py

@@ -0,0 +1,82 @@
+import decimal
+from functools import wraps
+from typing import List, Optional, Any
+
+from langchain import HuggingFaceHub
+from langchain.callbacks.manager import Callbacks
+from langchain.llms import HuggingFaceEndpoint
+from langchain.schema import LLMResult
+
+from core.model_providers.error import LLMBadRequestError
+from core.model_providers.models.llm.base import BaseLLM
+from core.model_providers.models.entity.message import PromptMessage, MessageType
+from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
+
+
+class HuggingfaceHubModel(BaseLLM):
+    model_mode: ModelMode = ModelMode.COMPLETION
+
+    def _init_client(self) -> Any:
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
+        if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints':
+            client = HuggingFaceEndpoint(
+                endpoint_url=self.credentials['huggingfacehub_endpoint_url'],
+                task='text2text-generation',
+                model_kwargs=provider_model_kwargs,
+                huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'],
+                callbacks=self.callbacks,
+            )
+        else:
+            client = HuggingFaceHub(
+                repo_id=self.name,
+                task=self.credentials['task_type'],
+                model_kwargs=provider_model_kwargs,
+                huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'],
+                callbacks=self.callbacks,
+            )
+
+        return client
+
+    def _run(self, messages: List[PromptMessage],
+             stop: Optional[List[str]] = None,
+             callbacks: Callbacks = None,
+             **kwargs) -> LLMResult:
+        """
+        run predict by prompt messages and stop words.
+
+        :param messages:
+        :param stop:
+        :param callbacks:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        return self._client.generate([prompts], stop, callbacks)
+
+    def get_num_tokens(self, messages: List[PromptMessage]) -> int:
+        """
+        get num tokens of prompt messages.
+
+        :param messages:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        return self._client.get_num_tokens(prompts)
+
+    def get_token_price(self, tokens: int, message_type: MessageType):
+        # not support calc price
+        return decimal.Decimal('0')
+
+    def get_currency(self):
+        return 'USD'
+
+    def _set_model_kwargs(self, model_kwargs: ModelKwargs):
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
+        self.client.model_kwargs = provider_model_kwargs
+
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        return LLMBadRequestError(f"Huggingface Hub: {str(ex)}")
+
+    @classmethod
+    def support_streaming(cls):
+        return False
+

+ 70 - 0
api/core/model_providers/models/llm/minimax_model.py

@@ -0,0 +1,70 @@
+import decimal
+from typing import List, Optional, Any
+
+from langchain.callbacks.manager import Callbacks
+from langchain.llms import Minimax
+from langchain.schema import LLMResult
+
+from core.model_providers.error import LLMBadRequestError
+from core.model_providers.models.llm.base import BaseLLM
+from core.model_providers.models.entity.message import PromptMessage, MessageType
+from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
+
+
+class MinimaxModel(BaseLLM):
+    model_mode: ModelMode = ModelMode.COMPLETION
+
+    def _init_client(self) -> Any:
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
+        return Minimax(
+            model=self.name,
+            model_kwargs={
+                'stream': False
+            },
+            callbacks=self.callbacks,
+            **self.credentials,
+            **provider_model_kwargs
+        )
+
+    def _run(self, messages: List[PromptMessage],
+             stop: Optional[List[str]] = None,
+             callbacks: Callbacks = None,
+             **kwargs) -> LLMResult:
+        """
+        run predict by prompt messages and stop words.
+
+        :param messages:
+        :param stop:
+        :param callbacks:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        return self._client.generate([prompts], stop, callbacks)
+
+    def get_num_tokens(self, messages: List[PromptMessage]) -> int:
+        """
+        get num tokens of prompt messages.
+
+        :param messages:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        return max(self._client.get_num_tokens(prompts), 0)
+
+    def get_token_price(self, tokens: int, message_type: MessageType):
+        return decimal.Decimal('0')
+
+    def get_currency(self):
+        return 'RMB'
+
+    def _set_model_kwargs(self, model_kwargs: ModelKwargs):
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
+        for k, v in provider_model_kwargs.items():
+            if hasattr(self.client, k):
+                setattr(self.client, k, v)
+
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        if isinstance(ex, ValueError):
+            return LLMBadRequestError(f"Minimax: {str(ex)}")
+        else:
+            return ex

+ 219 - 0
api/core/model_providers/models/llm/openai_model.py

@@ -0,0 +1,219 @@
+import decimal
+import logging
+from typing import List, Optional, Any
+
+import openai
+from langchain.callbacks.manager import Callbacks
+from langchain.schema import LLMResult
+
+from core.model_providers.providers.base import BaseModelProvider
+from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
+from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
+    LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
+from core.third_party.langchain.llms.open_ai import EnhanceOpenAI
+from core.model_providers.models.llm.base import BaseLLM
+from core.model_providers.models.entity.message import PromptMessage, MessageType
+from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
+from models.provider import ProviderType, ProviderQuotaType
+
+COMPLETION_MODELS = [
+    'text-davinci-003',  # 4,097 tokens
+]
+
+CHAT_MODELS = [
+    'gpt-4',  # 8,192 tokens
+    'gpt-4-32k',  # 32,768 tokens
+    'gpt-3.5-turbo',  # 4,096 tokens
+    'gpt-3.5-turbo-16k',  # 16,384 tokens
+]
+
+MODEL_MAX_TOKENS = {
+    'gpt-4': 8192,
+    'gpt-4-32k': 32768,
+    'gpt-3.5-turbo': 4096,
+    'gpt-3.5-turbo-16k': 16384,
+    'text-davinci-003': 4097,
+}
+
+
+class OpenAIModel(BaseLLM):
+    def __init__(self, model_provider: BaseModelProvider,
+                 name: str,
+                 model_kwargs: ModelKwargs,
+                 streaming: bool = False,
+                 callbacks: Callbacks = None):
+        if name in COMPLETION_MODELS:
+            self.model_mode = ModelMode.COMPLETION
+        else:
+            self.model_mode = ModelMode.CHAT
+
+        super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
+
+    def _init_client(self) -> Any:
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
+        if self.name in COMPLETION_MODELS:
+            client = EnhanceOpenAI(
+                model_name=self.name,
+                streaming=self.streaming,
+                callbacks=self.callbacks,
+                request_timeout=60,
+                **self.credentials,
+                **provider_model_kwargs
+            )
+        else:
+            # Fine-tuning is currently only available for the following base models:
+            # davinci, curie, babbage, and ada.
+            # This means that except for the fixed `completion` model,
+            # all other fine-tuned models are `completion` models.
+            extra_model_kwargs = {
+                'top_p': provider_model_kwargs.get('top_p'),
+                'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
+                'presence_penalty': provider_model_kwargs.get('presence_penalty'),
+            }
+
+            client = EnhanceChatOpenAI(
+                model_name=self.name,
+                temperature=provider_model_kwargs.get('temperature'),
+                max_tokens=provider_model_kwargs.get('max_tokens'),
+                model_kwargs=extra_model_kwargs,
+                streaming=self.streaming,
+                callbacks=self.callbacks,
+                request_timeout=60,
+                **self.credentials
+            )
+
+        return client
+
+    def _run(self, messages: List[PromptMessage],
+             stop: Optional[List[str]] = None,
+             callbacks: Callbacks = None,
+             **kwargs) -> LLMResult:
+        """
+        run predict by prompt messages and stop words.
+
+        :param messages:
+        :param stop:
+        :param callbacks:
+        :return:
+        """
+        if self.name == 'gpt-4' \
+                and self.model_provider.provider.provider_type == ProviderType.SYSTEM.value \
+                and self.model_provider.provider.quota_type == ProviderQuotaType.TRIAL.value:
+            raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.")
+
+        prompts = self._get_prompt_from_messages(messages)
+        return self._client.generate([prompts], stop, callbacks)
+
+    def get_num_tokens(self, messages: List[PromptMessage]) -> int:
+        """
+        get num tokens of prompt messages.
+
+        :param messages:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        if isinstance(prompts, str):
+            return self._client.get_num_tokens(prompts)
+        else:
+            return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
+
+    def get_token_price(self, tokens: int, message_type: MessageType):
+        model_unit_prices = {
+            'gpt-4': {
+                'prompt': decimal.Decimal('0.03'),
+                'completion': decimal.Decimal('0.06'),
+            },
+            'gpt-4-32k': {
+                'prompt': decimal.Decimal('0.06'),
+                'completion': decimal.Decimal('0.12')
+            },
+            'gpt-3.5-turbo': {
+                'prompt': decimal.Decimal('0.0015'),
+                'completion': decimal.Decimal('0.002')
+            },
+            'gpt-3.5-turbo-16k': {
+                'prompt': decimal.Decimal('0.003'),
+                'completion': decimal.Decimal('0.004')
+            },
+            'text-davinci-003': {
+                'prompt': decimal.Decimal('0.02'),
+                'completion': decimal.Decimal('0.02')
+            },
+        }
+
+        if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
+            unit_price = model_unit_prices[self.name]['prompt']
+        else:
+            unit_price = model_unit_prices[self.name]['completion']
+
+        tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
+                                                                  rounding=decimal.ROUND_HALF_UP)
+
+        total_price = tokens_per_1k * unit_price
+        return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
+
+    def get_currency(self):
+        return 'USD'
+
+    def _set_model_kwargs(self, model_kwargs: ModelKwargs):
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
+        if self.name in COMPLETION_MODELS:
+            for k, v in provider_model_kwargs.items():
+                if hasattr(self.client, k):
+                    setattr(self.client, k, v)
+        else:
+            extra_model_kwargs = {
+                'top_p': provider_model_kwargs.get('top_p'),
+                'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
+                'presence_penalty': provider_model_kwargs.get('presence_penalty'),
+            }
+
+            self.client.temperature = provider_model_kwargs.get('temperature')
+            self.client.max_tokens = provider_model_kwargs.get('max_tokens')
+            self.client.model_kwargs = extra_model_kwargs
+
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        if isinstance(ex, openai.error.InvalidRequestError):
+            logging.warning("Invalid request to OpenAI API.")
+            return LLMBadRequestError(str(ex))
+        elif isinstance(ex, openai.error.APIConnectionError):
+            logging.warning("Failed to connect to OpenAI API.")
+            return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
+        elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
+            logging.warning("OpenAI service unavailable.")
+            return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
+        elif isinstance(ex, openai.error.RateLimitError):
+            return LLMRateLimitError(str(ex))
+        elif isinstance(ex, openai.error.AuthenticationError):
+            raise LLMAuthorizationError(str(ex))
+        elif isinstance(ex, openai.error.OpenAIError):
+            return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
+        else:
+            return ex
+
+    @classmethod
+    def support_streaming(cls):
+        return True
+
+    # def is_model_valid_or_raise(self):
+    #     """
+    #     check is a valid model.
+    #
+    #     :return:
+    #     """
+    #     credentials = self._model_provider.get_credentials()
+    #
+    #     try:
+    #         result = openai.Model.retrieve(
+    #             id=self.name,
+    #             api_key=credentials.get('openai_api_key'),
+    #             request_timeout=60
+    #         )
+    #
+    #         if 'id' not in result or result['id'] != self.name:
+    #             raise LLMNotExistsError(f"OpenAI Model {self.name} not exists.")
+    #     except openai.error.OpenAIError as e:
+    #         raise LLMNotExistsError(f"OpenAI Model {self.name} not exists, cause: {e.__class__.__name__}:{str(e)}")
+    #     except Exception as e:
+    #         logging.exception("OpenAI Model retrieve failed.")
+    #         raise e

+ 103 - 0
api/core/model_providers/models/llm/replicate_model.py

@@ -0,0 +1,103 @@
+import decimal
+from functools import wraps
+from typing import List, Optional, Any
+
+from langchain.callbacks.manager import Callbacks
+from langchain.schema import LLMResult, get_buffer_string
+from replicate.exceptions import ReplicateError, ModelError
+
+from core.model_providers.providers.base import BaseModelProvider
+from core.model_providers.error import LLMBadRequestError
+from core.third_party.langchain.llms.replicate_llm import EnhanceReplicate
+from core.model_providers.models.llm.base import BaseLLM
+from core.model_providers.models.entity.message import PromptMessage, MessageType
+from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
+
+
+class ReplicateModel(BaseLLM):
+    def __init__(self, model_provider: BaseModelProvider,
+                 name: str,
+                 model_kwargs: ModelKwargs,
+                 streaming: bool = False,
+                 callbacks: Callbacks = None):
+        self.model_mode = ModelMode.CHAT if name.endswith('-chat') else ModelMode.COMPLETION
+
+        super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
+
+    def _init_client(self) -> Any:
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
+
+        return EnhanceReplicate(
+            model=self.name + ':' + self.credentials.get('model_version'),
+            input=provider_model_kwargs,
+            streaming=self.streaming,
+            replicate_api_token=self.credentials.get('replicate_api_token'),
+            callbacks=self.callbacks,
+        )
+
+    def _run(self, messages: List[PromptMessage],
+             stop: Optional[List[str]] = None,
+             callbacks: Callbacks = None,
+             **kwargs) -> LLMResult:
+        """
+        run predict by prompt messages and stop words.
+
+        :param messages:
+        :param stop:
+        :param callbacks:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        extra_kwargs = {}
+        if isinstance(prompts, list):
+            system_messages = [message for message in messages if message.type == 'system']
+            if system_messages:
+                system_message = system_messages[0]
+                extra_kwargs['system_prompt'] = system_message.content
+                prompts = [message for message in messages if message.type != 'system']
+
+            prompts = get_buffer_string(prompts)
+
+        # The maximum length the generated tokens can have.
+        # Corresponds to the length of the input prompt + max_new_tokens.
+        if 'max_length' in self._client.input:
+            self._client.input['max_length'] = min(
+                self._client.input['max_length'] + self.get_num_tokens(messages),
+                self.model_rules.max_tokens.max
+            )
+
+        return self._client.generate([prompts], stop, callbacks, **extra_kwargs)
+
+    def get_num_tokens(self, messages: List[PromptMessage]) -> int:
+        """
+        get num tokens of prompt messages.
+
+        :param messages:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        if isinstance(prompts, list):
+            prompts = get_buffer_string(prompts)
+
+        return self._client.get_num_tokens(prompts)
+
+    def get_token_price(self, tokens: int, message_type: MessageType):
+        # replicate only pay for prediction seconds
+        return decimal.Decimal('0')
+
+    def get_currency(self):
+        return 'USD'
+
+    def _set_model_kwargs(self, model_kwargs: ModelKwargs):
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
+        self.client.input = provider_model_kwargs
+
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        if isinstance(ex, (ModelError, ReplicateError)):
+            return LLMBadRequestError(f"Replicate: {str(ex)}")
+        else:
+            return ex
+
+    @classmethod
+    def support_streaming(cls):
+        return True

+ 73 - 0
api/core/model_providers/models/llm/spark_model.py

@@ -0,0 +1,73 @@
+import decimal
+from functools import wraps
+from typing import List, Optional, Any
+
+from langchain.callbacks.manager import Callbacks
+from langchain.schema import LLMResult
+
+from core.model_providers.error import LLMBadRequestError
+from core.model_providers.models.llm.base import BaseLLM
+from core.model_providers.models.entity.message import PromptMessage, MessageType
+from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
+from core.third_party.langchain.llms.spark import ChatSpark
+from core.third_party.spark.spark_llm import SparkError
+
+
+class SparkModel(BaseLLM):
+    model_mode: ModelMode = ModelMode.CHAT
+
+    def _init_client(self) -> Any:
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
+        return ChatSpark(
+            streaming=self.streaming,
+            callbacks=self.callbacks,
+            **self.credentials,
+            **provider_model_kwargs
+        )
+
+    def _run(self, messages: List[PromptMessage],
+             stop: Optional[List[str]] = None,
+             callbacks: Callbacks = None,
+             **kwargs) -> LLMResult:
+        """
+        run predict by prompt messages and stop words.
+
+        :param messages:
+        :param stop:
+        :param callbacks:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        return self._client.generate([prompts], stop, callbacks)
+
+    def get_num_tokens(self, messages: List[PromptMessage]) -> int:
+        """
+        get num tokens of prompt messages.
+
+        :param messages:
+        :return:
+        """
+        contents = [message.content for message in messages]
+        return max(self._client.get_num_tokens("".join(contents)), 0)
+
+    def get_token_price(self, tokens: int, message_type: MessageType):
+        return decimal.Decimal('0')
+
+    def get_currency(self):
+        return 'RMB'
+
+    def _set_model_kwargs(self, model_kwargs: ModelKwargs):
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
+        for k, v in provider_model_kwargs.items():
+            if hasattr(self.client, k):
+                setattr(self.client, k, v)
+
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        if isinstance(ex, SparkError):
+            return LLMBadRequestError(f"Spark: {str(ex)}")
+        else:
+            return ex
+
+    @classmethod
+    def support_streaming(cls):
+        return True

Nem az összes módosított fájl került megjelenítésre, mert túl sok fájl változott