Forráskód Böngészése

feat: server multi models support (#799)

takatost 1 éve
szülő
commit
5fa2161b05
100 módosított fájl, 3368 hozzáadás és 1699 törlés
  1. 2 1
      .github/workflows/check_no_chinese_comments.py
  2. 26 0
      api/.env.example
  3. 17 2
      api/app.py
  4. 24 13
      api/commands.py
  5. 39 13
      api/config.py
  6. 4 1
      api/controllers/console/__init__.py
  7. 21 5
      api/controllers/console/app/app.py
  8. 1 1
      api/controllers/console/app/audio.py
  9. 9 3
      api/controllers/console/app/completion.py
  10. 1 1
      api/controllers/console/app/generator.py
  11. 1 1
      api/controllers/console/app/message.py
  12. 2 2
      api/controllers/console/app/model_config.py
  13. 1 1
      api/controllers/console/datasets/data_source.py
  14. 29 3
      api/controllers/console/datasets/datasets.py
  15. 43 5
      api/controllers/console/datasets/datasets_document.py
  16. 3 1
      api/controllers/console/datasets/hit_testing.py
  17. 1 1
      api/controllers/console/explore/audio.py
  18. 1 1
      api/controllers/console/explore/completion.py
  19. 1 1
      api/controllers/console/explore/message.py
  20. 1 4
      api/controllers/console/explore/parameter.py
  21. 1 1
      api/controllers/console/universal_chat/audio.py
  22. 3 7
      api/controllers/console/universal_chat/chat.py
  23. 1 1
      api/controllers/console/universal_chat/message.py
  24. 1 4
      api/controllers/console/universal_chat/parameter.py
  25. 0 0
      api/controllers/console/webhook/__init__.py
  26. 53 0
      api/controllers/console/webhook/stripe.py
  27. 203 193
      api/controllers/console/workspace/model_providers.py
  28. 108 0
      api/controllers/console/workspace/models.py
  29. 130 0
      api/controllers/console/workspace/providers.py
  30. 1 1
      api/controllers/console/workspace/workspace.py
  31. 1 4
      api/controllers/service_api/app/app.py
  32. 1 1
      api/controllers/service_api/app/audio.py
  33. 1 1
      api/controllers/service_api/app/completion.py
  34. 1 1
      api/controllers/service_api/dataset/document.py
  35. 1 4
      api/controllers/web/app.py
  36. 1 1
      api/controllers/web/audio.py
  37. 1 1
      api/controllers/web/completion.py
  38. 1 1
      api/controllers/web/message.py
  39. 0 36
      api/core/__init__.py
  40. 9 13
      api/core/agent/agent/calc_token_mixin.py
  41. 9 1
      api/core/agent/agent/multi_dataset_router_agent.py
  42. 3 2
      api/core/agent/agent/openai_function_call.py
  43. 11 3
      api/core/agent/agent/openai_function_call_summarize_mixin.py
  44. 3 2
      api/core/agent/agent/openai_multi_function_call.py
  45. 162 0
      api/core/agent/agent/structed_multi_dataset_router_agent.py
  46. 8 2
      api/core/agent/agent/structured_chat.py
  47. 25 11
      api/core/agent/agent_executor.py
  48. 5 4
      api/core/callback_handler/agent_loop_gather_callback_handler.py
  49. 11 7
      api/core/callback_handler/llm_callback_handler.py
  50. 0 2
      api/core/callback_handler/main_chain_gather_callback_handler.py
  51. 86 111
      api/core/completion.py
  52. 0 109
      api/core/constant/llm_constant.py
  53. 22 37
      api/core/conversation_message_task.py
  54. 6 8
      api/core/docstore/dataset_docstore.py
  55. 33 25
      api/core/embedding/cached_embedding.py
  56. 61 83
      api/core/generator/llm_generator.py
  57. 0 0
      api/core/helper/__init__.py
  58. 20 0
      api/core/helper/encrypter.py
  59. 4 10
      api/core/index/index.py
  60. 46 43
      api/core/indexing_runner.py
  61. 0 148
      api/core/llm/llm_builder.py
  62. 0 15
      api/core/llm/moderation.py
  63. 0 138
      api/core/llm/provider/anthropic_provider.py
  64. 0 145
      api/core/llm/provider/azure_provider.py
  65. 0 132
      api/core/llm/provider/base.py
  66. 0 2
      api/core/llm/provider/errors.py
  67. 0 22
      api/core/llm/provider/huggingface_provider.py
  68. 0 53
      api/core/llm/provider/llm_provider_service.py
  69. 0 55
      api/core/llm/provider/openai_provider.py
  70. 0 62
      api/core/llm/streamable_chat_anthropic.py
  71. 0 41
      api/core/llm/token_calculator.py
  72. 0 26
      api/core/llm/whisper.py
  73. 0 27
      api/core/llm/wrappers/anthropic_wrapper.py
  74. 0 31
      api/core/llm/wrappers/openai_wrapper.py
  75. 12 12
      api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py
  76. 0 0
      api/core/model_providers/error.py
  77. 293 0
      api/core/model_providers/model_factory.py
  78. 228 0
      api/core/model_providers/model_provider_factory.py
  79. 0 0
      api/core/model_providers/models/__init__.py
  80. 22 0
      api/core/model_providers/models/base.py
  81. 0 0
      api/core/model_providers/models/embedding/__init__.py
  82. 78 0
      api/core/model_providers/models/embedding/azure_openai_embedding.py
  83. 40 0
      api/core/model_providers/models/embedding/base.py
  84. 35 0
      api/core/model_providers/models/embedding/minimax_embedding.py
  85. 72 0
      api/core/model_providers/models/embedding/openai_embedding.py
  86. 36 0
      api/core/model_providers/models/embedding/replicate_embedding.py
  87. 0 0
      api/core/model_providers/models/entity/__init__.py
  88. 53 0
      api/core/model_providers/models/entity/message.py
  89. 59 0
      api/core/model_providers/models/entity/model_params.py
  90. 10 0
      api/core/model_providers/models/entity/provider.py
  91. 0 0
      api/core/model_providers/models/llm/__init__.py
  92. 107 0
      api/core/model_providers/models/llm/anthropic_model.py
  93. 177 0
      api/core/model_providers/models/llm/azure_openai_model.py
  94. 269 0
      api/core/model_providers/models/llm/base.py
  95. 70 0
      api/core/model_providers/models/llm/chatglm_model.py
  96. 82 0
      api/core/model_providers/models/llm/huggingface_hub_model.py
  97. 70 0
      api/core/model_providers/models/llm/minimax_model.py
  98. 219 0
      api/core/model_providers/models/llm/openai_model.py
  99. 103 0
      api/core/model_providers/models/llm/replicate_model.py
  100. 73 0
      api/core/model_providers/models/llm/spark_model.py

+ 2 - 1
.github/workflows/check_no_chinese_comments.py

@@ -19,7 +19,8 @@ def check_file_for_chinese_comments(file_path):
 
 
 def main():
 def main():
     has_chinese = False
     has_chinese = False
-    excluded_files = ["model_template.py", 'stopwords.py', 'commands.py', 'indexing_runner.py', 'web_reader_tool.py']
+    excluded_files = ["model_template.py", 'stopwords.py', 'commands.py',
+                      'indexing_runner.py', 'web_reader_tool.py', 'spark_provider.py']
 
 
     for root, _, files in os.walk("."):
     for root, _, files in os.walk("."):
         for file in files:
         for file in files:

+ 26 - 0
api/.env.example

@@ -102,3 +102,29 @@ NOTION_INTEGRATION_TYPE=public
 NOTION_CLIENT_SECRET=you-client-secret
 NOTION_CLIENT_SECRET=you-client-secret
 NOTION_CLIENT_ID=you-client-id
 NOTION_CLIENT_ID=you-client-id
 NOTION_INTERNAL_SECRET=you-internal-secret
 NOTION_INTERNAL_SECRET=you-internal-secret
+
+# Hosted Model Credentials
+HOSTED_OPENAI_ENABLED=false
+HOSTED_OPENAI_API_KEY=
+HOSTED_OPENAI_API_BASE=
+HOSTED_OPENAI_API_ORGANIZATION=
+HOSTED_OPENAI_QUOTA_LIMIT=200
+HOSTED_OPENAI_PAID_ENABLED=false
+HOSTED_OPENAI_PAID_STRIPE_PRICE_ID=
+HOSTED_OPENAI_PAID_INCREASE_QUOTA=1
+
+HOSTED_AZURE_OPENAI_ENABLED=false
+HOSTED_AZURE_OPENAI_API_KEY=
+HOSTED_AZURE_OPENAI_API_BASE=
+HOSTED_AZURE_OPENAI_QUOTA_LIMIT=200
+
+HOSTED_ANTHROPIC_ENABLED=false
+HOSTED_ANTHROPIC_API_BASE=
+HOSTED_ANTHROPIC_API_KEY=
+HOSTED_ANTHROPIC_QUOTA_LIMIT=1000000
+HOSTED_ANTHROPIC_PAID_ENABLED=false
+HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID=
+HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1
+
+STRIPE_API_KEY=
+STRIPE_WEBHOOK_SECRET=

+ 17 - 2
api/app.py

@@ -16,8 +16,9 @@ from flask import Flask, request, Response, session
 import flask_login
 import flask_login
 from flask_cors import CORS
 from flask_cors import CORS
 
 
+from core.model_providers.providers import hosted
 from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
 from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
-    ext_database, ext_storage, ext_mail
+    ext_database, ext_storage, ext_mail, ext_stripe
 from extensions.ext_database import db
 from extensions.ext_database import db
 from extensions.ext_login import login_manager
 from extensions.ext_login import login_manager
 
 
@@ -71,7 +72,7 @@ def create_app(test_config=None) -> Flask:
     register_blueprints(app)
     register_blueprints(app)
     register_commands(app)
     register_commands(app)
 
 
-    core.init_app(app)
+    hosted.init_app(app)
 
 
     return app
     return app
 
 
@@ -88,6 +89,7 @@ def initialize_extensions(app):
     ext_login.init_app(app)
     ext_login.init_app(app)
     ext_mail.init_app(app)
     ext_mail.init_app(app)
     ext_sentry.init_app(app)
     ext_sentry.init_app(app)
+    ext_stripe.init_app(app)
 
 
 
 
 def _create_tenant_for_account(account):
 def _create_tenant_for_account(account):
@@ -246,5 +248,18 @@ def threads():
     }
     }
 
 
 
 
+@app.route('/db-pool-stat')
+def pool_stat():
+    engine = db.engine
+    return {
+        'pool_size': engine.pool.size(),
+        'checked_in_connections': engine.pool.checkedin(),
+        'checked_out_connections': engine.pool.checkedout(),
+        'overflow_connections': engine.pool.overflow(),
+        'connection_timeout': engine.pool.timeout(),
+        'recycle_time': db.engine.pool._recycle
+    }
+
+
 if __name__ == '__main__':
 if __name__ == '__main__':
     app.run(host='0.0.0.0', port=5001)
     app.run(host='0.0.0.0', port=5001)

+ 24 - 13
api/commands.py

@@ -1,5 +1,5 @@
 import datetime
 import datetime
-import logging
+import math
 import random
 import random
 import string
 import string
 import time
 import time
@@ -9,18 +9,18 @@ from flask import current_app
 from werkzeug.exceptions import NotFound
 from werkzeug.exceptions import NotFound
 
 
 from core.index.index import IndexBuilder
 from core.index.index import IndexBuilder
+from core.model_providers.providers.hosted import hosted_model_providers
 from libs.password import password_pattern, valid_password, hash_password
 from libs.password import password_pattern, valid_password, hash_password
 from libs.helper import email as email_validate
 from libs.helper import email as email_validate
 from extensions.ext_database import db
 from extensions.ext_database import db
 from libs.rsa import generate_key_pair
 from libs.rsa import generate_key_pair
 from models.account import InvitationCode, Tenant
 from models.account import InvitationCode, Tenant
-from models.dataset import Dataset, DatasetQuery, Document, DocumentSegment
+from models.dataset import Dataset, DatasetQuery, Document
 from models.model import Account
 from models.model import Account
 import secrets
 import secrets
 import base64
 import base64
 
 
-from models.provider import Provider, ProviderName
-from services.provider_service import ProviderService
+from models.provider import Provider, ProviderType, ProviderQuotaType
 
 
 
 
 @click.command('reset-password', help='Reset the account password.')
 @click.command('reset-password', help='Reset the account password.')
@@ -251,26 +251,37 @@ def clean_unused_dataset_indexes():
 
 
 @click.command('sync-anthropic-hosted-providers', help='Sync anthropic hosted providers.')
 @click.command('sync-anthropic-hosted-providers', help='Sync anthropic hosted providers.')
 def sync_anthropic_hosted_providers():
 def sync_anthropic_hosted_providers():
+    if not hosted_model_providers.anthropic:
+        click.echo(click.style('Anthropic hosted provider is not configured.', fg='red'))
+        return
+
     click.echo(click.style('Start sync anthropic hosted providers.', fg='green'))
     click.echo(click.style('Start sync anthropic hosted providers.', fg='green'))
     count = 0
     count = 0
 
 
     page = 1
     page = 1
     while True:
     while True:
         try:
         try:
-            tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc()).paginate(page=page, per_page=50)
+            providers = db.session.query(Provider).filter(
+                Provider.provider_name == 'anthropic',
+                Provider.provider_type == ProviderType.SYSTEM.value,
+                Provider.quota_type == ProviderQuotaType.TRIAL.value,
+            ).order_by(Provider.created_at.desc()).paginate(page=page, per_page=100)
         except NotFound:
         except NotFound:
             break
             break
 
 
         page += 1
         page += 1
-        for tenant in tenants:
+        for provider in providers:
             try:
             try:
-                click.echo('Syncing tenant anthropic hosted provider: {}'.format(tenant.id))
-                ProviderService.create_system_provider(
-                    tenant,
-                    ProviderName.ANTHROPIC.value,
-                    current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
-                    True
-                )
+                click.echo('Syncing tenant anthropic hosted provider: {}'.format(provider.tenant_id))
+                original_quota_limit = provider.quota_limit
+                new_quota_limit = hosted_model_providers.anthropic.quota_limit
+                division = math.ceil(new_quota_limit / 1000)
+
+                provider.quota_limit = new_quota_limit if original_quota_limit == 1000 \
+                    else original_quota_limit * division
+                provider.quota_used = division * provider.quota_used
+                db.session.commit()
+
                 count += 1
                 count += 1
             except Exception as e:
             except Exception as e:
                 click.echo(click.style(
                 click.echo(click.style(

+ 39 - 13
api/config.py

@@ -41,6 +41,7 @@ DEFAULTS = {
     'SESSION_USE_SIGNER': 'True',
     'SESSION_USE_SIGNER': 'True',
     'DEPLOY_ENV': 'PRODUCTION',
     'DEPLOY_ENV': 'PRODUCTION',
     'SQLALCHEMY_POOL_SIZE': 30,
     'SQLALCHEMY_POOL_SIZE': 30,
+    'SQLALCHEMY_POOL_RECYCLE': 3600,
     'SQLALCHEMY_ECHO': 'False',
     'SQLALCHEMY_ECHO': 'False',
     'SENTRY_TRACES_SAMPLE_RATE': 1.0,
     'SENTRY_TRACES_SAMPLE_RATE': 1.0,
     'SENTRY_PROFILES_SAMPLE_RATE': 1.0,
     'SENTRY_PROFILES_SAMPLE_RATE': 1.0,
@@ -50,9 +51,16 @@ DEFAULTS = {
     'PDF_PREVIEW': 'True',
     'PDF_PREVIEW': 'True',
     'LOG_LEVEL': 'INFO',
     'LOG_LEVEL': 'INFO',
     'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False',
     'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False',
-    'DEFAULT_LLM_PROVIDER': 'openai',
-    'OPENAI_HOSTED_QUOTA_LIMIT': 200,
-    'ANTHROPIC_HOSTED_QUOTA_LIMIT': 1000,
+    'HOSTED_OPENAI_QUOTA_LIMIT': 200,
+    'HOSTED_OPENAI_ENABLED': 'False',
+    'HOSTED_OPENAI_PAID_ENABLED': 'False',
+    'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1,
+    'HOSTED_AZURE_OPENAI_ENABLED': 'False',
+    'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
+    'HOSTED_ANTHROPIC_QUOTA_LIMIT': 1000000,
+    'HOSTED_ANTHROPIC_ENABLED': 'False',
+    'HOSTED_ANTHROPIC_PAID_ENABLED': 'False',
+    'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1,
     'TENANT_DOCUMENT_COUNT': 100,
     'TENANT_DOCUMENT_COUNT': 100,
     'CLEAN_DAY_SETTING': 30
     'CLEAN_DAY_SETTING': 30
 }
 }
@@ -182,7 +190,10 @@ class Config:
         }
         }
 
 
         self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/{db_credentials['DB_DATABASE']}"
         self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/{db_credentials['DB_DATABASE']}"
-        self.SQLALCHEMY_ENGINE_OPTIONS = {'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE'))}
+        self.SQLALCHEMY_ENGINE_OPTIONS = {
+            'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE')),
+            'pool_recycle': int(get_env('SQLALCHEMY_POOL_RECYCLE'))
+        }
 
 
         self.SQLALCHEMY_ECHO = get_bool_env('SQLALCHEMY_ECHO')
         self.SQLALCHEMY_ECHO = get_bool_env('SQLALCHEMY_ECHO')
 
 
@@ -194,20 +205,35 @@ class Config:
         self.BROKER_USE_SSL = self.CELERY_BROKER_URL.startswith('rediss://')
         self.BROKER_USE_SSL = self.CELERY_BROKER_URL.startswith('rediss://')
 
 
         # hosted provider credentials
         # hosted provider credentials
-        self.OPENAI_API_KEY = get_env('OPENAI_API_KEY')
-        self.ANTHROPIC_API_KEY = get_env('ANTHROPIC_API_KEY')
-
-        self.OPENAI_HOSTED_QUOTA_LIMIT = get_env('OPENAI_HOSTED_QUOTA_LIMIT')
-        self.ANTHROPIC_HOSTED_QUOTA_LIMIT = get_env('ANTHROPIC_HOSTED_QUOTA_LIMIT')
+        self.HOSTED_OPENAI_ENABLED = get_bool_env('HOSTED_OPENAI_ENABLED')
+        self.HOSTED_OPENAI_API_KEY = get_env('HOSTED_OPENAI_API_KEY')
+        self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
+        self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
+        self.HOSTED_OPENAI_QUOTA_LIMIT = get_env('HOSTED_OPENAI_QUOTA_LIMIT')
+        self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
+        self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID')
+        self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA'))
+
+        self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
+        self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
+        self.HOSTED_AZURE_OPENAI_API_BASE = get_env('HOSTED_AZURE_OPENAI_API_BASE')
+        self.HOSTED_AZURE_OPENAI_QUOTA_LIMIT = get_env('HOSTED_AZURE_OPENAI_QUOTA_LIMIT')
+
+        self.HOSTED_ANTHROPIC_ENABLED = get_bool_env('HOSTED_ANTHROPIC_ENABLED')
+        self.HOSTED_ANTHROPIC_API_BASE = get_env('HOSTED_ANTHROPIC_API_BASE')
+        self.HOSTED_ANTHROPIC_API_KEY = get_env('HOSTED_ANTHROPIC_API_KEY')
+        self.HOSTED_ANTHROPIC_QUOTA_LIMIT = get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT')
+        self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED')
+        self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID')
+        self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA')
+
+        self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
+        self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
 
 
         # By default it is False
         # By default it is False
         # You could disable it for compatibility with certain OpenAPI providers
         # You could disable it for compatibility with certain OpenAPI providers
         self.DISABLE_PROVIDER_CONFIG_VALIDATION = get_bool_env('DISABLE_PROVIDER_CONFIG_VALIDATION')
         self.DISABLE_PROVIDER_CONFIG_VALIDATION = get_bool_env('DISABLE_PROVIDER_CONFIG_VALIDATION')
 
 
-        # For temp use only
-        # set default LLM provider, default is 'openai', support `azure_openai`
-        self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER')
-
         # notion import setting
         # notion import setting
         self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
         self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
         self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
         self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')

+ 4 - 1
api/controllers/console/__init__.py

@@ -18,10 +18,13 @@ from .auth import login, oauth, data_source_oauth, activate
 from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source
 from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source
 
 
 # Import workspace controllers
 # Import workspace controllers
-from .workspace import workspace, members, model_providers, account, tool_providers
+from .workspace import workspace, members, providers, model_providers, account, tool_providers, models
 
 
 # Import explore controllers
 # Import explore controllers
 from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio
 from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio
 
 
 # Import universal chat controllers
 # Import universal chat controllers
 from .universal_chat import chat, conversation, message, parameter, audio
 from .universal_chat import chat, conversation, message, parameter, audio
+
+# Import webhook controllers
+from .webhook import stripe

+ 21 - 5
api/controllers/console/app/app.py

@@ -2,16 +2,17 @@
 import json
 import json
 from datetime import datetime
 from datetime import datetime
 
 
-import flask
 from flask_login import login_required, current_user
 from flask_login import login_required, current_user
 from flask_restful import Resource, reqparse, fields, marshal_with, abort, inputs
 from flask_restful import Resource, reqparse, fields, marshal_with, abort, inputs
-from werkzeug.exceptions import Unauthorized, Forbidden
+from werkzeug.exceptions import Forbidden
 
 
 from constants.model_template import model_templates, demo_model_templates
 from constants.model_template import model_templates, demo_model_templates
 from controllers.console import api
 from controllers.console import api
-from controllers.console.app.error import AppNotFoundError
+from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
 from controllers.console.setup import setup_required
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
 from controllers.console.wraps import account_initialization_required
+from core.model_providers.model_factory import ModelFactory
+from core.model_providers.models.entity.model_params import ModelType
 from events.app_event import app_was_created, app_was_deleted
 from events.app_event import app_was_created, app_was_deleted
 from libs.helper import TimestampField
 from libs.helper import TimestampField
 from extensions.ext_database import db
 from extensions.ext_database import db
@@ -126,9 +127,9 @@ class AppListApi(Resource):
         if args['model_config'] is not None:
         if args['model_config'] is not None:
             # validate config
             # validate config
             model_configuration = AppModelConfigService.validate_configuration(
             model_configuration = AppModelConfigService.validate_configuration(
+                tenant_id=current_user.current_tenant_id,
                 account=current_user,
                 account=current_user,
-                config=args['model_config'],
-                mode=args['mode']
+                config=args['model_config']
             )
             )
 
 
             app = App(
             app = App(
@@ -164,6 +165,21 @@ class AppListApi(Resource):
             app = App(**model_config_template['app'])
             app = App(**model_config_template['app'])
             app_model_config = AppModelConfig(**model_config_template['model_config'])
             app_model_config = AppModelConfig(**model_config_template['model_config'])
 
 
+            default_model = ModelFactory.get_default_model(
+                tenant_id=current_user.current_tenant_id,
+                model_type=ModelType.TEXT_GENERATION
+            )
+
+            if default_model:
+                model_dict = app_model_config.model_dict
+                model_dict['provider'] = default_model.provider_name
+                model_dict['name'] = default_model.model_name
+                app_model_config.model = json.dumps(model_dict)
+            else:
+                raise ProviderNotInitializeError(
+                    f"No Text Generation Model available. Please configure a valid provider "
+                    f"in the Settings -> Model Provider.")
+
         app.name = args['name']
         app.name = args['name']
         app.mode = args['mode']
         app.mode = args['mode']
         app.icon = args['icon']
         app.icon = args['icon']

+ 1 - 1
api/controllers/console/app/audio.py

@@ -14,7 +14,7 @@ from controllers.console.app.error import AppUnavailableError, \
     UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
     UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
 from controllers.console.setup import setup_required
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
 from controllers.console.wraps import account_initialization_required
-from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
+from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
     LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
     LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
 from flask_restful import Resource
 from flask_restful import Resource
 from services.audio_service import AudioService
 from services.audio_service import AudioService

+ 9 - 3
api/controllers/console/app/completion.py

@@ -17,7 +17,7 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail
 from controllers.console.setup import setup_required
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
 from controllers.console.wraps import account_initialization_required
 from core.conversation_message_task import PubHandler
 from core.conversation_message_task import PubHandler
-from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
+from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
     LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
     LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
 from libs.helper import uuid_value
 from libs.helper import uuid_value
 from flask_restful import Resource, reqparse
 from flask_restful import Resource, reqparse
@@ -41,8 +41,11 @@ class CompletionMessageApi(Resource):
         parser.add_argument('inputs', type=dict, required=True, location='json')
         parser.add_argument('inputs', type=dict, required=True, location='json')
         parser.add_argument('query', type=str, location='json')
         parser.add_argument('query', type=str, location='json')
         parser.add_argument('model_config', type=dict, required=True, location='json')
         parser.add_argument('model_config', type=dict, required=True, location='json')
+        parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         args = parser.parse_args()
         args = parser.parse_args()
 
 
+        streaming = args['response_mode'] != 'blocking'
+
         account = flask_login.current_user
         account = flask_login.current_user
 
 
         try:
         try:
@@ -51,7 +54,7 @@ class CompletionMessageApi(Resource):
                 user=account,
                 user=account,
                 args=args,
                 args=args,
                 from_source='console',
                 from_source='console',
-                streaming=True,
+                streaming=streaming,
                 is_model_config_override=True
                 is_model_config_override=True
             )
             )
 
 
@@ -111,8 +114,11 @@ class ChatMessageApi(Resource):
         parser.add_argument('query', type=str, required=True, location='json')
         parser.add_argument('query', type=str, required=True, location='json')
         parser.add_argument('model_config', type=dict, required=True, location='json')
         parser.add_argument('model_config', type=dict, required=True, location='json')
         parser.add_argument('conversation_id', type=uuid_value, location='json')
         parser.add_argument('conversation_id', type=uuid_value, location='json')
+        parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         args = parser.parse_args()
         args = parser.parse_args()
 
 
+        streaming = args['response_mode'] != 'blocking'
+
         account = flask_login.current_user
         account = flask_login.current_user
 
 
         try:
         try:
@@ -121,7 +127,7 @@ class ChatMessageApi(Resource):
                 user=account,
                 user=account,
                 args=args,
                 args=args,
                 from_source='console',
                 from_source='console',
-                streaming=True,
+                streaming=streaming,
                 is_model_config_override=True
                 is_model_config_override=True
             )
             )
 
 

+ 1 - 1
api/controllers/console/app/generator.py

@@ -7,7 +7,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
 from controllers.console.setup import setup_required
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
 from controllers.console.wraps import account_initialization_required
 from core.generator.llm_generator import LLMGenerator
 from core.generator.llm_generator import LLMGenerator
-from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \
+from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \
     LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
     LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
 
 
 
 

+ 1 - 1
api/controllers/console/app/message.py

@@ -14,7 +14,7 @@ from controllers.console.app.error import CompletionRequestError, ProviderNotIni
     AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
     AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
 from controllers.console.setup import setup_required
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
 from controllers.console.wraps import account_initialization_required
-from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
+from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
     ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
     ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
 from libs.helper import uuid_value, TimestampField
 from libs.helper import uuid_value, TimestampField
 from libs.infinite_scroll_pagination import InfiniteScrollPagination
 from libs.infinite_scroll_pagination import InfiniteScrollPagination

+ 2 - 2
api/controllers/console/app/model_config.py

@@ -28,9 +28,9 @@ class ModelConfigResource(Resource):
 
 
         # validate config
         # validate config
         model_configuration = AppModelConfigService.validate_configuration(
         model_configuration = AppModelConfigService.validate_configuration(
+            tenant_id=current_user.current_tenant_id,
             account=current_user,
             account=current_user,
-            config=request.json,
-            mode=app_model.mode
+            config=request.json
         )
         )
 
 
         new_app_model_config = AppModelConfig(
         new_app_model_config = AppModelConfig(

+ 1 - 1
api/controllers/console/datasets/data_source.py

@@ -255,7 +255,7 @@ class DataSourceNotionApi(Resource):
         # validate args
         # validate args
         DocumentService.estimate_args_validate(args)
         DocumentService.estimate_args_validate(args)
         indexing_runner = IndexingRunner()
         indexing_runner = IndexingRunner()
-        response = indexing_runner.notion_indexing_estimate(args['notion_info_list'], args['process_rule'])
+        response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, args['notion_info_list'], args['process_rule'])
         return response, 200
         return response, 200
 
 
 
 

+ 29 - 3
api/controllers/console/datasets/datasets.py

@@ -5,10 +5,13 @@ from flask_restful import Resource, reqparse, fields, marshal, marshal_with
 from werkzeug.exceptions import NotFound, Forbidden
 from werkzeug.exceptions import NotFound, Forbidden
 import services
 import services
 from controllers.console import api
 from controllers.console import api
+from controllers.console.app.error import ProviderNotInitializeError
 from controllers.console.datasets.error import DatasetNameDuplicateError
 from controllers.console.datasets.error import DatasetNameDuplicateError
 from controllers.console.setup import setup_required
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
 from controllers.console.wraps import account_initialization_required
 from core.indexing_runner import IndexingRunner
 from core.indexing_runner import IndexingRunner
+from core.model_providers.error import LLMBadRequestError
+from core.model_providers.model_factory import ModelFactory
 from libs.helper import TimestampField
 from libs.helper import TimestampField
 from extensions.ext_database import db
 from extensions.ext_database import db
 from models.dataset import DocumentSegment, Document
 from models.dataset import DocumentSegment, Document
@@ -97,6 +100,15 @@ class DatasetListApi(Resource):
         if current_user.current_tenant.current_role not in ['admin', 'owner']:
         if current_user.current_tenant.current_role not in ['admin', 'owner']:
             raise Forbidden()
             raise Forbidden()
 
 
+        try:
+            ModelFactory.get_embedding_model(
+                tenant_id=current_user.current_tenant_id
+            )
+        except LLMBadRequestError:
+            raise ProviderNotInitializeError(
+                f"No Embedding Model available. Please configure a valid provider "
+                f"in the Settings -> Model Provider.")
+
         try:
         try:
             dataset = DatasetService.create_empty_dataset(
             dataset = DatasetService.create_empty_dataset(
                 tenant_id=current_user.current_tenant_id,
                 tenant_id=current_user.current_tenant_id,
@@ -235,12 +247,26 @@ class DatasetIndexingEstimateApi(Resource):
                 raise NotFound("File not found.")
                 raise NotFound("File not found.")
 
 
             indexing_runner = IndexingRunner()
             indexing_runner = IndexingRunner()
-            response = indexing_runner.file_indexing_estimate(file_details, args['process_rule'], args['doc_form'])
+
+            try:
+                response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
+                                                                  args['process_rule'], args['doc_form'])
+            except LLMBadRequestError:
+                raise ProviderNotInitializeError(
+                    f"No Embedding Model available. Please configure a valid provider "
+                    f"in the Settings -> Model Provider.")
         elif args['info_list']['data_source_type'] == 'notion_import':
         elif args['info_list']['data_source_type'] == 'notion_import':
 
 
             indexing_runner = IndexingRunner()
             indexing_runner = IndexingRunner()
-            response = indexing_runner.notion_indexing_estimate(args['info_list']['notion_info_list'],
-                                                                args['process_rule'], args['doc_form'])
+
+            try:
+                response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
+                                                                    args['info_list']['notion_info_list'],
+                                                                    args['process_rule'], args['doc_form'])
+            except LLMBadRequestError:
+                raise ProviderNotInitializeError(
+                    f"No Embedding Model available. Please configure a valid provider "
+                    f"in the Settings -> Model Provider.")
         else:
         else:
             raise ValueError('Data source type not support')
             raise ValueError('Data source type not support')
         return response, 200
         return response, 200

+ 43 - 5
api/controllers/console/datasets/datasets_document.py

@@ -18,7 +18,9 @@ from controllers.console.datasets.error import DocumentAlreadyFinishedError, Inv
 from controllers.console.setup import setup_required
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
 from controllers.console.wraps import account_initialization_required
 from core.indexing_runner import IndexingRunner
 from core.indexing_runner import IndexingRunner
-from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
+    LLMBadRequestError
+from core.model_providers.model_factory import ModelFactory
 from extensions.ext_redis import redis_client
 from extensions.ext_redis import redis_client
 from libs.helper import TimestampField
 from libs.helper import TimestampField
 from extensions.ext_database import db
 from extensions.ext_database import db
@@ -280,6 +282,15 @@ class DatasetDocumentListApi(Resource):
         # validate args
         # validate args
         DocumentService.document_create_args_validate(args)
         DocumentService.document_create_args_validate(args)
 
 
+        try:
+            ModelFactory.get_embedding_model(
+                tenant_id=current_user.current_tenant_id
+            )
+        except LLMBadRequestError:
+            raise ProviderNotInitializeError(
+                f"No Embedding Model available. Please configure a valid provider "
+                f"in the Settings -> Model Provider.")
+
         try:
         try:
             documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
             documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
         except ProviderTokenNotInitError as ex:
         except ProviderTokenNotInitError as ex:
@@ -319,6 +330,15 @@ class DatasetInitApi(Resource):
         parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
         parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
         args = parser.parse_args()
         args = parser.parse_args()
 
 
+        try:
+            ModelFactory.get_embedding_model(
+                tenant_id=current_user.current_tenant_id
+            )
+        except LLMBadRequestError:
+            raise ProviderNotInitializeError(
+                f"No Embedding Model available. Please configure a valid provider "
+                f"in the Settings -> Model Provider.")
+
         # validate args
         # validate args
         DocumentService.document_create_args_validate(args)
         DocumentService.document_create_args_validate(args)
 
 
@@ -384,7 +404,13 @@ class DocumentIndexingEstimateApi(DocumentResource):
 
 
                 indexing_runner = IndexingRunner()
                 indexing_runner = IndexingRunner()
 
 
-                response = indexing_runner.file_indexing_estimate([file], data_process_rule_dict)
+                try:
+                    response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file],
+                                                                      data_process_rule_dict)
+                except LLMBadRequestError:
+                    raise ProviderNotInitializeError(
+                        f"No Embedding Model available. Please configure a valid provider "
+                        f"in the Settings -> Model Provider.")
 
 
         return response
         return response
 
 
@@ -445,12 +471,24 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
                 raise NotFound("File not found.")
                 raise NotFound("File not found.")
 
 
             indexing_runner = IndexingRunner()
             indexing_runner = IndexingRunner()
-            response = indexing_runner.file_indexing_estimate(file_details, data_process_rule_dict)
+            try:
+                response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
+                                                                  data_process_rule_dict)
+            except LLMBadRequestError:
+                raise ProviderNotInitializeError(
+                    f"No Embedding Model available. Please configure a valid provider "
+                    f"in the Settings -> Model Provider.")
         elif dataset.data_source_type:
         elif dataset.data_source_type:
 
 
             indexing_runner = IndexingRunner()
             indexing_runner = IndexingRunner()
-            response = indexing_runner.notion_indexing_estimate(info_list,
-                                                                data_process_rule_dict)
+            try:
+                response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
+                                                                    info_list,
+                                                                    data_process_rule_dict)
+            except LLMBadRequestError:
+                raise ProviderNotInitializeError(
+                    f"No Embedding Model available. Please configure a valid provider "
+                    f"in the Settings -> Model Provider.")
         else:
         else:
             raise ValueError('Data source type not support')
             raise ValueError('Data source type not support')
         return response
         return response

+ 3 - 1
api/controllers/console/datasets/hit_testing.py

@@ -11,7 +11,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
 from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
 from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
 from controllers.console.setup import setup_required
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
 from controllers.console.wraps import account_initialization_required
-from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
 from libs.helper import TimestampField
 from libs.helper import TimestampField
 from services.dataset_service import DatasetService
 from services.dataset_service import DatasetService
 from services.hit_testing_service import HitTestingService
 from services.hit_testing_service import HitTestingService
@@ -102,6 +102,8 @@ class HitTestingApi(Resource):
             raise ProviderQuotaExceededError()
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
             raise ProviderModelCurrentlyNotSupportError()
+        except ValueError as e:
+            raise ValueError(str(e))
         except Exception as e:
         except Exception as e:
             logging.exception("Hit testing failed.")
             logging.exception("Hit testing failed.")
             raise InternalServerError(str(e))
             raise InternalServerError(str(e))

+ 1 - 1
api/controllers/console/explore/audio.py

@@ -11,7 +11,7 @@ from controllers.console.app.error import AppUnavailableError, ProviderNotInitia
     NoAudioUploadedError, AudioTooLargeError, \
     NoAudioUploadedError, AudioTooLargeError, \
     UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
     UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
 from controllers.console.explore.wraps import InstalledAppResource
 from controllers.console.explore.wraps import InstalledAppResource
-from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
+from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
     LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
     LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
 from services.audio_service import AudioService
 from services.audio_service import AudioService
 from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
 from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \

+ 1 - 1
api/controllers/console/explore/completion.py

@@ -15,7 +15,7 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail
 from controllers.console.explore.error import NotCompletionAppError, NotChatAppError
 from controllers.console.explore.error import NotCompletionAppError, NotChatAppError
 from controllers.console.explore.wraps import InstalledAppResource
 from controllers.console.explore.wraps import InstalledAppResource
 from core.conversation_message_task import PubHandler
 from core.conversation_message_task import PubHandler
-from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
+from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
     LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
     LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
 from libs.helper import uuid_value
 from libs.helper import uuid_value
 from services.completion_service import CompletionService
 from services.completion_service import CompletionService

+ 1 - 1
api/controllers/console/explore/message.py

@@ -15,7 +15,7 @@ from controllers.console.app.error import AppMoreLikeThisDisabledError, Provider
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
 from controllers.console.explore.error import NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError
 from controllers.console.explore.error import NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError
 from controllers.console.explore.wraps import InstalledAppResource
 from controllers.console.explore.wraps import InstalledAppResource
-from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
+from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
     ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
     ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
 from libs.helper import uuid_value, TimestampField
 from libs.helper import uuid_value, TimestampField
 from services.completion_service import CompletionService
 from services.completion_service import CompletionService

+ 1 - 4
api/controllers/console/explore/parameter.py

@@ -4,8 +4,6 @@ from flask_restful import marshal_with, fields
 from controllers.console import api
 from controllers.console import api
 from controllers.console.explore.wraps import InstalledAppResource
 from controllers.console.explore.wraps import InstalledAppResource
 
 
-from core.llm.llm_builder import LLMBuilder
-from models.provider import ProviderName
 from models.model import InstalledApp
 from models.model import InstalledApp
 
 
 
 
@@ -35,13 +33,12 @@ class AppParameterApi(InstalledAppResource):
         """Retrieve app parameters."""
         """Retrieve app parameters."""
         app_model = installed_app.app
         app_model = installed_app.app
         app_model_config = app_model.app_model_config
         app_model_config = app_model.app_model_config
-        provider_name = LLMBuilder.get_default_provider(installed_app.tenant_id, 'whisper-1')
 
 
         return {
         return {
             'opening_statement': app_model_config.opening_statement,
             'opening_statement': app_model_config.opening_statement,
             'suggested_questions': app_model_config.suggested_questions_list,
             'suggested_questions': app_model_config.suggested_questions_list,
             'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
             'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
-            'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
+            'speech_to_text': app_model_config.speech_to_text_dict,
             'more_like_this': app_model_config.more_like_this_dict,
             'more_like_this': app_model_config.more_like_this_dict,
             'user_input_form': app_model_config.user_input_form_list
             'user_input_form': app_model_config.user_input_form_list
         }
         }

+ 1 - 1
api/controllers/console/universal_chat/audio.py

@@ -11,7 +11,7 @@ from controllers.console.app.error import AppUnavailableError, ProviderNotInitia
     NoAudioUploadedError, AudioTooLargeError, \
     NoAudioUploadedError, AudioTooLargeError, \
     UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
     UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
 from controllers.console.universal_chat.wraps import UniversalChatResource
 from controllers.console.universal_chat.wraps import UniversalChatResource
-from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
+from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
     LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
     LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
 from services.audio_service import AudioService
 from services.audio_service import AudioService
 from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
 from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \

+ 3 - 7
api/controllers/console/universal_chat/chat.py

@@ -12,9 +12,8 @@ from controllers.console import api
 from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, ProviderNotInitializeError, \
 from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, ProviderNotInitializeError, \
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
 from controllers.console.universal_chat.wraps import UniversalChatResource
 from controllers.console.universal_chat.wraps import UniversalChatResource
-from core.constant import llm_constant
 from core.conversation_message_task import PubHandler
 from core.conversation_message_task import PubHandler
-from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
+from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
     LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError
     LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError
 from libs.helper import uuid_value
 from libs.helper import uuid_value
 from services.completion_service import CompletionService
 from services.completion_service import CompletionService
@@ -27,6 +26,7 @@ class UniversalChatApi(UniversalChatResource):
         parser = reqparse.RequestParser()
         parser = reqparse.RequestParser()
         parser.add_argument('query', type=str, required=True, location='json')
         parser.add_argument('query', type=str, required=True, location='json')
         parser.add_argument('conversation_id', type=uuid_value, location='json')
         parser.add_argument('conversation_id', type=uuid_value, location='json')
+        parser.add_argument('provider', type=str, required=True, location='json')
         parser.add_argument('model', type=str, required=True, location='json')
         parser.add_argument('model', type=str, required=True, location='json')
         parser.add_argument('tools', type=list, required=True, location='json')
         parser.add_argument('tools', type=list, required=True, location='json')
         args = parser.parse_args()
         args = parser.parse_args()
@@ -36,11 +36,7 @@ class UniversalChatApi(UniversalChatResource):
         # update app model config
         # update app model config
         args['model_config'] = app_model_config.to_dict()
         args['model_config'] = app_model_config.to_dict()
         args['model_config']['model']['name'] = args['model']
         args['model_config']['model']['name'] = args['model']
-
-        if not llm_constant.models[args['model']]:
-            raise ValueError("Model not exists.")
-
-        args['model_config']['model']['provider'] = llm_constant.models[args['model']]
+        args['model_config']['model']['provider'] = args['provider']
         args['model_config']['agent_mode']['tools'] = args['tools']
         args['model_config']['agent_mode']['tools'] = args['tools']
 
 
         if not args['model_config']['agent_mode']['tools']:
         if not args['model_config']['agent_mode']['tools']:

+ 1 - 1
api/controllers/console/universal_chat/message.py

@@ -12,7 +12,7 @@ from controllers.console.app.error import ProviderNotInitializeError, \
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
 from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
 from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
 from controllers.console.universal_chat.wraps import UniversalChatResource
 from controllers.console.universal_chat.wraps import UniversalChatResource
-from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
+from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
     ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
     ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
 from libs.helper import uuid_value, TimestampField
 from libs.helper import uuid_value, TimestampField
 from services.errors.conversation import ConversationNotExistsError
 from services.errors.conversation import ConversationNotExistsError

+ 1 - 4
api/controllers/console/universal_chat/parameter.py

@@ -4,8 +4,6 @@ from flask_restful import marshal_with, fields
 from controllers.console import api
 from controllers.console import api
 from controllers.console.universal_chat.wraps import UniversalChatResource
 from controllers.console.universal_chat.wraps import UniversalChatResource
 
 
-from core.llm.llm_builder import LLMBuilder
-from models.provider import ProviderName
 from models.model import App
 from models.model import App
 
 
 
 
@@ -23,13 +21,12 @@ class UniversalChatParameterApi(UniversalChatResource):
         """Retrieve app parameters."""
         """Retrieve app parameters."""
         app_model = universal_app
         app_model = universal_app
         app_model_config = app_model.app_model_config
         app_model_config = app_model.app_model_config
-        provider_name = LLMBuilder.get_default_provider(universal_app.tenant_id, 'whisper-1')
 
 
         return {
         return {
             'opening_statement': app_model_config.opening_statement,
             'opening_statement': app_model_config.opening_statement,
             'suggested_questions': app_model_config.suggested_questions_list,
             'suggested_questions': app_model_config.suggested_questions_list,
             'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
             'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
-            'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
+            'speech_to_text': app_model_config.speech_to_text_dict,
         }
         }
 
 
 
 

+ 0 - 0
api/tests/test_controllers/__init__.py → api/controllers/console/webhook/__init__.py


+ 53 - 0
api/controllers/console/webhook/stripe.py

@@ -0,0 +1,53 @@
+import logging
+
+import stripe
+from flask import request, current_app
+from flask_restful import Resource
+
+from controllers.console import api
+from controllers.console.setup import setup_required
+from controllers.console.wraps import only_edition_cloud
+from services.provider_checkout_service import ProviderCheckoutService
+
+
+class StripeWebhookApi(Resource):
+    @setup_required
+    @only_edition_cloud
+    def post(self):
+        payload = request.data
+        sig_header = request.headers.get('STRIPE_SIGNATURE')
+        webhook_secret = current_app.config.get('STRIPE_WEBHOOK_SECRET')
+
+        try:
+            event = stripe.Webhook.construct_event(
+                payload, sig_header, webhook_secret
+            )
+        except ValueError as e:
+            # Invalid payload
+            return 'Invalid payload', 400
+        except stripe.error.SignatureVerificationError as e:
+            # Invalid signature
+            return 'Invalid signature', 400
+
+        # Handle the checkout.session.completed event
+        if event['type'] == 'checkout.session.completed':
+            logging.debug(event['data']['object']['id'])
+            logging.debug(event['data']['object']['amount_subtotal'])
+            logging.debug(event['data']['object']['currency'])
+            logging.debug(event['data']['object']['payment_intent'])
+            logging.debug(event['data']['object']['payment_status'])
+            logging.debug(event['data']['object']['metadata'])
+
+            # Fulfill the purchase...
+            provider_checkout_service = ProviderCheckoutService()
+
+            try:
+                provider_checkout_service.fulfill_provider_order(event)
+            except Exception as e:
+                logging.debug(str(e))
+                return 'success', 200
+
+        return 'success', 200
+
+
+api.add_resource(StripeWebhookApi, '/webhook/stripe')

+ 203 - 193
api/controllers/console/workspace/model_providers.py

@@ -1,24 +1,18 @@
-# -*- coding:utf-8 -*-
-import base64
-import json
-import logging
-
-from flask import current_app
 from flask_login import login_required, current_user
 from flask_login import login_required, current_user
-from flask_restful import Resource, reqparse, abort
+from flask_restful import Resource, reqparse
 from werkzeug.exceptions import Forbidden
 from werkzeug.exceptions import Forbidden
 
 
 from controllers.console import api
 from controllers.console import api
+from controllers.console.app.error import ProviderNotInitializeError
 from controllers.console.setup import setup_required
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
 from controllers.console.wraps import account_initialization_required
-from core.llm.provider.errors import ValidateFailedError
-from extensions.ext_database import db
-from libs import rsa
-from models.provider import Provider, ProviderType, ProviderName
+from core.model_providers.error import LLMBadRequestError
+from core.model_providers.providers.base import CredentialsValidateFailedError
+from services.provider_checkout_service import ProviderCheckoutService
 from services.provider_service import ProviderService
 from services.provider_service import ProviderService
 
 
 
 
-class ProviderListApi(Resource):
+class ModelProviderListApi(Resource):
 
 
     @setup_required
     @setup_required
     @login_required
     @login_required
@@ -26,156 +20,115 @@ class ProviderListApi(Resource):
     def get(self):
     def get(self):
         tenant_id = current_user.current_tenant_id
         tenant_id = current_user.current_tenant_id
 
 
-        """
-        If the type is AZURE_OPENAI, decode and return the four fields of azure_api_type, azure_api_version:, 
-        azure_api_base, azure_api_key as an object, where azure_api_key displays the first 6 bits in plaintext, and the 
-        rest is replaced by * and the last two bits are displayed in plaintext
-        
-        If the type is other, decode and return the Token field directly, the field displays the first 6 bits in 
-        plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
-        """
-
-        ProviderService.init_supported_provider(current_user.current_tenant)
-        providers = Provider.query.filter_by(tenant_id=tenant_id).all()
-
-        provider_list = [
-            {
-                'provider_name': p.provider_name,
-                'provider_type': p.provider_type,
-                'is_valid': p.is_valid,
-                'last_used': p.last_used,
-                'is_enabled': p.is_enabled,
-                **({
-                       'quota_type': p.quota_type,
-                       'quota_limit': p.quota_limit,
-                       'quota_used': p.quota_used
-                   } if p.provider_type == ProviderType.SYSTEM.value else {}),
-                'token': ProviderService.get_obfuscated_api_key(current_user.current_tenant,
-                                                                ProviderName(p.provider_name), only_custom=True)
-                if p.provider_type == ProviderType.CUSTOM.value else None
-            }
-            for p in providers
-        ]
+        provider_service = ProviderService()
+        provider_list = provider_service.get_provider_list(tenant_id)
 
 
         return provider_list
         return provider_list
 
 
 
 
-class ProviderTokenApi(Resource):
+class ModelProviderValidateApi(Resource):
 
 
     @setup_required
     @setup_required
     @login_required
     @login_required
     @account_initialization_required
     @account_initialization_required
-    def post(self, provider):
-        if provider not in [p.value for p in ProviderName]:
-            abort(404)
+    def post(self, provider_name: str):
+
+        parser = reqparse.RequestParser()
+        parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
+        args = parser.parse_args()
+
+        provider_service = ProviderService()
+
+        result = True
+        error = None
+
+        try:
+            provider_service.custom_provider_config_validate(
+                provider_name=provider_name,
+                config=args['config']
+            )
+        except CredentialsValidateFailedError as ex:
+            result = False
+            error = str(ex)
+
+        response = {'result': 'success' if result else 'error'}
+
+        if not result:
+            response['error'] = error
+
+        return response
+
+
+class ModelProviderUpdateApi(Resource):
 
 
-        # The role of the current user in the ta table must be admin or owner
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def post(self, provider_name: str):
         if current_user.current_tenant.current_role not in ['admin', 'owner']:
         if current_user.current_tenant.current_role not in ['admin', 'owner']:
-            logging.log(logging.ERROR,
-                        f'User {current_user.id} is not authorized to update provider token, current_role is {current_user.current_tenant.current_role}')
             raise Forbidden()
             raise Forbidden()
 
 
         parser = reqparse.RequestParser()
         parser = reqparse.RequestParser()
-
-        parser.add_argument('token', type=ProviderService.get_token_type(
-            tenant=current_user.current_tenant,
-            provider_name=ProviderName(provider)
-        ), required=True, nullable=False, location='json')
-
+        parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
         args = parser.parse_args()
         args = parser.parse_args()
 
 
-        if args['token']:
-            try:
-                ProviderService.validate_provider_configs(
-                    tenant=current_user.current_tenant,
-                    provider_name=ProviderName(provider),
-                    configs=args['token']
-                )
-                token_is_valid = True
-            except ValidateFailedError as ex:
-                raise ValueError(str(ex))
-
-            base64_encrypted_token = ProviderService.get_encrypted_token(
-                tenant=current_user.current_tenant,
-                provider_name=ProviderName(provider),
-                configs=args['token']
+        provider_service = ProviderService()
+
+        try:
+            provider_service.save_custom_provider_config(
+                tenant_id=current_user.current_tenant_id,
+                provider_name=provider_name,
+                config=args['config']
             )
             )
-        else:
-            base64_encrypted_token = None
-            token_is_valid = False
-
-        tenant = current_user.current_tenant
-
-        provider_model = db.session.query(Provider).filter(
-                Provider.tenant_id == tenant.id,
-                Provider.provider_name == provider,
-                Provider.provider_type == ProviderType.CUSTOM.value
-            ).first()
-
-        # Only allow updating token for CUSTOM provider type
-        if provider_model:
-            provider_model.encrypted_config = base64_encrypted_token
-            provider_model.is_valid = token_is_valid
-        else:
-            provider_model = Provider(tenant_id=tenant.id, provider_name=provider,
-                                      provider_type=ProviderType.CUSTOM.value,
-                                      encrypted_config=base64_encrypted_token,
-                                      is_valid=token_is_valid)
-            db.session.add(provider_model)
-
-        if provider in [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value] and provider_model.is_valid:
-            other_providers = db.session.query(Provider).filter(
-                Provider.tenant_id == tenant.id,
-                Provider.provider_name.in_([ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value]),
-                Provider.provider_name != provider,
-                Provider.provider_type == ProviderType.CUSTOM.value
-            ).all()
-
-            for other_provider in other_providers:
-                other_provider.is_valid = False
-
-        db.session.commit()
-
-        if provider in [ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,
-                        ProviderName.HUGGINGFACEHUB.value]:
-            return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}, 201
+        except CredentialsValidateFailedError as ex:
+            raise ValueError(str(ex))
 
 
         return {'result': 'success'}, 201
         return {'result': 'success'}, 201
 
 
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def delete(self, provider_name: str):
+        if current_user.current_tenant.current_role not in ['admin', 'owner']:
+            raise Forbidden()
 
 
-class ProviderTokenValidateApi(Resource):
+        provider_service = ProviderService()
+        provider_service.delete_custom_provider(
+            tenant_id=current_user.current_tenant_id,
+            provider_name=provider_name
+        )
+
+        return {'result': 'success'}, 204
+
+
+class ModelProviderModelValidateApi(Resource):
 
 
     @setup_required
     @setup_required
     @login_required
     @login_required
     @account_initialization_required
     @account_initialization_required
-    def post(self, provider):
-        if provider not in [p.value for p in ProviderName]:
-            abort(404)
-
+    def post(self, provider_name: str):
         parser = reqparse.RequestParser()
         parser = reqparse.RequestParser()
-        parser.add_argument('token', type=ProviderService.get_token_type(
-            tenant=current_user.current_tenant,
-            provider_name=ProviderName(provider)
-        ), required=True, nullable=False, location='json')
+        parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
+        parser.add_argument('model_type', type=str, required=True, nullable=False,
+                            choices=['text-generation', 'embeddings', 'speech2text'], location='json')
+        parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
         args = parser.parse_args()
         args = parser.parse_args()
 
 
-        # todo: remove this when the provider is supported
-        if provider in [ProviderName.COHERE.value,
-                        ProviderName.HUGGINGFACEHUB.value]:
-            return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}
+        provider_service = ProviderService()
 
 
         result = True
         result = True
         error = None
         error = None
 
 
         try:
         try:
-            ProviderService.validate_provider_configs(
-                tenant=current_user.current_tenant,
-                provider_name=ProviderName(provider),
-                configs=args['token']
+            provider_service.custom_provider_model_config_validate(
+                provider_name=provider_name,
+                model_name=args['model_name'],
+                model_type=args['model_type'],
+                config=args['config']
             )
             )
-        except ValidateFailedError as e:
+        except CredentialsValidateFailedError as ex:
             result = False
             result = False
-            error = str(e)
+            error = str(ex)
 
 
         response = {'result': 'success' if result else 'error'}
         response = {'result': 'success' if result else 'error'}
 
 
@@ -185,91 +138,148 @@ class ProviderTokenValidateApi(Resource):
         return response
         return response
 
 
 
 
-class ProviderSystemApi(Resource):
+class ModelProviderModelUpdateApi(Resource):
 
 
     @setup_required
     @setup_required
     @login_required
     @login_required
     @account_initialization_required
     @account_initialization_required
-    def put(self, provider):
-        if provider not in [p.value for p in ProviderName]:
-            abort(404)
+    def post(self, provider_name: str):
+        if current_user.current_tenant.current_role not in ['admin', 'owner']:
+            raise Forbidden()
 
 
         parser = reqparse.RequestParser()
         parser = reqparse.RequestParser()
-        parser.add_argument('is_enabled', type=bool, required=True, location='json')
+        parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
+        parser.add_argument('model_type', type=str, required=True, nullable=False,
+                            choices=['text-generation', 'embeddings', 'speech2text'], location='json')
+        parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
         args = parser.parse_args()
         args = parser.parse_args()
 
 
-        tenant = current_user.current_tenant_id
-
-        provider_model = Provider.query.filter_by(tenant_id=tenant.id, provider_name=provider).first()
-
-        if provider_model and provider_model.provider_type == ProviderType.SYSTEM.value:
-            provider_model.is_valid = args['is_enabled']
-            db.session.commit()
-        elif not provider_model:
-            if provider == ProviderName.OPENAI.value:
-                quota_limit = current_app.config['OPENAI_HOSTED_QUOTA_LIMIT']
-            elif provider == ProviderName.ANTHROPIC.value:
-                quota_limit = current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT']
-            else:
-                quota_limit = 0
-
-            ProviderService.create_system_provider(
-                tenant,
-                provider,
-                quota_limit,
-                args['is_enabled']
+        provider_service = ProviderService()
+
+        try:
+            provider_service.add_or_save_custom_provider_model_config(
+                tenant_id=current_user.current_tenant_id,
+                provider_name=provider_name,
+                model_name=args['model_name'],
+                model_type=args['model_type'],
+                config=args['config']
             )
             )
-        else:
-            abort(403)
+        except CredentialsValidateFailedError as ex:
+            raise ValueError(str(ex))
 
 
-        return {'result': 'success'}
+        return {'result': 'success'}, 200
 
 
     @setup_required
     @setup_required
     @login_required
     @login_required
     @account_initialization_required
     @account_initialization_required
-    def get(self, provider):
-        if provider not in [p.value for p in ProviderName]:
-            abort(404)
+    def delete(self, provider_name: str):
+        if current_user.current_tenant.current_role not in ['admin', 'owner']:
+            raise Forbidden()
+
+        parser = reqparse.RequestParser()
+        parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
+        parser.add_argument('model_type', type=str, required=True, nullable=False,
+                            choices=['text-generation', 'embeddings', 'speech2text'], location='args')
+        args = parser.parse_args()
 
 
-        # The role of the current user in the ta table must be admin or owner
+        provider_service = ProviderService()
+        provider_service.delete_custom_provider_model(
+            tenant_id=current_user.current_tenant_id,
+            provider_name=provider_name,
+            model_name=args['model_name'],
+            model_type=args['model_type']
+        )
+
+        return {'result': 'success'}, 204
+
+
+class PreferredProviderTypeUpdateApi(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def post(self, provider_name: str):
         if current_user.current_tenant.current_role not in ['admin', 'owner']:
         if current_user.current_tenant.current_role not in ['admin', 'owner']:
             raise Forbidden()
             raise Forbidden()
 
 
-        provider_model = db.session.query(Provider).filter(Provider.tenant_id == current_user.current_tenant_id,
-                                                           Provider.provider_name == provider,
-                                                           Provider.provider_type == ProviderType.SYSTEM.value).first()
-
-        system_model = None
-        if provider_model:
-            system_model = {
-                'result': 'success',
-                'provider': {
-                    'provider_name': provider_model.provider_name,
-                    'provider_type': provider_model.provider_type,
-                    'is_valid': provider_model.is_valid,
-                    'last_used': provider_model.last_used,
-                    'is_enabled': provider_model.is_enabled,
-                    'quota_type': provider_model.quota_type,
-                    'quota_limit': provider_model.quota_limit,
-                    'quota_used': provider_model.quota_used
-                }
-            }
-        else:
-            abort(404)
+        parser = reqparse.RequestParser()
+        parser.add_argument('preferred_provider_type', type=str, required=True, nullable=False,
+                            choices=['system', 'custom'], location='json')
+        args = parser.parse_args()
+
+        provider_service = ProviderService()
+        provider_service.switch_preferred_provider(
+            tenant_id=current_user.current_tenant_id,
+            provider_name=provider_name,
+            preferred_provider_type=args['preferred_provider_type']
+        )
+
+        return {'result': 'success'}
+
+
+class ModelProviderModelParameterRuleApi(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self, provider_name: str):
+        parser = reqparse.RequestParser()
+        parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
+        args = parser.parse_args()
 
 
-        return system_model
+        provider_service = ProviderService()
 
 
+        try:
+            parameter_rules = provider_service.get_model_parameter_rules(
+                tenant_id=current_user.current_tenant_id,
+                model_provider_name=provider_name,
+                model_name=args['model_name'],
+                model_type='text-generation'
+            )
+        except LLMBadRequestError:
+            raise ProviderNotInitializeError(
+                f"Current Text Generation Model is invalid. Please switch to the available model.")
+
+        rules = {
+            k: {
+                'enabled': v.enabled,
+                'min': v.min,
+                'max': v.max,
+                'default': v.default
+            }
+            for k, v in vars(parameter_rules).items()
+        }
 
 
-api.add_resource(ProviderTokenApi, '/providers/<provider>/token',
-                 endpoint='current_providers_token')  # Deprecated
-api.add_resource(ProviderTokenValidateApi, '/providers/<provider>/token-validate',
-                 endpoint='current_providers_token_validate')  # Deprecated
+        return rules
 
 
-api.add_resource(ProviderTokenApi, '/workspaces/current/providers/<provider>/token',
-                 endpoint='workspaces_current_providers_token')  # PUT for updating provider token
-api.add_resource(ProviderTokenValidateApi, '/workspaces/current/providers/<provider>/token-validate',
-                 endpoint='workspaces_current_providers_token_validate')  # POST for validating provider token
 
 
-api.add_resource(ProviderListApi, '/workspaces/current/providers')  # GET for getting providers list
-api.add_resource(ProviderSystemApi, '/workspaces/current/providers/<provider>/system',
-                 endpoint='workspaces_current_providers_system')  # GET for getting provider quota, PUT for updating provider status
+class ModelProviderPaymentCheckoutUrlApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self, provider_name: str):
+        provider_service = ProviderCheckoutService()
+        provider_checkout = provider_service.create_checkout(
+            tenant_id=current_user.current_tenant_id,
+            provider_name=provider_name,
+            account=current_user
+        )
+
+        return {
+            'url': provider_checkout.get_checkout_url()
+        }
+
+
+api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers')
+api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider_name>/validate')
+api.add_resource(ModelProviderUpdateApi, '/workspaces/current/model-providers/<string:provider_name>')
+api.add_resource(ModelProviderModelValidateApi,
+                 '/workspaces/current/model-providers/<string:provider_name>/models/validate')
+api.add_resource(ModelProviderModelUpdateApi,
+                 '/workspaces/current/model-providers/<string:provider_name>/models')
+api.add_resource(PreferredProviderTypeUpdateApi,
+                 '/workspaces/current/model-providers/<string:provider_name>/preferred-provider-type')
+api.add_resource(ModelProviderModelParameterRuleApi,
+                 '/workspaces/current/model-providers/<string:provider_name>/models/parameter-rules')
+api.add_resource(ModelProviderPaymentCheckoutUrlApi,
+                 '/workspaces/current/model-providers/<string:provider_name>/checkout-url')

+ 108 - 0
api/controllers/console/workspace/models.py

@@ -0,0 +1,108 @@
+from flask_login import login_required, current_user
+from flask_restful import Resource, reqparse
+
+from controllers.console import api
+from controllers.console.setup import setup_required
+from controllers.console.wraps import account_initialization_required
+from core.model_providers.model_provider_factory import ModelProviderFactory
+from core.model_providers.models.entity.model_params import ModelType
+from models.provider import ProviderType
+from services.provider_service import ProviderService
+
+
+class DefaultModelApi(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self):
+        parser = reqparse.RequestParser()
+        parser.add_argument('model_type', type=str, required=True, nullable=False,
+                            choices=['text-generation', 'embeddings', 'speech2text'], location='args')
+        args = parser.parse_args()
+
+        tenant_id = current_user.current_tenant_id
+
+        provider_service = ProviderService()
+        default_model = provider_service.get_default_model_of_model_type(
+            tenant_id=tenant_id,
+            model_type=args['model_type']
+        )
+
+        if not default_model:
+            return None
+
+        model_provider = ModelProviderFactory.get_preferred_model_provider(
+            tenant_id,
+            default_model.provider_name
+        )
+
+        if not model_provider:
+            return {
+                'model_name': default_model.model_name,
+                'model_type': default_model.model_type,
+                'model_provider': {
+                    'provider_name': default_model.provider_name
+                }
+            }
+
+        provider = model_provider.provider
+        rst = {
+            'model_name': default_model.model_name,
+            'model_type': default_model.model_type,
+            'model_provider': {
+                'provider_name': provider.provider_name,
+                'provider_type': provider.provider_type
+            }
+        }
+
+        model_provider_rules = ModelProviderFactory.get_provider_rule(default_model.provider_name)
+        if provider.provider_type == ProviderType.SYSTEM.value:
+            rst['model_provider']['quota_type'] = provider.quota_type
+            rst['model_provider']['quota_unit'] = model_provider_rules['system_config']['quota_unit']
+            rst['model_provider']['quota_limit'] = provider.quota_limit
+            rst['model_provider']['quota_used'] = provider.quota_used
+
+        return rst
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def post(self):
+        parser = reqparse.RequestParser()
+        parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
+        parser.add_argument('model_type', type=str, required=True, nullable=False,
+                            choices=['text-generation', 'embeddings', 'speech2text'], location='json')
+        parser.add_argument('provider_name', type=str, required=True, nullable=False, location='json')
+        args = parser.parse_args()
+
+        provider_service = ProviderService()
+        provider_service.update_default_model_of_model_type(
+            tenant_id=current_user.current_tenant_id,
+            model_type=args['model_type'],
+            provider_name=args['provider_name'],
+            model_name=args['model_name']
+        )
+
+        return {'result': 'success'}
+
+
+class ValidModelApi(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self, model_type):
+        ModelType.value_of(model_type)
+
+        provider_service = ProviderService()
+        valid_models = provider_service.get_valid_model_list(
+            tenant_id=current_user.current_tenant_id,
+            model_type=model_type
+        )
+
+        return valid_models
+
+
+api.add_resource(DefaultModelApi, '/workspaces/current/default-model')
+api.add_resource(ValidModelApi, '/workspaces/current/models/model-type/<string:model_type>')

+ 130 - 0
api/controllers/console/workspace/providers.py

@@ -0,0 +1,130 @@
+# -*- coding:utf-8 -*-
+from flask_login import login_required, current_user
+from flask_restful import Resource, reqparse
+from werkzeug.exceptions import Forbidden
+
+from controllers.console import api
+from controllers.console.setup import setup_required
+from controllers.console.wraps import account_initialization_required
+from core.model_providers.providers.base import CredentialsValidateFailedError
+from models.provider import ProviderType
+from services.provider_service import ProviderService
+
+
+class ProviderListApi(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self):
+        tenant_id = current_user.current_tenant_id
+
+        """
+        If the type is AZURE_OPENAI, decode and return the four fields of azure_api_type, azure_api_version:, 
+        azure_api_base, azure_api_key as an object, where azure_api_key displays the first 6 bits in plaintext, and the 
+        rest is replaced by * and the last two bits are displayed in plaintext
+        
+        If the type is other, decode and return the Token field directly, the field displays the first 6 bits in 
+        plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
+        """
+
+        provider_service = ProviderService()
+        provider_info_list = provider_service.get_provider_list(tenant_id)
+
+        provider_list = [
+            {
+                'provider_name': p['provider_name'],
+                'provider_type': p['provider_type'],
+                'is_valid': p['is_valid'],
+                'last_used': p['last_used'],
+                'is_enabled': p['is_valid'],
+                **({
+                       'quota_type': p['quota_type'],
+                       'quota_limit': p['quota_limit'],
+                       'quota_used': p['quota_used']
+                   } if p['provider_type'] == ProviderType.SYSTEM.value else {}),
+                'token': (p['config'] if p['provider_name'] != 'openai' else p['config']['openai_api_key'])
+                        if p['config'] else None
+            }
+            for name, provider_info in provider_info_list.items()
+            for p in provider_info['providers']
+        ]
+
+        return provider_list
+
+
+class ProviderTokenApi(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def post(self, provider):
+        # The role of the current user in the ta table must be admin or owner
+        if current_user.current_tenant.current_role not in ['admin', 'owner']:
+            raise Forbidden()
+
+        parser = reqparse.RequestParser()
+        parser.add_argument('token', required=True, nullable=False, location='json')
+        args = parser.parse_args()
+
+        if provider == 'openai':
+            args['token'] = {
+                'openai_api_key': args['token']
+            }
+
+        provider_service = ProviderService()
+        try:
+            provider_service.save_custom_provider_config(
+                tenant_id=current_user.current_tenant_id,
+                provider_name=provider,
+                config=args['token']
+            )
+        except CredentialsValidateFailedError as ex:
+            raise ValueError(str(ex))
+
+        return {'result': 'success'}, 201
+
+
+class ProviderTokenValidateApi(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def post(self, provider):
+        parser = reqparse.RequestParser()
+        parser.add_argument('token', required=True, nullable=False, location='json')
+        args = parser.parse_args()
+
+        provider_service = ProviderService()
+
+        if provider == 'openai':
+            args['token'] = {
+                'openai_api_key': args['token']
+            }
+
+        result = True
+        error = None
+
+        try:
+            provider_service.custom_provider_config_validate(
+                provider_name=provider,
+                config=args['token']
+            )
+        except CredentialsValidateFailedError as ex:
+            result = False
+            error = str(ex)
+
+        response = {'result': 'success' if result else 'error'}
+
+        if not result:
+            response['error'] = error
+
+        return response
+
+
+api.add_resource(ProviderTokenApi, '/workspaces/current/providers/<provider>/token',
+                 endpoint='workspaces_current_providers_token')  # PUT for updating provider token
+api.add_resource(ProviderTokenValidateApi, '/workspaces/current/providers/<provider>/token-validate',
+                 endpoint='workspaces_current_providers_token_validate')  # POST for validating provider token
+
+api.add_resource(ProviderListApi, '/workspaces/current/providers')  # GET for getting providers list

+ 1 - 1
api/controllers/console/workspace/workspace.py

@@ -30,7 +30,7 @@ tenant_fields = {
     'created_at': TimestampField,
     'created_at': TimestampField,
     'role': fields.String,
     'role': fields.String,
     'providers': fields.List(fields.Nested(provider_fields)),
     'providers': fields.List(fields.Nested(provider_fields)),
-    'in_trail': fields.Boolean,
+    'in_trial': fields.Boolean,
     'trial_end_reason': fields.String,
     'trial_end_reason': fields.String,
 }
 }
 
 

+ 1 - 4
api/controllers/service_api/app/app.py

@@ -4,8 +4,6 @@ from flask_restful import fields, marshal_with
 from controllers.service_api import api
 from controllers.service_api import api
 from controllers.service_api.wraps import AppApiResource
 from controllers.service_api.wraps import AppApiResource
 
 
-from core.llm.llm_builder import LLMBuilder
-from models.provider import ProviderName
 from models.model import App
 from models.model import App
 
 
 
 
@@ -35,13 +33,12 @@ class AppParameterApi(AppApiResource):
     def get(self, app_model: App, end_user):
     def get(self, app_model: App, end_user):
         """Retrieve app parameters."""
         """Retrieve app parameters."""
         app_model_config = app_model.app_model_config
         app_model_config = app_model.app_model_config
-        provider_name = LLMBuilder.get_default_provider(app_model.tenant_id, 'whisper-1')
 
 
         return {
         return {
             'opening_statement': app_model_config.opening_statement,
             'opening_statement': app_model_config.opening_statement,
             'suggested_questions': app_model_config.suggested_questions_list,
             'suggested_questions': app_model_config.suggested_questions_list,
             'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
             'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
-            'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
+            'speech_to_text': app_model_config.speech_to_text_dict,
             'more_like_this': app_model_config.more_like_this_dict,
             'more_like_this': app_model_config.more_like_this_dict,
             'user_input_form': app_model_config.user_input_form_list
             'user_input_form': app_model_config.user_input_form_list
         }
         }

+ 1 - 1
api/controllers/service_api/app/audio.py

@@ -9,7 +9,7 @@ from controllers.service_api.app.error import AppUnavailableError, ProviderNotIn
     ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, UnsupportedAudioTypeError, \
     ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, UnsupportedAudioTypeError, \
     ProviderNotSupportSpeechToTextError
     ProviderNotSupportSpeechToTextError
 from controllers.service_api.wraps import AppApiResource
 from controllers.service_api.wraps import AppApiResource
-from core.llm.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
+from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
     LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
     LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
 from models.model import App, AppModelConfig
 from models.model import App, AppModelConfig
 from services.audio_service import AudioService
 from services.audio_service import AudioService

+ 1 - 1
api/controllers/service_api/app/completion.py

@@ -14,7 +14,7 @@ from controllers.service_api.app.error import AppUnavailableError, ProviderNotIn
     ProviderModelCurrentlyNotSupportError
     ProviderModelCurrentlyNotSupportError
 from controllers.service_api.wraps import AppApiResource
 from controllers.service_api.wraps import AppApiResource
 from core.conversation_message_task import PubHandler
 from core.conversation_message_task import PubHandler
-from core.llm.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
+from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
     LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
     LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
 from libs.helper import uuid_value
 from libs.helper import uuid_value
 from services.completion_service import CompletionService
 from services.completion_service import CompletionService

+ 1 - 1
api/controllers/service_api/dataset/document.py

@@ -11,7 +11,7 @@ from controllers.service_api.app.error import ProviderNotInitializeError
 from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \
 from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \
     DatasetNotInitedError
     DatasetNotInitedError
 from controllers.service_api.wraps import DatasetApiResource
 from controllers.service_api.wraps import DatasetApiResource
-from core.llm.error import ProviderTokenNotInitError
+from core.model_providers.error import ProviderTokenNotInitError
 from extensions.ext_database import db
 from extensions.ext_database import db
 from extensions.ext_storage import storage
 from extensions.ext_storage import storage
 from models.model import UploadFile
 from models.model import UploadFile

+ 1 - 4
api/controllers/web/app.py

@@ -4,8 +4,6 @@ from flask_restful import marshal_with, fields
 from controllers.web import api
 from controllers.web import api
 from controllers.web.wraps import WebApiResource
 from controllers.web.wraps import WebApiResource
 
 
-from core.llm.llm_builder import LLMBuilder
-from models.provider import ProviderName
 from models.model import App
 from models.model import App
 
 
 
 
@@ -34,13 +32,12 @@ class AppParameterApi(WebApiResource):
     def get(self, app_model: App, end_user):
     def get(self, app_model: App, end_user):
         """Retrieve app parameters."""
         """Retrieve app parameters."""
         app_model_config = app_model.app_model_config
         app_model_config = app_model.app_model_config
-        provider_name = LLMBuilder.get_default_provider(app_model.tenant_id, 'whisper-1')
 
 
         return {
         return {
             'opening_statement': app_model_config.opening_statement,
             'opening_statement': app_model_config.opening_statement,
             'suggested_questions': app_model_config.suggested_questions_list,
             'suggested_questions': app_model_config.suggested_questions_list,
             'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
             'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
-            'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
+            'speech_to_text': app_model_config.speech_to_text_dict,
             'more_like_this': app_model_config.more_like_this_dict,
             'more_like_this': app_model_config.more_like_this_dict,
             'user_input_form': app_model_config.user_input_form_list
             'user_input_form': app_model_config.user_input_form_list
         }
         }

+ 1 - 1
api/controllers/web/audio.py

@@ -10,7 +10,7 @@ from controllers.web.error import AppUnavailableError, ProviderNotInitializeErro
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, \
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, \
     UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
     UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
 from controllers.web.wraps import WebApiResource
 from controllers.web.wraps import WebApiResource
-from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
+from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
     LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
     LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
 from services.audio_service import AudioService
 from services.audio_service import AudioService
 from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
 from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \

+ 1 - 1
api/controllers/web/completion.py

@@ -14,7 +14,7 @@ from controllers.web.error import AppUnavailableError, ConversationCompletedErro
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
 from controllers.web.wraps import WebApiResource
 from controllers.web.wraps import WebApiResource
 from core.conversation_message_task import PubHandler
 from core.conversation_message_task import PubHandler
-from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
+from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
     LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
     LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
 from libs.helper import uuid_value
 from libs.helper import uuid_value
 from services.completion_service import CompletionService
 from services.completion_service import CompletionService

+ 1 - 1
api/controllers/web/message.py

@@ -14,7 +14,7 @@ from controllers.web.error import NotChatAppError, CompletionRequestError, Provi
     AppMoreLikeThisDisabledError, NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \
     AppMoreLikeThisDisabledError, NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
 from controllers.web.wraps import WebApiResource
 from controllers.web.wraps import WebApiResource
-from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
+from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
     ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
     ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
 from libs.helper import uuid_value, TimestampField
 from libs.helper import uuid_value, TimestampField
 from services.completion_service import CompletionService
 from services.completion_service import CompletionService

+ 0 - 36
api/core/__init__.py

@@ -1,36 +0,0 @@
-import os
-from typing import Optional
-
-import langchain
-from flask import Flask
-from pydantic import BaseModel
-
-from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
-from core.prompt.prompt_template import OneLineFormatter
-
-
-class HostedOpenAICredential(BaseModel):
-    api_key: str
-
-
-class HostedAnthropicCredential(BaseModel):
-    api_key: str
-
-
-class HostedLLMCredentials(BaseModel):
-    openai: Optional[HostedOpenAICredential] = None
-    anthropic: Optional[HostedAnthropicCredential] = None
-
-
-hosted_llm_credentials = HostedLLMCredentials()
-
-
-def init_app(app: Flask):
-    if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
-        langchain.verbose = True
-
-    if app.config.get("OPENAI_API_KEY"):
-        hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY"))
-
-    if app.config.get("ANTHROPIC_API_KEY"):
-        hosted_llm_credentials.anthropic = HostedAnthropicCredential(api_key=app.config.get("ANTHROPIC_API_KEY"))

+ 9 - 13
api/core/agent/agent/calc_token_mixin.py

@@ -1,20 +1,17 @@
-from typing import cast, List
+from typing import List
 
 
-from langchain import OpenAI
-from langchain.base_language import BaseLanguageModel
-from langchain.chat_models.openai import ChatOpenAI
 from langchain.schema import BaseMessage
 from langchain.schema import BaseMessage
 
 
-from core.constant import llm_constant
+from core.model_providers.models.entity.message import to_prompt_messages
+from core.model_providers.models.llm.base import BaseLLM
 
 
 
 
 class CalcTokenMixin:
 class CalcTokenMixin:
 
 
-    def get_num_tokens_from_messages(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int:
-        llm = cast(ChatOpenAI, llm)
-        return llm.get_num_tokens_from_messages(messages)
+    def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
+        return model_instance.get_num_tokens(to_prompt_messages(messages))
 
 
-    def get_message_rest_tokens(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int:
+    def get_message_rest_tokens(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
         """
         """
         Got the rest tokens available for the model after excluding messages tokens and completion max tokens
         Got the rest tokens available for the model after excluding messages tokens and completion max tokens
 
 
@@ -22,10 +19,9 @@ class CalcTokenMixin:
         :param messages:
         :param messages:
         :return:
         :return:
         """
         """
-        llm = cast(ChatOpenAI, llm)
-        llm_max_tokens = llm_constant.max_context_token_length[llm.model_name]
-        completion_max_tokens = llm.max_tokens
-        used_tokens = self.get_num_tokens_from_messages(llm, messages, **kwargs)
+        llm_max_tokens = model_instance.model_rules.max_tokens.max
+        completion_max_tokens = model_instance.model_kwargs.max_tokens
+        used_tokens = self.get_num_tokens_from_messages(model_instance, messages, **kwargs)
         rest_tokens = llm_max_tokens - completion_max_tokens - used_tokens
         rest_tokens = llm_max_tokens - completion_max_tokens - used_tokens
 
 
         return rest_tokens
         return rest_tokens

+ 9 - 1
api/core/agent/agent/multi_dataset_router_agent.py

@@ -4,9 +4,11 @@ from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
 from langchain.callbacks.base import BaseCallbackManager
 from langchain.callbacks.base import BaseCallbackManager
 from langchain.callbacks.manager import Callbacks
 from langchain.callbacks.manager import Callbacks
 from langchain.prompts.chat import BaseMessagePromptTemplate
 from langchain.prompts.chat import BaseMessagePromptTemplate
-from langchain.schema import AgentAction, AgentFinish, BaseLanguageModel, SystemMessage
+from langchain.schema import AgentAction, AgentFinish, SystemMessage
+from langchain.schema.language_model import BaseLanguageModel
 from langchain.tools import BaseTool
 from langchain.tools import BaseTool
 
 
+from core.model_providers.models.llm.base import BaseLLM
 from core.tool.dataset_retriever_tool import DatasetRetrieverTool
 from core.tool.dataset_retriever_tool import DatasetRetrieverTool
 
 
 
 
@@ -14,6 +16,12 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
     """
     """
     An Multi Dataset Retrieve Agent driven by Router.
     An Multi Dataset Retrieve Agent driven by Router.
     """
     """
+    model_instance: BaseLLM
+
+    class Config:
+        """Configuration for this pydantic object."""
+
+        arbitrary_types_allowed = True
 
 
     def should_use_agent(self, query: str):
     def should_use_agent(self, query: str):
         """
         """

+ 3 - 2
api/core/agent/agent/openai_function_call.py

@@ -6,7 +6,8 @@ from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
 from langchain.callbacks.base import BaseCallbackManager
 from langchain.callbacks.base import BaseCallbackManager
 from langchain.callbacks.manager import Callbacks
 from langchain.callbacks.manager import Callbacks
 from langchain.prompts.chat import BaseMessagePromptTemplate
 from langchain.prompts.chat import BaseMessagePromptTemplate
-from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel
+from langchain.schema import AgentAction, AgentFinish, SystemMessage
+from langchain.schema.language_model import BaseLanguageModel
 from langchain.tools import BaseTool
 from langchain.tools import BaseTool
 
 
 from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
 from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
@@ -84,7 +85,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
 
 
         # summarize messages if rest_tokens < 0
         # summarize messages if rest_tokens < 0
         try:
         try:
-            messages = self.summarize_messages_if_needed(self.llm, messages, functions=self.functions)
+            messages = self.summarize_messages_if_needed(messages, functions=self.functions)
         except ExceededLLMTokensLimitError as e:
         except ExceededLLMTokensLimitError as e:
             return AgentFinish(return_values={"output": str(e)}, log=str(e))
             return AgentFinish(return_values={"output": str(e)}, log=str(e))
 
 

+ 11 - 3
api/core/agent/agent/openai_function_call_summarize_mixin.py

@@ -3,20 +3,28 @@ from typing import cast, List
 from langchain.chat_models import ChatOpenAI
 from langchain.chat_models import ChatOpenAI
 from langchain.chat_models.openai import _convert_message_to_dict
 from langchain.chat_models.openai import _convert_message_to_dict
 from langchain.memory.summary import SummarizerMixin
 from langchain.memory.summary import SummarizerMixin
-from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage, BaseLanguageModel
+from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage
+from langchain.schema.language_model import BaseLanguageModel
 from pydantic import BaseModel
 from pydantic import BaseModel
 
 
 from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
 from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
+from core.model_providers.models.llm.base import BaseLLM
 
 
 
 
 class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin):
 class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin):
     moving_summary_buffer: str = ""
     moving_summary_buffer: str = ""
     moving_summary_index: int = 0
     moving_summary_index: int = 0
     summary_llm: BaseLanguageModel
     summary_llm: BaseLanguageModel
+    model_instance: BaseLLM
 
 
-    def summarize_messages_if_needed(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
+    class Config:
+        """Configuration for this pydantic object."""
+
+        arbitrary_types_allowed = True
+
+    def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
         # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
         # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
-        rest_tokens = self.get_message_rest_tokens(llm, messages, **kwargs)
+        rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
         rest_tokens = rest_tokens - 20  # to deal with the inaccuracy of rest_tokens
         rest_tokens = rest_tokens - 20  # to deal with the inaccuracy of rest_tokens
         if rest_tokens >= 0:
         if rest_tokens >= 0:
             return messages
             return messages

+ 3 - 2
api/core/agent/agent/openai_multi_function_call.py

@@ -6,7 +6,8 @@ from langchain.agents.openai_functions_multi_agent.base import OpenAIMultiFuncti
 from langchain.callbacks.base import BaseCallbackManager
 from langchain.callbacks.base import BaseCallbackManager
 from langchain.callbacks.manager import Callbacks
 from langchain.callbacks.manager import Callbacks
 from langchain.prompts.chat import BaseMessagePromptTemplate
 from langchain.prompts.chat import BaseMessagePromptTemplate
-from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel
+from langchain.schema import AgentAction, AgentFinish, SystemMessage
+from langchain.schema.language_model import BaseLanguageModel
 from langchain.tools import BaseTool
 from langchain.tools import BaseTool
 
 
 from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
 from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
@@ -84,7 +85,7 @@ class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, Ope
 
 
         # summarize messages if rest_tokens < 0
         # summarize messages if rest_tokens < 0
         try:
         try:
-            messages = self.summarize_messages_if_needed(self.llm, messages, functions=self.functions)
+            messages = self.summarize_messages_if_needed(messages, functions=self.functions)
         except ExceededLLMTokensLimitError as e:
         except ExceededLLMTokensLimitError as e:
             return AgentFinish(return_values={"output": str(e)}, log=str(e))
             return AgentFinish(return_values={"output": str(e)}, log=str(e))
 
 

+ 162 - 0
api/core/agent/agent/structed_multi_dataset_router_agent.py

@@ -0,0 +1,162 @@
+import re
+from typing import List, Tuple, Any, Union, Sequence, Optional, cast
+
+from langchain import BasePromptTemplate
+from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
+from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
+from langchain.base_language import BaseLanguageModel
+from langchain.callbacks.base import BaseCallbackManager
+from langchain.callbacks.manager import Callbacks
+from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
+from langchain.schema import AgentAction, AgentFinish, OutputParserException
+from langchain.tools import BaseTool
+from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
+
+from core.model_providers.models.llm.base import BaseLLM
+from core.tool.dataset_retriever_tool import DatasetRetrieverTool
+
+FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
+The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
+Valid "action" values: "Final Answer" or {tool_names}
+
+Provide only ONE action per $JSON_BLOB, as shown:
+
+```
+{{{{
+  "action": $TOOL_NAME,
+  "action_input": $INPUT
+}}}}
+```
+
+Follow this format:
+
+Question: input question to answer
+Thought: consider previous and subsequent steps
+Action:
+```
+$JSON_BLOB
+```
+Observation: action result
+... (repeat Thought/Action/Observation N times)
+Thought: I know what to respond
+Action:
+```
+{{{{
+  "action": "Final Answer",
+  "action_input": "Final response to human"
+}}}}
+```"""
+
+
+class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
+    model_instance: BaseLLM
+    dataset_tools: Sequence[BaseTool]
+
+    class Config:
+        """Configuration for this pydantic object."""
+
+        arbitrary_types_allowed = True
+
+    def should_use_agent(self, query: str):
+        """
+        return should use agent
+        Using the ReACT mode to determine whether an agent is needed is costly,
+        so it's better to just use an Agent for reasoning, which is cheaper.
+
+        :param query:
+        :return:
+        """
+        return True
+
+    def plan(
+        self,
+        intermediate_steps: List[Tuple[AgentAction, str]],
+        callbacks: Callbacks = None,
+        **kwargs: Any,
+    ) -> Union[AgentAction, AgentFinish]:
+        """Given input, decided what to do.
+
+        Args:
+            intermediate_steps: Steps the LLM has taken to date,
+                along with observations
+            callbacks: Callbacks to run.
+            **kwargs: User inputs.
+
+        Returns:
+            Action specifying what tool to use.
+        """
+        if len(self.dataset_tools) == 0:
+            return AgentFinish(return_values={"output": ''}, log='')
+        elif len(self.dataset_tools) == 1:
+            tool = next(iter(self.dataset_tools))
+            tool = cast(DatasetRetrieverTool, tool)
+            rst = tool.run(tool_input={'dataset_id': tool.dataset_id, 'query': kwargs['input']})
+            return AgentFinish(return_values={"output": rst}, log=rst)
+
+        full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
+        full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
+
+        try:
+            return self.output_parser.parse(full_output)
+        except OutputParserException:
+            return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
+                                          "I don't know how to respond to that."}, "")
+    @classmethod
+    def create_prompt(
+            cls,
+            tools: Sequence[BaseTool],
+            prefix: str = PREFIX,
+            suffix: str = SUFFIX,
+            human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
+            format_instructions: str = FORMAT_INSTRUCTIONS,
+            input_variables: Optional[List[str]] = None,
+            memory_prompts: Optional[List[BasePromptTemplate]] = None,
+    ) -> BasePromptTemplate:
+        tool_strings = []
+        for tool in tools:
+            args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
+            tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
+        formatted_tools = "\n".join(tool_strings)
+        unique_tool_names = set(tool.name for tool in tools)
+        tool_names = ", ".join('"' + name + '"' for name in unique_tool_names)
+        format_instructions = format_instructions.format(tool_names=tool_names)
+        template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
+        if input_variables is None:
+            input_variables = ["input", "agent_scratchpad"]
+        _memory_prompts = memory_prompts or []
+        messages = [
+            SystemMessagePromptTemplate.from_template(template),
+            *_memory_prompts,
+            HumanMessagePromptTemplate.from_template(human_message_template),
+        ]
+        return ChatPromptTemplate(input_variables=input_variables, messages=messages)
+
+    @classmethod
+    def from_llm_and_tools(
+            cls,
+            llm: BaseLanguageModel,
+            tools: Sequence[BaseTool],
+            callback_manager: Optional[BaseCallbackManager] = None,
+            output_parser: Optional[AgentOutputParser] = None,
+            prefix: str = PREFIX,
+            suffix: str = SUFFIX,
+            human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
+            format_instructions: str = FORMAT_INSTRUCTIONS,
+            input_variables: Optional[List[str]] = None,
+            memory_prompts: Optional[List[BasePromptTemplate]] = None,
+            **kwargs: Any,
+    ) -> Agent:
+        return super().from_llm_and_tools(
+            llm=llm,
+            tools=tools,
+            callback_manager=callback_manager,
+            output_parser=output_parser,
+            prefix=prefix,
+            suffix=suffix,
+            human_message_template=human_message_template,
+            format_instructions=format_instructions,
+            input_variables=input_variables,
+            memory_prompts=memory_prompts,
+            dataset_tools=tools,
+            **kwargs,
+        )

+ 8 - 2
api/core/agent/agent/structured_chat.py

@@ -14,7 +14,7 @@ from langchain.tools import BaseTool
 from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
 from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
 
 
 from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
 from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
-
+from core.model_providers.models.llm.base import BaseLLM
 
 
 FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
 FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
 The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
 The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
@@ -53,6 +53,12 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
     moving_summary_buffer: str = ""
     moving_summary_buffer: str = ""
     moving_summary_index: int = 0
     moving_summary_index: int = 0
     summary_llm: BaseLanguageModel
     summary_llm: BaseLanguageModel
+    model_instance: BaseLLM
+
+    class Config:
+        """Configuration for this pydantic object."""
+
+        arbitrary_types_allowed = True
 
 
     def should_use_agent(self, query: str):
     def should_use_agent(self, query: str):
         """
         """
@@ -89,7 +95,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
         if prompts:
         if prompts:
             messages = prompts[0].to_messages()
             messages = prompts[0].to_messages()
 
 
-        rest_tokens = self.get_message_rest_tokens(self.llm_chain.llm, messages)
+        rest_tokens = self.get_message_rest_tokens(self.model_instance, messages)
         if rest_tokens < 0:
         if rest_tokens < 0:
             full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
             full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
 
 

+ 25 - 11
api/core/agent/agent_executor.py

@@ -3,7 +3,6 @@ import logging
 from typing import Union, Optional
 from typing import Union, Optional
 
 
 from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent
 from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent
-from langchain.base_language import BaseLanguageModel
 from langchain.callbacks.manager import Callbacks
 from langchain.callbacks.manager import Callbacks
 from langchain.memory.chat_memory import BaseChatMemory
 from langchain.memory.chat_memory import BaseChatMemory
 from langchain.tools import BaseTool
 from langchain.tools import BaseTool
@@ -13,14 +12,17 @@ from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
 from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
 from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
 from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
 from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
 from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
 from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
+from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
 from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
 from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
 from langchain.agents import AgentExecutor as LCAgentExecutor
 from langchain.agents import AgentExecutor as LCAgentExecutor
 
 
+from core.model_providers.models.llm.base import BaseLLM
 from core.tool.dataset_retriever_tool import DatasetRetrieverTool
 from core.tool.dataset_retriever_tool import DatasetRetrieverTool
 
 
 
 
 class PlanningStrategy(str, enum.Enum):
 class PlanningStrategy(str, enum.Enum):
     ROUTER = 'router'
     ROUTER = 'router'
+    REACT_ROUTER = 'react_router'
     REACT = 'react'
     REACT = 'react'
     FUNCTION_CALL = 'function_call'
     FUNCTION_CALL = 'function_call'
     MULTI_FUNCTION_CALL = 'multi_function_call'
     MULTI_FUNCTION_CALL = 'multi_function_call'
@@ -28,10 +30,9 @@ class PlanningStrategy(str, enum.Enum):
 
 
 class AgentConfiguration(BaseModel):
 class AgentConfiguration(BaseModel):
     strategy: PlanningStrategy
     strategy: PlanningStrategy
-    llm: BaseLanguageModel
+    model_instance: BaseLLM
     tools: list[BaseTool]
     tools: list[BaseTool]
-    summary_llm: BaseLanguageModel
-    dataset_llm: BaseLanguageModel
+    summary_model_instance: BaseLLM
     memory: Optional[BaseChatMemory] = None
     memory: Optional[BaseChatMemory] = None
     callbacks: Callbacks = None
     callbacks: Callbacks = None
     max_iterations: int = 6
     max_iterations: int = 6
@@ -60,36 +61,49 @@ class AgentExecutor:
     def _init_agent(self) -> Union[BaseSingleActionAgent | BaseMultiActionAgent]:
     def _init_agent(self) -> Union[BaseSingleActionAgent | BaseMultiActionAgent]:
         if self.configuration.strategy == PlanningStrategy.REACT:
         if self.configuration.strategy == PlanningStrategy.REACT:
             agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
             agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
-                llm=self.configuration.llm,
+                model_instance=self.configuration.model_instance,
+                llm=self.configuration.model_instance.client,
                 tools=self.configuration.tools,
                 tools=self.configuration.tools,
                 output_parser=StructuredChatOutputParser(),
                 output_parser=StructuredChatOutputParser(),
-                summary_llm=self.configuration.summary_llm,
+                summary_llm=self.configuration.summary_model_instance.client,
                 verbose=True
                 verbose=True
             )
             )
         elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
         elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
             agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
             agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
-                llm=self.configuration.llm,
+                model_instance=self.configuration.model_instance,
+                llm=self.configuration.model_instance.client,
                 tools=self.configuration.tools,
                 tools=self.configuration.tools,
                 extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,  # used for read chat histories memory
                 extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,  # used for read chat histories memory
-                summary_llm=self.configuration.summary_llm,
+                summary_llm=self.configuration.summary_model_instance.client,
                 verbose=True
                 verbose=True
             )
             )
         elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
         elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
             agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools(
             agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools(
-                llm=self.configuration.llm,
+                model_instance=self.configuration.model_instance,
+                llm=self.configuration.model_instance.client,
                 tools=self.configuration.tools,
                 tools=self.configuration.tools,
                 extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,  # used for read chat histories memory
                 extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,  # used for read chat histories memory
-                summary_llm=self.configuration.summary_llm,
+                summary_llm=self.configuration.summary_model_instance.client,
                 verbose=True
                 verbose=True
             )
             )
         elif self.configuration.strategy == PlanningStrategy.ROUTER:
         elif self.configuration.strategy == PlanningStrategy.ROUTER:
             self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
             self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
             agent = MultiDatasetRouterAgent.from_llm_and_tools(
             agent = MultiDatasetRouterAgent.from_llm_and_tools(
-                llm=self.configuration.dataset_llm,
+                model_instance=self.configuration.model_instance,
+                llm=self.configuration.model_instance.client,
                 tools=self.configuration.tools,
                 tools=self.configuration.tools,
                 extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
                 extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
                 verbose=True
                 verbose=True
             )
             )
+        elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER:
+            self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
+            agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
+                model_instance=self.configuration.model_instance,
+                llm=self.configuration.model_instance.client,
+                tools=self.configuration.tools,
+                output_parser=StructuredChatOutputParser(),
+                verbose=True
+            )
         else:
         else:
             raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}")
             raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}")
 
 

+ 5 - 4
api/core/callback_handler/agent_loop_gather_callback_handler.py

@@ -10,15 +10,16 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
 
 
 from core.callback_handler.entity.agent_loop import AgentLoop
 from core.callback_handler.entity.agent_loop import AgentLoop
 from core.conversation_message_task import ConversationMessageTask
 from core.conversation_message_task import ConversationMessageTask
+from core.model_providers.models.llm.base import BaseLLM
 
 
 
 
 class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
 class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
     """Callback Handler that prints to std out."""
     """Callback Handler that prints to std out."""
     raise_error: bool = True
     raise_error: bool = True
 
 
-    def __init__(self, model_name, conversation_message_task: ConversationMessageTask) -> None:
+    def __init__(self, model_instant: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
         """Initialize callback handler."""
         """Initialize callback handler."""
-        self.model_name = model_name
+        self.model_instant = model_instant
         self.conversation_message_task = conversation_message_task
         self.conversation_message_task = conversation_message_task
         self._agent_loops = []
         self._agent_loops = []
         self._current_loop = None
         self._current_loop = None
@@ -152,7 +153,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
             self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
             self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
 
 
             self.conversation_message_task.on_agent_end(
             self.conversation_message_task.on_agent_end(
-                self._message_agent_thought, self.model_name, self._current_loop
+                self._message_agent_thought, self.model_instant, self._current_loop
             )
             )
 
 
             self._agent_loops.append(self._current_loop)
             self._agent_loops.append(self._current_loop)
@@ -183,7 +184,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
             )
             )
 
 
             self.conversation_message_task.on_agent_end(
             self.conversation_message_task.on_agent_end(
-                self._message_agent_thought, self.model_name, self._current_loop
+                self._message_agent_thought, self.model_instant, self._current_loop
             )
             )
 
 
             self._agent_loops.append(self._current_loop)
             self._agent_loops.append(self._current_loop)

+ 11 - 7
api/core/callback_handler/llm_callback_handler.py

@@ -3,18 +3,20 @@ import time
 from typing import Any, Dict, List, Union
 from typing import Any, Dict, List, Union
 
 
 from langchain.callbacks.base import BaseCallbackHandler
 from langchain.callbacks.base import BaseCallbackHandler
-from langchain.schema import LLMResult, BaseMessage, BaseLanguageModel
+from langchain.schema import LLMResult, BaseMessage
 
 
 from core.callback_handler.entity.llm_message import LLMMessage
 from core.callback_handler.entity.llm_message import LLMMessage
 from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
 from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
+from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage
+from core.model_providers.models.llm.base import BaseLLM
 
 
 
 
 class LLMCallbackHandler(BaseCallbackHandler):
 class LLMCallbackHandler(BaseCallbackHandler):
     raise_error: bool = True
     raise_error: bool = True
 
 
-    def __init__(self, llm: BaseLanguageModel,
+    def __init__(self, model_instance: BaseLLM,
                  conversation_message_task: ConversationMessageTask):
                  conversation_message_task: ConversationMessageTask):
-        self.llm = llm
+        self.model_instance = model_instance
         self.llm_message = LLMMessage()
         self.llm_message = LLMMessage()
         self.start_at = None
         self.start_at = None
         self.conversation_message_task = conversation_message_task
         self.conversation_message_task = conversation_message_task
@@ -46,7 +48,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
             })
             })
 
 
         self.llm_message.prompt = real_prompts
         self.llm_message.prompt = real_prompts
-        self.llm_message.prompt_tokens = self.llm.get_num_tokens_from_messages(messages[0])
+        self.llm_message.prompt_tokens = self.model_instance.get_num_tokens(to_prompt_messages(messages[0]))
 
 
     def on_llm_start(
     def on_llm_start(
         self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
         self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
@@ -58,7 +60,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
             "text": prompts[0]
             "text": prompts[0]
         }]
         }]
 
 
-        self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0])
+        self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
 
 
     def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
     def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
         end_at = time.perf_counter()
         end_at = time.perf_counter()
@@ -68,7 +70,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
             self.conversation_message_task.append_message_text(response.generations[0][0].text)
             self.conversation_message_task.append_message_text(response.generations[0][0].text)
             self.llm_message.completion = response.generations[0][0].text
             self.llm_message.completion = response.generations[0][0].text
 
 
-        self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
+        self.llm_message.completion_tokens = self.model_instance.get_num_tokens([PromptMessage(content=self.llm_message.completion)])
 
 
         self.conversation_message_task.save_message(self.llm_message)
         self.conversation_message_task.save_message(self.llm_message)
 
 
@@ -89,7 +91,9 @@ class LLMCallbackHandler(BaseCallbackHandler):
             if self.conversation_message_task.streaming:
             if self.conversation_message_task.streaming:
                 end_at = time.perf_counter()
                 end_at = time.perf_counter()
                 self.llm_message.latency = end_at - self.start_at
                 self.llm_message.latency = end_at - self.start_at
-                self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
+                self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
+                    [PromptMessage(content=self.llm_message.completion)]
+                )
                 self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
                 self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
         else:
         else:
             logging.error(error)
             logging.error(error)

+ 0 - 2
api/core/callback_handler/main_chain_gather_callback_handler.py

@@ -5,9 +5,7 @@ from typing import Any, Dict, Union
 
 
 from langchain.callbacks.base import BaseCallbackHandler
 from langchain.callbacks.base import BaseCallbackHandler
 
 
-from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
 from core.callback_handler.entity.chain_result import ChainResult
 from core.callback_handler.entity.chain_result import ChainResult
-from core.constant import llm_constant
 from core.conversation_message_task import ConversationMessageTask
 from core.conversation_message_task import ConversationMessageTask
 
 
 
 

+ 86 - 111
api/core/completion.py

@@ -2,27 +2,19 @@ import logging
 import re
 import re
 from typing import Optional, List, Union, Tuple
 from typing import Optional, List, Union, Tuple
 
 
-from langchain.base_language import BaseLanguageModel
-from langchain.callbacks.base import BaseCallbackHandler
-from langchain.chat_models.base import BaseChatModel
-from langchain.llms import BaseLLM
-from langchain.schema import BaseMessage, HumanMessage
+from langchain.schema import BaseMessage
 from requests.exceptions import ChunkedEncodingError
 from requests.exceptions import ChunkedEncodingError
 
 
 from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
 from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
 from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
 from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
-from core.constant import llm_constant
 from core.callback_handler.llm_callback_handler import LLMCallbackHandler
 from core.callback_handler.llm_callback_handler import LLMCallbackHandler
-from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
-    DifyStdOutCallbackHandler
 from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
 from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
-from core.llm.error import LLMBadRequestError
-from core.llm.fake import FakeLLM
-from core.llm.llm_builder import LLMBuilder
-from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
-from core.llm.streamable_open_ai import StreamableOpenAI
+from core.model_providers.error import LLMBadRequestError
 from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
 from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
     ReadOnlyConversationTokenDBBufferSharedMemory
     ReadOnlyConversationTokenDBBufferSharedMemory
+from core.model_providers.model_factory import ModelFactory
+from core.model_providers.models.entity.message import PromptMessage, to_prompt_messages
+from core.model_providers.models.llm.base import BaseLLM
 from core.orchestrator_rule_parser import OrchestratorRuleParser
 from core.orchestrator_rule_parser import OrchestratorRuleParser
 from core.prompt.prompt_builder import PromptBuilder
 from core.prompt.prompt_builder import PromptBuilder
 from core.prompt.prompt_template import JinjaPromptTemplate
 from core.prompt.prompt_template import JinjaPromptTemplate
@@ -51,12 +43,10 @@ class Completion:
 
 
             inputs = conversation.inputs
             inputs = conversation.inputs
 
 
-        rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
-            mode=app.mode,
+        final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
             tenant_id=app.tenant_id,
             tenant_id=app.tenant_id,
-            app_model_config=app_model_config,
-            query=query,
-            inputs=inputs
+            model_config=app_model_config.model_dict,
+            streaming=streaming
         )
         )
 
 
         conversation_message_task = ConversationMessageTask(
         conversation_message_task = ConversationMessageTask(
@@ -68,10 +58,17 @@ class Completion:
             is_override=is_override,
             is_override=is_override,
             inputs=inputs,
             inputs=inputs,
             query=query,
             query=query,
-            streaming=streaming
+            streaming=streaming,
+            model_instance=final_model_instance
         )
         )
 
 
-        chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
+        rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
+            mode=app.mode,
+            model_instance=final_model_instance,
+            app_model_config=app_model_config,
+            query=query,
+            inputs=inputs
+        )
 
 
         # init orchestrator rule parser
         # init orchestrator rule parser
         orchestrator_rule_parser = OrchestratorRuleParser(
         orchestrator_rule_parser = OrchestratorRuleParser(
@@ -80,6 +77,7 @@ class Completion:
         )
         )
 
 
         # parse sensitive_word_avoidance_chain
         # parse sensitive_word_avoidance_chain
+        chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
         sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback])
         sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback])
         if sensitive_word_avoidance_chain:
         if sensitive_word_avoidance_chain:
             query = sensitive_word_avoidance_chain.run(query)
             query = sensitive_word_avoidance_chain.run(query)
@@ -102,15 +100,14 @@ class Completion:
         # run the final llm
         # run the final llm
         try:
         try:
             cls.run_final_llm(
             cls.run_final_llm(
-                tenant_id=app.tenant_id,
+                model_instance=final_model_instance,
                 mode=app.mode,
                 mode=app.mode,
                 app_model_config=app_model_config,
                 app_model_config=app_model_config,
                 query=query,
                 query=query,
                 inputs=inputs,
                 inputs=inputs,
                 agent_execute_result=agent_execute_result,
                 agent_execute_result=agent_execute_result,
                 conversation_message_task=conversation_message_task,
                 conversation_message_task=conversation_message_task,
-                memory=memory,
-                streaming=streaming
+                memory=memory
             )
             )
         except ConversationTaskStoppedException:
         except ConversationTaskStoppedException:
             return
             return
@@ -121,31 +118,20 @@ class Completion:
             return
             return
 
 
     @classmethod
     @classmethod
-    def run_final_llm(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
+    def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
                       agent_execute_result: Optional[AgentExecuteResult],
                       agent_execute_result: Optional[AgentExecuteResult],
                       conversation_message_task: ConversationMessageTask,
                       conversation_message_task: ConversationMessageTask,
-                      memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], streaming: bool):
+                      memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]):
         # When no extra pre prompt is specified,
         # When no extra pre prompt is specified,
         # the output of the agent can be used directly as the main output content without calling LLM again
         # the output of the agent can be used directly as the main output content without calling LLM again
+        fake_response = None
         if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
         if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
                 and agent_execute_result.strategy != PlanningStrategy.ROUTER:
                 and agent_execute_result.strategy != PlanningStrategy.ROUTER:
-            final_llm = FakeLLM(response=agent_execute_result.output,
-                                origin_llm=agent_execute_result.configuration.llm,
-                                streaming=streaming)
-            final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
-            response = final_llm.generate([[HumanMessage(content=query)]])
-            return response
-
-        final_llm = LLMBuilder.to_llm_from_model(
-            tenant_id=tenant_id,
-            model=app_model_config.model_dict,
-            streaming=streaming
-        )
+            fake_response = agent_execute_result.output
 
 
         # get llm prompt
         # get llm prompt
-        prompt, stop_words = cls.get_main_llm_prompt(
+        prompt_messages, stop_words = cls.get_main_llm_prompt(
             mode=mode,
             mode=mode,
-            llm=final_llm,
             model=app_model_config.model_dict,
             model=app_model_config.model_dict,
             pre_prompt=app_model_config.pre_prompt,
             pre_prompt=app_model_config.pre_prompt,
             query=query,
             query=query,
@@ -154,25 +140,26 @@ class Completion:
             memory=memory
             memory=memory
         )
         )
 
 
-        final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
-
         cls.recale_llm_max_tokens(
         cls.recale_llm_max_tokens(
-            final_llm=final_llm,
-            model=app_model_config.model_dict,
-            prompt=prompt,
-            mode=mode
+            model_instance=model_instance,
+            prompt_messages=prompt_messages,
         )
         )
 
 
-        response = final_llm.generate([prompt], stop_words)
+        response = model_instance.run(
+            messages=prompt_messages,
+            stop=stop_words,
+            callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
+            fake_response=fake_response
+        )
 
 
         return response
         return response
 
 
     @classmethod
     @classmethod
-    def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, model: dict,
+    def get_main_llm_prompt(cls, mode: str, model: dict,
                             pre_prompt: str, query: str, inputs: dict,
                             pre_prompt: str, query: str, inputs: dict,
                             agent_execute_result: Optional[AgentExecuteResult],
                             agent_execute_result: Optional[AgentExecuteResult],
                             memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
                             memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
-            Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
+            Tuple[List[PromptMessage], Optional[List[str]]]:
         if mode == 'completion':
         if mode == 'completion':
             prompt_template = JinjaPromptTemplate.from_template(
             prompt_template = JinjaPromptTemplate.from_template(
                 template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
                 template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
@@ -200,11 +187,7 @@ And answer according to the language of the user's question.
                 **prompt_inputs
                 **prompt_inputs
             )
             )
 
 
-            if isinstance(llm, BaseChatModel):
-                # use chat llm as completion model
-                return [HumanMessage(content=prompt_content)], None
-            else:
-                return prompt_content, None
+            return [PromptMessage(content=prompt_content)], None
         else:
         else:
             messages: List[BaseMessage] = []
             messages: List[BaseMessage] = []
 
 
@@ -249,12 +232,14 @@ And answer according to the language of the user's question.
                     inputs=human_inputs
                     inputs=human_inputs
                 )
                 )
 
 
-                curr_message_tokens = memory.llm.get_num_tokens_from_messages([tmp_human_message])
-                model_name = model['name']
-                max_tokens = model.get("completion_params").get('max_tokens')
-                rest_tokens = llm_constant.max_context_token_length[model_name] \
-                              - max_tokens - curr_message_tokens
-                rest_tokens = max(rest_tokens, 0)
+                if memory.model_instance.model_rules.max_tokens.max:
+                    curr_message_tokens = memory.model_instance.get_num_tokens(to_prompt_messages([tmp_human_message]))
+                    max_tokens = model.get("completion_params").get('max_tokens')
+                    rest_tokens = memory.model_instance.model_rules.max_tokens.max - max_tokens - curr_message_tokens
+                    rest_tokens = max(rest_tokens, 0)
+                else:
+                    rest_tokens = 2000
+
                 histories = cls.get_history_messages_from_memory(memory, rest_tokens)
                 histories = cls.get_history_messages_from_memory(memory, rest_tokens)
                 human_message_prompt += "\n\n" if human_message_prompt else ""
                 human_message_prompt += "\n\n" if human_message_prompt else ""
                 human_message_prompt += "Here is the chat histories between human and assistant, " \
                 human_message_prompt += "Here is the chat histories between human and assistant, " \
@@ -274,17 +259,7 @@ And answer according to the language of the user's question.
             for message in messages:
             for message in messages:
                 message.content = re.sub(r'<\|.*?\|>', '', message.content)
                 message.content = re.sub(r'<\|.*?\|>', '', message.content)
 
 
-            return messages, ['\nHuman:', '</histories>']
-
-    @classmethod
-    def get_llm_callbacks(cls, llm: BaseLanguageModel,
-                          streaming: bool,
-                          conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]:
-        llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
-        if streaming:
-            return [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
-        else:
-            return [llm_callback_handler, DifyStdOutCallbackHandler()]
+            return to_prompt_messages(messages), ['\nHuman:', '</histories>']
 
 
     @classmethod
     @classmethod
     def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
     def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
@@ -300,15 +275,15 @@ And answer according to the language of the user's question.
                                      conversation: Conversation,
                                      conversation: Conversation,
                                      **kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory:
                                      **kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory:
         # only for calc token in memory
         # only for calc token in memory
-        memory_llm = LLMBuilder.to_llm_from_model(
+        memory_model_instance = ModelFactory.get_text_generation_model_from_model_config(
             tenant_id=tenant_id,
             tenant_id=tenant_id,
-            model=app_model_config.model_dict
+            model_config=app_model_config.model_dict
         )
         )
 
 
         # use llm config from conversation
         # use llm config from conversation
         memory = ReadOnlyConversationTokenDBBufferSharedMemory(
         memory = ReadOnlyConversationTokenDBBufferSharedMemory(
             conversation=conversation,
             conversation=conversation,
-            llm=memory_llm,
+            model_instance=memory_model_instance,
             max_token_limit=kwargs.get("max_token_limit", 2048),
             max_token_limit=kwargs.get("max_token_limit", 2048),
             memory_key=kwargs.get("memory_key", "chat_history"),
             memory_key=kwargs.get("memory_key", "chat_history"),
             return_messages=kwargs.get("return_messages", True),
             return_messages=kwargs.get("return_messages", True),
@@ -320,21 +295,20 @@ And answer according to the language of the user's question.
         return memory
         return memory
 
 
     @classmethod
     @classmethod
-    def get_validate_rest_tokens(cls, mode: str, tenant_id: str, app_model_config: AppModelConfig,
+    def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_config: AppModelConfig,
                                  query: str, inputs: dict) -> int:
                                  query: str, inputs: dict) -> int:
-        llm = LLMBuilder.to_llm_from_model(
-            tenant_id=tenant_id,
-            model=app_model_config.model_dict
-        )
+        model_limited_tokens = model_instance.model_rules.max_tokens.max
+        max_tokens = model_instance.get_model_kwargs().max_tokens
 
 
-        model_name = app_model_config.model_dict.get("name")
-        model_limited_tokens = llm_constant.max_context_token_length[model_name]
-        max_tokens = app_model_config.model_dict.get("completion_params").get('max_tokens')
+        if model_limited_tokens is None:
+            return -1
+
+        if max_tokens is None:
+            max_tokens = 0
 
 
         # get prompt without memory and context
         # get prompt without memory and context
-        prompt, _ = cls.get_main_llm_prompt(
+        prompt_messages, _ = cls.get_main_llm_prompt(
             mode=mode,
             mode=mode,
-            llm=llm,
             model=app_model_config.model_dict,
             model=app_model_config.model_dict,
             pre_prompt=app_model_config.pre_prompt,
             pre_prompt=app_model_config.pre_prompt,
             query=query,
             query=query,
@@ -343,9 +317,7 @@ And answer according to the language of the user's question.
             memory=None
             memory=None
         )
         )
 
 
-        prompt_tokens = llm.get_num_tokens(prompt) if isinstance(prompt, str) \
-            else llm.get_num_tokens_from_messages(prompt)
-
+        prompt_tokens = model_instance.get_num_tokens(prompt_messages)
         rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
         rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
         if rest_tokens < 0:
         if rest_tokens < 0:
             raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
             raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
@@ -354,36 +326,40 @@ And answer according to the language of the user's question.
         return rest_tokens
         return rest_tokens
 
 
     @classmethod
     @classmethod
-    def recale_llm_max_tokens(cls, final_llm: BaseLanguageModel, model: dict,
-                              prompt: Union[str, List[BaseMessage]], mode: str):
+    def recale_llm_max_tokens(cls, model_instance: BaseLLM, prompt_messages: List[PromptMessage]):
         # recalc max_tokens if sum(prompt_token +  max_tokens) over model token limit
         # recalc max_tokens if sum(prompt_token +  max_tokens) over model token limit
-        model_name = model.get("name")
-        model_limited_tokens = llm_constant.max_context_token_length[model_name]
-        max_tokens = model.get("completion_params").get('max_tokens')
+        model_limited_tokens = model_instance.model_rules.max_tokens.max
+        max_tokens = model_instance.get_model_kwargs().max_tokens
 
 
-        if mode == 'completion' and isinstance(final_llm, BaseLLM):
-            prompt_tokens = final_llm.get_num_tokens(prompt)
-        else:
-            prompt_tokens = final_llm.get_num_tokens_from_messages(prompt)
+        if model_limited_tokens is None:
+            return
+
+        if max_tokens is None:
+            max_tokens = 0
+
+        prompt_tokens = model_instance.get_num_tokens(prompt_messages)
 
 
         if prompt_tokens + max_tokens > model_limited_tokens:
         if prompt_tokens + max_tokens > model_limited_tokens:
             max_tokens = max(model_limited_tokens - prompt_tokens, 16)
             max_tokens = max(model_limited_tokens - prompt_tokens, 16)
-            final_llm.max_tokens = max_tokens
+
+            # update model instance max tokens
+            model_kwargs = model_instance.get_model_kwargs()
+            model_kwargs.max_tokens = max_tokens
+            model_instance.set_model_kwargs(model_kwargs)
 
 
     @classmethod
     @classmethod
     def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
     def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
                                 app_model_config: AppModelConfig, user: Account, streaming: bool):
                                 app_model_config: AppModelConfig, user: Account, streaming: bool):
 
 
-        llm = LLMBuilder.to_llm_from_model(
+        final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
             tenant_id=app.tenant_id,
             tenant_id=app.tenant_id,
-            model=app_model_config.model_dict,
+            model_config=app_model_config.model_dict,
             streaming=streaming
             streaming=streaming
         )
         )
 
 
         # get llm prompt
         # get llm prompt
-        original_prompt, _ = cls.get_main_llm_prompt(
+        old_prompt_messages, _ = cls.get_main_llm_prompt(
             mode="completion",
             mode="completion",
-            llm=llm,
             model=app_model_config.model_dict,
             model=app_model_config.model_dict,
             pre_prompt=pre_prompt,
             pre_prompt=pre_prompt,
             query=message.query,
             query=message.query,
@@ -395,10 +371,9 @@ And answer according to the language of the user's question.
         original_completion = message.answer.strip()
         original_completion = message.answer.strip()
 
 
         prompt = MORE_LIKE_THIS_GENERATE_PROMPT
         prompt = MORE_LIKE_THIS_GENERATE_PROMPT
-        prompt = prompt.format(prompt=original_prompt, original_completion=original_completion)
+        prompt = prompt.format(prompt=old_prompt_messages[0].content, original_completion=original_completion)
 
 
-        if isinstance(llm, BaseChatModel):
-            prompt = [HumanMessage(content=prompt)]
+        prompt_messages = [PromptMessage(content=prompt)]
 
 
         conversation_message_task = ConversationMessageTask(
         conversation_message_task = ConversationMessageTask(
             task_id=task_id,
             task_id=task_id,
@@ -408,16 +383,16 @@ And answer according to the language of the user's question.
             inputs=message.inputs,
             inputs=message.inputs,
             query=message.query,
             query=message.query,
             is_override=True if message.override_model_configs else False,
             is_override=True if message.override_model_configs else False,
-            streaming=streaming
+            streaming=streaming,
+            model_instance=final_model_instance
         )
         )
 
 
-        llm.callbacks = cls.get_llm_callbacks(llm, streaming, conversation_message_task)
-
         cls.recale_llm_max_tokens(
         cls.recale_llm_max_tokens(
-            final_llm=llm,
-            model=app_model_config.model_dict,
-            prompt=prompt,
-            mode='completion'
+            model_instance=final_model_instance,
+            prompt_messages=prompt_messages
         )
         )
 
 
-        llm.generate([prompt])
+        final_model_instance.run(
+            messages=prompt_messages,
+            callbacks=[LLMCallbackHandler(final_model_instance, conversation_message_task)]
+        )

+ 0 - 109
api/core/constant/llm_constant.py

@@ -1,109 +0,0 @@
-from _decimal import Decimal
-
-models = {
-    'claude-instant-1': 'anthropic',  # 100,000 tokens
-    'claude-2': 'anthropic',  # 100,000 tokens
-    'gpt-4': 'openai',  # 8,192 tokens
-    'gpt-4-32k': 'openai',  # 32,768 tokens
-    'gpt-3.5-turbo': 'openai',  # 4,096 tokens
-    'gpt-3.5-turbo-16k': 'openai',  # 16384 tokens
-    'text-davinci-003': 'openai',  # 4,097 tokens
-    'text-davinci-002': 'openai',  # 4,097 tokens
-    'text-curie-001': 'openai',  # 2,049 tokens
-    'text-babbage-001': 'openai',  # 2,049 tokens
-    'text-ada-001': 'openai',  # 2,049 tokens
-    'text-embedding-ada-002': 'openai',  # 8191 tokens, 1536 dimensions
-    'whisper-1': 'openai'
-}
-
-max_context_token_length = {
-    'claude-instant-1': 100000,
-    'claude-2': 100000,
-    'gpt-4': 8192,
-    'gpt-4-32k': 32768,
-    'gpt-3.5-turbo': 4096,
-    'gpt-3.5-turbo-16k': 16384,
-    'text-davinci-003': 4097,
-    'text-davinci-002': 4097,
-    'text-curie-001': 2049,
-    'text-babbage-001': 2049,
-    'text-ada-001': 2049,
-    'text-embedding-ada-002': 8191,
-}
-
-models_by_mode = {
-    'chat': [
-        'claude-instant-1',  # 100,000 tokens
-        'claude-2',  # 100,000 tokens
-        'gpt-4',  # 8,192 tokens
-        'gpt-4-32k',  # 32,768 tokens
-        'gpt-3.5-turbo',  # 4,096 tokens
-        'gpt-3.5-turbo-16k',  # 16,384 tokens
-    ],
-    'completion': [
-        'claude-instant-1',  # 100,000 tokens
-        'claude-2',  # 100,000 tokens
-        'gpt-4',  # 8,192 tokens
-        'gpt-4-32k',  # 32,768 tokens
-        'gpt-3.5-turbo',  # 4,096 tokens
-        'gpt-3.5-turbo-16k',  # 16,384 tokens
-        'text-davinci-003',  # 4,097 tokens
-        'text-davinci-002'  # 4,097 tokens
-        'text-curie-001',  # 2,049 tokens
-        'text-babbage-001',  # 2,049 tokens
-        'text-ada-001'  # 2,049 tokens
-    ],
-    'embedding': [
-        'text-embedding-ada-002'  # 8191 tokens, 1536 dimensions
-    ]
-}
-
-model_currency = 'USD'
-
-model_prices = {
-    'claude-instant-1': {
-        'prompt': Decimal('0.00163'),
-        'completion': Decimal('0.00551'),
-    },
-    'claude-2': {
-        'prompt': Decimal('0.01102'),
-        'completion': Decimal('0.03268'),
-    },
-    'gpt-4': {
-        'prompt': Decimal('0.03'),
-        'completion': Decimal('0.06'),
-    },
-    'gpt-4-32k': {
-        'prompt': Decimal('0.06'),
-        'completion': Decimal('0.12')
-    },
-    'gpt-3.5-turbo': {
-        'prompt': Decimal('0.0015'),
-        'completion': Decimal('0.002')
-    },
-    'gpt-3.5-turbo-16k': {
-        'prompt': Decimal('0.003'),
-        'completion': Decimal('0.004')
-    },
-    'text-davinci-003': {
-        'prompt': Decimal('0.02'),
-        'completion': Decimal('0.02')
-    },
-    'text-curie-001': {
-        'prompt': Decimal('0.002'),
-        'completion': Decimal('0.002')
-    },
-    'text-babbage-001': {
-        'prompt': Decimal('0.0005'),
-        'completion': Decimal('0.0005')
-    },
-    'text-ada-001': {
-        'prompt': Decimal('0.0004'),
-        'completion': Decimal('0.0004')
-    },
-    'text-embedding-ada-002': {
-        'usage': Decimal('0.0001'),
-    }
-}
-
-agent_model_name = 'text-davinci-003'

+ 22 - 37
api/core/conversation_message_task.py

@@ -6,9 +6,9 @@ from core.callback_handler.entity.agent_loop import AgentLoop
 from core.callback_handler.entity.dataset_query import DatasetQueryObj
 from core.callback_handler.entity.dataset_query import DatasetQueryObj
 from core.callback_handler.entity.llm_message import LLMMessage
 from core.callback_handler.entity.llm_message import LLMMessage
 from core.callback_handler.entity.chain_result import ChainResult
 from core.callback_handler.entity.chain_result import ChainResult
-from core.constant import llm_constant
-from core.llm.llm_builder import LLMBuilder
-from core.llm.provider.llm_provider_service import LLMProviderService
+from core.model_providers.model_factory import ModelFactory
+from core.model_providers.models.entity.message import to_prompt_messages, MessageType
+from core.model_providers.models.llm.base import BaseLLM
 from core.prompt.prompt_builder import PromptBuilder
 from core.prompt.prompt_builder import PromptBuilder
 from core.prompt.prompt_template import JinjaPromptTemplate
 from core.prompt.prompt_template import JinjaPromptTemplate
 from events.message_event import message_was_created
 from events.message_event import message_was_created
@@ -16,12 +16,11 @@ from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from extensions.ext_redis import redis_client
 from models.dataset import DatasetQuery
 from models.dataset import DatasetQuery
 from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, MessageChain
 from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, MessageChain
-from models.provider import ProviderType, Provider
 
 
 
 
 class ConversationMessageTask:
 class ConversationMessageTask:
     def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
     def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
-                 inputs: dict, query: str, streaming: bool,
+                 inputs: dict, query: str, streaming: bool, model_instance: BaseLLM,
                  conversation: Optional[Conversation] = None, is_override: bool = False):
                  conversation: Optional[Conversation] = None, is_override: bool = False):
         self.task_id = task_id
         self.task_id = task_id
 
 
@@ -38,9 +37,12 @@ class ConversationMessageTask:
         self.conversation = conversation
         self.conversation = conversation
         self.is_new_conversation = False
         self.is_new_conversation = False
 
 
+        self.model_instance = model_instance
+
         self.message = None
         self.message = None
 
 
         self.model_dict = self.app_model_config.model_dict
         self.model_dict = self.app_model_config.model_dict
+        self.provider_name = self.model_dict.get('provider')
         self.model_name = self.model_dict.get('name')
         self.model_name = self.model_dict.get('name')
         self.mode = app.mode
         self.mode = app.mode
 
 
@@ -56,9 +58,6 @@ class ConversationMessageTask:
         )
         )
 
 
     def init(self):
     def init(self):
-        provider_name = LLMBuilder.get_default_provider(self.app.tenant_id, self.model_name)
-        self.model_dict['provider'] = provider_name
-
         override_model_configs = None
         override_model_configs = None
         if self.is_override:
         if self.is_override:
             override_model_configs = {
             override_model_configs = {
@@ -89,15 +88,19 @@ class ConversationMessageTask:
             if self.app_model_config.pre_prompt:
             if self.app_model_config.pre_prompt:
                 system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs)
                 system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs)
                 system_instruction = system_message.content
                 system_instruction = system_message.content
-                llm = LLMBuilder.to_llm(self.tenant_id, self.model_name)
-                system_instruction_tokens = llm.get_num_tokens_from_messages([system_message])
+                model_instance = ModelFactory.get_text_generation_model(
+                    tenant_id=self.tenant_id,
+                    model_provider_name=self.provider_name,
+                    model_name=self.model_name
+                )
+                system_instruction_tokens = model_instance.get_num_tokens(to_prompt_messages([system_message]))
 
 
         if not self.conversation:
         if not self.conversation:
             self.is_new_conversation = True
             self.is_new_conversation = True
             self.conversation = Conversation(
             self.conversation = Conversation(
                 app_id=self.app_model_config.app_id,
                 app_id=self.app_model_config.app_id,
                 app_model_config_id=self.app_model_config.id,
                 app_model_config_id=self.app_model_config.id,
-                model_provider=self.model_dict.get('provider'),
+                model_provider=self.provider_name,
                 model_id=self.model_name,
                 model_id=self.model_name,
                 override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
                 override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
                 mode=self.mode,
                 mode=self.mode,
@@ -117,7 +120,7 @@ class ConversationMessageTask:
 
 
         self.message = Message(
         self.message = Message(
             app_id=self.app_model_config.app_id,
             app_id=self.app_model_config.app_id,
-            model_provider=self.model_dict.get('provider'),
+            model_provider=self.provider_name,
             model_id=self.model_name,
             model_id=self.model_name,
             override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
             override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
             conversation_id=self.conversation.id,
             conversation_id=self.conversation.id,
@@ -131,7 +134,7 @@ class ConversationMessageTask:
             answer_unit_price=0,
             answer_unit_price=0,
             provider_response_latency=0,
             provider_response_latency=0,
             total_price=0,
             total_price=0,
-            currency=llm_constant.model_currency,
+            currency=self.model_instance.get_currency(),
             from_source=('console' if isinstance(self.user, Account) else 'api'),
             from_source=('console' if isinstance(self.user, Account) else 'api'),
             from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
             from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
             from_account_id=(self.user.id if isinstance(self.user, Account) else None),
             from_account_id=(self.user.id if isinstance(self.user, Account) else None),
@@ -145,12 +148,10 @@ class ConversationMessageTask:
         self._pub_handler.pub_text(text)
         self._pub_handler.pub_text(text)
 
 
     def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
     def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
-        model_name = self.app_model_config.model_dict.get('name')
-
         message_tokens = llm_message.prompt_tokens
         message_tokens = llm_message.prompt_tokens
         answer_tokens = llm_message.completion_tokens
         answer_tokens = llm_message.completion_tokens
-        message_unit_price = llm_constant.model_prices[model_name]['prompt']
-        answer_unit_price = llm_constant.model_prices[model_name]['completion']
+        message_unit_price = self.model_instance.get_token_price(1, MessageType.HUMAN)
+        answer_unit_price = self.model_instance.get_token_price(1, MessageType.ASSISTANT)
 
 
         total_price = self.calc_total_price(message_tokens, message_unit_price, answer_tokens, answer_unit_price)
         total_price = self.calc_total_price(message_tokens, message_unit_price, answer_tokens, answer_unit_price)
 
 
@@ -163,8 +164,6 @@ class ConversationMessageTask:
         self.message.provider_response_latency = llm_message.latency
         self.message.provider_response_latency = llm_message.latency
         self.message.total_price = total_price
         self.message.total_price = total_price
 
 
-        self.update_provider_quota()
-
         db.session.commit()
         db.session.commit()
 
 
         message_was_created.send(
         message_was_created.send(
@@ -176,20 +175,6 @@ class ConversationMessageTask:
         if not by_stopped:
         if not by_stopped:
             self.end()
             self.end()
 
 
-    def update_provider_quota(self):
-        llm_provider_service = LLMProviderService(
-            tenant_id=self.app.tenant_id,
-            provider_name=self.message.model_provider,
-        )
-
-        provider = llm_provider_service.get_provider_db_record()
-        if provider and provider.provider_type == ProviderType.SYSTEM.value:
-            db.session.query(Provider).filter(
-                Provider.tenant_id == self.app.tenant_id,
-                Provider.provider_name == provider.provider_name,
-                Provider.quota_limit > Provider.quota_used
-            ).update({'quota_used': Provider.quota_used + 1})
-
     def init_chain(self, chain_result: ChainResult):
     def init_chain(self, chain_result: ChainResult):
         message_chain = MessageChain(
         message_chain = MessageChain(
             message_id=self.message.id,
             message_id=self.message.id,
@@ -229,10 +214,10 @@ class ConversationMessageTask:
 
 
         return message_agent_thought
         return message_agent_thought
 
 
-    def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_name: str,
+    def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM,
                      agent_loop: AgentLoop):
                      agent_loop: AgentLoop):
-        agent_message_unit_price = llm_constant.model_prices[agent_model_name]['prompt']
-        agent_answer_unit_price = llm_constant.model_prices[agent_model_name]['completion']
+        agent_message_unit_price = agent_model_instant.get_token_price(1, MessageType.HUMAN)
+        agent_answer_unit_price = agent_model_instant.get_token_price(1, MessageType.ASSISTANT)
 
 
         loop_message_tokens = agent_loop.prompt_tokens
         loop_message_tokens = agent_loop.prompt_tokens
         loop_answer_tokens = agent_loop.completion_tokens
         loop_answer_tokens = agent_loop.completion_tokens
@@ -253,7 +238,7 @@ class ConversationMessageTask:
         message_agent_thought.latency = agent_loop.latency
         message_agent_thought.latency = agent_loop.latency
         message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
         message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
         message_agent_thought.total_price = loop_total_price
         message_agent_thought.total_price = loop_total_price
-        message_agent_thought.currency = llm_constant.model_currency
+        message_agent_thought.currency = agent_model_instant.get_currency()
         db.session.flush()
         db.session.flush()
 
 
     def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):
     def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):

+ 6 - 8
api/core/docstore/dataset_docstore.py

@@ -3,7 +3,7 @@ from typing import Any, Dict, Optional, Sequence
 from langchain.schema import Document
 from langchain.schema import Document
 from sqlalchemy import func
 from sqlalchemy import func
 
 
-from core.llm.token_calculator import TokenCalculator
+from core.model_providers.model_factory import ModelFactory
 from extensions.ext_database import db
 from extensions.ext_database import db
 from models.dataset import Dataset, DocumentSegment
 from models.dataset import Dataset, DocumentSegment
 
 
@@ -13,12 +13,10 @@ class DatesetDocumentStore:
         self,
         self,
         dataset: Dataset,
         dataset: Dataset,
         user_id: str,
         user_id: str,
-        embedding_model_name: str,
         document_id: Optional[str] = None,
         document_id: Optional[str] = None,
     ):
     ):
         self._dataset = dataset
         self._dataset = dataset
         self._user_id = user_id
         self._user_id = user_id
-        self._embedding_model_name = embedding_model_name
         self._document_id = document_id
         self._document_id = document_id
 
 
     @classmethod
     @classmethod
@@ -39,10 +37,6 @@ class DatesetDocumentStore:
     def user_id(self) -> Any:
     def user_id(self) -> Any:
         return self._user_id
         return self._user_id
 
 
-    @property
-    def embedding_model_name(self) -> Any:
-        return self._embedding_model_name
-
     @property
     @property
     def docs(self) -> Dict[str, Document]:
     def docs(self) -> Dict[str, Document]:
         document_segments = db.session.query(DocumentSegment).filter(
         document_segments = db.session.query(DocumentSegment).filter(
@@ -74,6 +68,10 @@ class DatesetDocumentStore:
         if max_position is None:
         if max_position is None:
             max_position = 0
             max_position = 0
 
 
+        embedding_model = ModelFactory.get_embedding_model(
+            tenant_id=self._dataset.tenant_id
+        )
+
         for doc in docs:
         for doc in docs:
             if not isinstance(doc, Document):
             if not isinstance(doc, Document):
                 raise ValueError("doc must be a Document")
                 raise ValueError("doc must be a Document")
@@ -88,7 +86,7 @@ class DatesetDocumentStore:
                 )
                 )
 
 
             # calc embedding use tokens
             # calc embedding use tokens
-            tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.page_content)
+            tokens = embedding_model.get_num_tokens(doc.page_content)
 
 
             if not segment_document:
             if not segment_document:
                 max_position += 1
                 max_position += 1

+ 33 - 25
api/core/embedding/cached_embedding.py

@@ -4,14 +4,14 @@ from typing import List
 from langchain.embeddings.base import Embeddings
 from langchain.embeddings.base import Embeddings
 from sqlalchemy.exc import IntegrityError
 from sqlalchemy.exc import IntegrityError
 
 
-from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
+from core.model_providers.models.embedding.base import BaseEmbedding
 from extensions.ext_database import db
 from extensions.ext_database import db
 from libs import helper
 from libs import helper
 from models.dataset import Embedding
 from models.dataset import Embedding
 
 
 
 
 class CacheEmbedding(Embeddings):
 class CacheEmbedding(Embeddings):
-    def __init__(self, embeddings: Embeddings):
+    def __init__(self, embeddings: BaseEmbedding):
         self._embeddings = embeddings
         self._embeddings = embeddings
 
 
     def embed_documents(self, texts: List[str]) -> List[List[float]]:
     def embed_documents(self, texts: List[str]) -> List[List[float]]:
@@ -21,48 +21,54 @@ class CacheEmbedding(Embeddings):
         embedding_queue_texts = []
         embedding_queue_texts = []
         for text in texts:
         for text in texts:
             hash = helper.generate_text_hash(text)
             hash = helper.generate_text_hash(text)
-            embedding = db.session.query(Embedding).filter_by(hash=hash).first()
+            embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first()
             if embedding:
             if embedding:
                 text_embeddings.append(embedding.get_embedding())
                 text_embeddings.append(embedding.get_embedding())
             else:
             else:
                 embedding_queue_texts.append(text)
                 embedding_queue_texts.append(text)
 
 
-        embedding_results = self._embeddings.embed_documents(embedding_queue_texts)
+        if embedding_queue_texts:
+            try:
+                embedding_results = self._embeddings.client.embed_documents(embedding_queue_texts)
+            except Exception as ex:
+                raise self._embeddings.handle_exceptions(ex)
 
 
-        i = 0
-        for text in embedding_queue_texts:
-            hash = helper.generate_text_hash(text)
+            i = 0
+            for text in embedding_queue_texts:
+                hash = helper.generate_text_hash(text)
 
 
-            try:
-                embedding = Embedding(hash=hash)
-                embedding.set_embedding(embedding_results[i])
-                db.session.add(embedding)
-                db.session.commit()
-            except IntegrityError:
-                db.session.rollback()
-                continue
-            except:
-                logging.exception('Failed to add embedding to db')
-                continue
-            finally:
-                i += 1
+                try:
+                    embedding = Embedding(model_name=self._embeddings.name, hash=hash)
+                    embedding.set_embedding(embedding_results[i])
+                    db.session.add(embedding)
+                    db.session.commit()
+                except IntegrityError:
+                    db.session.rollback()
+                    continue
+                except:
+                    logging.exception('Failed to add embedding to db')
+                    continue
+                finally:
+                    i += 1
 
 
-        text_embeddings.extend(embedding_results)
+            text_embeddings.extend(embedding_results)
         return text_embeddings
         return text_embeddings
 
 
-    @handle_openai_exceptions
     def embed_query(self, text: str) -> List[float]:
     def embed_query(self, text: str) -> List[float]:
         """Embed query text."""
         """Embed query text."""
         # use doc embedding cache or store if not exists
         # use doc embedding cache or store if not exists
         hash = helper.generate_text_hash(text)
         hash = helper.generate_text_hash(text)
-        embedding = db.session.query(Embedding).filter_by(hash=hash).first()
+        embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first()
         if embedding:
         if embedding:
             return embedding.get_embedding()
             return embedding.get_embedding()
 
 
-        embedding_results = self._embeddings.embed_query(text)
+        try:
+            embedding_results = self._embeddings.client.embed_query(text)
+        except Exception as ex:
+            raise self._embeddings.handle_exceptions(ex)
 
 
         try:
         try:
-            embedding = Embedding(hash=hash)
+            embedding = Embedding(model_name=self._embeddings.name, hash=hash)
             embedding.set_embedding(embedding_results)
             embedding.set_embedding(embedding_results)
             db.session.add(embedding)
             db.session.add(embedding)
             db.session.commit()
             db.session.commit()
@@ -72,3 +78,5 @@ class CacheEmbedding(Embeddings):
             logging.exception('Failed to add embedding to db')
             logging.exception('Failed to add embedding to db')
 
 
         return embedding_results
         return embedding_results
+
+

+ 61 - 83
api/core/generator/llm_generator.py

@@ -1,13 +1,10 @@
 import logging
 import logging
 
 
-from langchain import PromptTemplate
-from langchain.chat_models.base import BaseChatModel
-from langchain.schema import HumanMessage, OutputParserException, BaseMessage, SystemMessage
-
-from core.constant import llm_constant
-from core.llm.llm_builder import LLMBuilder
-from core.llm.streamable_open_ai import StreamableOpenAI
-from core.llm.token_calculator import TokenCalculator
+from langchain.schema import OutputParserException
+
+from core.model_providers.model_factory import ModelFactory
+from core.model_providers.models.entity.message import PromptMessage, MessageType
+from core.model_providers.models.entity.model_params import ModelKwargs
 from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
 from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
 
 
 from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
 from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
@@ -15,9 +12,6 @@ from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTempla
 from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \
 from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \
     GENERATOR_QA_PROMPT
     GENERATOR_QA_PROMPT
 
 
-# gpt-3.5-turbo works not well
-generate_base_model = 'text-davinci-003'
-
 
 
 class LLMGenerator:
 class LLMGenerator:
     @classmethod
     @classmethod
@@ -28,29 +22,35 @@ class LLMGenerator:
             query = query[:300] + "...[TRUNCATED]..." + query[-300:]
             query = query[:300] + "...[TRUNCATED]..." + query[-300:]
 
 
         prompt = prompt.format(query=query)
         prompt = prompt.format(query=query)
-        llm: StreamableOpenAI = LLMBuilder.to_llm(
+
+        model_instance = ModelFactory.get_text_generation_model(
             tenant_id=tenant_id,
             tenant_id=tenant_id,
-            model_name='gpt-3.5-turbo',
-            max_tokens=50,
-            timeout=600
+            model_kwargs=ModelKwargs(
+                max_tokens=50
+            )
         )
         )
 
 
-        if isinstance(llm, BaseChatModel):
-            prompt = [HumanMessage(content=prompt)]
-
-        response = llm.generate([prompt])
-        answer = response.generations[0][0].text
+        prompts = [PromptMessage(content=prompt)]
+        response = model_instance.run(prompts)
+        answer = response.content
         return answer.strip()
         return answer.strip()
 
 
     @classmethod
     @classmethod
     def generate_conversation_summary(cls, tenant_id: str, messages):
     def generate_conversation_summary(cls, tenant_id: str, messages):
         max_tokens = 200
         max_tokens = 200
-        model = 'gpt-3.5-turbo'
+
+        model_instance = ModelFactory.get_text_generation_model(
+            tenant_id=tenant_id,
+            model_kwargs=ModelKwargs(
+                max_tokens=max_tokens
+            )
+        )
 
 
         prompt = CONVERSATION_SUMMARY_PROMPT
         prompt = CONVERSATION_SUMMARY_PROMPT
         prompt_with_empty_context = prompt.format(context='')
         prompt_with_empty_context = prompt.format(context='')
-        prompt_tokens = TokenCalculator.get_num_tokens(model, prompt_with_empty_context)
-        rest_tokens = llm_constant.max_context_token_length[model] - prompt_tokens - max_tokens - 1
+        prompt_tokens = model_instance.get_num_tokens([PromptMessage(content=prompt_with_empty_context)])
+        max_context_token_length = model_instance.model_rules.max_tokens.max
+        rest_tokens = max_context_token_length - prompt_tokens - max_tokens - 1
 
 
         context = ''
         context = ''
         for message in messages:
         for message in messages:
@@ -68,25 +68,16 @@ class LLMGenerator:
                 answer = message.answer
                 answer = message.answer
 
 
             message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer
             message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer
-            if rest_tokens - TokenCalculator.get_num_tokens(model, context + message_qa_text) > 0:
+            if rest_tokens - model_instance.get_num_tokens([PromptMessage(content=context + message_qa_text)]) > 0:
                 context += message_qa_text
                 context += message_qa_text
 
 
         if not context:
         if not context:
             return '[message too long, no summary]'
             return '[message too long, no summary]'
 
 
         prompt = prompt.format(context=context)
         prompt = prompt.format(context=context)
-
-        llm: StreamableOpenAI = LLMBuilder.to_llm(
-            tenant_id=tenant_id,
-            model_name=model,
-            max_tokens=max_tokens
-        )
-
-        if isinstance(llm, BaseChatModel):
-            prompt = [HumanMessage(content=prompt)]
-
-        response = llm.generate([prompt])
-        answer = response.generations[0][0].text
+        prompts = [PromptMessage(content=prompt)]
+        response = model_instance.run(prompts)
+        answer = response.content
         return answer.strip()
         return answer.strip()
 
 
     @classmethod
     @classmethod
@@ -94,16 +85,13 @@ class LLMGenerator:
         prompt = INTRODUCTION_GENERATE_PROMPT
         prompt = INTRODUCTION_GENERATE_PROMPT
         prompt = prompt.format(prompt=pre_prompt)
         prompt = prompt.format(prompt=pre_prompt)
 
 
-        llm: StreamableOpenAI = LLMBuilder.to_llm(
-            tenant_id=tenant_id,
-            model_name=generate_base_model,
+        model_instance = ModelFactory.get_text_generation_model(
+            tenant_id=tenant_id
         )
         )
 
 
-        if isinstance(llm, BaseChatModel):
-            prompt = [HumanMessage(content=prompt)]
-
-        response = llm.generate([prompt])
-        answer = response.generations[0][0].text
+        prompts = [PromptMessage(content=prompt)]
+        response = model_instance.run(prompts)
+        answer = response.content
         return answer.strip()
         return answer.strip()
 
 
     @classmethod
     @classmethod
@@ -119,23 +107,19 @@ class LLMGenerator:
 
 
         _input = prompt.format_prompt(histories=histories)
         _input = prompt.format_prompt(histories=histories)
 
 
-        llm: StreamableOpenAI = LLMBuilder.to_llm(
+        model_instance = ModelFactory.get_text_generation_model(
             tenant_id=tenant_id,
             tenant_id=tenant_id,
-            model_name='gpt-3.5-turbo',
-            temperature=0,
-            max_tokens=256
+            model_kwargs=ModelKwargs(
+                max_tokens=256,
+                temperature=0
+            )
         )
         )
 
 
-        if isinstance(llm, BaseChatModel):
-            query = [HumanMessage(content=_input.to_string())]
-        else:
-            query = _input.to_string()
+        prompts = [PromptMessage(content=_input.to_string())]
 
 
         try:
         try:
-            output = llm(query)
-            if isinstance(output, BaseMessage):
-                output = output.content
-            questions = output_parser.parse(output)
+            output = model_instance.run(prompts)
+            questions = output_parser.parse(output.content)
         except Exception:
         except Exception:
             logging.exception("Error generating suggested questions after answer")
             logging.exception("Error generating suggested questions after answer")
             questions = []
             questions = []
@@ -160,21 +144,19 @@ class LLMGenerator:
 
 
         _input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve)
         _input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve)
 
 
-        llm: StreamableOpenAI = LLMBuilder.to_llm(
+        model_instance = ModelFactory.get_text_generation_model(
             tenant_id=tenant_id,
             tenant_id=tenant_id,
-            model_name=generate_base_model,
-            temperature=0,
-            max_tokens=512
+            model_kwargs=ModelKwargs(
+                max_tokens=512,
+                temperature=0
+            )
         )
         )
 
 
-        if isinstance(llm, BaseChatModel):
-            query = [HumanMessage(content=_input.to_string())]
-        else:
-            query = _input.to_string()
+        prompts = [PromptMessage(content=_input.to_string())]
 
 
         try:
         try:
-            output = llm(query)
-            rule_config = output_parser.parse(output)
+            output = model_instance.run(prompts)
+            rule_config = output_parser.parse(output.content)
         except OutputParserException:
         except OutputParserException:
             raise ValueError('Please give a valid input for intended audience or hoping to solve problems.')
             raise ValueError('Please give a valid input for intended audience or hoping to solve problems.')
         except Exception:
         except Exception:
@@ -188,25 +170,21 @@ class LLMGenerator:
         return rule_config
         return rule_config
 
 
     @classmethod
     @classmethod
-    async def generate_qa_document(cls, llm: StreamableOpenAI, query):
+    def generate_qa_document(cls, tenant_id: str, query):
         prompt = GENERATOR_QA_PROMPT
         prompt = GENERATOR_QA_PROMPT
 
 
+        model_instance = ModelFactory.get_text_generation_model(
+            tenant_id=tenant_id,
+            model_kwargs=ModelKwargs(
+                max_tokens=2000
+            )
+        )
 
 
-        if isinstance(llm, BaseChatModel):
-            prompt = [SystemMessage(content=prompt), HumanMessage(content=query)]
-
-        response = llm.generate([prompt])
-        answer = response.generations[0][0].text
-        return answer.strip()
-
-    @classmethod
-    def generate_qa_document_sync(cls, llm: StreamableOpenAI, query):
-        prompt = GENERATOR_QA_PROMPT
-
-
-        if isinstance(llm, BaseChatModel):
-            prompt = [SystemMessage(content=prompt), HumanMessage(content=query)]
+        prompts = [
+            PromptMessage(content=prompt, type=MessageType.SYSTEM),
+            PromptMessage(content=query)
+        ]
 
 
-        response = llm.generate([prompt])
-        answer = response.generations[0][0].text
+        response = model_instance.run(prompts)
+        answer = response.content
         return answer.strip()
         return answer.strip()

+ 0 - 0
api/tests/test_helpers/__init__.py → api/core/helper/__init__.py


+ 20 - 0
api/core/helper/encrypter.py

@@ -0,0 +1,20 @@
+import base64
+
+from extensions.ext_database import db
+from libs import rsa
+
+from models.account import Tenant
+
+
+def obfuscated_token(token: str):
+    return token[:6] + '*' * (len(token) - 8) + token[-2:]
+
+
+def encrypt_token(tenant_id: str, token: str):
+    tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).first()
+    encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
+    return base64.b64encode(encrypted_token).decode()
+
+
+def decrypt_token(tenant_id: str, token: str):
+    return rsa.decrypt(base64.b64decode(token), tenant_id)

+ 4 - 10
api/core/index/index.py

@@ -1,10 +1,9 @@
 from flask import current_app
 from flask import current_app
-from langchain.embeddings import OpenAIEmbeddings
 
 
 from core.embedding.cached_embedding import CacheEmbedding
 from core.embedding.cached_embedding import CacheEmbedding
 from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
 from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
 from core.index.vector_index.vector_index import VectorIndex
 from core.index.vector_index.vector_index import VectorIndex
-from core.llm.llm_builder import LLMBuilder
+from core.model_providers.model_factory import ModelFactory
 from models.dataset import Dataset
 from models.dataset import Dataset
 
 
 
 
@@ -15,16 +14,11 @@ class IndexBuilder:
             if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality':
             if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality':
                 return None
                 return None
 
 
-            model_credentials = LLMBuilder.get_model_credentials(
-                tenant_id=dataset.tenant_id,
-                model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'),
-                model_name='text-embedding-ada-002'
+            embedding_model = ModelFactory.get_embedding_model(
+                tenant_id=dataset.tenant_id
             )
             )
 
 
-            embeddings = CacheEmbedding(OpenAIEmbeddings(
-                max_retries=1,
-                **model_credentials
-            ))
+            embeddings = CacheEmbedding(embedding_model)
 
 
             return VectorIndex(
             return VectorIndex(
                 dataset=dataset,
                 dataset=dataset,

+ 46 - 43
api/core/indexing_runner.py

@@ -1,4 +1,3 @@
-import concurrent
 import datetime
 import datetime
 import json
 import json
 import logging
 import logging
@@ -6,7 +5,6 @@ import re
 import threading
 import threading
 import time
 import time
 import uuid
 import uuid
-from concurrent.futures import ThreadPoolExecutor
 from typing import Optional, List, cast
 from typing import Optional, List, cast
 
 
 from flask_login import current_user
 from flask_login import current_user
@@ -18,11 +16,10 @@ from core.data_loader.loader.notion import NotionLoader
 from core.docstore.dataset_docstore import DatesetDocumentStore
 from core.docstore.dataset_docstore import DatesetDocumentStore
 from core.generator.llm_generator import LLMGenerator
 from core.generator.llm_generator import LLMGenerator
 from core.index.index import IndexBuilder
 from core.index.index import IndexBuilder
-from core.llm.error import ProviderTokenNotInitError
-from core.llm.llm_builder import LLMBuilder
-from core.llm.streamable_open_ai import StreamableOpenAI
+from core.model_providers.error import ProviderTokenNotInitError
+from core.model_providers.model_factory import ModelFactory
+from core.model_providers.models.entity.message import MessageType
 from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
 from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
-from core.llm.token_calculator import TokenCalculator
 from extensions.ext_database import db
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from extensions.ext_redis import redis_client
 from extensions.ext_storage import storage
 from extensions.ext_storage import storage
@@ -35,9 +32,8 @@ from models.source import DataSourceBinding
 
 
 class IndexingRunner:
 class IndexingRunner:
 
 
-    def __init__(self, embedding_model_name: str = "text-embedding-ada-002"):
+    def __init__(self):
         self.storage = storage
         self.storage = storage
-        self.embedding_model_name = embedding_model_name
 
 
     def run(self, dataset_documents: List[DatasetDocument]):
     def run(self, dataset_documents: List[DatasetDocument]):
         """Run the indexing process."""
         """Run the indexing process."""
@@ -227,11 +223,15 @@ class IndexingRunner:
             dataset_document.stopped_at = datetime.datetime.utcnow()
             dataset_document.stopped_at = datetime.datetime.utcnow()
             db.session.commit()
             db.session.commit()
 
 
-    def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict,
+    def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict,
                                doc_form: str = None) -> dict:
                                doc_form: str = None) -> dict:
         """
         """
         Estimate the indexing for the document.
         Estimate the indexing for the document.
         """
         """
+        embedding_model = ModelFactory.get_embedding_model(
+            tenant_id=tenant_id
+        )
+
         tokens = 0
         tokens = 0
         preview_texts = []
         preview_texts = []
         total_segments = 0
         total_segments = 0
@@ -253,44 +253,49 @@ class IndexingRunner:
                 splitter=splitter,
                 splitter=splitter,
                 processing_rule=processing_rule
                 processing_rule=processing_rule
             )
             )
+
             total_segments += len(documents)
             total_segments += len(documents)
+
             for document in documents:
             for document in documents:
                 if len(preview_texts) < 5:
                 if len(preview_texts) < 5:
                     preview_texts.append(document.page_content)
                     preview_texts.append(document.page_content)
 
 
-                tokens += TokenCalculator.get_num_tokens(self.embedding_model_name,
-                                                         self.filter_string(document.page_content))
+                tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content))
+
+        text_generation_model = ModelFactory.get_text_generation_model(
+            tenant_id=tenant_id
+        )
+
         if doc_form and doc_form == 'qa_model':
         if doc_form and doc_form == 'qa_model':
             if len(preview_texts) > 0:
             if len(preview_texts) > 0:
                 # qa model document
                 # qa model document
-                llm: StreamableOpenAI = LLMBuilder.to_llm(
-                    tenant_id=current_user.current_tenant_id,
-                    model_name='gpt-3.5-turbo',
-                    max_tokens=2000
-                )
-                response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0])
+                response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0])
                 document_qa_list = self.format_split_text(response)
                 document_qa_list = self.format_split_text(response)
                 return {
                 return {
                     "total_segments": total_segments * 20,
                     "total_segments": total_segments * 20,
                     "tokens": total_segments * 2000,
                     "tokens": total_segments * 2000,
                     "total_price": '{:f}'.format(
                     "total_price": '{:f}'.format(
-                        TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')),
-                    "currency": TokenCalculator.get_currency(self.embedding_model_name),
+                        text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)),
+                    "currency": embedding_model.get_currency(),
                     "qa_preview": document_qa_list,
                     "qa_preview": document_qa_list,
                     "preview": preview_texts
                     "preview": preview_texts
                 }
                 }
         return {
         return {
             "total_segments": total_segments,
             "total_segments": total_segments,
             "tokens": tokens,
             "tokens": tokens,
-            "total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)),
-            "currency": TokenCalculator.get_currency(self.embedding_model_name),
+            "total_price": '{:f}'.format(embedding_model.get_token_price(tokens)),
+            "currency": embedding_model.get_currency(),
             "preview": preview_texts
             "preview": preview_texts
         }
         }
 
 
-    def notion_indexing_estimate(self, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict:
+    def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict:
         """
         """
         Estimate the indexing for the document.
         Estimate the indexing for the document.
         """
         """
+        embedding_model = ModelFactory.get_embedding_model(
+            tenant_id=tenant_id
+        )
+
         # load data from notion
         # load data from notion
         tokens = 0
         tokens = 0
         preview_texts = []
         preview_texts = []
@@ -336,31 +341,31 @@ class IndexingRunner:
                     if len(preview_texts) < 5:
                     if len(preview_texts) < 5:
                         preview_texts.append(document.page_content)
                         preview_texts.append(document.page_content)
 
 
-                    tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
+                    tokens += embedding_model.get_num_tokens(document.page_content)
+
+        text_generation_model = ModelFactory.get_text_generation_model(
+            tenant_id=tenant_id
+        )
+
         if doc_form and doc_form == 'qa_model':
         if doc_form and doc_form == 'qa_model':
             if len(preview_texts) > 0:
             if len(preview_texts) > 0:
                 # qa model document
                 # qa model document
-                llm: StreamableOpenAI = LLMBuilder.to_llm(
-                    tenant_id=current_user.current_tenant_id,
-                    model_name='gpt-3.5-turbo',
-                    max_tokens=2000
-                )
-                response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0])
+                response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0])
                 document_qa_list = self.format_split_text(response)
                 document_qa_list = self.format_split_text(response)
                 return {
                 return {
                     "total_segments": total_segments * 20,
                     "total_segments": total_segments * 20,
                     "tokens": total_segments * 2000,
                     "tokens": total_segments * 2000,
                     "total_price": '{:f}'.format(
                     "total_price": '{:f}'.format(
-                        TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')),
-                    "currency": TokenCalculator.get_currency(self.embedding_model_name),
+                        text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)),
+                    "currency": embedding_model.get_currency(),
                     "qa_preview": document_qa_list,
                     "qa_preview": document_qa_list,
                     "preview": preview_texts
                     "preview": preview_texts
                 }
                 }
         return {
         return {
             "total_segments": total_segments,
             "total_segments": total_segments,
             "tokens": tokens,
             "tokens": tokens,
-            "total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)),
-            "currency": TokenCalculator.get_currency(self.embedding_model_name),
+            "total_price": '{:f}'.format(embedding_model.get_token_price(tokens)),
+            "currency": embedding_model.get_currency(),
             "preview": preview_texts
             "preview": preview_texts
         }
         }
 
 
@@ -459,7 +464,6 @@ class IndexingRunner:
         doc_store = DatesetDocumentStore(
         doc_store = DatesetDocumentStore(
             dataset=dataset,
             dataset=dataset,
             user_id=dataset_document.created_by,
             user_id=dataset_document.created_by,
-            embedding_model_name=self.embedding_model_name,
             document_id=dataset_document.id
             document_id=dataset_document.id
         )
         )
 
 
@@ -513,17 +517,12 @@ class IndexingRunner:
             all_documents.extend(split_documents)
             all_documents.extend(split_documents)
         # processing qa document
         # processing qa document
         if document_form == 'qa_model':
         if document_form == 'qa_model':
-            llm: StreamableOpenAI = LLMBuilder.to_llm(
-                tenant_id=tenant_id,
-                model_name='gpt-3.5-turbo',
-                max_tokens=2000
-            )
             for i in range(0, len(all_documents), 10):
             for i in range(0, len(all_documents), 10):
                 threads = []
                 threads = []
                 sub_documents = all_documents[i:i + 10]
                 sub_documents = all_documents[i:i + 10]
                 for doc in sub_documents:
                 for doc in sub_documents:
                     document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={
                     document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={
-                        'llm': llm, 'document_node': doc, 'all_qa_documents': all_qa_documents})
+                        'tenant_id': tenant_id, 'document_node': doc, 'all_qa_documents': all_qa_documents})
                     threads.append(document_format_thread)
                     threads.append(document_format_thread)
                     document_format_thread.start()
                     document_format_thread.start()
                 for thread in threads:
                 for thread in threads:
@@ -531,13 +530,13 @@ class IndexingRunner:
             return all_qa_documents
             return all_qa_documents
         return all_documents
         return all_documents
 
 
-    def format_qa_document(self, llm: StreamableOpenAI, document_node, all_qa_documents):
+    def format_qa_document(self, tenant_id: str, document_node, all_qa_documents):
         format_documents = []
         format_documents = []
         if document_node.page_content is None or not document_node.page_content.strip():
         if document_node.page_content is None or not document_node.page_content.strip():
             return
             return
         try:
         try:
             # qa model document
             # qa model document
-            response = LLMGenerator.generate_qa_document_sync(llm, document_node.page_content)
+            response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content)
             document_qa_list = self.format_split_text(response)
             document_qa_list = self.format_split_text(response)
             qa_documents = []
             qa_documents = []
             for result in document_qa_list:
             for result in document_qa_list:
@@ -638,6 +637,10 @@ class IndexingRunner:
         vector_index = IndexBuilder.get_index(dataset, 'high_quality')
         vector_index = IndexBuilder.get_index(dataset, 'high_quality')
         keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
         keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
 
 
+        embedding_model = ModelFactory.get_embedding_model(
+            tenant_id=dataset.tenant_id
+        )
+
         # chunk nodes by chunk size
         # chunk nodes by chunk size
         indexing_start_at = time.perf_counter()
         indexing_start_at = time.perf_counter()
         tokens = 0
         tokens = 0
@@ -648,7 +651,7 @@ class IndexingRunner:
             chunk_documents = documents[i:i + chunk_size]
             chunk_documents = documents[i:i + chunk_size]
 
 
             tokens += sum(
             tokens += sum(
-                TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
+                embedding_model.get_num_tokens(document.page_content)
                 for document in chunk_documents
                 for document in chunk_documents
             )
             )
 
 

+ 0 - 148
api/core/llm/llm_builder.py

@@ -1,148 +0,0 @@
-from typing import Union, Optional, List
-
-from langchain.callbacks.base import BaseCallbackHandler
-
-from core.constant import llm_constant
-from core.llm.error import ProviderTokenNotInitError
-from core.llm.provider.base import BaseProvider
-from core.llm.provider.llm_provider_service import LLMProviderService
-from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI
-from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI
-from core.llm.streamable_chat_anthropic import StreamableChatAnthropic
-from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
-from core.llm.streamable_open_ai import StreamableOpenAI
-from models.provider import ProviderType, ProviderName
-
-
-class LLMBuilder:
-    """
-    This class handles the following logic:
-    1. For providers with the name 'OpenAI', the OPENAI_API_KEY value is stored directly in encrypted_config.
-    2. For providers with the name 'Azure OpenAI', encrypted_config stores the serialized values of four fields, as shown below:
-       OPENAI_API_TYPE=azure
-       OPENAI_API_VERSION=2022-12-01
-       OPENAI_API_BASE=https://your-resource-name.openai.azure.com
-       OPENAI_API_KEY=<your Azure OpenAI API key>
-    3. For providers with the name 'Anthropic', the ANTHROPIC_API_KEY value is stored directly in encrypted_config.
-    4. For providers with the name 'Cohere', the COHERE_API_KEY value is stored directly in encrypted_config.
-    5. For providers with the name 'HUGGINGFACEHUB', the HUGGINGFACEHUB_API_KEY value is stored directly in encrypted_config.
-    6. Providers with the provider_type 'CUSTOM' can be created through the admin interface, while 'System' providers cannot be created through the admin interface.
-    7. If both CUSTOM and System providers exist in the records, the CUSTOM provider is preferred by default, but this preference can be changed via an input parameter.
-    8. For providers with the provider_type 'System', the quota_used must not exceed quota_limit. If the quota is exceeded, the provider cannot be used. Currently, only the TRIAL quota_type is supported, which is permanently non-resetting.
-    """
-
-    @classmethod
-    def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
-        provider = cls.get_default_provider(tenant_id, model_name)
-
-        model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
-
-        llm_cls = None
-        mode = cls.get_mode_by_model(model_name)
-        if mode == 'chat':
-            if provider == ProviderName.OPENAI.value:
-                llm_cls = StreamableChatOpenAI
-            elif provider == ProviderName.AZURE_OPENAI.value:
-                llm_cls = StreamableAzureChatOpenAI
-            elif provider == ProviderName.ANTHROPIC.value:
-                llm_cls = StreamableChatAnthropic
-        elif mode == 'completion':
-            if provider == ProviderName.OPENAI.value:
-                llm_cls = StreamableOpenAI
-            elif provider == ProviderName.AZURE_OPENAI.value:
-                llm_cls = StreamableAzureOpenAI
-
-        if not llm_cls:
-            raise ValueError(f"model name {model_name} is not supported.")
-
-        model_kwargs = {
-            'model_name': model_name,
-            'temperature': kwargs.get('temperature', 0),
-            'max_tokens': kwargs.get('max_tokens', 256),
-            'top_p': kwargs.get('top_p', 1),
-            'frequency_penalty': kwargs.get('frequency_penalty', 0),
-            'presence_penalty': kwargs.get('presence_penalty', 0),
-            'callbacks': kwargs.get('callbacks', None),
-            'streaming': kwargs.get('streaming', False),
-        }
-
-        model_kwargs.update(model_credentials)
-        model_kwargs = llm_cls.get_kwargs_from_model_params(model_kwargs)
-
-        return llm_cls(**model_kwargs)
-
-    @classmethod
-    def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
-                          callbacks: Optional[List[BaseCallbackHandler]] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
-        model_name = model.get("name")
-        completion_params = model.get("completion_params", {})
-
-        return cls.to_llm(
-            tenant_id=tenant_id,
-            model_name=model_name,
-            temperature=completion_params.get('temperature', 0),
-            max_tokens=completion_params.get('max_tokens', 256),
-            top_p=completion_params.get('top_p', 0),
-            frequency_penalty=completion_params.get('frequency_penalty', 0.1),
-            presence_penalty=completion_params.get('presence_penalty', 0.1),
-            streaming=streaming,
-            callbacks=callbacks
-        )
-
-    @classmethod
-    def get_mode_by_model(cls, model_name: str) -> str:
-        if not model_name:
-            raise ValueError(f"empty model name is not supported.")
-
-        if model_name in llm_constant.models_by_mode['chat']:
-            return "chat"
-        elif model_name in llm_constant.models_by_mode['completion']:
-            return "completion"
-        else:
-            raise ValueError(f"model name {model_name} is not supported.")
-
-    @classmethod
-    def get_model_credentials(cls, tenant_id: str, model_provider: str, model_name: str) -> dict:
-        """
-        Returns the API credentials for the given tenant_id and model_name, based on the model's provider.
-        Raises an exception if the model_name is not found or if the provider is not found.
-        """
-        if not model_name:
-            raise Exception('model name not found')
-        #
-        # if model_name not in llm_constant.models:
-        #     raise Exception('model {} not found'.format(model_name))
-
-        # model_provider = llm_constant.models[model_name]
-
-        provider_service = LLMProviderService(tenant_id=tenant_id, provider_name=model_provider)
-        return provider_service.get_credentials(model_name)
-
-    @classmethod
-    def get_default_provider(cls, tenant_id: str, model_name: str) -> str:
-        provider_name = llm_constant.models[model_name]
-
-        if provider_name == 'openai':
-            # get the default provider (openai / azure_openai) for the tenant
-            openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.OPENAI.value)
-            azure_openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.AZURE_OPENAI.value)
-
-            provider = None
-            if openai_provider and openai_provider.provider_type == ProviderType.CUSTOM.value:
-                provider = openai_provider
-            elif azure_openai_provider and azure_openai_provider.provider_type == ProviderType.CUSTOM.value:
-                provider = azure_openai_provider
-            elif openai_provider and openai_provider.provider_type == ProviderType.SYSTEM.value:
-                provider = openai_provider
-            elif azure_openai_provider and azure_openai_provider.provider_type == ProviderType.SYSTEM.value:
-                provider = azure_openai_provider
-
-            if not provider:
-                raise ProviderTokenNotInitError(
-                    f"No valid {provider_name} model provider credentials found. "
-                    f"Please go to Settings -> Model Provider to complete your provider credentials."
-                )
-
-            provider_name = provider.provider_name
-
-        return provider_name

+ 0 - 15
api/core/llm/moderation.py

@@ -1,15 +0,0 @@
-import openai
-from models.provider import ProviderName
-
-
-class Moderation:
-
-    def __init__(self, provider: str, api_key: str):
-        self.provider = provider
-        self.api_key = api_key
-
-        if self.provider == ProviderName.OPENAI.value:
-            self.client = openai.Moderation
-
-    def moderate(self, text):
-        return self.client.create(input=text, api_key=self.api_key)

+ 0 - 138
api/core/llm/provider/anthropic_provider.py

@@ -1,138 +0,0 @@
-import json
-import logging
-from typing import Optional, Union
-
-import anthropic
-from langchain.chat_models import ChatAnthropic
-from langchain.schema import HumanMessage
-
-from core import hosted_llm_credentials
-from core.llm.error import ProviderTokenNotInitError
-from core.llm.provider.base import BaseProvider
-from core.llm.provider.errors import ValidateFailedError
-from models.provider import ProviderName, ProviderType
-
-
-class AnthropicProvider(BaseProvider):
-    def get_models(self, model_id: Optional[str] = None) -> list[dict]:
-        return [
-            {
-                'id': 'claude-instant-1',
-                'name': 'claude-instant-1',
-            },
-            {
-                'id': 'claude-2',
-                'name': 'claude-2',
-            },
-        ]
-
-    def get_credentials(self, model_id: Optional[str] = None) -> dict:
-        return self.get_provider_api_key(model_id=model_id)
-
-    def get_provider_name(self):
-        return ProviderName.ANTHROPIC
-
-    def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
-        """
-        Returns the provider configs.
-        """
-        try:
-            config = self.get_provider_api_key(only_custom=only_custom)
-        except:
-            config = {
-                'anthropic_api_key': ''
-            }
-
-        if obfuscated:
-            if not config.get('anthropic_api_key'):
-                config = {
-                    'anthropic_api_key': ''
-                }
-
-            config['anthropic_api_key'] = self.obfuscated_token(config.get('anthropic_api_key'))
-            return config
-
-        return config
-
-    def get_encrypted_token(self, config: Union[dict | str]):
-        """
-        Returns the encrypted token.
-        """
-        return json.dumps({
-            'anthropic_api_key': self.encrypt_token(config['anthropic_api_key'])
-        })
-
-    def get_decrypted_token(self, token: str):
-        """
-        Returns the decrypted token.
-        """
-        config = json.loads(token)
-        config['anthropic_api_key'] = self.decrypt_token(config['anthropic_api_key'])
-        return config
-
-    def get_token_type(self):
-        return dict
-
-    def config_validate(self, config: Union[dict | str]):
-        """
-        Validates the given config.
-        """
-        # check OpenAI / Azure OpenAI credential is valid
-        openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.OPENAI.value)
-        azure_openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.AZURE_OPENAI.value)
-
-        provider = None
-        if openai_provider:
-            provider = openai_provider
-        elif azure_openai_provider:
-            provider = azure_openai_provider
-
-        if not provider:
-            raise ValidateFailedError(f"OpenAI or Azure OpenAI provider must be configured first.")
-
-        if provider.provider_type == ProviderType.SYSTEM.value:
-            quota_used = provider.quota_used if provider.quota_used is not None else 0
-            quota_limit = provider.quota_limit if provider.quota_limit is not None else 0
-            if quota_used >= quota_limit:
-                raise ValidateFailedError(f"Your quota for Dify Hosted OpenAI has been exhausted, "
-                                          f"please configure OpenAI or Azure OpenAI provider first.")
-
-        try:
-            if not isinstance(config, dict):
-                raise ValueError('Config must be a object.')
-
-            if 'anthropic_api_key' not in config:
-                raise ValueError('anthropic_api_key must be provided.')
-
-            chat_llm = ChatAnthropic(
-                model='claude-instant-1',
-                anthropic_api_key=config['anthropic_api_key'],
-                max_tokens_to_sample=10,
-                temperature=0,
-                default_request_timeout=60
-            )
-
-            messages = [
-                HumanMessage(
-                    content="ping"
-                )
-            ]
-
-            chat_llm(messages)
-        except anthropic.APIConnectionError as ex:
-            raise ValidateFailedError(f"Anthropic: Connection error, cause: {ex.__cause__}")
-        except (anthropic.APIStatusError, anthropic.RateLimitError) as ex:
-            raise ValidateFailedError(f"Anthropic: Error code: {ex.status_code} - "
-                                      f"{ex.body['error']['type']}: {ex.body['error']['message']}")
-        except Exception as ex:
-            logging.exception('Anthropic config validation failed')
-            raise ex
-
-    def get_hosted_credentials(self) -> Union[str | dict]:
-        if not hosted_llm_credentials.anthropic or not hosted_llm_credentials.anthropic.api_key:
-            raise ProviderTokenNotInitError(
-                f"No valid {self.get_provider_name().value} model provider credentials found. "
-                f"Please go to Settings -> Model Provider to complete your provider credentials."
-            )
-
-        return {'anthropic_api_key': hosted_llm_credentials.anthropic.api_key}

+ 0 - 145
api/core/llm/provider/azure_provider.py

@@ -1,145 +0,0 @@
-import json
-import logging
-from typing import Optional, Union
-
-import openai
-import requests
-
-from core.llm.provider.base import BaseProvider
-from core.llm.provider.errors import ValidateFailedError
-from models.provider import ProviderName
-
-
-AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
-
-
-class AzureProvider(BaseProvider):
-    def get_models(self, model_id: Optional[str] = None, credentials: Optional[dict] = None) -> list[dict]:
-        return []
-
-    def check_embedding_model(self, credentials: Optional[dict] = None):
-        credentials = self.get_credentials('text-embedding-ada-002') if not credentials else credentials
-        try:
-            result = openai.Embedding.create(input=['test'],
-                                             engine='text-embedding-ada-002',
-                                             timeout=60,
-                                             api_key=str(credentials.get('openai_api_key')),
-                                             api_base=str(credentials.get('openai_api_base')),
-                                             api_type='azure',
-                                             api_version=str(credentials.get('openai_api_version')))["data"][0][
-                "embedding"]
-        except openai.error.AuthenticationError as e:
-            raise AzureAuthenticationError(str(e))
-        except openai.error.APIConnectionError as e:
-            raise AzureRequestFailedError(
-                'Failed to request Azure OpenAI, please check your API Base Endpoint, The format is `https://xxx.openai.azure.com/`')
-        except openai.error.InvalidRequestError as e:
-            if e.http_status == 404:
-                raise AzureRequestFailedError("Please check your 'gpt-3.5-turbo' or 'text-embedding-ada-002' "
-                                              "deployment name is exists in Azure AI")
-            else:
-                raise AzureRequestFailedError(
-                    'Failed to request Azure OpenAI. cause: {}'.format(str(e)))
-        except openai.error.OpenAIError as e:
-            raise AzureRequestFailedError(
-                'Failed to request Azure OpenAI. cause: {}'.format(str(e)))
-
-        if not isinstance(result, list):
-            raise AzureRequestFailedError('Failed to request Azure OpenAI.')
-
-    def get_credentials(self, model_id: Optional[str] = None) -> dict:
-        """
-        Returns the API credentials for Azure OpenAI as a dictionary.
-        """
-        config = self.get_provider_api_key(model_id=model_id)
-        config['openai_api_type'] = 'azure'
-        config['openai_api_version'] = AZURE_OPENAI_API_VERSION
-        if model_id == 'text-embedding-ada-002':
-            config['deployment'] = model_id.replace('.', '') if model_id else None
-            config['chunk_size'] = 16
-        else:
-            config['deployment_name'] = model_id.replace('.', '') if model_id else None
-        return config
-
-    def get_provider_name(self):
-        return ProviderName.AZURE_OPENAI
-
-    def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
-        """
-        Returns the provider configs.
-        """
-        try:
-            config = self.get_provider_api_key(only_custom=only_custom)
-        except:
-            config = {
-                'openai_api_type': 'azure',
-                'openai_api_version': AZURE_OPENAI_API_VERSION,
-                'openai_api_base': '',
-                'openai_api_key': ''
-            }
-
-        if obfuscated:
-            if not config.get('openai_api_key'):
-                config = {
-                    'openai_api_type': 'azure',
-                    'openai_api_version': AZURE_OPENAI_API_VERSION,
-                    'openai_api_base': '',
-                    'openai_api_key': ''
-                }
-
-            config['openai_api_key'] = self.obfuscated_token(config.get('openai_api_key'))
-            return config
-
-        return config
-
-    def get_token_type(self):
-        return dict
-
-    def config_validate(self, config: Union[dict | str]):
-        """
-        Validates the given config.
-        """
-        try:
-            if not isinstance(config, dict):
-                raise ValueError('Config must be a object.')
-
-            if 'openai_api_version' not in config:
-                config['openai_api_version'] = AZURE_OPENAI_API_VERSION
-
-            self.check_embedding_model(credentials=config)
-        except ValidateFailedError as e:
-            raise e
-        except AzureAuthenticationError:
-            raise ValidateFailedError('Validation failed, please check your API Key.')
-        except AzureRequestFailedError as ex:
-            raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
-        except Exception as ex:
-            logging.exception('Azure OpenAI Credentials validation failed')
-            raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
-
-    def get_encrypted_token(self, config: Union[dict | str]):
-        """
-        Returns the encrypted token.
-        """
-        return json.dumps({
-            'openai_api_type': 'azure',
-            'openai_api_version': AZURE_OPENAI_API_VERSION,
-            'openai_api_base': config['openai_api_base'],
-            'openai_api_key': self.encrypt_token(config['openai_api_key'])
-        })
-
-    def get_decrypted_token(self, token: str):
-        """
-        Returns the decrypted token.
-        """
-        config = json.loads(token)
-        config['openai_api_key'] = self.decrypt_token(config['openai_api_key'])
-        return config
-
-
-class AzureAuthenticationError(Exception):
-    pass
-
-
-class AzureRequestFailedError(Exception):
-    pass

+ 0 - 132
api/core/llm/provider/base.py

@@ -1,132 +0,0 @@
-import base64
-from abc import ABC, abstractmethod
-from typing import Optional, Union
-
-from core.constant import llm_constant
-from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError
-from extensions.ext_database import db
-from libs import rsa
-from models.account import Tenant
-from models.provider import Provider, ProviderType, ProviderName
-
-
-class BaseProvider(ABC):
-    def __init__(self, tenant_id: str):
-        self.tenant_id = tenant_id
-
-    def get_provider_api_key(self, model_id: Optional[str] = None, only_custom: bool = False) -> Union[str | dict]:
-        """
-        Returns the decrypted API key for the given tenant_id and provider_name.
-        If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
-        If the provider is not found or not valid, raises a ProviderTokenNotInitError.
-        """
-        provider = self.get_provider(only_custom)
-        if not provider:
-            raise ProviderTokenNotInitError(
-                f"No valid {llm_constant.models[model_id]} model provider credentials found. "
-                f"Please go to Settings -> Model Provider to complete your provider credentials."
-            )
-
-        if provider.provider_type == ProviderType.SYSTEM.value:
-            quota_used = provider.quota_used if provider.quota_used is not None else 0
-            quota_limit = provider.quota_limit if provider.quota_limit is not None else 0
-
-            if model_id and model_id == 'gpt-4':
-                raise ModelCurrentlyNotSupportError()
-
-            if quota_used >= quota_limit:
-                raise QuotaExceededError()
-
-            return self.get_hosted_credentials()
-        else:
-            return self.get_decrypted_token(provider.encrypted_config)
-
-    def get_provider(self, only_custom: bool = False) -> Optional[Provider]:
-        """
-        Returns the Provider instance for the given tenant_id and provider_name.
-        If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
-        """
-        return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, only_custom)
-
-    @classmethod
-    def get_valid_provider(cls, tenant_id: str, provider_name: str = None, only_custom: bool = False) -> Optional[
-        Provider]:
-        """
-        Returns the Provider instance for the given tenant_id and provider_name.
-        If both CUSTOM and System providers exist.
-        """
-        query = db.session.query(Provider).filter(
-            Provider.tenant_id == tenant_id
-        )
-
-        if provider_name:
-            query = query.filter(Provider.provider_name == provider_name)
-
-        if only_custom:
-            query = query.filter(Provider.provider_type == ProviderType.CUSTOM.value)
-
-        providers = query.order_by(Provider.provider_type.asc()).all()
-
-        for provider in providers:
-            if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config:
-                return provider
-            elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid:
-                return provider
-
-        return None
-
-    def get_hosted_credentials(self) -> Union[str | dict]:
-        raise ProviderTokenNotInitError(
-            f"No valid {self.get_provider_name().value} model provider credentials found. "
-            f"Please go to Settings -> Model Provider to complete your provider credentials."
-        )
-
-    def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
-        """
-        Returns the provider configs.
-        """
-        try:
-            config = self.get_provider_api_key(only_custom=only_custom)
-        except:
-            config = ''
-
-        if obfuscated:
-            return self.obfuscated_token(config)
-
-        return config
-
-    def obfuscated_token(self, token: str):
-        return token[:6] + '*' * (len(token) - 8) + token[-2:]
-
-    def get_token_type(self):
-        return str
-
-    def get_encrypted_token(self, config: Union[dict | str]):
-        return self.encrypt_token(config)
-
-    def get_decrypted_token(self, token: str):
-        return self.decrypt_token(token)
-
-    def encrypt_token(self, token):
-        tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
-        encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
-        return base64.b64encode(encrypted_token).decode()
-
-    def decrypt_token(self, token):
-        return rsa.decrypt(base64.b64decode(token), self.tenant_id)
-
-    @abstractmethod
-    def get_provider_name(self):
-        raise NotImplementedError
-
-    @abstractmethod
-    def get_credentials(self, model_id: Optional[str] = None) -> dict:
-        raise NotImplementedError
-
-    @abstractmethod
-    def get_models(self, model_id: Optional[str] = None) -> list[dict]:
-        raise NotImplementedError
-
-    @abstractmethod
-    def config_validate(self, config: str):
-        raise NotImplementedError

+ 0 - 2
api/core/llm/provider/errors.py

@@ -1,2 +0,0 @@
-class ValidateFailedError(Exception):
-    description = "Provider Validate failed"

+ 0 - 22
api/core/llm/provider/huggingface_provider.py

@@ -1,22 +0,0 @@
-from typing import Optional
-
-from core.llm.provider.base import BaseProvider
-from models.provider import ProviderName
-
-
-class HuggingfaceProvider(BaseProvider):
-    def get_models(self, model_id: Optional[str] = None) -> list[dict]:
-        credentials = self.get_credentials(model_id)
-        # todo
-        return []
-
-    def get_credentials(self, model_id: Optional[str] = None) -> dict:
-        """
-        Returns the API credentials for Huggingface as a dictionary, for the given tenant_id.
-        """
-        return {
-            'huggingface_api_key': self.get_provider_api_key(model_id=model_id)
-        }
-
-    def get_provider_name(self):
-        return ProviderName.HUGGINGFACEHUB

+ 0 - 53
api/core/llm/provider/llm_provider_service.py

@@ -1,53 +0,0 @@
-from typing import Optional, Union
-
-from core.llm.provider.anthropic_provider import AnthropicProvider
-from core.llm.provider.azure_provider import AzureProvider
-from core.llm.provider.base import BaseProvider
-from core.llm.provider.huggingface_provider import HuggingfaceProvider
-from core.llm.provider.openai_provider import OpenAIProvider
-from models.provider import Provider
-
-
-class LLMProviderService:
-
-    def __init__(self, tenant_id: str, provider_name: str):
-        self.provider = self.init_provider(tenant_id, provider_name)
-
-    def init_provider(self, tenant_id: str, provider_name: str) -> BaseProvider:
-        if provider_name == 'openai':
-            return OpenAIProvider(tenant_id)
-        elif provider_name == 'azure_openai':
-            return AzureProvider(tenant_id)
-        elif provider_name == 'anthropic':
-            return AnthropicProvider(tenant_id)
-        elif provider_name == 'huggingface':
-            return HuggingfaceProvider(tenant_id)
-        else:
-            raise Exception('provider {} not found'.format(provider_name))
-
-    def get_models(self, model_id: Optional[str] = None) -> list[dict]:
-        return self.provider.get_models(model_id)
-
-    def get_credentials(self, model_id: Optional[str] = None) -> dict:
-        return self.provider.get_credentials(model_id)
-
-    def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
-        return self.provider.get_provider_configs(obfuscated=obfuscated, only_custom=only_custom)
-
-    def get_provider_db_record(self) -> Optional[Provider]:
-        return self.provider.get_provider()
-
-    def config_validate(self, config: Union[dict | str]):
-        """
-        Validates the given config.
-
-        :param config:
-        :raises: ValidateFailedError
-        """
-        return self.provider.config_validate(config)
-
-    def get_token_type(self):
-        return self.provider.get_token_type()
-
-    def get_encrypted_token(self, config: Union[dict | str]):
-        return self.provider.get_encrypted_token(config)

+ 0 - 55
api/core/llm/provider/openai_provider.py

@@ -1,55 +0,0 @@
-import logging
-from typing import Optional, Union
-
-import openai
-from openai.error import AuthenticationError, OpenAIError
-
-from core import hosted_llm_credentials
-from core.llm.error import ProviderTokenNotInitError
-from core.llm.moderation import Moderation
-from core.llm.provider.base import BaseProvider
-from core.llm.provider.errors import ValidateFailedError
-from models.provider import ProviderName
-
-
-class OpenAIProvider(BaseProvider):
-    def get_models(self, model_id: Optional[str] = None) -> list[dict]:
-        credentials = self.get_credentials(model_id)
-        response = openai.Model.list(**credentials)
-
-        return [{
-            'id': model['id'],
-            'name': model['id'],
-        } for model in response['data']]
-
-    def get_credentials(self, model_id: Optional[str] = None) -> dict:
-        """
-        Returns the credentials for the given tenant_id and provider_name.
-        """
-        return {
-            'openai_api_key': self.get_provider_api_key(model_id=model_id)
-        }
-
-    def get_provider_name(self):
-        return ProviderName.OPENAI
-
-    def config_validate(self, config: Union[dict | str]):
-        """
-        Validates the given config.
-        """
-        try:
-            Moderation(self.get_provider_name().value, config).moderate('test')
-        except (AuthenticationError, OpenAIError) as ex:
-            raise ValidateFailedError(str(ex))
-        except Exception as ex:
-            logging.exception('OpenAI config validation failed')
-            raise ex
-
-    def get_hosted_credentials(self) -> Union[str | dict]:
-        if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key:
-            raise ProviderTokenNotInitError(
-                f"No valid {self.get_provider_name().value} model provider credentials found. "
-                f"Please go to Settings -> Model Provider to complete your provider credentials."
-            )
-
-        return hosted_llm_credentials.openai.api_key

+ 0 - 62
api/core/llm/streamable_chat_anthropic.py

@@ -1,62 +0,0 @@
-from typing import List, Optional, Any, Dict
-
-from httpx import Timeout
-from langchain.callbacks.manager import Callbacks
-from langchain.chat_models import ChatAnthropic
-from langchain.schema import BaseMessage, LLMResult, SystemMessage, AIMessage, HumanMessage, ChatMessage
-from pydantic import root_validator
-
-from core.llm.wrappers.anthropic_wrapper import handle_anthropic_exceptions
-
-
-class StreamableChatAnthropic(ChatAnthropic):
-    """
-    Wrapper around Anthropic's large language model.
-    """
-
-    default_request_timeout: Optional[float] = Timeout(timeout=300.0, connect=5.0)
-
-    @root_validator()
-    def prepare_params(cls, values: Dict) -> Dict:
-        values['model_name'] = values.get('model')
-        values['max_tokens'] = values.get('max_tokens_to_sample')
-        return values
-
-    @handle_anthropic_exceptions
-    def generate(
-            self,
-            messages: List[List[BaseMessage]],
-            stop: Optional[List[str]] = None,
-            callbacks: Callbacks = None,
-            *,
-            tags: Optional[List[str]] = None,
-            metadata: Optional[Dict[str, Any]] = None,
-            **kwargs: Any,
-    ) -> LLMResult:
-        return super().generate(messages, stop, callbacks, tags=tags, metadata=metadata, **kwargs)
-
-    @classmethod
-    def get_kwargs_from_model_params(cls, params: dict):
-        params['model'] = params.get('model_name')
-        del params['model_name']
-
-        params['max_tokens_to_sample'] = params.get('max_tokens')
-        del params['max_tokens']
-
-        del params['frequency_penalty']
-        del params['presence_penalty']
-
-        return params
-
-    def _convert_one_message_to_text(self, message: BaseMessage) -> str:
-        if isinstance(message, ChatMessage):
-            message_text = f"\n\n{message.role.capitalize()}: {message.content}"
-        elif isinstance(message, HumanMessage):
-            message_text = f"{self.HUMAN_PROMPT} {message.content}"
-        elif isinstance(message, AIMessage):
-            message_text = f"{self.AI_PROMPT} {message.content}"
-        elif isinstance(message, SystemMessage):
-            message_text = f"<admin>{message.content}</admin>"
-        else:
-            raise ValueError(f"Got unknown type {message}")
-        return message_text

+ 0 - 41
api/core/llm/token_calculator.py

@@ -1,41 +0,0 @@
-import decimal
-from typing import Optional
-
-import tiktoken
-
-from core.constant import llm_constant
-
-
-class TokenCalculator:
-    @classmethod
-    def get_num_tokens(cls, model_name: str, text: str):
-        if len(text) == 0:
-            return 0
-
-        enc = tiktoken.encoding_for_model(model_name)
-
-        tokenized_text = enc.encode(text)
-
-        # calculate the number of tokens in the encoded text
-        return len(tokenized_text)
-
-    @classmethod
-    def get_token_price(cls, model_name: str, tokens: int, text_type: Optional[str] = None) -> decimal.Decimal:
-        if model_name in llm_constant.models_by_mode['embedding']:
-            unit_price = llm_constant.model_prices[model_name]['usage']
-        elif text_type == 'prompt':
-            unit_price = llm_constant.model_prices[model_name]['prompt']
-        elif text_type == 'completion':
-            unit_price = llm_constant.model_prices[model_name]['completion']
-        else:
-            raise Exception('Invalid text type')
-
-        tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
-                                                                  rounding=decimal.ROUND_HALF_UP)
-
-        total_price = tokens_per_1k * unit_price
-        return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
-
-    @classmethod
-    def get_currency(cls, model_name: str):
-        return llm_constant.model_currency

+ 0 - 26
api/core/llm/whisper.py

@@ -1,26 +0,0 @@
-import openai
-
-from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
-from models.provider import ProviderName
-from core.llm.provider.base import BaseProvider
-
-
-class Whisper:
-
-    def __init__(self, provider: BaseProvider):
-        self.provider = provider
-
-        if self.provider.get_provider_name() == ProviderName.OPENAI:
-            self.client = openai.Audio
-            self.credentials = provider.get_credentials()
-
-    @handle_openai_exceptions
-    def transcribe(self, file):
-        return self.client.transcribe(
-            model='whisper-1', 
-            file=file,
-            api_key=self.credentials.get('openai_api_key'),
-            api_base=self.credentials.get('openai_api_base'),
-            api_type=self.credentials.get('openai_api_type'),
-            api_version=self.credentials.get('openai_api_version'),
-        )

+ 0 - 27
api/core/llm/wrappers/anthropic_wrapper.py

@@ -1,27 +0,0 @@
-import logging
-from functools import wraps
-
-import anthropic
-
-from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \
-    LLMBadRequestError
-
-
-def handle_anthropic_exceptions(func):
-    @wraps(func)
-    def wrapper(*args, **kwargs):
-        try:
-            return func(*args, **kwargs)
-        except anthropic.APIConnectionError as e:
-            logging.exception("Failed to connect to Anthropic API.")
-            raise LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {e.__cause__}")
-        except anthropic.RateLimitError:
-            raise LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.")
-        except anthropic.AuthenticationError as e:
-            raise LLMAuthorizationError(f"Anthropic: {e.message}")
-        except anthropic.BadRequestError as e:
-            raise LLMBadRequestError(f"Anthropic: {e.message}")
-        except anthropic.APIStatusError as e:
-            raise LLMAPIUnavailableError(f"Anthropic: code: {e.status_code}, cause: {e.message}")
-
-    return wrapper

+ 0 - 31
api/core/llm/wrappers/openai_wrapper.py

@@ -1,31 +0,0 @@
-import logging
-from functools import wraps
-
-import openai
-
-from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \
-    LLMBadRequestError
-
-
-def handle_openai_exceptions(func):
-    @wraps(func)
-    def wrapper(*args, **kwargs):
-        try:
-            return func(*args, **kwargs)
-        except openai.error.InvalidRequestError as e:
-            logging.exception("Invalid request to OpenAI API.")
-            raise LLMBadRequestError(str(e))
-        except openai.error.APIConnectionError as e:
-            logging.exception("Failed to connect to OpenAI API.")
-            raise LLMAPIConnectionError(e.__class__.__name__ + ":" + str(e))
-        except (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout) as e:
-            logging.exception("OpenAI service unavailable.")
-            raise LLMAPIUnavailableError(e.__class__.__name__ + ":" + str(e))
-        except openai.error.RateLimitError as e:
-            raise LLMRateLimitError(str(e))
-        except openai.error.AuthenticationError as e:
-            raise LLMAuthorizationError(str(e))
-        except openai.error.OpenAIError as e:
-            raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e))
-
-    return wrapper

+ 12 - 12
api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py

@@ -1,10 +1,10 @@
-from typing import Any, List, Dict, Union
+from typing import Any, List, Dict
 
 
 from langchain.memory.chat_memory import BaseChatMemory
 from langchain.memory.chat_memory import BaseChatMemory
-from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage, BaseLanguageModel
+from langchain.schema import get_buffer_string, BaseMessage
 
 
-from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
-from core.llm.streamable_open_ai import StreamableOpenAI
+from core.model_providers.models.entity.message import PromptMessage, MessageType, to_lc_messages
+from core.model_providers.models.llm.base import BaseLLM
 from extensions.ext_database import db
 from extensions.ext_database import db
 from models.model import Conversation, Message
 from models.model import Conversation, Message
 
 
@@ -13,7 +13,7 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
     conversation: Conversation
     conversation: Conversation
     human_prefix: str = "Human"
     human_prefix: str = "Human"
     ai_prefix: str = "Assistant"
     ai_prefix: str = "Assistant"
-    llm: BaseLanguageModel
+    model_instance: BaseLLM
     memory_key: str = "chat_history"
     memory_key: str = "chat_history"
     max_token_limit: int = 2000
     max_token_limit: int = 2000
     message_limit: int = 10
     message_limit: int = 10
@@ -29,23 +29,23 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
 
 
         messages = list(reversed(messages))
         messages = list(reversed(messages))
 
 
-        chat_messages: List[BaseMessage] = []
+        chat_messages: List[PromptMessage] = []
         for message in messages:
         for message in messages:
-            chat_messages.append(HumanMessage(content=message.query))
-            chat_messages.append(AIMessage(content=message.answer))
+            chat_messages.append(PromptMessage(content=message.query, type=MessageType.HUMAN))
+            chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT))
 
 
         if not chat_messages:
         if not chat_messages:
-            return chat_messages
+            return []
 
 
         # prune the chat message if it exceeds the max token limit
         # prune the chat message if it exceeds the max token limit
-        curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
+        curr_buffer_length = self.model_instance.get_num_tokens(chat_messages)
         if curr_buffer_length > self.max_token_limit:
         if curr_buffer_length > self.max_token_limit:
             pruned_memory = []
             pruned_memory = []
             while curr_buffer_length > self.max_token_limit and chat_messages:
             while curr_buffer_length > self.max_token_limit and chat_messages:
                 pruned_memory.append(chat_messages.pop(0))
                 pruned_memory.append(chat_messages.pop(0))
-                curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
+                curr_buffer_length = self.model_instance.get_num_tokens(chat_messages)
 
 
-        return chat_messages
+        return to_lc_messages(chat_messages)
 
 
     @property
     @property
     def memory_variables(self) -> List[str]:
     def memory_variables(self) -> List[str]:

+ 0 - 0
api/core/llm/error.py → api/core/model_providers/error.py


+ 293 - 0
api/core/model_providers/model_factory.py

@@ -0,0 +1,293 @@
+from typing import Optional
+
+from langchain.callbacks.base import Callbacks
+
+from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
+from core.model_providers.model_provider_factory import ModelProviderFactory, DEFAULT_MODELS
+from core.model_providers.models.base import BaseProviderModel
+from core.model_providers.models.embedding.base import BaseEmbedding
+from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
+from core.model_providers.models.llm.base import BaseLLM
+from core.model_providers.models.speech2text.base import BaseSpeech2Text
+from extensions.ext_database import db
+from models.provider import TenantDefaultModel
+
+
+class ModelFactory:
+
+    @classmethod
+    def get_text_generation_model_from_model_config(cls, tenant_id: str,
+                                                    model_config: dict,
+                                                    streaming: bool = False,
+                                                    callbacks: Callbacks = None) -> Optional[BaseLLM]:
+        provider_name = model_config.get("provider")
+        model_name = model_config.get("name")
+        completion_params = model_config.get("completion_params", {})
+
+        return cls.get_text_generation_model(
+            tenant_id=tenant_id,
+            model_provider_name=provider_name,
+            model_name=model_name,
+            model_kwargs=ModelKwargs(
+                temperature=completion_params.get('temperature', 0),
+                max_tokens=completion_params.get('max_tokens', 256),
+                top_p=completion_params.get('top_p', 0),
+                frequency_penalty=completion_params.get('frequency_penalty', 0.1),
+                presence_penalty=completion_params.get('presence_penalty', 0.1)
+            ),
+            streaming=streaming,
+            callbacks=callbacks
+        )
+
+    @classmethod
+    def get_text_generation_model(cls,
+                                  tenant_id: str,
+                                  model_provider_name: Optional[str] = None,
+                                  model_name: Optional[str] = None,
+                                  model_kwargs: Optional[ModelKwargs] = None,
+                                  streaming: bool = False,
+                                  callbacks: Callbacks = None) -> Optional[BaseLLM]:
+        """
+        get text generation model.
+
+        :param tenant_id: a string representing the ID of the tenant.
+        :param model_provider_name:
+        :param model_name:
+        :param model_kwargs:
+        :param streaming:
+        :param callbacks:
+        :return:
+        """
+        is_default_model = False
+        if model_provider_name is None and model_name is None:
+            default_model = cls.get_default_model(tenant_id, ModelType.TEXT_GENERATION)
+
+            if not default_model:
+                raise LLMBadRequestError(f"Default model is not available. "
+                                         f"Please configure a Default System Reasoning Model "
+                                         f"in the Settings -> Model Provider.")
+
+            model_provider_name = default_model.provider_name
+            model_name = default_model.model_name
+            is_default_model = True
+
+        # get model provider
+        model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
+
+        if not model_provider:
+            raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
+
+        # init text generation model
+        model_class = model_provider.get_model_class(model_type=ModelType.TEXT_GENERATION)
+
+        try:
+            model_instance = model_class(
+                model_provider=model_provider,
+                name=model_name,
+                model_kwargs=model_kwargs,
+                streaming=streaming,
+                callbacks=callbacks
+            )
+        except LLMBadRequestError as e:
+            if is_default_model:
+                raise LLMBadRequestError(f"Default model {model_name} is not available. "
+                                         f"Please check your model provider credentials.")
+            else:
+                raise e
+
+        if is_default_model:
+            model_instance.deduct_quota = False
+
+        return model_instance
+
+    @classmethod
+    def get_embedding_model(cls,
+                            tenant_id: str,
+                            model_provider_name: Optional[str] = None,
+                            model_name: Optional[str] = None) -> Optional[BaseEmbedding]:
+        """
+        get embedding model.
+
+        :param tenant_id: a string representing the ID of the tenant.
+        :param model_provider_name:
+        :param model_name:
+        :return:
+        """
+        if model_provider_name is None and model_name is None:
+            default_model = cls.get_default_model(tenant_id, ModelType.EMBEDDINGS)
+
+            if not default_model:
+                raise LLMBadRequestError(f"Default model is not available. "
+                                         f"Please configure a Default Embedding Model "
+                                         f"in the Settings -> Model Provider.")
+
+            model_provider_name = default_model.provider_name
+            model_name = default_model.model_name
+
+        # get model provider
+        model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
+
+        if not model_provider:
+            raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
+
+        # init embedding model
+        model_class = model_provider.get_model_class(model_type=ModelType.EMBEDDINGS)
+        return model_class(
+            model_provider=model_provider,
+            name=model_name
+        )
+
+    @classmethod
+    def get_speech2text_model(cls,
+                              tenant_id: str,
+                              model_provider_name: Optional[str] = None,
+                              model_name: Optional[str] = None) -> Optional[BaseSpeech2Text]:
+        """
+        get speech to text model.
+
+        :param tenant_id: a string representing the ID of the tenant.
+        :param model_provider_name:
+        :param model_name:
+        :return:
+        """
+        if model_provider_name is None and model_name is None:
+            default_model = cls.get_default_model(tenant_id, ModelType.SPEECH_TO_TEXT)
+
+            if not default_model:
+                raise LLMBadRequestError(f"Default model is not available. "
+                                         f"Please configure a Default Speech-to-Text Model "
+                                         f"in the Settings -> Model Provider.")
+
+            model_provider_name = default_model.provider_name
+            model_name = default_model.model_name
+
+        # get model provider
+        model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
+
+        if not model_provider:
+            raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
+
+        # init speech to text model
+        model_class = model_provider.get_model_class(model_type=ModelType.SPEECH_TO_TEXT)
+        return model_class(
+            model_provider=model_provider,
+            name=model_name
+        )
+
+    @classmethod
+    def get_moderation_model(cls,
+                             tenant_id: str,
+                             model_provider_name: str,
+                             model_name: str) -> Optional[BaseProviderModel]:
+        """
+        get moderation model.
+
+        :param tenant_id: a string representing the ID of the tenant.
+        :param model_provider_name:
+        :param model_name:
+        :return:
+        """
+        # get model provider
+        model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
+
+        if not model_provider:
+            raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
+
+        # init moderation model
+        model_class = model_provider.get_model_class(model_type=ModelType.MODERATION)
+        return model_class(
+            model_provider=model_provider,
+            name=model_name
+        )
+
+    @classmethod
+    def get_default_model(cls, tenant_id: str, model_type: ModelType) -> TenantDefaultModel:
+        """
+        get default model of model type.
+
+        :param tenant_id:
+        :param model_type:
+        :return:
+        """
+        # get default model
+        default_model = db.session.query(TenantDefaultModel) \
+            .filter(
+            TenantDefaultModel.tenant_id == tenant_id,
+            TenantDefaultModel.model_type == model_type.value
+        ).first()
+
+        if not default_model:
+            model_provider_rules = ModelProviderFactory.get_provider_rules()
+            for model_provider_name, model_provider_rule in model_provider_rules.items():
+                model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
+                if not model_provider:
+                    continue
+
+                model_list = model_provider.get_supported_model_list(model_type)
+                if model_list:
+                    model_info = model_list[0]
+                    default_model = TenantDefaultModel(
+                        tenant_id=tenant_id,
+                        model_type=model_type.value,
+                        provider_name=model_provider_name,
+                        model_name=model_info['id']
+                    )
+                    db.session.add(default_model)
+                    db.session.commit()
+                    break
+
+        return default_model
+
+    @classmethod
+    def update_default_model(cls,
+                             tenant_id: str,
+                             model_type: ModelType,
+                             provider_name: str,
+                             model_name: str) -> TenantDefaultModel:
+        """
+        update default model of model type.
+
+        :param tenant_id:
+        :param model_type:
+        :param provider_name:
+        :param model_name:
+        :return:
+        """
+        model_provider_name = ModelProviderFactory.get_provider_names()
+        if provider_name not in model_provider_name:
+            raise ValueError(f'Invalid provider name: {provider_name}')
+
+        model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, provider_name)
+
+        if not model_provider:
+            raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
+
+        model_list = model_provider.get_supported_model_list(model_type)
+        model_ids = [model['id'] for model in model_list]
+        if model_name not in model_ids:
+            raise ValueError(f'Invalid model name: {model_name}')
+
+        # get default model
+        default_model = db.session.query(TenantDefaultModel) \
+            .filter(
+            TenantDefaultModel.tenant_id == tenant_id,
+            TenantDefaultModel.model_type == model_type.value
+        ).first()
+
+        if default_model:
+            # update default model
+            default_model.provider_name = provider_name
+            default_model.model_name = model_name
+            db.session.commit()
+        else:
+            # create default model
+            default_model = TenantDefaultModel(
+                tenant_id=tenant_id,
+                model_type=model_type.value,
+                provider_name=provider_name,
+                model_name=model_name,
+            )
+            db.session.add(default_model)
+            db.session.commit()
+
+        return default_model

+ 228 - 0
api/core/model_providers/model_provider_factory.py

@@ -0,0 +1,228 @@
+from typing import Type
+
+from sqlalchemy.exc import IntegrityError
+
+from core.model_providers.models.entity.model_params import ModelType
+from core.model_providers.providers.base import BaseModelProvider
+from core.model_providers.rules import provider_rules
+from extensions.ext_database import db
+from models.provider import TenantPreferredModelProvider, ProviderType, Provider, ProviderQuotaType
+
+DEFAULT_MODELS = {
+    ModelType.TEXT_GENERATION.value: {
+        'provider_name': 'openai',
+        'model_name': 'gpt-3.5-turbo',
+    },
+    ModelType.EMBEDDINGS.value: {
+        'provider_name': 'openai',
+        'model_name': 'text-embedding-ada-002',
+    },
+    ModelType.SPEECH_TO_TEXT.value: {
+        'provider_name': 'openai',
+        'model_name': 'whisper-1',
+    }
+}
+
+
+class ModelProviderFactory:
+    @classmethod
+    def get_model_provider_class(cls, provider_name: str) -> Type[BaseModelProvider]:
+        if provider_name == 'openai':
+            from core.model_providers.providers.openai_provider import OpenAIProvider
+            return OpenAIProvider
+        elif provider_name == 'anthropic':
+            from core.model_providers.providers.anthropic_provider import AnthropicProvider
+            return AnthropicProvider
+        elif provider_name == 'minimax':
+            from core.model_providers.providers.minimax_provider import MinimaxProvider
+            return MinimaxProvider
+        elif provider_name == 'spark':
+            from core.model_providers.providers.spark_provider import SparkProvider
+            return SparkProvider
+        elif provider_name == 'tongyi':
+            from core.model_providers.providers.tongyi_provider import TongyiProvider
+            return TongyiProvider
+        elif provider_name == 'wenxin':
+            from core.model_providers.providers.wenxin_provider import WenxinProvider
+            return WenxinProvider
+        elif provider_name == 'chatglm':
+            from core.model_providers.providers.chatglm_provider import ChatGLMProvider
+            return ChatGLMProvider
+        elif provider_name == 'azure_openai':
+            from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider
+            return AzureOpenAIProvider
+        elif provider_name == 'replicate':
+            from core.model_providers.providers.replicate_provider import ReplicateProvider
+            return ReplicateProvider
+        elif provider_name == 'huggingface_hub':
+            from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
+            return HuggingfaceHubProvider
+        else:
+            raise NotImplementedError
+
+    @classmethod
+    def get_provider_names(cls):
+        """
+        Returns a list of provider names.
+        """
+        return list(provider_rules.keys())
+
+    @classmethod
+    def get_provider_rules(cls):
+        """
+        Returns a list of provider rules.
+
+        :return:
+        """
+        return provider_rules
+
+    @classmethod
+    def get_provider_rule(cls, provider_name: str):
+        """
+        Returns provider rule.
+        """
+        return provider_rules[provider_name]
+
+    @classmethod
+    def get_preferred_model_provider(cls, tenant_id: str, model_provider_name: str):
+        """
+        get preferred model provider.
+
+        :param tenant_id: a string representing the ID of the tenant.
+        :param model_provider_name:
+        :return:
+        """
+        # get preferred provider
+        preferred_provider = cls._get_preferred_provider(tenant_id, model_provider_name)
+        if not preferred_provider or not preferred_provider.is_valid:
+            return None
+
+        # init model provider
+        model_provider_class = ModelProviderFactory.get_model_provider_class(model_provider_name)
+        return model_provider_class(provider=preferred_provider)
+
+    @classmethod
+    def get_preferred_type_by_preferred_model_provider(cls,
+                                                       tenant_id: str,
+                                                       model_provider_name: str,
+                                                       preferred_model_provider: TenantPreferredModelProvider):
+        """
+        get preferred provider type by preferred model provider.
+
+        :param model_provider_name:
+        :param preferred_model_provider:
+        :return:
+        """
+        if not preferred_model_provider:
+            model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name)
+            support_provider_types = model_provider_rules['support_provider_types']
+
+            if ProviderType.CUSTOM.value in support_provider_types:
+                custom_provider = db.session.query(Provider) \
+                    .filter(
+                        Provider.tenant_id == tenant_id,
+                        Provider.provider_name == model_provider_name,
+                        Provider.provider_type == ProviderType.CUSTOM.value,
+                        Provider.is_valid == True
+                    ).first()
+
+                if custom_provider:
+                    return ProviderType.CUSTOM.value
+
+            model_provider = cls.get_model_provider_class(model_provider_name)
+
+            if ProviderType.SYSTEM.value in support_provider_types \
+                    and model_provider.is_provider_type_system_supported():
+                return ProviderType.SYSTEM.value
+            elif ProviderType.CUSTOM.value in support_provider_types:
+                return ProviderType.CUSTOM.value
+        else:
+            return preferred_model_provider.preferred_provider_type
+
+    @classmethod
+    def _get_preferred_provider(cls, tenant_id: str, model_provider_name: str):
+        """
+        get preferred provider of tenant.
+
+        :param tenant_id:
+        :param model_provider_name:
+        :return:
+        """
+        # get preferred provider type
+        preferred_provider_type = cls._get_preferred_provider_type(tenant_id, model_provider_name)
+
+        # get providers by preferred provider type
+        providers = db.session.query(Provider) \
+            .filter(
+                Provider.tenant_id == tenant_id,
+                Provider.provider_name == model_provider_name,
+                Provider.provider_type == preferred_provider_type
+            ).all()
+
+        no_system_provider = False
+        if preferred_provider_type == ProviderType.SYSTEM.value:
+            quota_type_to_provider_dict = {}
+            for provider in providers:
+                quota_type_to_provider_dict[provider.quota_type] = provider
+
+            model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name)
+            for quota_type_enum in ProviderQuotaType:
+                quota_type = quota_type_enum.value
+                if quota_type in model_provider_rules['system_config']['supported_quota_types'] \
+                        and quota_type in quota_type_to_provider_dict.keys():
+                    provider = quota_type_to_provider_dict[quota_type]
+                    if provider.is_valid and provider.quota_limit > provider.quota_used:
+                        return provider
+
+            no_system_provider = True
+
+        if no_system_provider:
+            providers = db.session.query(Provider) \
+                .filter(
+                Provider.tenant_id == tenant_id,
+                Provider.provider_name == model_provider_name,
+                Provider.provider_type == ProviderType.CUSTOM.value
+            ).all()
+
+        if preferred_provider_type == ProviderType.CUSTOM.value or no_system_provider:
+            if providers:
+                return providers[0]
+            else:
+                try:
+                    provider = Provider(
+                        tenant_id=tenant_id,
+                        provider_name=model_provider_name,
+                        provider_type=ProviderType.CUSTOM.value,
+                        is_valid=False
+                    )
+                    db.session.add(provider)
+                    db.session.commit()
+                except IntegrityError:
+                    db.session.rollback()
+                    provider = db.session.query(Provider) \
+                        .filter(
+                            Provider.tenant_id == tenant_id,
+                            Provider.provider_name == model_provider_name,
+                            Provider.provider_type == ProviderType.CUSTOM.value
+                        ).first()
+
+                return provider
+
+        return None
+
+    @classmethod
+    def _get_preferred_provider_type(cls, tenant_id: str, model_provider_name: str):
+        """
+        get preferred provider type of tenant.
+
+        :param tenant_id:
+        :param model_provider_name:
+        :return:
+        """
+        preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
+            .filter(
+            TenantPreferredModelProvider.tenant_id == tenant_id,
+            TenantPreferredModelProvider.provider_name == model_provider_name
+        ).first()
+
+        return cls.get_preferred_type_by_preferred_model_provider(tenant_id, model_provider_name, preferred_model_provider)

+ 0 - 0
api/tests/test_libs/__init__.py → api/core/model_providers/models/__init__.py


+ 22 - 0
api/core/model_providers/models/base.py

@@ -0,0 +1,22 @@
+from abc import ABC
+from typing import Any
+
+from core.model_providers.providers.base import BaseModelProvider
+
+
+class BaseProviderModel(ABC):
+    _client: Any
+    _model_provider: BaseModelProvider
+
+    def __init__(self, model_provider: BaseModelProvider, client: Any):
+        self._model_provider = model_provider
+        self._client = client
+
+    @property
+    def client(self):
+        return self._client
+
+    @property
+    def model_provider(self):
+        return self._model_provider
+

+ 0 - 0
api/tests/test_models/__init__.py → api/core/model_providers/models/embedding/__init__.py


+ 78 - 0
api/core/model_providers/models/embedding/azure_openai_embedding.py

@@ -0,0 +1,78 @@
+import decimal
+import logging
+
+import openai
+import tiktoken
+from langchain.embeddings import OpenAIEmbeddings
+
+from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMRateLimitError, \
+    LLMAPIUnavailableError, LLMAPIConnectionError
+from core.model_providers.models.embedding.base import BaseEmbedding
+from core.model_providers.providers.base import BaseModelProvider
+
+AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
+
+
+class AzureOpenAIEmbedding(BaseEmbedding):
+    def __init__(self, model_provider: BaseModelProvider, name: str):
+        self.credentials = model_provider.get_model_credentials(
+            model_name=name,
+            model_type=self.type
+        )
+
+        client = OpenAIEmbeddings(
+            deployment=name,
+            openai_api_type='azure',
+            openai_api_version=AZURE_OPENAI_API_VERSION,
+            chunk_size=16,
+            max_retries=1,
+            **self.credentials
+        )
+
+        super().__init__(model_provider, client, name)
+
+    def get_num_tokens(self, text: str) -> int:
+        """
+        get num tokens of text.
+
+        :param text:
+        :return:
+        """
+        if len(text) == 0:
+            return 0
+
+        enc = tiktoken.encoding_for_model(self.credentials.get('base_model_name'))
+
+        tokenized_text = enc.encode(text)
+
+        # calculate the number of tokens in the encoded text
+        return len(tokenized_text)
+
+    def get_token_price(self, tokens: int):
+        tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
+                                                                  rounding=decimal.ROUND_HALF_UP)
+
+        total_price = tokens_per_1k * decimal.Decimal('0.0001')
+        return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
+
+    def get_currency(self):
+        return 'USD'
+
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        if isinstance(ex, openai.error.InvalidRequestError):
+            logging.warning("Invalid request to Azure OpenAI API.")
+            return LLMBadRequestError(str(ex))
+        elif isinstance(ex, openai.error.APIConnectionError):
+            logging.warning("Failed to connect to Azure OpenAI API.")
+            return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
+        elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
+            logging.warning("Azure OpenAI service unavailable.")
+            return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
+        elif isinstance(ex, openai.error.RateLimitError):
+            return LLMRateLimitError('Azure ' + str(ex))
+        elif isinstance(ex, openai.error.AuthenticationError):
+            raise LLMAuthorizationError('Azure ' + str(ex))
+        elif isinstance(ex, openai.error.OpenAIError):
+            return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex))
+        else:
+            return ex

+ 40 - 0
api/core/model_providers/models/embedding/base.py

@@ -0,0 +1,40 @@
+from abc import abstractmethod
+from typing import Any
+
+import tiktoken
+from langchain.schema.language_model import _get_token_ids_default_method
+
+from core.model_providers.models.base import BaseProviderModel
+from core.model_providers.models.entity.model_params import ModelType
+from core.model_providers.providers.base import BaseModelProvider
+
+
+class BaseEmbedding(BaseProviderModel):
+    name: str
+    type: ModelType = ModelType.EMBEDDINGS
+
+    def __init__(self, model_provider: BaseModelProvider, client: Any, name: str):
+        super().__init__(model_provider, client)
+        self.name = name
+
+    def get_num_tokens(self, text: str) -> int:
+        """
+        get num tokens of text.
+
+        :param text:
+        :return:
+        """
+        if len(text) == 0:
+            return 0
+
+        return len(_get_token_ids_default_method(text))
+
+    def get_token_price(self, tokens: int):
+        return 0
+
+    def get_currency(self):
+        return 'USD'
+
+    @abstractmethod
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        raise NotImplementedError

+ 35 - 0
api/core/model_providers/models/embedding/minimax_embedding.py

@@ -0,0 +1,35 @@
+import decimal
+import logging
+
+from langchain.embeddings import MiniMaxEmbeddings
+
+from core.model_providers.error import LLMBadRequestError
+from core.model_providers.models.embedding.base import BaseEmbedding
+from core.model_providers.providers.base import BaseModelProvider
+
+
+class MinimaxEmbedding(BaseEmbedding):
+    def __init__(self, model_provider: BaseModelProvider, name: str):
+        credentials = model_provider.get_model_credentials(
+            model_name=name,
+            model_type=self.type
+        )
+
+        client = MiniMaxEmbeddings(
+            model=name,
+            **credentials
+        )
+
+        super().__init__(model_provider, client, name)
+
+    def get_token_price(self, tokens: int):
+        return decimal.Decimal('0')
+
+    def get_currency(self):
+        return 'RMB'
+
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        if isinstance(ex, ValueError):
+            return LLMBadRequestError(f"Minimax: {str(ex)}")
+        else:
+            return ex

+ 72 - 0
api/core/model_providers/models/embedding/openai_embedding.py

@@ -0,0 +1,72 @@
+import decimal
+import logging
+
+import openai
+import tiktoken
+from langchain.embeddings import OpenAIEmbeddings
+
+from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
+    LLMRateLimitError, LLMAuthorizationError
+from core.model_providers.models.embedding.base import BaseEmbedding
+from core.model_providers.providers.base import BaseModelProvider
+
+
+class OpenAIEmbedding(BaseEmbedding):
+    def __init__(self, model_provider: BaseModelProvider, name: str):
+        credentials = model_provider.get_model_credentials(
+            model_name=name,
+            model_type=self.type
+        )
+
+        client = OpenAIEmbeddings(
+            max_retries=1,
+            **credentials
+        )
+
+        super().__init__(model_provider, client, name)
+
+    def get_num_tokens(self, text: str) -> int:
+        """
+        get num tokens of text.
+
+        :param text:
+        :return:
+        """
+        if len(text) == 0:
+            return 0
+
+        enc = tiktoken.encoding_for_model(self.name)
+
+        tokenized_text = enc.encode(text)
+
+        # calculate the number of tokens in the encoded text
+        return len(tokenized_text)
+
+    def get_token_price(self, tokens: int):
+        tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
+                                                                  rounding=decimal.ROUND_HALF_UP)
+
+        total_price = tokens_per_1k * decimal.Decimal('0.0001')
+        return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
+
+    def get_currency(self):
+        return 'USD'
+
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        if isinstance(ex, openai.error.InvalidRequestError):
+            logging.warning("Invalid request to OpenAI API.")
+            return LLMBadRequestError(str(ex))
+        elif isinstance(ex, openai.error.APIConnectionError):
+            logging.warning("Failed to connect to OpenAI API.")
+            return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
+        elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
+            logging.warning("OpenAI service unavailable.")
+            return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
+        elif isinstance(ex, openai.error.RateLimitError):
+            return LLMRateLimitError(str(ex))
+        elif isinstance(ex, openai.error.AuthenticationError):
+            raise LLMAuthorizationError(str(ex))
+        elif isinstance(ex, openai.error.OpenAIError):
+            return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
+        else:
+            return ex

+ 36 - 0
api/core/model_providers/models/embedding/replicate_embedding.py

@@ -0,0 +1,36 @@
+import decimal
+
+from replicate.exceptions import ModelError, ReplicateError
+
+from core.model_providers.error import LLMBadRequestError
+from core.model_providers.providers.base import BaseModelProvider
+from core.third_party.langchain.embeddings.replicate_embedding import ReplicateEmbeddings
+from core.model_providers.models.embedding.base import BaseEmbedding
+
+
+class ReplicateEmbedding(BaseEmbedding):
+    def __init__(self, model_provider: BaseModelProvider, name: str):
+        credentials = model_provider.get_model_credentials(
+            model_name=name,
+            model_type=self.type
+        )
+
+        client = ReplicateEmbeddings(
+            model=name + ':' + credentials.get('model_version'),
+            replicate_api_token=credentials.get('replicate_api_token')
+        )
+
+        super().__init__(model_provider, client, name)
+
+    def get_token_price(self, tokens: int):
+        # replicate only pay for prediction seconds
+        return decimal.Decimal('0')
+
+    def get_currency(self):
+        return 'USD'
+
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        if isinstance(ex, (ModelError, ReplicateError)):
+            return LLMBadRequestError(f"Replicate: {str(ex)}")
+        else:
+            return ex

+ 0 - 0
api/tests/test_services/__init__.py → api/core/model_providers/models/entity/__init__.py


+ 53 - 0
api/core/model_providers/models/entity/message.py

@@ -0,0 +1,53 @@
+import enum
+
+from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage
+from pydantic import BaseModel
+
+
+class LLMRunResult(BaseModel):
+    content: str
+    prompt_tokens: int
+    completion_tokens: int
+
+
+class MessageType(enum.Enum):
+    HUMAN = 'human'
+    ASSISTANT = 'assistant'
+    SYSTEM = 'system'
+
+
+class PromptMessage(BaseModel):
+    type: MessageType = MessageType.HUMAN
+    content: str = ''
+
+
+def to_lc_messages(messages: list[PromptMessage]):
+    lc_messages = []
+    for message in messages:
+        if message.type == MessageType.HUMAN:
+            lc_messages.append(HumanMessage(content=message.content))
+        elif message.type == MessageType.ASSISTANT:
+            lc_messages.append(AIMessage(content=message.content))
+        elif message.type == MessageType.SYSTEM:
+            lc_messages.append(SystemMessage(content=message.content))
+
+    return lc_messages
+
+
+def to_prompt_messages(messages: list[BaseMessage]):
+    prompt_messages = []
+    for message in messages:
+        if isinstance(message, HumanMessage):
+            prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
+        elif isinstance(message, AIMessage):
+            prompt_messages.append(PromptMessage(content=message.content, type=MessageType.ASSISTANT))
+        elif isinstance(message, SystemMessage):
+            prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM))
+    return prompt_messages
+
+
+def str_to_prompt_messages(texts: list[str]):
+    prompt_messages = []
+    for text in texts:
+        prompt_messages.append(PromptMessage(content=text))
+    return prompt_messages

+ 59 - 0
api/core/model_providers/models/entity/model_params.py

@@ -0,0 +1,59 @@
+import enum
+from typing import Optional, TypeVar, Generic
+
+from langchain.load.serializable import Serializable
+from pydantic import BaseModel
+
+
+class ModelMode(enum.Enum):
+    COMPLETION = 'completion'
+    CHAT = 'chat'
+
+
+class ModelType(enum.Enum):
+    TEXT_GENERATION = 'text-generation'
+    EMBEDDINGS = 'embeddings'
+    SPEECH_TO_TEXT = 'speech2text'
+    IMAGE = 'image'
+    VIDEO = 'video'
+    MODERATION = 'moderation'
+
+    @staticmethod
+    def value_of(value):
+        for member in ModelType:
+            if member.value == value:
+                return member
+        raise ValueError(f"No matching enum found for value '{value}'")
+
+
+class ModelKwargs(BaseModel):
+    max_tokens: Optional[int]
+    temperature: Optional[float]
+    top_p: Optional[float]
+    presence_penalty: Optional[float]
+    frequency_penalty: Optional[float]
+
+
+class KwargRuleType(enum.Enum):
+    STRING = 'string'
+    INTEGER = 'integer'
+    FLOAT = 'float'
+
+
+T = TypeVar('T')
+
+
+class KwargRule(Generic[T], BaseModel):
+    enabled: bool = True
+    min: Optional[T] = None
+    max: Optional[T] = None
+    default: Optional[T] = None
+    alias: Optional[str] = None
+
+
+class ModelKwargsRules(BaseModel):
+    max_tokens: KwargRule = KwargRule[int](enabled=False)
+    temperature: KwargRule = KwargRule[float](enabled=False)
+    top_p: KwargRule = KwargRule[float](enabled=False)
+    presence_penalty: KwargRule = KwargRule[float](enabled=False)
+    frequency_penalty: KwargRule = KwargRule[float](enabled=False)

+ 10 - 0
api/core/model_providers/models/entity/provider.py

@@ -0,0 +1,10 @@
+from enum import Enum
+
+
+class ProviderQuotaUnit(Enum):
+    TIMES = 'times'
+    TOKENS = 'tokens'
+
+
+class ModelFeature(Enum):
+    AGENT_THOUGHT = 'agent_thought'

+ 0 - 0
api/core/model_providers/models/llm/__init__.py


+ 107 - 0
api/core/model_providers/models/llm/anthropic_model.py

@@ -0,0 +1,107 @@
+import decimal
+import logging
+from functools import wraps
+from typing import List, Optional, Any
+
+import anthropic
+from langchain.callbacks.manager import Callbacks
+from langchain.chat_models import ChatAnthropic
+from langchain.schema import LLMResult
+
+from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
+    LLMRateLimitError, LLMAuthorizationError
+from core.model_providers.models.llm.base import BaseLLM
+from core.model_providers.models.entity.message import PromptMessage, MessageType
+from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
+
+
+class AnthropicModel(BaseLLM):
+    model_mode: ModelMode = ModelMode.CHAT
+
+    def _init_client(self) -> Any:
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
+        return ChatAnthropic(
+            model=self.name,
+            streaming=self.streaming,
+            callbacks=self.callbacks,
+            default_request_timeout=60,
+            **self.credentials,
+            **provider_model_kwargs
+        )
+
+    def _run(self, messages: List[PromptMessage],
+             stop: Optional[List[str]] = None,
+             callbacks: Callbacks = None,
+             **kwargs) -> LLMResult:
+        """
+        run predict by prompt messages and stop words.
+
+        :param messages:
+        :param stop:
+        :param callbacks:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        return self._client.generate([prompts], stop, callbacks)
+
+    def get_num_tokens(self, messages: List[PromptMessage]) -> int:
+        """
+        get num tokens of prompt messages.
+
+        :param messages:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
+
+    def get_token_price(self, tokens: int, message_type: MessageType):
+        model_unit_prices = {
+            'claude-instant-1': {
+                'prompt': decimal.Decimal('1.63'),
+                'completion': decimal.Decimal('5.51'),
+            },
+            'claude-2': {
+                'prompt': decimal.Decimal('11.02'),
+                'completion': decimal.Decimal('32.68'),
+            },
+        }
+
+        if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
+            unit_price = model_unit_prices[self.name]['prompt']
+        else:
+            unit_price = model_unit_prices[self.name]['completion']
+
+        tokens_per_1m = (decimal.Decimal(tokens) / 1000000).quantize(decimal.Decimal('0.000001'),
+                                                                     rounding=decimal.ROUND_HALF_UP)
+
+        total_price = tokens_per_1m * unit_price
+        return total_price.quantize(decimal.Decimal('0.00000001'), rounding=decimal.ROUND_HALF_UP)
+
+    def get_currency(self):
+        return 'USD'
+
+    def _set_model_kwargs(self, model_kwargs: ModelKwargs):
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
+        for k, v in provider_model_kwargs.items():
+            if hasattr(self.client, k):
+                setattr(self.client, k, v)
+
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        if isinstance(ex, anthropic.APIConnectionError):
+            logging.warning("Failed to connect to Anthropic API.")
+            return LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {ex.__cause__}")
+        elif isinstance(ex, anthropic.RateLimitError):
+            return LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.")
+        elif isinstance(ex, anthropic.AuthenticationError):
+            return LLMAuthorizationError(f"Anthropic: {ex.message}")
+        elif isinstance(ex, anthropic.BadRequestError):
+            return LLMBadRequestError(f"Anthropic: {ex.message}")
+        elif isinstance(ex, anthropic.APIStatusError):
+            return LLMAPIUnavailableError(f"Anthropic: code: {ex.status_code}, cause: {ex.message}")
+        else:
+            return ex
+
+    @classmethod
+    def support_streaming(cls):
+        return True
+

+ 177 - 0
api/core/model_providers/models/llm/azure_openai_model.py

@@ -0,0 +1,177 @@
+import decimal
+import logging
+from functools import wraps
+from typing import List, Optional, Any
+
+import openai
+from langchain.callbacks.manager import Callbacks
+from langchain.schema import LLMResult
+
+from core.model_providers.providers.base import BaseModelProvider
+from core.third_party.langchain.llms.azure_chat_open_ai import EnhanceAzureChatOpenAI
+from core.third_party.langchain.llms.azure_open_ai import EnhanceAzureOpenAI
+from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
+    LLMRateLimitError, LLMAuthorizationError
+from core.model_providers.models.llm.base import BaseLLM
+from core.model_providers.models.entity.message import PromptMessage, MessageType
+from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
+
+AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
+
+
+class AzureOpenAIModel(BaseLLM):
+    def __init__(self, model_provider: BaseModelProvider,
+                 name: str,
+                 model_kwargs: ModelKwargs,
+                 streaming: bool = False,
+                 callbacks: Callbacks = None):
+        if name == 'text-davinci-003':
+            self.model_mode = ModelMode.COMPLETION
+        else:
+            self.model_mode = ModelMode.CHAT
+
+        super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
+
+    def _init_client(self) -> Any:
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
+        if self.name == 'text-davinci-003':
+            client = EnhanceAzureOpenAI(
+                deployment_name=self.name,
+                streaming=self.streaming,
+                request_timeout=60,
+                openai_api_type='azure',
+                openai_api_version=AZURE_OPENAI_API_VERSION,
+                openai_api_key=self.credentials.get('openai_api_key'),
+                openai_api_base=self.credentials.get('openai_api_base'),
+                callbacks=self.callbacks,
+                **provider_model_kwargs
+            )
+        else:
+            extra_model_kwargs = {
+                'top_p': provider_model_kwargs.get('top_p'),
+                'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
+                'presence_penalty': provider_model_kwargs.get('presence_penalty'),
+            }
+
+            client = EnhanceAzureChatOpenAI(
+                deployment_name=self.name,
+                temperature=provider_model_kwargs.get('temperature'),
+                max_tokens=provider_model_kwargs.get('max_tokens'),
+                model_kwargs=extra_model_kwargs,
+                streaming=self.streaming,
+                request_timeout=60,
+                openai_api_type='azure',
+                openai_api_version=AZURE_OPENAI_API_VERSION,
+                openai_api_key=self.credentials.get('openai_api_key'),
+                openai_api_base=self.credentials.get('openai_api_base'),
+                callbacks=self.callbacks,
+            )
+
+        return client
+
+    def _run(self, messages: List[PromptMessage],
+             stop: Optional[List[str]] = None,
+             callbacks: Callbacks = None,
+             **kwargs) -> LLMResult:
+        """
+        run predict by prompt messages and stop words.
+
+        :param messages:
+        :param stop:
+        :param callbacks:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        return self._client.generate([prompts], stop, callbacks)
+
+    def get_num_tokens(self, messages: List[PromptMessage]) -> int:
+        """
+        get num tokens of prompt messages.
+
+        :param messages:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        if isinstance(prompts, str):
+            return self._client.get_num_tokens(prompts)
+        else:
+            return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
+
+    def get_token_price(self, tokens: int, message_type: MessageType):
+        model_unit_prices = {
+            'gpt-4': {
+                'prompt': decimal.Decimal('0.03'),
+                'completion': decimal.Decimal('0.06'),
+            },
+            'gpt-4-32k': {
+                'prompt': decimal.Decimal('0.06'),
+                'completion': decimal.Decimal('0.12')
+            },
+            'gpt-35-turbo': {
+                'prompt': decimal.Decimal('0.0015'),
+                'completion': decimal.Decimal('0.002')
+            },
+            'gpt-35-turbo-16k': {
+                'prompt': decimal.Decimal('0.003'),
+                'completion': decimal.Decimal('0.004')
+            },
+            'text-davinci-003': {
+                'prompt': decimal.Decimal('0.02'),
+                'completion': decimal.Decimal('0.02')
+            },
+        }
+
+        base_model_name = self.credentials.get("base_model_name")
+        if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
+            unit_price = model_unit_prices[base_model_name]['prompt']
+        else:
+            unit_price = model_unit_prices[base_model_name]['completion']
+
+        tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
+                                                                  rounding=decimal.ROUND_HALF_UP)
+
+        total_price = tokens_per_1k * unit_price
+        return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
+
+    def get_currency(self):
+        return 'USD'
+
+    def _set_model_kwargs(self, model_kwargs: ModelKwargs):
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
+        if self.name == 'text-davinci-003':
+            for k, v in provider_model_kwargs.items():
+                if hasattr(self.client, k):
+                    setattr(self.client, k, v)
+        else:
+            extra_model_kwargs = {
+                'top_p': provider_model_kwargs.get('top_p'),
+                'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
+                'presence_penalty': provider_model_kwargs.get('presence_penalty'),
+            }
+
+            self.client.temperature = provider_model_kwargs.get('temperature')
+            self.client.max_tokens = provider_model_kwargs.get('max_tokens')
+            self.client.model_kwargs = extra_model_kwargs
+
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        if isinstance(ex, openai.error.InvalidRequestError):
+            logging.warning("Invalid request to Azure OpenAI API.")
+            return LLMBadRequestError(str(ex))
+        elif isinstance(ex, openai.error.APIConnectionError):
+            logging.warning("Failed to connect to Azure OpenAI API.")
+            return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
+        elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
+            logging.warning("Azure OpenAI service unavailable.")
+            return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
+        elif isinstance(ex, openai.error.RateLimitError):
+            return LLMRateLimitError('Azure ' + str(ex))
+        elif isinstance(ex, openai.error.AuthenticationError):
+            raise LLMAuthorizationError('Azure ' + str(ex))
+        elif isinstance(ex, openai.error.OpenAIError):
+            return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex))
+        else:
+            return ex
+
+    @classmethod
+    def support_streaming(cls):
+        return True

+ 269 - 0
api/core/model_providers/models/llm/base.py

@@ -0,0 +1,269 @@
+from abc import abstractmethod
+from typing import List, Optional, Any, Union
+
+from langchain.callbacks.manager import Callbacks
+from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
+
+from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
+from core.model_providers.models.base import BaseProviderModel
+from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult
+from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
+from core.model_providers.providers.base import BaseModelProvider
+from core.third_party.langchain.llms.fake import FakeLLM
+
+
+class BaseLLM(BaseProviderModel):
+    model_mode: ModelMode = ModelMode.COMPLETION
+    name: str
+    model_kwargs: ModelKwargs
+    credentials: dict
+    streaming: bool = False
+    type: ModelType = ModelType.TEXT_GENERATION
+    deduct_quota: bool = True
+
+    def __init__(self, model_provider: BaseModelProvider,
+                 name: str,
+                 model_kwargs: ModelKwargs,
+                 streaming: bool = False,
+                 callbacks: Callbacks = None):
+        self.name = name
+        self.model_rules = model_provider.get_model_parameter_rules(name, self.type)
+        self.model_kwargs = model_kwargs if model_kwargs else ModelKwargs(
+            max_tokens=None,
+            temperature=None,
+            top_p=None,
+            presence_penalty=None,
+            frequency_penalty=None
+        )
+        self.credentials = model_provider.get_model_credentials(
+            model_name=name,
+            model_type=self.type
+        )
+        self.streaming = streaming
+
+        if streaming:
+            default_callback = DifyStreamingStdOutCallbackHandler()
+        else:
+            default_callback = DifyStdOutCallbackHandler()
+
+        if not callbacks:
+            callbacks = [default_callback]
+        else:
+            callbacks.append(default_callback)
+
+        self.callbacks = callbacks
+
+        client = self._init_client()
+        super().__init__(model_provider, client)
+
+    @abstractmethod
+    def _init_client(self) -> Any:
+        raise NotImplementedError
+
+    def run(self, messages: List[PromptMessage],
+            stop: Optional[List[str]] = None,
+            callbacks: Callbacks = None,
+            **kwargs) -> LLMRunResult:
+        """
+        run predict by prompt messages and stop words.
+
+        :param messages:
+        :param stop:
+        :param callbacks:
+        :return:
+        """
+        if self.deduct_quota:
+            self.model_provider.check_quota_over_limit()
+
+        if not callbacks:
+            callbacks = self.callbacks
+        else:
+            callbacks.extend(self.callbacks)
+
+        if 'fake_response' in kwargs and kwargs['fake_response']:
+            prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
+            fake_llm = FakeLLM(
+                response=kwargs['fake_response'],
+                num_token_func=self.get_num_tokens,
+                streaming=self.streaming,
+                callbacks=callbacks
+            )
+            result = fake_llm.generate([prompts])
+        else:
+            try:
+                result = self._run(
+                    messages=messages,
+                    stop=stop,
+                    callbacks=callbacks if not (self.streaming and not self.support_streaming()) else None,
+                    **kwargs
+                )
+            except Exception as ex:
+                raise self.handle_exceptions(ex)
+
+        if isinstance(result.generations[0][0], ChatGeneration):
+            completion_content = result.generations[0][0].message.content
+        else:
+            completion_content = result.generations[0][0].text
+
+        if self.streaming and not self.support_streaming():
+            # use FakeLLM to simulate streaming when current model not support streaming but streaming is True
+            prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
+            fake_llm = FakeLLM(
+                response=completion_content,
+                num_token_func=self.get_num_tokens,
+                streaming=self.streaming,
+                callbacks=callbacks
+            )
+            fake_llm.generate([prompts])
+
+        if result.llm_output and result.llm_output['token_usage']:
+            prompt_tokens = result.llm_output['token_usage']['prompt_tokens']
+            completion_tokens = result.llm_output['token_usage']['completion_tokens']
+            total_tokens = result.llm_output['token_usage']['total_tokens']
+        else:
+            prompt_tokens = self.get_num_tokens(messages)
+            completion_tokens = self.get_num_tokens([PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
+            total_tokens = prompt_tokens + completion_tokens
+
+        if self.deduct_quota:
+            self.model_provider.deduct_quota(total_tokens)
+
+        return LLMRunResult(
+            content=completion_content,
+            prompt_tokens=prompt_tokens,
+            completion_tokens=completion_tokens
+        )
+
+    @abstractmethod
+    def _run(self, messages: List[PromptMessage],
+             stop: Optional[List[str]] = None,
+             callbacks: Callbacks = None,
+             **kwargs) -> LLMResult:
+        """
+        run predict by prompt messages and stop words.
+
+        :param messages:
+        :param stop:
+        :param callbacks:
+        :return:
+        """
+        raise NotImplementedError
+
+    @abstractmethod
+    def get_num_tokens(self, messages: List[PromptMessage]) -> int:
+        """
+        get num tokens of prompt messages.
+
+        :param messages:
+        :return:
+        """
+        raise NotImplementedError
+
+    @abstractmethod
+    def get_token_price(self, tokens: int, message_type: MessageType):
+        """
+        get token price.
+
+        :param tokens:
+        :param message_type:
+        :return:
+        """
+        raise NotImplementedError
+
+    @abstractmethod
+    def get_currency(self):
+        """
+        get token currency.
+
+        :return:
+        """
+        raise NotImplementedError
+
+    def get_model_kwargs(self):
+        return self.model_kwargs
+
+    def set_model_kwargs(self, model_kwargs: ModelKwargs):
+        self.model_kwargs = model_kwargs
+        self._set_model_kwargs(model_kwargs)
+
+    @abstractmethod
+    def _set_model_kwargs(self, model_kwargs: ModelKwargs):
+        raise NotImplementedError
+
+    @abstractmethod
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        """
+        Handle llm run exceptions.
+
+        :param ex:
+        :return:
+        """
+        raise NotImplementedError
+
+    def add_callbacks(self, callbacks: Callbacks):
+        """
+        Add callbacks to client.
+
+        :param callbacks:
+        :return:
+        """
+        if not self.client.callbacks:
+            self.client.callbacks = callbacks
+        else:
+            self.client.callbacks.extend(callbacks)
+
+    @classmethod
+    def support_streaming(cls):
+        return False
+
+    def _get_prompt_from_messages(self, messages: List[PromptMessage],
+                                  model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
+        if len(messages) == 0:
+            raise ValueError("prompt must not be empty.")
+
+        if not model_mode:
+            model_mode = self.model_mode
+
+        if model_mode == ModelMode.COMPLETION:
+            return messages[0].content
+        else:
+            chat_messages = []
+            for message in messages:
+                if message.type == MessageType.HUMAN:
+                    chat_messages.append(HumanMessage(content=message.content))
+                elif message.type == MessageType.ASSISTANT:
+                    chat_messages.append(AIMessage(content=message.content))
+                elif message.type == MessageType.SYSTEM:
+                    chat_messages.append(SystemMessage(content=message.content))
+
+            return chat_messages
+
+    def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict:
+        """
+        convert model kwargs to provider model kwargs.
+
+        :param model_rules:
+        :param model_kwargs:
+        :return:
+        """
+        model_kwargs_input = {}
+        for key, value in model_kwargs.dict().items():
+            rule = getattr(model_rules, key)
+            if not rule.enabled:
+                continue
+
+            if rule.alias:
+                key = rule.alias
+
+            if rule.default is not None and value is None:
+                value = rule.default
+
+            if rule.min is not None:
+                value = max(value, rule.min)
+
+            if rule.max is not None:
+                value = min(value, rule.max)
+
+            model_kwargs_input[key] = value
+
+        return model_kwargs_input

+ 70 - 0
api/core/model_providers/models/llm/chatglm_model.py

@@ -0,0 +1,70 @@
+import decimal
+from typing import List, Optional, Any
+
+from langchain.callbacks.manager import Callbacks
+from langchain.llms import ChatGLM
+from langchain.schema import LLMResult
+
+from core.model_providers.error import LLMBadRequestError
+from core.model_providers.models.llm.base import BaseLLM
+from core.model_providers.models.entity.message import PromptMessage, MessageType
+from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
+
+
+class ChatGLMModel(BaseLLM):
+    model_mode: ModelMode = ModelMode.COMPLETION
+
+    def _init_client(self) -> Any:
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
+        return ChatGLM(
+            callbacks=self.callbacks,
+            endpoint_url=self.credentials.get('api_base'),
+            **provider_model_kwargs
+        )
+
+    def _run(self, messages: List[PromptMessage],
+             stop: Optional[List[str]] = None,
+             callbacks: Callbacks = None,
+             **kwargs) -> LLMResult:
+        """
+        run predict by prompt messages and stop words.
+
+        :param messages:
+        :param stop:
+        :param callbacks:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        return self._client.generate([prompts], stop, callbacks)
+
+    def get_num_tokens(self, messages: List[PromptMessage]) -> int:
+        """
+        get num tokens of prompt messages.
+
+        :param messages:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        return max(self._client.get_num_tokens(prompts), 0)
+
+    def get_token_price(self, tokens: int, message_type: MessageType):
+        return decimal.Decimal('0')
+
+    def get_currency(self):
+        return 'RMB'
+
+    def _set_model_kwargs(self, model_kwargs: ModelKwargs):
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
+        for k, v in provider_model_kwargs.items():
+            if hasattr(self.client, k):
+                setattr(self.client, k, v)
+
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        if isinstance(ex, ValueError):
+            return LLMBadRequestError(f"ChatGLM: {str(ex)}")
+        else:
+            return ex
+
+    @classmethod
+    def support_streaming(cls):
+        return False

+ 82 - 0
api/core/model_providers/models/llm/huggingface_hub_model.py

@@ -0,0 +1,82 @@
+import decimal
+from functools import wraps
+from typing import List, Optional, Any
+
+from langchain import HuggingFaceHub
+from langchain.callbacks.manager import Callbacks
+from langchain.llms import HuggingFaceEndpoint
+from langchain.schema import LLMResult
+
+from core.model_providers.error import LLMBadRequestError
+from core.model_providers.models.llm.base import BaseLLM
+from core.model_providers.models.entity.message import PromptMessage, MessageType
+from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
+
+
+class HuggingfaceHubModel(BaseLLM):
+    model_mode: ModelMode = ModelMode.COMPLETION
+
+    def _init_client(self) -> Any:
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
+        if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints':
+            client = HuggingFaceEndpoint(
+                endpoint_url=self.credentials['huggingfacehub_endpoint_url'],
+                task='text2text-generation',
+                model_kwargs=provider_model_kwargs,
+                huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'],
+                callbacks=self.callbacks,
+            )
+        else:
+            client = HuggingFaceHub(
+                repo_id=self.name,
+                task=self.credentials['task_type'],
+                model_kwargs=provider_model_kwargs,
+                huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'],
+                callbacks=self.callbacks,
+            )
+
+        return client
+
+    def _run(self, messages: List[PromptMessage],
+             stop: Optional[List[str]] = None,
+             callbacks: Callbacks = None,
+             **kwargs) -> LLMResult:
+        """
+        run predict by prompt messages and stop words.
+
+        :param messages:
+        :param stop:
+        :param callbacks:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        return self._client.generate([prompts], stop, callbacks)
+
+    def get_num_tokens(self, messages: List[PromptMessage]) -> int:
+        """
+        get num tokens of prompt messages.
+
+        :param messages:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        return self._client.get_num_tokens(prompts)
+
+    def get_token_price(self, tokens: int, message_type: MessageType):
+        # not support calc price
+        return decimal.Decimal('0')
+
+    def get_currency(self):
+        return 'USD'
+
+    def _set_model_kwargs(self, model_kwargs: ModelKwargs):
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
+        self.client.model_kwargs = provider_model_kwargs
+
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        return LLMBadRequestError(f"Huggingface Hub: {str(ex)}")
+
+    @classmethod
+    def support_streaming(cls):
+        return False
+

+ 70 - 0
api/core/model_providers/models/llm/minimax_model.py

@@ -0,0 +1,70 @@
+import decimal
+from typing import List, Optional, Any
+
+from langchain.callbacks.manager import Callbacks
+from langchain.llms import Minimax
+from langchain.schema import LLMResult
+
+from core.model_providers.error import LLMBadRequestError
+from core.model_providers.models.llm.base import BaseLLM
+from core.model_providers.models.entity.message import PromptMessage, MessageType
+from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
+
+
+class MinimaxModel(BaseLLM):
+    model_mode: ModelMode = ModelMode.COMPLETION
+
+    def _init_client(self) -> Any:
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
+        return Minimax(
+            model=self.name,
+            model_kwargs={
+                'stream': False
+            },
+            callbacks=self.callbacks,
+            **self.credentials,
+            **provider_model_kwargs
+        )
+
+    def _run(self, messages: List[PromptMessage],
+             stop: Optional[List[str]] = None,
+             callbacks: Callbacks = None,
+             **kwargs) -> LLMResult:
+        """
+        run predict by prompt messages and stop words.
+
+        :param messages:
+        :param stop:
+        :param callbacks:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        return self._client.generate([prompts], stop, callbacks)
+
+    def get_num_tokens(self, messages: List[PromptMessage]) -> int:
+        """
+        get num tokens of prompt messages.
+
+        :param messages:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        return max(self._client.get_num_tokens(prompts), 0)
+
+    def get_token_price(self, tokens: int, message_type: MessageType):
+        return decimal.Decimal('0')
+
+    def get_currency(self):
+        return 'RMB'
+
+    def _set_model_kwargs(self, model_kwargs: ModelKwargs):
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
+        for k, v in provider_model_kwargs.items():
+            if hasattr(self.client, k):
+                setattr(self.client, k, v)
+
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        if isinstance(ex, ValueError):
+            return LLMBadRequestError(f"Minimax: {str(ex)}")
+        else:
+            return ex

+ 219 - 0
api/core/model_providers/models/llm/openai_model.py

@@ -0,0 +1,219 @@
+import decimal
+import logging
+from typing import List, Optional, Any
+
+import openai
+from langchain.callbacks.manager import Callbacks
+from langchain.schema import LLMResult
+
+from core.model_providers.providers.base import BaseModelProvider
+from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
+from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
+    LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
+from core.third_party.langchain.llms.open_ai import EnhanceOpenAI
+from core.model_providers.models.llm.base import BaseLLM
+from core.model_providers.models.entity.message import PromptMessage, MessageType
+from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
+from models.provider import ProviderType, ProviderQuotaType
+
+COMPLETION_MODELS = [
+    'text-davinci-003',  # 4,097 tokens
+]
+
+CHAT_MODELS = [
+    'gpt-4',  # 8,192 tokens
+    'gpt-4-32k',  # 32,768 tokens
+    'gpt-3.5-turbo',  # 4,096 tokens
+    'gpt-3.5-turbo-16k',  # 16,384 tokens
+]
+
+MODEL_MAX_TOKENS = {
+    'gpt-4': 8192,
+    'gpt-4-32k': 32768,
+    'gpt-3.5-turbo': 4096,
+    'gpt-3.5-turbo-16k': 16384,
+    'text-davinci-003': 4097,
+}
+
+
+class OpenAIModel(BaseLLM):
+    def __init__(self, model_provider: BaseModelProvider,
+                 name: str,
+                 model_kwargs: ModelKwargs,
+                 streaming: bool = False,
+                 callbacks: Callbacks = None):
+        if name in COMPLETION_MODELS:
+            self.model_mode = ModelMode.COMPLETION
+        else:
+            self.model_mode = ModelMode.CHAT
+
+        super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
+
+    def _init_client(self) -> Any:
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
+        if self.name in COMPLETION_MODELS:
+            client = EnhanceOpenAI(
+                model_name=self.name,
+                streaming=self.streaming,
+                callbacks=self.callbacks,
+                request_timeout=60,
+                **self.credentials,
+                **provider_model_kwargs
+            )
+        else:
+            # Fine-tuning is currently only available for the following base models:
+            # davinci, curie, babbage, and ada.
+            # This means that except for the fixed `completion` model,
+            # all other fine-tuned models are `completion` models.
+            extra_model_kwargs = {
+                'top_p': provider_model_kwargs.get('top_p'),
+                'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
+                'presence_penalty': provider_model_kwargs.get('presence_penalty'),
+            }
+
+            client = EnhanceChatOpenAI(
+                model_name=self.name,
+                temperature=provider_model_kwargs.get('temperature'),
+                max_tokens=provider_model_kwargs.get('max_tokens'),
+                model_kwargs=extra_model_kwargs,
+                streaming=self.streaming,
+                callbacks=self.callbacks,
+                request_timeout=60,
+                **self.credentials
+            )
+
+        return client
+
+    def _run(self, messages: List[PromptMessage],
+             stop: Optional[List[str]] = None,
+             callbacks: Callbacks = None,
+             **kwargs) -> LLMResult:
+        """
+        run predict by prompt messages and stop words.
+
+        :param messages:
+        :param stop:
+        :param callbacks:
+        :return:
+        """
+        if self.name == 'gpt-4' \
+                and self.model_provider.provider.provider_type == ProviderType.SYSTEM.value \
+                and self.model_provider.provider.quota_type == ProviderQuotaType.TRIAL.value:
+            raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.")
+
+        prompts = self._get_prompt_from_messages(messages)
+        return self._client.generate([prompts], stop, callbacks)
+
+    def get_num_tokens(self, messages: List[PromptMessage]) -> int:
+        """
+        get num tokens of prompt messages.
+
+        :param messages:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        if isinstance(prompts, str):
+            return self._client.get_num_tokens(prompts)
+        else:
+            return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
+
+    def get_token_price(self, tokens: int, message_type: MessageType):
+        model_unit_prices = {
+            'gpt-4': {
+                'prompt': decimal.Decimal('0.03'),
+                'completion': decimal.Decimal('0.06'),
+            },
+            'gpt-4-32k': {
+                'prompt': decimal.Decimal('0.06'),
+                'completion': decimal.Decimal('0.12')
+            },
+            'gpt-3.5-turbo': {
+                'prompt': decimal.Decimal('0.0015'),
+                'completion': decimal.Decimal('0.002')
+            },
+            'gpt-3.5-turbo-16k': {
+                'prompt': decimal.Decimal('0.003'),
+                'completion': decimal.Decimal('0.004')
+            },
+            'text-davinci-003': {
+                'prompt': decimal.Decimal('0.02'),
+                'completion': decimal.Decimal('0.02')
+            },
+        }
+
+        if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
+            unit_price = model_unit_prices[self.name]['prompt']
+        else:
+            unit_price = model_unit_prices[self.name]['completion']
+
+        tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
+                                                                  rounding=decimal.ROUND_HALF_UP)
+
+        total_price = tokens_per_1k * unit_price
+        return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
+
+    def get_currency(self):
+        return 'USD'
+
+    def _set_model_kwargs(self, model_kwargs: ModelKwargs):
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
+        if self.name in COMPLETION_MODELS:
+            for k, v in provider_model_kwargs.items():
+                if hasattr(self.client, k):
+                    setattr(self.client, k, v)
+        else:
+            extra_model_kwargs = {
+                'top_p': provider_model_kwargs.get('top_p'),
+                'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
+                'presence_penalty': provider_model_kwargs.get('presence_penalty'),
+            }
+
+            self.client.temperature = provider_model_kwargs.get('temperature')
+            self.client.max_tokens = provider_model_kwargs.get('max_tokens')
+            self.client.model_kwargs = extra_model_kwargs
+
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        if isinstance(ex, openai.error.InvalidRequestError):
+            logging.warning("Invalid request to OpenAI API.")
+            return LLMBadRequestError(str(ex))
+        elif isinstance(ex, openai.error.APIConnectionError):
+            logging.warning("Failed to connect to OpenAI API.")
+            return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
+        elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
+            logging.warning("OpenAI service unavailable.")
+            return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
+        elif isinstance(ex, openai.error.RateLimitError):
+            return LLMRateLimitError(str(ex))
+        elif isinstance(ex, openai.error.AuthenticationError):
+            raise LLMAuthorizationError(str(ex))
+        elif isinstance(ex, openai.error.OpenAIError):
+            return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
+        else:
+            return ex
+
+    @classmethod
+    def support_streaming(cls):
+        return True
+
+    # def is_model_valid_or_raise(self):
+    #     """
+    #     check is a valid model.
+    #
+    #     :return:
+    #     """
+    #     credentials = self._model_provider.get_credentials()
+    #
+    #     try:
+    #         result = openai.Model.retrieve(
+    #             id=self.name,
+    #             api_key=credentials.get('openai_api_key'),
+    #             request_timeout=60
+    #         )
+    #
+    #         if 'id' not in result or result['id'] != self.name:
+    #             raise LLMNotExistsError(f"OpenAI Model {self.name} not exists.")
+    #     except openai.error.OpenAIError as e:
+    #         raise LLMNotExistsError(f"OpenAI Model {self.name} not exists, cause: {e.__class__.__name__}:{str(e)}")
+    #     except Exception as e:
+    #         logging.exception("OpenAI Model retrieve failed.")
+    #         raise e

+ 103 - 0
api/core/model_providers/models/llm/replicate_model.py

@@ -0,0 +1,103 @@
+import decimal
+from functools import wraps
+from typing import List, Optional, Any
+
+from langchain.callbacks.manager import Callbacks
+from langchain.schema import LLMResult, get_buffer_string
+from replicate.exceptions import ReplicateError, ModelError
+
+from core.model_providers.providers.base import BaseModelProvider
+from core.model_providers.error import LLMBadRequestError
+from core.third_party.langchain.llms.replicate_llm import EnhanceReplicate
+from core.model_providers.models.llm.base import BaseLLM
+from core.model_providers.models.entity.message import PromptMessage, MessageType
+from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
+
+
+class ReplicateModel(BaseLLM):
+    def __init__(self, model_provider: BaseModelProvider,
+                 name: str,
+                 model_kwargs: ModelKwargs,
+                 streaming: bool = False,
+                 callbacks: Callbacks = None):
+        self.model_mode = ModelMode.CHAT if name.endswith('-chat') else ModelMode.COMPLETION
+
+        super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
+
+    def _init_client(self) -> Any:
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
+
+        return EnhanceReplicate(
+            model=self.name + ':' + self.credentials.get('model_version'),
+            input=provider_model_kwargs,
+            streaming=self.streaming,
+            replicate_api_token=self.credentials.get('replicate_api_token'),
+            callbacks=self.callbacks,
+        )
+
+    def _run(self, messages: List[PromptMessage],
+             stop: Optional[List[str]] = None,
+             callbacks: Callbacks = None,
+             **kwargs) -> LLMResult:
+        """
+        run predict by prompt messages and stop words.
+
+        :param messages:
+        :param stop:
+        :param callbacks:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        extra_kwargs = {}
+        if isinstance(prompts, list):
+            system_messages = [message for message in messages if message.type == 'system']
+            if system_messages:
+                system_message = system_messages[0]
+                extra_kwargs['system_prompt'] = system_message.content
+                prompts = [message for message in messages if message.type != 'system']
+
+            prompts = get_buffer_string(prompts)
+
+        # The maximum length the generated tokens can have.
+        # Corresponds to the length of the input prompt + max_new_tokens.
+        if 'max_length' in self._client.input:
+            self._client.input['max_length'] = min(
+                self._client.input['max_length'] + self.get_num_tokens(messages),
+                self.model_rules.max_tokens.max
+            )
+
+        return self._client.generate([prompts], stop, callbacks, **extra_kwargs)
+
+    def get_num_tokens(self, messages: List[PromptMessage]) -> int:
+        """
+        get num tokens of prompt messages.
+
+        :param messages:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        if isinstance(prompts, list):
+            prompts = get_buffer_string(prompts)
+
+        return self._client.get_num_tokens(prompts)
+
+    def get_token_price(self, tokens: int, message_type: MessageType):
+        # replicate only pay for prediction seconds
+        return decimal.Decimal('0')
+
+    def get_currency(self):
+        return 'USD'
+
+    def _set_model_kwargs(self, model_kwargs: ModelKwargs):
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
+        self.client.input = provider_model_kwargs
+
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        if isinstance(ex, (ModelError, ReplicateError)):
+            return LLMBadRequestError(f"Replicate: {str(ex)}")
+        else:
+            return ex
+
+    @classmethod
+    def support_streaming(cls):
+        return True

+ 73 - 0
api/core/model_providers/models/llm/spark_model.py

@@ -0,0 +1,73 @@
+import decimal
+from functools import wraps
+from typing import List, Optional, Any
+
+from langchain.callbacks.manager import Callbacks
+from langchain.schema import LLMResult
+
+from core.model_providers.error import LLMBadRequestError
+from core.model_providers.models.llm.base import BaseLLM
+from core.model_providers.models.entity.message import PromptMessage, MessageType
+from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
+from core.third_party.langchain.llms.spark import ChatSpark
+from core.third_party.spark.spark_llm import SparkError
+
+
+class SparkModel(BaseLLM):
+    model_mode: ModelMode = ModelMode.CHAT
+
+    def _init_client(self) -> Any:
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
+        return ChatSpark(
+            streaming=self.streaming,
+            callbacks=self.callbacks,
+            **self.credentials,
+            **provider_model_kwargs
+        )
+
+    def _run(self, messages: List[PromptMessage],
+             stop: Optional[List[str]] = None,
+             callbacks: Callbacks = None,
+             **kwargs) -> LLMResult:
+        """
+        run predict by prompt messages and stop words.
+
+        :param messages:
+        :param stop:
+        :param callbacks:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        return self._client.generate([prompts], stop, callbacks)
+
+    def get_num_tokens(self, messages: List[PromptMessage]) -> int:
+        """
+        get num tokens of prompt messages.
+
+        :param messages:
+        :return:
+        """
+        contents = [message.content for message in messages]
+        return max(self._client.get_num_tokens("".join(contents)), 0)
+
+    def get_token_price(self, tokens: int, message_type: MessageType):
+        return decimal.Decimal('0')
+
+    def get_currency(self):
+        return 'RMB'
+
+    def _set_model_kwargs(self, model_kwargs: ModelKwargs):
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
+        for k, v in provider_model_kwargs.items():
+            if hasattr(self.client, k):
+                setattr(self.client, k, v)
+
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        if isinstance(ex, SparkError):
+            return LLMBadRequestError(f"Spark: {str(ex)}")
+        else:
+            return ex
+
+    @classmethod
+    def support_streaming(cls):
+        return True

Nem az összes módosított fájl került megjelenítésre, mert túl sok fájl változott