Browse Source

Model Runtime (#1858)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: Garfield Dai <dai.hai@foxmail.com>
Co-authored-by: chenhe <guchenhe@gmail.com>
Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: Yeuoly <admin@srmxy.cn>
takatost 1 year ago
parent
commit
d069c668f8
100 changed files with 6613 additions and 3373 deletions
  1. 58 0
      .github/workflows/api-model-runtime-tests.yml
  2. 0 38
      .github/workflows/api-unit-tests.yml
  3. 5 0
      CONTRIBUTING.md
  4. 15 0
      api/.vscode/launch.json
  5. 0 3
      api/Dockerfile
  6. 20 14
      api/app.py
  7. 15 13
      api/commands.py
  8. 1 1
      api/config.py
  9. 1 1
      api/controllers/console/__init__.py
  10. 44 34
      api/controllers/console/app/app.py
  11. 4 4
      api/controllers/console/app/audio.py
  12. 13 14
      api/controllers/console/app/completion.py
  13. 4 4
      api/controllers/console/app/generator.py
  14. 13 9
      api/controllers/console/app/message.py
  15. 1 1
      api/controllers/console/app/model_config.py
  16. 30 18
      api/controllers/console/datasets/datasets.py
  17. 11 6
      api/controllers/console/datasets/datasets_document.py
  18. 18 11
      api/controllers/console/datasets/datasets_segments.py
  19. 1 1
      api/controllers/console/datasets/hit_testing.py
  20. 3 4
      api/controllers/console/explore/audio.py
  21. 11 13
      api/controllers/console/explore/completion.py
  22. 17 12
      api/controllers/console/explore/message.py
  23. 3 4
      api/controllers/console/universal_chat/audio.py
  24. 8 9
      api/controllers/console/universal_chat/chat.py
  25. 3 4
      api/controllers/console/universal_chat/message.py
  26. 96 176
      api/controllers/console/workspace/model_providers.py
  27. 212 56
      api/controllers/console/workspace/models.py
  28. 0 131
      api/controllers/console/workspace/providers.py
  29. 0 1
      api/controllers/console/workspace/workspace.py
  30. 3 4
      api/controllers/service_api/app/audio.py
  31. 11 13
      api/controllers/service_api/app/completion.py
  32. 15 7
      api/controllers/service_api/dataset/dataset.py
  33. 1 1
      api/controllers/service_api/dataset/document.py
  34. 18 11
      api/controllers/service_api/dataset/segment.py
  35. 3 4
      api/controllers/web/audio.py
  36. 11 13
      api/controllers/web/completion.py
  37. 14 9
      api/controllers/web/message.py
  38. 101 0
      api/core/agent/agent/agent_llm_callback.py
  39. 33 12
      api/core/agent/agent/calc_token_mixin.py
  40. 38 16
      api/core/agent/agent/multi_dataset_router_agent.py
  41. 91 34
      api/core/agent/agent/openai_function_call.py
  42. 0 158
      api/core/agent/agent/output_parser/retirver_dataset_agent.py
  43. 18 14
      api/core/agent/agent/structed_multi_dataset_router_agent.py
  44. 22 13
      api/core/agent/agent/structured_chat.py
  45. 33 23
      api/core/agent/agent_executor.py
  46. 0 0
      api/core/app_runner/__init__.py
  47. 251 0
      api/core/app_runner/agent_app_runner.py
  48. 267 0
      api/core/app_runner/app_runner.py
  49. 363 0
      api/core/app_runner/basic_app_runner.py
  50. 483 0
      api/core/app_runner/generate_task_pipeline.py
  51. 138 0
      api/core/app_runner/moderation_handler.py
  52. 655 0
      api/core/application_manager.py
  53. 228 0
      api/core/application_queue_manager.py
  54. 110 65
      api/core/callback_handler/agent_loop_gather_callback_handler.py
  55. 0 74
      api/core/callback_handler/dataset_tool_callback_handler.py
  56. 0 16
      api/core/callback_handler/entity/chain_result.py
  57. 0 6
      api/core/callback_handler/entity/dataset_query.py
  58. 0 8
      api/core/callback_handler/entity/llm_message.py
  59. 56 6
      api/core/callback_handler/index_tool_callback_handler.py
  60. 0 284
      api/core/callback_handler/llm_callback_handler.py
  61. 0 76
      api/core/callback_handler/main_chain_gather_callback_handler.py
  62. 5 2
      api/core/callback_handler/std_out_callback_handler.py
  63. 21 8
      api/core/chain/llm_chain.py
  64. 0 501
      api/core/completion.py
  65. 0 517
      api/core/conversation_message_task.py
  66. 19 6
      api/core/docstore/dataset_docstore.py
  67. 26 13
      api/core/embedding/cached_embedding.py
  68. 0 0
      api/core/entities/__init__.py
  69. 265 0
      api/core/entities/application_entities.py
  70. 128 0
      api/core/entities/message_entities.py
  71. 71 0
      api/core/entities/model_entities.py
  72. 657 0
      api/core/entities/provider_configuration.py
  73. 67 0
      api/core/entities/provider_entities.py
  74. 118 0
      api/core/entities/queue_entities.py
  75. 0 0
      api/core/errors/__init__.py
  76. 0 20
      api/core/errors/error.py
  77. 0 0
      api/core/external_data_tool/weather_search/__init__.py
  78. 35 0
      api/core/external_data_tool/weather_search/schema.json
  79. 45 0
      api/core/external_data_tool/weather_search/weather_search.py
  80. 0 0
      api/core/features/__init__.py
  81. 325 0
      api/core/features/agent_runner.py
  82. 119 0
      api/core/features/annotation_reply.py
  83. 181 0
      api/core/features/dataset_retrieval.py
  84. 96 0
      api/core/features/external_data_fetch.py
  85. 32 0
      api/core/features/hosting_moderation.py
  86. 50 0
      api/core/features/moderation.py
  87. 5 5
      api/core/file/file_obj.py
  88. 63 40
      api/core/generator/llm_generator.py
  89. 14 0
      api/core/helper/encrypter.py
  90. 22 0
      api/core/helper/lru_cache.py
  91. 30 18
      api/core/helper/moderation.py
  92. 213 0
      api/core/hosting_configuration.py
  93. 7 11
      api/core/index/index.py
  94. 111 41
      api/core/indexing_runner.py
  95. 0 95
      api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py
  96. 0 36
      api/core/memory/read_only_conversation_token_db_string_buffer_shared_memory.py
  97. 109 0
      api/core/memory/token_buffer_memory.py
  98. 209 0
      api/core/model_manager.py
  99. 0 335
      api/core/model_providers/model_factory.py
  100. 0 276
      api/core/model_providers/model_provider_factory.py

+ 58 - 0
.github/workflows/api-model-runtime-tests.yml

@@ -0,0 +1,58 @@
+name: Run Pytest
+
+on:
+  pull_request:
+    branches:
+      - main
+  push:
+    branches:
+      - deploy/dev
+      - feat/model-runtime
+
+jobs:
+  test:
+    runs-on: ubuntu-latest
+
+    env:
+      OPENAI_API_KEY: sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii
+      AZURE_OPENAI_API_BASE: https://difyai-openai.openai.azure.com
+      AZURE_OPENAI_API_KEY: xxxxb1707exxxxxxxxxxaaxxxxxf94
+      ANTHROPIC_API_KEY: sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz
+      CHATGLM_API_BASE: http://a.abc.com:11451
+      XINFERENCE_SERVER_URL: http://a.abc.com:11451
+      XINFERENCE_GENERATION_MODEL_UID: generate
+      XINFERENCE_CHAT_MODEL_UID: chat
+      XINFERENCE_EMBEDDINGS_MODEL_UID: embedding
+      XINFERENCE_RERANK_MODEL_UID: rerank
+      GOOGLE_API_KEY: abcdefghijklmnopqrstuvwxyz
+      HUGGINGFACE_API_KEY: hf-awuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwu
+      HUGGINGFACE_TEXT_GEN_ENDPOINT_URL: a
+      HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL: b
+      HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL: c
+      MOCK_SWITCH: true
+
+
+    steps:
+    - name: Checkout code
+      uses: actions/checkout@v2
+
+    - name: Set up Python
+      uses: actions/setup-python@v2
+      with:
+        python-version: '3.10'
+
+    - name: Cache pip dependencies
+      uses: actions/cache@v2
+      with:
+        path: ~/.cache/pip
+        key: ${{ runner.os }}-pip-${{ hashFiles('api/requirements.txt') }}
+        restore-keys: ${{ runner.os }}-pip-
+
+    - name: Install dependencies
+      run: |
+        python -m pip install --upgrade pip
+        pip install pytest
+        pip install -r api/requirements.txt
+
+    - name: Run pytest
+      run: pytest api/tests/integration_tests/model_runtime/anthropic api/tests/integration_tests/model_runtime/azure_openai api/tests/integration_tests/model_runtime/openai api/tests/integration_tests/model_runtime/chatglm api/tests/integration_tests/model_runtime/google api/tests/integration_tests/model_runtime/xinference api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py

+ 0 - 38
.github/workflows/api-unit-tests.yml

@@ -1,38 +0,0 @@
-name: Run Pytest
-
-on:
-  pull_request:
-    branches:
-      - main
-  push:
-    branches:
-      - deploy/dev
-
-jobs:
-  test:
-    runs-on: ubuntu-latest
-
-    steps:
-    - name: Checkout code
-      uses: actions/checkout@v2
-
-    - name: Set up Python
-      uses: actions/setup-python@v2
-      with:
-        python-version: '3.10'
-
-    - name: Cache pip dependencies
-      uses: actions/cache@v2
-      with:
-        path: ~/.cache/pip
-        key: ${{ runner.os }}-pip-${{ hashFiles('api/requirements.txt') }}
-        restore-keys: ${{ runner.os }}-pip-
-
-    - name: Install dependencies
-      run: |
-        python -m pip install --upgrade pip
-        pip install pytest
-        pip install -r api/requirements.txt
-
-    - name: Run pytest
-      run: pytest api/tests/unit_tests

+ 5 - 0
CONTRIBUTING.md

@@ -55,6 +55,11 @@ Did you have an issue, like a merge conflict, or don't know how to open a pull r
 
 
 Stuck somewhere? Have any questions? Join the [Discord Community Server](https://discord.gg/j3XRWSPBf7). We are here to help!
 Stuck somewhere? Have any questions? Join the [Discord Community Server](https://discord.gg/j3XRWSPBf7). We are here to help!
 
 
+
+### Provider Integrations
+If you see a model provider not yet supported by Dify that you'd like to use, follow these [steps](api/core/model_runtime/README.md) to submit a PR.
+
+
 ### i18n (Internationalization) Support
 ### i18n (Internationalization) Support
 
 
 We are looking for contributors to help with translations in other languages. If you are interested in helping, please join the [Discord Community Server](https://discord.gg/AhzKf7dNgk) and let us know.  
 We are looking for contributors to help with translations in other languages. If you are interested in helping, please join the [Discord Community Server](https://discord.gg/AhzKf7dNgk) and let us know.  

+ 15 - 0
api/.vscode/launch.json

@@ -4,6 +4,21 @@
     // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
     // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
     "version": "0.2.0",
     "version": "0.2.0",
     "configurations": [
     "configurations": [
+        {
+            "name": "Python: Celery",
+            "type": "python",
+            "request": "launch",
+            "module": "celery",
+            "justMyCode": true,
+            "args": ["-A", "app.celery", "worker", "-P", "gevent", "-c", "1", "--loglevel", "info", "-Q", "dataset,generation,mail"],
+            "envFile": "${workspaceFolder}/.env",
+            "env": {
+                "FLASK_APP": "app.py",
+                "FLASK_DEBUG": "1",
+                "GEVENT_SUPPORT": "True"
+            },
+            "console": "integratedTerminal"
+        },
         {
         {
             "name": "Python: Flask",
             "name": "Python: Flask",
             "type": "python",
             "type": "python",

+ 0 - 3
api/Dockerfile

@@ -34,9 +34,6 @@ RUN apt-get update \
 COPY --from=base /pkg /usr/local
 COPY --from=base /pkg /usr/local
 COPY . /app/api/
 COPY . /app/api/
 
 
-RUN python -c "from transformers import GPT2TokenizerFast; GPT2TokenizerFast.from_pretrained('gpt2')"
-ENV TRANSFORMERS_OFFLINE true
-
 COPY docker/entrypoint.sh /entrypoint.sh
 COPY docker/entrypoint.sh /entrypoint.sh
 RUN chmod +x /entrypoint.sh
 RUN chmod +x /entrypoint.sh
 
 

+ 20 - 14
api/app.py

@@ -6,9 +6,12 @@ from werkzeug.exceptions import Unauthorized
 if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
 if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
     from gevent import monkey
     from gevent import monkey
     monkey.patch_all()
     monkey.patch_all()
-    if os.environ.get("VECTOR_STORE") == 'milvus':
-        import grpc.experimental.gevent
-        grpc.experimental.gevent.init_gevent()
+    # if os.environ.get("VECTOR_STORE") == 'milvus':
+    import grpc.experimental.gevent
+    grpc.experimental.gevent.init_gevent()
+
+    import langchain
+    langchain.verbose = True
 
 
 import time
 import time
 import logging
 import logging
@@ -18,9 +21,8 @@ import threading
 from flask import Flask, request, Response
 from flask import Flask, request, Response
 from flask_cors import CORS
 from flask_cors import CORS
 
 
-from core.model_providers.providers import hosted
 from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
 from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
-    ext_database, ext_storage, ext_mail, ext_code_based_extension
+    ext_database, ext_storage, ext_mail, ext_code_based_extension, ext_hosting_provider
 from extensions.ext_database import db
 from extensions.ext_database import db
 from extensions.ext_login import login_manager
 from extensions.ext_login import login_manager
 
 
@@ -79,8 +81,6 @@ def create_app(test_config=None) -> Flask:
     register_blueprints(app)
     register_blueprints(app)
     register_commands(app)
     register_commands(app)
 
 
-    hosted.init_app(app)
-
     return app
     return app
 
 
 
 
@@ -95,6 +95,7 @@ def initialize_extensions(app):
     ext_celery.init_app(app)
     ext_celery.init_app(app)
     ext_login.init_app(app)
     ext_login.init_app(app)
     ext_mail.init_app(app)
     ext_mail.init_app(app)
+    ext_hosting_provider.init_app(app)
     ext_sentry.init_app(app)
     ext_sentry.init_app(app)
 
 
 
 
@@ -105,13 +106,18 @@ def load_user_from_request(request_from_flask_login):
     if request.blueprint == 'console':
     if request.blueprint == 'console':
         # Check if the user_id contains a dot, indicating the old format
         # Check if the user_id contains a dot, indicating the old format
         auth_header = request.headers.get('Authorization', '')
         auth_header = request.headers.get('Authorization', '')
-        if ' ' not in auth_header:
-            raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
-        auth_scheme, auth_token = auth_header.split(None, 1)
-        auth_scheme = auth_scheme.lower()
-        if auth_scheme != 'bearer':
-            raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
-        
+        if not auth_header:
+            auth_token = request.args.get('_token')
+            if not auth_token:
+                raise Unauthorized('Invalid Authorization token.')
+        else:
+            if ' ' not in auth_header:
+                raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
+            auth_scheme, auth_token = auth_header.split(None, 1)
+            auth_scheme = auth_scheme.lower()
+            if auth_scheme != 'bearer':
+                raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
+
         decoded = PassportService().verify(auth_token)
         decoded = PassportService().verify(auth_token)
         user_id = decoded.get('user_id')
         user_id = decoded.get('user_id')
 
 

+ 15 - 13
api/commands.py

@@ -12,16 +12,12 @@ import qdrant_client
 from qdrant_client.http.models import TextIndexParams, TextIndexType, TokenizerType
 from qdrant_client.http.models import TextIndexParams, TextIndexType, TokenizerType
 from tqdm import tqdm
 from tqdm import tqdm
 from flask import current_app, Flask
 from flask import current_app, Flask
-from langchain.embeddings import OpenAIEmbeddings
 from werkzeug.exceptions import NotFound
 from werkzeug.exceptions import NotFound
 
 
 from core.embedding.cached_embedding import CacheEmbedding
 from core.embedding.cached_embedding import CacheEmbedding
 from core.index.index import IndexBuilder
 from core.index.index import IndexBuilder
-from core.model_providers.model_factory import ModelFactory
-from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
-from core.model_providers.models.entity.model_params import ModelType
-from core.model_providers.providers.hosted import hosted_model_providers
-from core.model_providers.providers.openai_provider import OpenAIProvider
+from core.model_manager import ModelManager
+from core.model_runtime.entities.model_entities import ModelType
 from libs.password import password_pattern, valid_password, hash_password
 from libs.password import password_pattern, valid_password, hash_password
 from libs.helper import email as email_validate
 from libs.helper import email as email_validate
 from extensions.ext_database import db
 from extensions.ext_database import db
@@ -327,6 +323,8 @@ def create_qdrant_indexes():
         except NotFound:
         except NotFound:
             break
             break
 
 
+        model_manager = ModelManager()
+
         page += 1
         page += 1
         for dataset in datasets:
         for dataset in datasets:
             if dataset.index_struct_dict:
             if dataset.index_struct_dict:
@@ -334,19 +332,23 @@ def create_qdrant_indexes():
                     try:
                     try:
                         click.echo('Create dataset qdrant index: {}'.format(dataset.id))
                         click.echo('Create dataset qdrant index: {}'.format(dataset.id))
                         try:
                         try:
-                            embedding_model = ModelFactory.get_embedding_model(
+                            embedding_model = model_manager.get_model_instance(
                                 tenant_id=dataset.tenant_id,
                                 tenant_id=dataset.tenant_id,
-                                model_provider_name=dataset.embedding_model_provider,
-                                model_name=dataset.embedding_model
+                                provider=dataset.embedding_model_provider,
+                                model_type=ModelType.TEXT_EMBEDDING,
+                                model=dataset.embedding_model
+
                             )
                             )
                         except Exception:
                         except Exception:
                             try:
                             try:
-                                embedding_model = ModelFactory.get_embedding_model(
-                                    tenant_id=dataset.tenant_id
+                                embedding_model = model_manager.get_default_model_instance(
+                                    tenant_id=dataset.tenant_id,
+                                    model_type=ModelType.TEXT_EMBEDDING,
                                 )
                                 )
-                                dataset.embedding_model = embedding_model.name
-                                dataset.embedding_model_provider = embedding_model.model_provider.provider_name
+                                dataset.embedding_model = embedding_model.model
+                                dataset.embedding_model_provider = embedding_model.provider
                             except Exception:
                             except Exception:
+
                                 provider = Provider(
                                 provider = Provider(
                                     id='provider_id',
                                     id='provider_id',
                                     tenant_id=dataset.tenant_id,
                                     tenant_id=dataset.tenant_id,

+ 1 - 1
api/config.py

@@ -87,7 +87,7 @@ class Config:
         # ------------------------
         # ------------------------
         # General Configurations.
         # General Configurations.
         # ------------------------
         # ------------------------
-        self.CURRENT_VERSION = "0.3.34"
+        self.CURRENT_VERSION = "0.4.0"
         self.COMMIT_SHA = get_env('COMMIT_SHA')
         self.COMMIT_SHA = get_env('COMMIT_SHA')
         self.EDITION = "SELF_HOSTED"
         self.EDITION = "SELF_HOSTED"
         self.DEPLOY_ENV = get_env('DEPLOY_ENV')
         self.DEPLOY_ENV = get_env('DEPLOY_ENV')

+ 1 - 1
api/controllers/console/__init__.py

@@ -18,7 +18,7 @@ from .auth import login, oauth, data_source_oauth, activate
 from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source
 from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source
 
 
 # Import workspace controllers
 # Import workspace controllers
-from .workspace import workspace, members, providers, model_providers, account, tool_providers, models
+from .workspace import workspace, members, model_providers, account, tool_providers, models
 
 
 # Import explore controllers
 # Import explore controllers
 from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio
 from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio

+ 44 - 34
api/controllers/console/app/app.py

@@ -4,6 +4,10 @@ import logging
 from datetime import datetime
 from datetime import datetime
 
 
 from flask_login import current_user
 from flask_login import current_user
+
+from core.model_manager import ModelManager
+from core.model_runtime.entities.model_entities import ModelType
+from core.provider_manager import ProviderManager
 from libs.login import login_required
 from libs.login import login_required
 from flask_restful import Resource, reqparse, marshal_with, abort, inputs
 from flask_restful import Resource, reqparse, marshal_with, abort, inputs
 from werkzeug.exceptions import Forbidden
 from werkzeug.exceptions import Forbidden
@@ -13,9 +17,7 @@ from controllers.console import api
 from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
 from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
 from controllers.console.setup import setup_required
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
 from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
-from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
-from core.model_providers.model_factory import ModelFactory
-from core.model_providers.model_provider_factory import ModelProviderFactory
+from core.errors.error import ProviderTokenNotInitError, LLMBadRequestError
 from events.app_event import app_was_created, app_was_deleted
 from events.app_event import app_was_created, app_was_deleted
 from fields.app_fields import app_pagination_fields, app_detail_fields, template_list_fields, \
 from fields.app_fields import app_pagination_fields, app_detail_fields, template_list_fields, \
     app_detail_fields_with_site
     app_detail_fields_with_site
@@ -73,39 +75,41 @@ class AppListApi(Resource):
             raise Forbidden()
             raise Forbidden()
 
 
         try:
         try:
-            default_model = ModelFactory.get_text_generation_model(
-                tenant_id=current_user.current_tenant_id
+            provider_manager = ProviderManager()
+            default_model_entity = provider_manager.get_default_model(
+                tenant_id=current_user.current_tenant_id,
+                model_type=ModelType.LLM
             )
             )
         except (ProviderTokenNotInitError, LLMBadRequestError):
         except (ProviderTokenNotInitError, LLMBadRequestError):
-            default_model = None
+            default_model_entity = None
         except Exception as e:
         except Exception as e:
             logging.exception(e)
             logging.exception(e)
-            default_model = None
+            default_model_entity = None
 
 
         if args['model_config'] is not None:
         if args['model_config'] is not None:
             # validate config
             # validate config
             model_config_dict = args['model_config']
             model_config_dict = args['model_config']
 
 
             # get model provider
             # get model provider
-            model_provider = ModelProviderFactory.get_preferred_model_provider(
-                current_user.current_tenant_id,
-                model_config_dict["model"]["provider"]
+            model_manager = ModelManager()
+            model_instance = model_manager.get_default_model_instance(
+                tenant_id=current_user.current_tenant_id,
+                model_type=ModelType.LLM
             )
             )
 
 
-            if not model_provider:
-                if not default_model:
-                    raise ProviderNotInitializeError(
-                        f"No Default System Reasoning Model available. Please configure "
-                        f"in the Settings -> Model Provider.")
-                else:
-                    model_config_dict["model"]["provider"] = default_model.model_provider.provider_name
-                    model_config_dict["model"]["name"] = default_model.name
+            if not model_instance:
+                raise ProviderNotInitializeError(
+                    f"No Default System Reasoning Model available. Please configure "
+                    f"in the Settings -> Model Provider.")
+            else:
+                model_config_dict["model"]["provider"] = model_instance.provider
+                model_config_dict["model"]["name"] = model_instance.model
 
 
             model_configuration = AppModelConfigService.validate_configuration(
             model_configuration = AppModelConfigService.validate_configuration(
                 tenant_id=current_user.current_tenant_id,
                 tenant_id=current_user.current_tenant_id,
                 account=current_user,
                 account=current_user,
                 config=model_config_dict,
                 config=model_config_dict,
-                mode=args['mode']
+                app_mode=args['mode']
             )
             )
 
 
             app = App(
             app = App(
@@ -129,21 +133,27 @@ class AppListApi(Resource):
             app_model_config = AppModelConfig(**model_config_template['model_config'])
             app_model_config = AppModelConfig(**model_config_template['model_config'])
 
 
             # get model provider
             # get model provider
-            model_provider = ModelProviderFactory.get_preferred_model_provider(
-                current_user.current_tenant_id,
-                app_model_config.model_dict["provider"]
-            )
-
-            if not model_provider:
-                if not default_model:
-                    raise ProviderNotInitializeError(
-                        f"No Default System Reasoning Model available. Please configure "
-                        f"in the Settings -> Model Provider.")
-                else:
-                    model_dict = app_model_config.model_dict
-                    model_dict['provider'] = default_model.model_provider.provider_name
-                    model_dict['name'] = default_model.name
-                    app_model_config.model = json.dumps(model_dict)
+            model_manager = ModelManager()
+
+            try:
+                model_instance = model_manager.get_default_model_instance(
+                    tenant_id=current_user.current_tenant_id,
+                    model_type=ModelType.LLM
+                )
+            except ProviderTokenNotInitError:
+                raise ProviderNotInitializeError(
+                    f"No Default System Reasoning Model available. Please configure "
+                    f"in the Settings -> Model Provider.")
+
+            if not model_instance:
+                raise ProviderNotInitializeError(
+                    f"No Default System Reasoning Model available. Please configure "
+                    f"in the Settings -> Model Provider.")
+            else:
+                model_dict = app_model_config.model_dict
+                model_dict['provider'] = model_instance.provider
+                model_dict['name'] = model_instance.model
+                app_model_config.model = json.dumps(model_dict)
 
 
         app.name = args['name']
         app.name = args['name']
         app.mode = args['mode']
         app.mode = args['mode']

+ 4 - 4
api/controllers/console/app/audio.py

@@ -2,6 +2,8 @@
 import logging
 import logging
 
 
 from flask import request
 from flask import request
+
+from core.model_runtime.errors.invoke import InvokeError
 from libs.login import login_required
 from libs.login import login_required
 from werkzeug.exceptions import InternalServerError
 from werkzeug.exceptions import InternalServerError
 
 
@@ -14,8 +16,7 @@ from controllers.console.app.error import AppUnavailableError, \
     UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
     UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
 from controllers.console.setup import setup_required
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
 from controllers.console.wraps import account_initialization_required
-from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
-    LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
 from flask_restful import Resource
 from flask_restful import Resource
 from services.audio_service import AudioService
 from services.audio_service import AudioService
 from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
 from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
@@ -56,8 +57,7 @@ class ChatMessageAudioApi(Resource):
             raise ProviderQuotaExceededError()
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
             raise CompletionRequestError(str(e))
         except ValueError as e:
         except ValueError as e:
             raise e
             raise e

+ 13 - 14
api/controllers/console/app/completion.py

@@ -5,6 +5,10 @@ from typing import Generator, Union
 
 
 import flask_login
 import flask_login
 from flask import Response, stream_with_context
 from flask import Response, stream_with_context
+
+from core.application_queue_manager import ApplicationQueueManager
+from core.entities.application_entities import InvokeFrom
+from core.model_runtime.errors.invoke import InvokeError
 from libs.login import login_required
 from libs.login import login_required
 from werkzeug.exceptions import InternalServerError, NotFound
 from werkzeug.exceptions import InternalServerError, NotFound
 
 
@@ -16,9 +20,7 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail
     ProviderModelCurrentlyNotSupportError
     ProviderModelCurrentlyNotSupportError
 from controllers.console.setup import setup_required
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
 from controllers.console.wraps import account_initialization_required
-from core.conversation_message_task import PubHandler
-from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
-    LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
 from libs.helper import uuid_value
 from libs.helper import uuid_value
 from flask_restful import Resource, reqparse
 from flask_restful import Resource, reqparse
 
 
@@ -56,7 +58,7 @@ class CompletionMessageApi(Resource):
                 app_model=app_model,
                 app_model=app_model,
                 user=account,
                 user=account,
                 args=args,
                 args=args,
-                from_source='console',
+                invoke_from=InvokeFrom.DEBUGGER,
                 streaming=streaming,
                 streaming=streaming,
                 is_model_config_override=True
                 is_model_config_override=True
             )
             )
@@ -75,8 +77,7 @@ class CompletionMessageApi(Resource):
             raise ProviderQuotaExceededError()
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
             raise CompletionRequestError(str(e))
         except ValueError as e:
         except ValueError as e:
             raise e
             raise e
@@ -97,7 +98,7 @@ class CompletionMessageStopApi(Resource):
 
 
         account = flask_login.current_user
         account = flask_login.current_user
 
 
-        PubHandler.stop(account, task_id)
+        ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
 
 
         return {'result': 'success'}, 200
         return {'result': 'success'}, 200
 
 
@@ -132,7 +133,7 @@ class ChatMessageApi(Resource):
                 app_model=app_model,
                 app_model=app_model,
                 user=account,
                 user=account,
                 args=args,
                 args=args,
-                from_source='console',
+                invoke_from=InvokeFrom.DEBUGGER,
                 streaming=streaming,
                 streaming=streaming,
                 is_model_config_override=True
                 is_model_config_override=True
             )
             )
@@ -151,8 +152,7 @@ class ChatMessageApi(Resource):
             raise ProviderQuotaExceededError()
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
             raise CompletionRequestError(str(e))
         except ValueError as e:
         except ValueError as e:
             raise e
             raise e
@@ -182,9 +182,8 @@ def compact_response(response: Union[dict, Generator]) -> Response:
                 yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
                 yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
             except ModelCurrentlyNotSupportError:
             except ModelCurrentlyNotSupportError:
                 yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
                 yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
-            except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                    LLMRateLimitError, LLMAuthorizationError) as e:
-                yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
+            except InvokeError as e:
+                yield "data: " + json.dumps(api.handle_error(CompletionRequestError(e.description)).get_json()) + "\n\n"
             except ValueError as e:
             except ValueError as e:
                 yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
                 yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
             except Exception:
             except Exception:
@@ -207,7 +206,7 @@ class ChatMessageStopApi(Resource):
 
 
         account = flask_login.current_user
         account = flask_login.current_user
 
 
-        PubHandler.stop(account, task_id)
+        ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
 
 
         return {'result': 'success'}, 200
         return {'result': 'success'}, 200
 
 

+ 4 - 4
api/controllers/console/app/generator.py

@@ -1,4 +1,6 @@
 from flask_login import current_user
 from flask_login import current_user
+
+from core.model_runtime.errors.invoke import InvokeError
 from libs.login import login_required
 from libs.login import login_required
 from flask_restful import Resource, reqparse
 from flask_restful import Resource, reqparse
 
 
@@ -8,8 +10,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
 from controllers.console.setup import setup_required
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
 from controllers.console.wraps import account_initialization_required
 from core.generator.llm_generator import LLMGenerator
 from core.generator.llm_generator import LLMGenerator
-from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \
-    LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
 
 
 
 
 class RuleGenerateApi(Resource):
 class RuleGenerateApi(Resource):
@@ -36,8 +37,7 @@ class RuleGenerateApi(Resource):
             raise ProviderQuotaExceededError()
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
             raise CompletionRequestError(str(e))
 
 
         return rules
         return rules

+ 13 - 9
api/controllers/console/app/message.py

@@ -14,8 +14,9 @@ from controllers.console.app.error import CompletionRequestError, ProviderNotIni
     AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
     AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
 from controllers.console.setup import setup_required
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
 from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
-from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
-    ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.entities.application_entities import InvokeFrom
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.model_runtime.errors.invoke import InvokeError
 from libs.login import login_required
 from libs.login import login_required
 from fields.conversation_fields import message_detail_fields, annotation_fields
 from fields.conversation_fields import message_detail_fields, annotation_fields
 from libs.helper import uuid_value
 from libs.helper import uuid_value
@@ -208,7 +209,13 @@ class MessageMoreLikeThisApi(Resource):
         app_model = _get_app(app_id, 'completion')
         app_model = _get_app(app_id, 'completion')
 
 
         try:
         try:
-            response = CompletionService.generate_more_like_this(app_model, current_user, message_id, streaming)
+            response = CompletionService.generate_more_like_this(
+                app_model=app_model,
+                user=current_user,
+                message_id=message_id,
+                invoke_from=InvokeFrom.DEBUGGER,
+                streaming=streaming
+            )
             return compact_response(response)
             return compact_response(response)
         except MessageNotExistsError:
         except MessageNotExistsError:
             raise NotFound("Message Not Exists.")
             raise NotFound("Message Not Exists.")
@@ -220,8 +227,7 @@ class MessageMoreLikeThisApi(Resource):
             raise ProviderQuotaExceededError()
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
             raise CompletionRequestError(str(e))
         except ValueError as e:
         except ValueError as e:
             raise e
             raise e
@@ -249,8 +255,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
             except ModelCurrentlyNotSupportError:
             except ModelCurrentlyNotSupportError:
                 yield "data: " + json.dumps(
                 yield "data: " + json.dumps(
                     api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
                     api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
-            except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                    LLMRateLimitError, LLMAuthorizationError) as e:
+            except InvokeError as e:
                 yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
                 yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
             except ValueError as e:
             except ValueError as e:
                 yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
                 yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
@@ -290,8 +295,7 @@ class MessageSuggestedQuestionApi(Resource):
             raise ProviderQuotaExceededError()
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
             raise CompletionRequestError(str(e))
         except Exception:
         except Exception:
             logging.exception("internal server error.")
             logging.exception("internal server error.")

+ 1 - 1
api/controllers/console/app/model_config.py

@@ -31,7 +31,7 @@ class ModelConfigResource(Resource):
             tenant_id=current_user.current_tenant_id,
             tenant_id=current_user.current_tenant_id,
             account=current_user,
             account=current_user,
             config=request.json,
             config=request.json,
-            mode=app.mode
+            app_mode=app.mode
         )
         )
 
 
         new_app_model_config = AppModelConfig(
         new_app_model_config = AppModelConfig(

+ 30 - 18
api/controllers/console/datasets/datasets.py

@@ -4,6 +4,8 @@ from flask import request, current_app
 from flask_login import current_user
 from flask_login import current_user
 
 
 from controllers.console.apikey import api_key_list, api_key_fields
 from controllers.console.apikey import api_key_list, api_key_fields
+from core.model_runtime.entities.model_entities import ModelType
+from core.provider_manager import ProviderManager
 from libs.login import login_required
 from libs.login import login_required
 from flask_restful import Resource, reqparse, marshal, marshal_with
 from flask_restful import Resource, reqparse, marshal, marshal_with
 from werkzeug.exceptions import NotFound, Forbidden
 from werkzeug.exceptions import NotFound, Forbidden
@@ -14,8 +16,7 @@ from controllers.console.datasets.error import DatasetNameDuplicateError
 from controllers.console.setup import setup_required
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
 from controllers.console.wraps import account_initialization_required
 from core.indexing_runner import IndexingRunner
 from core.indexing_runner import IndexingRunner
-from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
-from core.model_providers.models.entity.model_params import ModelType
+from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
 from fields.app_fields import related_app_list
 from fields.app_fields import related_app_list
 from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
 from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
 from fields.document_fields import document_status_fields
 from fields.document_fields import document_status_fields
@@ -23,7 +24,6 @@ from extensions.ext_database import db
 from models.dataset import DocumentSegment, Document
 from models.dataset import DocumentSegment, Document
 from models.model import UploadFile, ApiToken
 from models.model import UploadFile, ApiToken
 from services.dataset_service import DatasetService, DocumentService
 from services.dataset_service import DatasetService, DocumentService
-from services.provider_service import ProviderService
 
 
 
 
 def _validate_name(name):
 def _validate_name(name):
@@ -55,16 +55,20 @@ class DatasetListApi(Resource):
                                                           current_user.current_tenant_id, current_user)
                                                           current_user.current_tenant_id, current_user)
 
 
         # check embedding setting
         # check embedding setting
-        provider_service = ProviderService()
-        valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
-                                                                 ModelType.EMBEDDINGS.value)
-        # if len(valid_model_list) == 0:
-        #     raise ProviderNotInitializeError(
-        #         f"No Embedding Model available. Please configure a valid provider "
-        #         f"in the Settings -> Model Provider.")
+        provider_manager = ProviderManager()
+        configurations = provider_manager.get_configurations(
+            tenant_id=current_user.current_tenant_id
+        )
+
+        embedding_models = configurations.get_models(
+            model_type=ModelType.TEXT_EMBEDDING,
+            only_active=True
+        )
+
         model_names = []
         model_names = []
-        for valid_model in valid_model_list:
-            model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
+        for embedding_model in embedding_models:
+            model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
+
         data = marshal(datasets, dataset_detail_fields)
         data = marshal(datasets, dataset_detail_fields)
         for item in data:
         for item in data:
             if item['indexing_technique'] == 'high_quality':
             if item['indexing_technique'] == 'high_quality':
@@ -75,6 +79,7 @@ class DatasetListApi(Resource):
                     item['embedding_available'] = False
                     item['embedding_available'] = False
             else:
             else:
                 item['embedding_available'] = True
                 item['embedding_available'] = True
+
         response = {
         response = {
             'data': data,
             'data': data,
             'has_more': len(datasets) == limit,
             'has_more': len(datasets) == limit,
@@ -130,13 +135,20 @@ class DatasetApi(Resource):
             raise Forbidden(str(e))
             raise Forbidden(str(e))
         data = marshal(dataset, dataset_detail_fields)
         data = marshal(dataset, dataset_detail_fields)
         # check embedding setting
         # check embedding setting
-        provider_service = ProviderService()
-        # get valid model list
-        valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
-                                                                 ModelType.EMBEDDINGS.value)
+        provider_manager = ProviderManager()
+        configurations = provider_manager.get_configurations(
+            tenant_id=current_user.current_tenant_id
+        )
+
+        embedding_models = configurations.get_models(
+            model_type=ModelType.TEXT_EMBEDDING,
+            only_active=True
+        )
+
         model_names = []
         model_names = []
-        for valid_model in valid_model_list:
-            model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
+        for embedding_model in embedding_models:
+            model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
+
         if data['indexing_technique'] == 'high_quality':
         if data['indexing_technique'] == 'high_quality':
             item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
             item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
             if item_model in model_names:
             if item_model in model_names:

+ 11 - 6
api/controllers/console/datasets/datasets_document.py

@@ -2,8 +2,12 @@
 from datetime import datetime
 from datetime import datetime
 from typing import List
 from typing import List
 
 
-from flask import request, current_app
+from flask import request
 from flask_login import current_user
 from flask_login import current_user
+
+from core.model_manager import ModelManager
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.invoke import InvokeAuthorizationError
 from libs.login import login_required
 from libs.login import login_required
 from flask_restful import Resource, fields, marshal, marshal_with, reqparse
 from flask_restful import Resource, fields, marshal, marshal_with, reqparse
 from sqlalchemy import desc, asc
 from sqlalchemy import desc, asc
@@ -18,9 +22,8 @@ from controllers.console.datasets.error import DocumentAlreadyFinishedError, Inv
 from controllers.console.setup import setup_required
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
 from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
 from core.indexing_runner import IndexingRunner
 from core.indexing_runner import IndexingRunner
-from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
     LLMBadRequestError
     LLMBadRequestError
-from core.model_providers.model_factory import ModelFactory
 from extensions.ext_redis import redis_client
 from extensions.ext_redis import redis_client
 from fields.document_fields import document_with_segments_fields, document_fields, \
 from fields.document_fields import document_with_segments_fields, document_fields, \
     dataset_and_document_fields, document_status_fields
     dataset_and_document_fields, document_status_fields
@@ -272,10 +275,12 @@ class DatasetInitApi(Resource):
         args = parser.parse_args()
         args = parser.parse_args()
         if args['indexing_technique'] == 'high_quality':
         if args['indexing_technique'] == 'high_quality':
             try:
             try:
-                ModelFactory.get_embedding_model(
-                    tenant_id=current_user.current_tenant_id
+                model_manager = ModelManager()
+                model_manager.get_default_model_instance(
+                    tenant_id=current_user.current_tenant_id,
+                    model_type=ModelType.TEXT_EMBEDDING
                 )
                 )
-            except LLMBadRequestError:
+            except InvokeAuthorizationError:
                 raise ProviderNotInitializeError(
                 raise ProviderNotInitializeError(
                     f"No Embedding Model available. Please configure a valid provider "
                     f"No Embedding Model available. Please configure a valid provider "
                     f"in the Settings -> Model Provider.")
                     f"in the Settings -> Model Provider.")

+ 18 - 11
api/controllers/console/datasets/datasets_segments.py

@@ -12,8 +12,9 @@ from controllers.console.app.error import ProviderNotInitializeError
 from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError
 from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError
 from controllers.console.setup import setup_required
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
 from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
-from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
-from core.model_providers.model_factory import ModelFactory
+from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
+from core.model_manager import ModelManager
+from core.model_runtime.entities.model_entities import ModelType
 from libs.login import login_required
 from libs.login import login_required
 from extensions.ext_database import db
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from extensions.ext_redis import redis_client
@@ -133,10 +134,12 @@ class DatasetDocumentSegmentApi(Resource):
         if dataset.indexing_technique == 'high_quality':
         if dataset.indexing_technique == 'high_quality':
             # check embedding model setting
             # check embedding model setting
             try:
             try:
-                ModelFactory.get_embedding_model(
+                model_manager = ModelManager()
+                model_manager.get_model_instance(
                     tenant_id=current_user.current_tenant_id,
                     tenant_id=current_user.current_tenant_id,
-                    model_provider_name=dataset.embedding_model_provider,
-                    model_name=dataset.embedding_model
+                    provider=dataset.embedding_model_provider,
+                    model_type=ModelType.TEXT_EMBEDDING,
+                    model=dataset.embedding_model
                 )
                 )
             except LLMBadRequestError:
             except LLMBadRequestError:
                 raise ProviderNotInitializeError(
                 raise ProviderNotInitializeError(
@@ -219,10 +222,12 @@ class DatasetDocumentSegmentAddApi(Resource):
         # check embedding model setting
         # check embedding model setting
         if dataset.indexing_technique == 'high_quality':
         if dataset.indexing_technique == 'high_quality':
             try:
             try:
-                ModelFactory.get_embedding_model(
+                model_manager = ModelManager()
+                model_manager.get_model_instance(
                     tenant_id=current_user.current_tenant_id,
                     tenant_id=current_user.current_tenant_id,
-                    model_provider_name=dataset.embedding_model_provider,
-                    model_name=dataset.embedding_model
+                    provider=dataset.embedding_model_provider,
+                    model_type=ModelType.TEXT_EMBEDDING,
+                    model=dataset.embedding_model
                 )
                 )
             except LLMBadRequestError:
             except LLMBadRequestError:
                 raise ProviderNotInitializeError(
                 raise ProviderNotInitializeError(
@@ -269,10 +274,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
         if dataset.indexing_technique == 'high_quality':
         if dataset.indexing_technique == 'high_quality':
             # check embedding model setting
             # check embedding model setting
             try:
             try:
-                ModelFactory.get_embedding_model(
+                model_manager = ModelManager()
+                model_manager.get_model_instance(
                     tenant_id=current_user.current_tenant_id,
                     tenant_id=current_user.current_tenant_id,
-                    model_provider_name=dataset.embedding_model_provider,
-                    model_name=dataset.embedding_model
+                    provider=dataset.embedding_model_provider,
+                    model_type=ModelType.TEXT_EMBEDDING,
+                    model=dataset.embedding_model
                 )
                 )
             except LLMBadRequestError:
             except LLMBadRequestError:
                 raise ProviderNotInitializeError(
                 raise ProviderNotInitializeError(

+ 1 - 1
api/controllers/console/datasets/hit_testing.py

@@ -12,7 +12,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
 from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
 from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
 from controllers.console.setup import setup_required
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
 from controllers.console.wraps import account_initialization_required
-from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
     LLMBadRequestError
     LLMBadRequestError
 from fields.hit_testing_fields import hit_testing_record_fields
 from fields.hit_testing_fields import hit_testing_record_fields
 from services.dataset_service import DatasetService
 from services.dataset_service import DatasetService

+ 3 - 4
api/controllers/console/explore/audio.py

@@ -11,8 +11,8 @@ from controllers.console.app.error import AppUnavailableError, ProviderNotInitia
     NoAudioUploadedError, AudioTooLargeError, \
     NoAudioUploadedError, AudioTooLargeError, \
     UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
     UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
 from controllers.console.explore.wraps import InstalledAppResource
 from controllers.console.explore.wraps import InstalledAppResource
-from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
-    LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.model_runtime.errors.invoke import InvokeError
 from services.audio_service import AudioService
 from services.audio_service import AudioService
 from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
 from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
     UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
     UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
@@ -53,8 +53,7 @@ class ChatAudioApi(InstalledAppResource):
             raise ProviderQuotaExceededError()
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
             raise CompletionRequestError(str(e))
         except ValueError as e:
         except ValueError as e:
             raise e
             raise e

+ 11 - 13
api/controllers/console/explore/completion.py

@@ -15,9 +15,10 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
 from controllers.console.explore.error import NotCompletionAppError, NotChatAppError
 from controllers.console.explore.error import NotCompletionAppError, NotChatAppError
 from controllers.console.explore.wraps import InstalledAppResource
 from controllers.console.explore.wraps import InstalledAppResource
-from core.conversation_message_task import PubHandler
-from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
-    LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.application_queue_manager import ApplicationQueueManager
+from core.entities.application_entities import InvokeFrom
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.model_runtime.errors.invoke import InvokeError
 from extensions.ext_database import db
 from extensions.ext_database import db
 from libs.helper import uuid_value
 from libs.helper import uuid_value
 from services.completion_service import CompletionService
 from services.completion_service import CompletionService
@@ -50,7 +51,7 @@ class CompletionApi(InstalledAppResource):
                 app_model=app_model,
                 app_model=app_model,
                 user=current_user,
                 user=current_user,
                 args=args,
                 args=args,
-                from_source='console',
+                invoke_from=InvokeFrom.EXPLORE,
                 streaming=streaming
                 streaming=streaming
             )
             )
 
 
@@ -68,8 +69,7 @@ class CompletionApi(InstalledAppResource):
             raise ProviderQuotaExceededError()
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
             raise CompletionRequestError(str(e))
         except ValueError as e:
         except ValueError as e:
             raise e
             raise e
@@ -84,7 +84,7 @@ class CompletionStopApi(InstalledAppResource):
         if app_model.mode != 'completion':
         if app_model.mode != 'completion':
             raise NotCompletionAppError()
             raise NotCompletionAppError()
 
 
-        PubHandler.stop(current_user, task_id)
+        ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
 
 
         return {'result': 'success'}, 200
         return {'result': 'success'}, 200
 
 
@@ -115,7 +115,7 @@ class ChatApi(InstalledAppResource):
                 app_model=app_model,
                 app_model=app_model,
                 user=current_user,
                 user=current_user,
                 args=args,
                 args=args,
-                from_source='console',
+                invoke_from=InvokeFrom.EXPLORE,
                 streaming=streaming
                 streaming=streaming
             )
             )
 
 
@@ -133,8 +133,7 @@ class ChatApi(InstalledAppResource):
             raise ProviderQuotaExceededError()
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
             raise CompletionRequestError(str(e))
         except ValueError as e:
         except ValueError as e:
             raise e
             raise e
@@ -149,7 +148,7 @@ class ChatStopApi(InstalledAppResource):
         if app_model.mode != 'chat':
         if app_model.mode != 'chat':
             raise NotChatAppError()
             raise NotChatAppError()
 
 
-        PubHandler.stop(current_user, task_id)
+        ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
 
 
         return {'result': 'success'}, 200
         return {'result': 'success'}, 200
 
 
@@ -175,8 +174,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
                 yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
                 yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
             except ModelCurrentlyNotSupportError:
             except ModelCurrentlyNotSupportError:
                 yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
                 yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
-            except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                    LLMRateLimitError, LLMAuthorizationError) as e:
+            except InvokeError as e:
                 yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
                 yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
             except ValueError as e:
             except ValueError as e:
                 yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
                 yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"

+ 17 - 12
api/controllers/console/explore/message.py

@@ -5,7 +5,7 @@ from typing import Generator, Union
 
 
 from flask import stream_with_context, Response
 from flask import stream_with_context, Response
 from flask_login import current_user
 from flask_login import current_user
-from flask_restful import reqparse, fields, marshal_with
+from flask_restful import reqparse, marshal_with
 from flask_restful.inputs import int_range
 from flask_restful.inputs import int_range
 from werkzeug.exceptions import NotFound, InternalServerError
 from werkzeug.exceptions import NotFound, InternalServerError
 
 
@@ -13,12 +13,14 @@ import services
 from controllers.console import api
 from controllers.console import api
 from controllers.console.app.error import AppMoreLikeThisDisabledError, ProviderNotInitializeError, \
 from controllers.console.app.error import AppMoreLikeThisDisabledError, ProviderNotInitializeError, \
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
-from controllers.console.explore.error import NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError
+from controllers.console.explore.error import NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \
+    NotChatAppError
 from controllers.console.explore.wraps import InstalledAppResource
 from controllers.console.explore.wraps import InstalledAppResource
-from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
-    ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.entities.application_entities import InvokeFrom
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.model_runtime.errors.invoke import InvokeError
 from fields.message_fields import message_infinite_scroll_pagination_fields
 from fields.message_fields import message_infinite_scroll_pagination_fields
-from libs.helper import uuid_value, TimestampField
+from libs.helper import uuid_value
 from services.completion_service import CompletionService
 from services.completion_service import CompletionService
 from services.errors.app import MoreLikeThisDisabledError
 from services.errors.app import MoreLikeThisDisabledError
 from services.errors.conversation import ConversationNotExistsError
 from services.errors.conversation import ConversationNotExistsError
@@ -83,7 +85,13 @@ class MessageMoreLikeThisApi(InstalledAppResource):
         streaming = args['response_mode'] == 'streaming'
         streaming = args['response_mode'] == 'streaming'
 
 
         try:
         try:
-            response = CompletionService.generate_more_like_this(app_model, current_user, message_id, streaming)
+            response = CompletionService.generate_more_like_this(
+                app_model=app_model,
+                user=current_user,
+                message_id=message_id,
+                invoke_from=InvokeFrom.EXPLORE,
+                streaming=streaming
+            )
             return compact_response(response)
             return compact_response(response)
         except MessageNotExistsError:
         except MessageNotExistsError:
             raise NotFound("Message Not Exists.")
             raise NotFound("Message Not Exists.")
@@ -95,8 +103,7 @@ class MessageMoreLikeThisApi(InstalledAppResource):
             raise ProviderQuotaExceededError()
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
             raise CompletionRequestError(str(e))
         except ValueError as e:
         except ValueError as e:
             raise e
             raise e
@@ -123,8 +130,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
                 yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
                 yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
             except ModelCurrentlyNotSupportError:
             except ModelCurrentlyNotSupportError:
                 yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
                 yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
-            except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                    LLMRateLimitError, LLMAuthorizationError) as e:
+            except InvokeError as e:
                 yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
                 yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
             except ValueError as e:
             except ValueError as e:
                 yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
                 yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
@@ -162,8 +168,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
             raise ProviderQuotaExceededError()
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
             raise CompletionRequestError(str(e))
         except Exception:
         except Exception:
             logging.exception("internal server error.")
             logging.exception("internal server error.")

+ 3 - 4
api/controllers/console/universal_chat/audio.py

@@ -11,8 +11,8 @@ from controllers.console.app.error import AppUnavailableError, ProviderNotInitia
     NoAudioUploadedError, AudioTooLargeError, \
     NoAudioUploadedError, AudioTooLargeError, \
     UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
     UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
 from controllers.console.universal_chat.wraps import UniversalChatResource
 from controllers.console.universal_chat.wraps import UniversalChatResource
-from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
-    LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.model_runtime.errors.invoke import InvokeError
 from services.audio_service import AudioService
 from services.audio_service import AudioService
 from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
 from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
     UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
     UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
@@ -53,8 +53,7 @@ class UniversalChatAudioApi(UniversalChatResource):
             raise ProviderQuotaExceededError()
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
             raise CompletionRequestError(str(e))
         except ValueError as e:
         except ValueError as e:
             raise e
             raise e

+ 8 - 9
api/controllers/console/universal_chat/chat.py

@@ -12,9 +12,10 @@ from controllers.console import api
 from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, ProviderNotInitializeError, \
 from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, ProviderNotInitializeError, \
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
 from controllers.console.universal_chat.wraps import UniversalChatResource
 from controllers.console.universal_chat.wraps import UniversalChatResource
-from core.conversation_message_task import PubHandler
-from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
-    LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError
+from core.application_queue_manager import ApplicationQueueManager
+from core.entities.application_entities import InvokeFrom
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.model_runtime.errors.invoke import InvokeError
 from libs.helper import uuid_value
 from libs.helper import uuid_value
 from services.completion_service import CompletionService
 from services.completion_service import CompletionService
 
 
@@ -68,7 +69,7 @@ class UniversalChatApi(UniversalChatResource):
                 app_model=app_model,
                 app_model=app_model,
                 user=current_user,
                 user=current_user,
                 args=args,
                 args=args,
-                from_source='console',
+                invoke_from=InvokeFrom.EXPLORE,
                 streaming=True,
                 streaming=True,
                 is_model_config_override=True,
                 is_model_config_override=True,
             )
             )
@@ -87,8 +88,7 @@ class UniversalChatApi(UniversalChatResource):
             raise ProviderQuotaExceededError()
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
             raise CompletionRequestError(str(e))
         except ValueError as e:
         except ValueError as e:
             raise e
             raise e
@@ -99,7 +99,7 @@ class UniversalChatApi(UniversalChatResource):
 
 
 class UniversalChatStopApi(UniversalChatResource):
 class UniversalChatStopApi(UniversalChatResource):
     def post(self, universal_app, task_id):
     def post(self, universal_app, task_id):
-        PubHandler.stop(current_user, task_id)
+        ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
 
 
         return {'result': 'success'}, 200
         return {'result': 'success'}, 200
 
 
@@ -125,8 +125,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
                 yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
                 yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
             except ModelCurrentlyNotSupportError:
             except ModelCurrentlyNotSupportError:
                 yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
                 yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
-            except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                    LLMRateLimitError, LLMAuthorizationError) as e:
+            except InvokeError as e:
                 yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
                 yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
             except ValueError as e:
             except ValueError as e:
                 yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
                 yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"

+ 3 - 4
api/controllers/console/universal_chat/message.py

@@ -12,8 +12,8 @@ from controllers.console.app.error import ProviderNotInitializeError, \
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
 from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
 from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
 from controllers.console.universal_chat.wraps import UniversalChatResource
 from controllers.console.universal_chat.wraps import UniversalChatResource
-from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
-    ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.model_runtime.errors.invoke import InvokeError
 from libs.helper import uuid_value, TimestampField
 from libs.helper import uuid_value, TimestampField
 from services.errors.conversation import ConversationNotExistsError
 from services.errors.conversation import ConversationNotExistsError
 from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
 from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
@@ -132,8 +132,7 @@ class UniversalChatMessageSuggestedQuestionApi(UniversalChatResource):
             raise ProviderQuotaExceededError()
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
             raise CompletionRequestError(str(e))
         except Exception:
         except Exception:
             logging.exception("internal server error.")
             logging.exception("internal server error.")

+ 96 - 176
api/controllers/console/workspace/model_providers.py

@@ -1,16 +1,19 @@
+import io
+
+from flask import send_file
 from flask_login import current_user
 from flask_login import current_user
-from libs.login import login_required
 from flask_restful import Resource, reqparse
 from flask_restful import Resource, reqparse
 from werkzeug.exceptions import Forbidden
 from werkzeug.exceptions import Forbidden
 
 
 from controllers.console import api
 from controllers.console import api
-from controllers.console.app.error import ProviderNotInitializeError
 from controllers.console.setup import setup_required
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
 from controllers.console.wraps import account_initialization_required
-from core.model_providers.error import LLMBadRequestError
-from core.model_providers.providers.base import CredentialsValidateFailedError
-from services.provider_service import ProviderService
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.utils.encoders import jsonable_encoder
+from libs.login import login_required
 from services.billing_service import BillingService
 from services.billing_service import BillingService
+from services.model_provider_service import ModelProviderService
 
 
 
 
 class ModelProviderListApi(Resource):
 class ModelProviderListApi(Resource):
@@ -22,13 +25,36 @@ class ModelProviderListApi(Resource):
         tenant_id = current_user.current_tenant_id
         tenant_id = current_user.current_tenant_id
 
 
         parser = reqparse.RequestParser()
         parser = reqparse.RequestParser()
-        parser.add_argument('model_type', type=str, required=False, nullable=True, location='args')
+        parser.add_argument('model_type', type=str, required=False, nullable=True,
+                            choices=[mt.value for mt in ModelType], location='args')
         args = parser.parse_args()
         args = parser.parse_args()
 
 
-        provider_service = ProviderService()
-        provider_list = provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get('model_type'))
+        model_provider_service = ModelProviderService()
+        provider_list = model_provider_service.get_provider_list(
+            tenant_id=tenant_id,
+            model_type=args.get('model_type')
+        )
+
+        return jsonable_encoder({"data": provider_list})
+
+
+class ModelProviderCredentialApi(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self, provider: str):
+        tenant_id = current_user.current_tenant_id
+
+        model_provider_service = ModelProviderService()
+        credentials = model_provider_service.get_provider_credentials(
+            tenant_id=tenant_id,
+            provider=provider
+        )
 
 
-        return provider_list
+        return {
+            "credentials": credentials
+        }
 
 
 
 
 class ModelProviderValidateApi(Resource):
 class ModelProviderValidateApi(Resource):
@@ -36,21 +62,24 @@ class ModelProviderValidateApi(Resource):
     @setup_required
     @setup_required
     @login_required
     @login_required
     @account_initialization_required
     @account_initialization_required
-    def post(self, provider_name: str):
+    def post(self, provider: str):
 
 
         parser = reqparse.RequestParser()
         parser = reqparse.RequestParser()
-        parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
+        parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
         args = parser.parse_args()
         args = parser.parse_args()
 
 
-        provider_service = ProviderService()
+        tenant_id = current_user.current_tenant_id
+
+        model_provider_service = ModelProviderService()
 
 
         result = True
         result = True
         error = None
         error = None
 
 
         try:
         try:
-            provider_service.custom_provider_config_validate(
-                provider_name=provider_name,
-                config=args['config']
+            model_provider_service.provider_credentials_validate(
+                tenant_id=tenant_id,
+                provider=provider,
+                credentials=args['credentials']
             )
             )
         except CredentialsValidateFailedError as ex:
         except CredentialsValidateFailedError as ex:
             result = False
             result = False
@@ -64,26 +93,26 @@ class ModelProviderValidateApi(Resource):
         return response
         return response
 
 
 
 
-class ModelProviderUpdateApi(Resource):
+class ModelProviderApi(Resource):
 
 
     @setup_required
     @setup_required
     @login_required
     @login_required
     @account_initialization_required
     @account_initialization_required
-    def post(self, provider_name: str):
+    def post(self, provider: str):
         if current_user.current_tenant.current_role not in ['admin', 'owner']:
         if current_user.current_tenant.current_role not in ['admin', 'owner']:
             raise Forbidden()
             raise Forbidden()
 
 
         parser = reqparse.RequestParser()
         parser = reqparse.RequestParser()
-        parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
+        parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
         args = parser.parse_args()
         args = parser.parse_args()
 
 
-        provider_service = ProviderService()
+        model_provider_service = ModelProviderService()
 
 
         try:
         try:
-            provider_service.save_custom_provider_config(
+            model_provider_service.save_provider_credentials(
                 tenant_id=current_user.current_tenant_id,
                 tenant_id=current_user.current_tenant_id,
-                provider_name=provider_name,
-                config=args['config']
+                provider=provider,
+                credentials=args['credentials']
             )
             )
         except CredentialsValidateFailedError as ex:
         except CredentialsValidateFailedError as ex:
             raise ValueError(str(ex))
             raise ValueError(str(ex))
@@ -93,109 +122,36 @@ class ModelProviderUpdateApi(Resource):
     @setup_required
     @setup_required
     @login_required
     @login_required
     @account_initialization_required
     @account_initialization_required
-    def delete(self, provider_name: str):
+    def delete(self, provider: str):
         if current_user.current_tenant.current_role not in ['admin', 'owner']:
         if current_user.current_tenant.current_role not in ['admin', 'owner']:
             raise Forbidden()
             raise Forbidden()
 
 
-        provider_service = ProviderService()
-        provider_service.delete_custom_provider(
+        model_provider_service = ModelProviderService()
+        model_provider_service.remove_provider_credentials(
             tenant_id=current_user.current_tenant_id,
             tenant_id=current_user.current_tenant_id,
-            provider_name=provider_name
+            provider=provider
         )
         )
 
 
         return {'result': 'success'}, 204
         return {'result': 'success'}, 204
 
 
 
 
-class ModelProviderModelValidateApi(Resource):
-
-    @setup_required
-    @login_required
-    @account_initialization_required
-    def post(self, provider_name: str):
-        parser = reqparse.RequestParser()
-        parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('model_type', type=str, required=True, nullable=False,
-                            choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='json')
-        parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
-        args = parser.parse_args()
-
-        provider_service = ProviderService()
-
-        result = True
-        error = None
-
-        try:
-            provider_service.custom_provider_model_config_validate(
-                provider_name=provider_name,
-                model_name=args['model_name'],
-                model_type=args['model_type'],
-                config=args['config']
-            )
-        except CredentialsValidateFailedError as ex:
-            result = False
-            error = str(ex)
-
-        response = {'result': 'success' if result else 'error'}
-
-        if not result:
-            response['error'] = error
-
-        return response
-
-
-class ModelProviderModelUpdateApi(Resource):
+class ModelProviderIconApi(Resource):
+    """
+    Get model provider icon
+    """
 
 
     @setup_required
     @setup_required
     @login_required
     @login_required
     @account_initialization_required
     @account_initialization_required
-    def post(self, provider_name: str):
-        if current_user.current_tenant.current_role not in ['admin', 'owner']:
-            raise Forbidden()
-
-        parser = reqparse.RequestParser()
-        parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('model_type', type=str, required=True, nullable=False,
-                            choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='json')
-        parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
-        args = parser.parse_args()
-
-        provider_service = ProviderService()
-
-        try:
-            provider_service.add_or_save_custom_provider_model_config(
-                tenant_id=current_user.current_tenant_id,
-                provider_name=provider_name,
-                model_name=args['model_name'],
-                model_type=args['model_type'],
-                config=args['config']
-            )
-        except CredentialsValidateFailedError as ex:
-            raise ValueError(str(ex))
-
-        return {'result': 'success'}, 200
-
-    @setup_required
-    @login_required
-    @account_initialization_required
-    def delete(self, provider_name: str):
-        if current_user.current_tenant.current_role not in ['admin', 'owner']:
-            raise Forbidden()
-
-        parser = reqparse.RequestParser()
-        parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
-        parser.add_argument('model_type', type=str, required=True, nullable=False,
-                            choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args')
-        args = parser.parse_args()
-
-        provider_service = ProviderService()
-        provider_service.delete_custom_provider_model(
-            tenant_id=current_user.current_tenant_id,
-            provider_name=provider_name,
-            model_name=args['model_name'],
-            model_type=args['model_type']
+    def get(self, provider: str, icon_type: str, lang: str):
+        model_provider_service = ModelProviderService()
+        icon, mimetype = model_provider_service.get_model_provider_icon(
+            provider=provider,
+            icon_type=icon_type,
+            lang=lang
         )
         )
 
 
-        return {'result': 'success'}, 204
+        return send_file(io.BytesIO(icon), mimetype=mimetype)
 
 
 
 
 class PreferredProviderTypeUpdateApi(Resource):
 class PreferredProviderTypeUpdateApi(Resource):
@@ -203,71 +159,36 @@ class PreferredProviderTypeUpdateApi(Resource):
     @setup_required
     @setup_required
     @login_required
     @login_required
     @account_initialization_required
     @account_initialization_required
-    def post(self, provider_name: str):
+    def post(self, provider: str):
         if current_user.current_tenant.current_role not in ['admin', 'owner']:
         if current_user.current_tenant.current_role not in ['admin', 'owner']:
             raise Forbidden()
             raise Forbidden()
 
 
+        tenant_id = current_user.current_tenant_id
+
         parser = reqparse.RequestParser()
         parser = reqparse.RequestParser()
         parser.add_argument('preferred_provider_type', type=str, required=True, nullable=False,
         parser.add_argument('preferred_provider_type', type=str, required=True, nullable=False,
                             choices=['system', 'custom'], location='json')
                             choices=['system', 'custom'], location='json')
         args = parser.parse_args()
         args = parser.parse_args()
 
 
-        provider_service = ProviderService()
-        provider_service.switch_preferred_provider(
-            tenant_id=current_user.current_tenant_id,
-            provider_name=provider_name,
+        model_provider_service = ModelProviderService()
+        model_provider_service.switch_preferred_provider(
+            tenant_id=tenant_id,
+            provider=provider,
             preferred_provider_type=args['preferred_provider_type']
             preferred_provider_type=args['preferred_provider_type']
         )
         )
 
 
         return {'result': 'success'}
         return {'result': 'success'}
 
 
 
 
-class ModelProviderModelParameterRuleApi(Resource):
-
-    @setup_required
-    @login_required
-    @account_initialization_required
-    def get(self, provider_name: str):
-        parser = reqparse.RequestParser()
-        parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
-        args = parser.parse_args()
-
-        provider_service = ProviderService()
-
-        try:
-            parameter_rules = provider_service.get_model_parameter_rules(
-                tenant_id=current_user.current_tenant_id,
-                model_provider_name=provider_name,
-                model_name=args['model_name'],
-                model_type='text-generation'
-            )
-        except LLMBadRequestError:
-            raise ProviderNotInitializeError(
-                f"Current Text Generation Model is invalid. Please switch to the available model.")
-
-        rules = {
-            k: {
-                'enabled': v.enabled,
-                'min': v.min,
-                'max': v.max,
-                'default': v.default,
-                'precision': v.precision
-            }
-            for k, v in vars(parameter_rules).items()
-        }
-
-        return rules
-
-
 class ModelProviderPaymentCheckoutUrlApi(Resource):
 class ModelProviderPaymentCheckoutUrlApi(Resource):
     @setup_required
     @setup_required
     @login_required
     @login_required
     @account_initialization_required
     @account_initialization_required
-    def get(self, provider_name: str):
-        if provider_name != 'anthropic':
-            raise ValueError(f'provider name {provider_name} is invalid')
+    def get(self, provider: str):
+        if provider != 'anthropic':
+            raise ValueError(f'provider name {provider} is invalid')
 
 
-        data = BillingService.get_model_provider_payment_link(provider_name=provider_name,
+        data = BillingService.get_model_provider_payment_link(provider_name=provider,
                                                               tenant_id=current_user.current_tenant_id,
                                                               tenant_id=current_user.current_tenant_id,
                                                               account_id=current_user.id)
                                                               account_id=current_user.id)
         return data
         return data
@@ -277,11 +198,11 @@ class ModelProviderFreeQuotaSubmitApi(Resource):
     @setup_required
     @setup_required
     @login_required
     @login_required
     @account_initialization_required
     @account_initialization_required
-    def post(self, provider_name: str):
-        provider_service = ProviderService()
-        result = provider_service.free_quota_submit(
+    def post(self, provider: str):
+        model_provider_service = ModelProviderService()
+        result = model_provider_service.free_quota_submit(
             tenant_id=current_user.current_tenant_id,
             tenant_id=current_user.current_tenant_id,
-            provider_name=provider_name
+            provider=provider
         )
         )
 
 
         return result
         return result
@@ -291,15 +212,15 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
     @setup_required
     @setup_required
     @login_required
     @login_required
     @account_initialization_required
     @account_initialization_required
-    def get(self, provider_name: str):
+    def get(self, provider: str):
         parser = reqparse.RequestParser()
         parser = reqparse.RequestParser()
         parser.add_argument('token', type=str, required=False, nullable=True, location='args')
         parser.add_argument('token', type=str, required=False, nullable=True, location='args')
         args = parser.parse_args()
         args = parser.parse_args()
 
 
-        provider_service = ProviderService()
-        result = provider_service.free_quota_qualification_verify(
+        model_provider_service = ModelProviderService()
+        result = model_provider_service.free_quota_qualification_verify(
             tenant_id=current_user.current_tenant_id,
             tenant_id=current_user.current_tenant_id,
-            provider_name=provider_name,
+            provider=provider,
             token=args['token']
             token=args['token']
         )
         )
 
 
@@ -307,19 +228,18 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
 
 
 
 
 api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers')
 api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers')
-api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider_name>/validate')
-api.add_resource(ModelProviderUpdateApi, '/workspaces/current/model-providers/<string:provider_name>')
-api.add_resource(ModelProviderModelValidateApi,
-                 '/workspaces/current/model-providers/<string:provider_name>/models/validate')
-api.add_resource(ModelProviderModelUpdateApi,
-                 '/workspaces/current/model-providers/<string:provider_name>/models')
+
+api.add_resource(ModelProviderCredentialApi, '/workspaces/current/model-providers/<string:provider>/credentials')
+api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider>/credentials/validate')
+api.add_resource(ModelProviderApi, '/workspaces/current/model-providers/<string:provider>')
+api.add_resource(ModelProviderIconApi, '/workspaces/current/model-providers/<string:provider>/'
+                                       '<string:icon_type>/<string:lang>')
+
 api.add_resource(PreferredProviderTypeUpdateApi,
 api.add_resource(PreferredProviderTypeUpdateApi,
-                 '/workspaces/current/model-providers/<string:provider_name>/preferred-provider-type')
-api.add_resource(ModelProviderModelParameterRuleApi,
-                 '/workspaces/current/model-providers/<string:provider_name>/models/parameter-rules')
+                 '/workspaces/current/model-providers/<string:provider>/preferred-provider-type')
 api.add_resource(ModelProviderPaymentCheckoutUrlApi,
 api.add_resource(ModelProviderPaymentCheckoutUrlApi,
-                 '/workspaces/current/model-providers/<string:provider_name>/checkout-url')
+                 '/workspaces/current/model-providers/<string:provider>/checkout-url')
 api.add_resource(ModelProviderFreeQuotaSubmitApi,
 api.add_resource(ModelProviderFreeQuotaSubmitApi,
-                 '/workspaces/current/model-providers/<string:provider_name>/free-quota-submit')
+                 '/workspaces/current/model-providers/<string:provider>/free-quota-submit')
 api.add_resource(ModelProviderFreeQuotaQualificationVerifyApi,
 api.add_resource(ModelProviderFreeQuotaQualificationVerifyApi,
-                 '/workspaces/current/model-providers/<string:provider_name>/free-quota-qualification-verify')
+                 '/workspaces/current/model-providers/<string:provider>/free-quota-qualification-verify')

+ 212 - 56
api/controllers/console/workspace/models.py

@@ -1,16 +1,17 @@
 import logging
 import logging
 
 
 from flask_login import current_user
 from flask_login import current_user
-from libs.login import login_required
-from flask_restful import Resource, reqparse
+from flask_restful import reqparse, Resource
+from werkzeug.exceptions import Forbidden
 
 
 from controllers.console import api
 from controllers.console import api
 from controllers.console.setup import setup_required
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
 from controllers.console.wraps import account_initialization_required
-from core.model_providers.model_provider_factory import ModelProviderFactory
-from core.model_providers.models.entity.model_params import ModelType
-from models.provider import ProviderType
-from services.provider_service import ProviderService
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.utils.encoders import jsonable_encoder
+from libs.login import login_required
+from services.model_provider_service import ModelProviderService
 
 
 
 
 class DefaultModelApi(Resource):
 class DefaultModelApi(Resource):
@@ -21,52 +22,20 @@ class DefaultModelApi(Resource):
     def get(self):
     def get(self):
         parser = reqparse.RequestParser()
         parser = reqparse.RequestParser()
         parser.add_argument('model_type', type=str, required=True, nullable=False,
         parser.add_argument('model_type', type=str, required=True, nullable=False,
-                            choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args')
+                            choices=[mt.value for mt in ModelType], location='args')
         args = parser.parse_args()
         args = parser.parse_args()
 
 
         tenant_id = current_user.current_tenant_id
         tenant_id = current_user.current_tenant_id
 
 
-        provider_service = ProviderService()
-        default_model = provider_service.get_default_model_of_model_type(
+        model_provider_service = ModelProviderService()
+        default_model_entity = model_provider_service.get_default_model_of_model_type(
             tenant_id=tenant_id,
             tenant_id=tenant_id,
             model_type=args['model_type']
             model_type=args['model_type']
         )
         )
 
 
-        if not default_model:
-            return None
-
-        model_provider = ModelProviderFactory.get_preferred_model_provider(
-            tenant_id,
-            default_model.provider_name
-        )
-
-        if not model_provider:
-            return {
-                'model_name': default_model.model_name,
-                'model_type': default_model.model_type,
-                'model_provider': {
-                    'provider_name': default_model.provider_name
-                }
-            }
-
-        provider = model_provider.provider
-        rst = {
-            'model_name': default_model.model_name,
-            'model_type': default_model.model_type,
-            'model_provider': {
-                'provider_name': provider.provider_name,
-                'provider_type': provider.provider_type
-            }
-        }
-
-        model_provider_rules = ModelProviderFactory.get_provider_rule(default_model.provider_name)
-        if provider.provider_type == ProviderType.SYSTEM.value:
-            rst['model_provider']['quota_type'] = provider.quota_type
-            rst['model_provider']['quota_unit'] = model_provider_rules['system_config']['quota_unit']
-            rst['model_provider']['quota_limit'] = provider.quota_limit
-            rst['model_provider']['quota_used'] = provider.quota_used
-
-        return rst
+        return jsonable_encoder({
+            "data": default_model_entity
+        })
 
 
     @setup_required
     @setup_required
     @login_required
     @login_required
@@ -76,15 +45,26 @@ class DefaultModelApi(Resource):
         parser.add_argument('model_settings', type=list, required=True, nullable=False, location='json')
         parser.add_argument('model_settings', type=list, required=True, nullable=False, location='json')
         args = parser.parse_args()
         args = parser.parse_args()
 
 
-        provider_service = ProviderService()
+        tenant_id = current_user.current_tenant_id
+
+        model_provider_service = ModelProviderService()
         model_settings = args['model_settings']
         model_settings = args['model_settings']
         for model_setting in model_settings:
         for model_setting in model_settings:
+            if 'model_type' not in model_setting or model_setting['model_type'] not in [mt.value for mt in ModelType]:
+                raise ValueError('invalid model type')
+
+            if 'provider' not in model_setting:
+                continue
+
+            if 'model' not in model_setting:
+                raise ValueError('invalid model')
+
             try:
             try:
-                provider_service.update_default_model_of_model_type(
-                    tenant_id=current_user.current_tenant_id,
+                model_provider_service.update_default_model_of_model_type(
+                    tenant_id=tenant_id,
                     model_type=model_setting['model_type'],
                     model_type=model_setting['model_type'],
-                    provider_name=model_setting['provider_name'],
-                    model_name=model_setting['model_name']
+                    provider=model_setting['provider'],
+                    model=model_setting['model']
                 )
                 )
             except Exception:
             except Exception:
                 logging.warning(f"{model_setting['model_type']} save error")
                 logging.warning(f"{model_setting['model_type']} save error")
@@ -92,22 +72,198 @@ class DefaultModelApi(Resource):
         return {'result': 'success'}
         return {'result': 'success'}
 
 
 
 
-class ValidModelApi(Resource):
+class ModelProviderModelApi(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self, provider):
+        tenant_id = current_user.current_tenant_id
+
+        model_provider_service = ModelProviderService()
+        models = model_provider_service.get_models_by_provider(
+            tenant_id=tenant_id,
+            provider=provider
+        )
+
+        return jsonable_encoder({
+            "data": models
+        })
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def post(self, provider: str):
+        if current_user.current_tenant.current_role not in ['admin', 'owner']:
+            raise Forbidden()
+
+        tenant_id = current_user.current_tenant_id
+
+        parser = reqparse.RequestParser()
+        parser.add_argument('model', type=str, required=True, nullable=False, location='json')
+        parser.add_argument('model_type', type=str, required=True, nullable=False,
+                            choices=[mt.value for mt in ModelType], location='json')
+        parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
+        args = parser.parse_args()
+
+        model_provider_service = ModelProviderService()
+
+        try:
+            model_provider_service.save_model_credentials(
+                tenant_id=tenant_id,
+                provider=provider,
+                model=args['model'],
+                model_type=args['model_type'],
+                credentials=args['credentials']
+            )
+        except CredentialsValidateFailedError as ex:
+            raise ValueError(str(ex))
+
+        return {'result': 'success'}, 200
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def delete(self, provider: str):
+        if current_user.current_tenant.current_role not in ['admin', 'owner']:
+            raise Forbidden()
+
+        tenant_id = current_user.current_tenant_id
+
+        parser = reqparse.RequestParser()
+        parser.add_argument('model', type=str, required=True, nullable=False, location='json')
+        parser.add_argument('model_type', type=str, required=True, nullable=False,
+                            choices=[mt.value for mt in ModelType], location='json')
+        args = parser.parse_args()
+
+        model_provider_service = ModelProviderService()
+        model_provider_service.remove_model_credentials(
+            tenant_id=tenant_id,
+            provider=provider,
+            model=args['model'],
+            model_type=args['model_type']
+        )
+
+        return {'result': 'success'}, 204
+
+
+class ModelProviderModelCredentialApi(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self, provider: str):
+        tenant_id = current_user.current_tenant_id
+
+        parser = reqparse.RequestParser()
+        parser.add_argument('model', type=str, required=True, nullable=False, location='args')
+        parser.add_argument('model_type', type=str, required=True, nullable=False,
+                            choices=[mt.value for mt in ModelType], location='args')
+        args = parser.parse_args()
+
+        model_provider_service = ModelProviderService()
+        credentials = model_provider_service.get_model_credentials(
+            tenant_id=tenant_id,
+            provider=provider,
+            model_type=args['model_type'],
+            model=args['model']
+        )
+
+        return {
+            "credentials": credentials
+        }
+
+
+class ModelProviderModelValidateApi(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def post(self, provider: str):
+        tenant_id = current_user.current_tenant_id
+
+        parser = reqparse.RequestParser()
+        parser.add_argument('model', type=str, required=True, nullable=False, location='json')
+        parser.add_argument('model_type', type=str, required=True, nullable=False,
+                            choices=[mt.value for mt in ModelType], location='json')
+        parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
+        args = parser.parse_args()
+
+        model_provider_service = ModelProviderService()
+
+        result = True
+        error = None
+
+        try:
+            model_provider_service.model_credentials_validate(
+                tenant_id=tenant_id,
+                provider=provider,
+                model=args['model'],
+                model_type=args['model_type'],
+                credentials=args['credentials']
+            )
+        except CredentialsValidateFailedError as ex:
+            result = False
+            error = str(ex)
+
+        response = {'result': 'success' if result else 'error'}
+
+        if not result:
+            response['error'] = error
+
+        return response
+
+
+class ModelProviderModelParameterRuleApi(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self, provider: str):
+        parser = reqparse.RequestParser()
+        parser.add_argument('model', type=str, required=True, nullable=False, location='args')
+        args = parser.parse_args()
+
+        tenant_id = current_user.current_tenant_id
+
+        model_provider_service = ModelProviderService()
+        parameter_rules = model_provider_service.get_model_parameter_rules(
+            tenant_id=tenant_id,
+            provider=provider,
+            model=args['model']
+        )
+
+        return jsonable_encoder({
+            "data": parameter_rules
+        })
+
+
+class ModelProviderAvailableModelApi(Resource):
 
 
     @setup_required
     @setup_required
     @login_required
     @login_required
     @account_initialization_required
     @account_initialization_required
     def get(self, model_type):
     def get(self, model_type):
-        ModelType.value_of(model_type)
+        tenant_id = current_user.current_tenant_id
 
 
-        provider_service = ProviderService()
-        valid_models = provider_service.get_valid_model_list(
-            tenant_id=current_user.current_tenant_id,
+        model_provider_service = ModelProviderService()
+        models = model_provider_service.get_models_by_model_type(
+            tenant_id=tenant_id,
             model_type=model_type
             model_type=model_type
         )
         )
 
 
-        return valid_models
+        return jsonable_encoder({
+            "data": models
+        })
+
 
 
+api.add_resource(ModelProviderModelApi, '/workspaces/current/model-providers/<string:provider>/models')
+api.add_resource(ModelProviderModelCredentialApi,
+                 '/workspaces/current/model-providers/<string:provider>/models/credentials')
+api.add_resource(ModelProviderModelValidateApi,
+                 '/workspaces/current/model-providers/<string:provider>/models/credentials/validate')
 
 
+api.add_resource(ModelProviderModelParameterRuleApi,
+                 '/workspaces/current/model-providers/<string:provider>/models/parameter-rules')
+api.add_resource(ModelProviderAvailableModelApi, '/workspaces/current/models/model-types/<string:model_type>')
 api.add_resource(DefaultModelApi, '/workspaces/current/default-model')
 api.add_resource(DefaultModelApi, '/workspaces/current/default-model')
-api.add_resource(ValidModelApi, '/workspaces/current/models/model-type/<string:model_type>')

+ 0 - 131
api/controllers/console/workspace/providers.py

@@ -1,131 +0,0 @@
-# -*- coding:utf-8 -*-
-from flask_login import current_user
-from libs.login import login_required
-from flask_restful import Resource, reqparse
-from werkzeug.exceptions import Forbidden
-
-from controllers.console import api
-from controllers.console.setup import setup_required
-from controllers.console.wraps import account_initialization_required
-from core.model_providers.providers.base import CredentialsValidateFailedError
-from models.provider import ProviderType
-from services.provider_service import ProviderService
-
-
-class ProviderListApi(Resource):
-
-    @setup_required
-    @login_required
-    @account_initialization_required
-    def get(self):
-        tenant_id = current_user.current_tenant_id
-
-        """
-        If the type is AZURE_OPENAI, decode and return the four fields of azure_api_type, azure_api_version:, 
-        azure_api_base, azure_api_key as an object, where azure_api_key displays the first 6 bits in plaintext, and the 
-        rest is replaced by * and the last two bits are displayed in plaintext
-        
-        If the type is other, decode and return the Token field directly, the field displays the first 6 bits in 
-        plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
-        """
-
-        provider_service = ProviderService()
-        provider_info_list = provider_service.get_provider_list(tenant_id)
-
-        provider_list = [
-            {
-                'provider_name': p['provider_name'],
-                'provider_type': p['provider_type'],
-                'is_valid': p['is_valid'],
-                'last_used': p['last_used'],
-                'is_enabled': p['is_valid'],
-                **({
-                       'quota_type': p['quota_type'],
-                       'quota_limit': p['quota_limit'],
-                       'quota_used': p['quota_used']
-                   } if p['provider_type'] == ProviderType.SYSTEM.value else {}),
-                'token': (p['config'] if p['provider_name'] != 'openai' else p['config']['openai_api_key'])
-                        if p['config'] else None
-            }
-            for name, provider_info in provider_info_list.items()
-            for p in provider_info['providers']
-        ]
-
-        return provider_list
-
-
-class ProviderTokenApi(Resource):
-
-    @setup_required
-    @login_required
-    @account_initialization_required
-    def post(self, provider):
-        # The role of the current user in the ta table must be admin or owner
-        if current_user.current_tenant.current_role not in ['admin', 'owner']:
-            raise Forbidden()
-
-        parser = reqparse.RequestParser()
-        parser.add_argument('token', required=True, nullable=False, location='json')
-        args = parser.parse_args()
-
-        if provider == 'openai':
-            args['token'] = {
-                'openai_api_key': args['token']
-            }
-
-        provider_service = ProviderService()
-        try:
-            provider_service.save_custom_provider_config(
-                tenant_id=current_user.current_tenant_id,
-                provider_name=provider,
-                config=args['token']
-            )
-        except CredentialsValidateFailedError as ex:
-            raise ValueError(str(ex))
-
-        return {'result': 'success'}, 201
-
-
-class ProviderTokenValidateApi(Resource):
-
-    @setup_required
-    @login_required
-    @account_initialization_required
-    def post(self, provider):
-        parser = reqparse.RequestParser()
-        parser.add_argument('token', required=True, nullable=False, location='json')
-        args = parser.parse_args()
-
-        provider_service = ProviderService()
-
-        if provider == 'openai':
-            args['token'] = {
-                'openai_api_key': args['token']
-            }
-
-        result = True
-        error = None
-
-        try:
-            provider_service.custom_provider_config_validate(
-                provider_name=provider,
-                config=args['token']
-            )
-        except CredentialsValidateFailedError as ex:
-            result = False
-            error = str(ex)
-
-        response = {'result': 'success' if result else 'error'}
-
-        if not result:
-            response['error'] = error
-
-        return response
-
-
-api.add_resource(ProviderTokenApi, '/workspaces/current/providers/<provider>/token',
-                 endpoint='workspaces_current_providers_token')  # PUT for updating provider token
-api.add_resource(ProviderTokenValidateApi, '/workspaces/current/providers/<provider>/token-validate',
-                 endpoint='workspaces_current_providers_token_validate')  # POST for validating provider token
-
-api.add_resource(ProviderListApi, '/workspaces/current/providers')  # GET for getting providers list

+ 0 - 1
api/controllers/console/workspace/workspace.py

@@ -34,7 +34,6 @@ tenant_fields = {
     'status': fields.String,
     'status': fields.String,
     'created_at': TimestampField,
     'created_at': TimestampField,
     'role': fields.String,
     'role': fields.String,
-    'providers': fields.List(fields.Nested(provider_fields)),
     'in_trial': fields.Boolean,
     'in_trial': fields.Boolean,
     'trial_end_reason': fields.String,
     'trial_end_reason': fields.String,
     'custom_config': fields.Raw(attribute='custom_config'),
     'custom_config': fields.Raw(attribute='custom_config'),

+ 3 - 4
api/controllers/service_api/app/audio.py

@@ -9,8 +9,8 @@ from controllers.service_api.app.error import AppUnavailableError, ProviderNotIn
     ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, UnsupportedAudioTypeError, \
     ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, UnsupportedAudioTypeError, \
     ProviderNotSupportSpeechToTextError
     ProviderNotSupportSpeechToTextError
 from controllers.service_api.wraps import AppApiResource
 from controllers.service_api.wraps import AppApiResource
-from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
-    LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.model_runtime.errors.invoke import InvokeError
 from models.model import App, AppModelConfig
 from models.model import App, AppModelConfig
 from services.audio_service import AudioService
 from services.audio_service import AudioService
 from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
 from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
@@ -49,8 +49,7 @@ class AudioApi(AppApiResource):
             raise ProviderQuotaExceededError()
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
             raise CompletionRequestError(str(e))
         except ValueError as e:
         except ValueError as e:
             raise e
             raise e

+ 11 - 13
api/controllers/service_api/app/completion.py

@@ -13,9 +13,10 @@ from controllers.service_api.app.error import AppUnavailableError, ProviderNotIn
     ConversationCompletedError, CompletionRequestError, ProviderQuotaExceededError, \
     ConversationCompletedError, CompletionRequestError, ProviderQuotaExceededError, \
     ProviderModelCurrentlyNotSupportError
     ProviderModelCurrentlyNotSupportError
 from controllers.service_api.wraps import AppApiResource
 from controllers.service_api.wraps import AppApiResource
-from core.conversation_message_task import PubHandler
-from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
-    LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.application_queue_manager import ApplicationQueueManager
+from core.entities.application_entities import InvokeFrom
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.model_runtime.errors.invoke import InvokeError
 from libs.helper import uuid_value
 from libs.helper import uuid_value
 from services.completion_service import CompletionService
 from services.completion_service import CompletionService
 
 
@@ -47,7 +48,7 @@ class CompletionApi(AppApiResource):
                 app_model=app_model,
                 app_model=app_model,
                 user=end_user,
                 user=end_user,
                 args=args,
                 args=args,
-                from_source='api',
+                invoke_from=InvokeFrom.SERVICE_API,
                 streaming=streaming,
                 streaming=streaming,
             )
             )
 
 
@@ -65,8 +66,7 @@ class CompletionApi(AppApiResource):
             raise ProviderQuotaExceededError()
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
             raise CompletionRequestError(str(e))
         except ValueError as e:
         except ValueError as e:
             raise e
             raise e
@@ -80,7 +80,7 @@ class CompletionStopApi(AppApiResource):
         if app_model.mode != 'completion':
         if app_model.mode != 'completion':
             raise AppUnavailableError()
             raise AppUnavailableError()
 
 
-        PubHandler.stop(end_user, task_id)
+        ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
 
 
         return {'result': 'success'}, 200
         return {'result': 'success'}, 200
 
 
@@ -112,7 +112,7 @@ class ChatApi(AppApiResource):
                 app_model=app_model,
                 app_model=app_model,
                 user=end_user,
                 user=end_user,
                 args=args,
                 args=args,
-                from_source='api',
+                invoke_from=InvokeFrom.SERVICE_API,
                 streaming=streaming
                 streaming=streaming
             )
             )
 
 
@@ -130,8 +130,7 @@ class ChatApi(AppApiResource):
             raise ProviderQuotaExceededError()
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
             raise CompletionRequestError(str(e))
         except ValueError as e:
         except ValueError as e:
             raise e
             raise e
@@ -145,7 +144,7 @@ class ChatStopApi(AppApiResource):
         if app_model.mode != 'chat':
         if app_model.mode != 'chat':
             raise NotChatAppError()
             raise NotChatAppError()
 
 
-        PubHandler.stop(end_user, task_id)
+        ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
 
 
         return {'result': 'success'}, 200
         return {'result': 'success'}, 200
 
 
@@ -171,8 +170,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
                 yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
                 yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
             except ModelCurrentlyNotSupportError:
             except ModelCurrentlyNotSupportError:
                 yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
                 yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
-            except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                    LLMRateLimitError, LLMAuthorizationError) as e:
+            except InvokeError as e:
                 yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
                 yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
             except ValueError as e:
             except ValueError as e:
                 yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
                 yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"

+ 15 - 7
api/controllers/service_api/dataset/dataset.py

@@ -4,11 +4,11 @@ import services.dataset_service
 from controllers.service_api import api
 from controllers.service_api import api
 from controllers.service_api.dataset.error import DatasetNameDuplicateError
 from controllers.service_api.dataset.error import DatasetNameDuplicateError
 from controllers.service_api.wraps import DatasetApiResource
 from controllers.service_api.wraps import DatasetApiResource
+from core.model_runtime.entities.model_entities import ModelType
+from core.provider_manager import ProviderManager
 from libs.login import current_user
 from libs.login import current_user
-from core.model_providers.models.entity.model_params import ModelType
 from fields.dataset_fields import dataset_detail_fields
 from fields.dataset_fields import dataset_detail_fields
 from services.dataset_service import DatasetService
 from services.dataset_service import DatasetService
-from services.provider_service import ProviderService
 
 
 
 
 def _validate_name(name):
 def _validate_name(name):
@@ -27,12 +27,20 @@ class DatasetApi(DatasetApiResource):
         datasets, total = DatasetService.get_datasets(page, limit, provider,
         datasets, total = DatasetService.get_datasets(page, limit, provider,
                                                       tenant_id, current_user)
                                                       tenant_id, current_user)
         # check embedding setting
         # check embedding setting
-        provider_service = ProviderService()
-        valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
-                                                                 ModelType.EMBEDDINGS.value)
+        provider_manager = ProviderManager()
+        configurations = provider_manager.get_configurations(
+            tenant_id=current_user.current_tenant_id
+        )
+
+        embedding_models = configurations.get_models(
+            model_type=ModelType.TEXT_EMBEDDING,
+            only_active=True
+        )
+
         model_names = []
         model_names = []
-        for valid_model in valid_model_list:
-            model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
+        for embedding_model in embedding_models:
+            model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
+
         data = marshal(datasets, dataset_detail_fields)
         data = marshal(datasets, dataset_detail_fields)
         for item in data:
         for item in data:
             if item['indexing_technique'] == 'high_quality':
             if item['indexing_technique'] == 'high_quality':

+ 1 - 1
api/controllers/service_api/dataset/document.py

@@ -13,7 +13,7 @@ from controllers.service_api.dataset.error import ArchivedDocumentImmutableError
     NoFileUploadedError, TooManyFilesError
     NoFileUploadedError, TooManyFilesError
 from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
 from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
 from libs.login import current_user
 from libs.login import current_user
-from core.model_providers.error import ProviderTokenNotInitError
+from core.errors.error import ProviderTokenNotInitError
 from extensions.ext_database import db
 from extensions.ext_database import db
 from fields.document_fields import document_fields, document_status_fields
 from fields.document_fields import document_fields, document_status_fields
 from models.dataset import Dataset, Document, DocumentSegment
 from models.dataset import Dataset, Document, DocumentSegment

+ 18 - 11
api/controllers/service_api/dataset/segment.py

@@ -4,8 +4,9 @@ from werkzeug.exceptions import NotFound
 from controllers.service_api import api
 from controllers.service_api import api
 from controllers.service_api.app.error import ProviderNotInitializeError
 from controllers.service_api.app.error import ProviderNotInitializeError
 from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
 from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
-from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
-from core.model_providers.model_factory import ModelFactory
+from core.errors.error import ProviderTokenNotInitError, LLMBadRequestError
+from core.model_manager import ModelManager
+from core.model_runtime.entities.model_entities import ModelType
 from extensions.ext_database import db
 from extensions.ext_database import db
 from fields.segment_fields import segment_fields
 from fields.segment_fields import segment_fields
 from models.dataset import Dataset, DocumentSegment
 from models.dataset import Dataset, DocumentSegment
@@ -35,10 +36,12 @@ class SegmentApi(DatasetApiResource):
         # check embedding model setting
         # check embedding model setting
         if dataset.indexing_technique == 'high_quality':
         if dataset.indexing_technique == 'high_quality':
             try:
             try:
-                ModelFactory.get_embedding_model(
+                model_manager = ModelManager()
+                model_manager.get_model_instance(
                     tenant_id=current_user.current_tenant_id,
                     tenant_id=current_user.current_tenant_id,
-                    model_provider_name=dataset.embedding_model_provider,
-                    model_name=dataset.embedding_model
+                    provider=dataset.embedding_model_provider,
+                    model_type=ModelType.TEXT_EMBEDDING,
+                    model=dataset.embedding_model
                 )
                 )
             except LLMBadRequestError:
             except LLMBadRequestError:
                 raise ProviderNotInitializeError(
                 raise ProviderNotInitializeError(
@@ -77,10 +80,12 @@ class SegmentApi(DatasetApiResource):
         # check embedding model setting
         # check embedding model setting
         if dataset.indexing_technique == 'high_quality':
         if dataset.indexing_technique == 'high_quality':
             try:
             try:
-                ModelFactory.get_embedding_model(
+                model_manager = ModelManager()
+                model_manager.get_model_instance(
                     tenant_id=current_user.current_tenant_id,
                     tenant_id=current_user.current_tenant_id,
-                    model_provider_name=dataset.embedding_model_provider,
-                    model_name=dataset.embedding_model
+                    provider=dataset.embedding_model_provider,
+                    model_type=ModelType.TEXT_EMBEDDING,
+                    model=dataset.embedding_model
                 )
                 )
             except LLMBadRequestError:
             except LLMBadRequestError:
                 raise ProviderNotInitializeError(
                 raise ProviderNotInitializeError(
@@ -167,10 +172,12 @@ class DatasetSegmentApi(DatasetApiResource):
         if dataset.indexing_technique == 'high_quality':
         if dataset.indexing_technique == 'high_quality':
             # check embedding model setting
             # check embedding model setting
             try:
             try:
-                ModelFactory.get_embedding_model(
+                model_manager = ModelManager()
+                model_manager.get_model_instance(
                     tenant_id=current_user.current_tenant_id,
                     tenant_id=current_user.current_tenant_id,
-                    model_provider_name=dataset.embedding_model_provider,
-                    model_name=dataset.embedding_model
+                    provider=dataset.embedding_model_provider,
+                    model_type=ModelType.TEXT_EMBEDDING,
+                    model=dataset.embedding_model
                 )
                 )
             except LLMBadRequestError:
             except LLMBadRequestError:
                 raise ProviderNotInitializeError(
                 raise ProviderNotInitializeError(

+ 3 - 4
api/controllers/web/audio.py

@@ -10,8 +10,8 @@ from controllers.web.error import AppUnavailableError, ProviderNotInitializeErro
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, \
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, \
     UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
     UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
 from controllers.web.wraps import WebApiResource
 from controllers.web.wraps import WebApiResource
-from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
-    LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.model_runtime.errors.invoke import InvokeError
 from services.audio_service import AudioService
 from services.audio_service import AudioService
 from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
 from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
     UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
     UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
@@ -51,8 +51,7 @@ class AudioApi(WebApiResource):
             raise ProviderQuotaExceededError()
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
             raise CompletionRequestError(str(e))
         except ValueError as e:
         except ValueError as e:
             raise e
             raise e

+ 11 - 13
api/controllers/web/completion.py

@@ -13,9 +13,10 @@ from controllers.web.error import AppUnavailableError, ConversationCompletedErro
     ProviderNotInitializeError, NotChatAppError, NotCompletionAppError, CompletionRequestError, \
     ProviderNotInitializeError, NotChatAppError, NotCompletionAppError, CompletionRequestError, \
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
 from controllers.web.wraps import WebApiResource
 from controllers.web.wraps import WebApiResource
-from core.conversation_message_task import PubHandler
-from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
-    LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.application_queue_manager import ApplicationQueueManager
+from core.entities.application_entities import InvokeFrom
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.model_runtime.errors.invoke import InvokeError
 from libs.helper import uuid_value
 from libs.helper import uuid_value
 from services.completion_service import CompletionService
 from services.completion_service import CompletionService
 
 
@@ -44,7 +45,7 @@ class CompletionApi(WebApiResource):
                 app_model=app_model,
                 app_model=app_model,
                 user=end_user,
                 user=end_user,
                 args=args,
                 args=args,
-                from_source='api',
+                invoke_from=InvokeFrom.WEB_APP,
                 streaming=streaming
                 streaming=streaming
             )
             )
 
 
@@ -62,8 +63,7 @@ class CompletionApi(WebApiResource):
             raise ProviderQuotaExceededError()
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
             raise CompletionRequestError(str(e))
         except ValueError as e:
         except ValueError as e:
             raise e
             raise e
@@ -77,7 +77,7 @@ class CompletionStopApi(WebApiResource):
         if app_model.mode != 'completion':
         if app_model.mode != 'completion':
             raise NotCompletionAppError()
             raise NotCompletionAppError()
 
 
-        PubHandler.stop(end_user, task_id)
+        ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
 
 
         return {'result': 'success'}, 200
         return {'result': 'success'}, 200
 
 
@@ -105,7 +105,7 @@ class ChatApi(WebApiResource):
                 app_model=app_model,
                 app_model=app_model,
                 user=end_user,
                 user=end_user,
                 args=args,
                 args=args,
-                from_source='api',
+                invoke_from=InvokeFrom.WEB_APP,
                 streaming=streaming
                 streaming=streaming
             )
             )
 
 
@@ -123,8 +123,7 @@ class ChatApi(WebApiResource):
             raise ProviderQuotaExceededError()
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
             raise CompletionRequestError(str(e))
         except ValueError as e:
         except ValueError as e:
             raise e
             raise e
@@ -138,7 +137,7 @@ class ChatStopApi(WebApiResource):
         if app_model.mode != 'chat':
         if app_model.mode != 'chat':
             raise NotChatAppError()
             raise NotChatAppError()
 
 
-        PubHandler.stop(end_user, task_id)
+        ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
 
 
         return {'result': 'success'}, 200
         return {'result': 'success'}, 200
 
 
@@ -164,8 +163,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
                 yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
                 yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
             except ModelCurrentlyNotSupportError:
             except ModelCurrentlyNotSupportError:
                 yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
                 yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
-            except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                    LLMRateLimitError, LLMAuthorizationError) as e:
+            except InvokeError as e:
                 yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
                 yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
             except ValueError as e:
             except ValueError as e:
                 yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
                 yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"

+ 14 - 9
api/controllers/web/message.py

@@ -14,8 +14,9 @@ from controllers.web.error import NotChatAppError, CompletionRequestError, Provi
     AppMoreLikeThisDisabledError, NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \
     AppMoreLikeThisDisabledError, NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
 from controllers.web.wraps import WebApiResource
 from controllers.web.wraps import WebApiResource
-from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
-    ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.entities.application_entities import InvokeFrom
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.model_runtime.errors.invoke import InvokeError
 from libs.helper import uuid_value, TimestampField
 from libs.helper import uuid_value, TimestampField
 from services.completion_service import CompletionService
 from services.completion_service import CompletionService
 from services.errors.app import MoreLikeThisDisabledError
 from services.errors.app import MoreLikeThisDisabledError
@@ -117,7 +118,14 @@ class MessageMoreLikeThisApi(WebApiResource):
         streaming = args['response_mode'] == 'streaming'
         streaming = args['response_mode'] == 'streaming'
 
 
         try:
         try:
-            response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming, 'web_app')
+            response = CompletionService.generate_more_like_this(
+                app_model=app_model,
+                user=end_user,
+                message_id=message_id,
+                invoke_from=InvokeFrom.WEB_APP,
+                streaming=streaming
+            )
+
             return compact_response(response)
             return compact_response(response)
         except MessageNotExistsError:
         except MessageNotExistsError:
             raise NotFound("Message Not Exists.")
             raise NotFound("Message Not Exists.")
@@ -129,8 +137,7 @@ class MessageMoreLikeThisApi(WebApiResource):
             raise ProviderQuotaExceededError()
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
             raise CompletionRequestError(str(e))
         except ValueError as e:
         except ValueError as e:
             raise e
             raise e
@@ -157,8 +164,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
                 yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
                 yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
             except ModelCurrentlyNotSupportError:
             except ModelCurrentlyNotSupportError:
                 yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
                 yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
-            except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                    LLMRateLimitError, LLMAuthorizationError) as e:
+            except InvokeError as e:
                 yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
                 yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
             except ValueError as e:
             except ValueError as e:
                 yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
                 yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
@@ -195,8 +201,7 @@ class MessageSuggestedQuestionApi(WebApiResource):
             raise ProviderQuotaExceededError()
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
             raise CompletionRequestError(str(e))
         except Exception:
         except Exception:
             logging.exception("internal server error.")
             logging.exception("internal server error.")

+ 101 - 0
api/core/agent/agent/agent_llm_callback.py

@@ -0,0 +1,101 @@
+import logging
+from typing import Optional, List
+
+from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
+from core.model_runtime.callbacks.base_callback import Callback
+from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult
+from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage
+from core.model_runtime.model_providers.__base.ai_model import AIModel
+
+logger = logging.getLogger(__name__)
+
+
+class AgentLLMCallback(Callback):
+
+    def __init__(self, agent_callback: AgentLoopGatherCallbackHandler) -> None:
+        self.agent_callback = agent_callback
+
+    def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
+                         prompt_messages: list[PromptMessage], model_parameters: dict,
+                         tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
+                         stream: bool = True, user: Optional[str] = None) -> None:
+        """
+        Before invoke callback
+
+        :param llm_instance: LLM instance
+        :param model: model name
+        :param credentials: model credentials
+        :param prompt_messages: prompt messages
+        :param model_parameters: model parameters
+        :param tools: tools for tool calling
+        :param stop: stop words
+        :param stream: is stream response
+        :param user: unique user id
+        """
+        self.agent_callback.on_llm_before_invoke(
+            prompt_messages=prompt_messages
+        )
+
+    def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
+                     prompt_messages: list[PromptMessage], model_parameters: dict,
+                     tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
+                     stream: bool = True, user: Optional[str] = None):
+        """
+        On new chunk callback
+
+        :param llm_instance: LLM instance
+        :param chunk: chunk
+        :param model: model name
+        :param credentials: model credentials
+        :param prompt_messages: prompt messages
+        :param model_parameters: model parameters
+        :param tools: tools for tool calling
+        :param stop: stop words
+        :param stream: is stream response
+        :param user: unique user id
+        """
+        pass
+
+    def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
+                        prompt_messages: list[PromptMessage], model_parameters: dict,
+                        tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
+                        stream: bool = True, user: Optional[str] = None) -> None:
+        """
+        After invoke callback
+
+        :param llm_instance: LLM instance
+        :param result: result
+        :param model: model name
+        :param credentials: model credentials
+        :param prompt_messages: prompt messages
+        :param model_parameters: model parameters
+        :param tools: tools for tool calling
+        :param stop: stop words
+        :param stream: is stream response
+        :param user: unique user id
+        """
+        self.agent_callback.on_llm_after_invoke(
+            result=result
+        )
+
+    def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
+                        prompt_messages: list[PromptMessage], model_parameters: dict,
+                        tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
+                        stream: bool = True, user: Optional[str] = None) -> None:
+        """
+        Invoke error callback
+
+        :param llm_instance: LLM instance
+        :param ex: exception
+        :param model: model name
+        :param credentials: model credentials
+        :param prompt_messages: prompt messages
+        :param model_parameters: model parameters
+        :param tools: tools for tool calling
+        :param stop: stop words
+        :param stream: is stream response
+        :param user: unique user id
+        """
+        self.agent_callback.on_llm_error(
+            error=ex
+        )

+ 33 - 12
api/core/agent/agent/calc_token_mixin.py

@@ -1,28 +1,49 @@
-from typing import List
+from typing import List, cast
 
 
 from langchain.schema import BaseMessage
 from langchain.schema import BaseMessage
 
 
-from core.model_providers.models.entity.message import to_prompt_messages
-from core.model_providers.models.llm.base import BaseLLM
+from core.entities.application_entities import ModelConfigEntity
+from core.entities.message_entities import lc_messages_to_prompt_messages
+from core.model_runtime.entities.message_entities import PromptMessage
+from core.model_runtime.entities.model_entities import ModelPropertyKey
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 
 
 
 
 class CalcTokenMixin:
 class CalcTokenMixin:
 
 
-    def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
-        return model_instance.get_num_tokens(to_prompt_messages(messages))
-
-    def get_message_rest_tokens(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
+    def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: List[PromptMessage], **kwargs) -> int:
         """
         """
         Got the rest tokens available for the model after excluding messages tokens and completion max tokens
         Got the rest tokens available for the model after excluding messages tokens and completion max tokens
 
 
-        :param llm:
+        :param model_config:
         :param messages:
         :param messages:
         :return:
         :return:
         """
         """
-        llm_max_tokens = model_instance.model_rules.max_tokens.max
-        completion_max_tokens = model_instance.model_kwargs.max_tokens
-        used_tokens = self.get_num_tokens_from_messages(model_instance, messages, **kwargs)
-        rest_tokens = llm_max_tokens - completion_max_tokens - used_tokens
+        model_type_instance = model_config.provider_model_bundle.model_type_instance
+        model_type_instance = cast(LargeLanguageModel, model_type_instance)
+
+        model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
+
+        max_tokens = 0
+        for parameter_rule in model_config.model_schema.parameter_rules:
+            if (parameter_rule.name == 'max_tokens'
+                    or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
+                max_tokens = (model_config.parameters.get(parameter_rule.name)
+                              or model_config.parameters.get(parameter_rule.use_template)) or 0
+
+        if model_context_tokens is None:
+            return 0
+
+        if max_tokens is None:
+            max_tokens = 0
+
+        prompt_tokens = model_type_instance.get_num_tokens(
+            model_config.model,
+            model_config.credentials,
+            messages
+        )
+
+        rest_tokens = model_context_tokens - max_tokens - prompt_tokens
 
 
         return rest_tokens
         return rest_tokens
 
 

+ 38 - 16
api/core/agent/agent/multi_dataset_router_agent.py

@@ -1,4 +1,3 @@
-import json
 from typing import Tuple, List, Any, Union, Sequence, Optional, cast
 from typing import Tuple, List, Any, Union, Sequence, Optional, cast
 
 
 from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
 from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
@@ -6,13 +5,14 @@ from langchain.agents.openai_functions_agent.base import _format_intermediate_st
 from langchain.callbacks.base import BaseCallbackManager
 from langchain.callbacks.base import BaseCallbackManager
 from langchain.callbacks.manager import Callbacks
 from langchain.callbacks.manager import Callbacks
 from langchain.prompts.chat import BaseMessagePromptTemplate
 from langchain.prompts.chat import BaseMessagePromptTemplate
-from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage
-from langchain.schema.language_model import BaseLanguageModel
+from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage
 from langchain.tools import BaseTool
 from langchain.tools import BaseTool
 from pydantic import root_validator
 from pydantic import root_validator
 
 
-from core.model_providers.models.entity.message import to_prompt_messages
-from core.model_providers.models.llm.base import BaseLLM
+from core.entities.application_entities import ModelConfigEntity
+from core.model_manager import ModelInstance
+from core.entities.message_entities import lc_messages_to_prompt_messages
+from core.model_runtime.entities.message_entities import PromptMessageTool
 from core.third_party.langchain.llms.fake import FakeLLM
 from core.third_party.langchain.llms.fake import FakeLLM
 
 
 
 
@@ -20,7 +20,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
     """
     """
     An Multi Dataset Retrieve Agent driven by Router.
     An Multi Dataset Retrieve Agent driven by Router.
     """
     """
-    model_instance: BaseLLM
+    model_config: ModelConfigEntity
 
 
     class Config:
     class Config:
         """Configuration for this pydantic object."""
         """Configuration for this pydantic object."""
@@ -81,8 +81,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
                 agent_decision.return_values['output'] = ''
                 agent_decision.return_values['output'] = ''
             return agent_decision
             return agent_decision
         except Exception as e:
         except Exception as e:
-            new_exception = self.model_instance.handle_exceptions(e)
-            raise new_exception
+            raise e
 
 
     def real_plan(
     def real_plan(
         self,
         self,
@@ -106,16 +105,39 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
         full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
         full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
         prompt = self.prompt.format_prompt(**full_inputs)
         prompt = self.prompt.format_prompt(**full_inputs)
         messages = prompt.to_messages()
         messages = prompt.to_messages()
-        prompt_messages = to_prompt_messages(messages)
-        result = self.model_instance.run(
-            messages=prompt_messages,
-            functions=self.functions,
+        prompt_messages = lc_messages_to_prompt_messages(messages)
+
+        model_instance = ModelInstance(
+            provider_model_bundle=self.model_config.provider_model_bundle,
+            model=self.model_config.model,
+        )
+
+        tools = []
+        for function in self.functions:
+            tool = PromptMessageTool(
+                **function
+            )
+
+            tools.append(tool)
+
+        result = model_instance.invoke_llm(
+            prompt_messages=prompt_messages,
+            tools=tools,
+            stream=False,
+            model_parameters={
+                'temperature': 0.2,
+                'top_p': 0.3,
+                'max_tokens': 1500
+            }
         )
         )
 
 
         ai_message = AIMessage(
         ai_message = AIMessage(
-            content=result.content,
+            content=result.message.content or "",
             additional_kwargs={
             additional_kwargs={
-                'function_call': result.function_call
+                'function_call': {
+                    'id': result.message.tool_calls[0].id,
+                    **result.message.tool_calls[0].function.dict()
+                } if result.message.tool_calls else None
             }
             }
         )
         )
 
 
@@ -133,7 +155,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
     @classmethod
     @classmethod
     def from_llm_and_tools(
     def from_llm_and_tools(
             cls,
             cls,
-            model_instance: BaseLLM,
+            model_config: ModelConfigEntity,
             tools: Sequence[BaseTool],
             tools: Sequence[BaseTool],
             callback_manager: Optional[BaseCallbackManager] = None,
             callback_manager: Optional[BaseCallbackManager] = None,
             extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
             extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
@@ -147,7 +169,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
             system_message=system_message,
             system_message=system_message,
         )
         )
         return cls(
         return cls(
-            model_instance=model_instance,
+            model_config=model_config,
             llm=FakeLLM(response=''),
             llm=FakeLLM(response=''),
             prompt=prompt,
             prompt=prompt,
             tools=tools,
             tools=tools,

+ 91 - 34
api/core/agent/agent/openai_function_call.py

@@ -1,4 +1,4 @@
-from typing import List, Tuple, Any, Union, Sequence, Optional
+from typing import List, Tuple, Any, Union, Sequence, Optional, cast
 
 
 from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
 from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
 from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
 from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
@@ -13,18 +13,23 @@ from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage,
 from langchain.tools import BaseTool
 from langchain.tools import BaseTool
 from pydantic import root_validator
 from pydantic import root_validator
 
 
+from core.agent.agent.agent_llm_callback import AgentLLMCallback
 from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
 from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
 from core.chain.llm_chain import LLMChain
 from core.chain.llm_chain import LLMChain
-from core.model_providers.models.entity.message import to_prompt_messages
-from core.model_providers.models.llm.base import BaseLLM
+from core.entities.application_entities import ModelConfigEntity
+from core.model_manager import ModelInstance
+from core.entities.message_entities import lc_messages_to_prompt_messages
+from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.third_party.langchain.llms.fake import FakeLLM
 from core.third_party.langchain.llms.fake import FakeLLM
 
 
 
 
 class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
 class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
     moving_summary_buffer: str = ""
     moving_summary_buffer: str = ""
     moving_summary_index: int = 0
     moving_summary_index: int = 0
-    summary_model_instance: BaseLLM = None
-    model_instance: BaseLLM
+    summary_model_config: ModelConfigEntity = None
+    model_config: ModelConfigEntity
+    agent_llm_callback: Optional[AgentLLMCallback] = None
 
 
     class Config:
     class Config:
         """Configuration for this pydantic object."""
         """Configuration for this pydantic object."""
@@ -38,13 +43,14 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
     @classmethod
     @classmethod
     def from_llm_and_tools(
     def from_llm_and_tools(
             cls,
             cls,
-            model_instance: BaseLLM,
+            model_config: ModelConfigEntity,
             tools: Sequence[BaseTool],
             tools: Sequence[BaseTool],
             callback_manager: Optional[BaseCallbackManager] = None,
             callback_manager: Optional[BaseCallbackManager] = None,
             extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
             extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
             system_message: Optional[SystemMessage] = SystemMessage(
             system_message: Optional[SystemMessage] = SystemMessage(
                 content="You are a helpful AI assistant."
                 content="You are a helpful AI assistant."
             ),
             ),
+            agent_llm_callback: Optional[AgentLLMCallback] = None,
             **kwargs: Any,
             **kwargs: Any,
     ) -> BaseSingleActionAgent:
     ) -> BaseSingleActionAgent:
         prompt = cls.create_prompt(
         prompt = cls.create_prompt(
@@ -52,11 +58,12 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
             system_message=system_message,
             system_message=system_message,
         )
         )
         return cls(
         return cls(
-            model_instance=model_instance,
+            model_config=model_config,
             llm=FakeLLM(response=''),
             llm=FakeLLM(response=''),
             prompt=prompt,
             prompt=prompt,
             tools=tools,
             tools=tools,
             callback_manager=callback_manager,
             callback_manager=callback_manager,
+            agent_llm_callback=agent_llm_callback,
             **kwargs,
             **kwargs,
         )
         )
 
 
@@ -67,28 +74,49 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
         :param query:
         :param query:
         :return:
         :return:
         """
         """
-        original_max_tokens = self.model_instance.model_kwargs.max_tokens
-        self.model_instance.model_kwargs.max_tokens = 40
+        original_max_tokens = 0
+        for parameter_rule in self.model_config.model_schema.parameter_rules:
+            if (parameter_rule.name == 'max_tokens'
+                    or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
+                original_max_tokens = (self.model_config.parameters.get(parameter_rule.name)
+                              or self.model_config.parameters.get(parameter_rule.use_template)) or 0
+
+        self.model_config.parameters['max_tokens'] = 40
 
 
         prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
         prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
         messages = prompt.to_messages()
         messages = prompt.to_messages()
 
 
         try:
         try:
-            prompt_messages = to_prompt_messages(messages)
-            result = self.model_instance.run(
-                messages=prompt_messages,
-                functions=self.functions,
-                callbacks=None
+            prompt_messages = lc_messages_to_prompt_messages(messages)
+            model_instance = ModelInstance(
+                provider_model_bundle=self.model_config.provider_model_bundle,
+                model=self.model_config.model,
             )
             )
-        except Exception as e:
-            new_exception = self.model_instance.handle_exceptions(e)
-            raise new_exception
 
 
-        function_call = result.function_call
+            tools = []
+            for function in self.functions:
+                tool = PromptMessageTool(
+                    **function
+                )
+
+                tools.append(tool)
+
+            result = model_instance.invoke_llm(
+                prompt_messages=prompt_messages,
+                tools=tools,
+                stream=False,
+                model_parameters={
+                    'temperature': 0.2,
+                    'top_p': 0.3,
+                    'max_tokens': 1500
+                }
+            )
+        except Exception as e:
+            raise e
 
 
-        self.model_instance.model_kwargs.max_tokens = original_max_tokens
+        self.model_config.parameters['max_tokens'] = original_max_tokens
 
 
-        return True if function_call else False
+        return True if result.message.tool_calls else False
 
 
     def plan(
     def plan(
             self,
             self,
@@ -113,22 +141,46 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
         prompt = self.prompt.format_prompt(**full_inputs)
         prompt = self.prompt.format_prompt(**full_inputs)
         messages = prompt.to_messages()
         messages = prompt.to_messages()
 
 
+        prompt_messages = lc_messages_to_prompt_messages(messages)
+
         # summarize messages if rest_tokens < 0
         # summarize messages if rest_tokens < 0
         try:
         try:
-            messages = self.summarize_messages_if_needed(messages, functions=self.functions)
+            prompt_messages = self.summarize_messages_if_needed(prompt_messages, functions=self.functions)
         except ExceededLLMTokensLimitError as e:
         except ExceededLLMTokensLimitError as e:
             return AgentFinish(return_values={"output": str(e)}, log=str(e))
             return AgentFinish(return_values={"output": str(e)}, log=str(e))
 
 
-        prompt_messages = to_prompt_messages(messages)
-        result = self.model_instance.run(
-            messages=prompt_messages,
-            functions=self.functions,
+        model_instance = ModelInstance(
+            provider_model_bundle=self.model_config.provider_model_bundle,
+            model=self.model_config.model,
+        )
+
+        tools = []
+        for function in self.functions:
+            tool = PromptMessageTool(
+                **function
+            )
+
+            tools.append(tool)
+
+        result = model_instance.invoke_llm(
+            prompt_messages=prompt_messages,
+            tools=tools,
+            stream=False,
+            callbacks=[self.agent_llm_callback] if self.agent_llm_callback else [],
+            model_parameters={
+                'temperature': 0.2,
+                'top_p': 0.3,
+                'max_tokens': 1500
+            }
         )
         )
 
 
         ai_message = AIMessage(
         ai_message = AIMessage(
-            content=result.content,
+            content=result.message.content or "",
             additional_kwargs={
             additional_kwargs={
-                'function_call': result.function_call
+                'function_call': {
+                    'id': result.message.tool_calls[0].id,
+                    **result.message.tool_calls[0].function.dict()
+                } if result.message.tool_calls else None
             }
             }
         )
         )
         agent_decision = _parse_ai_message(ai_message)
         agent_decision = _parse_ai_message(ai_message)
@@ -158,9 +210,14 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
         except ValueError:
         except ValueError:
             return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
             return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
 
 
-    def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
+    def summarize_messages_if_needed(self, messages: List[PromptMessage], **kwargs) -> List[PromptMessage]:
         # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
         # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
-        rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
+        rest_tokens = self.get_message_rest_tokens(
+            self.model_config,
+            messages,
+            **kwargs
+        )
+
         rest_tokens = rest_tokens - 20  # to deal with the inaccuracy of rest_tokens
         rest_tokens = rest_tokens - 20  # to deal with the inaccuracy of rest_tokens
         if rest_tokens >= 0:
         if rest_tokens >= 0:
             return messages
             return messages
@@ -210,19 +267,19 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
             ai_prefix="AI",
             ai_prefix="AI",
         )
         )
 
 
-        chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
+        chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
         return chain.predict(summary=existing_summary, new_lines=new_lines)
         return chain.predict(summary=existing_summary, new_lines=new_lines)
 
 
-    def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
+    def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: List[BaseMessage], **kwargs) -> int:
         """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
         """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
 
 
         Official documentation: https://github.com/openai/openai-cookbook/blob/
         Official documentation: https://github.com/openai/openai-cookbook/blob/
         main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
         main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
-        if model_instance.model_provider.provider_name == 'azure_openai':
-            model = model_instance.base_model_name
+        if model_config.provider == 'azure_openai':
+            model = model_config.model
             model = model.replace("gpt-35", "gpt-3.5")
             model = model.replace("gpt-35", "gpt-3.5")
         else:
         else:
-            model = model_instance.base_model_name
+            model = model_config.credentials.get("base_model_name")
 
 
         tiktoken_ = _import_tiktoken()
         tiktoken_ = _import_tiktoken()
         try:
         try:

+ 0 - 158
api/core/agent/agent/output_parser/retirver_dataset_agent.py

@@ -1,158 +0,0 @@
-import json
-from typing import Tuple, List, Any, Union, Sequence, Optional, cast
-
-from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
-from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
-from langchain.callbacks.base import BaseCallbackManager
-from langchain.callbacks.manager import Callbacks
-from langchain.prompts.chat import BaseMessagePromptTemplate
-from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage
-from langchain.schema.language_model import BaseLanguageModel
-from langchain.tools import BaseTool
-from pydantic import root_validator
-
-from core.model_providers.models.entity.message import to_prompt_messages
-from core.model_providers.models.llm.base import BaseLLM
-from core.third_party.langchain.llms.fake import FakeLLM
-from core.tool.dataset_retriever_tool import DatasetRetrieverTool
-
-
-class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
-    """
-    An Multi Dataset Retrieve Agent driven by Router.
-    """
-    model_instance: BaseLLM
-
-    class Config:
-        """Configuration for this pydantic object."""
-
-        arbitrary_types_allowed = True
-
-    @root_validator
-    def validate_llm(cls, values: dict) -> dict:
-        return values
-
-    def should_use_agent(self, query: str):
-        """
-        return should use agent
-
-        :param query:
-        :return:
-        """
-        return True
-
-    def plan(
-        self,
-        intermediate_steps: List[Tuple[AgentAction, str]],
-        callbacks: Callbacks = None,
-        **kwargs: Any,
-    ) -> Union[AgentAction, AgentFinish]:
-        """Given input, decided what to do.
-
-        Args:
-            intermediate_steps: Steps the LLM has taken to date, along with observations
-            **kwargs: User inputs.
-
-        Returns:
-            Action specifying what tool to use.
-        """
-        if len(self.tools) == 0:
-            return AgentFinish(return_values={"output": ''}, log='')
-        elif len(self.tools) == 1:
-            tool = next(iter(self.tools))
-            tool = cast(DatasetRetrieverTool, tool)
-            rst = tool.run(tool_input={'query': kwargs['input']})
-            # output = ''
-            # rst_json = json.loads(rst)
-            # for item in rst_json:
-            #     output += f'{item["content"]}\n'
-            return AgentFinish(return_values={"output": rst}, log=rst)
-
-        if intermediate_steps:
-            _, observation = intermediate_steps[-1]
-            return AgentFinish(return_values={"output": observation}, log=observation)
-
-        try:
-            agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs)
-            if isinstance(agent_decision, AgentAction):
-                tool_inputs = agent_decision.tool_input
-                if isinstance(tool_inputs, dict) and 'query' in tool_inputs and 'chat_history' not in kwargs:
-                    tool_inputs['query'] = kwargs['input']
-                    agent_decision.tool_input = tool_inputs
-            else:
-                agent_decision.return_values['output'] = ''
-            return agent_decision
-        except Exception as e:
-            new_exception = self.model_instance.handle_exceptions(e)
-            raise new_exception
-
-    def real_plan(
-        self,
-        intermediate_steps: List[Tuple[AgentAction, str]],
-        callbacks: Callbacks = None,
-        **kwargs: Any,
-    ) -> Union[AgentAction, AgentFinish]:
-        """Given input, decided what to do.
-
-        Args:
-            intermediate_steps: Steps the LLM has taken to date, along with observations
-            **kwargs: User inputs.
-
-        Returns:
-            Action specifying what tool to use.
-        """
-        agent_scratchpad = _format_intermediate_steps(intermediate_steps)
-        selected_inputs = {
-            k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
-        }
-        full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
-        prompt = self.prompt.format_prompt(**full_inputs)
-        messages = prompt.to_messages()
-        prompt_messages = to_prompt_messages(messages)
-        result = self.model_instance.run(
-            messages=prompt_messages,
-            functions=self.functions,
-        )
-
-        ai_message = AIMessage(
-            content=result.content,
-            additional_kwargs={
-                'function_call': result.function_call
-            }
-        )
-
-        agent_decision = _parse_ai_message(ai_message)
-        return agent_decision
-
-    async def aplan(
-            self,
-            intermediate_steps: List[Tuple[AgentAction, str]],
-            callbacks: Callbacks = None,
-            **kwargs: Any,
-    ) -> Union[AgentAction, AgentFinish]:
-        raise NotImplementedError()
-
-    @classmethod
-    def from_llm_and_tools(
-            cls,
-            model_instance: BaseLLM,
-            tools: Sequence[BaseTool],
-            callback_manager: Optional[BaseCallbackManager] = None,
-            extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
-            system_message: Optional[SystemMessage] = SystemMessage(
-                content="You are a helpful AI assistant."
-            ),
-            **kwargs: Any,
-    ) -> BaseSingleActionAgent:
-        prompt = cls.create_prompt(
-            extra_prompt_messages=extra_prompt_messages,
-            system_message=system_message,
-        )
-        return cls(
-            model_instance=model_instance,
-            llm=FakeLLM(response=''),
-            prompt=prompt,
-            tools=tools,
-            callback_manager=callback_manager,
-            **kwargs,
-        )

+ 18 - 14
api/core/agent/agent/structed_multi_dataset_router_agent.py

@@ -12,9 +12,7 @@ from langchain.tools import BaseTool
 from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
 from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
 
 
 from core.chain.llm_chain import LLMChain
 from core.chain.llm_chain import LLMChain
-from core.model_providers.models.entity.model_params import ModelMode
-from core.model_providers.models.llm.base import BaseLLM
-from core.tool.dataset_retriever_tool import DatasetRetrieverTool
+from core.entities.application_entities import ModelConfigEntity
 
 
 FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
 FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
 The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
 The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
@@ -69,10 +67,10 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
         return True
         return True
 
 
     def plan(
     def plan(
-        self,
-        intermediate_steps: List[Tuple[AgentAction, str]],
-        callbacks: Callbacks = None,
-        **kwargs: Any,
+            self,
+            intermediate_steps: List[Tuple[AgentAction, str]],
+            callbacks: Callbacks = None,
+            **kwargs: Any,
     ) -> Union[AgentAction, AgentFinish]:
     ) -> Union[AgentAction, AgentFinish]:
         """Given input, decided what to do.
         """Given input, decided what to do.
 
 
@@ -101,8 +99,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
         try:
         try:
             full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
             full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
         except Exception as e:
         except Exception as e:
-            new_exception = self.llm_chain.model_instance.handle_exceptions(e)
-            raise new_exception
+            raise e
 
 
         try:
         try:
             agent_decision = self.output_parser.parse(full_output)
             agent_decision = self.output_parser.parse(full_output)
@@ -119,6 +116,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
         except OutputParserException:
         except OutputParserException:
             return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
             return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
                                           "I don't know how to respond to that."}, "")
                                           "I don't know how to respond to that."}, "")
+
     @classmethod
     @classmethod
     def create_prompt(
     def create_prompt(
             cls,
             cls,
@@ -182,7 +180,7 @@ Thought: {agent_scratchpad}
         return PromptTemplate(template=template, input_variables=input_variables)
         return PromptTemplate(template=template, input_variables=input_variables)
 
 
     def _construct_scratchpad(
     def _construct_scratchpad(
-        self, intermediate_steps: List[Tuple[AgentAction, str]]
+            self, intermediate_steps: List[Tuple[AgentAction, str]]
     ) -> str:
     ) -> str:
         agent_scratchpad = ""
         agent_scratchpad = ""
         for action, observation in intermediate_steps:
         for action, observation in intermediate_steps:
@@ -193,7 +191,7 @@ Thought: {agent_scratchpad}
             raise ValueError("agent_scratchpad should be of type string.")
             raise ValueError("agent_scratchpad should be of type string.")
         if agent_scratchpad:
         if agent_scratchpad:
             llm_chain = cast(LLMChain, self.llm_chain)
             llm_chain = cast(LLMChain, self.llm_chain)
-            if llm_chain.model_instance.model_mode == ModelMode.CHAT:
+            if llm_chain.model_config.mode == "chat":
                 return (
                 return (
                     f"This was your previous work "
                     f"This was your previous work "
                     f"(but I haven't seen any of it! I only see what "
                     f"(but I haven't seen any of it! I only see what "
@@ -207,7 +205,7 @@ Thought: {agent_scratchpad}
     @classmethod
     @classmethod
     def from_llm_and_tools(
     def from_llm_and_tools(
             cls,
             cls,
-            model_instance: BaseLLM,
+            model_config: ModelConfigEntity,
             tools: Sequence[BaseTool],
             tools: Sequence[BaseTool],
             callback_manager: Optional[BaseCallbackManager] = None,
             callback_manager: Optional[BaseCallbackManager] = None,
             output_parser: Optional[AgentOutputParser] = None,
             output_parser: Optional[AgentOutputParser] = None,
@@ -221,7 +219,7 @@ Thought: {agent_scratchpad}
     ) -> Agent:
     ) -> Agent:
         """Construct an agent from an LLM and tools."""
         """Construct an agent from an LLM and tools."""
         cls._validate_tools(tools)
         cls._validate_tools(tools)
-        if model_instance.model_mode == ModelMode.CHAT:
+        if model_config.mode == "chat":
             prompt = cls.create_prompt(
             prompt = cls.create_prompt(
                 tools,
                 tools,
                 prefix=prefix,
                 prefix=prefix,
@@ -238,10 +236,16 @@ Thought: {agent_scratchpad}
                 format_instructions=format_instructions,
                 format_instructions=format_instructions,
                 input_variables=input_variables
                 input_variables=input_variables
             )
             )
+
         llm_chain = LLMChain(
         llm_chain = LLMChain(
-            model_instance=model_instance,
+            model_config=model_config,
             prompt=prompt,
             prompt=prompt,
             callback_manager=callback_manager,
             callback_manager=callback_manager,
+            parameters={
+                'temperature': 0.2,
+                'top_p': 0.3,
+                'max_tokens': 1500
+            }
         )
         )
         tool_names = [tool.name for tool in tools]
         tool_names = [tool.name for tool in tools]
         _output_parser = output_parser
         _output_parser = output_parser

+ 22 - 13
api/core/agent/agent/structured_chat.py

@@ -13,10 +13,11 @@ from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage,
 from langchain.tools import BaseTool
 from langchain.tools import BaseTool
 from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
 from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
 
 
+from core.agent.agent.agent_llm_callback import AgentLLMCallback
 from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
 from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
 from core.chain.llm_chain import LLMChain
 from core.chain.llm_chain import LLMChain
-from core.model_providers.models.entity.model_params import ModelMode
-from core.model_providers.models.llm.base import BaseLLM
+from core.entities.application_entities import ModelConfigEntity
+from core.entities.message_entities import lc_messages_to_prompt_messages
 
 
 FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
 FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
 The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
 The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
@@ -54,7 +55,7 @@ Action:
 class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
 class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
     moving_summary_buffer: str = ""
     moving_summary_buffer: str = ""
     moving_summary_index: int = 0
     moving_summary_index: int = 0
-    summary_model_instance: BaseLLM = None
+    summary_model_config: ModelConfigEntity = None
 
 
     class Config:
     class Config:
         """Configuration for this pydantic object."""
         """Configuration for this pydantic object."""
@@ -82,7 +83,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
 
 
         Args:
         Args:
             intermediate_steps: Steps the LLM has taken to date,
             intermediate_steps: Steps the LLM has taken to date,
-                along with observations
+                along with observatons
             callbacks: Callbacks to run.
             callbacks: Callbacks to run.
             **kwargs: User inputs.
             **kwargs: User inputs.
 
 
@@ -96,15 +97,16 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
         if prompts:
         if prompts:
             messages = prompts[0].to_messages()
             messages = prompts[0].to_messages()
 
 
-        rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_instance, messages)
+        prompt_messages = lc_messages_to_prompt_messages(messages)
+
+        rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_config, prompt_messages)
         if rest_tokens < 0:
         if rest_tokens < 0:
             full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
             full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
 
 
         try:
         try:
             full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
             full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
         except Exception as e:
         except Exception as e:
-            new_exception = self.llm_chain.model_instance.handle_exceptions(e)
-            raise new_exception
+            raise e
 
 
         try:
         try:
             agent_decision = self.output_parser.parse(full_output)
             agent_decision = self.output_parser.parse(full_output)
@@ -119,7 +121,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
                                           "I don't know how to respond to that."}, "")
                                           "I don't know how to respond to that."}, "")
 
 
     def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
     def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
-        if len(intermediate_steps) >= 2 and self.summary_model_instance:
+        if len(intermediate_steps) >= 2 and self.summary_model_config:
             should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
             should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
             should_summary_messages = [AIMessage(content=observation)
             should_summary_messages = [AIMessage(content=observation)
                                        for _, observation in should_summary_intermediate_steps]
                                        for _, observation in should_summary_intermediate_steps]
@@ -153,7 +155,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
             ai_prefix="AI",
             ai_prefix="AI",
         )
         )
 
 
-        chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
+        chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
         return chain.predict(summary=existing_summary, new_lines=new_lines)
         return chain.predict(summary=existing_summary, new_lines=new_lines)
 
 
     @classmethod
     @classmethod
@@ -229,7 +231,7 @@ Thought: {agent_scratchpad}
             raise ValueError("agent_scratchpad should be of type string.")
             raise ValueError("agent_scratchpad should be of type string.")
         if agent_scratchpad:
         if agent_scratchpad:
             llm_chain = cast(LLMChain, self.llm_chain)
             llm_chain = cast(LLMChain, self.llm_chain)
-            if llm_chain.model_instance.model_mode == ModelMode.CHAT:
+            if llm_chain.model_config.mode == "chat":
                 return (
                 return (
                     f"This was your previous work "
                     f"This was your previous work "
                     f"(but I haven't seen any of it! I only see what "
                     f"(but I haven't seen any of it! I only see what "
@@ -243,7 +245,7 @@ Thought: {agent_scratchpad}
     @classmethod
     @classmethod
     def from_llm_and_tools(
     def from_llm_and_tools(
             cls,
             cls,
-            model_instance: BaseLLM,
+            model_config: ModelConfigEntity,
             tools: Sequence[BaseTool],
             tools: Sequence[BaseTool],
             callback_manager: Optional[BaseCallbackManager] = None,
             callback_manager: Optional[BaseCallbackManager] = None,
             output_parser: Optional[AgentOutputParser] = None,
             output_parser: Optional[AgentOutputParser] = None,
@@ -253,11 +255,12 @@ Thought: {agent_scratchpad}
             format_instructions: str = FORMAT_INSTRUCTIONS,
             format_instructions: str = FORMAT_INSTRUCTIONS,
             input_variables: Optional[List[str]] = None,
             input_variables: Optional[List[str]] = None,
             memory_prompts: Optional[List[BasePromptTemplate]] = None,
             memory_prompts: Optional[List[BasePromptTemplate]] = None,
+            agent_llm_callback: Optional[AgentLLMCallback] = None,
             **kwargs: Any,
             **kwargs: Any,
     ) -> Agent:
     ) -> Agent:
         """Construct an agent from an LLM and tools."""
         """Construct an agent from an LLM and tools."""
         cls._validate_tools(tools)
         cls._validate_tools(tools)
-        if model_instance.model_mode == ModelMode.CHAT:
+        if model_config.mode == "chat":
             prompt = cls.create_prompt(
             prompt = cls.create_prompt(
                 tools,
                 tools,
                 prefix=prefix,
                 prefix=prefix,
@@ -275,9 +278,15 @@ Thought: {agent_scratchpad}
                 input_variables=input_variables,
                 input_variables=input_variables,
             )
             )
         llm_chain = LLMChain(
         llm_chain = LLMChain(
-            model_instance=model_instance,
+            model_config=model_config,
             prompt=prompt,
             prompt=prompt,
             callback_manager=callback_manager,
             callback_manager=callback_manager,
+            agent_llm_callback=agent_llm_callback,
+            parameters={
+                'temperature': 0.2,
+                'top_p': 0.3,
+                'max_tokens': 1500
+            }
         )
         )
         tool_names = [tool.name for tool in tools]
         tool_names = [tool.name for tool in tools]
         _output_parser = output_parser
         _output_parser = output_parser

+ 33 - 23
api/core/agent/agent_executor.py

@@ -4,10 +4,10 @@ from typing import Union, Optional
 
 
 from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent
 from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent
 from langchain.callbacks.manager import Callbacks
 from langchain.callbacks.manager import Callbacks
-from langchain.memory.chat_memory import BaseChatMemory
 from langchain.tools import BaseTool
 from langchain.tools import BaseTool
 from pydantic import BaseModel, Extra
 from pydantic import BaseModel, Extra
 
 
+from core.agent.agent.agent_llm_callback import AgentLLMCallback
 from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
 from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
 from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
 from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
 from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
 from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
@@ -15,9 +15,11 @@ from core.agent.agent.structed_multi_dataset_router_agent import StructuredMulti
 from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
 from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
 from langchain.agents import AgentExecutor as LCAgentExecutor
 from langchain.agents import AgentExecutor as LCAgentExecutor
 
 
+from core.entities.application_entities import ModelConfigEntity
+from core.entities.message_entities import prompt_messages_to_lc_messages
 from core.helper import moderation
 from core.helper import moderation
-from core.model_providers.error import LLMError
-from core.model_providers.models.llm.base import BaseLLM
+from core.memory.token_buffer_memory import TokenBufferMemory
+from core.model_runtime.errors.invoke import InvokeError
 from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
 from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
 from core.tool.dataset_retriever_tool import DatasetRetrieverTool
 from core.tool.dataset_retriever_tool import DatasetRetrieverTool
 
 
@@ -31,14 +33,15 @@ class PlanningStrategy(str, enum.Enum):
 
 
 class AgentConfiguration(BaseModel):
 class AgentConfiguration(BaseModel):
     strategy: PlanningStrategy
     strategy: PlanningStrategy
-    model_instance: BaseLLM
+    model_config: ModelConfigEntity
     tools: list[BaseTool]
     tools: list[BaseTool]
-    summary_model_instance: BaseLLM = None
-    memory: Optional[BaseChatMemory] = None
+    summary_model_config: Optional[ModelConfigEntity] = None
+    memory: Optional[TokenBufferMemory] = None
     callbacks: Callbacks = None
     callbacks: Callbacks = None
     max_iterations: int = 6
     max_iterations: int = 6
     max_execution_time: Optional[float] = None
     max_execution_time: Optional[float] = None
     early_stopping_method: str = "generate"
     early_stopping_method: str = "generate"
+    agent_llm_callback: Optional[AgentLLMCallback] = None
     # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
     # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
 
 
     class Config:
     class Config:
@@ -62,34 +65,42 @@ class AgentExecutor:
     def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
     def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
         if self.configuration.strategy == PlanningStrategy.REACT:
         if self.configuration.strategy == PlanningStrategy.REACT:
             agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
             agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
-                model_instance=self.configuration.model_instance,
+                model_config=self.configuration.model_config,
                 tools=self.configuration.tools,
                 tools=self.configuration.tools,
                 output_parser=StructuredChatOutputParser(),
                 output_parser=StructuredChatOutputParser(),
-                summary_model_instance=self.configuration.summary_model_instance
-                if self.configuration.summary_model_instance else None,
+                summary_model_config=self.configuration.summary_model_config
+                if self.configuration.summary_model_config else None,
+                agent_llm_callback=self.configuration.agent_llm_callback,
                 verbose=True
                 verbose=True
             )
             )
         elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
         elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
             agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
             agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
-                model_instance=self.configuration.model_instance,
+                model_config=self.configuration.model_config,
                 tools=self.configuration.tools,
                 tools=self.configuration.tools,
-                extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,  # used for read chat histories memory
-                summary_model_instance=self.configuration.summary_model_instance
-                if self.configuration.summary_model_instance else None,
+                extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
+                if self.configuration.memory else None,  # used for read chat histories memory
+                summary_model_config=self.configuration.summary_model_config
+                if self.configuration.summary_model_config else None,
+                agent_llm_callback=self.configuration.agent_llm_callback,
                 verbose=True
                 verbose=True
             )
             )
         elif self.configuration.strategy == PlanningStrategy.ROUTER:
         elif self.configuration.strategy == PlanningStrategy.ROUTER:
-            self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)]
+            self.configuration.tools = [t for t in self.configuration.tools
+                                        if isinstance(t, DatasetRetrieverTool)
+                                        or isinstance(t, DatasetMultiRetrieverTool)]
             agent = MultiDatasetRouterAgent.from_llm_and_tools(
             agent = MultiDatasetRouterAgent.from_llm_and_tools(
-                model_instance=self.configuration.model_instance,
+                model_config=self.configuration.model_config,
                 tools=self.configuration.tools,
                 tools=self.configuration.tools,
-                extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
+                extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
+                if self.configuration.memory else None,
                 verbose=True
                 verbose=True
             )
             )
         elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER:
         elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER:
-            self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)]
+            self.configuration.tools = [t for t in self.configuration.tools
+                                        if isinstance(t, DatasetRetrieverTool)
+                                        or isinstance(t, DatasetMultiRetrieverTool)]
             agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
             agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
-                model_instance=self.configuration.model_instance,
+                model_config=self.configuration.model_config,
                 tools=self.configuration.tools,
                 tools=self.configuration.tools,
                 output_parser=StructuredChatOutputParser(),
                 output_parser=StructuredChatOutputParser(),
                 verbose=True
                 verbose=True
@@ -104,11 +115,11 @@ class AgentExecutor:
 
 
     def run(self, query: str) -> AgentExecuteResult:
     def run(self, query: str) -> AgentExecuteResult:
         moderation_result = moderation.check_moderation(
         moderation_result = moderation.check_moderation(
-            self.configuration.model_instance.model_provider,
+            self.configuration.model_config,
             query
             query
         )
         )
 
 
-        if not moderation_result:
+        if moderation_result:
             return AgentExecuteResult(
             return AgentExecuteResult(
                 output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
                 output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
                 strategy=self.configuration.strategy,
                 strategy=self.configuration.strategy,
@@ -118,7 +129,6 @@ class AgentExecutor:
         agent_executor = LCAgentExecutor.from_agent_and_tools(
         agent_executor = LCAgentExecutor.from_agent_and_tools(
             agent=self.agent,
             agent=self.agent,
             tools=self.configuration.tools,
             tools=self.configuration.tools,
-            memory=self.configuration.memory,
             max_iterations=self.configuration.max_iterations,
             max_iterations=self.configuration.max_iterations,
             max_execution_time=self.configuration.max_execution_time,
             max_execution_time=self.configuration.max_execution_time,
             early_stopping_method=self.configuration.early_stopping_method,
             early_stopping_method=self.configuration.early_stopping_method,
@@ -126,8 +136,8 @@ class AgentExecutor:
         )
         )
 
 
         try:
         try:
-            output = agent_executor.run(query)
-        except LLMError as ex:
+            output = agent_executor.run(input=query)
+        except InvokeError as ex:
             raise ex
             raise ex
         except Exception as ex:
         except Exception as ex:
             logging.exception("agent_executor run failed")
             logging.exception("agent_executor run failed")

+ 0 - 0
api/core/model_providers/models/__init__.py → api/core/app_runner/__init__.py


+ 251 - 0
api/core/app_runner/agent_app_runner.py

@@ -0,0 +1,251 @@
+import json
+import logging
+from typing import cast
+
+from core.agent.agent.agent_llm_callback import AgentLLMCallback
+from core.app_runner.app_runner import AppRunner
+from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
+from core.entities.application_entities import ApplicationGenerateEntity, PromptTemplateEntity, ModelConfigEntity
+from core.application_queue_manager import ApplicationQueueManager
+from core.features.agent_runner import AgentRunnerFeature
+from core.memory.token_buffer_memory import TokenBufferMemory
+from core.model_manager import ModelInstance
+from core.model_runtime.entities.llm_entities import LLMUsage
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from extensions.ext_database import db
+from models.model import Conversation, Message, App, MessageChain, MessageAgentThought
+
+logger = logging.getLogger(__name__)
+
+
+class AgentApplicationRunner(AppRunner):
+    """
+    Agent Application Runner
+    """
+
+    def run(self, application_generate_entity: ApplicationGenerateEntity,
+            queue_manager: ApplicationQueueManager,
+            conversation: Conversation,
+            message: Message) -> None:
+        """
+        Run agent application
+        :param application_generate_entity: application generate entity
+        :param queue_manager: application queue manager
+        :param conversation: conversation
+        :param message: message
+        :return:
+        """
+        app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
+        if not app_record:
+            raise ValueError(f"App not found")
+
+        app_orchestration_config = application_generate_entity.app_orchestration_config_entity
+
+        inputs = application_generate_entity.inputs
+        query = application_generate_entity.query
+        files = application_generate_entity.files
+
+        # Pre-calculate the number of tokens of the prompt messages,
+        # and return the rest number of tokens by model context token size limit and max token size limit.
+        # If the rest number of tokens is not enough, raise exception.
+        # Include: prompt template, inputs, query(optional), files(optional)
+        # Not Include: memory, external data, dataset context
+        self.get_pre_calculate_rest_tokens(
+            app_record=app_record,
+            model_config=app_orchestration_config.model_config,
+            prompt_template_entity=app_orchestration_config.prompt_template,
+            inputs=inputs,
+            files=files,
+            query=query
+        )
+
+        memory = None
+        if application_generate_entity.conversation_id:
+            # get memory of conversation (read-only)
+            model_instance = ModelInstance(
+                provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
+                model=app_orchestration_config.model_config.model
+            )
+
+            memory = TokenBufferMemory(
+                conversation=conversation,
+                model_instance=model_instance
+            )
+
+        # reorganize all inputs and template to prompt messages
+        # Include: prompt template, inputs, query(optional), files(optional)
+        #          memory(optional)
+        prompt_messages, stop = self.originze_prompt_messages(
+            app_record=app_record,
+            model_config=app_orchestration_config.model_config,
+            prompt_template_entity=app_orchestration_config.prompt_template,
+            inputs=inputs,
+            files=files,
+            query=query,
+            context=None,
+            memory=memory
+        )
+
+        # Create MessageChain
+        message_chain = self._init_message_chain(
+            message=message,
+            query=query
+        )
+
+        # add agent callback to record agent thoughts
+        agent_callback = AgentLoopGatherCallbackHandler(
+            model_config=app_orchestration_config.model_config,
+            message=message,
+            queue_manager=queue_manager,
+            message_chain=message_chain
+        )
+
+        # init LLM Callback
+        agent_llm_callback = AgentLLMCallback(
+            agent_callback=agent_callback
+        )
+
+        agent_runner = AgentRunnerFeature(
+            tenant_id=application_generate_entity.tenant_id,
+            app_orchestration_config=app_orchestration_config,
+            model_config=app_orchestration_config.model_config,
+            config=app_orchestration_config.agent,
+            queue_manager=queue_manager,
+            message=message,
+            user_id=application_generate_entity.user_id,
+            agent_llm_callback=agent_llm_callback,
+            callback=agent_callback,
+            memory=memory
+        )
+
+        # agent run
+        result = agent_runner.run(
+            query=query,
+            invoke_from=application_generate_entity.invoke_from
+        )
+
+        if result:
+            self._save_message_chain(
+                message_chain=message_chain,
+                output_text=result
+            )
+
+        if (result
+                and app_orchestration_config.prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE
+                and app_orchestration_config.prompt_template.simple_prompt_template
+        ):
+            # Direct output if agent result exists and has pre prompt
+            self.direct_output(
+                queue_manager=queue_manager,
+                app_orchestration_config=app_orchestration_config,
+                prompt_messages=prompt_messages,
+                stream=application_generate_entity.stream,
+                text=result,
+                usage=self._get_usage_of_all_agent_thoughts(
+                    model_config=app_orchestration_config.model_config,
+                    message=message
+                )
+            )
+        else:
+            # As normal LLM run, agent result as context
+            context = result
+
+            # reorganize all inputs and template to prompt messages
+            # Include: prompt template, inputs, query(optional), files(optional)
+            #          memory(optional), external data, dataset context(optional)
+            prompt_messages, stop = self.originze_prompt_messages(
+                app_record=app_record,
+                model_config=app_orchestration_config.model_config,
+                prompt_template_entity=app_orchestration_config.prompt_template,
+                inputs=inputs,
+                files=files,
+                query=query,
+                context=context,
+                memory=memory
+            )
+
+            # Re-calculate the max tokens if sum(prompt_token +  max_tokens) over model token limit
+            self.recale_llm_max_tokens(
+                model_config=app_orchestration_config.model_config,
+                prompt_messages=prompt_messages
+            )
+
+            # Invoke model
+            model_instance = ModelInstance(
+                provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
+                model=app_orchestration_config.model_config.model
+            )
+
+            invoke_result = model_instance.invoke_llm(
+                prompt_messages=prompt_messages,
+                model_parameters=app_orchestration_config.model_config.parameters,
+                stop=stop,
+                stream=application_generate_entity.stream,
+                user=application_generate_entity.user_id,
+            )
+
+            # handle invoke result
+            self._handle_invoke_result(
+                invoke_result=invoke_result,
+                queue_manager=queue_manager,
+                stream=application_generate_entity.stream
+            )
+
+    def _init_message_chain(self, message: Message, query: str) -> MessageChain:
+        """
+        Init MessageChain
+        :param message: message
+        :param query: query
+        :return:
+        """
+        message_chain = MessageChain(
+            message_id=message.id,
+            type="AgentExecutor",
+            input=json.dumps({
+                "input": query
+            })
+        )
+
+        db.session.add(message_chain)
+        db.session.commit()
+
+        return message_chain
+
+    def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None:
+        """
+        Save MessageChain
+        :param message_chain: message chain
+        :param output_text: output text
+        :return:
+        """
+        message_chain.output = json.dumps({
+            "output": output_text
+        })
+        db.session.commit()
+
+    def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
+                                         message: Message) -> LLMUsage:
+        """
+        Get usage of all agent thoughts
+        :param model_config: model config
+        :param message: message
+        :return:
+        """
+        agent_thoughts = (db.session.query(MessageAgentThought)
+                          .filter(MessageAgentThought.message_id == message.id).all())
+
+        all_message_tokens = 0
+        all_answer_tokens = 0
+        for agent_thought in agent_thoughts:
+            all_message_tokens += agent_thought.message_tokens
+            all_answer_tokens += agent_thought.answer_tokens
+
+        model_type_instance = model_config.provider_model_bundle.model_type_instance
+        model_type_instance = cast(LargeLanguageModel, model_type_instance)
+
+        return model_type_instance._calc_response_usage(
+            model_config.model,
+            model_config.credentials,
+            all_message_tokens,
+            all_answer_tokens
+        )

+ 267 - 0
api/core/app_runner/app_runner.py

@@ -0,0 +1,267 @@
+import time
+from typing import cast, Optional, List, Tuple, Generator, Union
+
+from core.application_queue_manager import ApplicationQueueManager
+from core.entities.application_entities import ModelConfigEntity, PromptTemplateEntity, AppOrchestrationConfigEntity
+from core.file.file_obj import FileObj
+from core.memory.token_buffer_memory import TokenBufferMemory
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
+from core.model_runtime.entities.message_entities import PromptMessage, AssistantPromptMessage
+from core.model_runtime.entities.model_entities import ModelPropertyKey
+from core.model_runtime.errors.invoke import InvokeBadRequestError
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.prompt.prompt_transform import PromptTransform
+from models.model import App
+
+
+class AppRunner:
+    def get_pre_calculate_rest_tokens(self, app_record: App,
+                                      model_config: ModelConfigEntity,
+                                      prompt_template_entity: PromptTemplateEntity,
+                                      inputs: dict[str, str],
+                                      files: list[FileObj],
+                                      query: Optional[str] = None) -> int:
+        """
+        Get pre calculate rest tokens
+        :param app_record: app record
+        :param model_config: model config entity
+        :param prompt_template_entity: prompt template entity
+        :param inputs: inputs
+        :param files: files
+        :param query: query
+        :return:
+        """
+        model_type_instance = model_config.provider_model_bundle.model_type_instance
+        model_type_instance = cast(LargeLanguageModel, model_type_instance)
+
+        model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
+
+        max_tokens = 0
+        for parameter_rule in model_config.model_schema.parameter_rules:
+            if (parameter_rule.name == 'max_tokens'
+                    or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
+                max_tokens = (model_config.parameters.get(parameter_rule.name)
+                              or model_config.parameters.get(parameter_rule.use_template)) or 0
+
+        if model_context_tokens is None:
+            return -1
+
+        if max_tokens is None:
+            max_tokens = 0
+
+        # get prompt messages without memory and context
+        prompt_messages, stop = self.originze_prompt_messages(
+            app_record=app_record,
+            model_config=model_config,
+            prompt_template_entity=prompt_template_entity,
+            inputs=inputs,
+            files=files,
+            query=query
+        )
+
+        prompt_tokens = model_type_instance.get_num_tokens(
+            model_config.model,
+            model_config.credentials,
+            prompt_messages
+        )
+
+        rest_tokens = model_context_tokens - max_tokens - prompt_tokens
+        if rest_tokens < 0:
+            raise InvokeBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
+                                        "or shrink the max token, or switch to a llm with a larger token limit size.")
+
+        return rest_tokens
+
+    def recale_llm_max_tokens(self, model_config: ModelConfigEntity,
+                              prompt_messages: List[PromptMessage]):
+        # recalc max_tokens if sum(prompt_token +  max_tokens) over model token limit
+        model_type_instance = model_config.provider_model_bundle.model_type_instance
+        model_type_instance = cast(LargeLanguageModel, model_type_instance)
+
+        model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
+
+        max_tokens = 0
+        for parameter_rule in model_config.model_schema.parameter_rules:
+            if (parameter_rule.name == 'max_tokens'
+                    or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
+                max_tokens = (model_config.parameters.get(parameter_rule.name)
+                              or model_config.parameters.get(parameter_rule.use_template)) or 0
+
+        if model_context_tokens is None:
+            return -1
+
+        if max_tokens is None:
+            max_tokens = 0
+
+        prompt_tokens = model_type_instance.get_num_tokens(
+            model_config.model,
+            model_config.credentials,
+            prompt_messages
+        )
+
+        if prompt_tokens + max_tokens > model_context_tokens:
+            max_tokens = max(model_context_tokens - prompt_tokens, 16)
+
+            for parameter_rule in model_config.model_schema.parameter_rules:
+                if (parameter_rule.name == 'max_tokens'
+                        or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
+                    model_config.parameters[parameter_rule.name] = max_tokens
+
+    def originze_prompt_messages(self, app_record: App,
+                                 model_config: ModelConfigEntity,
+                                 prompt_template_entity: PromptTemplateEntity,
+                                 inputs: dict[str, str],
+                                 files: list[FileObj],
+                                 query: Optional[str] = None,
+                                 context: Optional[str] = None,
+                                 memory: Optional[TokenBufferMemory] = None) \
+            -> Tuple[List[PromptMessage], Optional[List[str]]]:
+        """
+        Organize prompt messages
+        :param context:
+        :param app_record: app record
+        :param model_config: model config entity
+        :param prompt_template_entity: prompt template entity
+        :param inputs: inputs
+        :param files: files
+        :param query: query
+        :param memory: memory
+        :return:
+        """
+        prompt_transform = PromptTransform()
+
+        # get prompt without memory and context
+        if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
+            prompt_messages, stop = prompt_transform.get_prompt(
+                app_mode=app_record.mode,
+                prompt_template_entity=prompt_template_entity,
+                inputs=inputs,
+                query=query if query else '',
+                files=files,
+                context=context,
+                memory=memory,
+                model_config=model_config
+            )
+        else:
+            prompt_messages = prompt_transform.get_advanced_prompt(
+                app_mode=app_record.mode,
+                prompt_template_entity=prompt_template_entity,
+                inputs=inputs,
+                query=query,
+                files=files,
+                context=context,
+                memory=memory,
+                model_config=model_config
+            )
+            stop = model_config.stop
+
+        return prompt_messages, stop
+
+    def direct_output(self, queue_manager: ApplicationQueueManager,
+                      app_orchestration_config: AppOrchestrationConfigEntity,
+                      prompt_messages: list,
+                      text: str,
+                      stream: bool,
+                      usage: Optional[LLMUsage] = None) -> None:
+        """
+        Direct output
+        :param queue_manager: application queue manager
+        :param app_orchestration_config: app orchestration config
+        :param prompt_messages: prompt messages
+        :param text: text
+        :param stream: stream
+        :param usage: usage
+        :return:
+        """
+        if stream:
+            index = 0
+            for token in text:
+                queue_manager.publish_chunk_message(LLMResultChunk(
+                    model=app_orchestration_config.model_config.model,
+                    prompt_messages=prompt_messages,
+                    delta=LLMResultChunkDelta(
+                        index=index,
+                        message=AssistantPromptMessage(content=token)
+                    )
+                ))
+                index += 1
+                time.sleep(0.01)
+
+        queue_manager.publish_message_end(
+            llm_result=LLMResult(
+                model=app_orchestration_config.model_config.model,
+                prompt_messages=prompt_messages,
+                message=AssistantPromptMessage(content=text),
+                usage=usage if usage else LLMUsage.empty_usage()
+            )
+        )
+
+    def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator],
+                              queue_manager: ApplicationQueueManager,
+                              stream: bool) -> None:
+        """
+        Handle invoke result
+        :param invoke_result: invoke result
+        :param queue_manager: application queue manager
+        :param stream: stream
+        :return:
+        """
+        if not stream:
+            self._handle_invoke_result_direct(
+                invoke_result=invoke_result,
+                queue_manager=queue_manager
+            )
+        else:
+            self._handle_invoke_result_stream(
+                invoke_result=invoke_result,
+                queue_manager=queue_manager
+            )
+
+    def _handle_invoke_result_direct(self, invoke_result: LLMResult,
+                                     queue_manager: ApplicationQueueManager) -> None:
+        """
+        Handle invoke result direct
+        :param invoke_result: invoke result
+        :param queue_manager: application queue manager
+        :return:
+        """
+        queue_manager.publish_message_end(
+            llm_result=invoke_result
+        )
+
+    def _handle_invoke_result_stream(self, invoke_result: Generator,
+                                     queue_manager: ApplicationQueueManager) -> None:
+        """
+        Handle invoke result
+        :param invoke_result: invoke result
+        :param queue_manager: application queue manager
+        :return:
+        """
+        model = None
+        prompt_messages = []
+        text = ''
+        usage = None
+        for result in invoke_result:
+            queue_manager.publish_chunk_message(result)
+
+            text += result.delta.message.content
+
+            if not model:
+                model = result.model
+
+            if not prompt_messages:
+                prompt_messages = result.prompt_messages
+
+            if not usage and result.delta.usage:
+                usage = result.delta.usage
+
+        llm_result = LLMResult(
+            model=model,
+            prompt_messages=prompt_messages,
+            message=AssistantPromptMessage(content=text),
+            usage=usage
+        )
+
+        queue_manager.publish_message_end(
+            llm_result=llm_result
+        )

+ 363 - 0
api/core/app_runner/basic_app_runner.py

@@ -0,0 +1,363 @@
+import logging
+from typing import Tuple, Optional
+
+from core.app_runner.app_runner import AppRunner
+from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
+from core.entities.application_entities import ApplicationGenerateEntity, ModelConfigEntity, \
+    AppOrchestrationConfigEntity, InvokeFrom, ExternalDataVariableEntity, DatasetEntity
+from core.application_queue_manager import ApplicationQueueManager
+from core.features.annotation_reply import AnnotationReplyFeature
+from core.features.dataset_retrieval import DatasetRetrievalFeature
+from core.features.external_data_fetch import ExternalDataFetchFeature
+from core.features.hosting_moderation import HostingModerationFeature
+from core.features.moderation import ModerationFeature
+from core.memory.token_buffer_memory import TokenBufferMemory
+from core.model_manager import ModelInstance
+from core.model_runtime.entities.message_entities import PromptMessage
+from core.moderation.base import ModerationException
+from core.prompt.prompt_transform import AppMode
+from extensions.ext_database import db
+from models.model import Conversation, Message, App, MessageAnnotation
+
+logger = logging.getLogger(__name__)
+
+
+class BasicApplicationRunner(AppRunner):
+    """
+    Basic Application Runner
+    """
+
+    def run(self, application_generate_entity: ApplicationGenerateEntity,
+            queue_manager: ApplicationQueueManager,
+            conversation: Conversation,
+            message: Message) -> None:
+        """
+        Run application
+        :param application_generate_entity: application generate entity
+        :param queue_manager: application queue manager
+        :param conversation: conversation
+        :param message: message
+        :return:
+        """
+        app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
+        if not app_record:
+            raise ValueError(f"App not found")
+
+        app_orchestration_config = application_generate_entity.app_orchestration_config_entity
+
+        inputs = application_generate_entity.inputs
+        query = application_generate_entity.query
+        files = application_generate_entity.files
+
+        # Pre-calculate the number of tokens of the prompt messages,
+        # and return the rest number of tokens by model context token size limit and max token size limit.
+        # If the rest number of tokens is not enough, raise exception.
+        # Include: prompt template, inputs, query(optional), files(optional)
+        # Not Include: memory, external data, dataset context
+        self.get_pre_calculate_rest_tokens(
+            app_record=app_record,
+            model_config=app_orchestration_config.model_config,
+            prompt_template_entity=app_orchestration_config.prompt_template,
+            inputs=inputs,
+            files=files,
+            query=query
+        )
+
+        memory = None
+        if application_generate_entity.conversation_id:
+            # get memory of conversation (read-only)
+            model_instance = ModelInstance(
+                provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
+                model=app_orchestration_config.model_config.model
+            )
+
+            memory = TokenBufferMemory(
+                conversation=conversation,
+                model_instance=model_instance
+            )
+
+        # organize all inputs and template to prompt messages
+        # Include: prompt template, inputs, query(optional), files(optional)
+        #          memory(optional)
+        prompt_messages, stop = self.originze_prompt_messages(
+            app_record=app_record,
+            model_config=app_orchestration_config.model_config,
+            prompt_template_entity=app_orchestration_config.prompt_template,
+            inputs=inputs,
+            files=files,
+            query=query,
+            memory=memory
+        )
+
+        # moderation
+        try:
+            # process sensitive_word_avoidance
+            _, inputs, query = self.moderation_for_inputs(
+                app_id=app_record.id,
+                tenant_id=application_generate_entity.tenant_id,
+                app_orchestration_config_entity=app_orchestration_config,
+                inputs=inputs,
+                query=query,
+            )
+        except ModerationException as e:
+            self.direct_output(
+                queue_manager=queue_manager,
+                app_orchestration_config=app_orchestration_config,
+                prompt_messages=prompt_messages,
+                text=str(e),
+                stream=application_generate_entity.stream
+            )
+            return
+
+        if query:
+            # annotation reply
+            annotation_reply = self.query_app_annotations_to_reply(
+                app_record=app_record,
+                message=message,
+                query=query,
+                user_id=application_generate_entity.user_id,
+                invoke_from=application_generate_entity.invoke_from
+            )
+
+            if annotation_reply:
+                queue_manager.publish_annotation_reply(
+                    message_annotation_id=annotation_reply.id
+                )
+                self.direct_output(
+                    queue_manager=queue_manager,
+                    app_orchestration_config=app_orchestration_config,
+                    prompt_messages=prompt_messages,
+                    text=annotation_reply.content,
+                    stream=application_generate_entity.stream
+                )
+                return
+
+            # fill in variable inputs from external data tools if exists
+            external_data_tools = app_orchestration_config.external_data_variables
+            if external_data_tools:
+                inputs = self.fill_in_inputs_from_external_data_tools(
+                    tenant_id=app_record.tenant_id,
+                    app_id=app_record.id,
+                    external_data_tools=external_data_tools,
+                    inputs=inputs,
+                    query=query
+                )
+
+        # get context from datasets
+        context = None
+        if app_orchestration_config.dataset:
+            context = self.retrieve_dataset_context(
+                tenant_id=app_record.tenant_id,
+                app_record=app_record,
+                queue_manager=queue_manager,
+                model_config=app_orchestration_config.model_config,
+                show_retrieve_source=app_orchestration_config.show_retrieve_source,
+                dataset_config=app_orchestration_config.dataset,
+                message=message,
+                inputs=inputs,
+                query=query,
+                user_id=application_generate_entity.user_id,
+                invoke_from=application_generate_entity.invoke_from,
+                memory=memory
+            )
+
+        # reorganize all inputs and template to prompt messages
+        # Include: prompt template, inputs, query(optional), files(optional)
+        #          memory(optional), external data, dataset context(optional)
+        prompt_messages, stop = self.originze_prompt_messages(
+            app_record=app_record,
+            model_config=app_orchestration_config.model_config,
+            prompt_template_entity=app_orchestration_config.prompt_template,
+            inputs=inputs,
+            files=files,
+            query=query,
+            context=context,
+            memory=memory
+        )
+
+        # check hosting moderation
+        hosting_moderation_result = self.check_hosting_moderation(
+            application_generate_entity=application_generate_entity,
+            queue_manager=queue_manager,
+            prompt_messages=prompt_messages
+        )
+
+        if hosting_moderation_result:
+            return
+
+        # Re-calculate the max tokens if sum(prompt_token +  max_tokens) over model token limit
+        self.recale_llm_max_tokens(
+            model_config=app_orchestration_config.model_config,
+            prompt_messages=prompt_messages
+        )
+
+        # Invoke model
+        model_instance = ModelInstance(
+            provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
+            model=app_orchestration_config.model_config.model
+        )
+
+        invoke_result = model_instance.invoke_llm(
+            prompt_messages=prompt_messages,
+            model_parameters=app_orchestration_config.model_config.parameters,
+            stop=stop,
+            stream=application_generate_entity.stream,
+            user=application_generate_entity.user_id,
+        )
+
+        # handle invoke result
+        self._handle_invoke_result(
+            invoke_result=invoke_result,
+            queue_manager=queue_manager,
+            stream=application_generate_entity.stream
+        )
+
+    def moderation_for_inputs(self, app_id: str,
+                              tenant_id: str,
+                              app_orchestration_config_entity: AppOrchestrationConfigEntity,
+                              inputs: dict,
+                              query: str) -> Tuple[bool, dict, str]:
+        """
+        Process sensitive_word_avoidance.
+        :param app_id: app id
+        :param tenant_id: tenant id
+        :param app_orchestration_config_entity: app orchestration config entity
+        :param inputs: inputs
+        :param query: query
+        :return:
+        """
+        moderation_feature = ModerationFeature()
+        return moderation_feature.check(
+            app_id=app_id,
+            tenant_id=tenant_id,
+            app_orchestration_config_entity=app_orchestration_config_entity,
+            inputs=inputs,
+            query=query,
+        )
+
+    def query_app_annotations_to_reply(self, app_record: App,
+                                       message: Message,
+                                       query: str,
+                                       user_id: str,
+                                       invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
+        """
+        Query app annotations to reply
+        :param app_record: app record
+        :param message: message
+        :param query: query
+        :param user_id: user id
+        :param invoke_from: invoke from
+        :return:
+        """
+        annotation_reply_feature = AnnotationReplyFeature()
+        return annotation_reply_feature.query(
+            app_record=app_record,
+            message=message,
+            query=query,
+            user_id=user_id,
+            invoke_from=invoke_from
+        )
+
+    def fill_in_inputs_from_external_data_tools(self, tenant_id: str,
+                                                app_id: str,
+                                                external_data_tools: list[ExternalDataVariableEntity],
+                                                inputs: dict,
+                                                query: str) -> dict:
+        """
+        Fill in variable inputs from external data tools if exists.
+
+        :param tenant_id: workspace id
+        :param app_id: app id
+        :param external_data_tools: external data tools configs
+        :param inputs: the inputs
+        :param query: the query
+        :return: the filled inputs
+        """
+        external_data_fetch_feature = ExternalDataFetchFeature()
+        return external_data_fetch_feature.fetch(
+            tenant_id=tenant_id,
+            app_id=app_id,
+            external_data_tools=external_data_tools,
+            inputs=inputs,
+            query=query
+        )
+
+    def retrieve_dataset_context(self, tenant_id: str,
+                                 app_record: App,
+                                 queue_manager: ApplicationQueueManager,
+                                 model_config: ModelConfigEntity,
+                                 dataset_config: DatasetEntity,
+                                 show_retrieve_source: bool,
+                                 message: Message,
+                                 inputs: dict,
+                                 query: str,
+                                 user_id: str,
+                                 invoke_from: InvokeFrom,
+                                 memory: Optional[TokenBufferMemory] = None) -> Optional[str]:
+        """
+        Retrieve dataset context
+        :param tenant_id: tenant id
+        :param app_record: app record
+        :param queue_manager: queue manager
+        :param model_config: model config
+        :param dataset_config: dataset config
+        :param show_retrieve_source: show retrieve source
+        :param message: message
+        :param inputs: inputs
+        :param query: query
+        :param user_id: user id
+        :param invoke_from: invoke from
+        :param memory: memory
+        :return:
+        """
+        hit_callback = DatasetIndexToolCallbackHandler(
+            queue_manager,
+            app_record.id,
+            message.id,
+            user_id,
+            invoke_from
+        )
+
+        if (app_record.mode == AppMode.COMPLETION.value and dataset_config
+                and dataset_config.retrieve_config.query_variable):
+            query = inputs.get(dataset_config.retrieve_config.query_variable, "")
+
+        dataset_retrieval = DatasetRetrievalFeature()
+        return dataset_retrieval.retrieve(
+            tenant_id=tenant_id,
+            model_config=model_config,
+            config=dataset_config,
+            query=query,
+            invoke_from=invoke_from,
+            show_retrieve_source=show_retrieve_source,
+            hit_callback=hit_callback,
+            memory=memory
+        )
+
+    def check_hosting_moderation(self, application_generate_entity: ApplicationGenerateEntity,
+                                 queue_manager: ApplicationQueueManager,
+                                 prompt_messages: list[PromptMessage]) -> bool:
+        """
+        Check hosting moderation
+        :param application_generate_entity: application generate entity
+        :param queue_manager: queue manager
+        :param prompt_messages: prompt messages
+        :return:
+        """
+        hosting_moderation_feature = HostingModerationFeature()
+        moderation_result = hosting_moderation_feature.check(
+            application_generate_entity=application_generate_entity,
+            prompt_messages=prompt_messages
+        )
+
+        if moderation_result:
+            self.direct_output(
+                queue_manager=queue_manager,
+                app_orchestration_config=application_generate_entity.app_orchestration_config_entity,
+                prompt_messages=prompt_messages,
+                text="I apologize for any confusion, " \
+                     "but I'm an AI assistant to be helpful, harmless, and honest.",
+                stream=application_generate_entity.stream
+            )
+
+        return moderation_result

+ 483 - 0
api/core/app_runner/generate_task_pipeline.py

@@ -0,0 +1,483 @@
+import json
+import logging
+import time
+from typing import Union, Generator, cast, Optional
+
+from pydantic import BaseModel
+
+from core.app_runner.moderation_handler import OutputModerationHandler, ModerationRule
+from core.entities.application_entities import ApplicationGenerateEntity
+from core.application_queue_manager import ApplicationQueueManager
+from core.entities.queue_entities import QueueErrorEvent, QueueStopEvent, QueueMessageEndEvent, \
+    QueueRetrieverResourcesEvent, QueueAgentThoughtEvent, QueuePingEvent, QueueMessageEvent, QueueMessageReplaceEvent, \
+    AnnotationReplyEvent
+from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessageRole, \
+    TextPromptMessageContent, PromptMessageContentType, ImagePromptMessageContent, PromptMessage
+from core.model_runtime.errors.invoke import InvokeError, InvokeAuthorizationError
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.prompt.prompt_template import PromptTemplateParser
+from events.message_event import message_was_created
+from extensions.ext_database import db
+from models.model import Message, Conversation, MessageAgentThought
+from services.annotation_service import AppAnnotationService
+
+logger = logging.getLogger(__name__)
+
+
+class TaskState(BaseModel):
+    """
+    TaskState entity
+    """
+    llm_result: LLMResult
+    metadata: dict = {}
+
+
+class GenerateTaskPipeline:
+    """
+    GenerateTaskPipeline is a class that generate stream output and state management for Application.
+    """
+
+    def __init__(self, application_generate_entity: ApplicationGenerateEntity,
+                 queue_manager: ApplicationQueueManager,
+                 conversation: Conversation,
+                 message: Message) -> None:
+        """
+        Initialize GenerateTaskPipeline.
+        :param application_generate_entity: application generate entity
+        :param queue_manager: queue manager
+        :param conversation: conversation
+        :param message: message
+        """
+        self._application_generate_entity = application_generate_entity
+        self._queue_manager = queue_manager
+        self._conversation = conversation
+        self._message = message
+        self._task_state = TaskState(
+            llm_result=LLMResult(
+                model=self._application_generate_entity.app_orchestration_config_entity.model_config.model,
+                prompt_messages=[],
+                message=AssistantPromptMessage(content=""),
+                usage=LLMUsage.empty_usage()
+            )
+        )
+        self._start_at = time.perf_counter()
+        self._output_moderation_handler = self._init_output_moderation()
+
+    def process(self, stream: bool) -> Union[dict, Generator]:
+        """
+        Process generate task pipeline.
+        :return:
+        """
+        if stream:
+            return self._process_stream_response()
+        else:
+            return self._process_blocking_response()
+
+    def _process_blocking_response(self) -> dict:
+        """
+        Process blocking response.
+        :return:
+        """
+        for queue_message in self._queue_manager.listen():
+            event = queue_message.event
+
+            if isinstance(event, QueueErrorEvent):
+                raise self._handle_error(event)
+            elif isinstance(event, QueueRetrieverResourcesEvent):
+                self._task_state.metadata['retriever_resources'] = event.retriever_resources
+            elif isinstance(event, AnnotationReplyEvent):
+                annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
+                if annotation:
+                    account = annotation.account
+                    self._task_state.metadata['annotation_reply'] = {
+                        'id': annotation.id,
+                        'account': {
+                            'id': annotation.account_id,
+                            'name': account.name if account else 'Dify user'
+                        }
+                    }
+
+                    self._task_state.llm_result.message.content = annotation.content
+            elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
+                if isinstance(event, QueueMessageEndEvent):
+                    self._task_state.llm_result = event.llm_result
+                else:
+                    model_config = self._application_generate_entity.app_orchestration_config_entity.model_config
+                    model = model_config.model
+                    model_type_instance = model_config.provider_model_bundle.model_type_instance
+                    model_type_instance = cast(LargeLanguageModel, model_type_instance)
+
+                    # calculate num tokens
+                    prompt_tokens = 0
+                    if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
+                        prompt_tokens = model_type_instance.get_num_tokens(
+                            model,
+                            model_config.credentials,
+                            self._task_state.llm_result.prompt_messages
+                        )
+
+                    completion_tokens = 0
+                    if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
+                        completion_tokens = model_type_instance.get_num_tokens(
+                            model,
+                            model_config.credentials,
+                            [self._task_state.llm_result.message]
+                        )
+
+                    credentials = model_config.credentials
+
+                    # transform usage
+                    self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
+                        model,
+                        credentials,
+                        prompt_tokens,
+                        completion_tokens
+                    )
+
+                # response moderation
+                if self._output_moderation_handler:
+                    self._output_moderation_handler.stop_thread()
+
+                    self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion(
+                        completion=self._task_state.llm_result.message.content,
+                        public_event=False
+                    )
+
+                # Save message
+                self._save_message(event.llm_result)
+
+                response = {
+                    'event': 'message',
+                    'task_id': self._application_generate_entity.task_id,
+                    'id': self._message.id,
+                    'mode': self._conversation.mode,
+                    'answer': event.llm_result.message.content,
+                    'metadata': {},
+                    'created_at': int(self._message.created_at.timestamp())
+                }
+
+                if self._conversation.mode == 'chat':
+                    response['conversation_id'] = self._conversation.id
+
+                if self._task_state.metadata:
+                    response['metadata'] = self._task_state.metadata
+
+                return response
+            else:
+                continue
+
+    def _process_stream_response(self) -> Generator:
+        """
+        Process stream response.
+        :return:
+        """
+        for message in self._queue_manager.listen():
+            event = message.event
+
+            if isinstance(event, QueueErrorEvent):
+                raise self._handle_error(event)
+            elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
+                if isinstance(event, QueueMessageEndEvent):
+                    self._task_state.llm_result = event.llm_result
+                else:
+                    model_config = self._application_generate_entity.app_orchestration_config_entity.model_config
+                    model = model_config.model
+                    model_type_instance = model_config.provider_model_bundle.model_type_instance
+                    model_type_instance = cast(LargeLanguageModel, model_type_instance)
+
+                    # calculate num tokens
+                    prompt_tokens = 0
+                    if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
+                        prompt_tokens = model_type_instance.get_num_tokens(
+                            model,
+                            model_config.credentials,
+                            self._task_state.llm_result.prompt_messages
+                        )
+
+                    completion_tokens = 0
+                    if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
+                        completion_tokens = model_type_instance.get_num_tokens(
+                            model,
+                            model_config.credentials,
+                            [self._task_state.llm_result.message]
+                        )
+
+                    credentials = model_config.credentials
+
+                    # transform usage
+                    self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
+                        model,
+                        credentials,
+                        prompt_tokens,
+                        completion_tokens
+                    )
+
+                # response moderation
+                if self._output_moderation_handler:
+                    self._output_moderation_handler.stop_thread()
+
+                    self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion(
+                        completion=self._task_state.llm_result.message.content,
+                        public_event=False
+                    )
+
+                    self._output_moderation_handler = None
+
+                    replace_response = {
+                        'event': 'message_replace',
+                        'task_id': self._application_generate_entity.task_id,
+                        'message_id': self._message.id,
+                        'answer': self._task_state.llm_result.message.content,
+                        'created_at': int(self._message.created_at.timestamp())
+                    }
+
+                    if self._conversation.mode == 'chat':
+                        replace_response['conversation_id'] = self._conversation.id
+
+                    yield self._yield_response(replace_response)
+
+                # Save message
+                self._save_message(self._task_state.llm_result)
+
+                response = {
+                    'event': 'message_end',
+                    'task_id': self._application_generate_entity.task_id,
+                    'id': self._message.id,
+                }
+
+                if self._conversation.mode == 'chat':
+                    response['conversation_id'] = self._conversation.id
+
+                if self._task_state.metadata:
+                    response['metadata'] = self._task_state.metadata
+
+                yield self._yield_response(response)
+            elif isinstance(event, QueueRetrieverResourcesEvent):
+                self._task_state.metadata['retriever_resources'] = event.retriever_resources
+            elif isinstance(event, AnnotationReplyEvent):
+                annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
+                if annotation:
+                    account = annotation.account
+                    self._task_state.metadata['annotation_reply'] = {
+                        'id': annotation.id,
+                        'account': {
+                            'id': annotation.account_id,
+                            'name': account.name if account else 'Dify user'
+                        }
+                    }
+
+                    self._task_state.llm_result.message.content = annotation.content
+            elif isinstance(event, QueueAgentThoughtEvent):
+                agent_thought = (
+                    db.session.query(MessageAgentThought)
+                    .filter(MessageAgentThought.id == event.agent_thought_id)
+                    .first()
+                )
+
+                if agent_thought:
+                    response = {
+                        'event': 'agent_thought',
+                        'id': agent_thought.id,
+                        'task_id': self._application_generate_entity.task_id,
+                        'message_id': self._message.id,
+                        'position': agent_thought.position,
+                        'thought': agent_thought.thought,
+                        'tool': agent_thought.tool,
+                        'tool_input': agent_thought.tool_input,
+                        'created_at': int(self._message.created_at.timestamp())
+                    }
+
+                    if self._conversation.mode == 'chat':
+                        response['conversation_id'] = self._conversation.id
+
+                    yield self._yield_response(response)
+            elif isinstance(event, QueueMessageEvent):
+                chunk = event.chunk
+                delta_text = chunk.delta.message.content
+                if delta_text is None:
+                    continue
+
+                if not self._task_state.llm_result.prompt_messages:
+                    self._task_state.llm_result.prompt_messages = chunk.prompt_messages
+
+                if self._output_moderation_handler:
+                    if self._output_moderation_handler.should_direct_output():
+                        # stop subscribe new token when output moderation should direct output
+                        self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output()
+                        self._queue_manager.publish_chunk_message(LLMResultChunk(
+                            model=self._task_state.llm_result.model,
+                            prompt_messages=self._task_state.llm_result.prompt_messages,
+                            delta=LLMResultChunkDelta(
+                                index=0,
+                                message=AssistantPromptMessage(content=self._task_state.llm_result.message.content)
+                            )
+                        ))
+                        self._queue_manager.publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION))
+                        continue
+                    else:
+                        self._output_moderation_handler.append_new_token(delta_text)
+
+                self._task_state.llm_result.message.content += delta_text
+                response = self._handle_chunk(delta_text)
+                yield self._yield_response(response)
+            elif isinstance(event, QueueMessageReplaceEvent):
+                response = {
+                    'event': 'message_replace',
+                    'task_id': self._application_generate_entity.task_id,
+                    'message_id': self._message.id,
+                    'answer': event.text,
+                    'created_at': int(self._message.created_at.timestamp())
+                }
+
+                if self._conversation.mode == 'chat':
+                    response['conversation_id'] = self._conversation.id
+
+                yield self._yield_response(response)
+            elif isinstance(event, QueuePingEvent):
+                yield "event: ping\n\n"
+            else:
+                continue
+
+    def _save_message(self, llm_result: LLMResult) -> None:
+        """
+        Save message.
+        :param llm_result: llm result
+        :return:
+        """
+        usage = llm_result.usage
+
+        self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
+
+        self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages)
+        self._message.message_tokens = usage.prompt_tokens
+        self._message.message_unit_price = usage.prompt_unit_price
+        self._message.message_price_unit = usage.prompt_price_unit
+        self._message.answer = PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) \
+            if llm_result.message.content else ''
+        self._message.answer_tokens = usage.completion_tokens
+        self._message.answer_unit_price = usage.completion_unit_price
+        self._message.answer_price_unit = usage.completion_price_unit
+        self._message.provider_response_latency = time.perf_counter() - self._start_at
+        self._message.total_price = usage.total_price
+
+        db.session.commit()
+
+        message_was_created.send(
+            self._message,
+            application_generate_entity=self._application_generate_entity,
+            conversation=self._conversation,
+            is_first_message=self._application_generate_entity.conversation_id is None,
+            extras=self._application_generate_entity.extras
+        )
+
+    def _handle_chunk(self, text: str) -> dict:
+        """
+        Handle completed event.
+        :param text: text
+        :return:
+        """
+        response = {
+            'event': 'message',
+            'id': self._message.id,
+            'task_id': self._application_generate_entity.task_id,
+            'message_id': self._message.id,
+            'answer': text,
+            'created_at': int(self._message.created_at.timestamp())
+        }
+
+        if self._conversation.mode == 'chat':
+            response['conversation_id'] = self._conversation.id
+
+        return response
+
+    def _handle_error(self, event: QueueErrorEvent) -> Exception:
+        """
+        Handle error event.
+        :param event: event
+        :return:
+        """
+        logger.debug("error: %s", event.error)
+        e = event.error
+
+        if isinstance(e, InvokeAuthorizationError):
+            return InvokeAuthorizationError('Incorrect API key provided')
+        elif isinstance(e, InvokeError) or isinstance(e, ValueError):
+            return e
+        else:
+            return Exception(e.description if getattr(e, 'description', None) is not None else str(e))
+
+    def _yield_response(self, response: dict) -> str:
+        """
+        Yield response.
+        :param response: response
+        :return:
+        """
+        return "data: " + json.dumps(response) + "\n\n"
+
+    def _prompt_messages_to_prompt_for_saving(self, prompt_messages: list[PromptMessage]) -> list[dict]:
+        """
+        Prompt messages to prompt for saving.
+        :param prompt_messages: prompt messages
+        :return:
+        """
+        prompts = []
+        if self._application_generate_entity.app_orchestration_config_entity.model_config.mode == 'chat':
+            for prompt_message in prompt_messages:
+                if prompt_message.role == PromptMessageRole.USER:
+                    role = 'user'
+                elif prompt_message.role == PromptMessageRole.ASSISTANT:
+                    role = 'assistant'
+                elif prompt_message.role == PromptMessageRole.SYSTEM:
+                    role = 'system'
+                else:
+                    continue
+
+                text = ''
+                files = []
+                if isinstance(prompt_message.content, list):
+                    for content in prompt_message.content:
+                        if content.type == PromptMessageContentType.TEXT:
+                            content = cast(TextPromptMessageContent, content)
+                            text += content.data
+                        else:
+                            content = cast(ImagePromptMessageContent, content)
+                            files.append({
+                                "type": 'image',
+                                "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
+                                "detail": content.detail.value
+                            })
+                else:
+                    text = prompt_message.content
+
+                prompts.append({
+                    "role": role,
+                    "text": text,
+                    "files": files
+                })
+        else:
+            prompts.append({
+                "role": 'user',
+                "text": prompt_messages[0].content
+            })
+
+        return prompts
+
+    def _init_output_moderation(self) -> Optional[OutputModerationHandler]:
+        """
+        Init output moderation.
+        :return:
+        """
+        app_orchestration_config_entity = self._application_generate_entity.app_orchestration_config_entity
+        sensitive_word_avoidance = app_orchestration_config_entity.sensitive_word_avoidance
+
+        if sensitive_word_avoidance:
+            return OutputModerationHandler(
+                tenant_id=self._application_generate_entity.tenant_id,
+                app_id=self._application_generate_entity.app_id,
+                rule=ModerationRule(
+                    type=sensitive_word_avoidance.type,
+                    config=sensitive_word_avoidance.config
+                ),
+                on_message_replace_func=self._queue_manager.publish_message_replace
+            )

+ 138 - 0
api/core/app_runner/moderation_handler.py

@@ -0,0 +1,138 @@
+import logging
+import threading
+import time
+from typing import Any, Optional, Dict
+
+from flask import current_app, Flask
+from pydantic import BaseModel
+
+from core.moderation.base import ModerationAction, ModerationOutputsResult
+from core.moderation.factory import ModerationFactory
+
+logger = logging.getLogger(__name__)
+
+
+class ModerationRule(BaseModel):
+    type: str
+    config: Dict[str, Any]
+
+
+class OutputModerationHandler(BaseModel):
+    DEFAULT_BUFFER_SIZE: int = 300
+
+    tenant_id: str
+    app_id: str
+
+    rule: ModerationRule
+    on_message_replace_func: Any
+
+    thread: Optional[threading.Thread] = None
+    thread_running: bool = True
+    buffer: str = ''
+    is_final_chunk: bool = False
+    final_output: Optional[str] = None
+
+    class Config:
+        arbitrary_types_allowed = True
+
+    def should_direct_output(self):
+        return self.final_output is not None
+
+    def get_final_output(self):
+        return self.final_output
+
+    def append_new_token(self, token: str):
+        self.buffer += token
+
+        if not self.thread:
+            self.thread = self.start_thread()
+
+    def moderation_completion(self, completion: str, public_event: bool = False) -> str:
+        self.buffer = completion
+        self.is_final_chunk = True
+
+        result = self.moderation(
+            tenant_id=self.tenant_id,
+            app_id=self.app_id,
+            moderation_buffer=completion
+        )
+
+        if not result or not result.flagged:
+            return completion
+
+        if result.action == ModerationAction.DIRECT_OUTPUT:
+            final_output = result.preset_response
+        else:
+            final_output = result.text
+
+        if public_event:
+            self.on_message_replace_func(final_output)
+
+        return final_output
+
+    def start_thread(self) -> threading.Thread:
+        buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE))
+        thread = threading.Thread(target=self.worker, kwargs={
+            'flask_app': current_app._get_current_object(),
+            'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE
+        })
+
+        thread.start()
+
+        return thread
+
+    def stop_thread(self):
+        if self.thread and self.thread.is_alive():
+            self.thread_running = False
+
+    def worker(self, flask_app: Flask, buffer_size: int):
+        with flask_app.app_context():
+            current_length = 0
+            while self.thread_running:
+                moderation_buffer = self.buffer
+                buffer_length = len(moderation_buffer)
+                if not self.is_final_chunk:
+                    chunk_length = buffer_length - current_length
+                    if 0 <= chunk_length < buffer_size:
+                        time.sleep(1)
+                        continue
+
+                current_length = buffer_length
+
+                result = self.moderation(
+                    tenant_id=self.tenant_id,
+                    app_id=self.app_id,
+                    moderation_buffer=moderation_buffer
+                )
+
+                if not result or not result.flagged:
+                    continue
+
+                if result.action == ModerationAction.DIRECT_OUTPUT:
+                    final_output = result.preset_response
+                    self.final_output = final_output
+                else:
+                    final_output = result.text + self.buffer[len(moderation_buffer):]
+
+                # trigger replace event
+                if self.thread_running:
+                    self.on_message_replace_func(final_output)
+
+                if result.action == ModerationAction.DIRECT_OUTPUT:
+                    break
+
+    def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
+        try:
+            moderation_factory = ModerationFactory(
+                name=self.rule.type,
+                app_id=app_id,
+                tenant_id=tenant_id,
+                config=self.rule.config
+            )
+
+            result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
+            return result
+        except Exception as e:
+            logger.error("Moderation Output error: %s", e)
+
+        return None

+ 655 - 0
api/core/application_manager.py

@@ -0,0 +1,655 @@
+import json
+import logging
+import threading
+import uuid
+from typing import cast, Optional, Any, Union, Generator, Tuple
+
+from flask import Flask, current_app
+from pydantic import ValidationError
+
+from core.app_runner.agent_app_runner import AgentApplicationRunner
+from core.app_runner.basic_app_runner import BasicApplicationRunner
+from core.app_runner.generate_task_pipeline import GenerateTaskPipeline
+from core.entities.application_entities import ApplicationGenerateEntity, AppOrchestrationConfigEntity, \
+    ModelConfigEntity, PromptTemplateEntity, AdvancedChatPromptTemplateEntity, \
+    AdvancedCompletionPromptTemplateEntity, ExternalDataVariableEntity, DatasetEntity, DatasetRetrieveConfigEntity, \
+    AgentEntity, AgentToolEntity, FileUploadEntity, SensitiveWordAvoidanceEntity, InvokeFrom
+from core.entities.model_entities import ModelStatus
+from core.file.file_obj import FileObj
+from core.errors.error import QuotaExceededError, ProviderTokenNotInitError, ModelCurrentlyNotSupportError
+from core.model_runtime.entities.message_entities import PromptMessageRole
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.prompt.prompt_template import PromptTemplateParser
+from core.provider_manager import ProviderManager
+from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException
+from extensions.ext_database import db
+from models.account import Account
+from models.model import EndUser, Conversation, Message, MessageFile, App
+
+logger = logging.getLogger(__name__)
+
+
+class ApplicationManager:
+    """
+    This class is responsible for managing application
+    """
+
+    def generate(self, tenant_id: str,
+                 app_id: str,
+                 app_model_config_id: str,
+                 app_model_config_dict: dict,
+                 app_model_config_override: bool,
+                 user: Union[Account, EndUser],
+                 invoke_from: InvokeFrom,
+                 inputs: dict[str, str],
+                 query: Optional[str] = None,
+                 files: Optional[list[FileObj]] = None,
+                 conversation: Optional[Conversation] = None,
+                 stream: bool = False,
+                 extras: Optional[dict[str, Any]] = None) \
+            -> Union[dict, Generator]:
+        """
+        Generate App response.
+
+        :param tenant_id: workspace ID
+        :param app_id: app ID
+        :param app_model_config_id: app model config id
+        :param app_model_config_dict: app model config dict
+        :param app_model_config_override: app model config override
+        :param user: account or end user
+        :param invoke_from: invoke from source
+        :param inputs: inputs
+        :param query: query
+        :param files: file obj list
+        :param conversation: conversation
+        :param stream: is stream
+        :param extras: extras
+        """
+        # init task id
+        task_id = str(uuid.uuid4())
+
+        # init application generate entity
+        application_generate_entity = ApplicationGenerateEntity(
+            task_id=task_id,
+            tenant_id=tenant_id,
+            app_id=app_id,
+            app_model_config_id=app_model_config_id,
+            app_model_config_dict=app_model_config_dict,
+            app_orchestration_config_entity=self._convert_from_app_model_config_dict(
+                tenant_id=tenant_id,
+                app_model_config_dict=app_model_config_dict
+            ),
+            app_model_config_override=app_model_config_override,
+            conversation_id=conversation.id if conversation else None,
+            inputs=conversation.inputs if conversation else inputs,
+            query=query.replace('\x00', '') if query else None,
+            files=files if files else [],
+            user_id=user.id,
+            stream=stream,
+            invoke_from=invoke_from,
+            extras=extras
+        )
+
+        # init generate records
+        (
+            conversation,
+            message
+        ) = self._init_generate_records(application_generate_entity)
+
+        # init queue manager
+        queue_manager = ApplicationQueueManager(
+            task_id=application_generate_entity.task_id,
+            user_id=application_generate_entity.user_id,
+            invoke_from=application_generate_entity.invoke_from,
+            conversation_id=conversation.id,
+            app_mode=conversation.mode,
+            message_id=message.id
+        )
+
+        # new thread
+        worker_thread = threading.Thread(target=self._generate_worker, kwargs={
+            'flask_app': current_app._get_current_object(),
+            'application_generate_entity': application_generate_entity,
+            'queue_manager': queue_manager,
+            'conversation_id': conversation.id,
+            'message_id': message.id,
+        })
+
+        worker_thread.start()
+
+        # return response or stream generator
+        return self._handle_response(
+            application_generate_entity=application_generate_entity,
+            queue_manager=queue_manager,
+            conversation=conversation,
+            message=message,
+            stream=stream
+        )
+
+    def _generate_worker(self, flask_app: Flask,
+                         application_generate_entity: ApplicationGenerateEntity,
+                         queue_manager: ApplicationQueueManager,
+                         conversation_id: str,
+                         message_id: str) -> None:
+        """
+        Generate worker in a new thread.
+        :param flask_app: Flask app
+        :param application_generate_entity: application generate entity
+        :param queue_manager: queue manager
+        :param conversation_id: conversation ID
+        :param message_id: message ID
+        :return:
+        """
+        with flask_app.app_context():
+            try:
+                # get conversation and message
+                conversation = self._get_conversation(conversation_id)
+                message = self._get_message(message_id)
+
+                if application_generate_entity.app_orchestration_config_entity.agent:
+                    # agent app
+                    runner = AgentApplicationRunner()
+                    runner.run(
+                        application_generate_entity=application_generate_entity,
+                        queue_manager=queue_manager,
+                        conversation=conversation,
+                        message=message
+                    )
+                else:
+                    # basic app
+                    runner = BasicApplicationRunner()
+                    runner.run(
+                        application_generate_entity=application_generate_entity,
+                        queue_manager=queue_manager,
+                        conversation=conversation,
+                        message=message
+                    )
+            except ConversationTaskStoppedException:
+                pass
+            except InvokeAuthorizationError:
+                queue_manager.publish_error(InvokeAuthorizationError('Incorrect API key provided'))
+            except ValidationError as e:
+                logger.exception("Validation Error when generating")
+                queue_manager.publish_error(e)
+            except (ValueError, InvokeError) as e:
+                queue_manager.publish_error(e)
+            except Exception as e:
+                logger.exception("Unknown Error when generating")
+                queue_manager.publish_error(e)
+            finally:
+                db.session.remove()
+
+    def _handle_response(self, application_generate_entity: ApplicationGenerateEntity,
+                         queue_manager: ApplicationQueueManager,
+                         conversation: Conversation,
+                         message: Message,
+                         stream: bool = False) -> Union[dict, Generator]:
+        """
+        Handle response.
+        :param application_generate_entity: application generate entity
+        :param queue_manager: queue manager
+        :param conversation: conversation
+        :param message: message
+        :param stream: is stream
+        :return:
+        """
+        # init generate task pipeline
+        generate_task_pipeline = GenerateTaskPipeline(
+            application_generate_entity=application_generate_entity,
+            queue_manager=queue_manager,
+            conversation=conversation,
+            message=message
+        )
+
+        try:
+            return generate_task_pipeline.process(stream=stream)
+        except ValueError as e:
+            if e.args[0] == "I/O operation on closed file.":  # ignore this error
+                raise ConversationTaskStoppedException()
+            else:
+                logger.exception(e)
+                raise e
+        finally:
+            db.session.remove()
+
+    def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \
+            -> AppOrchestrationConfigEntity:
+        """
+        Convert app model config dict to entity.
+        :param tenant_id: tenant ID
+        :param app_model_config_dict: app model config dict
+        :raises ProviderTokenNotInitError: provider token not init error
+        :return: app orchestration config entity
+        """
+        properties = {}
+
+        copy_app_model_config_dict = app_model_config_dict.copy()
+
+        provider_manager = ProviderManager()
+        provider_model_bundle = provider_manager.get_provider_model_bundle(
+            tenant_id=tenant_id,
+            provider=copy_app_model_config_dict['model']['provider'],
+            model_type=ModelType.LLM
+        )
+
+        provider_name = provider_model_bundle.configuration.provider.provider
+        model_name = copy_app_model_config_dict['model']['name']
+
+        model_type_instance = provider_model_bundle.model_type_instance
+        model_type_instance = cast(LargeLanguageModel, model_type_instance)
+
+        # check model credentials
+        model_credentials = provider_model_bundle.configuration.get_current_credentials(
+            model_type=ModelType.LLM,
+            model=copy_app_model_config_dict['model']['name']
+        )
+
+        if model_credentials is None:
+            raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
+
+        # check model
+        provider_model = provider_model_bundle.configuration.get_provider_model(
+            model=copy_app_model_config_dict['model']['name'],
+            model_type=ModelType.LLM
+        )
+
+        if provider_model is None:
+            model_name = copy_app_model_config_dict['model']['name']
+            raise ValueError(f"Model {model_name} not exist.")
+
+        if provider_model.status == ModelStatus.NO_CONFIGURE:
+            raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
+        elif provider_model.status == ModelStatus.NO_PERMISSION:
+            raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
+        elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
+            raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
+
+        # model config
+        completion_params = copy_app_model_config_dict['model'].get('completion_params')
+        stop = []
+        if 'stop' in completion_params:
+            stop = completion_params['stop']
+            del completion_params['stop']
+
+        # get model mode
+        model_mode = copy_app_model_config_dict['model'].get('mode')
+        if not model_mode:
+            mode_enum = model_type_instance.get_model_mode(
+                model=copy_app_model_config_dict['model']['name'],
+                credentials=model_credentials
+            )
+
+            model_mode = mode_enum.value
+
+        model_schema = model_type_instance.get_model_schema(
+            copy_app_model_config_dict['model']['name'],
+            model_credentials
+        )
+
+        if not model_schema:
+            raise ValueError(f"Model {model_name} not exist.")
+
+        properties['model_config'] = ModelConfigEntity(
+            provider=copy_app_model_config_dict['model']['provider'],
+            model=copy_app_model_config_dict['model']['name'],
+            model_schema=model_schema,
+            mode=model_mode,
+            provider_model_bundle=provider_model_bundle,
+            credentials=model_credentials,
+            parameters=completion_params,
+            stop=stop,
+        )
+
+        # prompt template
+        prompt_type = PromptTemplateEntity.PromptType.value_of(copy_app_model_config_dict['prompt_type'])
+        if prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
+            simple_prompt_template = copy_app_model_config_dict.get("pre_prompt", "")
+            properties['prompt_template'] = PromptTemplateEntity(
+                prompt_type=prompt_type,
+                simple_prompt_template=simple_prompt_template
+            )
+        else:
+            advanced_chat_prompt_template = None
+            chat_prompt_config = copy_app_model_config_dict.get("chat_prompt_config", {})
+            if chat_prompt_config:
+                chat_prompt_messages = []
+                for message in chat_prompt_config.get("prompt", []):
+                    chat_prompt_messages.append({
+                        "text": message["text"],
+                        "role": PromptMessageRole.value_of(message["role"])
+                    })
+
+                advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(
+                    messages=chat_prompt_messages
+                )
+
+            advanced_completion_prompt_template = None
+            completion_prompt_config = copy_app_model_config_dict.get("completion_prompt_config", {})
+            if completion_prompt_config:
+                completion_prompt_template_params = {
+                    'prompt': completion_prompt_config['prompt']['text'],
+                }
+
+                if 'conversation_histories_role' in completion_prompt_config:
+                    completion_prompt_template_params['role_prefix'] = {
+                        'user': completion_prompt_config['conversation_histories_role']['user_prefix'],
+                        'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix']
+                    }
+
+                advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
+                    **completion_prompt_template_params
+                )
+
+            properties['prompt_template'] = PromptTemplateEntity(
+                prompt_type=prompt_type,
+                advanced_chat_prompt_template=advanced_chat_prompt_template,
+                advanced_completion_prompt_template=advanced_completion_prompt_template
+            )
+
+        # external data variables
+        properties['external_data_variables'] = []
+        external_data_tools = copy_app_model_config_dict.get('external_data_tools', [])
+        for external_data_tool in external_data_tools:
+            if 'enabled' not in external_data_tool or not external_data_tool['enabled']:
+                continue
+
+            properties['external_data_variables'].append(
+                ExternalDataVariableEntity(
+                    variable=external_data_tool['variable'],
+                    type=external_data_tool['type'],
+                    config=external_data_tool['config']
+                )
+            )
+
+        # show retrieve source
+        show_retrieve_source = False
+        retriever_resource_dict = copy_app_model_config_dict.get('retriever_resource')
+        if retriever_resource_dict:
+            if 'enabled' in retriever_resource_dict and retriever_resource_dict['enabled']:
+                show_retrieve_source = True
+
+        properties['show_retrieve_source'] = show_retrieve_source
+
+        if 'agent_mode' in copy_app_model_config_dict and copy_app_model_config_dict['agent_mode'] \
+                and 'enabled' in copy_app_model_config_dict['agent_mode'] and copy_app_model_config_dict['agent_mode'][
+            'enabled']:
+            agent_dict = copy_app_model_config_dict.get('agent_mode')
+            if agent_dict['strategy'] in ['router', 'react_router']:
+                dataset_ids = []
+                for tool in agent_dict.get('tools', []):
+                    key = list(tool.keys())[0]
+
+                    if key != 'dataset':
+                        continue
+
+                    tool_item = tool[key]
+
+                    if "enabled" not in tool_item or not tool_item["enabled"]:
+                        continue
+
+                    dataset_id = tool_item['id']
+                    dataset_ids.append(dataset_id)
+
+                dataset_configs = copy_app_model_config_dict.get('dataset_configs', {'retrieval_model': 'single'})
+                query_variable = copy_app_model_config_dict.get('dataset_query_variable')
+                if dataset_configs['retrieval_model'] == 'single':
+                    properties['dataset'] = DatasetEntity(
+                        dataset_ids=dataset_ids,
+                        retrieve_config=DatasetRetrieveConfigEntity(
+                            query_variable=query_variable,
+                            retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
+                                dataset_configs['retrieval_model']
+                            ),
+                            single_strategy=agent_dict['strategy']
+                        )
+                    )
+                else:
+                    properties['dataset'] = DatasetEntity(
+                        dataset_ids=dataset_ids,
+                        retrieve_config=DatasetRetrieveConfigEntity(
+                            query_variable=query_variable,
+                            retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
+                                dataset_configs['retrieval_model']
+                            ),
+                            top_k=dataset_configs.get('top_k'),
+                            score_threshold=dataset_configs.get('score_threshold'),
+                            reranking_model=dataset_configs.get('reranking_model')
+                        )
+                    )
+            else:
+                if agent_dict['strategy'] == 'react':
+                    strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
+                else:
+                    strategy = AgentEntity.Strategy.FUNCTION_CALLING
+
+                agent_tools = []
+                for tool in agent_dict.get('tools', []):
+                    key = list(tool.keys())[0]
+                    tool_item = tool[key]
+
+                    agent_tool_properties = {
+                        "tool_id": key
+                    }
+
+                    if "enabled" not in tool_item or not tool_item["enabled"]:
+                        continue
+
+                    agent_tool_properties["config"] = tool_item
+                    agent_tools.append(AgentToolEntity(**agent_tool_properties))
+
+                properties['agent'] = AgentEntity(
+                    provider=properties['model_config'].provider,
+                    model=properties['model_config'].model,
+                    strategy=strategy,
+                    tools=agent_tools
+                )
+
+        # file upload
+        file_upload_dict = copy_app_model_config_dict.get('file_upload')
+        if file_upload_dict:
+            if 'image' in file_upload_dict and file_upload_dict['image']:
+                if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']:
+                    properties['file_upload'] = FileUploadEntity(
+                        image_config={
+                            'number_limits': file_upload_dict['image']['number_limits'],
+                            'detail': file_upload_dict['image']['detail'],
+                            'transfer_methods': file_upload_dict['image']['transfer_methods']
+                        }
+                    )
+
+        # opening statement
+        properties['opening_statement'] = copy_app_model_config_dict.get('opening_statement')
+
+        # suggested questions after answer
+        suggested_questions_after_answer_dict = copy_app_model_config_dict.get('suggested_questions_after_answer')
+        if suggested_questions_after_answer_dict:
+            if 'enabled' in suggested_questions_after_answer_dict and suggested_questions_after_answer_dict['enabled']:
+                properties['suggested_questions_after_answer'] = True
+
+        # more like this
+        more_like_this_dict = copy_app_model_config_dict.get('more_like_this')
+        if more_like_this_dict:
+            if 'enabled' in more_like_this_dict and more_like_this_dict['enabled']:
+                properties['more_like_this'] = copy_app_model_config_dict.get('opening_statement')
+
+        # speech to text
+        speech_to_text_dict = copy_app_model_config_dict.get('speech_to_text')
+        if speech_to_text_dict:
+            if 'enabled' in speech_to_text_dict and speech_to_text_dict['enabled']:
+                properties['speech_to_text'] = True
+
+        # sensitive word avoidance
+        sensitive_word_avoidance_dict = copy_app_model_config_dict.get('sensitive_word_avoidance')
+        if sensitive_word_avoidance_dict:
+            if 'enabled' in sensitive_word_avoidance_dict and sensitive_word_avoidance_dict['enabled']:
+                properties['sensitive_word_avoidance'] = SensitiveWordAvoidanceEntity(
+                    type=sensitive_word_avoidance_dict.get('type'),
+                    config=sensitive_word_avoidance_dict.get('config'),
+                )
+
+        return AppOrchestrationConfigEntity(**properties)
+
+    def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \
+            -> Tuple[Conversation, Message]:
+        """
+        Initialize generate records
+        :param application_generate_entity: application generate entity
+        :return:
+        """
+        app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
+
+        model_type_instance = app_orchestration_config_entity.model_config.provider_model_bundle.model_type_instance
+        model_type_instance = cast(LargeLanguageModel, model_type_instance)
+        model_schema = model_type_instance.get_model_schema(
+            model=app_orchestration_config_entity.model_config.model,
+            credentials=app_orchestration_config_entity.model_config.credentials
+        )
+
+        app_record = (db.session.query(App)
+                      .filter(App.id == application_generate_entity.app_id).first())
+
+        app_mode = app_record.mode
+
+        # get from source
+        end_user_id = None
+        account_id = None
+        if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
+            from_source = 'api'
+            end_user_id = application_generate_entity.user_id
+        else:
+            from_source = 'console'
+            account_id = application_generate_entity.user_id
+
+        override_model_configs = None
+        if application_generate_entity.app_model_config_override:
+            override_model_configs = application_generate_entity.app_model_config_dict
+
+        introduction = ''
+        if app_mode == 'chat':
+            # get conversation introduction
+            introduction = self._get_conversation_introduction(application_generate_entity)
+
+        if not application_generate_entity.conversation_id:
+            conversation = Conversation(
+                app_id=app_record.id,
+                app_model_config_id=application_generate_entity.app_model_config_id,
+                model_provider=app_orchestration_config_entity.model_config.provider,
+                model_id=app_orchestration_config_entity.model_config.model,
+                override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
+                mode=app_mode,
+                name='New conversation',
+                inputs=application_generate_entity.inputs,
+                introduction=introduction,
+                system_instruction="",
+                system_instruction_tokens=0,
+                status='normal',
+                from_source=from_source,
+                from_end_user_id=end_user_id,
+                from_account_id=account_id,
+            )
+
+            db.session.add(conversation)
+            db.session.commit()
+        else:
+            conversation = (
+                db.session.query(Conversation)
+                .filter(
+                    Conversation.id == application_generate_entity.conversation_id,
+                    Conversation.app_id == app_record.id
+                ).first()
+            )
+
+        currency = model_schema.pricing.currency if model_schema.pricing else 'USD'
+
+        message = Message(
+            app_id=app_record.id,
+            model_provider=app_orchestration_config_entity.model_config.provider,
+            model_id=app_orchestration_config_entity.model_config.model,
+            override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
+            conversation_id=conversation.id,
+            inputs=application_generate_entity.inputs,
+            query=application_generate_entity.query or "",
+            message="",
+            message_tokens=0,
+            message_unit_price=0,
+            message_price_unit=0,
+            answer="",
+            answer_tokens=0,
+            answer_unit_price=0,
+            answer_price_unit=0,
+            provider_response_latency=0,
+            total_price=0,
+            currency=currency,
+            from_source=from_source,
+            from_end_user_id=end_user_id,
+            from_account_id=account_id,
+            agent_based=app_orchestration_config_entity.agent is not None
+        )
+
+        db.session.add(message)
+        db.session.commit()
+
+        for file in application_generate_entity.files:
+            message_file = MessageFile(
+                message_id=message.id,
+                type=file.type.value,
+                transfer_method=file.transfer_method.value,
+                url=file.url,
+                upload_file_id=file.upload_file_id,
+                created_by_role=('account' if account_id else 'end_user'),
+                created_by=account_id or end_user_id,
+            )
+            db.session.add(message_file)
+            db.session.commit()
+
+        return conversation, message
+
+    def _get_conversation_introduction(self, application_generate_entity: ApplicationGenerateEntity) -> str:
+        """
+        Get conversation introduction
+        :param application_generate_entity: application generate entity
+        :return: conversation introduction
+        """
+        app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
+        introduction = app_orchestration_config_entity.opening_statement
+
+        if introduction:
+            try:
+                inputs = application_generate_entity.inputs
+                prompt_template = PromptTemplateParser(template=introduction)
+                prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
+                introduction = prompt_template.format(prompt_inputs)
+            except KeyError:
+                pass
+
+        return introduction
+
+    def _get_conversation(self, conversation_id: str) -> Conversation:
+        """
+        Get conversation by conversation id
+        :param conversation_id: conversation id
+        :return: conversation
+        """
+        conversation = (
+            db.session.query(Conversation)
+            .filter(Conversation.id == conversation_id)
+            .first()
+        )
+
+        return conversation
+
+    def _get_message(self, message_id: str) -> Message:
+        """
+        Get message by message id
+        :param message_id: message id
+        :return: message
+        """
+        message = (
+            db.session.query(Message)
+            .filter(Message.id == message_id)
+            .first()
+        )
+
+        return message

+ 228 - 0
api/core/application_queue_manager.py

@@ -0,0 +1,228 @@
+import queue
+import time
+from typing import Generator, Any
+
+from sqlalchemy.orm import DeclarativeMeta
+
+from core.entities.application_entities import InvokeFrom
+from core.entities.queue_entities import QueueStopEvent, AppQueueEvent, QueuePingEvent, QueueErrorEvent, \
+    QueueAgentThoughtEvent, QueueMessageEndEvent, QueueRetrieverResourcesEvent, QueueMessageReplaceEvent, \
+    QueueMessageEvent, QueueMessage, AnnotationReplyEvent
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
+from extensions.ext_redis import redis_client
+from models.model import MessageAgentThought
+
+
+class ApplicationQueueManager:
+    def __init__(self, task_id: str,
+                 user_id: str,
+                 invoke_from: InvokeFrom,
+                 conversation_id: str,
+                 app_mode: str,
+                 message_id: str) -> None:
+        if not user_id:
+            raise ValueError("user is required")
+
+        self._task_id = task_id
+        self._user_id = user_id
+        self._invoke_from = invoke_from
+        self._conversation_id = str(conversation_id)
+        self._app_mode = app_mode
+        self._message_id = str(message_id)
+
+        user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
+        redis_client.setex(ApplicationQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}")
+
+        q = queue.Queue()
+
+        self._q = q
+
+    def listen(self) -> Generator:
+        """
+        Listen to queue
+        :return:
+        """
+        # wait for 10 minutes to stop listen
+        listen_timeout = 600
+        start_time = time.time()
+        last_ping_time = 0
+
+        while True:
+            try:
+                message = self._q.get(timeout=1)
+                if message is None:
+                    break
+
+                yield message
+            except queue.Empty:
+                continue
+            finally:
+                elapsed_time = time.time() - start_time
+                if elapsed_time >= listen_timeout or self._is_stopped():
+                    # publish two messages to make sure the client can receive the stop signal
+                    # and stop listening after the stop signal processed
+                    self.publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL))
+                    self.stop_listen()
+
+                if elapsed_time // 10 > last_ping_time:
+                    self.publish(QueuePingEvent())
+                    last_ping_time = elapsed_time // 10
+
+    def stop_listen(self) -> None:
+        """
+        Stop listen to queue
+        :return:
+        """
+        self._q.put(None)
+
+    def publish_chunk_message(self, chunk: LLMResultChunk) -> None:
+        """
+        Publish chunk message to channel
+
+        :param chunk: chunk
+        :return:
+        """
+        self.publish(QueueMessageEvent(
+            chunk=chunk
+        ))
+
+    def publish_message_replace(self, text: str) -> None:
+        """
+        Publish message replace
+        :param text: text
+        :return:
+        """
+        self.publish(QueueMessageReplaceEvent(
+            text=text
+        ))
+
+    def publish_retriever_resources(self, retriever_resources: list[dict]) -> None:
+        """
+        Publish retriever resources
+        :return:
+        """
+        self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources))
+
+    def publish_annotation_reply(self, message_annotation_id: str) -> None:
+        """
+        Publish annotation reply
+        :param message_annotation_id: message annotation id
+        :return:
+        """
+        self.publish(AnnotationReplyEvent(message_annotation_id=message_annotation_id))
+
+    def publish_message_end(self, llm_result: LLMResult) -> None:
+        """
+        Publish message end
+        :param llm_result: llm result
+        :return:
+        """
+        self.publish(QueueMessageEndEvent(llm_result=llm_result))
+        self.stop_listen()
+
+    def publish_agent_thought(self, message_agent_thought: MessageAgentThought) -> None:
+        """
+        Publish agent thought
+        :param message_agent_thought: message agent thought
+        :return:
+        """
+        self.publish(QueueAgentThoughtEvent(
+            agent_thought_id=message_agent_thought.id
+        ))
+
+    def publish_error(self, e) -> None:
+        """
+        Publish error
+        :param e: error
+        :return:
+        """
+        self.publish(QueueErrorEvent(
+            error=e
+        ))
+        self.stop_listen()
+
+    def publish(self, event: AppQueueEvent) -> None:
+        """
+        Publish event to queue
+        :param event:
+        :return:
+        """
+        self._check_for_sqlalchemy_models(event.dict())
+
+        message = QueueMessage(
+            task_id=self._task_id,
+            message_id=self._message_id,
+            conversation_id=self._conversation_id,
+            app_mode=self._app_mode,
+            event=event
+        )
+
+        self._q.put(message)
+
+        if isinstance(event, QueueStopEvent):
+            self.stop_listen()
+
+    @classmethod
+    def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None:
+        """
+        Set task stop flag
+        :return:
+        """
+        result = redis_client.get(cls._generate_task_belong_cache_key(task_id))
+        if result is None:
+            return
+
+        user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
+        if result != f"{user_prefix}-{user_id}":
+            return
+
+        stopped_cache_key = cls._generate_stopped_cache_key(task_id)
+        redis_client.setex(stopped_cache_key, 600, 1)
+
+    def _is_stopped(self) -> bool:
+        """
+        Check if task is stopped
+        :return:
+        """
+        stopped_cache_key = ApplicationQueueManager._generate_stopped_cache_key(self._task_id)
+        result = redis_client.get(stopped_cache_key)
+        if result is not None:
+            redis_client.delete(stopped_cache_key)
+            return True
+
+        return False
+
+    @classmethod
+    def _generate_task_belong_cache_key(cls, task_id: str) -> str:
+        """
+        Generate task belong cache key
+        :param task_id: task id
+        :return:
+        """
+        return f"generate_task_belong:{task_id}"
+
+    @classmethod
+    def _generate_stopped_cache_key(cls, task_id: str) -> str:
+        """
+        Generate stopped cache key
+        :param task_id: task id
+        :return:
+        """
+        return f"generate_task_stopped:{task_id}"
+
+    def _check_for_sqlalchemy_models(self, data: Any):
+        # from entity to dict or list
+        if isinstance(data, dict):
+            for key, value in data.items():
+                self._check_for_sqlalchemy_models(value)
+        elif isinstance(data, list):
+            for item in data:
+                self._check_for_sqlalchemy_models(item)
+        else:
+            if isinstance(data, DeclarativeMeta) or hasattr(data, '_sa_instance_state'):
+                raise TypeError("Critical Error: Passing SQLAlchemy Model instances "
+                                "that cause thread safety issues is not allowed.")
+
+
+class ConversationTaskStoppedException(Exception):
+    pass

+ 110 - 65
api/core/callback_handler/agent_loop_gather_callback_handler.py

@@ -2,30 +2,40 @@ import json
 import logging
 import logging
 import time
 import time
 
 
-from typing import Any, Dict, List, Union, Optional
+from typing import Any, Dict, List, Union, Optional, cast
 
 
 from langchain.agents import openai_functions_agent, openai_functions_multi_agent
 from langchain.agents import openai_functions_agent, openai_functions_multi_agent
 from langchain.callbacks.base import BaseCallbackHandler
 from langchain.callbacks.base import BaseCallbackHandler
 from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage
 from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage
 
 
+from core.application_queue_manager import ApplicationQueueManager
 from core.callback_handler.entity.agent_loop import AgentLoop
 from core.callback_handler.entity.agent_loop import AgentLoop
-from core.conversation_message_task import ConversationMessageTask
-from core.model_providers.models.entity.message import PromptMessage
-from core.model_providers.models.llm.base import BaseLLM
+from core.entities.application_entities import ModelConfigEntity
+from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult
+from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessage
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from extensions.ext_database import db
+from models.model import MessageChain, MessageAgentThought, Message
 
 
 
 
 class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
 class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
     """Callback Handler that prints to std out."""
     """Callback Handler that prints to std out."""
     raise_error: bool = True
     raise_error: bool = True
 
 
-    def __init__(self, model_instance: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
+    def __init__(self, model_config: ModelConfigEntity,
+                 queue_manager: ApplicationQueueManager,
+                 message: Message,
+                 message_chain: MessageChain) -> None:
         """Initialize callback handler."""
         """Initialize callback handler."""
-        self.model_instance = model_instance
-        self.conversation_message_task = conversation_message_task
+        self.model_config = model_config
+        self.queue_manager = queue_manager
+        self.message = message
+        self.message_chain = message_chain
+        model_type_instance = self.model_config.provider_model_bundle.model_type_instance
+        self.model_type_instance = cast(LargeLanguageModel, model_type_instance)
         self._agent_loops = []
         self._agent_loops = []
         self._current_loop = None
         self._current_loop = None
         self._message_agent_thought = None
         self._message_agent_thought = None
-        self.current_chain = None
 
 
     @property
     @property
     def agent_loops(self) -> List[AgentLoop]:
     def agent_loops(self) -> List[AgentLoop]:
@@ -46,66 +56,61 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
         """Whether to ignore chain callbacks."""
         """Whether to ignore chain callbacks."""
         return True
         return True
 
 
-    def on_chat_model_start(
-            self,
-            serialized: Dict[str, Any],
-            messages: List[List[BaseMessage]],
-            **kwargs: Any
-    ) -> Any:
-        if not self._current_loop:
-            # Agent start with a LLM query
-            self._current_loop = AgentLoop(
-                position=len(self._agent_loops) + 1,
-                prompt="\n".join([message.content for message in messages[0]]),
-                status='llm_started',
-                started_at=time.perf_counter()
-            )
-
-    def on_llm_start(
-        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
-    ) -> None:
-        """Print out the prompts."""
-        # serialized={'name': 'OpenAI'}
-        # prompts=['Answer the following questions...\nThought:']
-        # kwargs={}
+    def on_llm_before_invoke(self, prompt_messages: list[PromptMessage]) -> None:
         if not self._current_loop:
         if not self._current_loop:
             # Agent start with a LLM query
             # Agent start with a LLM query
             self._current_loop = AgentLoop(
             self._current_loop = AgentLoop(
                 position=len(self._agent_loops) + 1,
                 position=len(self._agent_loops) + 1,
-                prompt=prompts[0],
+                prompt="\n".join([prompt_message.content for prompt_message in prompt_messages]),
                 status='llm_started',
                 status='llm_started',
                 started_at=time.perf_counter()
                 started_at=time.perf_counter()
             )
             )
 
 
-    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
-        """Do nothing."""
-        # kwargs={}
+    def on_llm_after_invoke(self, result: RuntimeLLMResult) -> None:
         if self._current_loop and self._current_loop.status == 'llm_started':
         if self._current_loop and self._current_loop.status == 'llm_started':
             self._current_loop.status = 'llm_end'
             self._current_loop.status = 'llm_end'
-            if response.llm_output:
-                self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
+            if result.usage:
+                self._current_loop.prompt_tokens = result.usage.prompt_tokens
             else:
             else:
-                self._current_loop.prompt_tokens = self.model_instance.get_num_tokens(
-                    [PromptMessage(content=self._current_loop.prompt)]
+                self._current_loop.prompt_tokens = self.model_type_instance.get_num_tokens(
+                    model=self.model_config.model,
+                    credentials=self.model_config.credentials,
+                    prompt_messages=[UserPromptMessage(content=self._current_loop.prompt)]
                 )
                 )
-            completion_generation = response.generations[0][0]
-            if isinstance(completion_generation, ChatGeneration):
-                completion_message = completion_generation.message
-                if 'function_call' in completion_message.additional_kwargs:
-                    self._current_loop.completion \
-                        = json.dumps({'function_call': completion_message.additional_kwargs['function_call']})
-                else:
-                    self._current_loop.completion = response.generations[0][0].text
+
+            completion_message = result.message
+            if completion_message.tool_calls:
+                self._current_loop.completion \
+                    = json.dumps({'function_call': completion_message.tool_calls})
             else:
             else:
-                self._current_loop.completion = completion_generation.text
+                self._current_loop.completion = completion_message.content
 
 
-            if response.llm_output:
-                self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
+            if result.usage:
+                self._current_loop.completion_tokens = result.usage.completion_tokens
             else:
             else:
-                self._current_loop.completion_tokens = self.model_instance.get_num_tokens(
-                    [PromptMessage(content=self._current_loop.completion)]
+                self._current_loop.completion_tokens = self.model_type_instance.get_num_tokens(
+                    model=self.model_config.model,
+                    credentials=self.model_config.credentials,
+                    prompt_messages=[AssistantPromptMessage(content=self._current_loop.completion)]
                 )
                 )
 
 
+    def on_chat_model_start(
+            self,
+            serialized: Dict[str, Any],
+            messages: List[List[BaseMessage]],
+            **kwargs: Any
+    ) -> Any:
+        pass
+
+    def on_llm_start(
+        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
+    ) -> None:
+        pass
+
+    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
+        """Do nothing."""
+        pass
+
     def on_llm_error(
     def on_llm_error(
         self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
         self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
     ) -> None:
     ) -> None:
@@ -150,10 +155,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
             if completion is not None:
             if completion is not None:
                 self._current_loop.completion = completion
                 self._current_loop.completion = completion
 
 
-            self._message_agent_thought = self.conversation_message_task.on_agent_start(
-                self.current_chain,
-                self._current_loop
-            )
+            self._message_agent_thought = self._init_agent_thought()
 
 
     def on_tool_end(
     def on_tool_end(
         self,
         self,
@@ -176,9 +178,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
             self._current_loop.completed_at = time.perf_counter()
             self._current_loop.completed_at = time.perf_counter()
             self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
             self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
 
 
-            self.conversation_message_task.on_agent_end(
-                self._message_agent_thought, self.model_instance, self._current_loop
-            )
+            self._complete_agent_thought(self._message_agent_thought)
 
 
             self._agent_loops.append(self._current_loop)
             self._agent_loops.append(self._current_loop)
             self._current_loop = None
             self._current_loop = None
@@ -202,17 +202,62 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
             self._current_loop.completed_at = time.perf_counter()
             self._current_loop.completed_at = time.perf_counter()
             self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
             self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
             self._current_loop.thought = '[DONE]'
             self._current_loop.thought = '[DONE]'
-            self._message_agent_thought = self.conversation_message_task.on_agent_start(
-                self.current_chain,
-                self._current_loop
-            )
+            self._message_agent_thought = self._init_agent_thought()
 
 
-            self.conversation_message_task.on_agent_end(
-                self._message_agent_thought, self.model_instance, self._current_loop
-            )
+            self._complete_agent_thought(self._message_agent_thought)
 
 
             self._agent_loops.append(self._current_loop)
             self._agent_loops.append(self._current_loop)
             self._current_loop = None
             self._current_loop = None
             self._message_agent_thought = None
             self._message_agent_thought = None
         elif not self._current_loop and self._agent_loops:
         elif not self._current_loop and self._agent_loops:
             self._agent_loops[-1].status = 'agent_finish'
             self._agent_loops[-1].status = 'agent_finish'
+
+    def _init_agent_thought(self) -> MessageAgentThought:
+        message_agent_thought = MessageAgentThought(
+            message_id=self.message.id,
+            message_chain_id=self.message_chain.id,
+            position=self._current_loop.position,
+            thought=self._current_loop.thought,
+            tool=self._current_loop.tool_name,
+            tool_input=self._current_loop.tool_input,
+            message=self._current_loop.prompt,
+            message_price_unit=0,
+            answer=self._current_loop.completion,
+            answer_price_unit=0,
+            created_by_role=('account' if self.message.from_source == 'console' else 'end_user'),
+            created_by=(self.message.from_account_id
+                        if self.message.from_source == 'console' else self.message.from_end_user_id)
+        )
+
+        db.session.add(message_agent_thought)
+        db.session.commit()
+
+        self.queue_manager.publish_agent_thought(message_agent_thought)
+
+        return message_agent_thought
+
+    def _complete_agent_thought(self, message_agent_thought: MessageAgentThought) -> None:
+        loop_message_tokens = self._current_loop.prompt_tokens
+        loop_answer_tokens = self._current_loop.completion_tokens
+
+        # transform usage
+        llm_usage = self.model_type_instance._calc_response_usage(
+            self.model_config.model,
+            self.model_config.credentials,
+            loop_message_tokens,
+            loop_answer_tokens
+        )
+
+        message_agent_thought.observation = self._current_loop.tool_output
+        message_agent_thought.tool_process_data = ''  # currently not support
+        message_agent_thought.message_token = loop_message_tokens
+        message_agent_thought.message_unit_price = llm_usage.prompt_unit_price
+        message_agent_thought.message_price_unit = llm_usage.prompt_price_unit
+        message_agent_thought.answer_token = loop_answer_tokens
+        message_agent_thought.answer_unit_price = llm_usage.completion_unit_price
+        message_agent_thought.answer_price_unit = llm_usage.completion_price_unit
+        message_agent_thought.latency = self._current_loop.latency
+        message_agent_thought.tokens = self._current_loop.prompt_tokens + self._current_loop.completion_tokens
+        message_agent_thought.total_price = llm_usage.total_price
+        message_agent_thought.currency = llm_usage.currency
+        db.session.commit()

+ 0 - 74
api/core/callback_handler/dataset_tool_callback_handler.py

@@ -1,74 +0,0 @@
-import json
-import logging
-from json import JSONDecodeError
-
-from typing import Any, Dict, List, Union, Optional
-
-from langchain.callbacks.base import BaseCallbackHandler
-
-from core.callback_handler.entity.dataset_query import DatasetQueryObj
-from core.conversation_message_task import ConversationMessageTask
-
-
-class DatasetToolCallbackHandler(BaseCallbackHandler):
-    """Callback Handler that prints to std out."""
-    raise_error: bool = True
-
-    def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
-        """Initialize callback handler."""
-        self.queries = []
-        self.conversation_message_task = conversation_message_task
-
-    @property
-    def always_verbose(self) -> bool:
-        """Whether to call verbose callbacks even if verbose is False."""
-        return True
-
-    @property
-    def ignore_llm(self) -> bool:
-        """Whether to ignore LLM callbacks."""
-        return True
-
-    @property
-    def ignore_chain(self) -> bool:
-        """Whether to ignore chain callbacks."""
-        return True
-
-    @property
-    def ignore_agent(self) -> bool:
-        """Whether to ignore agent callbacks."""
-        return False
-
-    def on_tool_start(
-        self,
-        serialized: Dict[str, Any],
-        input_str: str,
-        **kwargs: Any,
-    ) -> None:
-        tool_name: str = serialized.get('name')
-        dataset_id = tool_name.removeprefix('dataset-')
-
-        try:
-            input_dict = json.loads(input_str.replace("'", "\""))
-            query = input_dict.get('query')
-        except JSONDecodeError:
-            query = input_str
-
-        self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=query))
-
-    def on_tool_end(
-        self,
-        output: str,
-        color: Optional[str] = None,
-        observation_prefix: Optional[str] = None,
-        llm_prefix: Optional[str] = None,
-        **kwargs: Any,
-    ) -> None:
-        pass
-
-
-    def on_tool_error(
-        self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
-    ) -> None:
-        """Do nothing."""
-        logging.debug("Dataset tool on_llm_error: %s", error)

+ 0 - 16
api/core/callback_handler/entity/chain_result.py

@@ -1,16 +0,0 @@
-from pydantic import BaseModel
-
-
-class ChainResult(BaseModel):
-    type: str = None
-    prompt: dict = None
-    completion: dict = None
-
-    status: str = 'chain_started'
-    completed: bool = False
-
-    started_at: float = None
-    completed_at: float = None
-
-    agent_result: dict = None
-    """only when type is 'AgentExecutor'"""

+ 0 - 6
api/core/callback_handler/entity/dataset_query.py

@@ -1,6 +0,0 @@
-from pydantic import BaseModel
-
-
-class DatasetQueryObj(BaseModel):
-    dataset_id: str = None
-    query: str = None

+ 0 - 8
api/core/callback_handler/entity/llm_message.py

@@ -1,8 +0,0 @@
-from pydantic import BaseModel
-
-
-class LLMMessage(BaseModel):
-    prompt: str = ''
-    prompt_tokens: int = 0
-    completion: str = ''
-    completion_tokens: int = 0

+ 56 - 6
api/core/callback_handler/index_tool_callback_handler.py

@@ -1,17 +1,44 @@
-from typing import List
+from typing import List, Union
 
 
 from langchain.schema import Document
 from langchain.schema import Document
 
 
-from core.conversation_message_task import ConversationMessageTask
+from core.application_queue_manager import ApplicationQueueManager
+from core.entities.application_entities import InvokeFrom
 from extensions.ext_database import db
 from extensions.ext_database import db
-from models.dataset import DocumentSegment
+from models.dataset import DocumentSegment, DatasetQuery
+from models.model import DatasetRetrieverResource
 
 
 
 
 class DatasetIndexToolCallbackHandler:
 class DatasetIndexToolCallbackHandler:
     """Callback handler for dataset tool."""
     """Callback handler for dataset tool."""
 
 
-    def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
-        self.conversation_message_task = conversation_message_task
+    def __init__(self, queue_manager: ApplicationQueueManager,
+                 app_id: str,
+                 message_id: str,
+                 user_id: str,
+                 invoke_from: InvokeFrom) -> None:
+        self._queue_manager = queue_manager
+        self._app_id = app_id
+        self._message_id = message_id
+        self._user_id = user_id
+        self._invoke_from = invoke_from
+
+    def on_query(self, query: str, dataset_id: str) -> None:
+        """
+        Handle query.
+        """
+        dataset_query = DatasetQuery(
+            dataset_id=dataset_id,
+            content=query,
+            source='app',
+            source_app_id=self._app_id,
+            created_by_role=('account'
+                             if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'),
+            created_by=self._user_id
+        )
+
+        db.session.add(dataset_query)
+        db.session.commit()
 
 
     def on_tool_end(self, documents: List[Document]) -> None:
     def on_tool_end(self, documents: List[Document]) -> None:
         """Handle tool end."""
         """Handle tool end."""
@@ -30,4 +57,27 @@ class DatasetIndexToolCallbackHandler:
 
 
     def return_retriever_resource_info(self, resource: List):
     def return_retriever_resource_info(self, resource: List):
         """Handle return_retriever_resource_info."""
         """Handle return_retriever_resource_info."""
-        self.conversation_message_task.on_dataset_query_finish(resource)
+        if resource and len(resource) > 0:
+            for item in resource:
+                dataset_retriever_resource = DatasetRetrieverResource(
+                    message_id=self._message_id,
+                    position=item.get('position'),
+                    dataset_id=item.get('dataset_id'),
+                    dataset_name=item.get('dataset_name'),
+                    document_id=item.get('document_id'),
+                    document_name=item.get('document_name'),
+                    data_source_type=item.get('data_source_type'),
+                    segment_id=item.get('segment_id'),
+                    score=item.get('score') if 'score' in item else None,
+                    hit_count=item.get('hit_count') if 'hit_count' else None,
+                    word_count=item.get('word_count') if 'word_count' in item else None,
+                    segment_position=item.get('segment_position') if 'segment_position' in item else None,
+                    index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None,
+                    content=item.get('content'),
+                    retriever_from=item.get('retriever_from'),
+                    created_by=self._user_id
+                )
+                db.session.add(dataset_retriever_resource)
+                db.session.commit()
+
+        self._queue_manager.publish_retriever_resources(resource)

+ 0 - 284
api/core/callback_handler/llm_callback_handler.py

@@ -1,284 +0,0 @@
-import logging
-import threading
-import time
-from typing import Any, Dict, List, Union, Optional
-
-from flask import Flask, current_app
-from langchain.callbacks.base import BaseCallbackHandler
-from langchain.schema import LLMResult, BaseMessage
-from pydantic import BaseModel
-
-from core.callback_handler.entity.llm_message import LLMMessage
-from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
-    ConversationTaskInterruptException
-from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage, LCHumanMessageWithFiles, \
-    ImagePromptMessageFile
-from core.model_providers.models.llm.base import BaseLLM
-from core.moderation.base import ModerationOutputsResult, ModerationAction
-from core.moderation.factory import ModerationFactory
-
-
-class ModerationRule(BaseModel):
-    type: str
-    config: Dict[str, Any]
-
-
-class LLMCallbackHandler(BaseCallbackHandler):
-    raise_error: bool = True
-
-    def __init__(self, model_instance: BaseLLM,
-                 conversation_message_task: ConversationMessageTask):
-        self.model_instance = model_instance
-        self.llm_message = LLMMessage()
-        self.start_at = None
-        self.conversation_message_task = conversation_message_task
-
-        self.output_moderation_handler = None
-        self.init_output_moderation()
-
-    def init_output_moderation(self):
-        app_model_config = self.conversation_message_task.app_model_config
-        sensitive_word_avoidance_dict = app_model_config.sensitive_word_avoidance_dict
-
-        if sensitive_word_avoidance_dict and sensitive_word_avoidance_dict.get("enabled"):
-            self.output_moderation_handler = OutputModerationHandler(
-                tenant_id=self.conversation_message_task.tenant_id,
-                app_id=self.conversation_message_task.app.id,
-                rule=ModerationRule(
-                    type=sensitive_word_avoidance_dict.get("type"),
-                    config=sensitive_word_avoidance_dict.get("config")
-                ),
-                on_message_replace_func=self.conversation_message_task.on_message_replace
-            )
-
-    @property
-    def always_verbose(self) -> bool:
-        """Whether to call verbose callbacks even if verbose is False."""
-        return True
-
-    def on_chat_model_start(
-            self,
-            serialized: Dict[str, Any],
-            messages: List[List[BaseMessage]],
-            **kwargs: Any
-    ) -> Any:
-        real_prompts = []
-        for message in messages[0]:
-            if message.type == 'human':
-                role = 'user'
-            elif message.type == 'ai':
-                role = 'assistant'
-            else:
-                role = 'system'
-
-            real_prompts.append({
-                "role": role,
-                "text": message.content,
-                "files": [{
-                    "type": file.type.value,
-                    "data": file.data[:10] + '...[TRUNCATED]...' + file.data[-10:],
-                    "detail": file.detail.value if isinstance(file, ImagePromptMessageFile) else None,
-                } for file in (message.files if isinstance(message, LCHumanMessageWithFiles) else [])]
-            })
-
-        self.llm_message.prompt = real_prompts
-        self.llm_message.prompt_tokens = self.model_instance.get_num_tokens(to_prompt_messages(messages[0]))
-
-    def on_llm_start(
-        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
-    ) -> None:
-        self.llm_message.prompt = [{
-            "role": 'user',
-            "text": prompts[0]
-        }]
-
-        self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
-
-    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
-        if self.output_moderation_handler:
-            self.output_moderation_handler.stop_thread()
-
-            self.llm_message.completion = self.output_moderation_handler.moderation_completion(
-                completion=response.generations[0][0].text,
-                public_event=True if self.conversation_message_task.streaming else False
-            )
-        else:
-            self.llm_message.completion = response.generations[0][0].text
-
-        if not self.conversation_message_task.streaming:
-            self.conversation_message_task.append_message_text(self.llm_message.completion)
-
-        if response.llm_output and 'token_usage' in response.llm_output:
-            if 'prompt_tokens' in response.llm_output['token_usage']:
-                self.llm_message.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
-
-            if 'completion_tokens' in response.llm_output['token_usage']:
-                self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
-            else:
-                self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
-                    [PromptMessage(content=self.llm_message.completion)])
-        else:
-            self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
-                [PromptMessage(content=self.llm_message.completion)])
-
-        self.conversation_message_task.save_message(self.llm_message)
-
-    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
-        if self.output_moderation_handler and self.output_moderation_handler.should_direct_output():
-            # stop subscribe new token when output moderation should direct output
-            ex = ConversationTaskInterruptException()
-            self.on_llm_error(error=ex)
-            raise ex
-
-        try:
-            self.conversation_message_task.append_message_text(token)
-            self.llm_message.completion += token
-
-            if self.output_moderation_handler:
-                self.output_moderation_handler.append_new_token(token)
-        except ConversationTaskStoppedException as ex:
-            self.on_llm_error(error=ex)
-            raise ex
-
-    def on_llm_error(
-            self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
-    ) -> None:
-        """Do nothing."""
-        if self.output_moderation_handler:
-            self.output_moderation_handler.stop_thread()
-
-        if isinstance(error, ConversationTaskStoppedException):
-            if self.conversation_message_task.streaming:
-                self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
-                    [PromptMessage(content=self.llm_message.completion)]
-                )
-                self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
-        if isinstance(error, ConversationTaskInterruptException):
-            self.llm_message.completion = self.output_moderation_handler.get_final_output()
-            self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
-                [PromptMessage(content=self.llm_message.completion)]
-            )
-            self.conversation_message_task.save_message(llm_message=self.llm_message)
-        else:
-            logging.debug("on_llm_error: %s", error)
-
-
-class OutputModerationHandler(BaseModel):
-    DEFAULT_BUFFER_SIZE: int = 300
-
-    tenant_id: str
-    app_id: str
-
-    rule: ModerationRule
-    on_message_replace_func: Any
-
-    thread: Optional[threading.Thread] = None
-    thread_running: bool = True
-    buffer: str = ''
-    is_final_chunk: bool = False
-    final_output: Optional[str] = None
-
-    class Config:
-        arbitrary_types_allowed = True
-
-    def should_direct_output(self):
-        return self.final_output is not None
-
-    def get_final_output(self):
-        return self.final_output
-
-    def append_new_token(self, token: str):
-        self.buffer += token
-
-        if not self.thread:
-            self.thread = self.start_thread()
-
-    def moderation_completion(self, completion: str, public_event: bool = False) -> str:
-        self.buffer = completion
-        self.is_final_chunk = True
-
-        result = self.moderation(
-            tenant_id=self.tenant_id,
-            app_id=self.app_id,
-            moderation_buffer=completion
-        )
-
-        if not result or not result.flagged:
-            return completion
-
-        if result.action == ModerationAction.DIRECT_OUTPUT:
-            final_output = result.preset_response
-        else:
-            final_output = result.text
-
-        if public_event:
-            self.on_message_replace_func(final_output)
-
-        return final_output
-
-    def start_thread(self) -> threading.Thread:
-        buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE))
-        thread = threading.Thread(target=self.worker, kwargs={
-            'flask_app': current_app._get_current_object(),
-            'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE
-        })
-
-        thread.start()
-
-        return thread
-
-    def stop_thread(self):
-        if self.thread and self.thread.is_alive():
-            self.thread_running = False
-
-    def worker(self, flask_app: Flask, buffer_size: int):
-        with flask_app.app_context():
-            current_length = 0
-            while self.thread_running:
-                moderation_buffer = self.buffer
-                buffer_length = len(moderation_buffer)
-                if not self.is_final_chunk:
-                    chunk_length = buffer_length - current_length
-                    if 0 <= chunk_length < buffer_size:
-                        time.sleep(1)
-                        continue
-
-                current_length = buffer_length
-
-                result = self.moderation(
-                    tenant_id=self.tenant_id,
-                    app_id=self.app_id,
-                    moderation_buffer=moderation_buffer
-                )
-
-                if not result or not result.flagged:
-                    continue
-
-                if result.action == ModerationAction.DIRECT_OUTPUT:
-                    final_output = result.preset_response
-                    self.final_output = final_output
-                else:
-                    final_output = result.text + self.buffer[len(moderation_buffer):]
-
-                # trigger replace event
-                if self.thread_running:
-                    self.on_message_replace_func(final_output)
-
-                if result.action == ModerationAction.DIRECT_OUTPUT:
-                    break
-
-    def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
-        try:
-            moderation_factory = ModerationFactory(
-                name=self.rule.type,
-                app_id=app_id,
-                tenant_id=tenant_id,
-                config=self.rule.config
-            )
-
-            result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
-            return result
-        except Exception as e:
-            logging.error("Moderation Output error: %s", e)
-
-        return None

+ 0 - 76
api/core/callback_handler/main_chain_gather_callback_handler.py

@@ -1,76 +0,0 @@
-import logging
-import time
-
-from typing import Any, Dict, Union
-
-from langchain.callbacks.base import BaseCallbackHandler
-
-from core.callback_handler.entity.chain_result import ChainResult
-from core.conversation_message_task import ConversationMessageTask
-
-
-class MainChainGatherCallbackHandler(BaseCallbackHandler):
-    """Callback Handler that prints to std out."""
-    raise_error: bool = True
-
-    def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
-        """Initialize callback handler."""
-        self._current_chain_result = None
-        self._current_chain_message = None
-        self.conversation_message_task = conversation_message_task
-        self.agent_callback = None
-
-    def clear_chain_results(self) -> None:
-        self._current_chain_result = None
-        self._current_chain_message = None
-        if self.agent_callback:
-            self.agent_callback.current_chain = None
-
-    @property
-    def always_verbose(self) -> bool:
-        """Whether to call verbose callbacks even if verbose is False."""
-        return True
-
-    @property
-    def ignore_llm(self) -> bool:
-        """Whether to ignore LLM callbacks."""
-        return True
-
-    @property
-    def ignore_agent(self) -> bool:
-        """Whether to ignore agent callbacks."""
-        return True
-
-    def on_chain_start(
-        self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
-    ) -> None:
-        """Print out that we are entering a chain."""
-        if not self._current_chain_result:
-            chain_type = serialized['id'][-1]
-            if chain_type:
-                self._current_chain_result = ChainResult(
-                    type=chain_type,
-                    prompt=inputs,
-                    started_at=time.perf_counter()
-                )
-                self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result)
-                if self.agent_callback:
-                    self.agent_callback.current_chain = self._current_chain_message
-
-    def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
-        """Print out that we finished a chain."""
-        if self._current_chain_result and self._current_chain_result.status == 'chain_started':
-            self._current_chain_result.status = 'chain_ended'
-            self._current_chain_result.completion = outputs
-            self._current_chain_result.completed = True
-            self._current_chain_result.completed_at = time.perf_counter()
-
-            self.conversation_message_task.on_chain_end(self._current_chain_message, self._current_chain_result)
-
-            self.clear_chain_results()
-
-    def on_chain_error(
-        self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
-    ) -> None:
-        logging.debug("Dataset tool on_chain_error: %s", error)
-        self.clear_chain_results()

+ 5 - 2
api/core/callback_handler/std_out_callback_handler.py

@@ -79,8 +79,11 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
         """Run on agent action."""
         """Run on agent action."""
         tool = action.tool
         tool = action.tool
         tool_input = action.tool_input
         tool_input = action.tool_input
-        action_name_position = action.log.index("\nAction:") + 1 if action.log else -1
-        thought = action.log[:action_name_position].strip() if action.log else ''
+        try:
+            action_name_position = action.log.index("\nAction:") + 1 if action.log else -1
+            thought = action.log[:action_name_position].strip() if action.log else ''
+        except ValueError:
+            thought = ''
 
 
         log = f"Thought: {thought}\nTool: {tool}\nTool Input: {tool_input}"
         log = f"Thought: {thought}\nTool: {tool}\nTool Input: {tool_input}"
         print_text("\n[on_agent_action]\n" + log + "\n", color='green')
         print_text("\n[on_agent_action]\n" + log + "\n", color='green')

+ 21 - 8
api/core/chain/llm_chain.py

@@ -5,15 +5,19 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
 from langchain.schema import LLMResult, Generation
 from langchain.schema import LLMResult, Generation
 from langchain.schema.language_model import BaseLanguageModel
 from langchain.schema.language_model import BaseLanguageModel
 
 
-from core.model_providers.models.entity.message import to_prompt_messages
-from core.model_providers.models.llm.base import BaseLLM
+from core.agent.agent.agent_llm_callback import AgentLLMCallback
+from core.entities.application_entities import ModelConfigEntity
+from core.model_manager import ModelInstance
+from core.entities.message_entities import lc_messages_to_prompt_messages
 from core.third_party.langchain.llms.fake import FakeLLM
 from core.third_party.langchain.llms.fake import FakeLLM
 
 
 
 
 class LLMChain(LCLLMChain):
 class LLMChain(LCLLMChain):
-    model_instance: BaseLLM
+    model_config: ModelConfigEntity
     """The language model instance to use."""
     """The language model instance to use."""
     llm: BaseLanguageModel = FakeLLM(response="")
     llm: BaseLanguageModel = FakeLLM(response="")
+    parameters: Dict[str, Any] = {}
+    agent_llm_callback: Optional[AgentLLMCallback] = None
 
 
     def generate(
     def generate(
         self,
         self,
@@ -23,14 +27,23 @@ class LLMChain(LCLLMChain):
         """Generate LLM result from inputs."""
         """Generate LLM result from inputs."""
         prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
         prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
         messages = prompts[0].to_messages()
         messages = prompts[0].to_messages()
-        prompt_messages = to_prompt_messages(messages)
-        result = self.model_instance.run(
-            messages=prompt_messages,
-            stop=stop
+        prompt_messages = lc_messages_to_prompt_messages(messages)
+
+        model_instance = ModelInstance(
+            provider_model_bundle=self.model_config.provider_model_bundle,
+            model=self.model_config.model,
+        )
+
+        result = model_instance.invoke_llm(
+            prompt_messages=prompt_messages,
+            stream=False,
+            stop=stop,
+            callbacks=[self.agent_llm_callback] if self.agent_llm_callback else None,
+            model_parameters=self.parameters
         )
         )
 
 
         generations = [
         generations = [
-            [Generation(text=result.content)]
+            [Generation(text=result.message.content)]
         ]
         ]
 
 
         return LLMResult(generations=generations)
         return LLMResult(generations=generations)

+ 0 - 501
api/core/completion.py

@@ -1,501 +0,0 @@
-import concurrent
-import json
-import logging
-from concurrent.futures import ThreadPoolExecutor
-from typing import Optional, List, Union, Tuple
-
-from flask import current_app, Flask
-from requests.exceptions import ChunkedEncodingError
-
-from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
-from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
-from core.callback_handler.llm_callback_handler import LLMCallbackHandler
-from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
-    ConversationTaskInterruptException
-from core.embedding.cached_embedding import CacheEmbedding
-from core.external_data_tool.factory import ExternalDataToolFactory
-from core.file.file_obj import FileObj
-from core.index.vector_index.vector_index import VectorIndex
-from core.model_providers.error import LLMBadRequestError
-from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
-    ReadOnlyConversationTokenDBBufferSharedMemory
-from core.model_providers.model_factory import ModelFactory
-from core.model_providers.models.entity.message import PromptMessage, PromptMessageFile
-from core.model_providers.models.llm.base import BaseLLM
-from core.orchestrator_rule_parser import OrchestratorRuleParser
-from core.prompt.prompt_template import PromptTemplateParser
-from core.prompt.prompt_transform import PromptTransform
-from models.dataset import Dataset
-from models.model import App, AppModelConfig, Account, Conversation, EndUser
-from core.moderation.base import ModerationException, ModerationAction
-from core.moderation.factory import ModerationFactory
-from services.annotation_service import AppAnnotationService
-from services.dataset_service import DatasetCollectionBindingService
-
-
-class Completion:
-    @classmethod
-    def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
-                 files: List[FileObj], user: Union[Account, EndUser], conversation: Optional[Conversation],
-                 streaming: bool, is_override: bool = False, retriever_from: str = 'dev',
-                 auto_generate_name: bool = True, from_source: str = 'console'):
-        """
-        errors: ProviderTokenNotInitError
-        """
-        query = PromptTemplateParser.remove_template_variables(query)
-
-        memory = None
-        if conversation:
-            # get memory of conversation (read-only)
-            memory = cls.get_memory_from_conversation(
-                tenant_id=app.tenant_id,
-                app_model_config=app_model_config,
-                conversation=conversation,
-                return_messages=False
-            )
-
-            inputs = conversation.inputs
-
-        final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
-            tenant_id=app.tenant_id,
-            model_config=app_model_config.model_dict,
-            streaming=streaming
-        )
-
-        conversation_message_task = ConversationMessageTask(
-            task_id=task_id,
-            app=app,
-            app_model_config=app_model_config,
-            user=user,
-            conversation=conversation,
-            is_override=is_override,
-            inputs=inputs,
-            query=query,
-            files=files,
-            streaming=streaming,
-            model_instance=final_model_instance,
-            auto_generate_name=auto_generate_name
-        )
-
-        prompt_message_files = [file.prompt_message_file for file in files]
-
-        rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
-            mode=app.mode,
-            model_instance=final_model_instance,
-            app_model_config=app_model_config,
-            query=query,
-            inputs=inputs,
-            files=prompt_message_files
-        )
-
-        # init orchestrator rule parser
-        orchestrator_rule_parser = OrchestratorRuleParser(
-            tenant_id=app.tenant_id,
-            app_model_config=app_model_config
-        )
-
-        try:
-            chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
-
-            try:
-                # process sensitive_word_avoidance
-                inputs, query = cls.moderation_for_inputs(app.id, app.tenant_id, app_model_config, inputs, query)
-            except ModerationException as e:
-                cls.run_final_llm(
-                    model_instance=final_model_instance,
-                    mode=app.mode,
-                    app_model_config=app_model_config,
-                    query=query,
-                    inputs=inputs,
-                    files=prompt_message_files,
-                    agent_execute_result=None,
-                    conversation_message_task=conversation_message_task,
-                    memory=memory,
-                    fake_response=str(e)
-                )
-                return
-            # check annotation reply
-            annotation_reply = cls.query_app_annotations_to_reply(conversation_message_task, from_source)
-            if annotation_reply:
-                return
-            # fill in variable inputs from external data tools if exists
-            external_data_tools = app_model_config.external_data_tools_list
-            if external_data_tools:
-                inputs = cls.fill_in_inputs_from_external_data_tools(
-                    tenant_id=app.tenant_id,
-                    app_id=app.id,
-                    external_data_tools=external_data_tools,
-                    inputs=inputs,
-                    query=query
-                )
-
-            # get agent executor
-            agent_executor = orchestrator_rule_parser.to_agent_executor(
-                conversation_message_task=conversation_message_task,
-                memory=memory,
-                rest_tokens=rest_tokens_for_context_and_memory,
-                chain_callback=chain_callback,
-                tenant_id=app.tenant_id,
-                retriever_from=retriever_from
-            )
-
-            query_for_agent = cls.get_query_for_agent(app, app_model_config, query, inputs)
-
-            # run agent executor
-            agent_execute_result = None
-            if query_for_agent and agent_executor:
-                should_use_agent = agent_executor.should_use_agent(query_for_agent)
-                if should_use_agent:
-                    agent_execute_result = agent_executor.run(query_for_agent)
-
-            # When no extra pre prompt is specified,
-            # the output of the agent can be used directly as the main output content without calling LLM again
-            fake_response = None
-            if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
-                    and agent_execute_result.strategy not in [PlanningStrategy.ROUTER,
-                                                              PlanningStrategy.REACT_ROUTER]:
-                fake_response = agent_execute_result.output
-
-            # run the final llm
-            cls.run_final_llm(
-                model_instance=final_model_instance,
-                mode=app.mode,
-                app_model_config=app_model_config,
-                query=query,
-                inputs=inputs,
-                files=prompt_message_files,
-                agent_execute_result=agent_execute_result,
-                conversation_message_task=conversation_message_task,
-                memory=memory,
-                fake_response=fake_response
-            )
-        except (ConversationTaskInterruptException, ConversationTaskStoppedException):
-            return
-        except ChunkedEncodingError as e:
-            # Interrupt by LLM (like OpenAI), handle it.
-            logging.warning(f'ChunkedEncodingError: {e}')
-            return
-
-    @classmethod
-    def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict,
-                              query: str):
-        if not app_model_config.sensitive_word_avoidance_dict['enabled']:
-            return inputs, query
-
-        type = app_model_config.sensitive_word_avoidance_dict['type']
-
-        moderation = ModerationFactory(type, app_id, tenant_id,
-                                       app_model_config.sensitive_word_avoidance_dict['config'])
-        moderation_result = moderation.moderation_for_inputs(inputs, query)
-
-        if not moderation_result.flagged:
-            return inputs, query
-
-        if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
-            raise ModerationException(moderation_result.preset_response)
-        elif moderation_result.action == ModerationAction.OVERRIDED:
-            inputs = moderation_result.inputs
-            query = moderation_result.query
-
-        return inputs, query
-
-    @classmethod
-    def fill_in_inputs_from_external_data_tools(cls, tenant_id: str, app_id: str, external_data_tools: list[dict],
-                                                inputs: dict, query: str) -> dict:
-        """
-        Fill in variable inputs from external data tools if exists.
-
-        :param tenant_id: workspace id
-        :param app_id: app id
-        :param external_data_tools: external data tools configs
-        :param inputs: the inputs
-        :param query: the query
-        :return: the filled inputs
-        """
-        # Group tools by type and config
-        grouped_tools = {}
-        for tool in external_data_tools:
-            if not tool.get("enabled"):
-                continue
-
-            tool_key = (tool.get("type"), json.dumps(tool.get("config"), sort_keys=True))
-            grouped_tools.setdefault(tool_key, []).append(tool)
-
-        results = {}
-        with ThreadPoolExecutor() as executor:
-            futures = {}
-            for tool in external_data_tools:
-                if not tool.get("enabled"):
-                    continue
-
-                future = executor.submit(
-                    cls.query_external_data_tool, current_app._get_current_object(), tenant_id, app_id, tool,
-                    inputs, query
-                )
-
-                futures[future] = tool
-
-            for future in concurrent.futures.as_completed(futures):
-                tool_variable, result = future.result()
-                results[tool_variable] = result
-
-        inputs.update(results)
-        return inputs
-
-    @classmethod
-    def query_external_data_tool(cls, flask_app: Flask, tenant_id: str, app_id: str, external_data_tool: dict,
-                                 inputs: dict, query: str) -> Tuple[Optional[str], Optional[str]]:
-        with flask_app.app_context():
-            tool_variable = external_data_tool.get("variable")
-            tool_type = external_data_tool.get("type")
-            tool_config = external_data_tool.get("config")
-
-            external_data_tool_factory = ExternalDataToolFactory(
-                name=tool_type,
-                tenant_id=tenant_id,
-                app_id=app_id,
-                variable=tool_variable,
-                config=tool_config
-            )
-
-            # query external data tool
-            result = external_data_tool_factory.query(
-                inputs=inputs,
-                query=query
-            )
-
-            return tool_variable, result
-
-    @classmethod
-    def get_query_for_agent(cls, app: App, app_model_config: AppModelConfig, query: str, inputs: dict) -> str:
-        if app.mode != 'completion':
-            return query
-
-        return inputs.get(app_model_config.dataset_query_variable, "")
-
-    @classmethod
-    def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str,
-                      inputs: dict,
-                      files: List[PromptMessageFile],
-                      agent_execute_result: Optional[AgentExecuteResult],
-                      conversation_message_task: ConversationMessageTask,
-                      memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
-                      fake_response: Optional[str]):
-        prompt_transform = PromptTransform()
-
-        # get llm prompt
-        if app_model_config.prompt_type == 'simple':
-            prompt_messages, stop_words = prompt_transform.get_prompt(
-                app_mode=mode,
-                pre_prompt=app_model_config.pre_prompt,
-                inputs=inputs,
-                query=query,
-                files=files,
-                context=agent_execute_result.output if agent_execute_result else None,
-                memory=memory,
-                model_instance=model_instance
-            )
-        else:
-            prompt_messages = prompt_transform.get_advanced_prompt(
-                app_mode=mode,
-                app_model_config=app_model_config,
-                inputs=inputs,
-                query=query,
-                files=files,
-                context=agent_execute_result.output if agent_execute_result else None,
-                memory=memory,
-                model_instance=model_instance
-            )
-
-            model_config = app_model_config.model_dict
-            completion_params = model_config.get("completion_params", {})
-            stop_words = completion_params.get("stop", [])
-
-        cls.recale_llm_max_tokens(
-            model_instance=model_instance,
-            prompt_messages=prompt_messages,
-        )
-
-        response = model_instance.run(
-            messages=prompt_messages,
-            stop=stop_words if stop_words else None,
-            callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
-            fake_response=fake_response
-        )
-        return response
-
-    @classmethod
-    def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
-                                         max_token_limit: int) -> str:
-        """Get memory messages."""
-        memory.max_token_limit = max_token_limit
-        memory_key = memory.memory_variables[0]
-        external_context = memory.load_memory_variables({})
-        return external_context[memory_key]
-
-    @classmethod
-    def query_app_annotations_to_reply(cls, conversation_message_task: ConversationMessageTask,
-                                       from_source: str) -> bool:
-        """Get memory messages."""
-        app_model_config = conversation_message_task.app_model_config
-        app = conversation_message_task.app
-        annotation_reply = app_model_config.annotation_reply_dict
-        if annotation_reply['enabled']:
-            try:
-                score_threshold = annotation_reply.get('score_threshold', 1)
-                embedding_provider_name = annotation_reply['embedding_model']['embedding_provider_name']
-                embedding_model_name = annotation_reply['embedding_model']['embedding_model_name']
-                # get embedding model
-                embedding_model = ModelFactory.get_embedding_model(
-                    tenant_id=app.tenant_id,
-                    model_provider_name=embedding_provider_name,
-                    model_name=embedding_model_name
-                )
-                embeddings = CacheEmbedding(embedding_model)
-
-                dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
-                    embedding_provider_name,
-                    embedding_model_name,
-                    'annotation'
-                )
-
-                dataset = Dataset(
-                    id=app.id,
-                    tenant_id=app.tenant_id,
-                    indexing_technique='high_quality',
-                    embedding_model_provider=embedding_provider_name,
-                    embedding_model=embedding_model_name,
-                    collection_binding_id=dataset_collection_binding.id
-                )
-
-                vector_index = VectorIndex(
-                    dataset=dataset,
-                    config=current_app.config,
-                    embeddings=embeddings,
-                    attributes=['doc_id', 'annotation_id', 'app_id']
-                )
-
-                documents = vector_index.search(
-                    conversation_message_task.query,
-                    search_type='similarity_score_threshold',
-                    search_kwargs={
-                        'k': 1,
-                        'score_threshold': score_threshold,
-                        'filter': {
-                            'group_id': [dataset.id]
-                        }
-                    }
-                )
-                if documents:
-                    annotation_id = documents[0].metadata['annotation_id']
-                    score = documents[0].metadata['score']
-                    annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
-                    if annotation:
-                        conversation_message_task.annotation_end(annotation.content, annotation.id, annotation.account.name)
-                        # insert annotation history
-                        AppAnnotationService.add_annotation_history(annotation.id,
-                                                                    app.id,
-                                                                    annotation.question,
-                                                                    annotation.content,
-                                                                    conversation_message_task.query,
-                                                                    conversation_message_task.user.id,
-                                                                    conversation_message_task.message.id,
-                                                                    from_source,
-                                                                    score)
-                        return True
-            except Exception as e:
-                logging.warning(f'Query annotation failed, exception: {str(e)}.')
-                return False
-        return False
-
-    @classmethod
-    def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig,
-                                     conversation: Conversation,
-                                     **kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory:
-        # only for calc token in memory
-        memory_model_instance = ModelFactory.get_text_generation_model_from_model_config(
-            tenant_id=tenant_id,
-            model_config=app_model_config.model_dict
-        )
-
-        # use llm config from conversation
-        memory = ReadOnlyConversationTokenDBBufferSharedMemory(
-            conversation=conversation,
-            model_instance=memory_model_instance,
-            max_token_limit=kwargs.get("max_token_limit", 2048),
-            memory_key=kwargs.get("memory_key", "chat_history"),
-            return_messages=kwargs.get("return_messages", True),
-            input_key=kwargs.get("input_key", "input"),
-            output_key=kwargs.get("output_key", "output"),
-            message_limit=kwargs.get("message_limit", 10),
-        )
-
-        return memory
-
-    @classmethod
-    def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_config: AppModelConfig,
-                                 query: str, inputs: dict, files: List[PromptMessageFile]) -> int:
-        model_limited_tokens = model_instance.model_rules.max_tokens.max
-        max_tokens = model_instance.get_model_kwargs().max_tokens
-
-        if model_limited_tokens is None:
-            return -1
-
-        if max_tokens is None:
-            max_tokens = 0
-
-        prompt_transform = PromptTransform()
-
-        # get prompt without memory and context
-        if app_model_config.prompt_type == 'simple':
-            prompt_messages, _ = prompt_transform.get_prompt(
-                app_mode=mode,
-                pre_prompt=app_model_config.pre_prompt,
-                inputs=inputs,
-                query=query,
-                files=files,
-                context=None,
-                memory=None,
-                model_instance=model_instance
-            )
-        else:
-            prompt_messages = prompt_transform.get_advanced_prompt(
-                app_mode=mode,
-                app_model_config=app_model_config,
-                inputs=inputs,
-                query=query,
-                files=files,
-                context=None,
-                memory=None,
-                model_instance=model_instance
-            )
-
-        prompt_tokens = model_instance.get_num_tokens(prompt_messages)
-        rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
-        if rest_tokens < 0:
-            raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
-                                     "or shrink the max token, or switch to a llm with a larger token limit size.")
-
-        return rest_tokens
-
-    @classmethod
-    def recale_llm_max_tokens(cls, model_instance: BaseLLM, prompt_messages: List[PromptMessage]):
-        # recalc max_tokens if sum(prompt_token +  max_tokens) over model token limit
-        model_limited_tokens = model_instance.model_rules.max_tokens.max
-        max_tokens = model_instance.get_model_kwargs().max_tokens
-
-        if model_limited_tokens is None:
-            return
-
-        if max_tokens is None:
-            max_tokens = 0
-
-        prompt_tokens = model_instance.get_num_tokens(prompt_messages)
-
-        if prompt_tokens + max_tokens > model_limited_tokens:
-            max_tokens = max(model_limited_tokens - prompt_tokens, 16)
-
-            # update model instance max tokens
-            model_kwargs = model_instance.get_model_kwargs()
-            model_kwargs.max_tokens = max_tokens
-            model_instance.set_model_kwargs(model_kwargs)

+ 0 - 517
api/core/conversation_message_task.py

@@ -1,517 +0,0 @@
-import json
-import time
-from typing import Optional, Union, List
-
-from core.callback_handler.entity.agent_loop import AgentLoop
-from core.callback_handler.entity.dataset_query import DatasetQueryObj
-from core.callback_handler.entity.llm_message import LLMMessage
-from core.callback_handler.entity.chain_result import ChainResult
-from core.file.file_obj import FileObj
-from core.model_providers.model_factory import ModelFactory
-from core.model_providers.models.entity.message import to_prompt_messages, MessageType, PromptMessageFile
-from core.model_providers.models.llm.base import BaseLLM
-from core.prompt.prompt_builder import PromptBuilder
-from core.prompt.prompt_template import PromptTemplateParser
-from events.message_event import message_was_created
-from extensions.ext_database import db
-from extensions.ext_redis import redis_client
-from models.dataset import DatasetQuery
-from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, \
-    MessageChain, DatasetRetrieverResource, MessageFile
-
-
-class ConversationMessageTask:
-    def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
-                 inputs: dict, query: str, files: List[FileObj], streaming: bool,
-                 model_instance: BaseLLM, conversation: Optional[Conversation] = None, is_override: bool = False,
-                 auto_generate_name: bool = True):
-        self.start_at = time.perf_counter()
-
-        self.task_id = task_id
-
-        self.app = app
-        self.tenant_id = app.tenant_id
-        self.app_model_config = app_model_config
-        self.is_override = is_override
-
-        self.user = user
-        self.inputs = inputs
-        self.query = query
-        self.files = files
-        self.streaming = streaming
-
-        self.conversation = conversation
-        self.is_new_conversation = False
-
-        self.model_instance = model_instance
-
-        self.message = None
-
-        self.retriever_resource = None
-        self.auto_generate_name = auto_generate_name
-
-        self.model_dict = self.app_model_config.model_dict
-        self.provider_name = self.model_dict.get('provider')
-        self.model_name = self.model_dict.get('name')
-        self.mode = app.mode
-
-        self.init()
-
-        self._pub_handler = PubHandler(
-            user=self.user,
-            task_id=self.task_id,
-            message=self.message,
-            conversation=self.conversation,
-            chain_pub=False,  # disabled currently
-            agent_thought_pub=True
-        )
-
-    def init(self):
-
-        override_model_configs = None
-        if self.is_override:
-            override_model_configs = self.app_model_config.to_dict()
-
-        introduction = ''
-        system_instruction = ''
-        system_instruction_tokens = 0
-        if self.mode == 'chat':
-            introduction = self.app_model_config.opening_statement
-            if introduction:
-                prompt_template = PromptTemplateParser(template=introduction)
-                prompt_inputs = {k: self.inputs[k] for k in prompt_template.variable_keys if k in self.inputs}
-                try:
-                    introduction = prompt_template.format(prompt_inputs)
-                except KeyError:
-                    pass
-
-            if self.app_model_config.pre_prompt:
-                system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs)
-                system_instruction = system_message.content
-                model_instance = ModelFactory.get_text_generation_model(
-                    tenant_id=self.tenant_id,
-                    model_provider_name=self.provider_name,
-                    model_name=self.model_name
-                )
-                system_instruction_tokens = model_instance.get_num_tokens(to_prompt_messages([system_message]))
-
-        if not self.conversation:
-            self.is_new_conversation = True
-            self.conversation = Conversation(
-                app_id=self.app.id,
-                app_model_config_id=self.app_model_config.id,
-                model_provider=self.provider_name,
-                model_id=self.model_name,
-                override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
-                mode=self.mode,
-                name='New conversation',
-                inputs=self.inputs,
-                introduction=introduction,
-                system_instruction=system_instruction,
-                system_instruction_tokens=system_instruction_tokens,
-                status='normal',
-                from_source=('console' if isinstance(self.user, Account) else 'api'),
-                from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
-                from_account_id=(self.user.id if isinstance(self.user, Account) else None),
-            )
-
-            db.session.add(self.conversation)
-            db.session.commit()
-
-        self.message = Message(
-            app_id=self.app.id,
-            model_provider=self.provider_name,
-            model_id=self.model_name,
-            override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
-            conversation_id=self.conversation.id,
-            inputs=self.inputs,
-            query=self.query,
-            message="",
-            message_tokens=0,
-            message_unit_price=0,
-            message_price_unit=0,
-            answer="",
-            answer_tokens=0,
-            answer_unit_price=0,
-            answer_price_unit=0,
-            provider_response_latency=0,
-            total_price=0,
-            currency=self.model_instance.get_currency(),
-            from_source=('console' if isinstance(self.user, Account) else 'api'),
-            from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
-            from_account_id=(self.user.id if isinstance(self.user, Account) else None),
-            agent_based=self.app_model_config.agent_mode_dict.get('enabled'),
-        )
-
-        db.session.add(self.message)
-        db.session.commit()
-
-        for file in self.files:
-            message_file = MessageFile(
-                message_id=self.message.id,
-                type=file.type.value,
-                transfer_method=file.transfer_method.value,
-                url=file.url,
-                upload_file_id=file.upload_file_id,
-                created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
-                created_by=self.user.id
-            )
-            db.session.add(message_file)
-            db.session.commit()
-
-    def append_message_text(self, text: str):
-        if text is not None:
-            self._pub_handler.pub_text(text)
-
-    def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
-        message_tokens = llm_message.prompt_tokens
-        answer_tokens = llm_message.completion_tokens
-
-        message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.USER)
-        message_price_unit = self.model_instance.get_price_unit(MessageType.USER)
-        answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
-        answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT)
-
-        message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.USER)
-        answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
-        total_price = message_total_price + answer_total_price
-
-        self.message.message = llm_message.prompt
-        self.message.message_tokens = message_tokens
-        self.message.message_unit_price = message_unit_price
-        self.message.message_price_unit = message_price_unit
-        self.message.answer = PromptTemplateParser.remove_template_variables(
-            llm_message.completion.strip()) if llm_message.completion else ''
-        self.message.answer_tokens = answer_tokens
-        self.message.answer_unit_price = answer_unit_price
-        self.message.answer_price_unit = answer_price_unit
-        self.message.provider_response_latency = time.perf_counter() - self.start_at
-        self.message.total_price = total_price
-
-        db.session.commit()
-
-        message_was_created.send(
-            self.message,
-            conversation=self.conversation,
-            is_first_message=self.is_new_conversation,
-            auto_generate_name=self.auto_generate_name
-        )
-
-        if not by_stopped:
-            self.end()
-
-    def init_chain(self, chain_result: ChainResult):
-        message_chain = MessageChain(
-            message_id=self.message.id,
-            type=chain_result.type,
-            input=json.dumps(chain_result.prompt),
-            output=''
-        )
-
-        db.session.add(message_chain)
-        db.session.commit()
-
-        return message_chain
-
-    def on_chain_end(self, message_chain: MessageChain, chain_result: ChainResult):
-        message_chain.output = json.dumps(chain_result.completion)
-        db.session.commit()
-
-        self._pub_handler.pub_chain(message_chain)
-
-    def on_agent_start(self, message_chain: MessageChain, agent_loop: AgentLoop) -> MessageAgentThought:
-        message_agent_thought = MessageAgentThought(
-            message_id=self.message.id,
-            message_chain_id=message_chain.id,
-            position=agent_loop.position,
-            thought=agent_loop.thought,
-            tool=agent_loop.tool_name,
-            tool_input=agent_loop.tool_input,
-            message=agent_loop.prompt,
-            message_price_unit=0,
-            answer=agent_loop.completion,
-            answer_price_unit=0,
-            created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
-            created_by=self.user.id
-        )
-
-        db.session.add(message_agent_thought)
-        db.session.commit()
-
-        self._pub_handler.pub_agent_thought(message_agent_thought)
-
-        return message_agent_thought
-
-    def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM,
-                     agent_loop: AgentLoop):
-        agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.USER)
-        agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.USER)
-        agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
-        agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT)
-
-        loop_message_tokens = agent_loop.prompt_tokens
-        loop_answer_tokens = agent_loop.completion_tokens
-
-        loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.USER)
-        loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
-        loop_total_price = loop_message_total_price + loop_answer_total_price
-
-        message_agent_thought.observation = agent_loop.tool_output
-        message_agent_thought.tool_process_data = ''  # currently not support
-        message_agent_thought.message_token = loop_message_tokens
-        message_agent_thought.message_unit_price = agent_message_unit_price
-        message_agent_thought.message_price_unit = agent_message_price_unit
-        message_agent_thought.answer_token = loop_answer_tokens
-        message_agent_thought.answer_unit_price = agent_answer_unit_price
-        message_agent_thought.answer_price_unit = agent_answer_price_unit
-        message_agent_thought.latency = agent_loop.latency
-        message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
-        message_agent_thought.total_price = loop_total_price
-        message_agent_thought.currency = agent_model_instance.get_currency()
-        db.session.commit()
-
-    def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):
-        dataset_query = DatasetQuery(
-            dataset_id=dataset_query_obj.dataset_id,
-            content=dataset_query_obj.query,
-            source='app',
-            source_app_id=self.app.id,
-            created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
-            created_by=self.user.id
-        )
-
-        db.session.add(dataset_query)
-        db.session.commit()
-
-    def on_dataset_query_finish(self, resource: List):
-        if resource and len(resource) > 0:
-            for item in resource:
-                dataset_retriever_resource = DatasetRetrieverResource(
-                    message_id=self.message.id,
-                    position=item.get('position'),
-                    dataset_id=item.get('dataset_id'),
-                    dataset_name=item.get('dataset_name'),
-                    document_id=item.get('document_id'),
-                    document_name=item.get('document_name'),
-                    data_source_type=item.get('data_source_type'),
-                    segment_id=item.get('segment_id'),
-                    score=item.get('score') if 'score' in item else None,
-                    hit_count=item.get('hit_count') if 'hit_count' else None,
-                    word_count=item.get('word_count') if 'word_count' in item else None,
-                    segment_position=item.get('segment_position') if 'segment_position' in item else None,
-                    index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None,
-                    content=item.get('content'),
-                    retriever_from=item.get('retriever_from'),
-                    created_by=self.user.id
-                )
-                db.session.add(dataset_retriever_resource)
-                db.session.commit()
-            self.retriever_resource = resource
-
-    def on_message_replace(self, text: str):
-        if text is not None:
-            self._pub_handler.pub_message_replace(text)
-
-    def message_end(self):
-        self._pub_handler.pub_message_end(self.retriever_resource)
-
-    def end(self):
-        self._pub_handler.pub_message_end(self.retriever_resource)
-        self._pub_handler.pub_end()
-
-    def annotation_end(self, text: str, annotation_id: str, annotation_author_name: str):
-        self._pub_handler.pub_annotation(text, annotation_id, annotation_author_name, self.start_at)
-        self._pub_handler.pub_end()
-
-
-class PubHandler:
-    def __init__(self, user: Union[Account, EndUser], task_id: str,
-                 message: Message, conversation: Conversation,
-                 chain_pub: bool = False, agent_thought_pub: bool = False):
-        self._channel = PubHandler.generate_channel_name(user, task_id)
-        self._stopped_cache_key = PubHandler.generate_stopped_cache_key(user, task_id)
-
-        self._task_id = task_id
-        self._message = message
-        self._conversation = conversation
-        self._chain_pub = chain_pub
-        self._agent_thought_pub = agent_thought_pub
-
-    @classmethod
-    def generate_channel_name(cls, user: Union[Account, EndUser], task_id: str):
-        if not user:
-            raise ValueError("user is required")
-
-        user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
-        return "generate_result:{}-{}".format(user_str, task_id)
-
-    @classmethod
-    def generate_stopped_cache_key(cls, user: Union[Account, EndUser], task_id: str):
-        user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
-        return "generate_result_stopped:{}-{}".format(user_str, task_id)
-
-    def pub_text(self, text: str):
-        content = {
-            'event': 'message',
-            'data': {
-                'task_id': self._task_id,
-                'message_id': str(self._message.id),
-                'text': text,
-                'mode': self._conversation.mode,
-                'conversation_id': str(self._conversation.id)
-            }
-        }
-
-        redis_client.publish(self._channel, json.dumps(content))
-
-        if self._is_stopped():
-            self.pub_end()
-            raise ConversationTaskStoppedException()
-
-    def pub_message_replace(self, text: str):
-        content = {
-            'event': 'message_replace',
-            'data': {
-                'task_id': self._task_id,
-                'message_id': str(self._message.id),
-                'text': text,
-                'mode': self._conversation.mode,
-                'conversation_id': str(self._conversation.id)
-            }
-        }
-
-        redis_client.publish(self._channel, json.dumps(content))
-
-        if self._is_stopped():
-            self.pub_end()
-            raise ConversationTaskStoppedException()
-
-    def pub_chain(self, message_chain: MessageChain):
-        if self._chain_pub:
-            content = {
-                'event': 'chain',
-                'data': {
-                    'task_id': self._task_id,
-                    'message_id': self._message.id,
-                    'chain_id': message_chain.id,
-                    'type': message_chain.type,
-                    'input': json.loads(message_chain.input),
-                    'output': json.loads(message_chain.output),
-                    'mode': self._conversation.mode,
-                    'conversation_id': self._conversation.id
-                }
-            }
-
-            redis_client.publish(self._channel, json.dumps(content))
-
-        if self._is_stopped():
-            self.pub_end()
-            raise ConversationTaskStoppedException()
-
-    def pub_agent_thought(self, message_agent_thought: MessageAgentThought):
-        if self._agent_thought_pub:
-            content = {
-                'event': 'agent_thought',
-                'data': {
-                    'id': message_agent_thought.id,
-                    'task_id': self._task_id,
-                    'message_id': self._message.id,
-                    'chain_id': message_agent_thought.message_chain_id,
-                    'position': message_agent_thought.position,
-                    'thought': message_agent_thought.thought,
-                    'tool': message_agent_thought.tool,
-                    'tool_input': message_agent_thought.tool_input,
-                    'mode': self._conversation.mode,
-                    'conversation_id': self._conversation.id
-                }
-            }
-
-            redis_client.publish(self._channel, json.dumps(content))
-
-        if self._is_stopped():
-            self.pub_end()
-            raise ConversationTaskStoppedException()
-
-    def pub_message_end(self, retriever_resource: List):
-        content = {
-            'event': 'message_end',
-            'data': {
-                'task_id': self._task_id,
-                'message_id': self._message.id,
-                'mode': self._conversation.mode,
-                'conversation_id': self._conversation.id,
-            }
-        }
-        if retriever_resource:
-            content['data']['retriever_resources'] = retriever_resource
-        redis_client.publish(self._channel, json.dumps(content))
-
-        if self._is_stopped():
-            self.pub_end()
-            raise ConversationTaskStoppedException()
-
-    def pub_annotation(self, text: str, annotation_id: str, annotation_author_name: str, start_at: float):
-        content = {
-            'event': 'annotation',
-            'data': {
-                'task_id': self._task_id,
-                'message_id': self._message.id,
-                'mode': self._conversation.mode,
-                'conversation_id': self._conversation.id,
-                'text': text,
-                'annotation_id': annotation_id,
-                'annotation_author_name': annotation_author_name
-            }
-        }
-        self._message.answer = text
-        self._message.provider_response_latency = time.perf_counter() - start_at
-
-        db.session.commit()
-
-        redis_client.publish(self._channel, json.dumps(content))
-
-        if self._is_stopped():
-            self.pub_end()
-            raise ConversationTaskStoppedException()
-
-    def pub_end(self):
-        content = {
-            'event': 'end',
-        }
-
-        redis_client.publish(self._channel, json.dumps(content))
-
-    @classmethod
-    def pub_error(cls, user: Union[Account, EndUser], task_id: str, e):
-        content = {
-            'error': type(e).__name__,
-            'description': e.description if getattr(e, 'description', None) is not None else str(e)
-        }
-
-        channel = cls.generate_channel_name(user, task_id)
-        redis_client.publish(channel, json.dumps(content))
-
-    def _is_stopped(self):
-        return redis_client.get(self._stopped_cache_key) is not None
-
-    @classmethod
-    def ping(cls, user: Union[Account, EndUser], task_id: str):
-        content = {
-            'event': 'ping'
-        }
-
-        channel = cls.generate_channel_name(user, task_id)
-        redis_client.publish(channel, json.dumps(content))
-
-    @classmethod
-    def stop(cls, user: Union[Account, EndUser], task_id: str):
-        stopped_cache_key = cls.generate_stopped_cache_key(user, task_id)
-        redis_client.setex(stopped_cache_key, 600, 1)
-
-
-class ConversationTaskStoppedException(Exception):
-    pass
-
-
-class ConversationTaskInterruptException(Exception):
-    pass

+ 19 - 6
api/core/docstore/dataset_docstore.py

@@ -1,9 +1,11 @@
-from typing import Any, Dict, Optional, Sequence
+from typing import Any, Dict, Optional, Sequence, cast
 
 
 from langchain.schema import Document
 from langchain.schema import Document
 from sqlalchemy import func
 from sqlalchemy import func
 
 
-from core.model_providers.model_factory import ModelFactory
+from core.model_manager import ModelManager
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
 from extensions.ext_database import db
 from extensions.ext_database import db
 from models.dataset import Dataset, DocumentSegment
 from models.dataset import Dataset, DocumentSegment
 
 
@@ -69,10 +71,12 @@ class DatasetDocumentStore:
             max_position = 0
             max_position = 0
         embedding_model = None
         embedding_model = None
         if self._dataset.indexing_technique == 'high_quality':
         if self._dataset.indexing_technique == 'high_quality':
-            embedding_model = ModelFactory.get_embedding_model(
+            model_manager = ModelManager()
+            embedding_model = model_manager.get_model_instance(
                 tenant_id=self._dataset.tenant_id,
                 tenant_id=self._dataset.tenant_id,
-                model_provider_name=self._dataset.embedding_model_provider,
-                model_name=self._dataset.embedding_model
+                provider=self._dataset.embedding_model_provider,
+                model_type=ModelType.TEXT_EMBEDDING,
+                model=self._dataset.embedding_model
             )
             )
 
 
         for doc in docs:
         for doc in docs:
@@ -89,7 +93,16 @@ class DatasetDocumentStore:
                 )
                 )
 
 
             # calc embedding use tokens
             # calc embedding use tokens
-            tokens = embedding_model.get_num_tokens(doc.page_content) if embedding_model else 0
+            if embedding_model:
+                model_type_instance = embedding_model.model_type_instance
+                model_type_instance = cast(TextEmbeddingModel, model_type_instance)
+                tokens = model_type_instance.get_num_tokens(
+                    model=embedding_model.model,
+                    credentials=embedding_model.credentials,
+                    texts=[doc.page_content]
+                )
+            else:
+                tokens = 0
 
 
             if not segment_document:
             if not segment_document:
                 max_position += 1
                 max_position += 1

+ 26 - 13
api/core/embedding/cached_embedding.py

@@ -1,19 +1,22 @@
 import logging
 import logging
-from typing import List
+from typing import List, Optional
 
 
 import numpy as np
 import numpy as np
 from langchain.embeddings.base import Embeddings
 from langchain.embeddings.base import Embeddings
 from sqlalchemy.exc import IntegrityError
 from sqlalchemy.exc import IntegrityError
 
 
-from core.model_providers.models.embedding.base import BaseEmbedding
+from core.model_manager import ModelInstance
 from extensions.ext_database import db
 from extensions.ext_database import db
 from libs import helper
 from libs import helper
 from models.dataset import Embedding
 from models.dataset import Embedding
 
 
+logger = logging.getLogger(__name__)
+
 
 
 class CacheEmbedding(Embeddings):
 class CacheEmbedding(Embeddings):
-    def __init__(self, embeddings: BaseEmbedding):
-        self._embeddings = embeddings
+    def __init__(self, model_instance: ModelInstance, user: Optional[str] = None) -> None:
+        self._model_instance = model_instance
+        self._user = user
 
 
     def embed_documents(self, texts: List[str]) -> List[List[float]]:
     def embed_documents(self, texts: List[str]) -> List[List[float]]:
         """Embed search docs."""
         """Embed search docs."""
@@ -22,7 +25,7 @@ class CacheEmbedding(Embeddings):
         embedding_queue_indices = []
         embedding_queue_indices = []
         for i, text in enumerate(texts):
         for i, text in enumerate(texts):
             hash = helper.generate_text_hash(text)
             hash = helper.generate_text_hash(text)
-            embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first()
+            embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first()
             if embedding:
             if embedding:
                 text_embeddings[i] = embedding.get_embedding()
                 text_embeddings[i] = embedding.get_embedding()
             else:
             else:
@@ -30,15 +33,21 @@ class CacheEmbedding(Embeddings):
 
 
         if embedding_queue_indices:
         if embedding_queue_indices:
             try:
             try:
-                embedding_results = self._embeddings.client.embed_documents([texts[i] for i in embedding_queue_indices])
+                embedding_result = self._model_instance.invoke_text_embedding(
+                    texts=[texts[i] for i in embedding_queue_indices],
+                    user=self._user
+                )
+
+                embedding_results = embedding_result.embeddings
             except Exception as ex:
             except Exception as ex:
-                raise self._embeddings.handle_exceptions(ex)
+                logger.error('Failed to embed documents: ', ex)
+                raise ex
 
 
             for i, indice in enumerate(embedding_queue_indices):
             for i, indice in enumerate(embedding_queue_indices):
                 hash = helper.generate_text_hash(texts[indice])
                 hash = helper.generate_text_hash(texts[indice])
 
 
                 try:
                 try:
-                    embedding = Embedding(model_name=self._embeddings.name, hash=hash)
+                    embedding = Embedding(model_name=self._model_instance.model, hash=hash)
                     vector = embedding_results[i]
                     vector = embedding_results[i]
                     normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
                     normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
                     text_embeddings[indice] = normalized_embedding
                     text_embeddings[indice] = normalized_embedding
@@ -58,18 +67,23 @@ class CacheEmbedding(Embeddings):
         """Embed query text."""
         """Embed query text."""
         # use doc embedding cache or store if not exists
         # use doc embedding cache or store if not exists
         hash = helper.generate_text_hash(text)
         hash = helper.generate_text_hash(text)
-        embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first()
+        embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first()
         if embedding:
         if embedding:
             return embedding.get_embedding()
             return embedding.get_embedding()
 
 
         try:
         try:
-            embedding_results = self._embeddings.client.embed_query(text)
+            embedding_result = self._model_instance.invoke_text_embedding(
+                texts=[text],
+                user=self._user
+            )
+
+            embedding_results = embedding_result.embeddings[0]
             embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
             embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
         except Exception as ex:
         except Exception as ex:
-            raise self._embeddings.handle_exceptions(ex)
+            raise ex
 
 
         try:
         try:
-            embedding = Embedding(model_name=self._embeddings.name, hash=hash)
+            embedding = Embedding(model_name=self._model_instance.model, hash=hash)
             embedding.set_embedding(embedding_results)
             embedding.set_embedding(embedding_results)
             db.session.add(embedding)
             db.session.add(embedding)
             db.session.commit()
             db.session.commit()
@@ -79,4 +93,3 @@ class CacheEmbedding(Embeddings):
             logging.exception('Failed to add embedding to db')
             logging.exception('Failed to add embedding to db')
 
 
         return embedding_results
         return embedding_results
-

+ 0 - 0
api/core/model_providers/models/embedding/__init__.py → api/core/entities/__init__.py


+ 265 - 0
api/core/entities/application_entities.py

@@ -0,0 +1,265 @@
+from enum import Enum
+from typing import Optional, Any, cast
+
+from pydantic import BaseModel
+
+from core.entities.provider_configuration import ProviderModelBundle
+from core.file.file_obj import FileObj
+from core.model_runtime.entities.message_entities import PromptMessageRole
+from core.model_runtime.entities.model_entities import AIModelEntity
+
+
+class ModelConfigEntity(BaseModel):
+    """
+    Model Config Entity.
+    """
+    provider: str
+    model: str
+    model_schema: AIModelEntity
+    mode: str
+    provider_model_bundle: ProviderModelBundle
+    credentials: dict[str, Any] = {}
+    parameters: dict[str, Any] = {}
+    stop: list[str] = []
+
+
+class AdvancedChatMessageEntity(BaseModel):
+    """
+    Advanced Chat Message Entity.
+    """
+    text: str
+    role: PromptMessageRole
+
+
+class AdvancedChatPromptTemplateEntity(BaseModel):
+    """
+    Advanced Chat Prompt Template Entity.
+    """
+    messages: list[AdvancedChatMessageEntity]
+
+
+class AdvancedCompletionPromptTemplateEntity(BaseModel):
+    """
+    Advanced Completion Prompt Template Entity.
+    """
+    class RolePrefixEntity(BaseModel):
+        """
+        Role Prefix Entity.
+        """
+        user: str
+        assistant: str
+
+    prompt: str
+    role_prefix: Optional[RolePrefixEntity] = None
+
+
+class PromptTemplateEntity(BaseModel):
+    """
+    Prompt Template Entity.
+    """
+    class PromptType(Enum):
+        """
+        Prompt Type.
+        'simple', 'advanced'
+        """
+        SIMPLE = 'simple'
+        ADVANCED = 'advanced'
+
+        @classmethod
+        def value_of(cls, value: str) -> 'PromptType':
+            """
+            Get value of given mode.
+
+            :param value: mode value
+            :return: mode
+            """
+            for mode in cls:
+                if mode.value == value:
+                    return mode
+            raise ValueError(f'invalid prompt type value {value}')
+
+    prompt_type: PromptType
+    simple_prompt_template: Optional[str] = None
+    advanced_chat_prompt_template: Optional[AdvancedChatPromptTemplateEntity] = None
+    advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None
+
+
+class ExternalDataVariableEntity(BaseModel):
+    """
+    External Data Variable Entity.
+    """
+    variable: str
+    type: str
+    config: dict[str, Any] = {}
+
+
+class DatasetRetrieveConfigEntity(BaseModel):
+    """
+    Dataset Retrieve Config Entity.
+    """
+    class RetrieveStrategy(Enum):
+        """
+        Dataset Retrieve Strategy.
+        'single' or 'multiple'
+        """
+        SINGLE = 'single'
+        MULTIPLE = 'multiple'
+
+        @classmethod
+        def value_of(cls, value: str) -> 'RetrieveStrategy':
+            """
+            Get value of given mode.
+
+            :param value: mode value
+            :return: mode
+            """
+            for mode in cls:
+                if mode.value == value:
+                    return mode
+            raise ValueError(f'invalid retrieve strategy value {value}')
+
+    query_variable: Optional[str] = None  # Only when app mode is completion
+
+    retrieve_strategy: RetrieveStrategy
+    single_strategy: Optional[str] = None  # for temp
+    top_k: Optional[int] = None
+    score_threshold: Optional[float] = None
+    reranking_model: Optional[dict] = None
+
+
+class DatasetEntity(BaseModel):
+    """
+    Dataset Config Entity.
+    """
+    dataset_ids: list[str]
+    retrieve_config: DatasetRetrieveConfigEntity
+
+
+class SensitiveWordAvoidanceEntity(BaseModel):
+    """
+    Sensitive Word Avoidance Entity.
+    """
+    type: str
+    config: dict[str, Any] = {}
+
+
+class FileUploadEntity(BaseModel):
+    """
+    File Upload Entity.
+    """
+    image_config: Optional[dict[str, Any]] = None
+
+
+class AgentToolEntity(BaseModel):
+    """
+    Agent Tool Entity.
+    """
+    tool_id: str
+    config: dict[str, Any] = {}
+
+
+class AgentEntity(BaseModel):
+    """
+    Agent Entity.
+    """
+    class Strategy(Enum):
+        """
+        Agent Strategy.
+        """
+        CHAIN_OF_THOUGHT = 'chain-of-thought'
+        FUNCTION_CALLING = 'function-calling'
+
+    provider: str
+    model: str
+    strategy: Strategy
+    tools: list[AgentToolEntity] = []
+
+
+class AppOrchestrationConfigEntity(BaseModel):
+    """
+    App Orchestration Config Entity.
+    """
+    model_config: ModelConfigEntity
+    prompt_template: PromptTemplateEntity
+    external_data_variables: list[ExternalDataVariableEntity] = []
+    agent: Optional[AgentEntity] = None
+
+    # features
+    dataset: Optional[DatasetEntity] = None
+    file_upload: Optional[FileUploadEntity] = None
+    opening_statement: Optional[str] = None
+    suggested_questions_after_answer: bool = False
+    show_retrieve_source: bool = False
+    more_like_this: bool = False
+    speech_to_text: bool = False
+    sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
+
+
+class InvokeFrom(Enum):
+    """
+    Invoke From.
+    """
+    SERVICE_API = 'service-api'
+    WEB_APP = 'web-app'
+    EXPLORE = 'explore'
+    DEBUGGER = 'debugger'
+
+    @classmethod
+    def value_of(cls, value: str) -> 'InvokeFrom':
+        """
+        Get value of given mode.
+
+        :param value: mode value
+        :return: mode
+        """
+        for mode in cls:
+            if mode.value == value:
+                return mode
+        raise ValueError(f'invalid invoke from value {value}')
+
+    def to_source(self) -> str:
+        """
+        Get source of invoke from.
+
+        :return: source
+        """
+        if self == InvokeFrom.WEB_APP:
+            return 'web_app'
+        elif self == InvokeFrom.DEBUGGER:
+            return 'dev'
+        elif self == InvokeFrom.EXPLORE:
+            return 'explore_app'
+        elif self == InvokeFrom.SERVICE_API:
+            return 'api'
+
+        return 'dev'
+
+
+class ApplicationGenerateEntity(BaseModel):
+    """
+    Application Generate Entity.
+    """
+    task_id: str
+    tenant_id: str
+
+    app_id: str
+    app_model_config_id: str
+    # for save
+    app_model_config_dict: dict
+    app_model_config_override: bool
+
+    # Converted from app_model_config to Entity object, or directly covered by external input
+    app_orchestration_config_entity: AppOrchestrationConfigEntity
+
+    conversation_id: Optional[str] = None
+    inputs: dict[str, str]
+    query: Optional[str] = None
+    files: list[FileObj] = []
+    user_id: str
+
+    # extras
+    stream: bool
+    invoke_from: InvokeFrom
+
+    # extra parameters, like: auto_generate_conversation_name
+    extras: dict[str, Any] = {}

+ 128 - 0
api/core/entities/message_entities.py

@@ -0,0 +1,128 @@
+import enum
+from typing import Any, cast
+
+from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage, FunctionMessage
+from pydantic import BaseModel
+
+from core.model_runtime.entities.message_entities import PromptMessage, UserPromptMessage, TextPromptMessageContent, \
+    ImagePromptMessageContent, AssistantPromptMessage, SystemPromptMessage, ToolPromptMessage
+
+
+class PromptMessageFileType(enum.Enum):
+    IMAGE = 'image'
+
+    @staticmethod
+    def value_of(value):
+        for member in PromptMessageFileType:
+            if member.value == value:
+                return member
+        raise ValueError(f"No matching enum found for value '{value}'")
+
+
+class PromptMessageFile(BaseModel):
+    type: PromptMessageFileType
+    data: Any
+
+
+class ImagePromptMessageFile(PromptMessageFile):
+    class DETAIL(enum.Enum):
+        LOW = 'low'
+        HIGH = 'high'
+
+    type: PromptMessageFileType = PromptMessageFileType.IMAGE
+    detail: DETAIL = DETAIL.LOW
+
+
+class LCHumanMessageWithFiles(HumanMessage):
+    # content: Union[str, List[Union[str, Dict]]]
+    content: str
+    files: list[PromptMessageFile]
+
+
+def lc_messages_to_prompt_messages(messages: list[BaseMessage]) -> list[PromptMessage]:
+    prompt_messages = []
+    for message in messages:
+        if isinstance(message, HumanMessage):
+            if isinstance(message, LCHumanMessageWithFiles):
+                file_prompt_message_contents = []
+                for file in message.files:
+                    if file.type == PromptMessageFileType.IMAGE:
+                        file = cast(ImagePromptMessageFile, file)
+                        file_prompt_message_contents.append(ImagePromptMessageContent(
+                            data=file.data,
+                            detail=ImagePromptMessageContent.DETAIL.HIGH
+                            if file.detail.value == "high" else ImagePromptMessageContent.DETAIL.LOW
+                        ))
+
+                prompt_message_contents = [TextPromptMessageContent(data=message.content)]
+                prompt_message_contents.extend(file_prompt_message_contents)
+
+                prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
+            else:
+                prompt_messages.append(UserPromptMessage(content=message.content))
+        elif isinstance(message, AIMessage):
+            message_kwargs = {
+                'content': message.content
+            }
+
+            if 'function_call' in message.additional_kwargs:
+                message_kwargs['tool_calls'] = [
+                    AssistantPromptMessage.ToolCall(
+                        id=message.additional_kwargs['function_call']['id'],
+                        type='function',
+                        function=AssistantPromptMessage.ToolCall.ToolCallFunction(
+                            name=message.additional_kwargs['function_call']['name'],
+                            arguments=message.additional_kwargs['function_call']['arguments']
+                        )
+                    )
+                ]
+
+            prompt_messages.append(AssistantPromptMessage(**message_kwargs))
+        elif isinstance(message, SystemMessage):
+            prompt_messages.append(SystemPromptMessage(content=message.content))
+        elif isinstance(message, FunctionMessage):
+            prompt_messages.append(ToolPromptMessage(content=message.content, tool_call_id=message.name))
+
+    return prompt_messages
+
+
+def prompt_messages_to_lc_messages(prompt_messages: list[PromptMessage]) -> list[BaseMessage]:
+    messages = []
+    for prompt_message in prompt_messages:
+        if isinstance(prompt_message, UserPromptMessage):
+            if isinstance(prompt_message.content, str):
+                messages.append(HumanMessage(content=prompt_message.content))
+            else:
+                message_contents = []
+                for content in prompt_message.content:
+                    if isinstance(content, TextPromptMessageContent):
+                        message_contents.append(content.data)
+                    elif isinstance(content, ImagePromptMessageContent):
+                        message_contents.append({
+                            'type': 'image',
+                            'data': content.data,
+                            'detail': content.detail.value
+                        })
+
+                messages.append(HumanMessage(content=message_contents))
+        elif isinstance(prompt_message, AssistantPromptMessage):
+            message_kwargs = {
+                'content': prompt_message.content
+            }
+
+            if prompt_message.tool_calls:
+                message_kwargs['additional_kwargs'] = {
+                    'function_call': {
+                        'id': prompt_message.tool_calls[0].id,
+                        'name': prompt_message.tool_calls[0].function.name,
+                        'arguments': prompt_message.tool_calls[0].function.arguments
+                    }
+                }
+
+            messages.append(AIMessage(**message_kwargs))
+        elif isinstance(prompt_message, SystemPromptMessage):
+            messages.append(SystemMessage(content=prompt_message.content))
+        elif isinstance(prompt_message, ToolPromptMessage):
+            messages.append(FunctionMessage(name=prompt_message.tool_call_id, content=prompt_message.content))
+
+    return messages

+ 71 - 0
api/core/entities/model_entities.py

@@ -0,0 +1,71 @@
+from enum import Enum
+from typing import Optional
+
+from pydantic import BaseModel
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.model_entities import ProviderModel, ModelType
+from core.model_runtime.entities.provider_entities import SimpleProviderEntity, ProviderEntity
+
+
+class ModelStatus(Enum):
+    """
+    Enum class for model status.
+    """
+    ACTIVE = "active"
+    NO_CONFIGURE = "no-configure"
+    QUOTA_EXCEEDED = "quota-exceeded"
+    NO_PERMISSION = "no-permission"
+
+
+class SimpleModelProviderEntity(BaseModel):
+    """
+    Simple provider.
+    """
+    provider: str
+    label: I18nObject
+    icon_small: Optional[I18nObject] = None
+    icon_large: Optional[I18nObject] = None
+    supported_model_types: list[ModelType]
+
+    def __init__(self, provider_entity: ProviderEntity) -> None:
+        """
+        Init simple provider.
+
+        :param provider_entity: provider entity
+        """
+        super().__init__(
+            provider=provider_entity.provider,
+            label=provider_entity.label,
+            icon_small=provider_entity.icon_small,
+            icon_large=provider_entity.icon_large,
+            supported_model_types=provider_entity.supported_model_types
+        )
+
+
+class ModelWithProviderEntity(ProviderModel):
+    """
+    Model with provider entity.
+    """
+    provider: SimpleModelProviderEntity
+    status: ModelStatus
+
+
+class DefaultModelProviderEntity(BaseModel):
+    """
+    Default model provider entity.
+    """
+    provider: str
+    label: I18nObject
+    icon_small: Optional[I18nObject] = None
+    icon_large: Optional[I18nObject] = None
+    supported_model_types: list[ModelType]
+
+
+class DefaultModelEntity(BaseModel):
+    """
+    Default model entity.
+    """
+    model: str
+    model_type: ModelType
+    provider: DefaultModelProviderEntity

+ 657 - 0
api/core/entities/provider_configuration.py

@@ -0,0 +1,657 @@
+import datetime
+import json
+import time
+from json import JSONDecodeError
+from typing import Optional, List, Dict, Tuple, Iterator
+
+from pydantic import BaseModel
+
+from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, SimpleModelProviderEntity
+from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus
+from core.helper import encrypter
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType
+from core.model_runtime.model_providers import model_provider_factory
+from core.model_runtime.model_providers.__base.ai_model import AIModel
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+from core.model_runtime.utils import encoders
+from extensions.ext_database import db
+from models.provider import ProviderType, Provider, ProviderModel, TenantPreferredModelProvider
+
+
+class ProviderConfiguration(BaseModel):
+    """
+    Model class for provider configuration.
+    """
+    tenant_id: str
+    provider: ProviderEntity
+    preferred_provider_type: ProviderType
+    using_provider_type: ProviderType
+    system_configuration: SystemConfiguration
+    custom_configuration: CustomConfiguration
+
+    def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
+        """
+        Get current credentials.
+
+        :param model_type: model type
+        :param model: model name
+        :return:
+        """
+        if self.using_provider_type == ProviderType.SYSTEM:
+            return self.system_configuration.credentials
+        else:
+            if self.custom_configuration.models:
+                for model_configuration in self.custom_configuration.models:
+                    if model_configuration.model_type == model_type and model_configuration.model == model:
+                        return model_configuration.credentials
+
+            if self.custom_configuration.provider:
+                return self.custom_configuration.provider.credentials
+            else:
+                return None
+
+    def get_system_configuration_status(self) -> SystemConfigurationStatus:
+        """
+        Get system configuration status.
+        :return:
+        """
+        if self.system_configuration.enabled is False:
+            return SystemConfigurationStatus.UNSUPPORTED
+
+        current_quota_type = self.system_configuration.current_quota_type
+        current_quota_configuration = next(
+            (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type),
+            None
+        )
+
+        return SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else \
+            SystemConfigurationStatus.QUOTA_EXCEEDED
+
+    def is_custom_configuration_available(self) -> bool:
+        """
+        Check custom configuration available.
+        :return:
+        """
+        return (self.custom_configuration.provider is not None
+                or len(self.custom_configuration.models) > 0)
+
+    def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
+        """
+        Get custom credentials.
+
+        :param obfuscated: obfuscated secret data in credentials
+        :return:
+        """
+        if self.custom_configuration.provider is None:
+            return None
+
+        credentials = self.custom_configuration.provider.credentials
+        if not obfuscated:
+            return credentials
+
+        # Obfuscate credentials
+        return self._obfuscated_credentials(
+            credentials=credentials,
+            credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
+            if self.provider.provider_credential_schema else []
+        )
+
+    def custom_credentials_validate(self, credentials: dict) -> Tuple[Provider, dict]:
+        """
+        Validate custom credentials.
+        :param credentials: provider credentials
+        :return:
+        """
+        # get provider
+        provider_record = db.session.query(Provider) \
+            .filter(
+            Provider.tenant_id == self.tenant_id,
+            Provider.provider_name == self.provider.provider,
+            Provider.provider_type == ProviderType.CUSTOM.value
+        ).first()
+
+        # Get provider credential secret variables
+        provider_credential_secret_variables = self._extract_secret_variables(
+            self.provider.provider_credential_schema.credential_form_schemas
+            if self.provider.provider_credential_schema else []
+        )
+
+        if provider_record:
+            try:
+                original_credentials = json.loads(provider_record.encrypted_config) if provider_record.encrypted_config else {}
+            except JSONDecodeError:
+                original_credentials = {}
+
+            # encrypt credentials
+            for key, value in credentials.items():
+                if key in provider_credential_secret_variables:
+                    # if send [__HIDDEN__] in secret input, it will be same as original value
+                    if value == '[__HIDDEN__]' and key in original_credentials:
+                        credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
+
+        model_provider_factory.provider_credentials_validate(
+            self.provider.provider,
+            credentials
+        )
+
+        for key, value in credentials.items():
+            if key in provider_credential_secret_variables:
+                credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
+
+        return provider_record, credentials
+
+    def add_or_update_custom_credentials(self, credentials: dict) -> None:
+        """
+        Add or update custom provider credentials.
+        :param credentials:
+        :return:
+        """
+        # validate custom provider config
+        provider_record, credentials = self.custom_credentials_validate(credentials)
+
+        # save provider
+        # Note: Do not switch the preferred provider, which allows users to use quotas first
+        if provider_record:
+            provider_record.encrypted_config = json.dumps(credentials)
+            provider_record.is_valid = True
+            provider_record.updated_at = datetime.datetime.utcnow()
+            db.session.commit()
+        else:
+            provider_record = Provider(
+                tenant_id=self.tenant_id,
+                provider_name=self.provider.provider,
+                provider_type=ProviderType.CUSTOM.value,
+                encrypted_config=json.dumps(credentials),
+                is_valid=True
+            )
+            db.session.add(provider_record)
+            db.session.commit()
+
+        self.switch_preferred_provider_type(ProviderType.CUSTOM)
+
+    def delete_custom_credentials(self) -> None:
+        """
+        Delete custom provider credentials.
+        :return:
+        """
+        # get provider
+        provider_record = db.session.query(Provider) \
+            .filter(
+            Provider.tenant_id == self.tenant_id,
+            Provider.provider_name == self.provider.provider,
+            Provider.provider_type == ProviderType.CUSTOM.value
+        ).first()
+
+        # delete provider
+        if provider_record:
+            self.switch_preferred_provider_type(ProviderType.SYSTEM)
+
+            db.session.delete(provider_record)
+            db.session.commit()
+
+    def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \
+            -> Optional[dict]:
+        """
+        Get custom model credentials.
+
+        :param model_type: model type
+        :param model: model name
+        :param obfuscated: obfuscated secret data in credentials
+        :return:
+        """
+        if not self.custom_configuration.models:
+            return None
+
+        for model_configuration in self.custom_configuration.models:
+            if model_configuration.model_type == model_type and model_configuration.model == model:
+                credentials = model_configuration.credentials
+                if not obfuscated:
+                    return credentials
+
+                # Obfuscate credentials
+                return self._obfuscated_credentials(
+                    credentials=credentials,
+                    credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
+                    if self.provider.model_credential_schema else []
+                )
+
+        return None
+
+    def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \
+            -> Tuple[ProviderModel, dict]:
+        """
+        Validate custom model credentials.
+
+        :param model_type: model type
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        # get provider model
+        provider_model_record = db.session.query(ProviderModel) \
+            .filter(
+            ProviderModel.tenant_id == self.tenant_id,
+            ProviderModel.provider_name == self.provider.provider,
+            ProviderModel.model_name == model,
+            ProviderModel.model_type == model_type.to_origin_model_type()
+        ).first()
+
+        # Get provider credential secret variables
+        provider_credential_secret_variables = self._extract_secret_variables(
+            self.provider.model_credential_schema.credential_form_schemas
+            if self.provider.model_credential_schema else []
+        )
+
+        if provider_model_record:
+            try:
+                original_credentials = json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
+            except JSONDecodeError:
+                original_credentials = {}
+
+            # decrypt credentials
+            for key, value in credentials.items():
+                if key in provider_credential_secret_variables:
+                    # if send [__HIDDEN__] in secret input, it will be same as original value
+                    if value == '[__HIDDEN__]' and key in original_credentials:
+                        credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
+
+        model_provider_factory.model_credentials_validate(
+            provider=self.provider.provider,
+            model_type=model_type,
+            model=model,
+            credentials=credentials
+        )
+
+        model_schema = (
+            model_provider_factory.get_provider_instance(self.provider.provider)
+            .get_model_instance(model_type)._get_customizable_model_schema(
+                model=model,
+                credentials=credentials
+            )
+        )
+
+        if model_schema:
+            credentials['schema'] = json.dumps(encoders.jsonable_encoder(model_schema))
+
+        for key, value in credentials.items():
+            if key in provider_credential_secret_variables:
+                credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
+
+        return provider_model_record, credentials
+
+    def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None:
+        """
+        Add or update custom model credentials.
+
+        :param model_type: model type
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        # validate custom model config
+        provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials)
+
+        # save provider model
+        # Note: Do not switch the preferred provider, which allows users to use quotas first
+        if provider_model_record:
+            provider_model_record.encrypted_config = json.dumps(credentials)
+            provider_model_record.is_valid = True
+            provider_model_record.updated_at = datetime.datetime.utcnow()
+            db.session.commit()
+        else:
+            provider_model_record = ProviderModel(
+                tenant_id=self.tenant_id,
+                provider_name=self.provider.provider,
+                model_name=model,
+                model_type=model_type.to_origin_model_type(),
+                encrypted_config=json.dumps(credentials),
+                is_valid=True
+            )
+            db.session.add(provider_model_record)
+            db.session.commit()
+
+    def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None:
+        """
+        Delete custom model credentials.
+        :param model_type: model type
+        :param model: model name
+        :return:
+        """
+        # get provider model
+        provider_model_record = db.session.query(ProviderModel) \
+            .filter(
+            ProviderModel.tenant_id == self.tenant_id,
+            ProviderModel.provider_name == self.provider.provider,
+            ProviderModel.model_name == model,
+            ProviderModel.model_type == model_type.to_origin_model_type()
+        ).first()
+
+        # delete provider model
+        if provider_model_record:
+            db.session.delete(provider_model_record)
+            db.session.commit()
+
+    def get_provider_instance(self) -> ModelProvider:
+        """
+        Get provider instance.
+        :return:
+        """
+        return model_provider_factory.get_provider_instance(self.provider.provider)
+
+    def get_model_type_instance(self, model_type: ModelType) -> AIModel:
+        """
+        Get current model type instance.
+
+        :param model_type: model type
+        :return:
+        """
+        # Get provider instance
+        provider_instance = self.get_provider_instance()
+
+        # Get model instance of LLM
+        return provider_instance.get_model_instance(model_type)
+
+    def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
+        """
+        Switch preferred provider type.
+        :param provider_type:
+        :return:
+        """
+        if provider_type == self.preferred_provider_type:
+            return
+
+        if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
+            return
+
+        # get preferred provider
+        preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
+            .filter(
+            TenantPreferredModelProvider.tenant_id == self.tenant_id,
+            TenantPreferredModelProvider.provider_name == self.provider.provider
+        ).first()
+
+        if preferred_model_provider:
+            preferred_model_provider.preferred_provider_type = provider_type.value
+        else:
+            preferred_model_provider = TenantPreferredModelProvider(
+                tenant_id=self.tenant_id,
+                provider_name=self.provider.provider,
+                preferred_provider_type=provider_type.value
+            )
+            db.session.add(preferred_model_provider)
+
+        db.session.commit()
+
+    def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
+        """
+        Extract secret input form variables.
+
+        :param credential_form_schemas:
+        :return:
+        """
+        secret_input_form_variables = []
+        for credential_form_schema in credential_form_schemas:
+            if credential_form_schema.type == FormType.SECRET_INPUT:
+                secret_input_form_variables.append(credential_form_schema.variable)
+
+        return secret_input_form_variables
+
+    def _obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
+        """
+        Obfuscated credentials.
+
+        :param credentials: credentials
+        :param credential_form_schemas: credential form schemas
+        :return:
+        """
+        # Get provider credential secret variables
+        credential_secret_variables = self._extract_secret_variables(
+            credential_form_schemas
+        )
+
+        # Obfuscate provider credentials
+        copy_credentials = credentials.copy()
+        for key, value in copy_credentials.items():
+            if key in credential_secret_variables:
+                copy_credentials[key] = encrypter.obfuscated_token(value)
+
+        return copy_credentials
+
+    def get_provider_model(self, model_type: ModelType,
+                           model: str,
+                           only_active: bool = False) -> Optional[ModelWithProviderEntity]:
+        """
+        Get provider model.
+        :param model_type: model type
+        :param model: model name
+        :param only_active: return active model only
+        :return:
+        """
+        provider_models = self.get_provider_models(model_type, only_active)
+
+        for provider_model in provider_models:
+            if provider_model.model == model:
+                return provider_model
+
+        return None
+
+    def get_provider_models(self, model_type: Optional[ModelType] = None,
+                            only_active: bool = False) -> list[ModelWithProviderEntity]:
+        """
+        Get provider models.
+        :param model_type: model type
+        :param only_active: only active models
+        :return:
+        """
+        provider_instance = self.get_provider_instance()
+
+        model_types = []
+        if model_type:
+            model_types.append(model_type)
+        else:
+            model_types = provider_instance.get_provider_schema().supported_model_types
+
+        if self.using_provider_type == ProviderType.SYSTEM:
+            provider_models = self._get_system_provider_models(
+                model_types=model_types,
+                provider_instance=provider_instance
+            )
+        else:
+            provider_models = self._get_custom_provider_models(
+                model_types=model_types,
+                provider_instance=provider_instance
+            )
+
+        if only_active:
+            provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
+
+        # resort provider_models
+        return sorted(provider_models, key=lambda x: x.model_type.value)
+
+    def _get_system_provider_models(self,
+                                    model_types: list[ModelType],
+                                    provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
+        """
+        Get system provider models.
+
+        :param model_types: model types
+        :param provider_instance: provider instance
+        :return:
+        """
+        provider_models = []
+        for model_type in model_types:
+            provider_models.extend(
+                [
+                    ModelWithProviderEntity(
+                        **m.dict(),
+                        provider=SimpleModelProviderEntity(self.provider),
+                        status=ModelStatus.ACTIVE
+                    )
+                    for m in provider_instance.models(model_type)
+                ]
+            )
+
+        for quota_configuration in self.system_configuration.quota_configurations:
+            if self.system_configuration.current_quota_type != quota_configuration.quota_type:
+                continue
+
+            restrict_llms = quota_configuration.restrict_llms
+            if not restrict_llms:
+                break
+
+            # if llm name not in restricted llm list, remove it
+            for m in provider_models:
+                if m.model_type == ModelType.LLM and m.model not in restrict_llms:
+                    m.status = ModelStatus.NO_PERMISSION
+                elif not quota_configuration.is_valid:
+                    m.status = ModelStatus.QUOTA_EXCEEDED
+
+        return provider_models
+
+    def _get_custom_provider_models(self,
+                                    model_types: list[ModelType],
+                                    provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
+        """
+        Get custom provider models.
+
+        :param model_types: model types
+        :param provider_instance: provider instance
+        :return:
+        """
+        provider_models = []
+
+        credentials = None
+        if self.custom_configuration.provider:
+            credentials = self.custom_configuration.provider.credentials
+
+        for model_type in model_types:
+            if model_type not in self.provider.supported_model_types:
+                continue
+
+            models = provider_instance.models(model_type)
+            for m in models:
+                provider_models.append(
+                    ModelWithProviderEntity(
+                        **m.dict(),
+                        provider=SimpleModelProviderEntity(self.provider),
+                        status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
+                    )
+                )
+
+        # custom models
+        for model_configuration in self.custom_configuration.models:
+            if model_configuration.model_type not in model_types:
+                continue
+
+            custom_model_schema = (
+                provider_instance.get_model_instance(model_configuration.model_type)
+                .get_customizable_model_schema_from_credentials(
+                    model_configuration.model,
+                    model_configuration.credentials
+                )
+            )
+
+            if not custom_model_schema:
+                continue
+
+            provider_models.append(
+                ModelWithProviderEntity(
+                    **custom_model_schema.dict(),
+                    provider=SimpleModelProviderEntity(self.provider),
+                    status=ModelStatus.ACTIVE
+                )
+            )
+
+        return provider_models
+
+
+class ProviderConfigurations(BaseModel):
+    """
+    Model class for provider configuration dict.
+    """
+    tenant_id: str
+    configurations: Dict[str, ProviderConfiguration] = {}
+
+    def __init__(self, tenant_id: str):
+        super().__init__(tenant_id=tenant_id)
+
+    def get_models(self,
+                   provider: Optional[str] = None,
+                   model_type: Optional[ModelType] = None,
+                   only_active: bool = False) \
+            -> list[ModelWithProviderEntity]:
+        """
+        Get available models.
+
+        If preferred provider type is `system`:
+          Get the current **system mode** if provider supported,
+          if all system modes are not available (no quota), it is considered to be the **custom credential mode**.
+          If there is no model configured in custom mode, it is treated as no_configure.
+        system > custom > no_configure
+
+        If preferred provider type is `custom`:
+          If custom credentials are configured, it is treated as custom mode.
+          Otherwise, get the current **system mode** if supported,
+          If all system modes are not available (no quota), it is treated as no_configure.
+        custom > system > no_configure
+
+        If real mode is `system`, use system credentials to get models,
+          paid quotas > provider free quotas > system free quotas
+          include pre-defined models (exclude GPT-4, status marked as `no_permission`).
+        If real mode is `custom`, use workspace custom credentials to get models,
+          include pre-defined models, custom models(manual append).
+        If real mode is `no_configure`, only return pre-defined models from `model runtime`.
+          (model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`)
+        model status marked as `active` is available.
+
+        :param provider: provider name
+        :param model_type: model type
+        :param only_active: only active models
+        :return:
+        """
+        all_models = []
+        for provider_configuration in self.values():
+            if provider and provider_configuration.provider.provider != provider:
+                continue
+
+            all_models.extend(provider_configuration.get_provider_models(model_type, only_active))
+
+        return all_models
+
+    def to_list(self) -> List[ProviderConfiguration]:
+        """
+        Convert to list.
+
+        :return:
+        """
+        return list(self.values())
+
+    def __getitem__(self, key):
+        return self.configurations[key]
+
+    def __setitem__(self, key, value):
+        self.configurations[key] = value
+
+    def __iter__(self):
+        return iter(self.configurations)
+
+    def values(self) -> Iterator[ProviderConfiguration]:
+        return self.configurations.values()
+
+    def get(self, key, default=None):
+        return self.configurations.get(key, default)
+
+
+class ProviderModelBundle(BaseModel):
+    """
+    Provider model bundle.
+    """
+    configuration: ProviderConfiguration
+    provider_instance: ModelProvider
+    model_type_instance: AIModel
+
+    class Config:
+        """Configuration for this pydantic object."""
+
+        arbitrary_types_allowed = True

+ 67 - 0
api/core/entities/provider_entities.py

@@ -0,0 +1,67 @@
+from enum import Enum
+from typing import Optional
+
+from pydantic import BaseModel
+
+from core.model_runtime.entities.model_entities import ModelType
+from models.provider import ProviderQuotaType
+
+
+class QuotaUnit(Enum):
+    TIMES = 'times'
+    TOKENS = 'tokens'
+
+
+class SystemConfigurationStatus(Enum):
+    """
+    Enum class for system configuration status.
+    """
+    ACTIVE = 'active'
+    QUOTA_EXCEEDED = 'quota-exceeded'
+    UNSUPPORTED = 'unsupported'
+
+
+class QuotaConfiguration(BaseModel):
+    """
+    Model class for provider quota configuration.
+    """
+    quota_type: ProviderQuotaType
+    quota_unit: QuotaUnit
+    quota_limit: int
+    quota_used: int
+    is_valid: bool
+    restrict_llms: list[str] = []
+
+
+class SystemConfiguration(BaseModel):
+    """
+    Model class for provider system configuration.
+    """
+    enabled: bool
+    current_quota_type: Optional[ProviderQuotaType] = None
+    quota_configurations: list[QuotaConfiguration] = []
+    credentials: Optional[dict] = None
+
+
+class CustomProviderConfiguration(BaseModel):
+    """
+    Model class for provider custom configuration.
+    """
+    credentials: dict
+
+
+class CustomModelConfiguration(BaseModel):
+    """
+    Model class for provider custom model configuration.
+    """
+    model: str
+    model_type: ModelType
+    credentials: dict
+
+
+class CustomConfiguration(BaseModel):
+    """
+    Model class for provider custom configuration.
+    """
+    provider: Optional[CustomProviderConfiguration] = None
+    models: list[CustomModelConfiguration] = []

+ 118 - 0
api/core/entities/queue_entities.py

@@ -0,0 +1,118 @@
+from enum import Enum
+from typing import Any
+
+from pydantic import BaseModel
+
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
+
+
+class QueueEvent(Enum):
+    """
+    QueueEvent enum
+    """
+    MESSAGE = "message"
+    MESSAGE_REPLACE = "message-replace"
+    MESSAGE_END = "message-end"
+    RETRIEVER_RESOURCES = "retriever-resources"
+    ANNOTATION_REPLY = "annotation-reply"
+    AGENT_THOUGHT = "agent-thought"
+    ERROR = "error"
+    PING = "ping"
+    STOP = "stop"
+
+
+class AppQueueEvent(BaseModel):
+    """
+    QueueEvent entity
+    """
+    event: QueueEvent
+
+
+class QueueMessageEvent(AppQueueEvent):
+    """
+    QueueMessageEvent entity
+    """
+    event = QueueEvent.MESSAGE
+    chunk: LLMResultChunk
+    
+    
+class QueueMessageReplaceEvent(AppQueueEvent):
+    """
+    QueueMessageReplaceEvent entity
+    """
+    event = QueueEvent.MESSAGE_REPLACE
+    text: str
+
+
+class QueueRetrieverResourcesEvent(AppQueueEvent):
+    """
+    QueueRetrieverResourcesEvent entity
+    """
+    event = QueueEvent.RETRIEVER_RESOURCES
+    retriever_resources: list[dict]
+
+
+class AnnotationReplyEvent(AppQueueEvent):
+    """
+    AnnotationReplyEvent entity
+    """
+    event = QueueEvent.ANNOTATION_REPLY
+    message_annotation_id: str
+
+
+class QueueMessageEndEvent(AppQueueEvent):
+    """
+    QueueMessageEndEvent entity
+    """
+    event = QueueEvent.MESSAGE_END
+    llm_result: LLMResult
+
+    
+class QueueAgentThoughtEvent(AppQueueEvent):
+    """
+    QueueAgentThoughtEvent entity
+    """
+    event = QueueEvent.AGENT_THOUGHT
+    agent_thought_id: str
+    
+    
+class QueueErrorEvent(AppQueueEvent):
+    """
+    QueueErrorEvent entity
+    """
+    event = QueueEvent.ERROR
+    error: Any
+
+
+class QueuePingEvent(AppQueueEvent):
+    """
+    QueuePingEvent entity
+    """
+    event = QueueEvent.PING
+
+
+class QueueStopEvent(AppQueueEvent):
+    """
+    QueueStopEvent entity
+    """
+    class StopBy(Enum):
+        """
+        Stop by enum
+        """
+        USER_MANUAL = "user-manual"
+        ANNOTATION_REPLY = "annotation-reply"
+        OUTPUT_MODERATION = "output-moderation"
+
+    event = QueueEvent.STOP
+    stopped_by: StopBy
+
+
+class QueueMessage(BaseModel):
+    """
+    QueueMessage entity
+    """
+    task_id: str
+    message_id: str
+    conversation_id: str
+    app_mode: str
+    event: AppQueueEvent

+ 0 - 0
api/core/model_providers/models/entity/__init__.py → api/core/errors/__init__.py


+ 0 - 20
api/core/model_providers/error.py → api/core/errors/error.py

@@ -14,26 +14,6 @@ class LLMBadRequestError(LLMError):
     description = "Bad Request"
     description = "Bad Request"
 
 
 
 
-class LLMAPIConnectionError(LLMError):
-    """Raised when the LLM returns API connection error."""
-    description = "API Connection Error"
-
-
-class LLMAPIUnavailableError(LLMError):
-    """Raised when the LLM returns API unavailable error."""
-    description = "API Unavailable Error"
-
-
-class LLMRateLimitError(LLMError):
-    """Raised when the LLM returns rate limit error."""
-    description = "Rate Limit Error"
-
-
-class LLMAuthorizationError(LLMError):
-    """Raised when the LLM returns authorization error."""
-    description = "Authorization Error"
-
-
 class ProviderTokenNotInitError(Exception):
 class ProviderTokenNotInitError(Exception):
     """
     """
     Custom exception raised when the provider token is not initialized.
     Custom exception raised when the provider token is not initialized.

+ 0 - 0
api/core/model_providers/models/llm/__init__.py → api/core/external_data_tool/weather_search/__init__.py


+ 35 - 0
api/core/external_data_tool/weather_search/schema.json

@@ -0,0 +1,35 @@
+{
+    "label": {
+        "en-US": "Weather Search",
+        "zh-Hans": "天气查询"
+    },
+    "form_schema": [
+        {
+            "type": "select",
+            "label": {
+                "en-US": "Temperature Unit",
+                "zh-Hans": "温度单位"
+            },
+            "variable": "temperature_unit",
+            "required": true,
+            "options": [
+                {
+                    "label": {
+                        "en-US": "Fahrenheit",
+                        "zh-Hans": "华氏度"
+                    },
+                    "value": "fahrenheit"
+                },
+                {
+                    "label": {
+                        "en-US": "Centigrade",
+                        "zh-Hans": "摄氏度"
+                    },
+                    "value": "centigrade"
+                }
+            ],
+            "default": "centigrade",
+            "placeholder": "Please select temperature unit"
+        }
+    ]
+}

+ 45 - 0
api/core/external_data_tool/weather_search/weather_search.py

@@ -0,0 +1,45 @@
+from typing import Optional
+
+from core.external_data_tool.base import ExternalDataTool
+
+
+class WeatherSearch(ExternalDataTool):
+    """
+    The name of custom type must be unique, keep the same with directory and file name.
+    """
+    name: str = "weather_search"
+
+    @classmethod
+    def validate_config(cls, tenant_id: str, config: dict) -> None:
+        """
+        schema.json validation. It will be called when user save the config.
+
+        Example:
+            .. code-block:: python
+                config = {
+                    "temperature_unit": "centigrade"
+                }
+
+        :param tenant_id: the id of workspace
+        :param config: the variables of form config
+        :return:
+        """
+
+        if not config.get('temperature_unit'):
+            raise ValueError('temperature unit is required')
+
+    def query(self, inputs: dict, query: Optional[str] = None) -> str:
+        """
+        Query the external data tool.
+
+        :param inputs: user inputs
+        :param query: the query of chat app
+        :return: the tool query result
+        """
+        city = inputs.get('city')
+        temperature_unit = self.config.get('temperature_unit')
+
+        if temperature_unit == 'fahrenheit':
+            return f'Weather in {city} is 32°F'
+        else:
+            return f'Weather in {city} is 0°C'

+ 0 - 0
api/core/model_providers/models/moderation/__init__.py → api/core/features/__init__.py


+ 325 - 0
api/core/features/agent_runner.py

@@ -0,0 +1,325 @@
+import logging
+from typing import cast, Optional, List
+
+from langchain import WikipediaAPIWrapper
+from langchain.callbacks.base import BaseCallbackHandler
+from langchain.tools import BaseTool, WikipediaQueryRun, Tool
+from pydantic import BaseModel, Field
+
+from core.agent.agent.agent_llm_callback import AgentLLMCallback
+from core.agent.agent_executor import PlanningStrategy, AgentConfiguration, AgentExecutor
+from core.application_queue_manager import ApplicationQueueManager
+from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
+from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
+from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
+from core.entities.application_entities import ModelConfigEntity, InvokeFrom, \
+    AgentEntity, AgentToolEntity, AppOrchestrationConfigEntity
+from core.memory.token_buffer_memory import TokenBufferMemory
+from core.model_runtime.entities.model_entities import ModelFeature, ModelType
+from core.model_runtime.model_providers import model_provider_factory
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.tool.current_datetime_tool import DatetimeTool
+from core.tool.dataset_retriever_tool import DatasetRetrieverTool
+from core.tool.provider.serpapi_provider import SerpAPIToolProvider
+from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput
+from core.tool.web_reader_tool import WebReaderTool
+from extensions.ext_database import db
+from models.dataset import Dataset
+from models.model import Message
+
+logger = logging.getLogger(__name__)
+
+
+class AgentRunnerFeature:
+    def __init__(self, tenant_id: str,
+                 app_orchestration_config: AppOrchestrationConfigEntity,
+                 model_config: ModelConfigEntity,
+                 config: AgentEntity,
+                 queue_manager: ApplicationQueueManager,
+                 message: Message,
+                 user_id: str,
+                 agent_llm_callback: AgentLLMCallback,
+                 callback: AgentLoopGatherCallbackHandler,
+                 memory: Optional[TokenBufferMemory] = None,) -> None:
+        """
+        Agent runner
+        :param tenant_id: tenant id
+        :param app_orchestration_config: app orchestration config
+        :param model_config: model config
+        :param config: dataset config
+        :param queue_manager: queue manager
+        :param message: message
+        :param user_id: user id
+        :param agent_llm_callback: agent llm callback
+        :param callback: callback
+        :param memory: memory
+        """
+        self.tenant_id = tenant_id
+        self.app_orchestration_config = app_orchestration_config
+        self.model_config = model_config
+        self.config = config
+        self.queue_manager = queue_manager
+        self.message = message
+        self.user_id = user_id
+        self.agent_llm_callback = agent_llm_callback
+        self.callback = callback
+        self.memory = memory
+
+    def run(self, query: str,
+            invoke_from: InvokeFrom) -> Optional[str]:
+        """
+        Retrieve agent loop result.
+        :param query: query
+        :param invoke_from: invoke from
+        :return:
+        """
+        provider = self.config.provider
+        model = self.config.model
+        tool_configs = self.config.tools
+
+        # check model is support tool calling
+        provider_instance = model_provider_factory.get_provider_instance(provider=provider)
+        model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
+        model_type_instance = cast(LargeLanguageModel, model_type_instance)
+
+        # get model schema
+        model_schema = model_type_instance.get_model_schema(
+            model=model,
+            credentials=self.model_config.credentials
+        )
+
+        if not model_schema:
+            return None
+
+        planning_strategy = PlanningStrategy.REACT
+        features = model_schema.features
+        if features:
+            if ModelFeature.TOOL_CALL in features \
+                    or ModelFeature.MULTI_TOOL_CALL in features:
+                planning_strategy = PlanningStrategy.FUNCTION_CALL
+
+        tools = self.to_tools(
+            tool_configs=tool_configs,
+            invoke_from=invoke_from,
+            callbacks=[self.callback, DifyStdOutCallbackHandler()],
+        )
+
+        if len(tools) == 0:
+            return None
+
+        agent_configuration = AgentConfiguration(
+            strategy=planning_strategy,
+            model_config=self.model_config,
+            tools=tools,
+            memory=self.memory,
+            max_iterations=10,
+            max_execution_time=400.0,
+            early_stopping_method="generate",
+            agent_llm_callback=self.agent_llm_callback,
+            callbacks=[self.callback, DifyStdOutCallbackHandler()]
+        )
+
+        agent_executor = AgentExecutor(agent_configuration)
+
+        try:
+            # check if should use agent
+            should_use_agent = agent_executor.should_use_agent(query)
+            if not should_use_agent:
+                return None
+
+            result = agent_executor.run(query)
+            return result.output
+        except Exception as ex:
+            logger.exception("agent_executor run failed")
+            return None
+
+    def to_tools(self, tool_configs: list[AgentToolEntity],
+                 invoke_from: InvokeFrom,
+                 callbacks: list[BaseCallbackHandler]) \
+            -> Optional[List[BaseTool]]:
+        """
+        Convert tool configs to tools
+        :param tool_configs: tool configs
+        :param invoke_from: invoke from
+        :param callbacks: callbacks
+        """
+        tools = []
+        for tool_config in tool_configs:
+            tool = None
+            if tool_config.tool_id == "dataset":
+                tool = self.to_dataset_retriever_tool(
+                    tool_config=tool_config.config,
+                    invoke_from=invoke_from
+                )
+            elif tool_config.tool_id == "web_reader":
+                tool = self.to_web_reader_tool(
+                    tool_config=tool_config.config,
+                    invoke_from=invoke_from
+                )
+            elif tool_config.tool_id == "google_search":
+                tool = self.to_google_search_tool(
+                    tool_config=tool_config.config,
+                    invoke_from=invoke_from
+                )
+            elif tool_config.tool_id == "wikipedia":
+                tool = self.to_wikipedia_tool(
+                    tool_config=tool_config.config,
+                    invoke_from=invoke_from
+                )
+            elif tool_config.tool_id == "current_datetime":
+                tool = self.to_current_datetime_tool(
+                    tool_config=tool_config.config,
+                    invoke_from=invoke_from
+                )
+
+            if tool:
+                if tool.callbacks is not None:
+                    tool.callbacks.extend(callbacks)
+                else:
+                    tool.callbacks = callbacks
+
+                tools.append(tool)
+
+        return tools
+
+    def to_dataset_retriever_tool(self, tool_config: dict,
+                                  invoke_from: InvokeFrom) \
+            -> Optional[BaseTool]:
+        """
+        A dataset tool is a tool that can be used to retrieve information from a dataset
+        :param tool_config: tool config
+        :param invoke_from: invoke from
+        """
+        show_retrieve_source = self.app_orchestration_config.show_retrieve_source
+
+        hit_callback = DatasetIndexToolCallbackHandler(
+            queue_manager=self.queue_manager,
+            app_id=self.message.app_id,
+            message_id=self.message.id,
+            user_id=self.user_id,
+            invoke_from=invoke_from
+        )
+
+        # get dataset from dataset id
+        dataset = db.session.query(Dataset).filter(
+            Dataset.tenant_id == self.tenant_id,
+            Dataset.id == tool_config.get("id")
+        ).first()
+
+        # pass if dataset is not available
+        if not dataset:
+            return None
+
+        # pass if dataset is not available
+        if (dataset and dataset.available_document_count == 0
+                and dataset.available_document_count == 0):
+            return None
+
+        # get retrieval model config
+        default_retrieval_model = {
+            'search_method': 'semantic_search',
+            'reranking_enable': False,
+            'reranking_model': {
+                'reranking_provider_name': '',
+                'reranking_model_name': ''
+            },
+            'top_k': 2,
+            'score_threshold_enabled': False
+        }
+
+        retrieval_model_config = dataset.retrieval_model \
+            if dataset.retrieval_model else default_retrieval_model
+
+        # get top k
+        top_k = retrieval_model_config['top_k']
+
+        # get score threshold
+        score_threshold = None
+        score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
+        if score_threshold_enabled:
+            score_threshold = retrieval_model_config.get("score_threshold")
+
+        tool = DatasetRetrieverTool.from_dataset(
+            dataset=dataset,
+            top_k=top_k,
+            score_threshold=score_threshold,
+            hit_callbacks=[hit_callback],
+            return_resource=show_retrieve_source,
+            retriever_from=invoke_from.to_source()
+        )
+
+        return tool
+
+    def to_web_reader_tool(self, tool_config: dict,
+                           invoke_from: InvokeFrom) -> Optional[BaseTool]:
+        """
+        A tool for reading web pages
+        :param tool_config: tool config
+        :param invoke_from: invoke from
+        :return:
+        """
+        model_parameters = {
+            "temperature": 0,
+            "max_tokens": 500
+        }
+
+        tool = WebReaderTool(
+            model_config=self.model_config,
+            model_parameters=model_parameters,
+            max_chunk_length=4000,
+            continue_reading=True
+        )
+
+        return tool
+
+    def to_google_search_tool(self, tool_config: dict,
+                              invoke_from: InvokeFrom) -> Optional[BaseTool]:
+        """
+        A tool for performing a Google search and extracting snippets and webpages
+        :param tool_config: tool config
+        :param invoke_from: invoke from
+        :return:
+        """
+        tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id)
+        func_kwargs = tool_provider.credentials_to_func_kwargs()
+        if not func_kwargs:
+            return None
+
+        tool = Tool(
+            name="google_search",
+            description="A tool for performing a Google search and extracting snippets and webpages "
+                        "when you need to search for something you don't know or when your information "
+                        "is not up to date. "
+                        "Input should be a search query.",
+            func=OptimizedSerpAPIWrapper(**func_kwargs).run,
+            args_schema=OptimizedSerpAPIInput
+        )
+
+        return tool
+
+    def to_current_datetime_tool(self, tool_config: dict,
+                                 invoke_from: InvokeFrom) -> Optional[BaseTool]:
+        """
+        A tool for getting the current date and time
+        :param tool_config: tool config
+        :param invoke_from: invoke from
+        :return:
+        """
+        return DatetimeTool()
+
+    def to_wikipedia_tool(self, tool_config: dict,
+                          invoke_from: InvokeFrom) -> Optional[BaseTool]:
+        """
+        A tool for searching Wikipedia
+        :param tool_config: tool config
+        :param invoke_from: invoke from
+        :return:
+        """
+        class WikipediaInput(BaseModel):
+            query: str = Field(..., description="search query.")
+
+        return WikipediaQueryRun(
+            name="wikipedia",
+            api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
+            args_schema=WikipediaInput
+        )

+ 119 - 0
api/core/features/annotation_reply.py

@@ -0,0 +1,119 @@
+import logging
+from typing import Optional
+
+from flask import current_app
+
+from core.embedding.cached_embedding import CacheEmbedding
+from core.entities.application_entities import InvokeFrom
+from core.index.vector_index.vector_index import VectorIndex
+from core.model_manager import ModelManager
+from core.model_runtime.entities.model_entities import ModelType
+from extensions.ext_database import db
+from models.dataset import Dataset
+from models.model import App, Message, AppAnnotationSetting, MessageAnnotation
+from services.annotation_service import AppAnnotationService
+from services.dataset_service import DatasetCollectionBindingService
+
+logger = logging.getLogger(__name__)
+
+
+class AnnotationReplyFeature:
+    def query(self, app_record: App,
+              message: Message,
+              query: str,
+              user_id: str,
+              invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
+        """
+        Query app annotations to reply
+        :param app_record: app record
+        :param message: message
+        :param query: query
+        :param user_id: user id
+        :param invoke_from: invoke from
+        :return:
+        """
+        annotation_setting = db.session.query(AppAnnotationSetting).filter(
+            AppAnnotationSetting.app_id == app_record.id).first()
+
+        if not annotation_setting:
+            return None
+
+        collection_binding_detail = annotation_setting.collection_binding_detail
+
+        try:
+            score_threshold = annotation_setting.score_threshold or 1
+            embedding_provider_name = collection_binding_detail.provider_name
+            embedding_model_name = collection_binding_detail.model_name
+
+            model_manager = ModelManager()
+            model_instance = model_manager.get_model_instance(
+                tenant_id=app_record.tenant_id,
+                provider=embedding_provider_name,
+                model_type=ModelType.TEXT_EMBEDDING,
+                model=embedding_model_name
+            )
+
+            # get embedding model
+            embeddings = CacheEmbedding(model_instance)
+
+            dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
+                embedding_provider_name,
+                embedding_model_name,
+                'annotation'
+            )
+
+            dataset = Dataset(
+                id=app_record.id,
+                tenant_id=app_record.tenant_id,
+                indexing_technique='high_quality',
+                embedding_model_provider=embedding_provider_name,
+                embedding_model=embedding_model_name,
+                collection_binding_id=dataset_collection_binding.id
+            )
+
+            vector_index = VectorIndex(
+                dataset=dataset,
+                config=current_app.config,
+                embeddings=embeddings,
+                attributes=['doc_id', 'annotation_id', 'app_id']
+            )
+
+            documents = vector_index.search(
+                query=query,
+                search_type='similarity_score_threshold',
+                search_kwargs={
+                    'k': 1,
+                    'score_threshold': score_threshold,
+                    'filter': {
+                        'group_id': [dataset.id]
+                    }
+                }
+            )
+
+            if documents:
+                annotation_id = documents[0].metadata['annotation_id']
+                score = documents[0].metadata['score']
+                annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
+                if annotation:
+                    if invoke_from in [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP]:
+                        from_source = 'api'
+                    else:
+                        from_source = 'console'
+
+                    # insert annotation history
+                    AppAnnotationService.add_annotation_history(annotation.id,
+                                                                app_record.id,
+                                                                annotation.question,
+                                                                annotation.content,
+                                                                query,
+                                                                user_id,
+                                                                message.id,
+                                                                from_source,
+                                                                score)
+
+                    return annotation
+        except Exception as e:
+            logger.warning(f'Query annotation failed, exception: {str(e)}.')
+            return None
+
+        return None

+ 181 - 0
api/core/features/dataset_retrieval.py

@@ -0,0 +1,181 @@
+from typing import cast, Optional, List
+
+from langchain.tools import BaseTool
+
+from core.agent.agent_executor import PlanningStrategy, AgentConfiguration, AgentExecutor
+from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
+from core.entities.application_entities import DatasetEntity, ModelConfigEntity, InvokeFrom, DatasetRetrieveConfigEntity
+from core.memory.token_buffer_memory import TokenBufferMemory
+from core.model_runtime.entities.model_entities import ModelFeature
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
+from core.tool.dataset_retriever_tool import DatasetRetrieverTool
+from extensions.ext_database import db
+from models.dataset import Dataset
+
+
+class DatasetRetrievalFeature:
+    def retrieve(self, tenant_id: str,
+                 model_config: ModelConfigEntity,
+                 config: DatasetEntity,
+                 query: str,
+                 invoke_from: InvokeFrom,
+                 show_retrieve_source: bool,
+                 hit_callback: DatasetIndexToolCallbackHandler,
+                 memory: Optional[TokenBufferMemory] = None) -> Optional[str]:
+        """
+        Retrieve dataset.
+        :param tenant_id: tenant id
+        :param model_config: model config
+        :param config: dataset config
+        :param query: query
+        :param invoke_from: invoke from
+        :param show_retrieve_source: show retrieve source
+        :param hit_callback: hit callback
+        :param memory: memory
+        :return:
+        """
+        dataset_ids = config.dataset_ids
+        retrieve_config = config.retrieve_config
+
+        # check model is support tool calling
+        model_type_instance = model_config.provider_model_bundle.model_type_instance
+        model_type_instance = cast(LargeLanguageModel, model_type_instance)
+
+        # get model schema
+        model_schema = model_type_instance.get_model_schema(
+            model=model_config.model,
+            credentials=model_config.credentials
+        )
+
+        if not model_schema:
+            return None
+
+        planning_strategy = PlanningStrategy.REACT_ROUTER
+        features = model_schema.features
+        if features:
+            if ModelFeature.TOOL_CALL in features \
+                    or ModelFeature.MULTI_TOOL_CALL in features:
+                planning_strategy = PlanningStrategy.ROUTER
+
+        dataset_retriever_tools = self.to_dataset_retriever_tool(
+            tenant_id=tenant_id,
+            dataset_ids=dataset_ids,
+            retrieve_config=retrieve_config,
+            return_resource=show_retrieve_source,
+            invoke_from=invoke_from,
+            hit_callback=hit_callback
+        )
+
+        if len(dataset_retriever_tools) == 0:
+            return None
+
+        agent_configuration = AgentConfiguration(
+            strategy=planning_strategy,
+            model_config=model_config,
+            tools=dataset_retriever_tools,
+            memory=memory,
+            max_iterations=10,
+            max_execution_time=400.0,
+            early_stopping_method="generate"
+        )
+
+        agent_executor = AgentExecutor(agent_configuration)
+
+        should_use_agent = agent_executor.should_use_agent(query)
+        if not should_use_agent:
+            return None
+
+        result = agent_executor.run(query)
+
+        return result.output
+
+    def to_dataset_retriever_tool(self, tenant_id: str,
+                                  dataset_ids: list[str],
+                                  retrieve_config: DatasetRetrieveConfigEntity,
+                                  return_resource: bool,
+                                  invoke_from: InvokeFrom,
+                                  hit_callback: DatasetIndexToolCallbackHandler) \
+            -> Optional[List[BaseTool]]:
+        """
+        A dataset tool is a tool that can be used to retrieve information from a dataset
+        :param tenant_id: tenant id
+        :param dataset_ids: dataset ids
+        :param retrieve_config: retrieve config
+        :param return_resource: return resource
+        :param invoke_from: invoke from
+        :param hit_callback: hit callback
+        """
+        tools = []
+        available_datasets = []
+        for dataset_id in dataset_ids:
+            # get dataset from dataset id
+            dataset = db.session.query(Dataset).filter(
+                Dataset.tenant_id == tenant_id,
+                Dataset.id == dataset_id
+            ).first()
+
+            # pass if dataset is not available
+            if not dataset:
+                continue
+
+            # pass if dataset is not available
+            if (dataset and dataset.available_document_count == 0
+                    and dataset.available_document_count == 0):
+                continue
+
+            available_datasets.append(dataset)
+
+        if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
+            # get retrieval model config
+            default_retrieval_model = {
+                'search_method': 'semantic_search',
+                'reranking_enable': False,
+                'reranking_model': {
+                    'reranking_provider_name': '',
+                    'reranking_model_name': ''
+                },
+                'top_k': 2,
+                'score_threshold_enabled': False
+            }
+
+            for dataset in available_datasets:
+                retrieval_model_config = dataset.retrieval_model \
+                    if dataset.retrieval_model else default_retrieval_model
+
+                # get top k
+                top_k = retrieval_model_config['top_k']
+
+                # get score threshold
+                score_threshold = None
+                score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
+                if score_threshold_enabled:
+                    score_threshold = retrieval_model_config.get("score_threshold")
+
+                tool = DatasetRetrieverTool.from_dataset(
+                    dataset=dataset,
+                    top_k=top_k,
+                    score_threshold=score_threshold,
+                    hit_callbacks=[hit_callback],
+                    return_resource=return_resource,
+                    retriever_from=invoke_from.to_source()
+                )
+
+                tools.append(tool)
+        elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
+            tool = DatasetMultiRetrieverTool.from_dataset(
+                dataset_ids=[dataset.id for dataset in available_datasets],
+                tenant_id=tenant_id,
+                top_k=retrieve_config.top_k or 2,
+                score_threshold=(retrieve_config.score_threshold or 0.5)
+                if retrieve_config.reranking_model.get('score_threshold_enabled', False) else None,
+                hit_callbacks=[hit_callback],
+                return_resource=return_resource,
+                retriever_from=invoke_from.to_source(),
+                reranking_provider_name=retrieve_config.reranking_model.get('reranking_provider_name'),
+                reranking_model_name=retrieve_config.reranking_model.get('reranking_model_name')
+            )
+
+            tools.append(tool)
+
+        return tools

+ 96 - 0
api/core/features/external_data_fetch.py

@@ -0,0 +1,96 @@
+import concurrent
+import json
+import logging
+
+from concurrent.futures import ThreadPoolExecutor
+from typing import Tuple, Optional
+
+from flask import current_app, Flask
+
+from core.entities.application_entities import ExternalDataVariableEntity
+from core.external_data_tool.factory import ExternalDataToolFactory
+
+logger = logging.getLogger(__name__)
+
+
+class ExternalDataFetchFeature:
+    def fetch(self, tenant_id: str,
+              app_id: str,
+              external_data_tools: list[ExternalDataVariableEntity],
+              inputs: dict,
+              query: str) -> dict:
+        """
+        Fill in variable inputs from external data tools if exists.
+
+        :param tenant_id: workspace id
+        :param app_id: app id
+        :param external_data_tools: external data tools configs
+        :param inputs: the inputs
+        :param query: the query
+        :return: the filled inputs
+        """
+        # Group tools by type and config
+        grouped_tools = {}
+        for tool in external_data_tools:
+            tool_key = (tool.type, json.dumps(tool.config, sort_keys=True))
+            grouped_tools.setdefault(tool_key, []).append(tool)
+
+        results = {}
+        with ThreadPoolExecutor() as executor:
+            futures = {}
+            for tool in external_data_tools:
+                future = executor.submit(
+                    self._query_external_data_tool,
+                    current_app._get_current_object(),
+                    tenant_id,
+                    app_id,
+                    tool,
+                    inputs,
+                    query
+                )
+
+                futures[future] = tool
+
+            for future in concurrent.futures.as_completed(futures):
+                tool_variable, result = future.result()
+                results[tool_variable] = result
+
+        inputs.update(results)
+        return inputs
+
+    def _query_external_data_tool(self, flask_app: Flask,
+                                  tenant_id: str,
+                                  app_id: str,
+                                  external_data_tool: ExternalDataVariableEntity,
+                                  inputs: dict,
+                                  query: str) -> Tuple[Optional[str], Optional[str]]:
+        """
+        Query external data tool.
+        :param flask_app: flask app
+        :param tenant_id: tenant id
+        :param app_id: app id
+        :param external_data_tool: external data tool
+        :param inputs: inputs
+        :param query: query
+        :return:
+        """
+        with flask_app.app_context():
+            tool_variable = external_data_tool.variable
+            tool_type = external_data_tool.type
+            tool_config = external_data_tool.config
+
+            external_data_tool_factory = ExternalDataToolFactory(
+                name=tool_type,
+                tenant_id=tenant_id,
+                app_id=app_id,
+                variable=tool_variable,
+                config=tool_config
+            )
+
+            # query external data tool
+            result = external_data_tool_factory.query(
+                inputs=inputs,
+                query=query
+            )
+
+            return tool_variable, result

+ 32 - 0
api/core/features/hosting_moderation.py

@@ -0,0 +1,32 @@
+import logging
+
+from core.entities.application_entities import ApplicationGenerateEntity
+from core.helper import moderation
+from core.model_runtime.entities.message_entities import PromptMessage
+
+logger = logging.getLogger(__name__)
+
+
+class HostingModerationFeature:
+    def check(self, application_generate_entity: ApplicationGenerateEntity,
+              prompt_messages: list[PromptMessage]) -> bool:
+        """
+        Check hosting moderation
+        :param application_generate_entity: application generate entity
+        :param prompt_messages: prompt messages
+        :return:
+        """
+        app_orchestration_config = application_generate_entity.app_orchestration_config_entity
+        model_config = app_orchestration_config.model_config
+
+        text = ""
+        for prompt_message in prompt_messages:
+            if isinstance(prompt_message.content, str):
+                text += prompt_message.content + "\n"
+
+        moderation_result = moderation.check_moderation(
+            model_config,
+            text
+        )
+
+        return moderation_result

+ 50 - 0
api/core/features/moderation.py

@@ -0,0 +1,50 @@
+import logging
+from typing import Tuple
+
+from core.entities.application_entities import AppOrchestrationConfigEntity
+from core.moderation.base import ModerationAction, ModerationException
+from core.moderation.factory import ModerationFactory
+
+logger = logging.getLogger(__name__)
+
+
+class ModerationFeature:
+    def check(self, app_id: str,
+              tenant_id: str,
+              app_orchestration_config_entity: AppOrchestrationConfigEntity,
+              inputs: dict,
+              query: str) -> Tuple[bool, dict, str]:
+        """
+        Process sensitive_word_avoidance.
+        :param app_id: app id
+        :param tenant_id: tenant id
+        :param app_orchestration_config_entity: app orchestration config entity
+        :param inputs: inputs
+        :param query: query
+        :return:
+        """
+        if not app_orchestration_config_entity.sensitive_word_avoidance:
+            return False, inputs, query
+
+        sensitive_word_avoidance_config = app_orchestration_config_entity.sensitive_word_avoidance
+        moderation_type = sensitive_word_avoidance_config.type
+
+        moderation_factory = ModerationFactory(
+            name=moderation_type,
+            app_id=app_id,
+            tenant_id=tenant_id,
+            config=sensitive_word_avoidance_config.config
+        )
+
+        moderation_result = moderation_factory.moderation_for_inputs(inputs, query)
+
+        if not moderation_result.flagged:
+            return False, inputs, query
+
+        if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
+            raise ModerationException(moderation_result.preset_response)
+        elif moderation_result.action == ModerationAction.OVERRIDED:
+            inputs = moderation_result.inputs
+            query = moderation_result.query
+
+        return True, inputs, query

+ 5 - 5
api/core/file/file_obj.py

@@ -4,7 +4,7 @@ from typing import Optional
 from pydantic import BaseModel
 from pydantic import BaseModel
 
 
 from core.file.upload_file_parser import UploadFileParser
 from core.file.upload_file_parser import UploadFileParser
-from core.model_providers.models.entity.message import PromptMessageFile, ImagePromptMessageFile
+from core.model_runtime.entities.message_entities import ImagePromptMessageContent
 from extensions.ext_database import db
 from extensions.ext_database import db
 from models.model import UploadFile
 from models.model import UploadFile
 
 
@@ -50,14 +50,14 @@ class FileObj(BaseModel):
         return self._get_data(force_url=True)
         return self._get_data(force_url=True)
 
 
     @property
     @property
-    def prompt_message_file(self) -> PromptMessageFile:
+    def prompt_message_content(self) -> ImagePromptMessageContent:
         if self.type == FileType.IMAGE:
         if self.type == FileType.IMAGE:
             image_config = self.file_config.get('image')
             image_config = self.file_config.get('image')
 
 
-            return ImagePromptMessageFile(
+            return ImagePromptMessageContent(
                 data=self.data,
                 data=self.data,
-                detail=ImagePromptMessageFile.DETAIL.HIGH
-                if image_config.get("detail") == "high" else ImagePromptMessageFile.DETAIL.LOW
+                detail=ImagePromptMessageContent.DETAIL.HIGH
+                if image_config.get("detail") == "high" else ImagePromptMessageContent.DETAIL.LOW
             )
             )
 
 
     def _get_data(self, force_url: bool = False) -> Optional[str]:
     def _get_data(self, force_url: bool = False) -> Optional[str]:

+ 63 - 40
api/core/generator/llm_generator.py

@@ -3,10 +3,10 @@ import logging
 
 
 from langchain.schema import OutputParserException
 from langchain.schema import OutputParserException
 
 
-from core.model_providers.error import LLMError, ProviderTokenNotInitError
-from core.model_providers.model_factory import ModelFactory
-from core.model_providers.models.entity.message import PromptMessage, MessageType
-from core.model_providers.models.entity.model_params import ModelKwargs
+from core.model_manager import ModelManager
+from core.model_runtime.entities.message_entities import UserPromptMessage, SystemPromptMessage
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
 from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
 from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
 
 
 from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
 from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
@@ -26,17 +26,22 @@ class LLMGenerator:
 
 
         prompt += query + "\n"
         prompt += query + "\n"
 
 
-        model_instance = ModelFactory.get_text_generation_model(
+        model_manager = ModelManager()
+        model_instance = model_manager.get_default_model_instance(
             tenant_id=tenant_id,
             tenant_id=tenant_id,
-            model_kwargs=ModelKwargs(
-                temperature=1,
-                max_tokens=100
-            )
+            model_type=ModelType.LLM,
         )
         )
 
 
-        prompts = [PromptMessage(content=prompt)]
-        response = model_instance.run(prompts)
-        answer = response.content
+        prompts = [UserPromptMessage(content=prompt)]
+        response = model_instance.invoke_llm(
+            prompt_messages=prompts,
+            model_parameters={
+                "max_tokens": 100,
+                "temperature": 1
+            },
+            stream=False
+        )
+        answer = response.message.content
 
 
         result_dict = json.loads(answer)
         result_dict = json.loads(answer)
         answer = result_dict['Your Output']
         answer = result_dict['Your Output']
@@ -62,22 +67,28 @@ class LLMGenerator:
         })
         })
 
 
         try:
         try:
-            model_instance = ModelFactory.get_text_generation_model(
+            model_manager = ModelManager()
+            model_instance = model_manager.get_default_model_instance(
                 tenant_id=tenant_id,
                 tenant_id=tenant_id,
-                model_kwargs=ModelKwargs(
-                    max_tokens=256,
-                    temperature=0
-                )
+                model_type=ModelType.LLM,
             )
             )
-        except ProviderTokenNotInitError:
+        except InvokeAuthorizationError:
             return []
             return []
 
 
-        prompt_messages = [PromptMessage(content=prompt)]
+        prompt_messages = [UserPromptMessage(content=prompt)]
 
 
         try:
         try:
-            output = model_instance.run(prompt_messages)
-            questions = output_parser.parse(output.content)
-        except LLMError:
+            response = model_instance.invoke_llm(
+                prompt_messages=prompt_messages,
+                model_parameters={
+                    "max_tokens": 256,
+                    "temperature": 0
+                },
+                stream=False
+            )
+
+            questions = output_parser.parse(response.message.content)
+        except InvokeError:
             questions = []
             questions = []
         except Exception as e:
         except Exception as e:
             logging.exception(e)
             logging.exception(e)
@@ -105,20 +116,26 @@ class LLMGenerator:
             remove_template_variables=False
             remove_template_variables=False
         )
         )
 
 
-        model_instance = ModelFactory.get_text_generation_model(
+        model_manager = ModelManager()
+        model_instance = model_manager.get_default_model_instance(
             tenant_id=tenant_id,
             tenant_id=tenant_id,
-            model_kwargs=ModelKwargs(
-                max_tokens=512,
-                temperature=0
-            )
+            model_type=ModelType.LLM,
         )
         )
 
 
-        prompt_messages = [PromptMessage(content=prompt)]
+        prompt_messages = [UserPromptMessage(content=prompt)]
 
 
         try:
         try:
-            output = model_instance.run(prompt_messages)
-            rule_config = output_parser.parse(output.content)
-        except LLMError as e:
+            response = model_instance.invoke_llm(
+                prompt_messages=prompt_messages,
+                model_parameters={
+                    "max_tokens": 512,
+                    "temperature": 0
+                },
+                stream=False
+            )
+
+            rule_config = output_parser.parse(response.message.content)
+        except InvokeError as e:
             raise e
             raise e
         except OutputParserException:
         except OutputParserException:
             raise ValueError('Please give a valid input for intended audience or hoping to solve problems.')
             raise ValueError('Please give a valid input for intended audience or hoping to solve problems.')
@@ -136,18 +153,24 @@ class LLMGenerator:
     def generate_qa_document(cls, tenant_id: str, query, document_language: str):
     def generate_qa_document(cls, tenant_id: str, query, document_language: str):
         prompt = GENERATOR_QA_PROMPT.format(language=document_language)
         prompt = GENERATOR_QA_PROMPT.format(language=document_language)
 
 
-        model_instance = ModelFactory.get_text_generation_model(
+        model_manager = ModelManager()
+        model_instance = model_manager.get_default_model_instance(
             tenant_id=tenant_id,
             tenant_id=tenant_id,
-            model_kwargs=ModelKwargs(
-                max_tokens=2000
-            )
+            model_type=ModelType.LLM,
         )
         )
 
 
-        prompts = [
-            PromptMessage(content=prompt, type=MessageType.SYSTEM),
-            PromptMessage(content=query)
+        prompt_messages = [
+            SystemPromptMessage(content=prompt),
+            UserPromptMessage(content=query)
         ]
         ]
 
 
-        response = model_instance.run(prompts)
-        answer = response.content
+        response = model_instance.invoke_llm(
+            prompt_messages=prompt_messages,
+            model_parameters={
+                "max_tokens": 2000
+            },
+            stream=False
+        )
+
+        answer = response.message.content
         return answer.strip()
         return answer.strip()

+ 14 - 0
api/core/helper/encrypter.py

@@ -18,3 +18,17 @@ def encrypt_token(tenant_id: str, token: str):
 
 
 def decrypt_token(tenant_id: str, token: str):
 def decrypt_token(tenant_id: str, token: str):
     return rsa.decrypt(base64.b64decode(token), tenant_id)
     return rsa.decrypt(base64.b64decode(token), tenant_id)
+
+
+def batch_decrypt_token(tenant_id: str, tokens: list[str]):
+    rsa_key, cipher_rsa = rsa.get_decrypt_decoding(tenant_id)
+
+    return [rsa.decrypt_token_with_decoding(base64.b64decode(token), rsa_key, cipher_rsa) for token in tokens]
+
+
+def get_decrypt_decoding(tenant_id: str):
+    return rsa.get_decrypt_decoding(tenant_id)
+
+
+def decrypt_token_with_decoding(token: str, rsa_key, cipher_rsa):
+    return rsa.decrypt_token_with_decoding(base64.b64decode(token), rsa_key, cipher_rsa)

+ 22 - 0
api/core/helper/lru_cache.py

@@ -0,0 +1,22 @@
+from collections import OrderedDict
+from typing import Any
+
+
+class LRUCache:
+    def __init__(self, capacity: int):
+        self.cache = OrderedDict()
+        self.capacity = capacity
+
+    def get(self, key: Any) -> Any:
+        if key not in self.cache:
+            return None
+        else:
+            self.cache.move_to_end(key)  # move the key to the end of the OrderedDict
+            return self.cache[key]
+
+    def put(self, key: Any, value: Any) -> None:
+        if key in self.cache:
+            self.cache.move_to_end(key)
+        self.cache[key] = value
+        if len(self.cache) > self.capacity:
+            self.cache.popitem(last=False)  # pop the first item

+ 30 - 18
api/core/helper/moderation.py

@@ -1,18 +1,27 @@
 import logging
 import logging
 import random
 import random
 
 
-import openai
-
-from core.model_providers.error import LLMBadRequestError
-from core.model_providers.providers.base import BaseModelProvider
-from core.model_providers.providers.hosted import hosted_config, hosted_model_providers
+from core.entities.application_entities import ModelConfigEntity
+from core.model_runtime.errors.invoke import InvokeBadRequestError
+from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel
+from extensions.ext_hosting_provider import hosting_configuration
 from models.provider import ProviderType
 from models.provider import ProviderType
 
 
+logger = logging.getLogger(__name__)
+
+
+def check_moderation(model_config: ModelConfigEntity, text: str) -> bool:
+    moderation_config = hosting_configuration.moderation_config
+    if (moderation_config and moderation_config.enabled is True
+            and 'openai' in hosting_configuration.provider_map
+            and hosting_configuration.provider_map['openai'].enabled is True
+    ):
+        using_provider_type = model_config.provider_model_bundle.configuration.using_provider_type
+        provider_name = model_config.provider
+        if using_provider_type == ProviderType.SYSTEM \
+                and provider_name in moderation_config.providers:
+            hosting_openai_config = hosting_configuration.provider_map['openai']
 
 
-def check_moderation(model_provider: BaseModelProvider, text: str) -> bool:
-    if hosted_config.moderation.enabled is True and hosted_model_providers.openai:
-        if model_provider.provider.provider_type == ProviderType.SYSTEM.value \
-                and model_provider.provider_name in hosted_config.moderation.providers:
             # 2000 text per chunk
             # 2000 text per chunk
             length = 2000
             length = 2000
             text_chunks = [text[i:i + length] for i in range(0, len(text), length)]
             text_chunks = [text[i:i + length] for i in range(0, len(text), length)]
@@ -23,14 +32,17 @@ def check_moderation(model_provider: BaseModelProvider, text: str) -> bool:
             text_chunk = random.choice(text_chunks)
             text_chunk = random.choice(text_chunks)
 
 
             try:
             try:
-                moderation_result = openai.Moderation.create(input=text_chunk,
-                                                             api_key=hosted_model_providers.openai.api_key)
+                model_type_instance = OpenAIModerationModel()
+                moderation_result = model_type_instance.invoke(
+                    model='text-moderation-stable',
+                    credentials=hosting_openai_config.credentials,
+                    text=text_chunk
+                )
+
+                if moderation_result is True:
+                    return True
             except Exception as ex:
             except Exception as ex:
-                logging.exception(ex)
-                raise LLMBadRequestError('Rate limit exceeded, please try again later.')
-
-            for result in moderation_result.results:
-                if result['flagged'] is True:
-                    return False
+                logger.exception(ex)
+                raise InvokeBadRequestError('Rate limit exceeded, please try again later.')
 
 
-    return True
+    return False

+ 213 - 0
api/core/hosting_configuration.py

@@ -0,0 +1,213 @@
+import os
+from typing import Optional
+
+from flask import Flask
+from pydantic import BaseModel
+
+from core.entities.provider_entities import QuotaUnit
+from models.provider import ProviderQuotaType
+
+
+class HostingQuota(BaseModel):
+    quota_type: ProviderQuotaType
+    restrict_llms: list[str] = []
+
+
+class TrialHostingQuota(HostingQuota):
+    quota_type: ProviderQuotaType = ProviderQuotaType.TRIAL
+    quota_limit: int = 0
+    """Quota limit for the hosting provider models. -1 means unlimited."""
+
+
+class PaidHostingQuota(HostingQuota):
+    quota_type: ProviderQuotaType = ProviderQuotaType.PAID
+    stripe_price_id: str = None
+    increase_quota: int = 1
+    min_quantity: int = 20
+    max_quantity: int = 100
+
+
+class FreeHostingQuota(HostingQuota):
+    quota_type: ProviderQuotaType = ProviderQuotaType.FREE
+
+
+class HostingProvider(BaseModel):
+    enabled: bool = False
+    credentials: Optional[dict] = None
+    quota_unit: Optional[QuotaUnit] = None
+    quotas: list[HostingQuota] = []
+
+
+class HostedModerationConfig(BaseModel):
+    enabled: bool = False
+    providers: list[str] = []
+
+
+class HostingConfiguration:
+    provider_map: dict[str, HostingProvider] = {}
+    moderation_config: HostedModerationConfig = None
+
+    def init_app(self, app: Flask):
+        if app.config.get('EDITION') != 'CLOUD':
+            return
+
+        self.provider_map["openai"] = self.init_openai()
+        self.provider_map["anthropic"] = self.init_anthropic()
+        self.provider_map["minimax"] = self.init_minimax()
+        self.provider_map["spark"] = self.init_spark()
+        self.provider_map["zhipuai"] = self.init_zhipuai()
+
+        self.moderation_config = self.init_moderation_config()
+
+    def init_openai(self) -> HostingProvider:
+        quota_unit = QuotaUnit.TIMES
+        if os.environ.get("HOSTED_OPENAI_ENABLED") and os.environ.get("HOSTED_OPENAI_ENABLED").lower() == 'true':
+            credentials = {
+                "openai_api_key": os.environ.get("HOSTED_OPENAI_API_KEY"),
+            }
+
+            if os.environ.get("HOSTED_OPENAI_API_BASE"):
+                credentials["openai_api_base"] = os.environ.get("HOSTED_OPENAI_API_BASE")
+
+            if os.environ.get("HOSTED_OPENAI_API_ORGANIZATION"):
+                credentials["openai_organization"] = os.environ.get("HOSTED_OPENAI_API_ORGANIZATION")
+
+            quotas = []
+            hosted_quota_limit = int(os.environ.get("HOSTED_OPENAI_QUOTA_LIMIT", "200"))
+            if hosted_quota_limit != -1 or hosted_quota_limit > 0:
+                trial_quota = TrialHostingQuota(
+                    quota_limit=hosted_quota_limit,
+                    restrict_llms=[
+                        "gpt-3.5-turbo",
+                        "gpt-3.5-turbo-1106",
+                        "gpt-3.5-turbo-instruct",
+                        "gpt-3.5-turbo-16k",
+                        "text-davinci-003"
+                    ]
+                )
+                quotas.append(trial_quota)
+
+            if os.environ.get("HOSTED_OPENAI_PAID_ENABLED") and os.environ.get(
+                    "HOSTED_OPENAI_PAID_ENABLED").lower() == 'true':
+                paid_quota = PaidHostingQuota(
+                    stripe_price_id=os.environ.get("HOSTED_OPENAI_PAID_STRIPE_PRICE_ID"),
+                    increase_quota=int(os.environ.get("HOSTED_OPENAI_PAID_INCREASE_QUOTA", "1")),
+                    min_quantity=int(os.environ.get("HOSTED_OPENAI_PAID_MIN_QUANTITY", "1")),
+                    max_quantity=int(os.environ.get("HOSTED_OPENAI_PAID_MAX_QUANTITY", "1"))
+                )
+                quotas.append(paid_quota)
+
+            return HostingProvider(
+                enabled=True,
+                credentials=credentials,
+                quota_unit=quota_unit,
+                quotas=quotas
+            )
+
+        return HostingProvider(
+            enabled=False,
+            quota_unit=quota_unit,
+        )
+
+    def init_anthropic(self) -> HostingProvider:
+        quota_unit = QuotaUnit.TOKENS
+        if os.environ.get("HOSTED_ANTHROPIC_ENABLED") and os.environ.get("HOSTED_ANTHROPIC_ENABLED").lower() == 'true':
+            credentials = {
+                "anthropic_api_key": os.environ.get("HOSTED_ANTHROPIC_API_KEY"),
+            }
+
+            if os.environ.get("HOSTED_ANTHROPIC_API_BASE"):
+                credentials["anthropic_api_url"] = os.environ.get("HOSTED_ANTHROPIC_API_BASE")
+
+            quotas = []
+            hosted_quota_limit = int(os.environ.get("HOSTED_ANTHROPIC_QUOTA_LIMIT", "0"))
+            if hosted_quota_limit != -1 or hosted_quota_limit > 0:
+                trial_quota = TrialHostingQuota(
+                    quota_limit=hosted_quota_limit
+                )
+                quotas.append(trial_quota)
+
+            if os.environ.get("HOSTED_ANTHROPIC_PAID_ENABLED") and os.environ.get(
+                    "HOSTED_ANTHROPIC_PAID_ENABLED").lower() == 'true':
+                paid_quota = PaidHostingQuota(
+                    stripe_price_id=os.environ.get("HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID"),
+                    increase_quota=int(os.environ.get("HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA", "1000000")),
+                    min_quantity=int(os.environ.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY", "20")),
+                    max_quantity=int(os.environ.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY", "100"))
+                )
+                quotas.append(paid_quota)
+
+            return HostingProvider(
+                enabled=True,
+                credentials=credentials,
+                quota_unit=quota_unit,
+                quotas=quotas
+            )
+
+        return HostingProvider(
+            enabled=False,
+            quota_unit=quota_unit,
+        )
+
+    def init_minimax(self) -> HostingProvider:
+        quota_unit = QuotaUnit.TOKENS
+        if os.environ.get("HOSTED_MINIMAX_ENABLED") and os.environ.get("HOSTED_MINIMAX_ENABLED").lower() == 'true':
+            quotas = [FreeHostingQuota()]
+
+            return HostingProvider(
+                enabled=True,
+                credentials=None,  # use credentials from the provider
+                quota_unit=quota_unit,
+                quotas=quotas
+            )
+
+        return HostingProvider(
+            enabled=False,
+            quota_unit=quota_unit,
+        )
+
+    def init_spark(self) -> HostingProvider:
+        quota_unit = QuotaUnit.TOKENS
+        if os.environ.get("HOSTED_SPARK_ENABLED") and os.environ.get("HOSTED_SPARK_ENABLED").lower() == 'true':
+            quotas = [FreeHostingQuota()]
+
+            return HostingProvider(
+                enabled=True,
+                credentials=None,  # use credentials from the provider
+                quota_unit=quota_unit,
+                quotas=quotas
+            )
+
+        return HostingProvider(
+            enabled=False,
+            quota_unit=quota_unit,
+        )
+
+    def init_zhipuai(self) -> HostingProvider:
+        quota_unit = QuotaUnit.TOKENS
+        if os.environ.get("HOSTED_ZHIPUAI_ENABLED") and os.environ.get("HOSTED_ZHIPUAI_ENABLED").lower() == 'true':
+            quotas = [FreeHostingQuota()]
+
+            return HostingProvider(
+                enabled=True,
+                credentials=None,  # use credentials from the provider
+                quota_unit=quota_unit,
+                quotas=quotas
+            )
+
+        return HostingProvider(
+            enabled=False,
+            quota_unit=quota_unit,
+        )
+
+    def init_moderation_config(self) -> HostedModerationConfig:
+        if os.environ.get("HOSTED_MODERATION_ENABLED") and os.environ.get("HOSTED_MODERATION_ENABLED").lower() == 'true' \
+                and os.environ.get("HOSTED_MODERATION_PROVIDERS"):
+            return HostedModerationConfig(
+                enabled=True,
+                providers=os.environ.get("HOSTED_MODERATION_PROVIDERS").split(',')
+            )
+
+        return HostedModerationConfig(
+            enabled=False
+        )

+ 7 - 11
api/core/index/index.py

@@ -1,18 +1,12 @@
-import json
-
 from flask import current_app
 from flask import current_app
 from langchain.embeddings import OpenAIEmbeddings
 from langchain.embeddings import OpenAIEmbeddings
 
 
 from core.embedding.cached_embedding import CacheEmbedding
 from core.embedding.cached_embedding import CacheEmbedding
 from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
 from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
 from core.index.vector_index.vector_index import VectorIndex
 from core.index.vector_index.vector_index import VectorIndex
-from core.model_providers.model_factory import ModelFactory
-from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
-from core.model_providers.models.entity.model_params import ModelKwargs
-from core.model_providers.models.llm.openai_model import OpenAIModel
-from core.model_providers.providers.openai_provider import OpenAIProvider
+from core.model_manager import ModelManager
+from core.model_runtime.entities.model_entities import ModelType
 from models.dataset import Dataset
 from models.dataset import Dataset
-from models.provider import Provider, ProviderType
 
 
 
 
 class IndexBuilder:
 class IndexBuilder:
@@ -22,10 +16,12 @@ class IndexBuilder:
             if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality':
             if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality':
                 return None
                 return None
 
 
-            embedding_model = ModelFactory.get_embedding_model(
+            model_manager = ModelManager()
+            embedding_model = model_manager.get_model_instance(
                 tenant_id=dataset.tenant_id,
                 tenant_id=dataset.tenant_id,
-                model_provider_name=dataset.embedding_model_provider,
-                model_name=dataset.embedding_model
+                model_type=ModelType.TEXT_EMBEDDING,
+                provider=dataset.embedding_model_provider,
+                model=dataset.embedding_model
             )
             )
 
 
             embeddings = CacheEmbedding(embedding_model)
             embeddings = CacheEmbedding(embedding_model)

+ 111 - 41
api/core/indexing_runner.py

@@ -18,9 +18,11 @@ from core.data_loader.loader.notion import NotionLoader
 from core.docstore.dataset_docstore import DatasetDocumentStore
 from core.docstore.dataset_docstore import DatasetDocumentStore
 from core.generator.llm_generator import LLMGenerator
 from core.generator.llm_generator import LLMGenerator
 from core.index.index import IndexBuilder
 from core.index.index import IndexBuilder
-from core.model_providers.error import ProviderTokenNotInitError
-from core.model_providers.model_factory import ModelFactory
-from core.model_providers.models.entity.message import MessageType
+from core.model_manager import ModelManager
+from core.errors.error import ProviderTokenNotInitError
+from core.model_runtime.entities.model_entities import ModelType, PriceType
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
 from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
 from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
 from extensions.ext_database import db
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from extensions.ext_redis import redis_client
@@ -36,6 +38,7 @@ class IndexingRunner:
 
 
     def __init__(self):
     def __init__(self):
         self.storage = storage
         self.storage = storage
+        self.model_manager = ModelManager()
 
 
     def run(self, dataset_documents: List[DatasetDocument]):
     def run(self, dataset_documents: List[DatasetDocument]):
         """Run the indexing process."""
         """Run the indexing process."""
@@ -210,7 +213,7 @@ class IndexingRunner:
         """
         """
         Estimate the indexing for the document.
         Estimate the indexing for the document.
         """
         """
-        embedding_model = None
+        embedding_model_instance = None
         if dataset_id:
         if dataset_id:
             dataset = Dataset.query.filter_by(
             dataset = Dataset.query.filter_by(
                 id=dataset_id
                 id=dataset_id
@@ -218,15 +221,17 @@ class IndexingRunner:
             if not dataset:
             if not dataset:
                 raise ValueError('Dataset not found.')
                 raise ValueError('Dataset not found.')
             if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
             if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
-                embedding_model = ModelFactory.get_embedding_model(
-                    tenant_id=dataset.tenant_id,
-                    model_provider_name=dataset.embedding_model_provider,
-                    model_name=dataset.embedding_model
+                embedding_model_instance = self.model_manager.get_model_instance(
+                    tenant_id=tenant_id,
+                    provider=dataset.embedding_model_provider,
+                    model_type=ModelType.TEXT_EMBEDDING,
+                    model=dataset.embedding_model
                 )
                 )
         else:
         else:
             if indexing_technique == 'high_quality':
             if indexing_technique == 'high_quality':
-                embedding_model = ModelFactory.get_embedding_model(
-                    tenant_id=tenant_id
+                embedding_model_instance = self.model_manager.get_default_model_instance(
+                    tenant_id=tenant_id,
+                    model_type=ModelType.TEXT_EMBEDDING,
                 )
                 )
         tokens = 0
         tokens = 0
         preview_texts = []
         preview_texts = []
@@ -255,32 +260,56 @@ class IndexingRunner:
             for document in documents:
             for document in documents:
                 if len(preview_texts) < 5:
                 if len(preview_texts) < 5:
                     preview_texts.append(document.page_content)
                     preview_texts.append(document.page_content)
-                if indexing_technique == 'high_quality' or embedding_model:
-                    tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content))
+                if indexing_technique == 'high_quality' or embedding_model_instance:
+                    embedding_model_type_instance = embedding_model_instance.model_type_instance
+                    embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
+                    tokens += embedding_model_type_instance.get_num_tokens(
+                        model=embedding_model_instance.model,
+                        credentials=embedding_model_instance.credentials,
+                        texts=[self.filter_string(document.page_content)]
+                    )
 
 
         if doc_form and doc_form == 'qa_model':
         if doc_form and doc_form == 'qa_model':
-            text_generation_model = ModelFactory.get_text_generation_model(
-                tenant_id=tenant_id
+            model_instance = self.model_manager.get_default_model_instance(
+                tenant_id=tenant_id,
+                model_type=ModelType.LLM
             )
             )
+
+            model_type_instance = model_instance.model_type_instance
+            model_type_instance = cast(LargeLanguageModel, model_type_instance)
+
             if len(preview_texts) > 0:
             if len(preview_texts) > 0:
                 # qa model document
                 # qa model document
                 response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
                 response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
                                                              doc_language)
                                                              doc_language)
                 document_qa_list = self.format_split_text(response)
                 document_qa_list = self.format_split_text(response)
+                price_info = model_type_instance.get_price(
+                    model=model_instance.model,
+                    credentials=model_instance.credentials,
+                    price_type=PriceType.INPUT,
+                    tokens=total_segments * 2000,
+                )
                 return {
                 return {
                     "total_segments": total_segments * 20,
                     "total_segments": total_segments * 20,
                     "tokens": total_segments * 2000,
                     "tokens": total_segments * 2000,
-                    "total_price": '{:f}'.format(
-                        text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)),
-                    "currency": embedding_model.get_currency(),
+                    "total_price": '{:f}'.format(price_info.total_amount),
+                    "currency": price_info.currency,
                     "qa_preview": document_qa_list,
                     "qa_preview": document_qa_list,
                     "preview": preview_texts
                     "preview": preview_texts
                 }
                 }
+        if embedding_model_instance:
+            embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_instance.model_type_instance)
+            embedding_price_info = embedding_model_type_instance.get_price(
+                model=embedding_model_instance.model,
+                credentials=embedding_model_instance.credentials,
+                price_type=PriceType.INPUT,
+                tokens=tokens
+            )
         return {
         return {
             "total_segments": total_segments,
             "total_segments": total_segments,
             "tokens": tokens,
             "tokens": tokens,
-            "total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)) if embedding_model else 0,
-            "currency": embedding_model.get_currency() if embedding_model else 'USD',
+            "total_price": '{:f}'.format(embedding_price_info.total_amount) if embedding_model_instance else 0,
+            "currency": embedding_price_info.currency if embedding_model_instance else 'USD',
             "preview": preview_texts
             "preview": preview_texts
         }
         }
 
 
@@ -290,7 +319,7 @@ class IndexingRunner:
         """
         """
         Estimate the indexing for the document.
         Estimate the indexing for the document.
         """
         """
-        embedding_model = None
+        embedding_model_instance = None
         if dataset_id:
         if dataset_id:
             dataset = Dataset.query.filter_by(
             dataset = Dataset.query.filter_by(
                 id=dataset_id
                 id=dataset_id
@@ -298,15 +327,17 @@ class IndexingRunner:
             if not dataset:
             if not dataset:
                 raise ValueError('Dataset not found.')
                 raise ValueError('Dataset not found.')
             if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
             if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
-                embedding_model = ModelFactory.get_embedding_model(
-                    tenant_id=dataset.tenant_id,
-                    model_provider_name=dataset.embedding_model_provider,
-                    model_name=dataset.embedding_model
+                embedding_model_instance = self.model_manager.get_model_instance(
+                    tenant_id=tenant_id,
+                    provider=dataset.embedding_model_provider,
+                    model_type=ModelType.TEXT_EMBEDDING,
+                    model=dataset.embedding_model
                 )
                 )
         else:
         else:
             if indexing_technique == 'high_quality':
             if indexing_technique == 'high_quality':
-                embedding_model = ModelFactory.get_embedding_model(
-                    tenant_id=tenant_id
+                embedding_model_instance = self.model_manager.get_default_model_instance(
+                    tenant_id=tenant_id,
+                    model_type=ModelType.TEXT_EMBEDDING
                 )
                 )
         # load data from notion
         # load data from notion
         tokens = 0
         tokens = 0
@@ -349,35 +380,63 @@ class IndexingRunner:
                     processing_rule=processing_rule
                     processing_rule=processing_rule
                 )
                 )
                 total_segments += len(documents)
                 total_segments += len(documents)
+
+                embedding_model_type_instance = embedding_model_instance.model_type_instance
+                embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
+
                 for document in documents:
                 for document in documents:
                     if len(preview_texts) < 5:
                     if len(preview_texts) < 5:
                         preview_texts.append(document.page_content)
                         preview_texts.append(document.page_content)
-                    if indexing_technique == 'high_quality' or embedding_model:
-                        tokens += embedding_model.get_num_tokens(document.page_content)
+                    if indexing_technique == 'high_quality' or embedding_model_instance:
+                        tokens += embedding_model_type_instance.get_num_tokens(
+                            model=embedding_model_instance.model,
+                            credentials=embedding_model_instance.credentials,
+                            texts=[document.page_content]
+                        )
 
 
         if doc_form and doc_form == 'qa_model':
         if doc_form and doc_form == 'qa_model':
-            text_generation_model = ModelFactory.get_text_generation_model(
-                tenant_id=tenant_id
+            model_instance = self.model_manager.get_default_model_instance(
+                tenant_id=tenant_id,
+                model_type=ModelType.LLM
             )
             )
+
+            model_type_instance = model_instance.model_type_instance
+            model_type_instance = cast(LargeLanguageModel, model_type_instance)
             if len(preview_texts) > 0:
             if len(preview_texts) > 0:
                 # qa model document
                 # qa model document
                 response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
                 response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
                                                              doc_language)
                                                              doc_language)
                 document_qa_list = self.format_split_text(response)
                 document_qa_list = self.format_split_text(response)
+
+                price_info = model_type_instance.get_price(
+                    model=model_instance.model,
+                    credentials=model_instance.credentials,
+                    price_type=PriceType.INPUT,
+                    tokens=total_segments * 2000,
+                )
+
                 return {
                 return {
                     "total_segments": total_segments * 20,
                     "total_segments": total_segments * 20,
                     "tokens": total_segments * 2000,
                     "tokens": total_segments * 2000,
-                    "total_price": '{:f}'.format(
-                        text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)),
-                    "currency": embedding_model.get_currency(),
+                    "total_price": '{:f}'.format(price_info.total_amount),
+                    "currency": price_info.currency,
                     "qa_preview": document_qa_list,
                     "qa_preview": document_qa_list,
                     "preview": preview_texts
                     "preview": preview_texts
                 }
                 }
+
+        embedding_model_type_instance = embedding_model_instance.model_type_instance
+        embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
+        embedding_price_info = embedding_model_type_instance.get_price(
+            model=embedding_model_instance.model,
+            credentials=embedding_model_instance.credentials,
+            price_type=PriceType.INPUT,
+            tokens=tokens
+        )
         return {
         return {
             "total_segments": total_segments,
             "total_segments": total_segments,
             "tokens": tokens,
             "tokens": tokens,
-            "total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)) if embedding_model else 0,
-            "currency": embedding_model.get_currency() if embedding_model else 'USD',
+            "total_price": '{:f}'.format(embedding_price_info.total_amount) if embedding_model_instance else 0,
+            "currency": embedding_price_info.currency if embedding_model_instance else 'USD',
             "preview": preview_texts
             "preview": preview_texts
         }
         }
 
 
@@ -656,25 +715,36 @@ class IndexingRunner:
         """
         """
         vector_index = IndexBuilder.get_index(dataset, 'high_quality')
         vector_index = IndexBuilder.get_index(dataset, 'high_quality')
         keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
         keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
-        embedding_model = None
+        embedding_model_instance = None
         if dataset.indexing_technique == 'high_quality':
         if dataset.indexing_technique == 'high_quality':
-            embedding_model = ModelFactory.get_embedding_model(
+            embedding_model_instance = self.model_manager.get_model_instance(
                 tenant_id=dataset.tenant_id,
                 tenant_id=dataset.tenant_id,
-                model_provider_name=dataset.embedding_model_provider,
-                model_name=dataset.embedding_model
+                provider=dataset.embedding_model_provider,
+                model_type=ModelType.TEXT_EMBEDDING,
+                model=dataset.embedding_model
             )
             )
 
 
         # chunk nodes by chunk size
         # chunk nodes by chunk size
         indexing_start_at = time.perf_counter()
         indexing_start_at = time.perf_counter()
         tokens = 0
         tokens = 0
         chunk_size = 100
         chunk_size = 100
+
+        embedding_model_type_instance = None
+        if embedding_model_instance:
+            embedding_model_type_instance = embedding_model_instance.model_type_instance
+            embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
+
         for i in range(0, len(documents), chunk_size):
         for i in range(0, len(documents), chunk_size):
             # check document is paused
             # check document is paused
             self._check_document_paused_status(dataset_document.id)
             self._check_document_paused_status(dataset_document.id)
             chunk_documents = documents[i:i + chunk_size]
             chunk_documents = documents[i:i + chunk_size]
-            if dataset.indexing_technique == 'high_quality' or embedding_model:
+            if dataset.indexing_technique == 'high_quality' or embedding_model_type_instance:
                 tokens += sum(
                 tokens += sum(
-                    embedding_model.get_num_tokens(document.page_content)
+                    embedding_model_type_instance.get_num_tokens(
+                        embedding_model_instance.model,
+                        embedding_model_instance.credentials,
+                        [document.page_content]
+                    )
                     for document in chunk_documents
                     for document in chunk_documents
                 )
                 )
 
 

+ 0 - 95
api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py

@@ -1,95 +0,0 @@
-from typing import Any, List, Dict
-
-from langchain.memory.chat_memory import BaseChatMemory
-from langchain.schema import get_buffer_string, BaseMessage
-
-from core.file.message_file_parser import MessageFileParser
-from core.model_providers.models.entity.message import PromptMessage, MessageType, to_lc_messages
-from core.model_providers.models.llm.base import BaseLLM
-from extensions.ext_database import db
-from models.model import Conversation, Message
-
-
-class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
-    conversation: Conversation
-    human_prefix: str = "Human"
-    ai_prefix: str = "Assistant"
-    model_instance: BaseLLM
-    memory_key: str = "chat_history"
-    max_token_limit: int = 2000
-    message_limit: int = 10
-
-    @property
-    def buffer(self) -> List[BaseMessage]:
-        """String buffer of memory."""
-        app_model = self.conversation.app
-
-        # fetch limited messages desc, and return reversed
-        messages = db.session.query(Message).filter(
-            Message.conversation_id == self.conversation.id,
-            Message.answer != ''
-        ).order_by(Message.created_at.desc()).limit(self.message_limit).all()
-
-        messages = list(reversed(messages))
-        message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=self.conversation.app_id)
-
-        chat_messages: List[PromptMessage] = []
-        for message in messages:
-            files = message.message_files
-            if files:
-                file_objs = message_file_parser.transform_message_files(
-                    files, message.app_model_config
-                )
-
-                prompt_message_files = [file_obj.prompt_message_file for file_obj in file_objs]
-                chat_messages.append(PromptMessage(
-                    content=message.query,
-                    type=MessageType.USER,
-                    files=prompt_message_files
-                ))
-            else:
-                chat_messages.append(PromptMessage(content=message.query, type=MessageType.USER))
-
-            chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT))
-
-        if not chat_messages:
-            return []
-
-        # prune the chat message if it exceeds the max token limit
-        curr_buffer_length = self.model_instance.get_num_tokens(chat_messages)
-        if curr_buffer_length > self.max_token_limit:
-            pruned_memory = []
-            while curr_buffer_length > self.max_token_limit and chat_messages:
-                pruned_memory.append(chat_messages.pop(0))
-                curr_buffer_length = self.model_instance.get_num_tokens(chat_messages)
-
-        return to_lc_messages(chat_messages)
-
-    @property
-    def memory_variables(self) -> List[str]:
-        """Will always return list of memory variables.
-
-        :meta private:
-        """
-        return [self.memory_key]
-
-    def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
-        """Return history buffer."""
-        buffer: Any = self.buffer
-        if self.return_messages:
-            final_buffer: Any = buffer
-        else:
-            final_buffer = get_buffer_string(
-                buffer,
-                human_prefix=self.human_prefix,
-                ai_prefix=self.ai_prefix,
-            )
-        return {self.memory_key: final_buffer}
-
-    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
-        """Nothing should be saved or changed"""
-        pass
-
-    def clear(self) -> None:
-        """Nothing to clear, got a memory like a vault."""
-        pass

+ 0 - 36
api/core/memory/read_only_conversation_token_db_string_buffer_shared_memory.py

@@ -1,36 +0,0 @@
-from typing import Any, List, Dict
-
-from langchain.memory.chat_memory import BaseChatMemory
-from langchain.schema import get_buffer_string
-
-from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
-    ReadOnlyConversationTokenDBBufferSharedMemory
-
-
-class ReadOnlyConversationTokenDBStringBufferSharedMemory(BaseChatMemory):
-    memory: ReadOnlyConversationTokenDBBufferSharedMemory
-
-    @property
-    def memory_variables(self) -> List[str]:
-        """Return memory variables."""
-        return self.memory.memory_variables
-
-    def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
-        """Load memory variables from memory."""
-        buffer: Any = self.memory.buffer
-
-        final_buffer = get_buffer_string(
-            buffer,
-            human_prefix=self.memory.human_prefix,
-            ai_prefix=self.memory.ai_prefix,
-        )
-
-        return {self.memory.memory_key: final_buffer}
-
-    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
-        """Nothing should be saved or changed"""
-        pass
-
-    def clear(self) -> None:
-        """Nothing to clear, got a memory like a vault."""
-        pass

+ 109 - 0
api/core/memory/token_buffer_memory.py

@@ -0,0 +1,109 @@
+from core.file.message_file_parser import MessageFileParser
+from core.model_manager import ModelInstance
+from core.model_runtime.entities.message_entities import PromptMessage, TextPromptMessageContent, UserPromptMessage, \
+    AssistantPromptMessage, PromptMessageRole
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.model_providers import model_provider_factory
+from extensions.ext_database import db
+from models.model import Conversation, Message
+
+
+class TokenBufferMemory:
+    def __init__(self, conversation: Conversation, model_instance: ModelInstance) -> None:
+        self.conversation = conversation
+        self.model_instance = model_instance
+
+    def get_history_prompt_messages(self, max_token_limit: int = 2000,
+                                    message_limit: int = 10) -> list[PromptMessage]:
+        """
+        Get history prompt messages.
+        :param max_token_limit: max token limit
+        :param message_limit: message limit
+        """
+        app_record = self.conversation.app
+
+        # fetch limited messages, and return reversed
+        messages = db.session.query(Message).filter(
+            Message.conversation_id == self.conversation.id,
+            Message.answer != ''
+        ).order_by(Message.created_at.desc()).limit(message_limit).all()
+
+        messages = list(reversed(messages))
+        message_file_parser = MessageFileParser(
+            tenant_id=app_record.tenant_id,
+            app_id=app_record.id
+        )
+
+        prompt_messages = []
+        for message in messages:
+            files = message.message_files
+            if files:
+                file_objs = message_file_parser.transform_message_files(
+                    files, message.app_model_config
+                )
+
+                prompt_message_contents = [TextPromptMessageContent(data=message.query)]
+                for file_obj in file_objs:
+                    prompt_message_contents.append(file_obj.prompt_message_content)
+
+                prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
+            else:
+                prompt_messages.append(UserPromptMessage(content=message.query))
+
+            prompt_messages.append(AssistantPromptMessage(content=message.answer))
+
+        if not prompt_messages:
+            return []
+
+        # prune the chat message if it exceeds the max token limit
+        provider_instance = model_provider_factory.get_provider_instance(self.model_instance.provider)
+        model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
+
+        curr_message_tokens = model_type_instance.get_num_tokens(
+            self.model_instance.model,
+            self.model_instance.credentials,
+            prompt_messages
+        )
+
+        if curr_message_tokens > max_token_limit:
+            pruned_memory = []
+            while curr_message_tokens > max_token_limit and prompt_messages:
+                pruned_memory.append(prompt_messages.pop(0))
+                curr_message_tokens = model_type_instance.get_num_tokens(
+                    self.model_instance.model,
+                    self.model_instance.credentials,
+                    prompt_messages
+                )
+
+        return prompt_messages
+
+    def get_history_prompt_text(self, human_prefix: str = "Human",
+                                ai_prefix: str = "Assistant",
+                                max_token_limit: int = 2000,
+                                message_limit: int = 10) -> str:
+        """
+        Get history prompt text.
+        :param human_prefix: human prefix
+        :param ai_prefix: ai prefix
+        :param max_token_limit: max token limit
+        :param message_limit: message limit
+        :return:
+        """
+        prompt_messages = self.get_history_prompt_messages(
+            max_token_limit=max_token_limit,
+            message_limit=message_limit
+        )
+
+        string_messages = []
+        for m in prompt_messages:
+            if m.role == PromptMessageRole.USER:
+                role = human_prefix
+            elif m.role == PromptMessageRole.ASSISTANT:
+                role = ai_prefix
+            else:
+                continue
+
+            message = f"{role}: {m.content}"
+            string_messages.append(message)
+
+        return "\n".join(string_messages)

+ 209 - 0
api/core/model_manager.py

@@ -0,0 +1,209 @@
+from typing import Optional, Union, Generator, cast, List, IO
+
+from core.entities.provider_configuration import ProviderModelBundle
+from core.errors.error import ProviderTokenNotInitError
+from core.model_runtime.callbacks.base_callback import Callback
+from core.model_runtime.entities.llm_entities import LLMResult
+from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.entities.rerank_entities import RerankResult
+from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
+from core.model_runtime.model_providers.__base.rerank_model import RerankModel
+from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
+from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+from core.provider_manager import ProviderManager
+
+
+class ModelInstance:
+    """
+    Model instance class
+    """
+
+    def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> None:
+        self._provider_model_bundle = provider_model_bundle
+        self.model = model
+        self.provider = provider_model_bundle.configuration.provider.provider
+        self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
+        self.model_type_instance = self._provider_model_bundle.model_type_instance
+
+    def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBundle, model: str) -> dict:
+        """
+        Fetch credentials from provider model bundle
+        :param provider_model_bundle: provider model bundle
+        :param model: model name
+        :return:
+        """
+        credentials = provider_model_bundle.configuration.get_current_credentials(
+            model_type=provider_model_bundle.model_type_instance.model_type,
+            model=model
+        )
+
+        if credentials is None:
+            raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.")
+
+        return credentials
+
+    def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
+                   tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
+                   stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
+            -> Union[LLMResult, Generator]:
+        """
+        Invoke large language model
+
+        :param prompt_messages: prompt messages
+        :param model_parameters: model parameters
+        :param tools: tools for tool calling
+        :param stop: stop words
+        :param stream: is stream response
+        :param user: unique user id
+        :param callbacks: callbacks
+        :return: full response or stream response chunk generator result
+        """
+        if not isinstance(self.model_type_instance, LargeLanguageModel):
+            raise Exception(f"Model type instance is not LargeLanguageModel")
+
+        self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
+        return self.model_type_instance.invoke(
+            model=self.model,
+            credentials=self.credentials,
+            prompt_messages=prompt_messages,
+            model_parameters=model_parameters,
+            tools=tools,
+            stop=stop,
+            stream=stream,
+            user=user,
+            callbacks=callbacks
+        )
+
+    def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \
+            -> TextEmbeddingResult:
+        """
+        Invoke large language model
+
+        :param texts: texts to embed
+        :param user: unique user id
+        :return: embeddings result
+        """
+        if not isinstance(self.model_type_instance, TextEmbeddingModel):
+            raise Exception(f"Model type instance is not TextEmbeddingModel")
+
+        self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
+        return self.model_type_instance.invoke(
+            model=self.model,
+            credentials=self.credentials,
+            texts=texts,
+            user=user
+        )
+
+    def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
+                      user: Optional[str] = None) \
+            -> RerankResult:
+        """
+        Invoke rerank model
+
+        :param query: search query
+        :param docs: docs for reranking
+        :param score_threshold: score threshold
+        :param top_n: top n
+        :param user: unique user id
+        :return: rerank result
+        """
+        if not isinstance(self.model_type_instance, RerankModel):
+            raise Exception(f"Model type instance is not RerankModel")
+
+        self.model_type_instance = cast(RerankModel, self.model_type_instance)
+        return self.model_type_instance.invoke(
+            model=self.model,
+            credentials=self.credentials,
+            query=query,
+            docs=docs,
+            score_threshold=score_threshold,
+            top_n=top_n,
+            user=user
+        )
+
+    def invoke_moderation(self, text: str, user: Optional[str] = None) \
+            -> bool:
+        """
+        Invoke moderation model
+
+        :param text: text to moderate
+        :param user: unique user id
+        :return: false if text is safe, true otherwise
+        """
+        if not isinstance(self.model_type_instance, ModerationModel):
+            raise Exception(f"Model type instance is not ModerationModel")
+
+        self.model_type_instance = cast(ModerationModel, self.model_type_instance)
+        return self.model_type_instance.invoke(
+            model=self.model,
+            credentials=self.credentials,
+            text=text,
+            user=user
+        )
+
+    def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \
+            -> str:
+        """
+        Invoke large language model
+
+        :param file: audio file
+        :param user: unique user id
+        :return: text for given audio file
+        """
+        if not isinstance(self.model_type_instance, Speech2TextModel):
+            raise Exception(f"Model type instance is not Speech2TextModel")
+
+        self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
+        return self.model_type_instance.invoke(
+            model=self.model,
+            credentials=self.credentials,
+            file=file,
+            user=user
+        )
+
+
+class ModelManager:
+    def __init__(self) -> None:
+        self._provider_manager = ProviderManager()
+
+    def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance:
+        """
+        Get model instance
+        :param tenant_id: tenant id
+        :param provider: provider name
+        :param model_type: model type
+        :param model: model name
+        :return:
+        """
+        provider_model_bundle = self._provider_manager.get_provider_model_bundle(
+            tenant_id=tenant_id,
+            provider=provider,
+            model_type=model_type
+        )
+
+        return ModelInstance(provider_model_bundle, model)
+
+    def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance:
+        """
+        Get default model instance
+        :param tenant_id: tenant id
+        :param model_type: model type
+        :return:
+        """
+        default_model_entity = self._provider_manager.get_default_model(
+            tenant_id=tenant_id,
+            model_type=model_type
+        )
+
+        if not default_model_entity:
+            raise ProviderTokenNotInitError(f"Default model not found for {model_type}")
+
+        return self.get_model_instance(
+            tenant_id=tenant_id,
+            provider=default_model_entity.provider.provider,
+            model_type=model_type,
+            model=default_model_entity.model
+        )

+ 0 - 335
api/core/model_providers/model_factory.py

@@ -1,335 +0,0 @@
-from typing import Optional
-
-from langchain.callbacks.base import Callbacks
-
-from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
-from core.model_providers.model_provider_factory import ModelProviderFactory, DEFAULT_MODELS
-from core.model_providers.models.base import BaseProviderModel
-from core.model_providers.models.embedding.base import BaseEmbedding
-from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
-from core.model_providers.models.llm.base import BaseLLM
-from core.model_providers.models.moderation.base import BaseModeration
-from core.model_providers.models.reranking.base import BaseReranking
-from core.model_providers.models.speech2text.base import BaseSpeech2Text
-from extensions.ext_database import db
-from models.provider import TenantDefaultModel
-
-
-class ModelFactory:
-
-    @classmethod
-    def get_text_generation_model_from_model_config(cls, tenant_id: str,
-                                                    model_config: dict,
-                                                    streaming: bool = False,
-                                                    callbacks: Callbacks = None) -> Optional[BaseLLM]:
-        provider_name = model_config.get("provider")
-        model_name = model_config.get("name")
-        completion_params = model_config.get("completion_params", {})
-
-        return cls.get_text_generation_model(
-            tenant_id=tenant_id,
-            model_provider_name=provider_name,
-            model_name=model_name,
-            model_kwargs=ModelKwargs(
-                temperature=completion_params.get('temperature', 0),
-                max_tokens=completion_params.get('max_tokens', 256),
-                top_p=completion_params.get('top_p', 0),
-                frequency_penalty=completion_params.get('frequency_penalty', 0.1),
-                presence_penalty=completion_params.get('presence_penalty', 0.1)
-            ),
-            streaming=streaming,
-            callbacks=callbacks
-        )
-
-    @classmethod
-    def get_text_generation_model(cls,
-                                  tenant_id: str,
-                                  model_provider_name: Optional[str] = None,
-                                  model_name: Optional[str] = None,
-                                  model_kwargs: Optional[ModelKwargs] = None,
-                                  streaming: bool = False,
-                                  callbacks: Callbacks = None,
-                                  deduct_quota: bool = True) -> Optional[BaseLLM]:
-        """
-        get text generation model.
-
-        :param tenant_id: a string representing the ID of the tenant.
-        :param model_provider_name:
-        :param model_name:
-        :param model_kwargs:
-        :param streaming:
-        :param callbacks:
-        :param deduct_quota:
-        :return:
-        """
-        is_default_model = False
-        if model_provider_name is None and model_name is None:
-            default_model = cls.get_default_model(tenant_id, ModelType.TEXT_GENERATION)
-
-            if not default_model:
-                raise LLMBadRequestError(f"Default model is not available. "
-                                         f"Please configure a Default System Reasoning Model "
-                                         f"in the Settings -> Model Provider.")
-
-            model_provider_name = default_model.provider_name
-            model_name = default_model.model_name
-            is_default_model = True
-
-        # get model provider
-        model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
-
-        if not model_provider:
-            raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
-
-        # init text generation model
-        model_class = model_provider.get_model_class(model_type=ModelType.TEXT_GENERATION)
-
-        try:
-            model_instance = model_class(
-                model_provider=model_provider,
-                name=model_name,
-                model_kwargs=model_kwargs,
-                streaming=streaming,
-                callbacks=callbacks
-            )
-        except LLMBadRequestError as e:
-            if is_default_model:
-                raise LLMBadRequestError(f"Default model {model_name} is not available. "
-                                         f"Please check your model provider credentials.")
-            else:
-                raise e
-
-        if is_default_model or not deduct_quota:
-            model_instance.deduct_quota = False
-
-        return model_instance
-
-    @classmethod
-    def get_embedding_model(cls,
-                            tenant_id: str,
-                            model_provider_name: Optional[str] = None,
-                            model_name: Optional[str] = None) -> Optional[BaseEmbedding]:
-        """
-        get embedding model.
-
-        :param tenant_id: a string representing the ID of the tenant.
-        :param model_provider_name:
-        :param model_name:
-        :return:
-        """
-        if model_provider_name is None and model_name is None:
-            default_model = cls.get_default_model(tenant_id, ModelType.EMBEDDINGS)
-
-            if not default_model:
-                raise LLMBadRequestError(f"Default model is not available. "
-                                         f"Please configure a Default Embedding Model "
-                                         f"in the Settings -> Model Provider.")
-
-            model_provider_name = default_model.provider_name
-            model_name = default_model.model_name
-
-        # get model provider
-        model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
-
-        if not model_provider:
-            raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
-
-        # init embedding model
-        model_class = model_provider.get_model_class(model_type=ModelType.EMBEDDINGS)
-        return model_class(
-            model_provider=model_provider,
-            name=model_name
-        )
-
-
-    @classmethod
-    def get_reranking_model(cls,
-                            tenant_id: str,
-                            model_provider_name: Optional[str] = None,
-                            model_name: Optional[str] = None) -> Optional[BaseReranking]:
-        """
-        get reranking model.
-
-        :param tenant_id: a string representing the ID of the tenant.
-        :param model_provider_name:
-        :param model_name:
-        :return:
-        """
-        if (model_provider_name is None or len(model_provider_name) == 0) and (model_name is None or len(model_name) == 0):
-            default_model = cls.get_default_model(tenant_id, ModelType.RERANKING)
-
-            if not default_model:
-                raise LLMBadRequestError(f"Default model is not available. "
-                                         f"Please configure a Default Reranking Model "
-                                         f"in the Settings -> Model Provider.")
-
-            model_provider_name = default_model.provider_name
-            model_name = default_model.model_name
-
-        # get model provider
-        model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
-
-        if not model_provider:
-            raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
-
-        # init reranking model
-        model_class = model_provider.get_model_class(model_type=ModelType.RERANKING)
-        return model_class(
-            model_provider=model_provider,
-            name=model_name
-        )
-
-    @classmethod
-    def get_speech2text_model(cls,
-                              tenant_id: str,
-                              model_provider_name: Optional[str] = None,
-                              model_name: Optional[str] = None) -> Optional[BaseSpeech2Text]:
-        """
-        get speech to text model.
-
-        :param tenant_id: a string representing the ID of the tenant.
-        :param model_provider_name:
-        :param model_name:
-        :return:
-        """
-        if model_provider_name is None and model_name is None:
-            default_model = cls.get_default_model(tenant_id, ModelType.SPEECH_TO_TEXT)
-
-            if not default_model:
-                raise LLMBadRequestError(f"Default model is not available. "
-                                         f"Please configure a Default Speech-to-Text Model "
-                                         f"in the Settings -> Model Provider.")
-
-            model_provider_name = default_model.provider_name
-            model_name = default_model.model_name
-
-        # get model provider
-        model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
-
-        if not model_provider:
-            raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
-
-        # init speech to text model
-        model_class = model_provider.get_model_class(model_type=ModelType.SPEECH_TO_TEXT)
-        return model_class(
-            model_provider=model_provider,
-            name=model_name
-        )
-
-    @classmethod
-    def get_moderation_model(cls,
-                             tenant_id: str,
-                             model_provider_name: str,
-                             model_name: str) -> Optional[BaseModeration]:
-        """
-        get moderation model.
-
-        :param tenant_id: a string representing the ID of the tenant.
-        :param model_provider_name:
-        :param model_name:
-        :return:
-        """
-        # get model provider
-        model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
-
-        if not model_provider:
-            raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
-
-        # init moderation model
-        model_class = model_provider.get_model_class(model_type=ModelType.MODERATION)
-        return model_class(
-            model_provider=model_provider,
-            name=model_name
-        )
-
-    @classmethod
-    def get_default_model(cls, tenant_id: str, model_type: ModelType) -> TenantDefaultModel:
-        """
-        get default model of model type.
-
-        :param tenant_id:
-        :param model_type:
-        :return:
-        """
-        # get default model
-        default_model = db.session.query(TenantDefaultModel) \
-            .filter(
-            TenantDefaultModel.tenant_id == tenant_id,
-            TenantDefaultModel.model_type == model_type.value
-        ).first()
-
-        if not default_model:
-            model_provider_rules = ModelProviderFactory.get_provider_rules()
-            for model_provider_name, model_provider_rule in model_provider_rules.items():
-                model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
-                if not model_provider:
-                    continue
-
-                model_list = model_provider.get_supported_model_list(model_type)
-                if model_list:
-                    model_info = model_list[0]
-                    default_model = TenantDefaultModel(
-                        tenant_id=tenant_id,
-                        model_type=model_type.value,
-                        provider_name=model_provider_name,
-                        model_name=model_info['id']
-                    )
-                    db.session.add(default_model)
-                    db.session.commit()
-                    break
-
-        return default_model
-
-    @classmethod
-    def update_default_model(cls,
-                             tenant_id: str,
-                             model_type: ModelType,
-                             provider_name: str,
-                             model_name: str) -> TenantDefaultModel:
-        """
-        update default model of model type.
-
-        :param tenant_id:
-        :param model_type:
-        :param provider_name:
-        :param model_name:
-        :return:
-        """
-        model_provider_name = ModelProviderFactory.get_provider_names()
-        if provider_name not in model_provider_name:
-            raise ValueError(f'Invalid provider name: {provider_name}')
-
-        model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, provider_name)
-
-        if not model_provider:
-            raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
-
-        model_list = model_provider.get_supported_model_list(model_type)
-        model_ids = [model['id'] for model in model_list]
-        if model_name not in model_ids:
-            raise ValueError(f'Invalid model name: {model_name}')
-
-        # get default model
-        default_model = db.session.query(TenantDefaultModel) \
-            .filter(
-            TenantDefaultModel.tenant_id == tenant_id,
-            TenantDefaultModel.model_type == model_type.value
-        ).first()
-
-        if default_model:
-            # update default model
-            default_model.provider_name = provider_name
-            default_model.model_name = model_name
-            db.session.commit()
-        else:
-            # create default model
-            default_model = TenantDefaultModel(
-                tenant_id=tenant_id,
-                model_type=model_type.value,
-                provider_name=provider_name,
-                model_name=model_name,
-            )
-            db.session.add(default_model)
-            db.session.commit()
-
-        return default_model

+ 0 - 276
api/core/model_providers/model_provider_factory.py

@@ -1,276 +0,0 @@
-from typing import Type
-
-from sqlalchemy.exc import IntegrityError
-
-from core.model_providers.models.entity.model_params import ModelType
-from core.model_providers.providers.base import BaseModelProvider
-from core.model_providers.rules import provider_rules
-from extensions.ext_database import db
-from models.provider import TenantPreferredModelProvider, ProviderType, Provider, ProviderQuotaType
-
-DEFAULT_MODELS = {
-    ModelType.TEXT_GENERATION.value: {
-        'provider_name': 'openai',
-        'model_name': 'gpt-3.5-turbo',
-    },
-    ModelType.EMBEDDINGS.value: {
-        'provider_name': 'openai',
-        'model_name': 'text-embedding-ada-002',
-    },
-    ModelType.SPEECH_TO_TEXT.value: {
-        'provider_name': 'openai',
-        'model_name': 'whisper-1',
-    }
-}
-
-
-class ModelProviderFactory:
-    @classmethod
-    def get_model_provider_class(cls, provider_name: str) -> Type[BaseModelProvider]:
-        if provider_name == 'openai':
-            from core.model_providers.providers.openai_provider import OpenAIProvider
-            return OpenAIProvider
-        elif provider_name == 'anthropic':
-            from core.model_providers.providers.anthropic_provider import AnthropicProvider
-            return AnthropicProvider
-        elif provider_name == 'minimax':
-            from core.model_providers.providers.minimax_provider import MinimaxProvider
-            return MinimaxProvider
-        elif provider_name == 'spark':
-            from core.model_providers.providers.spark_provider import SparkProvider
-            return SparkProvider
-        elif provider_name == 'tongyi':
-            from core.model_providers.providers.tongyi_provider import TongyiProvider
-            return TongyiProvider
-        elif provider_name == 'wenxin':
-            from core.model_providers.providers.wenxin_provider import WenxinProvider
-            return WenxinProvider
-        elif provider_name == 'zhipuai':
-            from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider
-            return ZhipuAIProvider
-        elif provider_name == 'chatglm':
-            from core.model_providers.providers.chatglm_provider import ChatGLMProvider
-            return ChatGLMProvider
-        elif provider_name == 'baichuan':
-            from core.model_providers.providers.baichuan_provider import BaichuanProvider
-            return BaichuanProvider
-        elif provider_name == 'azure_openai':
-            from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider
-            return AzureOpenAIProvider
-        elif provider_name == 'replicate':
-            from core.model_providers.providers.replicate_provider import ReplicateProvider
-            return ReplicateProvider
-        elif provider_name == 'huggingface_hub':
-            from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
-            return HuggingfaceHubProvider
-        elif provider_name == 'xinference':
-            from core.model_providers.providers.xinference_provider import XinferenceProvider
-            return XinferenceProvider
-        elif provider_name == 'openllm':
-            from core.model_providers.providers.openllm_provider import OpenLLMProvider
-            return OpenLLMProvider
-        elif provider_name == 'localai':
-            from core.model_providers.providers.localai_provider import LocalAIProvider
-            return LocalAIProvider
-        elif provider_name == 'cohere':
-            from core.model_providers.providers.cohere_provider import CohereProvider
-            return CohereProvider
-        elif provider_name == 'jina':
-            from core.model_providers.providers.jina_provider import JinaProvider
-            return JinaProvider
-        else:
-            raise NotImplementedError
-
-    @classmethod
-    def get_provider_names(cls):
-        """
-        Returns a list of provider names.
-        """
-        return list(provider_rules.keys())
-
-    @classmethod
-    def get_provider_rules(cls):
-        """
-        Returns a list of provider rules.
-
-        :return:
-        """
-        return provider_rules
-
-    @classmethod
-    def get_provider_rule(cls, provider_name: str):
-        """
-        Returns provider rule.
-        """
-        return provider_rules[provider_name]
-
-    @classmethod
-    def get_preferred_model_provider(cls, tenant_id: str, model_provider_name: str):
-        """
-        get preferred model provider.
-
-        :param tenant_id: a string representing the ID of the tenant.
-        :param model_provider_name:
-        :return:
-        """
-        # get preferred provider
-        preferred_provider = cls._get_preferred_provider(tenant_id, model_provider_name)
-        if not preferred_provider or not preferred_provider.is_valid:
-            return None
-
-        # init model provider
-        model_provider_class = ModelProviderFactory.get_model_provider_class(model_provider_name)
-        return model_provider_class(provider=preferred_provider)
-
-    @classmethod
-    def get_preferred_type_by_preferred_model_provider(cls,
-                                                       tenant_id: str,
-                                                       model_provider_name: str,
-                                                       preferred_model_provider: TenantPreferredModelProvider):
-        """
-        get preferred provider type by preferred model provider.
-
-        :param model_provider_name:
-        :param preferred_model_provider:
-        :return:
-        """
-        if not preferred_model_provider:
-            model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name)
-            support_provider_types = model_provider_rules['support_provider_types']
-
-            if ProviderType.CUSTOM.value in support_provider_types:
-                custom_provider = db.session.query(Provider) \
-                    .filter(
-                        Provider.tenant_id == tenant_id,
-                        Provider.provider_name == model_provider_name,
-                        Provider.provider_type == ProviderType.CUSTOM.value,
-                        Provider.is_valid == True
-                    ).first()
-
-                if custom_provider:
-                    return ProviderType.CUSTOM.value
-
-            model_provider = cls.get_model_provider_class(model_provider_name)
-
-            if ProviderType.SYSTEM.value in support_provider_types \
-                    and model_provider.is_provider_type_system_supported():
-                return ProviderType.SYSTEM.value
-            elif ProviderType.CUSTOM.value in support_provider_types:
-                return ProviderType.CUSTOM.value
-        else:
-            return preferred_model_provider.preferred_provider_type
-
-    @classmethod
-    def _get_preferred_provider(cls, tenant_id: str, model_provider_name: str):
-        """
-        get preferred provider of tenant.
-
-        :param tenant_id:
-        :param model_provider_name:
-        :return:
-        """
-        # get preferred provider type
-        preferred_provider_type = cls._get_preferred_provider_type(tenant_id, model_provider_name)
-
-        # get providers by preferred provider type
-        providers = db.session.query(Provider) \
-            .filter(
-                Provider.tenant_id == tenant_id,
-                Provider.provider_name == model_provider_name,
-                Provider.provider_type == preferred_provider_type
-            ).all()
-
-        no_system_provider = False
-        if preferred_provider_type == ProviderType.SYSTEM.value:
-            quota_type_to_provider_dict = {}
-            for provider in providers:
-                quota_type_to_provider_dict[provider.quota_type] = provider
-
-            model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name)
-            for quota_type_enum in ProviderQuotaType:
-                quota_type = quota_type_enum.value
-                if quota_type in model_provider_rules['system_config']['supported_quota_types']:
-                    if quota_type in quota_type_to_provider_dict.keys():
-                        provider = quota_type_to_provider_dict[quota_type]
-                        if provider.is_valid and provider.quota_limit > provider.quota_used:
-                            return provider
-                    elif quota_type == ProviderQuotaType.TRIAL.value:
-                        try:
-                            provider = Provider(
-                                tenant_id=tenant_id,
-                                provider_name=model_provider_name,
-                                provider_type=ProviderType.SYSTEM.value,
-                                is_valid=True,
-                                quota_type=ProviderQuotaType.TRIAL.value,
-                                quota_limit=model_provider_rules['system_config']['quota_limit'],
-                                quota_used=0
-                            )
-                            db.session.add(provider)
-                            db.session.commit()
-                        except IntegrityError:
-                            db.session.rollback()
-                            provider = db.session.query(Provider) \
-                                .filter(
-                                Provider.tenant_id == tenant_id,
-                                Provider.provider_name == model_provider_name,
-                                Provider.provider_type == ProviderType.SYSTEM.value,
-                                Provider.quota_type == ProviderQuotaType.TRIAL.value
-                            ).first()
-
-                        if provider.quota_limit == 0:
-                            return None
-
-                        return provider
-
-            no_system_provider = True
-
-        if no_system_provider:
-            providers = db.session.query(Provider) \
-                .filter(
-                Provider.tenant_id == tenant_id,
-                Provider.provider_name == model_provider_name,
-                Provider.provider_type == ProviderType.CUSTOM.value
-            ).all()
-
-        if preferred_provider_type == ProviderType.CUSTOM.value or no_system_provider:
-            if providers:
-                return providers[0]
-            else:
-                try:
-                    provider = Provider(
-                        tenant_id=tenant_id,
-                        provider_name=model_provider_name,
-                        provider_type=ProviderType.CUSTOM.value,
-                        is_valid=False
-                    )
-                    db.session.add(provider)
-                    db.session.commit()
-                except IntegrityError:
-                    db.session.rollback()
-                    provider = db.session.query(Provider) \
-                        .filter(
-                            Provider.tenant_id == tenant_id,
-                            Provider.provider_name == model_provider_name,
-                            Provider.provider_type == ProviderType.CUSTOM.value
-                        ).first()
-
-                return provider
-
-        return None
-
-    @classmethod
-    def _get_preferred_provider_type(cls, tenant_id: str, model_provider_name: str):
-        """
-        get preferred provider type of tenant.
-
-        :param tenant_id:
-        :param model_provider_name:
-        :return:
-        """
-        preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
-            .filter(
-            TenantPreferredModelProvider.tenant_id == tenant_id,
-            TenantPreferredModelProvider.provider_name == model_provider_name
-        ).first()
-
-        return cls.get_preferred_type_by_preferred_model_provider(tenant_id, model_provider_name, preferred_model_provider)

Some files were not shown because too many files changed in this diff