Bläddra i källkod

Model Runtime (#1858)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: Garfield Dai <dai.hai@foxmail.com>
Co-authored-by: chenhe <guchenhe@gmail.com>
Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: Yeuoly <admin@srmxy.cn>
takatost 1 år sedan
förälder
incheckning
d069c668f8
100 ändrade filer med 6613 tillägg och 3373 borttagningar
  1. 58 0
      .github/workflows/api-model-runtime-tests.yml
  2. 0 38
      .github/workflows/api-unit-tests.yml
  3. 5 0
      CONTRIBUTING.md
  4. 15 0
      api/.vscode/launch.json
  5. 0 3
      api/Dockerfile
  6. 20 14
      api/app.py
  7. 15 13
      api/commands.py
  8. 1 1
      api/config.py
  9. 1 1
      api/controllers/console/__init__.py
  10. 44 34
      api/controllers/console/app/app.py
  11. 4 4
      api/controllers/console/app/audio.py
  12. 13 14
      api/controllers/console/app/completion.py
  13. 4 4
      api/controllers/console/app/generator.py
  14. 13 9
      api/controllers/console/app/message.py
  15. 1 1
      api/controllers/console/app/model_config.py
  16. 30 18
      api/controllers/console/datasets/datasets.py
  17. 11 6
      api/controllers/console/datasets/datasets_document.py
  18. 18 11
      api/controllers/console/datasets/datasets_segments.py
  19. 1 1
      api/controllers/console/datasets/hit_testing.py
  20. 3 4
      api/controllers/console/explore/audio.py
  21. 11 13
      api/controllers/console/explore/completion.py
  22. 17 12
      api/controllers/console/explore/message.py
  23. 3 4
      api/controllers/console/universal_chat/audio.py
  24. 8 9
      api/controllers/console/universal_chat/chat.py
  25. 3 4
      api/controllers/console/universal_chat/message.py
  26. 96 176
      api/controllers/console/workspace/model_providers.py
  27. 212 56
      api/controllers/console/workspace/models.py
  28. 0 131
      api/controllers/console/workspace/providers.py
  29. 0 1
      api/controllers/console/workspace/workspace.py
  30. 3 4
      api/controllers/service_api/app/audio.py
  31. 11 13
      api/controllers/service_api/app/completion.py
  32. 15 7
      api/controllers/service_api/dataset/dataset.py
  33. 1 1
      api/controllers/service_api/dataset/document.py
  34. 18 11
      api/controllers/service_api/dataset/segment.py
  35. 3 4
      api/controllers/web/audio.py
  36. 11 13
      api/controllers/web/completion.py
  37. 14 9
      api/controllers/web/message.py
  38. 101 0
      api/core/agent/agent/agent_llm_callback.py
  39. 33 12
      api/core/agent/agent/calc_token_mixin.py
  40. 38 16
      api/core/agent/agent/multi_dataset_router_agent.py
  41. 91 34
      api/core/agent/agent/openai_function_call.py
  42. 0 158
      api/core/agent/agent/output_parser/retirver_dataset_agent.py
  43. 18 14
      api/core/agent/agent/structed_multi_dataset_router_agent.py
  44. 22 13
      api/core/agent/agent/structured_chat.py
  45. 33 23
      api/core/agent/agent_executor.py
  46. 0 0
      api/core/app_runner/__init__.py
  47. 251 0
      api/core/app_runner/agent_app_runner.py
  48. 267 0
      api/core/app_runner/app_runner.py
  49. 363 0
      api/core/app_runner/basic_app_runner.py
  50. 483 0
      api/core/app_runner/generate_task_pipeline.py
  51. 138 0
      api/core/app_runner/moderation_handler.py
  52. 655 0
      api/core/application_manager.py
  53. 228 0
      api/core/application_queue_manager.py
  54. 110 65
      api/core/callback_handler/agent_loop_gather_callback_handler.py
  55. 0 74
      api/core/callback_handler/dataset_tool_callback_handler.py
  56. 0 16
      api/core/callback_handler/entity/chain_result.py
  57. 0 6
      api/core/callback_handler/entity/dataset_query.py
  58. 0 8
      api/core/callback_handler/entity/llm_message.py
  59. 56 6
      api/core/callback_handler/index_tool_callback_handler.py
  60. 0 284
      api/core/callback_handler/llm_callback_handler.py
  61. 0 76
      api/core/callback_handler/main_chain_gather_callback_handler.py
  62. 5 2
      api/core/callback_handler/std_out_callback_handler.py
  63. 21 8
      api/core/chain/llm_chain.py
  64. 0 501
      api/core/completion.py
  65. 0 517
      api/core/conversation_message_task.py
  66. 19 6
      api/core/docstore/dataset_docstore.py
  67. 26 13
      api/core/embedding/cached_embedding.py
  68. 0 0
      api/core/entities/__init__.py
  69. 265 0
      api/core/entities/application_entities.py
  70. 128 0
      api/core/entities/message_entities.py
  71. 71 0
      api/core/entities/model_entities.py
  72. 657 0
      api/core/entities/provider_configuration.py
  73. 67 0
      api/core/entities/provider_entities.py
  74. 118 0
      api/core/entities/queue_entities.py
  75. 0 0
      api/core/errors/__init__.py
  76. 0 20
      api/core/errors/error.py
  77. 0 0
      api/core/external_data_tool/weather_search/__init__.py
  78. 35 0
      api/core/external_data_tool/weather_search/schema.json
  79. 45 0
      api/core/external_data_tool/weather_search/weather_search.py
  80. 0 0
      api/core/features/__init__.py
  81. 325 0
      api/core/features/agent_runner.py
  82. 119 0
      api/core/features/annotation_reply.py
  83. 181 0
      api/core/features/dataset_retrieval.py
  84. 96 0
      api/core/features/external_data_fetch.py
  85. 32 0
      api/core/features/hosting_moderation.py
  86. 50 0
      api/core/features/moderation.py
  87. 5 5
      api/core/file/file_obj.py
  88. 63 40
      api/core/generator/llm_generator.py
  89. 14 0
      api/core/helper/encrypter.py
  90. 22 0
      api/core/helper/lru_cache.py
  91. 30 18
      api/core/helper/moderation.py
  92. 213 0
      api/core/hosting_configuration.py
  93. 7 11
      api/core/index/index.py
  94. 111 41
      api/core/indexing_runner.py
  95. 0 95
      api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py
  96. 0 36
      api/core/memory/read_only_conversation_token_db_string_buffer_shared_memory.py
  97. 109 0
      api/core/memory/token_buffer_memory.py
  98. 209 0
      api/core/model_manager.py
  99. 0 335
      api/core/model_providers/model_factory.py
  100. 0 276
      api/core/model_providers/model_provider_factory.py

+ 58 - 0
.github/workflows/api-model-runtime-tests.yml

@@ -0,0 +1,58 @@
+name: Run Pytest
+
+on:
+  pull_request:
+    branches:
+      - main
+  push:
+    branches:
+      - deploy/dev
+      - feat/model-runtime
+
+jobs:
+  test:
+    runs-on: ubuntu-latest
+
+    env:
+      OPENAI_API_KEY: sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii
+      AZURE_OPENAI_API_BASE: https://difyai-openai.openai.azure.com
+      AZURE_OPENAI_API_KEY: xxxxb1707exxxxxxxxxxaaxxxxxf94
+      ANTHROPIC_API_KEY: sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz
+      CHATGLM_API_BASE: http://a.abc.com:11451
+      XINFERENCE_SERVER_URL: http://a.abc.com:11451
+      XINFERENCE_GENERATION_MODEL_UID: generate
+      XINFERENCE_CHAT_MODEL_UID: chat
+      XINFERENCE_EMBEDDINGS_MODEL_UID: embedding
+      XINFERENCE_RERANK_MODEL_UID: rerank
+      GOOGLE_API_KEY: abcdefghijklmnopqrstuvwxyz
+      HUGGINGFACE_API_KEY: hf-awuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwu
+      HUGGINGFACE_TEXT_GEN_ENDPOINT_URL: a
+      HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL: b
+      HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL: c
+      MOCK_SWITCH: true
+
+
+    steps:
+    - name: Checkout code
+      uses: actions/checkout@v2
+
+    - name: Set up Python
+      uses: actions/setup-python@v2
+      with:
+        python-version: '3.10'
+
+    - name: Cache pip dependencies
+      uses: actions/cache@v2
+      with:
+        path: ~/.cache/pip
+        key: ${{ runner.os }}-pip-${{ hashFiles('api/requirements.txt') }}
+        restore-keys: ${{ runner.os }}-pip-
+
+    - name: Install dependencies
+      run: |
+        python -m pip install --upgrade pip
+        pip install pytest
+        pip install -r api/requirements.txt
+
+    - name: Run pytest
+      run: pytest api/tests/integration_tests/model_runtime/anthropic api/tests/integration_tests/model_runtime/azure_openai api/tests/integration_tests/model_runtime/openai api/tests/integration_tests/model_runtime/chatglm api/tests/integration_tests/model_runtime/google api/tests/integration_tests/model_runtime/xinference api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py

+ 0 - 38
.github/workflows/api-unit-tests.yml

@@ -1,38 +0,0 @@
-name: Run Pytest
-
-on:
-  pull_request:
-    branches:
-      - main
-  push:
-    branches:
-      - deploy/dev
-
-jobs:
-  test:
-    runs-on: ubuntu-latest
-
-    steps:
-    - name: Checkout code
-      uses: actions/checkout@v2
-
-    - name: Set up Python
-      uses: actions/setup-python@v2
-      with:
-        python-version: '3.10'
-
-    - name: Cache pip dependencies
-      uses: actions/cache@v2
-      with:
-        path: ~/.cache/pip
-        key: ${{ runner.os }}-pip-${{ hashFiles('api/requirements.txt') }}
-        restore-keys: ${{ runner.os }}-pip-
-
-    - name: Install dependencies
-      run: |
-        python -m pip install --upgrade pip
-        pip install pytest
-        pip install -r api/requirements.txt
-
-    - name: Run pytest
-      run: pytest api/tests/unit_tests

+ 5 - 0
CONTRIBUTING.md

@@ -55,6 +55,11 @@ Did you have an issue, like a merge conflict, or don't know how to open a pull r
 
 Stuck somewhere? Have any questions? Join the [Discord Community Server](https://discord.gg/j3XRWSPBf7). We are here to help!
 
+
+### Provider Integrations
+If you see a model provider not yet supported by Dify that you'd like to use, follow these [steps](api/core/model_runtime/README.md) to submit a PR.
+
+
 ### i18n (Internationalization) Support
 
 We are looking for contributors to help with translations in other languages. If you are interested in helping, please join the [Discord Community Server](https://discord.gg/AhzKf7dNgk) and let us know.  

+ 15 - 0
api/.vscode/launch.json

@@ -4,6 +4,21 @@
     // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
     "version": "0.2.0",
     "configurations": [
+        {
+            "name": "Python: Celery",
+            "type": "python",
+            "request": "launch",
+            "module": "celery",
+            "justMyCode": true,
+            "args": ["-A", "app.celery", "worker", "-P", "gevent", "-c", "1", "--loglevel", "info", "-Q", "dataset,generation,mail"],
+            "envFile": "${workspaceFolder}/.env",
+            "env": {
+                "FLASK_APP": "app.py",
+                "FLASK_DEBUG": "1",
+                "GEVENT_SUPPORT": "True"
+            },
+            "console": "integratedTerminal"
+        },
         {
             "name": "Python: Flask",
             "type": "python",

+ 0 - 3
api/Dockerfile

@@ -34,9 +34,6 @@ RUN apt-get update \
 COPY --from=base /pkg /usr/local
 COPY . /app/api/
 
-RUN python -c "from transformers import GPT2TokenizerFast; GPT2TokenizerFast.from_pretrained('gpt2')"
-ENV TRANSFORMERS_OFFLINE true
-
 COPY docker/entrypoint.sh /entrypoint.sh
 RUN chmod +x /entrypoint.sh
 

+ 20 - 14
api/app.py

@@ -6,9 +6,12 @@ from werkzeug.exceptions import Unauthorized
 if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
     from gevent import monkey
     monkey.patch_all()
-    if os.environ.get("VECTOR_STORE") == 'milvus':
-        import grpc.experimental.gevent
-        grpc.experimental.gevent.init_gevent()
+    # if os.environ.get("VECTOR_STORE") == 'milvus':
+    import grpc.experimental.gevent
+    grpc.experimental.gevent.init_gevent()
+
+    import langchain
+    langchain.verbose = True
 
 import time
 import logging
@@ -18,9 +21,8 @@ import threading
 from flask import Flask, request, Response
 from flask_cors import CORS
 
-from core.model_providers.providers import hosted
 from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
-    ext_database, ext_storage, ext_mail, ext_code_based_extension
+    ext_database, ext_storage, ext_mail, ext_code_based_extension, ext_hosting_provider
 from extensions.ext_database import db
 from extensions.ext_login import login_manager
 
@@ -79,8 +81,6 @@ def create_app(test_config=None) -> Flask:
     register_blueprints(app)
     register_commands(app)
 
-    hosted.init_app(app)
-
     return app
 
 
@@ -95,6 +95,7 @@ def initialize_extensions(app):
     ext_celery.init_app(app)
     ext_login.init_app(app)
     ext_mail.init_app(app)
+    ext_hosting_provider.init_app(app)
     ext_sentry.init_app(app)
 
 
@@ -105,13 +106,18 @@ def load_user_from_request(request_from_flask_login):
     if request.blueprint == 'console':
         # Check if the user_id contains a dot, indicating the old format
         auth_header = request.headers.get('Authorization', '')
-        if ' ' not in auth_header:
-            raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
-        auth_scheme, auth_token = auth_header.split(None, 1)
-        auth_scheme = auth_scheme.lower()
-        if auth_scheme != 'bearer':
-            raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
-        
+        if not auth_header:
+            auth_token = request.args.get('_token')
+            if not auth_token:
+                raise Unauthorized('Invalid Authorization token.')
+        else:
+            if ' ' not in auth_header:
+                raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
+            auth_scheme, auth_token = auth_header.split(None, 1)
+            auth_scheme = auth_scheme.lower()
+            if auth_scheme != 'bearer':
+                raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
+
         decoded = PassportService().verify(auth_token)
         user_id = decoded.get('user_id')
 

+ 15 - 13
api/commands.py

@@ -12,16 +12,12 @@ import qdrant_client
 from qdrant_client.http.models import TextIndexParams, TextIndexType, TokenizerType
 from tqdm import tqdm
 from flask import current_app, Flask
-from langchain.embeddings import OpenAIEmbeddings
 from werkzeug.exceptions import NotFound
 
 from core.embedding.cached_embedding import CacheEmbedding
 from core.index.index import IndexBuilder
-from core.model_providers.model_factory import ModelFactory
-from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
-from core.model_providers.models.entity.model_params import ModelType
-from core.model_providers.providers.hosted import hosted_model_providers
-from core.model_providers.providers.openai_provider import OpenAIProvider
+from core.model_manager import ModelManager
+from core.model_runtime.entities.model_entities import ModelType
 from libs.password import password_pattern, valid_password, hash_password
 from libs.helper import email as email_validate
 from extensions.ext_database import db
@@ -327,6 +323,8 @@ def create_qdrant_indexes():
         except NotFound:
             break
 
+        model_manager = ModelManager()
+
         page += 1
         for dataset in datasets:
             if dataset.index_struct_dict:
@@ -334,19 +332,23 @@ def create_qdrant_indexes():
                     try:
                         click.echo('Create dataset qdrant index: {}'.format(dataset.id))
                         try:
-                            embedding_model = ModelFactory.get_embedding_model(
+                            embedding_model = model_manager.get_model_instance(
                                 tenant_id=dataset.tenant_id,
-                                model_provider_name=dataset.embedding_model_provider,
-                                model_name=dataset.embedding_model
+                                provider=dataset.embedding_model_provider,
+                                model_type=ModelType.TEXT_EMBEDDING,
+                                model=dataset.embedding_model
+
                             )
                         except Exception:
                             try:
-                                embedding_model = ModelFactory.get_embedding_model(
-                                    tenant_id=dataset.tenant_id
+                                embedding_model = model_manager.get_default_model_instance(
+                                    tenant_id=dataset.tenant_id,
+                                    model_type=ModelType.TEXT_EMBEDDING,
                                 )
-                                dataset.embedding_model = embedding_model.name
-                                dataset.embedding_model_provider = embedding_model.model_provider.provider_name
+                                dataset.embedding_model = embedding_model.model
+                                dataset.embedding_model_provider = embedding_model.provider
                             except Exception:
+
                                 provider = Provider(
                                     id='provider_id',
                                     tenant_id=dataset.tenant_id,

+ 1 - 1
api/config.py

@@ -87,7 +87,7 @@ class Config:
         # ------------------------
         # General Configurations.
         # ------------------------
-        self.CURRENT_VERSION = "0.3.34"
+        self.CURRENT_VERSION = "0.4.0"
         self.COMMIT_SHA = get_env('COMMIT_SHA')
         self.EDITION = "SELF_HOSTED"
         self.DEPLOY_ENV = get_env('DEPLOY_ENV')

+ 1 - 1
api/controllers/console/__init__.py

@@ -18,7 +18,7 @@ from .auth import login, oauth, data_source_oauth, activate
 from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source
 
 # Import workspace controllers
-from .workspace import workspace, members, providers, model_providers, account, tool_providers, models
+from .workspace import workspace, members, model_providers, account, tool_providers, models
 
 # Import explore controllers
 from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio

+ 44 - 34
api/controllers/console/app/app.py

@@ -4,6 +4,10 @@ import logging
 from datetime import datetime
 
 from flask_login import current_user
+
+from core.model_manager import ModelManager
+from core.model_runtime.entities.model_entities import ModelType
+from core.provider_manager import ProviderManager
 from libs.login import login_required
 from flask_restful import Resource, reqparse, marshal_with, abort, inputs
 from werkzeug.exceptions import Forbidden
@@ -13,9 +17,7 @@ from controllers.console import api
 from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
-from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
-from core.model_providers.model_factory import ModelFactory
-from core.model_providers.model_provider_factory import ModelProviderFactory
+from core.errors.error import ProviderTokenNotInitError, LLMBadRequestError
 from events.app_event import app_was_created, app_was_deleted
 from fields.app_fields import app_pagination_fields, app_detail_fields, template_list_fields, \
     app_detail_fields_with_site
@@ -73,39 +75,41 @@ class AppListApi(Resource):
             raise Forbidden()
 
         try:
-            default_model = ModelFactory.get_text_generation_model(
-                tenant_id=current_user.current_tenant_id
+            provider_manager = ProviderManager()
+            default_model_entity = provider_manager.get_default_model(
+                tenant_id=current_user.current_tenant_id,
+                model_type=ModelType.LLM
             )
         except (ProviderTokenNotInitError, LLMBadRequestError):
-            default_model = None
+            default_model_entity = None
         except Exception as e:
             logging.exception(e)
-            default_model = None
+            default_model_entity = None
 
         if args['model_config'] is not None:
             # validate config
             model_config_dict = args['model_config']
 
             # get model provider
-            model_provider = ModelProviderFactory.get_preferred_model_provider(
-                current_user.current_tenant_id,
-                model_config_dict["model"]["provider"]
+            model_manager = ModelManager()
+            model_instance = model_manager.get_default_model_instance(
+                tenant_id=current_user.current_tenant_id,
+                model_type=ModelType.LLM
             )
 
-            if not model_provider:
-                if not default_model:
-                    raise ProviderNotInitializeError(
-                        f"No Default System Reasoning Model available. Please configure "
-                        f"in the Settings -> Model Provider.")
-                else:
-                    model_config_dict["model"]["provider"] = default_model.model_provider.provider_name
-                    model_config_dict["model"]["name"] = default_model.name
+            if not model_instance:
+                raise ProviderNotInitializeError(
+                    f"No Default System Reasoning Model available. Please configure "
+                    f"in the Settings -> Model Provider.")
+            else:
+                model_config_dict["model"]["provider"] = model_instance.provider
+                model_config_dict["model"]["name"] = model_instance.model
 
             model_configuration = AppModelConfigService.validate_configuration(
                 tenant_id=current_user.current_tenant_id,
                 account=current_user,
                 config=model_config_dict,
-                mode=args['mode']
+                app_mode=args['mode']
             )
 
             app = App(
@@ -129,21 +133,27 @@ class AppListApi(Resource):
             app_model_config = AppModelConfig(**model_config_template['model_config'])
 
             # get model provider
-            model_provider = ModelProviderFactory.get_preferred_model_provider(
-                current_user.current_tenant_id,
-                app_model_config.model_dict["provider"]
-            )
-
-            if not model_provider:
-                if not default_model:
-                    raise ProviderNotInitializeError(
-                        f"No Default System Reasoning Model available. Please configure "
-                        f"in the Settings -> Model Provider.")
-                else:
-                    model_dict = app_model_config.model_dict
-                    model_dict['provider'] = default_model.model_provider.provider_name
-                    model_dict['name'] = default_model.name
-                    app_model_config.model = json.dumps(model_dict)
+            model_manager = ModelManager()
+
+            try:
+                model_instance = model_manager.get_default_model_instance(
+                    tenant_id=current_user.current_tenant_id,
+                    model_type=ModelType.LLM
+                )
+            except ProviderTokenNotInitError:
+                raise ProviderNotInitializeError(
+                    f"No Default System Reasoning Model available. Please configure "
+                    f"in the Settings -> Model Provider.")
+
+            if not model_instance:
+                raise ProviderNotInitializeError(
+                    f"No Default System Reasoning Model available. Please configure "
+                    f"in the Settings -> Model Provider.")
+            else:
+                model_dict = app_model_config.model_dict
+                model_dict['provider'] = model_instance.provider
+                model_dict['name'] = model_instance.model
+                app_model_config.model = json.dumps(model_dict)
 
         app.name = args['name']
         app.mode = args['mode']

+ 4 - 4
api/controllers/console/app/audio.py

@@ -2,6 +2,8 @@
 import logging
 
 from flask import request
+
+from core.model_runtime.errors.invoke import InvokeError
 from libs.login import login_required
 from werkzeug.exceptions import InternalServerError
 
@@ -14,8 +16,7 @@ from controllers.console.app.error import AppUnavailableError, \
     UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
-from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
-    LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
 from flask_restful import Resource
 from services.audio_service import AudioService
 from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
@@ -56,8 +57,7 @@ class ChatMessageAudioApi(Resource):
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
         except ValueError as e:
             raise e

+ 13 - 14
api/controllers/console/app/completion.py

@@ -5,6 +5,10 @@ from typing import Generator, Union
 
 import flask_login
 from flask import Response, stream_with_context
+
+from core.application_queue_manager import ApplicationQueueManager
+from core.entities.application_entities import InvokeFrom
+from core.model_runtime.errors.invoke import InvokeError
 from libs.login import login_required
 from werkzeug.exceptions import InternalServerError, NotFound
 
@@ -16,9 +20,7 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail
     ProviderModelCurrentlyNotSupportError
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
-from core.conversation_message_task import PubHandler
-from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
-    LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
 from libs.helper import uuid_value
 from flask_restful import Resource, reqparse
 
@@ -56,7 +58,7 @@ class CompletionMessageApi(Resource):
                 app_model=app_model,
                 user=account,
                 args=args,
-                from_source='console',
+                invoke_from=InvokeFrom.DEBUGGER,
                 streaming=streaming,
                 is_model_config_override=True
             )
@@ -75,8 +77,7 @@ class CompletionMessageApi(Resource):
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
         except ValueError as e:
             raise e
@@ -97,7 +98,7 @@ class CompletionMessageStopApi(Resource):
 
         account = flask_login.current_user
 
-        PubHandler.stop(account, task_id)
+        ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
 
         return {'result': 'success'}, 200
 
@@ -132,7 +133,7 @@ class ChatMessageApi(Resource):
                 app_model=app_model,
                 user=account,
                 args=args,
-                from_source='console',
+                invoke_from=InvokeFrom.DEBUGGER,
                 streaming=streaming,
                 is_model_config_override=True
             )
@@ -151,8 +152,7 @@ class ChatMessageApi(Resource):
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
         except ValueError as e:
             raise e
@@ -182,9 +182,8 @@ def compact_response(response: Union[dict, Generator]) -> Response:
                 yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
             except ModelCurrentlyNotSupportError:
                 yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
-            except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                    LLMRateLimitError, LLMAuthorizationError) as e:
-                yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
+            except InvokeError as e:
+                yield "data: " + json.dumps(api.handle_error(CompletionRequestError(e.description)).get_json()) + "\n\n"
             except ValueError as e:
                 yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
             except Exception:
@@ -207,7 +206,7 @@ class ChatMessageStopApi(Resource):
 
         account = flask_login.current_user
 
-        PubHandler.stop(account, task_id)
+        ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
 
         return {'result': 'success'}, 200
 

+ 4 - 4
api/controllers/console/app/generator.py

@@ -1,4 +1,6 @@
 from flask_login import current_user
+
+from core.model_runtime.errors.invoke import InvokeError
 from libs.login import login_required
 from flask_restful import Resource, reqparse
 
@@ -8,8 +10,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
 from core.generator.llm_generator import LLMGenerator
-from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \
-    LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
 
 
 class RuleGenerateApi(Resource):
@@ -36,8 +37,7 @@ class RuleGenerateApi(Resource):
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
 
         return rules

+ 13 - 9
api/controllers/console/app/message.py

@@ -14,8 +14,9 @@ from controllers.console.app.error import CompletionRequestError, ProviderNotIni
     AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
-from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
-    ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.entities.application_entities import InvokeFrom
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.model_runtime.errors.invoke import InvokeError
 from libs.login import login_required
 from fields.conversation_fields import message_detail_fields, annotation_fields
 from libs.helper import uuid_value
@@ -208,7 +209,13 @@ class MessageMoreLikeThisApi(Resource):
         app_model = _get_app(app_id, 'completion')
 
         try:
-            response = CompletionService.generate_more_like_this(app_model, current_user, message_id, streaming)
+            response = CompletionService.generate_more_like_this(
+                app_model=app_model,
+                user=current_user,
+                message_id=message_id,
+                invoke_from=InvokeFrom.DEBUGGER,
+                streaming=streaming
+            )
             return compact_response(response)
         except MessageNotExistsError:
             raise NotFound("Message Not Exists.")
@@ -220,8 +227,7 @@ class MessageMoreLikeThisApi(Resource):
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
         except ValueError as e:
             raise e
@@ -249,8 +255,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
             except ModelCurrentlyNotSupportError:
                 yield "data: " + json.dumps(
                     api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
-            except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                    LLMRateLimitError, LLMAuthorizationError) as e:
+            except InvokeError as e:
                 yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
             except ValueError as e:
                 yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
@@ -290,8 +295,7 @@ class MessageSuggestedQuestionApi(Resource):
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
         except Exception:
             logging.exception("internal server error.")

+ 1 - 1
api/controllers/console/app/model_config.py

@@ -31,7 +31,7 @@ class ModelConfigResource(Resource):
             tenant_id=current_user.current_tenant_id,
             account=current_user,
             config=request.json,
-            mode=app.mode
+            app_mode=app.mode
         )
 
         new_app_model_config = AppModelConfig(

+ 30 - 18
api/controllers/console/datasets/datasets.py

@@ -4,6 +4,8 @@ from flask import request, current_app
 from flask_login import current_user
 
 from controllers.console.apikey import api_key_list, api_key_fields
+from core.model_runtime.entities.model_entities import ModelType
+from core.provider_manager import ProviderManager
 from libs.login import login_required
 from flask_restful import Resource, reqparse, marshal, marshal_with
 from werkzeug.exceptions import NotFound, Forbidden
@@ -14,8 +16,7 @@ from controllers.console.datasets.error import DatasetNameDuplicateError
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
 from core.indexing_runner import IndexingRunner
-from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
-from core.model_providers.models.entity.model_params import ModelType
+from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
 from fields.app_fields import related_app_list
 from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
 from fields.document_fields import document_status_fields
@@ -23,7 +24,6 @@ from extensions.ext_database import db
 from models.dataset import DocumentSegment, Document
 from models.model import UploadFile, ApiToken
 from services.dataset_service import DatasetService, DocumentService
-from services.provider_service import ProviderService
 
 
 def _validate_name(name):
@@ -55,16 +55,20 @@ class DatasetListApi(Resource):
                                                           current_user.current_tenant_id, current_user)
 
         # check embedding setting
-        provider_service = ProviderService()
-        valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
-                                                                 ModelType.EMBEDDINGS.value)
-        # if len(valid_model_list) == 0:
-        #     raise ProviderNotInitializeError(
-        #         f"No Embedding Model available. Please configure a valid provider "
-        #         f"in the Settings -> Model Provider.")
+        provider_manager = ProviderManager()
+        configurations = provider_manager.get_configurations(
+            tenant_id=current_user.current_tenant_id
+        )
+
+        embedding_models = configurations.get_models(
+            model_type=ModelType.TEXT_EMBEDDING,
+            only_active=True
+        )
+
         model_names = []
-        for valid_model in valid_model_list:
-            model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
+        for embedding_model in embedding_models:
+            model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
+
         data = marshal(datasets, dataset_detail_fields)
         for item in data:
             if item['indexing_technique'] == 'high_quality':
@@ -75,6 +79,7 @@ class DatasetListApi(Resource):
                     item['embedding_available'] = False
             else:
                 item['embedding_available'] = True
+
         response = {
             'data': data,
             'has_more': len(datasets) == limit,
@@ -130,13 +135,20 @@ class DatasetApi(Resource):
             raise Forbidden(str(e))
         data = marshal(dataset, dataset_detail_fields)
         # check embedding setting
-        provider_service = ProviderService()
-        # get valid model list
-        valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
-                                                                 ModelType.EMBEDDINGS.value)
+        provider_manager = ProviderManager()
+        configurations = provider_manager.get_configurations(
+            tenant_id=current_user.current_tenant_id
+        )
+
+        embedding_models = configurations.get_models(
+            model_type=ModelType.TEXT_EMBEDDING,
+            only_active=True
+        )
+
         model_names = []
-        for valid_model in valid_model_list:
-            model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
+        for embedding_model in embedding_models:
+            model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
+
         if data['indexing_technique'] == 'high_quality':
             item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
             if item_model in model_names:

+ 11 - 6
api/controllers/console/datasets/datasets_document.py

@@ -2,8 +2,12 @@
 from datetime import datetime
 from typing import List
 
-from flask import request, current_app
+from flask import request
 from flask_login import current_user
+
+from core.model_manager import ModelManager
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.invoke import InvokeAuthorizationError
 from libs.login import login_required
 from flask_restful import Resource, fields, marshal, marshal_with, reqparse
 from sqlalchemy import desc, asc
@@ -18,9 +22,8 @@ from controllers.console.datasets.error import DocumentAlreadyFinishedError, Inv
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
 from core.indexing_runner import IndexingRunner
-from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
     LLMBadRequestError
-from core.model_providers.model_factory import ModelFactory
 from extensions.ext_redis import redis_client
 from fields.document_fields import document_with_segments_fields, document_fields, \
     dataset_and_document_fields, document_status_fields
@@ -272,10 +275,12 @@ class DatasetInitApi(Resource):
         args = parser.parse_args()
         if args['indexing_technique'] == 'high_quality':
             try:
-                ModelFactory.get_embedding_model(
-                    tenant_id=current_user.current_tenant_id
+                model_manager = ModelManager()
+                model_manager.get_default_model_instance(
+                    tenant_id=current_user.current_tenant_id,
+                    model_type=ModelType.TEXT_EMBEDDING
                 )
-            except LLMBadRequestError:
+            except InvokeAuthorizationError:
                 raise ProviderNotInitializeError(
                     f"No Embedding Model available. Please configure a valid provider "
                     f"in the Settings -> Model Provider.")

+ 18 - 11
api/controllers/console/datasets/datasets_segments.py

@@ -12,8 +12,9 @@ from controllers.console.app.error import ProviderNotInitializeError
 from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
-from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
-from core.model_providers.model_factory import ModelFactory
+from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
+from core.model_manager import ModelManager
+from core.model_runtime.entities.model_entities import ModelType
 from libs.login import login_required
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
@@ -133,10 +134,12 @@ class DatasetDocumentSegmentApi(Resource):
         if dataset.indexing_technique == 'high_quality':
             # check embedding model setting
             try:
-                ModelFactory.get_embedding_model(
+                model_manager = ModelManager()
+                model_manager.get_model_instance(
                     tenant_id=current_user.current_tenant_id,
-                    model_provider_name=dataset.embedding_model_provider,
-                    model_name=dataset.embedding_model
+                    provider=dataset.embedding_model_provider,
+                    model_type=ModelType.TEXT_EMBEDDING,
+                    model=dataset.embedding_model
                 )
             except LLMBadRequestError:
                 raise ProviderNotInitializeError(
@@ -219,10 +222,12 @@ class DatasetDocumentSegmentAddApi(Resource):
         # check embedding model setting
         if dataset.indexing_technique == 'high_quality':
             try:
-                ModelFactory.get_embedding_model(
+                model_manager = ModelManager()
+                model_manager.get_model_instance(
                     tenant_id=current_user.current_tenant_id,
-                    model_provider_name=dataset.embedding_model_provider,
-                    model_name=dataset.embedding_model
+                    provider=dataset.embedding_model_provider,
+                    model_type=ModelType.TEXT_EMBEDDING,
+                    model=dataset.embedding_model
                 )
             except LLMBadRequestError:
                 raise ProviderNotInitializeError(
@@ -269,10 +274,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
         if dataset.indexing_technique == 'high_quality':
             # check embedding model setting
             try:
-                ModelFactory.get_embedding_model(
+                model_manager = ModelManager()
+                model_manager.get_model_instance(
                     tenant_id=current_user.current_tenant_id,
-                    model_provider_name=dataset.embedding_model_provider,
-                    model_name=dataset.embedding_model
+                    provider=dataset.embedding_model_provider,
+                    model_type=ModelType.TEXT_EMBEDDING,
+                    model=dataset.embedding_model
                 )
             except LLMBadRequestError:
                 raise ProviderNotInitializeError(

+ 1 - 1
api/controllers/console/datasets/hit_testing.py

@@ -12,7 +12,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
 from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
-from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
     LLMBadRequestError
 from fields.hit_testing_fields import hit_testing_record_fields
 from services.dataset_service import DatasetService

+ 3 - 4
api/controllers/console/explore/audio.py

@@ -11,8 +11,8 @@ from controllers.console.app.error import AppUnavailableError, ProviderNotInitia
     NoAudioUploadedError, AudioTooLargeError, \
     UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
 from controllers.console.explore.wraps import InstalledAppResource
-from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
-    LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.model_runtime.errors.invoke import InvokeError
 from services.audio_service import AudioService
 from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
     UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
@@ -53,8 +53,7 @@ class ChatAudioApi(InstalledAppResource):
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
         except ValueError as e:
             raise e

+ 11 - 13
api/controllers/console/explore/completion.py

@@ -15,9 +15,10 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
 from controllers.console.explore.error import NotCompletionAppError, NotChatAppError
 from controllers.console.explore.wraps import InstalledAppResource
-from core.conversation_message_task import PubHandler
-from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
-    LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.application_queue_manager import ApplicationQueueManager
+from core.entities.application_entities import InvokeFrom
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.model_runtime.errors.invoke import InvokeError
 from extensions.ext_database import db
 from libs.helper import uuid_value
 from services.completion_service import CompletionService
@@ -50,7 +51,7 @@ class CompletionApi(InstalledAppResource):
                 app_model=app_model,
                 user=current_user,
                 args=args,
-                from_source='console',
+                invoke_from=InvokeFrom.EXPLORE,
                 streaming=streaming
             )
 
@@ -68,8 +69,7 @@ class CompletionApi(InstalledAppResource):
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
         except ValueError as e:
             raise e
@@ -84,7 +84,7 @@ class CompletionStopApi(InstalledAppResource):
         if app_model.mode != 'completion':
             raise NotCompletionAppError()
 
-        PubHandler.stop(current_user, task_id)
+        ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
 
         return {'result': 'success'}, 200
 
@@ -115,7 +115,7 @@ class ChatApi(InstalledAppResource):
                 app_model=app_model,
                 user=current_user,
                 args=args,
-                from_source='console',
+                invoke_from=InvokeFrom.EXPLORE,
                 streaming=streaming
             )
 
@@ -133,8 +133,7 @@ class ChatApi(InstalledAppResource):
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
         except ValueError as e:
             raise e
@@ -149,7 +148,7 @@ class ChatStopApi(InstalledAppResource):
         if app_model.mode != 'chat':
             raise NotChatAppError()
 
-        PubHandler.stop(current_user, task_id)
+        ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
 
         return {'result': 'success'}, 200
 
@@ -175,8 +174,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
                 yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
             except ModelCurrentlyNotSupportError:
                 yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
-            except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                    LLMRateLimitError, LLMAuthorizationError) as e:
+            except InvokeError as e:
                 yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
             except ValueError as e:
                 yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"

+ 17 - 12
api/controllers/console/explore/message.py

@@ -5,7 +5,7 @@ from typing import Generator, Union
 
 from flask import stream_with_context, Response
 from flask_login import current_user
-from flask_restful import reqparse, fields, marshal_with
+from flask_restful import reqparse, marshal_with
 from flask_restful.inputs import int_range
 from werkzeug.exceptions import NotFound, InternalServerError
 
@@ -13,12 +13,14 @@ import services
 from controllers.console import api
 from controllers.console.app.error import AppMoreLikeThisDisabledError, ProviderNotInitializeError, \
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
-from controllers.console.explore.error import NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError
+from controllers.console.explore.error import NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \
+    NotChatAppError
 from controllers.console.explore.wraps import InstalledAppResource
-from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
-    ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.entities.application_entities import InvokeFrom
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.model_runtime.errors.invoke import InvokeError
 from fields.message_fields import message_infinite_scroll_pagination_fields
-from libs.helper import uuid_value, TimestampField
+from libs.helper import uuid_value
 from services.completion_service import CompletionService
 from services.errors.app import MoreLikeThisDisabledError
 from services.errors.conversation import ConversationNotExistsError
@@ -83,7 +85,13 @@ class MessageMoreLikeThisApi(InstalledAppResource):
         streaming = args['response_mode'] == 'streaming'
 
         try:
-            response = CompletionService.generate_more_like_this(app_model, current_user, message_id, streaming)
+            response = CompletionService.generate_more_like_this(
+                app_model=app_model,
+                user=current_user,
+                message_id=message_id,
+                invoke_from=InvokeFrom.EXPLORE,
+                streaming=streaming
+            )
             return compact_response(response)
         except MessageNotExistsError:
             raise NotFound("Message Not Exists.")
@@ -95,8 +103,7 @@ class MessageMoreLikeThisApi(InstalledAppResource):
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
         except ValueError as e:
             raise e
@@ -123,8 +130,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
                 yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
             except ModelCurrentlyNotSupportError:
                 yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
-            except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                    LLMRateLimitError, LLMAuthorizationError) as e:
+            except InvokeError as e:
                 yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
             except ValueError as e:
                 yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
@@ -162,8 +168,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
         except Exception:
             logging.exception("internal server error.")

+ 3 - 4
api/controllers/console/universal_chat/audio.py

@@ -11,8 +11,8 @@ from controllers.console.app.error import AppUnavailableError, ProviderNotInitia
     NoAudioUploadedError, AudioTooLargeError, \
     UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
 from controllers.console.universal_chat.wraps import UniversalChatResource
-from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
-    LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.model_runtime.errors.invoke import InvokeError
 from services.audio_service import AudioService
 from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
     UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
@@ -53,8 +53,7 @@ class UniversalChatAudioApi(UniversalChatResource):
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
         except ValueError as e:
             raise e

+ 8 - 9
api/controllers/console/universal_chat/chat.py

@@ -12,9 +12,10 @@ from controllers.console import api
 from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, ProviderNotInitializeError, \
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
 from controllers.console.universal_chat.wraps import UniversalChatResource
-from core.conversation_message_task import PubHandler
-from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
-    LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError
+from core.application_queue_manager import ApplicationQueueManager
+from core.entities.application_entities import InvokeFrom
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.model_runtime.errors.invoke import InvokeError
 from libs.helper import uuid_value
 from services.completion_service import CompletionService
 
@@ -68,7 +69,7 @@ class UniversalChatApi(UniversalChatResource):
                 app_model=app_model,
                 user=current_user,
                 args=args,
-                from_source='console',
+                invoke_from=InvokeFrom.EXPLORE,
                 streaming=True,
                 is_model_config_override=True,
             )
@@ -87,8 +88,7 @@ class UniversalChatApi(UniversalChatResource):
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
         except ValueError as e:
             raise e
@@ -99,7 +99,7 @@ class UniversalChatApi(UniversalChatResource):
 
 class UniversalChatStopApi(UniversalChatResource):
     def post(self, universal_app, task_id):
-        PubHandler.stop(current_user, task_id)
+        ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
 
         return {'result': 'success'}, 200
 
@@ -125,8 +125,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
                 yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
             except ModelCurrentlyNotSupportError:
                 yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
-            except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                    LLMRateLimitError, LLMAuthorizationError) as e:
+            except InvokeError as e:
                 yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
             except ValueError as e:
                 yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"

+ 3 - 4
api/controllers/console/universal_chat/message.py

@@ -12,8 +12,8 @@ from controllers.console.app.error import ProviderNotInitializeError, \
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
 from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
 from controllers.console.universal_chat.wraps import UniversalChatResource
-from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
-    ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.model_runtime.errors.invoke import InvokeError
 from libs.helper import uuid_value, TimestampField
 from services.errors.conversation import ConversationNotExistsError
 from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
@@ -132,8 +132,7 @@ class UniversalChatMessageSuggestedQuestionApi(UniversalChatResource):
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
         except Exception:
             logging.exception("internal server error.")

+ 96 - 176
api/controllers/console/workspace/model_providers.py

@@ -1,16 +1,19 @@
+import io
+
+from flask import send_file
 from flask_login import current_user
-from libs.login import login_required
 from flask_restful import Resource, reqparse
 from werkzeug.exceptions import Forbidden
 
 from controllers.console import api
-from controllers.console.app.error import ProviderNotInitializeError
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
-from core.model_providers.error import LLMBadRequestError
-from core.model_providers.providers.base import CredentialsValidateFailedError
-from services.provider_service import ProviderService
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.utils.encoders import jsonable_encoder
+from libs.login import login_required
 from services.billing_service import BillingService
+from services.model_provider_service import ModelProviderService
 
 
 class ModelProviderListApi(Resource):
@@ -22,13 +25,36 @@ class ModelProviderListApi(Resource):
         tenant_id = current_user.current_tenant_id
 
         parser = reqparse.RequestParser()
-        parser.add_argument('model_type', type=str, required=False, nullable=True, location='args')
+        parser.add_argument('model_type', type=str, required=False, nullable=True,
+                            choices=[mt.value for mt in ModelType], location='args')
         args = parser.parse_args()
 
-        provider_service = ProviderService()
-        provider_list = provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get('model_type'))
+        model_provider_service = ModelProviderService()
+        provider_list = model_provider_service.get_provider_list(
+            tenant_id=tenant_id,
+            model_type=args.get('model_type')
+        )
+
+        return jsonable_encoder({"data": provider_list})
+
+
+class ModelProviderCredentialApi(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self, provider: str):
+        tenant_id = current_user.current_tenant_id
+
+        model_provider_service = ModelProviderService()
+        credentials = model_provider_service.get_provider_credentials(
+            tenant_id=tenant_id,
+            provider=provider
+        )
 
-        return provider_list
+        return {
+            "credentials": credentials
+        }
 
 
 class ModelProviderValidateApi(Resource):
@@ -36,21 +62,24 @@ class ModelProviderValidateApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
-    def post(self, provider_name: str):
+    def post(self, provider: str):
 
         parser = reqparse.RequestParser()
-        parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
+        parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
         args = parser.parse_args()
 
-        provider_service = ProviderService()
+        tenant_id = current_user.current_tenant_id
+
+        model_provider_service = ModelProviderService()
 
         result = True
         error = None
 
         try:
-            provider_service.custom_provider_config_validate(
-                provider_name=provider_name,
-                config=args['config']
+            model_provider_service.provider_credentials_validate(
+                tenant_id=tenant_id,
+                provider=provider,
+                credentials=args['credentials']
             )
         except CredentialsValidateFailedError as ex:
             result = False
@@ -64,26 +93,26 @@ class ModelProviderValidateApi(Resource):
         return response
 
 
-class ModelProviderUpdateApi(Resource):
+class ModelProviderApi(Resource):
 
     @setup_required
     @login_required
     @account_initialization_required
-    def post(self, provider_name: str):
+    def post(self, provider: str):
         if current_user.current_tenant.current_role not in ['admin', 'owner']:
             raise Forbidden()
 
         parser = reqparse.RequestParser()
-        parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
+        parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
         args = parser.parse_args()
 
-        provider_service = ProviderService()
+        model_provider_service = ModelProviderService()
 
         try:
-            provider_service.save_custom_provider_config(
+            model_provider_service.save_provider_credentials(
                 tenant_id=current_user.current_tenant_id,
-                provider_name=provider_name,
-                config=args['config']
+                provider=provider,
+                credentials=args['credentials']
             )
         except CredentialsValidateFailedError as ex:
             raise ValueError(str(ex))
@@ -93,109 +122,36 @@ class ModelProviderUpdateApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
-    def delete(self, provider_name: str):
+    def delete(self, provider: str):
         if current_user.current_tenant.current_role not in ['admin', 'owner']:
             raise Forbidden()
 
-        provider_service = ProviderService()
-        provider_service.delete_custom_provider(
+        model_provider_service = ModelProviderService()
+        model_provider_service.remove_provider_credentials(
             tenant_id=current_user.current_tenant_id,
-            provider_name=provider_name
+            provider=provider
         )
 
         return {'result': 'success'}, 204
 
 
-class ModelProviderModelValidateApi(Resource):
-
-    @setup_required
-    @login_required
-    @account_initialization_required
-    def post(self, provider_name: str):
-        parser = reqparse.RequestParser()
-        parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('model_type', type=str, required=True, nullable=False,
-                            choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='json')
-        parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
-        args = parser.parse_args()
-
-        provider_service = ProviderService()
-
-        result = True
-        error = None
-
-        try:
-            provider_service.custom_provider_model_config_validate(
-                provider_name=provider_name,
-                model_name=args['model_name'],
-                model_type=args['model_type'],
-                config=args['config']
-            )
-        except CredentialsValidateFailedError as ex:
-            result = False
-            error = str(ex)
-
-        response = {'result': 'success' if result else 'error'}
-
-        if not result:
-            response['error'] = error
-
-        return response
-
-
-class ModelProviderModelUpdateApi(Resource):
+class ModelProviderIconApi(Resource):
+    """
+    Get model provider icon
+    """
 
     @setup_required
     @login_required
     @account_initialization_required
-    def post(self, provider_name: str):
-        if current_user.current_tenant.current_role not in ['admin', 'owner']:
-            raise Forbidden()
-
-        parser = reqparse.RequestParser()
-        parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('model_type', type=str, required=True, nullable=False,
-                            choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='json')
-        parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
-        args = parser.parse_args()
-
-        provider_service = ProviderService()
-
-        try:
-            provider_service.add_or_save_custom_provider_model_config(
-                tenant_id=current_user.current_tenant_id,
-                provider_name=provider_name,
-                model_name=args['model_name'],
-                model_type=args['model_type'],
-                config=args['config']
-            )
-        except CredentialsValidateFailedError as ex:
-            raise ValueError(str(ex))
-
-        return {'result': 'success'}, 200
-
-    @setup_required
-    @login_required
-    @account_initialization_required
-    def delete(self, provider_name: str):
-        if current_user.current_tenant.current_role not in ['admin', 'owner']:
-            raise Forbidden()
-
-        parser = reqparse.RequestParser()
-        parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
-        parser.add_argument('model_type', type=str, required=True, nullable=False,
-                            choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args')
-        args = parser.parse_args()
-
-        provider_service = ProviderService()
-        provider_service.delete_custom_provider_model(
-            tenant_id=current_user.current_tenant_id,
-            provider_name=provider_name,
-            model_name=args['model_name'],
-            model_type=args['model_type']
+    def get(self, provider: str, icon_type: str, lang: str):
+        model_provider_service = ModelProviderService()
+        icon, mimetype = model_provider_service.get_model_provider_icon(
+            provider=provider,
+            icon_type=icon_type,
+            lang=lang
         )
 
-        return {'result': 'success'}, 204
+        return send_file(io.BytesIO(icon), mimetype=mimetype)
 
 
 class PreferredProviderTypeUpdateApi(Resource):
@@ -203,71 +159,36 @@ class PreferredProviderTypeUpdateApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
-    def post(self, provider_name: str):
+    def post(self, provider: str):
         if current_user.current_tenant.current_role not in ['admin', 'owner']:
             raise Forbidden()
 
+        tenant_id = current_user.current_tenant_id
+
         parser = reqparse.RequestParser()
         parser.add_argument('preferred_provider_type', type=str, required=True, nullable=False,
                             choices=['system', 'custom'], location='json')
         args = parser.parse_args()
 
-        provider_service = ProviderService()
-        provider_service.switch_preferred_provider(
-            tenant_id=current_user.current_tenant_id,
-            provider_name=provider_name,
+        model_provider_service = ModelProviderService()
+        model_provider_service.switch_preferred_provider(
+            tenant_id=tenant_id,
+            provider=provider,
             preferred_provider_type=args['preferred_provider_type']
         )
 
         return {'result': 'success'}
 
 
-class ModelProviderModelParameterRuleApi(Resource):
-
-    @setup_required
-    @login_required
-    @account_initialization_required
-    def get(self, provider_name: str):
-        parser = reqparse.RequestParser()
-        parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
-        args = parser.parse_args()
-
-        provider_service = ProviderService()
-
-        try:
-            parameter_rules = provider_service.get_model_parameter_rules(
-                tenant_id=current_user.current_tenant_id,
-                model_provider_name=provider_name,
-                model_name=args['model_name'],
-                model_type='text-generation'
-            )
-        except LLMBadRequestError:
-            raise ProviderNotInitializeError(
-                f"Current Text Generation Model is invalid. Please switch to the available model.")
-
-        rules = {
-            k: {
-                'enabled': v.enabled,
-                'min': v.min,
-                'max': v.max,
-                'default': v.default,
-                'precision': v.precision
-            }
-            for k, v in vars(parameter_rules).items()
-        }
-
-        return rules
-
-
 class ModelProviderPaymentCheckoutUrlApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
-    def get(self, provider_name: str):
-        if provider_name != 'anthropic':
-            raise ValueError(f'provider name {provider_name} is invalid')
+    def get(self, provider: str):
+        if provider != 'anthropic':
+            raise ValueError(f'provider name {provider} is invalid')
 
-        data = BillingService.get_model_provider_payment_link(provider_name=provider_name,
+        data = BillingService.get_model_provider_payment_link(provider_name=provider,
                                                               tenant_id=current_user.current_tenant_id,
                                                               account_id=current_user.id)
         return data
@@ -277,11 +198,11 @@ class ModelProviderFreeQuotaSubmitApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
-    def post(self, provider_name: str):
-        provider_service = ProviderService()
-        result = provider_service.free_quota_submit(
+    def post(self, provider: str):
+        model_provider_service = ModelProviderService()
+        result = model_provider_service.free_quota_submit(
             tenant_id=current_user.current_tenant_id,
-            provider_name=provider_name
+            provider=provider
         )
 
         return result
@@ -291,15 +212,15 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
-    def get(self, provider_name: str):
+    def get(self, provider: str):
         parser = reqparse.RequestParser()
         parser.add_argument('token', type=str, required=False, nullable=True, location='args')
         args = parser.parse_args()
 
-        provider_service = ProviderService()
-        result = provider_service.free_quota_qualification_verify(
+        model_provider_service = ModelProviderService()
+        result = model_provider_service.free_quota_qualification_verify(
             tenant_id=current_user.current_tenant_id,
-            provider_name=provider_name,
+            provider=provider,
             token=args['token']
         )
 
@@ -307,19 +228,18 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
 
 
 api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers')
-api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider_name>/validate')
-api.add_resource(ModelProviderUpdateApi, '/workspaces/current/model-providers/<string:provider_name>')
-api.add_resource(ModelProviderModelValidateApi,
-                 '/workspaces/current/model-providers/<string:provider_name>/models/validate')
-api.add_resource(ModelProviderModelUpdateApi,
-                 '/workspaces/current/model-providers/<string:provider_name>/models')
+
+api.add_resource(ModelProviderCredentialApi, '/workspaces/current/model-providers/<string:provider>/credentials')
+api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider>/credentials/validate')
+api.add_resource(ModelProviderApi, '/workspaces/current/model-providers/<string:provider>')
+api.add_resource(ModelProviderIconApi, '/workspaces/current/model-providers/<string:provider>/'
+                                       '<string:icon_type>/<string:lang>')
+
 api.add_resource(PreferredProviderTypeUpdateApi,
-                 '/workspaces/current/model-providers/<string:provider_name>/preferred-provider-type')
-api.add_resource(ModelProviderModelParameterRuleApi,
-                 '/workspaces/current/model-providers/<string:provider_name>/models/parameter-rules')
+                 '/workspaces/current/model-providers/<string:provider>/preferred-provider-type')
 api.add_resource(ModelProviderPaymentCheckoutUrlApi,
-                 '/workspaces/current/model-providers/<string:provider_name>/checkout-url')
+                 '/workspaces/current/model-providers/<string:provider>/checkout-url')
 api.add_resource(ModelProviderFreeQuotaSubmitApi,
-                 '/workspaces/current/model-providers/<string:provider_name>/free-quota-submit')
+                 '/workspaces/current/model-providers/<string:provider>/free-quota-submit')
 api.add_resource(ModelProviderFreeQuotaQualificationVerifyApi,
-                 '/workspaces/current/model-providers/<string:provider_name>/free-quota-qualification-verify')
+                 '/workspaces/current/model-providers/<string:provider>/free-quota-qualification-verify')

+ 212 - 56
api/controllers/console/workspace/models.py

@@ -1,16 +1,17 @@
 import logging
 
 from flask_login import current_user
-from libs.login import login_required
-from flask_restful import Resource, reqparse
+from flask_restful import reqparse, Resource
+from werkzeug.exceptions import Forbidden
 
 from controllers.console import api
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
-from core.model_providers.model_provider_factory import ModelProviderFactory
-from core.model_providers.models.entity.model_params import ModelType
-from models.provider import ProviderType
-from services.provider_service import ProviderService
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.utils.encoders import jsonable_encoder
+from libs.login import login_required
+from services.model_provider_service import ModelProviderService
 
 
 class DefaultModelApi(Resource):
@@ -21,52 +22,20 @@ class DefaultModelApi(Resource):
     def get(self):
         parser = reqparse.RequestParser()
         parser.add_argument('model_type', type=str, required=True, nullable=False,
-                            choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args')
+                            choices=[mt.value for mt in ModelType], location='args')
         args = parser.parse_args()
 
         tenant_id = current_user.current_tenant_id
 
-        provider_service = ProviderService()
-        default_model = provider_service.get_default_model_of_model_type(
+        model_provider_service = ModelProviderService()
+        default_model_entity = model_provider_service.get_default_model_of_model_type(
             tenant_id=tenant_id,
             model_type=args['model_type']
         )
 
-        if not default_model:
-            return None
-
-        model_provider = ModelProviderFactory.get_preferred_model_provider(
-            tenant_id,
-            default_model.provider_name
-        )
-
-        if not model_provider:
-            return {
-                'model_name': default_model.model_name,
-                'model_type': default_model.model_type,
-                'model_provider': {
-                    'provider_name': default_model.provider_name
-                }
-            }
-
-        provider = model_provider.provider
-        rst = {
-            'model_name': default_model.model_name,
-            'model_type': default_model.model_type,
-            'model_provider': {
-                'provider_name': provider.provider_name,
-                'provider_type': provider.provider_type
-            }
-        }
-
-        model_provider_rules = ModelProviderFactory.get_provider_rule(default_model.provider_name)
-        if provider.provider_type == ProviderType.SYSTEM.value:
-            rst['model_provider']['quota_type'] = provider.quota_type
-            rst['model_provider']['quota_unit'] = model_provider_rules['system_config']['quota_unit']
-            rst['model_provider']['quota_limit'] = provider.quota_limit
-            rst['model_provider']['quota_used'] = provider.quota_used
-
-        return rst
+        return jsonable_encoder({
+            "data": default_model_entity
+        })
 
     @setup_required
     @login_required
@@ -76,15 +45,26 @@ class DefaultModelApi(Resource):
         parser.add_argument('model_settings', type=list, required=True, nullable=False, location='json')
         args = parser.parse_args()
 
-        provider_service = ProviderService()
+        tenant_id = current_user.current_tenant_id
+
+        model_provider_service = ModelProviderService()
         model_settings = args['model_settings']
         for model_setting in model_settings:
+            if 'model_type' not in model_setting or model_setting['model_type'] not in [mt.value for mt in ModelType]:
+                raise ValueError('invalid model type')
+
+            if 'provider' not in model_setting:
+                continue
+
+            if 'model' not in model_setting:
+                raise ValueError('invalid model')
+
             try:
-                provider_service.update_default_model_of_model_type(
-                    tenant_id=current_user.current_tenant_id,
+                model_provider_service.update_default_model_of_model_type(
+                    tenant_id=tenant_id,
                     model_type=model_setting['model_type'],
-                    provider_name=model_setting['provider_name'],
-                    model_name=model_setting['model_name']
+                    provider=model_setting['provider'],
+                    model=model_setting['model']
                 )
             except Exception:
                 logging.warning(f"{model_setting['model_type']} save error")
@@ -92,22 +72,198 @@ class DefaultModelApi(Resource):
         return {'result': 'success'}
 
 
-class ValidModelApi(Resource):
+class ModelProviderModelApi(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self, provider):
+        tenant_id = current_user.current_tenant_id
+
+        model_provider_service = ModelProviderService()
+        models = model_provider_service.get_models_by_provider(
+            tenant_id=tenant_id,
+            provider=provider
+        )
+
+        return jsonable_encoder({
+            "data": models
+        })
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def post(self, provider: str):
+        if current_user.current_tenant.current_role not in ['admin', 'owner']:
+            raise Forbidden()
+
+        tenant_id = current_user.current_tenant_id
+
+        parser = reqparse.RequestParser()
+        parser.add_argument('model', type=str, required=True, nullable=False, location='json')
+        parser.add_argument('model_type', type=str, required=True, nullable=False,
+                            choices=[mt.value for mt in ModelType], location='json')
+        parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
+        args = parser.parse_args()
+
+        model_provider_service = ModelProviderService()
+
+        try:
+            model_provider_service.save_model_credentials(
+                tenant_id=tenant_id,
+                provider=provider,
+                model=args['model'],
+                model_type=args['model_type'],
+                credentials=args['credentials']
+            )
+        except CredentialsValidateFailedError as ex:
+            raise ValueError(str(ex))
+
+        return {'result': 'success'}, 200
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def delete(self, provider: str):
+        if current_user.current_tenant.current_role not in ['admin', 'owner']:
+            raise Forbidden()
+
+        tenant_id = current_user.current_tenant_id
+
+        parser = reqparse.RequestParser()
+        parser.add_argument('model', type=str, required=True, nullable=False, location='json')
+        parser.add_argument('model_type', type=str, required=True, nullable=False,
+                            choices=[mt.value for mt in ModelType], location='json')
+        args = parser.parse_args()
+
+        model_provider_service = ModelProviderService()
+        model_provider_service.remove_model_credentials(
+            tenant_id=tenant_id,
+            provider=provider,
+            model=args['model'],
+            model_type=args['model_type']
+        )
+
+        return {'result': 'success'}, 204
+
+
+class ModelProviderModelCredentialApi(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self, provider: str):
+        tenant_id = current_user.current_tenant_id
+
+        parser = reqparse.RequestParser()
+        parser.add_argument('model', type=str, required=True, nullable=False, location='args')
+        parser.add_argument('model_type', type=str, required=True, nullable=False,
+                            choices=[mt.value for mt in ModelType], location='args')
+        args = parser.parse_args()
+
+        model_provider_service = ModelProviderService()
+        credentials = model_provider_service.get_model_credentials(
+            tenant_id=tenant_id,
+            provider=provider,
+            model_type=args['model_type'],
+            model=args['model']
+        )
+
+        return {
+            "credentials": credentials
+        }
+
+
+class ModelProviderModelValidateApi(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def post(self, provider: str):
+        tenant_id = current_user.current_tenant_id
+
+        parser = reqparse.RequestParser()
+        parser.add_argument('model', type=str, required=True, nullable=False, location='json')
+        parser.add_argument('model_type', type=str, required=True, nullable=False,
+                            choices=[mt.value for mt in ModelType], location='json')
+        parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
+        args = parser.parse_args()
+
+        model_provider_service = ModelProviderService()
+
+        result = True
+        error = None
+
+        try:
+            model_provider_service.model_credentials_validate(
+                tenant_id=tenant_id,
+                provider=provider,
+                model=args['model'],
+                model_type=args['model_type'],
+                credentials=args['credentials']
+            )
+        except CredentialsValidateFailedError as ex:
+            result = False
+            error = str(ex)
+
+        response = {'result': 'success' if result else 'error'}
+
+        if not result:
+            response['error'] = error
+
+        return response
+
+
+class ModelProviderModelParameterRuleApi(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self, provider: str):
+        parser = reqparse.RequestParser()
+        parser.add_argument('model', type=str, required=True, nullable=False, location='args')
+        args = parser.parse_args()
+
+        tenant_id = current_user.current_tenant_id
+
+        model_provider_service = ModelProviderService()
+        parameter_rules = model_provider_service.get_model_parameter_rules(
+            tenant_id=tenant_id,
+            provider=provider,
+            model=args['model']
+        )
+
+        return jsonable_encoder({
+            "data": parameter_rules
+        })
+
+
+class ModelProviderAvailableModelApi(Resource):
 
     @setup_required
     @login_required
     @account_initialization_required
     def get(self, model_type):
-        ModelType.value_of(model_type)
+        tenant_id = current_user.current_tenant_id
 
-        provider_service = ProviderService()
-        valid_models = provider_service.get_valid_model_list(
-            tenant_id=current_user.current_tenant_id,
+        model_provider_service = ModelProviderService()
+        models = model_provider_service.get_models_by_model_type(
+            tenant_id=tenant_id,
             model_type=model_type
         )
 
-        return valid_models
+        return jsonable_encoder({
+            "data": models
+        })
+
 
+api.add_resource(ModelProviderModelApi, '/workspaces/current/model-providers/<string:provider>/models')
+api.add_resource(ModelProviderModelCredentialApi,
+                 '/workspaces/current/model-providers/<string:provider>/models/credentials')
+api.add_resource(ModelProviderModelValidateApi,
+                 '/workspaces/current/model-providers/<string:provider>/models/credentials/validate')
 
+api.add_resource(ModelProviderModelParameterRuleApi,
+                 '/workspaces/current/model-providers/<string:provider>/models/parameter-rules')
+api.add_resource(ModelProviderAvailableModelApi, '/workspaces/current/models/model-types/<string:model_type>')
 api.add_resource(DefaultModelApi, '/workspaces/current/default-model')
-api.add_resource(ValidModelApi, '/workspaces/current/models/model-type/<string:model_type>')

+ 0 - 131
api/controllers/console/workspace/providers.py

@@ -1,131 +0,0 @@
-# -*- coding:utf-8 -*-
-from flask_login import current_user
-from libs.login import login_required
-from flask_restful import Resource, reqparse
-from werkzeug.exceptions import Forbidden
-
-from controllers.console import api
-from controllers.console.setup import setup_required
-from controllers.console.wraps import account_initialization_required
-from core.model_providers.providers.base import CredentialsValidateFailedError
-from models.provider import ProviderType
-from services.provider_service import ProviderService
-
-
-class ProviderListApi(Resource):
-
-    @setup_required
-    @login_required
-    @account_initialization_required
-    def get(self):
-        tenant_id = current_user.current_tenant_id
-
-        """
-        If the type is AZURE_OPENAI, decode and return the four fields of azure_api_type, azure_api_version:, 
-        azure_api_base, azure_api_key as an object, where azure_api_key displays the first 6 bits in plaintext, and the 
-        rest is replaced by * and the last two bits are displayed in plaintext
-        
-        If the type is other, decode and return the Token field directly, the field displays the first 6 bits in 
-        plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
-        """
-
-        provider_service = ProviderService()
-        provider_info_list = provider_service.get_provider_list(tenant_id)
-
-        provider_list = [
-            {
-                'provider_name': p['provider_name'],
-                'provider_type': p['provider_type'],
-                'is_valid': p['is_valid'],
-                'last_used': p['last_used'],
-                'is_enabled': p['is_valid'],
-                **({
-                       'quota_type': p['quota_type'],
-                       'quota_limit': p['quota_limit'],
-                       'quota_used': p['quota_used']
-                   } if p['provider_type'] == ProviderType.SYSTEM.value else {}),
-                'token': (p['config'] if p['provider_name'] != 'openai' else p['config']['openai_api_key'])
-                        if p['config'] else None
-            }
-            for name, provider_info in provider_info_list.items()
-            for p in provider_info['providers']
-        ]
-
-        return provider_list
-
-
-class ProviderTokenApi(Resource):
-
-    @setup_required
-    @login_required
-    @account_initialization_required
-    def post(self, provider):
-        # The role of the current user in the ta table must be admin or owner
-        if current_user.current_tenant.current_role not in ['admin', 'owner']:
-            raise Forbidden()
-
-        parser = reqparse.RequestParser()
-        parser.add_argument('token', required=True, nullable=False, location='json')
-        args = parser.parse_args()
-
-        if provider == 'openai':
-            args['token'] = {
-                'openai_api_key': args['token']
-            }
-
-        provider_service = ProviderService()
-        try:
-            provider_service.save_custom_provider_config(
-                tenant_id=current_user.current_tenant_id,
-                provider_name=provider,
-                config=args['token']
-            )
-        except CredentialsValidateFailedError as ex:
-            raise ValueError(str(ex))
-
-        return {'result': 'success'}, 201
-
-
-class ProviderTokenValidateApi(Resource):
-
-    @setup_required
-    @login_required
-    @account_initialization_required
-    def post(self, provider):
-        parser = reqparse.RequestParser()
-        parser.add_argument('token', required=True, nullable=False, location='json')
-        args = parser.parse_args()
-
-        provider_service = ProviderService()
-
-        if provider == 'openai':
-            args['token'] = {
-                'openai_api_key': args['token']
-            }
-
-        result = True
-        error = None
-
-        try:
-            provider_service.custom_provider_config_validate(
-                provider_name=provider,
-                config=args['token']
-            )
-        except CredentialsValidateFailedError as ex:
-            result = False
-            error = str(ex)
-
-        response = {'result': 'success' if result else 'error'}
-
-        if not result:
-            response['error'] = error
-
-        return response
-
-
-api.add_resource(ProviderTokenApi, '/workspaces/current/providers/<provider>/token',
-                 endpoint='workspaces_current_providers_token')  # PUT for updating provider token
-api.add_resource(ProviderTokenValidateApi, '/workspaces/current/providers/<provider>/token-validate',
-                 endpoint='workspaces_current_providers_token_validate')  # POST for validating provider token
-
-api.add_resource(ProviderListApi, '/workspaces/current/providers')  # GET for getting providers list

+ 0 - 1
api/controllers/console/workspace/workspace.py

@@ -34,7 +34,6 @@ tenant_fields = {
     'status': fields.String,
     'created_at': TimestampField,
     'role': fields.String,
-    'providers': fields.List(fields.Nested(provider_fields)),
     'in_trial': fields.Boolean,
     'trial_end_reason': fields.String,
     'custom_config': fields.Raw(attribute='custom_config'),

+ 3 - 4
api/controllers/service_api/app/audio.py

@@ -9,8 +9,8 @@ from controllers.service_api.app.error import AppUnavailableError, ProviderNotIn
     ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, UnsupportedAudioTypeError, \
     ProviderNotSupportSpeechToTextError
 from controllers.service_api.wraps import AppApiResource
-from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
-    LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.model_runtime.errors.invoke import InvokeError
 from models.model import App, AppModelConfig
 from services.audio_service import AudioService
 from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
@@ -49,8 +49,7 @@ class AudioApi(AppApiResource):
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
         except ValueError as e:
             raise e

+ 11 - 13
api/controllers/service_api/app/completion.py

@@ -13,9 +13,10 @@ from controllers.service_api.app.error import AppUnavailableError, ProviderNotIn
     ConversationCompletedError, CompletionRequestError, ProviderQuotaExceededError, \
     ProviderModelCurrentlyNotSupportError
 from controllers.service_api.wraps import AppApiResource
-from core.conversation_message_task import PubHandler
-from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
-    LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.application_queue_manager import ApplicationQueueManager
+from core.entities.application_entities import InvokeFrom
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.model_runtime.errors.invoke import InvokeError
 from libs.helper import uuid_value
 from services.completion_service import CompletionService
 
@@ -47,7 +48,7 @@ class CompletionApi(AppApiResource):
                 app_model=app_model,
                 user=end_user,
                 args=args,
-                from_source='api',
+                invoke_from=InvokeFrom.SERVICE_API,
                 streaming=streaming,
             )
 
@@ -65,8 +66,7 @@ class CompletionApi(AppApiResource):
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
         except ValueError as e:
             raise e
@@ -80,7 +80,7 @@ class CompletionStopApi(AppApiResource):
         if app_model.mode != 'completion':
             raise AppUnavailableError()
 
-        PubHandler.stop(end_user, task_id)
+        ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
 
         return {'result': 'success'}, 200
 
@@ -112,7 +112,7 @@ class ChatApi(AppApiResource):
                 app_model=app_model,
                 user=end_user,
                 args=args,
-                from_source='api',
+                invoke_from=InvokeFrom.SERVICE_API,
                 streaming=streaming
             )
 
@@ -130,8 +130,7 @@ class ChatApi(AppApiResource):
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
         except ValueError as e:
             raise e
@@ -145,7 +144,7 @@ class ChatStopApi(AppApiResource):
         if app_model.mode != 'chat':
             raise NotChatAppError()
 
-        PubHandler.stop(end_user, task_id)
+        ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
 
         return {'result': 'success'}, 200
 
@@ -171,8 +170,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
                 yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
             except ModelCurrentlyNotSupportError:
                 yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
-            except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                    LLMRateLimitError, LLMAuthorizationError) as e:
+            except InvokeError as e:
                 yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
             except ValueError as e:
                 yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"

+ 15 - 7
api/controllers/service_api/dataset/dataset.py

@@ -4,11 +4,11 @@ import services.dataset_service
 from controllers.service_api import api
 from controllers.service_api.dataset.error import DatasetNameDuplicateError
 from controllers.service_api.wraps import DatasetApiResource
+from core.model_runtime.entities.model_entities import ModelType
+from core.provider_manager import ProviderManager
 from libs.login import current_user
-from core.model_providers.models.entity.model_params import ModelType
 from fields.dataset_fields import dataset_detail_fields
 from services.dataset_service import DatasetService
-from services.provider_service import ProviderService
 
 
 def _validate_name(name):
@@ -27,12 +27,20 @@ class DatasetApi(DatasetApiResource):
         datasets, total = DatasetService.get_datasets(page, limit, provider,
                                                       tenant_id, current_user)
         # check embedding setting
-        provider_service = ProviderService()
-        valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
-                                                                 ModelType.EMBEDDINGS.value)
+        provider_manager = ProviderManager()
+        configurations = provider_manager.get_configurations(
+            tenant_id=current_user.current_tenant_id
+        )
+
+        embedding_models = configurations.get_models(
+            model_type=ModelType.TEXT_EMBEDDING,
+            only_active=True
+        )
+
         model_names = []
-        for valid_model in valid_model_list:
-            model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
+        for embedding_model in embedding_models:
+            model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
+
         data = marshal(datasets, dataset_detail_fields)
         for item in data:
             if item['indexing_technique'] == 'high_quality':

+ 1 - 1
api/controllers/service_api/dataset/document.py

@@ -13,7 +13,7 @@ from controllers.service_api.dataset.error import ArchivedDocumentImmutableError
     NoFileUploadedError, TooManyFilesError
 from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
 from libs.login import current_user
-from core.model_providers.error import ProviderTokenNotInitError
+from core.errors.error import ProviderTokenNotInitError
 from extensions.ext_database import db
 from fields.document_fields import document_fields, document_status_fields
 from models.dataset import Dataset, Document, DocumentSegment

+ 18 - 11
api/controllers/service_api/dataset/segment.py

@@ -4,8 +4,9 @@ from werkzeug.exceptions import NotFound
 from controllers.service_api import api
 from controllers.service_api.app.error import ProviderNotInitializeError
 from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
-from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
-from core.model_providers.model_factory import ModelFactory
+from core.errors.error import ProviderTokenNotInitError, LLMBadRequestError
+from core.model_manager import ModelManager
+from core.model_runtime.entities.model_entities import ModelType
 from extensions.ext_database import db
 from fields.segment_fields import segment_fields
 from models.dataset import Dataset, DocumentSegment
@@ -35,10 +36,12 @@ class SegmentApi(DatasetApiResource):
         # check embedding model setting
         if dataset.indexing_technique == 'high_quality':
             try:
-                ModelFactory.get_embedding_model(
+                model_manager = ModelManager()
+                model_manager.get_model_instance(
                     tenant_id=current_user.current_tenant_id,
-                    model_provider_name=dataset.embedding_model_provider,
-                    model_name=dataset.embedding_model
+                    provider=dataset.embedding_model_provider,
+                    model_type=ModelType.TEXT_EMBEDDING,
+                    model=dataset.embedding_model
                 )
             except LLMBadRequestError:
                 raise ProviderNotInitializeError(
@@ -77,10 +80,12 @@ class SegmentApi(DatasetApiResource):
         # check embedding model setting
         if dataset.indexing_technique == 'high_quality':
             try:
-                ModelFactory.get_embedding_model(
+                model_manager = ModelManager()
+                model_manager.get_model_instance(
                     tenant_id=current_user.current_tenant_id,
-                    model_provider_name=dataset.embedding_model_provider,
-                    model_name=dataset.embedding_model
+                    provider=dataset.embedding_model_provider,
+                    model_type=ModelType.TEXT_EMBEDDING,
+                    model=dataset.embedding_model
                 )
             except LLMBadRequestError:
                 raise ProviderNotInitializeError(
@@ -167,10 +172,12 @@ class DatasetSegmentApi(DatasetApiResource):
         if dataset.indexing_technique == 'high_quality':
             # check embedding model setting
             try:
-                ModelFactory.get_embedding_model(
+                model_manager = ModelManager()
+                model_manager.get_model_instance(
                     tenant_id=current_user.current_tenant_id,
-                    model_provider_name=dataset.embedding_model_provider,
-                    model_name=dataset.embedding_model
+                    provider=dataset.embedding_model_provider,
+                    model_type=ModelType.TEXT_EMBEDDING,
+                    model=dataset.embedding_model
                 )
             except LLMBadRequestError:
                 raise ProviderNotInitializeError(

+ 3 - 4
api/controllers/web/audio.py

@@ -10,8 +10,8 @@ from controllers.web.error import AppUnavailableError, ProviderNotInitializeErro
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, \
     UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
 from controllers.web.wraps import WebApiResource
-from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
-    LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.model_runtime.errors.invoke import InvokeError
 from services.audio_service import AudioService
 from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
     UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
@@ -51,8 +51,7 @@ class AudioApi(WebApiResource):
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
         except ValueError as e:
             raise e

+ 11 - 13
api/controllers/web/completion.py

@@ -13,9 +13,10 @@ from controllers.web.error import AppUnavailableError, ConversationCompletedErro
     ProviderNotInitializeError, NotChatAppError, NotCompletionAppError, CompletionRequestError, \
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
 from controllers.web.wraps import WebApiResource
-from core.conversation_message_task import PubHandler
-from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
-    LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.application_queue_manager import ApplicationQueueManager
+from core.entities.application_entities import InvokeFrom
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.model_runtime.errors.invoke import InvokeError
 from libs.helper import uuid_value
 from services.completion_service import CompletionService
 
@@ -44,7 +45,7 @@ class CompletionApi(WebApiResource):
                 app_model=app_model,
                 user=end_user,
                 args=args,
-                from_source='api',
+                invoke_from=InvokeFrom.WEB_APP,
                 streaming=streaming
             )
 
@@ -62,8 +63,7 @@ class CompletionApi(WebApiResource):
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
         except ValueError as e:
             raise e
@@ -77,7 +77,7 @@ class CompletionStopApi(WebApiResource):
         if app_model.mode != 'completion':
             raise NotCompletionAppError()
 
-        PubHandler.stop(end_user, task_id)
+        ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
 
         return {'result': 'success'}, 200
 
@@ -105,7 +105,7 @@ class ChatApi(WebApiResource):
                 app_model=app_model,
                 user=end_user,
                 args=args,
-                from_source='api',
+                invoke_from=InvokeFrom.WEB_APP,
                 streaming=streaming
             )
 
@@ -123,8 +123,7 @@ class ChatApi(WebApiResource):
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
         except ValueError as e:
             raise e
@@ -138,7 +137,7 @@ class ChatStopApi(WebApiResource):
         if app_model.mode != 'chat':
             raise NotChatAppError()
 
-        PubHandler.stop(end_user, task_id)
+        ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
 
         return {'result': 'success'}, 200
 
@@ -164,8 +163,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
                 yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
             except ModelCurrentlyNotSupportError:
                 yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
-            except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                    LLMRateLimitError, LLMAuthorizationError) as e:
+            except InvokeError as e:
                 yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
             except ValueError as e:
                 yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"

+ 14 - 9
api/controllers/web/message.py

@@ -14,8 +14,9 @@ from controllers.web.error import NotChatAppError, CompletionRequestError, Provi
     AppMoreLikeThisDisabledError, NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \
     ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
 from controllers.web.wraps import WebApiResource
-from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
-    ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.entities.application_entities import InvokeFrom
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.model_runtime.errors.invoke import InvokeError
 from libs.helper import uuid_value, TimestampField
 from services.completion_service import CompletionService
 from services.errors.app import MoreLikeThisDisabledError
@@ -117,7 +118,14 @@ class MessageMoreLikeThisApi(WebApiResource):
         streaming = args['response_mode'] == 'streaming'
 
         try:
-            response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming, 'web_app')
+            response = CompletionService.generate_more_like_this(
+                app_model=app_model,
+                user=end_user,
+                message_id=message_id,
+                invoke_from=InvokeFrom.WEB_APP,
+                streaming=streaming
+            )
+
             return compact_response(response)
         except MessageNotExistsError:
             raise NotFound("Message Not Exists.")
@@ -129,8 +137,7 @@ class MessageMoreLikeThisApi(WebApiResource):
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
         except ValueError as e:
             raise e
@@ -157,8 +164,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
                 yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
             except ModelCurrentlyNotSupportError:
                 yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
-            except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                    LLMRateLimitError, LLMAuthorizationError) as e:
+            except InvokeError as e:
                 yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
             except ValueError as e:
                 yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
@@ -195,8 +201,7 @@ class MessageSuggestedQuestionApi(WebApiResource):
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
             raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
+        except InvokeError as e:
             raise CompletionRequestError(str(e))
         except Exception:
             logging.exception("internal server error.")

+ 101 - 0
api/core/agent/agent/agent_llm_callback.py

@@ -0,0 +1,101 @@
+import logging
+from typing import Optional, List
+
+from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
+from core.model_runtime.callbacks.base_callback import Callback
+from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult
+from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage
+from core.model_runtime.model_providers.__base.ai_model import AIModel
+
+logger = logging.getLogger(__name__)
+
+
+class AgentLLMCallback(Callback):
+
+    def __init__(self, agent_callback: AgentLoopGatherCallbackHandler) -> None:
+        self.agent_callback = agent_callback
+
+    def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
+                         prompt_messages: list[PromptMessage], model_parameters: dict,
+                         tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
+                         stream: bool = True, user: Optional[str] = None) -> None:
+        """
+        Before invoke callback
+
+        :param llm_instance: LLM instance
+        :param model: model name
+        :param credentials: model credentials
+        :param prompt_messages: prompt messages
+        :param model_parameters: model parameters
+        :param tools: tools for tool calling
+        :param stop: stop words
+        :param stream: is stream response
+        :param user: unique user id
+        """
+        self.agent_callback.on_llm_before_invoke(
+            prompt_messages=prompt_messages
+        )
+
+    def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
+                     prompt_messages: list[PromptMessage], model_parameters: dict,
+                     tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
+                     stream: bool = True, user: Optional[str] = None):
+        """
+        On new chunk callback
+
+        :param llm_instance: LLM instance
+        :param chunk: chunk
+        :param model: model name
+        :param credentials: model credentials
+        :param prompt_messages: prompt messages
+        :param model_parameters: model parameters
+        :param tools: tools for tool calling
+        :param stop: stop words
+        :param stream: is stream response
+        :param user: unique user id
+        """
+        pass
+
+    def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
+                        prompt_messages: list[PromptMessage], model_parameters: dict,
+                        tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
+                        stream: bool = True, user: Optional[str] = None) -> None:
+        """
+        After invoke callback
+
+        :param llm_instance: LLM instance
+        :param result: result
+        :param model: model name
+        :param credentials: model credentials
+        :param prompt_messages: prompt messages
+        :param model_parameters: model parameters
+        :param tools: tools for tool calling
+        :param stop: stop words
+        :param stream: is stream response
+        :param user: unique user id
+        """
+        self.agent_callback.on_llm_after_invoke(
+            result=result
+        )
+
+    def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
+                        prompt_messages: list[PromptMessage], model_parameters: dict,
+                        tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
+                        stream: bool = True, user: Optional[str] = None) -> None:
+        """
+        Invoke error callback
+
+        :param llm_instance: LLM instance
+        :param ex: exception
+        :param model: model name
+        :param credentials: model credentials
+        :param prompt_messages: prompt messages
+        :param model_parameters: model parameters
+        :param tools: tools for tool calling
+        :param stop: stop words
+        :param stream: is stream response
+        :param user: unique user id
+        """
+        self.agent_callback.on_llm_error(
+            error=ex
+        )

+ 33 - 12
api/core/agent/agent/calc_token_mixin.py

@@ -1,28 +1,49 @@
-from typing import List
+from typing import List, cast
 
 from langchain.schema import BaseMessage
 
-from core.model_providers.models.entity.message import to_prompt_messages
-from core.model_providers.models.llm.base import BaseLLM
+from core.entities.application_entities import ModelConfigEntity
+from core.entities.message_entities import lc_messages_to_prompt_messages
+from core.model_runtime.entities.message_entities import PromptMessage
+from core.model_runtime.entities.model_entities import ModelPropertyKey
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 
 
 class CalcTokenMixin:
 
-    def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
-        return model_instance.get_num_tokens(to_prompt_messages(messages))
-
-    def get_message_rest_tokens(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
+    def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: List[PromptMessage], **kwargs) -> int:
         """
         Got the rest tokens available for the model after excluding messages tokens and completion max tokens
 
-        :param llm:
+        :param model_config:
         :param messages:
         :return:
         """
-        llm_max_tokens = model_instance.model_rules.max_tokens.max
-        completion_max_tokens = model_instance.model_kwargs.max_tokens
-        used_tokens = self.get_num_tokens_from_messages(model_instance, messages, **kwargs)
-        rest_tokens = llm_max_tokens - completion_max_tokens - used_tokens
+        model_type_instance = model_config.provider_model_bundle.model_type_instance
+        model_type_instance = cast(LargeLanguageModel, model_type_instance)
+
+        model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
+
+        max_tokens = 0
+        for parameter_rule in model_config.model_schema.parameter_rules:
+            if (parameter_rule.name == 'max_tokens'
+                    or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
+                max_tokens = (model_config.parameters.get(parameter_rule.name)
+                              or model_config.parameters.get(parameter_rule.use_template)) or 0
+
+        if model_context_tokens is None:
+            return 0
+
+        if max_tokens is None:
+            max_tokens = 0
+
+        prompt_tokens = model_type_instance.get_num_tokens(
+            model_config.model,
+            model_config.credentials,
+            messages
+        )
+
+        rest_tokens = model_context_tokens - max_tokens - prompt_tokens
 
         return rest_tokens
 

+ 38 - 16
api/core/agent/agent/multi_dataset_router_agent.py

@@ -1,4 +1,3 @@
-import json
 from typing import Tuple, List, Any, Union, Sequence, Optional, cast
 
 from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
@@ -6,13 +5,14 @@ from langchain.agents.openai_functions_agent.base import _format_intermediate_st
 from langchain.callbacks.base import BaseCallbackManager
 from langchain.callbacks.manager import Callbacks
 from langchain.prompts.chat import BaseMessagePromptTemplate
-from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage
-from langchain.schema.language_model import BaseLanguageModel
+from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage
 from langchain.tools import BaseTool
 from pydantic import root_validator
 
-from core.model_providers.models.entity.message import to_prompt_messages
-from core.model_providers.models.llm.base import BaseLLM
+from core.entities.application_entities import ModelConfigEntity
+from core.model_manager import ModelInstance
+from core.entities.message_entities import lc_messages_to_prompt_messages
+from core.model_runtime.entities.message_entities import PromptMessageTool
 from core.third_party.langchain.llms.fake import FakeLLM
 
 
@@ -20,7 +20,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
     """
     An Multi Dataset Retrieve Agent driven by Router.
     """
-    model_instance: BaseLLM
+    model_config: ModelConfigEntity
 
     class Config:
         """Configuration for this pydantic object."""
@@ -81,8 +81,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
                 agent_decision.return_values['output'] = ''
             return agent_decision
         except Exception as e:
-            new_exception = self.model_instance.handle_exceptions(e)
-            raise new_exception
+            raise e
 
     def real_plan(
         self,
@@ -106,16 +105,39 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
         full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
         prompt = self.prompt.format_prompt(**full_inputs)
         messages = prompt.to_messages()
-        prompt_messages = to_prompt_messages(messages)
-        result = self.model_instance.run(
-            messages=prompt_messages,
-            functions=self.functions,
+        prompt_messages = lc_messages_to_prompt_messages(messages)
+
+        model_instance = ModelInstance(
+            provider_model_bundle=self.model_config.provider_model_bundle,
+            model=self.model_config.model,
+        )
+
+        tools = []
+        for function in self.functions:
+            tool = PromptMessageTool(
+                **function
+            )
+
+            tools.append(tool)
+
+        result = model_instance.invoke_llm(
+            prompt_messages=prompt_messages,
+            tools=tools,
+            stream=False,
+            model_parameters={
+                'temperature': 0.2,
+                'top_p': 0.3,
+                'max_tokens': 1500
+            }
         )
 
         ai_message = AIMessage(
-            content=result.content,
+            content=result.message.content or "",
             additional_kwargs={
-                'function_call': result.function_call
+                'function_call': {
+                    'id': result.message.tool_calls[0].id,
+                    **result.message.tool_calls[0].function.dict()
+                } if result.message.tool_calls else None
             }
         )
 
@@ -133,7 +155,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
     @classmethod
     def from_llm_and_tools(
             cls,
-            model_instance: BaseLLM,
+            model_config: ModelConfigEntity,
             tools: Sequence[BaseTool],
             callback_manager: Optional[BaseCallbackManager] = None,
             extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
@@ -147,7 +169,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
             system_message=system_message,
         )
         return cls(
-            model_instance=model_instance,
+            model_config=model_config,
             llm=FakeLLM(response=''),
             prompt=prompt,
             tools=tools,

+ 91 - 34
api/core/agent/agent/openai_function_call.py

@@ -1,4 +1,4 @@
-from typing import List, Tuple, Any, Union, Sequence, Optional
+from typing import List, Tuple, Any, Union, Sequence, Optional, cast
 
 from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
 from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
@@ -13,18 +13,23 @@ from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage,
 from langchain.tools import BaseTool
 from pydantic import root_validator
 
+from core.agent.agent.agent_llm_callback import AgentLLMCallback
 from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
 from core.chain.llm_chain import LLMChain
-from core.model_providers.models.entity.message import to_prompt_messages
-from core.model_providers.models.llm.base import BaseLLM
+from core.entities.application_entities import ModelConfigEntity
+from core.model_manager import ModelInstance
+from core.entities.message_entities import lc_messages_to_prompt_messages
+from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.third_party.langchain.llms.fake import FakeLLM
 
 
 class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
     moving_summary_buffer: str = ""
     moving_summary_index: int = 0
-    summary_model_instance: BaseLLM = None
-    model_instance: BaseLLM
+    summary_model_config: ModelConfigEntity = None
+    model_config: ModelConfigEntity
+    agent_llm_callback: Optional[AgentLLMCallback] = None
 
     class Config:
         """Configuration for this pydantic object."""
@@ -38,13 +43,14 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
     @classmethod
     def from_llm_and_tools(
             cls,
-            model_instance: BaseLLM,
+            model_config: ModelConfigEntity,
             tools: Sequence[BaseTool],
             callback_manager: Optional[BaseCallbackManager] = None,
             extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
             system_message: Optional[SystemMessage] = SystemMessage(
                 content="You are a helpful AI assistant."
             ),
+            agent_llm_callback: Optional[AgentLLMCallback] = None,
             **kwargs: Any,
     ) -> BaseSingleActionAgent:
         prompt = cls.create_prompt(
@@ -52,11 +58,12 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
             system_message=system_message,
         )
         return cls(
-            model_instance=model_instance,
+            model_config=model_config,
             llm=FakeLLM(response=''),
             prompt=prompt,
             tools=tools,
             callback_manager=callback_manager,
+            agent_llm_callback=agent_llm_callback,
             **kwargs,
         )
 
@@ -67,28 +74,49 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
         :param query:
         :return:
         """
-        original_max_tokens = self.model_instance.model_kwargs.max_tokens
-        self.model_instance.model_kwargs.max_tokens = 40
+        original_max_tokens = 0
+        for parameter_rule in self.model_config.model_schema.parameter_rules:
+            if (parameter_rule.name == 'max_tokens'
+                    or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
+                original_max_tokens = (self.model_config.parameters.get(parameter_rule.name)
+                              or self.model_config.parameters.get(parameter_rule.use_template)) or 0
+
+        self.model_config.parameters['max_tokens'] = 40
 
         prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
         messages = prompt.to_messages()
 
         try:
-            prompt_messages = to_prompt_messages(messages)
-            result = self.model_instance.run(
-                messages=prompt_messages,
-                functions=self.functions,
-                callbacks=None
+            prompt_messages = lc_messages_to_prompt_messages(messages)
+            model_instance = ModelInstance(
+                provider_model_bundle=self.model_config.provider_model_bundle,
+                model=self.model_config.model,
             )
-        except Exception as e:
-            new_exception = self.model_instance.handle_exceptions(e)
-            raise new_exception
 
-        function_call = result.function_call
+            tools = []
+            for function in self.functions:
+                tool = PromptMessageTool(
+                    **function
+                )
+
+                tools.append(tool)
+
+            result = model_instance.invoke_llm(
+                prompt_messages=prompt_messages,
+                tools=tools,
+                stream=False,
+                model_parameters={
+                    'temperature': 0.2,
+                    'top_p': 0.3,
+                    'max_tokens': 1500
+                }
+            )
+        except Exception as e:
+            raise e
 
-        self.model_instance.model_kwargs.max_tokens = original_max_tokens
+        self.model_config.parameters['max_tokens'] = original_max_tokens
 
-        return True if function_call else False
+        return True if result.message.tool_calls else False
 
     def plan(
             self,
@@ -113,22 +141,46 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
         prompt = self.prompt.format_prompt(**full_inputs)
         messages = prompt.to_messages()
 
+        prompt_messages = lc_messages_to_prompt_messages(messages)
+
         # summarize messages if rest_tokens < 0
         try:
-            messages = self.summarize_messages_if_needed(messages, functions=self.functions)
+            prompt_messages = self.summarize_messages_if_needed(prompt_messages, functions=self.functions)
         except ExceededLLMTokensLimitError as e:
             return AgentFinish(return_values={"output": str(e)}, log=str(e))
 
-        prompt_messages = to_prompt_messages(messages)
-        result = self.model_instance.run(
-            messages=prompt_messages,
-            functions=self.functions,
+        model_instance = ModelInstance(
+            provider_model_bundle=self.model_config.provider_model_bundle,
+            model=self.model_config.model,
+        )
+
+        tools = []
+        for function in self.functions:
+            tool = PromptMessageTool(
+                **function
+            )
+
+            tools.append(tool)
+
+        result = model_instance.invoke_llm(
+            prompt_messages=prompt_messages,
+            tools=tools,
+            stream=False,
+            callbacks=[self.agent_llm_callback] if self.agent_llm_callback else [],
+            model_parameters={
+                'temperature': 0.2,
+                'top_p': 0.3,
+                'max_tokens': 1500
+            }
         )
 
         ai_message = AIMessage(
-            content=result.content,
+            content=result.message.content or "",
             additional_kwargs={
-                'function_call': result.function_call
+                'function_call': {
+                    'id': result.message.tool_calls[0].id,
+                    **result.message.tool_calls[0].function.dict()
+                } if result.message.tool_calls else None
             }
         )
         agent_decision = _parse_ai_message(ai_message)
@@ -158,9 +210,14 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
         except ValueError:
             return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
 
-    def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
+    def summarize_messages_if_needed(self, messages: List[PromptMessage], **kwargs) -> List[PromptMessage]:
         # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
-        rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
+        rest_tokens = self.get_message_rest_tokens(
+            self.model_config,
+            messages,
+            **kwargs
+        )
+
         rest_tokens = rest_tokens - 20  # to deal with the inaccuracy of rest_tokens
         if rest_tokens >= 0:
             return messages
@@ -210,19 +267,19 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
             ai_prefix="AI",
         )
 
-        chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
+        chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
         return chain.predict(summary=existing_summary, new_lines=new_lines)
 
-    def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
+    def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: List[BaseMessage], **kwargs) -> int:
         """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
 
         Official documentation: https://github.com/openai/openai-cookbook/blob/
         main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
-        if model_instance.model_provider.provider_name == 'azure_openai':
-            model = model_instance.base_model_name
+        if model_config.provider == 'azure_openai':
+            model = model_config.model
             model = model.replace("gpt-35", "gpt-3.5")
         else:
-            model = model_instance.base_model_name
+            model = model_config.credentials.get("base_model_name")
 
         tiktoken_ = _import_tiktoken()
         try:

+ 0 - 158
api/core/agent/agent/output_parser/retirver_dataset_agent.py

@@ -1,158 +0,0 @@
-import json
-from typing import Tuple, List, Any, Union, Sequence, Optional, cast
-
-from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
-from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
-from langchain.callbacks.base import BaseCallbackManager
-from langchain.callbacks.manager import Callbacks
-from langchain.prompts.chat import BaseMessagePromptTemplate
-from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage
-from langchain.schema.language_model import BaseLanguageModel
-from langchain.tools import BaseTool
-from pydantic import root_validator
-
-from core.model_providers.models.entity.message import to_prompt_messages
-from core.model_providers.models.llm.base import BaseLLM
-from core.third_party.langchain.llms.fake import FakeLLM
-from core.tool.dataset_retriever_tool import DatasetRetrieverTool
-
-
-class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
-    """
-    An Multi Dataset Retrieve Agent driven by Router.
-    """
-    model_instance: BaseLLM
-
-    class Config:
-        """Configuration for this pydantic object."""
-
-        arbitrary_types_allowed = True
-
-    @root_validator
-    def validate_llm(cls, values: dict) -> dict:
-        return values
-
-    def should_use_agent(self, query: str):
-        """
-        return should use agent
-
-        :param query:
-        :return:
-        """
-        return True
-
-    def plan(
-        self,
-        intermediate_steps: List[Tuple[AgentAction, str]],
-        callbacks: Callbacks = None,
-        **kwargs: Any,
-    ) -> Union[AgentAction, AgentFinish]:
-        """Given input, decided what to do.
-
-        Args:
-            intermediate_steps: Steps the LLM has taken to date, along with observations
-            **kwargs: User inputs.
-
-        Returns:
-            Action specifying what tool to use.
-        """
-        if len(self.tools) == 0:
-            return AgentFinish(return_values={"output": ''}, log='')
-        elif len(self.tools) == 1:
-            tool = next(iter(self.tools))
-            tool = cast(DatasetRetrieverTool, tool)
-            rst = tool.run(tool_input={'query': kwargs['input']})
-            # output = ''
-            # rst_json = json.loads(rst)
-            # for item in rst_json:
-            #     output += f'{item["content"]}\n'
-            return AgentFinish(return_values={"output": rst}, log=rst)
-
-        if intermediate_steps:
-            _, observation = intermediate_steps[-1]
-            return AgentFinish(return_values={"output": observation}, log=observation)
-
-        try:
-            agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs)
-            if isinstance(agent_decision, AgentAction):
-                tool_inputs = agent_decision.tool_input
-                if isinstance(tool_inputs, dict) and 'query' in tool_inputs and 'chat_history' not in kwargs:
-                    tool_inputs['query'] = kwargs['input']
-                    agent_decision.tool_input = tool_inputs
-            else:
-                agent_decision.return_values['output'] = ''
-            return agent_decision
-        except Exception as e:
-            new_exception = self.model_instance.handle_exceptions(e)
-            raise new_exception
-
-    def real_plan(
-        self,
-        intermediate_steps: List[Tuple[AgentAction, str]],
-        callbacks: Callbacks = None,
-        **kwargs: Any,
-    ) -> Union[AgentAction, AgentFinish]:
-        """Given input, decided what to do.
-
-        Args:
-            intermediate_steps: Steps the LLM has taken to date, along with observations
-            **kwargs: User inputs.
-
-        Returns:
-            Action specifying what tool to use.
-        """
-        agent_scratchpad = _format_intermediate_steps(intermediate_steps)
-        selected_inputs = {
-            k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
-        }
-        full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
-        prompt = self.prompt.format_prompt(**full_inputs)
-        messages = prompt.to_messages()
-        prompt_messages = to_prompt_messages(messages)
-        result = self.model_instance.run(
-            messages=prompt_messages,
-            functions=self.functions,
-        )
-
-        ai_message = AIMessage(
-            content=result.content,
-            additional_kwargs={
-                'function_call': result.function_call
-            }
-        )
-
-        agent_decision = _parse_ai_message(ai_message)
-        return agent_decision
-
-    async def aplan(
-            self,
-            intermediate_steps: List[Tuple[AgentAction, str]],
-            callbacks: Callbacks = None,
-            **kwargs: Any,
-    ) -> Union[AgentAction, AgentFinish]:
-        raise NotImplementedError()
-
-    @classmethod
-    def from_llm_and_tools(
-            cls,
-            model_instance: BaseLLM,
-            tools: Sequence[BaseTool],
-            callback_manager: Optional[BaseCallbackManager] = None,
-            extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
-            system_message: Optional[SystemMessage] = SystemMessage(
-                content="You are a helpful AI assistant."
-            ),
-            **kwargs: Any,
-    ) -> BaseSingleActionAgent:
-        prompt = cls.create_prompt(
-            extra_prompt_messages=extra_prompt_messages,
-            system_message=system_message,
-        )
-        return cls(
-            model_instance=model_instance,
-            llm=FakeLLM(response=''),
-            prompt=prompt,
-            tools=tools,
-            callback_manager=callback_manager,
-            **kwargs,
-        )

+ 18 - 14
api/core/agent/agent/structed_multi_dataset_router_agent.py

@@ -12,9 +12,7 @@ from langchain.tools import BaseTool
 from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
 
 from core.chain.llm_chain import LLMChain
-from core.model_providers.models.entity.model_params import ModelMode
-from core.model_providers.models.llm.base import BaseLLM
-from core.tool.dataset_retriever_tool import DatasetRetrieverTool
+from core.entities.application_entities import ModelConfigEntity
 
 FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
 The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
@@ -69,10 +67,10 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
         return True
 
     def plan(
-        self,
-        intermediate_steps: List[Tuple[AgentAction, str]],
-        callbacks: Callbacks = None,
-        **kwargs: Any,
+            self,
+            intermediate_steps: List[Tuple[AgentAction, str]],
+            callbacks: Callbacks = None,
+            **kwargs: Any,
     ) -> Union[AgentAction, AgentFinish]:
         """Given input, decided what to do.
 
@@ -101,8 +99,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
         try:
             full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
         except Exception as e:
-            new_exception = self.llm_chain.model_instance.handle_exceptions(e)
-            raise new_exception
+            raise e
 
         try:
             agent_decision = self.output_parser.parse(full_output)
@@ -119,6 +116,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
         except OutputParserException:
             return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
                                           "I don't know how to respond to that."}, "")
+
     @classmethod
     def create_prompt(
             cls,
@@ -182,7 +180,7 @@ Thought: {agent_scratchpad}
         return PromptTemplate(template=template, input_variables=input_variables)
 
     def _construct_scratchpad(
-        self, intermediate_steps: List[Tuple[AgentAction, str]]
+            self, intermediate_steps: List[Tuple[AgentAction, str]]
     ) -> str:
         agent_scratchpad = ""
         for action, observation in intermediate_steps:
@@ -193,7 +191,7 @@ Thought: {agent_scratchpad}
             raise ValueError("agent_scratchpad should be of type string.")
         if agent_scratchpad:
             llm_chain = cast(LLMChain, self.llm_chain)
-            if llm_chain.model_instance.model_mode == ModelMode.CHAT:
+            if llm_chain.model_config.mode == "chat":
                 return (
                     f"This was your previous work "
                     f"(but I haven't seen any of it! I only see what "
@@ -207,7 +205,7 @@ Thought: {agent_scratchpad}
     @classmethod
     def from_llm_and_tools(
             cls,
-            model_instance: BaseLLM,
+            model_config: ModelConfigEntity,
             tools: Sequence[BaseTool],
             callback_manager: Optional[BaseCallbackManager] = None,
             output_parser: Optional[AgentOutputParser] = None,
@@ -221,7 +219,7 @@ Thought: {agent_scratchpad}
     ) -> Agent:
         """Construct an agent from an LLM and tools."""
         cls._validate_tools(tools)
-        if model_instance.model_mode == ModelMode.CHAT:
+        if model_config.mode == "chat":
             prompt = cls.create_prompt(
                 tools,
                 prefix=prefix,
@@ -238,10 +236,16 @@ Thought: {agent_scratchpad}
                 format_instructions=format_instructions,
                 input_variables=input_variables
             )
+
         llm_chain = LLMChain(
-            model_instance=model_instance,
+            model_config=model_config,
             prompt=prompt,
             callback_manager=callback_manager,
+            parameters={
+                'temperature': 0.2,
+                'top_p': 0.3,
+                'max_tokens': 1500
+            }
         )
         tool_names = [tool.name for tool in tools]
         _output_parser = output_parser

+ 22 - 13
api/core/agent/agent/structured_chat.py

@@ -13,10 +13,11 @@ from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage,
 from langchain.tools import BaseTool
 from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
 
+from core.agent.agent.agent_llm_callback import AgentLLMCallback
 from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
 from core.chain.llm_chain import LLMChain
-from core.model_providers.models.entity.model_params import ModelMode
-from core.model_providers.models.llm.base import BaseLLM
+from core.entities.application_entities import ModelConfigEntity
+from core.entities.message_entities import lc_messages_to_prompt_messages
 
 FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
 The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
@@ -54,7 +55,7 @@ Action:
 class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
     moving_summary_buffer: str = ""
     moving_summary_index: int = 0
-    summary_model_instance: BaseLLM = None
+    summary_model_config: ModelConfigEntity = None
 
     class Config:
         """Configuration for this pydantic object."""
@@ -82,7 +83,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
 
         Args:
             intermediate_steps: Steps the LLM has taken to date,
-                along with observations
+                along with observatons
             callbacks: Callbacks to run.
             **kwargs: User inputs.
 
@@ -96,15 +97,16 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
         if prompts:
             messages = prompts[0].to_messages()
 
-        rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_instance, messages)
+        prompt_messages = lc_messages_to_prompt_messages(messages)
+
+        rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_config, prompt_messages)
         if rest_tokens < 0:
             full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
 
         try:
             full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
         except Exception as e:
-            new_exception = self.llm_chain.model_instance.handle_exceptions(e)
-            raise new_exception
+            raise e
 
         try:
             agent_decision = self.output_parser.parse(full_output)
@@ -119,7 +121,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
                                           "I don't know how to respond to that."}, "")
 
     def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
-        if len(intermediate_steps) >= 2 and self.summary_model_instance:
+        if len(intermediate_steps) >= 2 and self.summary_model_config:
             should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
             should_summary_messages = [AIMessage(content=observation)
                                        for _, observation in should_summary_intermediate_steps]
@@ -153,7 +155,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
             ai_prefix="AI",
         )
 
-        chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
+        chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
         return chain.predict(summary=existing_summary, new_lines=new_lines)
 
     @classmethod
@@ -229,7 +231,7 @@ Thought: {agent_scratchpad}
             raise ValueError("agent_scratchpad should be of type string.")
         if agent_scratchpad:
             llm_chain = cast(LLMChain, self.llm_chain)
-            if llm_chain.model_instance.model_mode == ModelMode.CHAT:
+            if llm_chain.model_config.mode == "chat":
                 return (
                     f"This was your previous work "
                     f"(but I haven't seen any of it! I only see what "
@@ -243,7 +245,7 @@ Thought: {agent_scratchpad}
     @classmethod
     def from_llm_and_tools(
             cls,
-            model_instance: BaseLLM,
+            model_config: ModelConfigEntity,
             tools: Sequence[BaseTool],
             callback_manager: Optional[BaseCallbackManager] = None,
             output_parser: Optional[AgentOutputParser] = None,
@@ -253,11 +255,12 @@ Thought: {agent_scratchpad}
             format_instructions: str = FORMAT_INSTRUCTIONS,
             input_variables: Optional[List[str]] = None,
             memory_prompts: Optional[List[BasePromptTemplate]] = None,
+            agent_llm_callback: Optional[AgentLLMCallback] = None,
             **kwargs: Any,
     ) -> Agent:
         """Construct an agent from an LLM and tools."""
         cls._validate_tools(tools)
-        if model_instance.model_mode == ModelMode.CHAT:
+        if model_config.mode == "chat":
             prompt = cls.create_prompt(
                 tools,
                 prefix=prefix,
@@ -275,9 +278,15 @@ Thought: {agent_scratchpad}
                 input_variables=input_variables,
             )
         llm_chain = LLMChain(
-            model_instance=model_instance,
+            model_config=model_config,
             prompt=prompt,
             callback_manager=callback_manager,
+            agent_llm_callback=agent_llm_callback,
+            parameters={
+                'temperature': 0.2,
+                'top_p': 0.3,
+                'max_tokens': 1500
+            }
         )
         tool_names = [tool.name for tool in tools]
         _output_parser = output_parser

+ 33 - 23
api/core/agent/agent_executor.py

@@ -4,10 +4,10 @@ from typing import Union, Optional
 
 from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent
 from langchain.callbacks.manager import Callbacks
-from langchain.memory.chat_memory import BaseChatMemory
 from langchain.tools import BaseTool
 from pydantic import BaseModel, Extra
 
+from core.agent.agent.agent_llm_callback import AgentLLMCallback
 from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
 from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
 from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
@@ -15,9 +15,11 @@ from core.agent.agent.structed_multi_dataset_router_agent import StructuredMulti
 from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
 from langchain.agents import AgentExecutor as LCAgentExecutor
 
+from core.entities.application_entities import ModelConfigEntity
+from core.entities.message_entities import prompt_messages_to_lc_messages
 from core.helper import moderation
-from core.model_providers.error import LLMError
-from core.model_providers.models.llm.base import BaseLLM
+from core.memory.token_buffer_memory import TokenBufferMemory
+from core.model_runtime.errors.invoke import InvokeError
 from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
 from core.tool.dataset_retriever_tool import DatasetRetrieverTool
 
@@ -31,14 +33,15 @@ class PlanningStrategy(str, enum.Enum):
 
 class AgentConfiguration(BaseModel):
     strategy: PlanningStrategy
-    model_instance: BaseLLM
+    model_config: ModelConfigEntity
     tools: list[BaseTool]
-    summary_model_instance: BaseLLM = None
-    memory: Optional[BaseChatMemory] = None
+    summary_model_config: Optional[ModelConfigEntity] = None
+    memory: Optional[TokenBufferMemory] = None
     callbacks: Callbacks = None
     max_iterations: int = 6
     max_execution_time: Optional[float] = None
     early_stopping_method: str = "generate"
+    agent_llm_callback: Optional[AgentLLMCallback] = None
     # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
 
     class Config:
@@ -62,34 +65,42 @@ class AgentExecutor:
     def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
         if self.configuration.strategy == PlanningStrategy.REACT:
             agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
-                model_instance=self.configuration.model_instance,
+                model_config=self.configuration.model_config,
                 tools=self.configuration.tools,
                 output_parser=StructuredChatOutputParser(),
-                summary_model_instance=self.configuration.summary_model_instance
-                if self.configuration.summary_model_instance else None,
+                summary_model_config=self.configuration.summary_model_config
+                if self.configuration.summary_model_config else None,
+                agent_llm_callback=self.configuration.agent_llm_callback,
                 verbose=True
             )
         elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
             agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
-                model_instance=self.configuration.model_instance,
+                model_config=self.configuration.model_config,
                 tools=self.configuration.tools,
-                extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,  # used for read chat histories memory
-                summary_model_instance=self.configuration.summary_model_instance
-                if self.configuration.summary_model_instance else None,
+                extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
+                if self.configuration.memory else None,  # used for read chat histories memory
+                summary_model_config=self.configuration.summary_model_config
+                if self.configuration.summary_model_config else None,
+                agent_llm_callback=self.configuration.agent_llm_callback,
                 verbose=True
             )
         elif self.configuration.strategy == PlanningStrategy.ROUTER:
-            self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)]
+            self.configuration.tools = [t for t in self.configuration.tools
+                                        if isinstance(t, DatasetRetrieverTool)
+                                        or isinstance(t, DatasetMultiRetrieverTool)]
             agent = MultiDatasetRouterAgent.from_llm_and_tools(
-                model_instance=self.configuration.model_instance,
+                model_config=self.configuration.model_config,
                 tools=self.configuration.tools,
-                extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
+                extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
+                if self.configuration.memory else None,
                 verbose=True
             )
         elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER:
-            self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)]
+            self.configuration.tools = [t for t in self.configuration.tools
+                                        if isinstance(t, DatasetRetrieverTool)
+                                        or isinstance(t, DatasetMultiRetrieverTool)]
             agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
-                model_instance=self.configuration.model_instance,
+                model_config=self.configuration.model_config,
                 tools=self.configuration.tools,
                 output_parser=StructuredChatOutputParser(),
                 verbose=True
@@ -104,11 +115,11 @@ class AgentExecutor:
 
     def run(self, query: str) -> AgentExecuteResult:
         moderation_result = moderation.check_moderation(
-            self.configuration.model_instance.model_provider,
+            self.configuration.model_config,
             query
         )
 
-        if not moderation_result:
+        if moderation_result:
             return AgentExecuteResult(
                 output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
                 strategy=self.configuration.strategy,
@@ -118,7 +129,6 @@ class AgentExecutor:
         agent_executor = LCAgentExecutor.from_agent_and_tools(
             agent=self.agent,
             tools=self.configuration.tools,
-            memory=self.configuration.memory,
             max_iterations=self.configuration.max_iterations,
             max_execution_time=self.configuration.max_execution_time,
             early_stopping_method=self.configuration.early_stopping_method,
@@ -126,8 +136,8 @@ class AgentExecutor:
         )
 
         try:
-            output = agent_executor.run(query)
-        except LLMError as ex:
+            output = agent_executor.run(input=query)
+        except InvokeError as ex:
             raise ex
         except Exception as ex:
             logging.exception("agent_executor run failed")

+ 0 - 0
api/core/model_providers/models/__init__.py → api/core/app_runner/__init__.py


+ 251 - 0
api/core/app_runner/agent_app_runner.py

@@ -0,0 +1,251 @@
+import json
+import logging
+from typing import cast
+
+from core.agent.agent.agent_llm_callback import AgentLLMCallback
+from core.app_runner.app_runner import AppRunner
+from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
+from core.entities.application_entities import ApplicationGenerateEntity, PromptTemplateEntity, ModelConfigEntity
+from core.application_queue_manager import ApplicationQueueManager
+from core.features.agent_runner import AgentRunnerFeature
+from core.memory.token_buffer_memory import TokenBufferMemory
+from core.model_manager import ModelInstance
+from core.model_runtime.entities.llm_entities import LLMUsage
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from extensions.ext_database import db
+from models.model import Conversation, Message, App, MessageChain, MessageAgentThought
+
+logger = logging.getLogger(__name__)
+
+
+class AgentApplicationRunner(AppRunner):
+    """
+    Agent Application Runner
+    """
+
+    def run(self, application_generate_entity: ApplicationGenerateEntity,
+            queue_manager: ApplicationQueueManager,
+            conversation: Conversation,
+            message: Message) -> None:
+        """
+        Run agent application
+        :param application_generate_entity: application generate entity
+        :param queue_manager: application queue manager
+        :param conversation: conversation
+        :param message: message
+        :return:
+        """
+        app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
+        if not app_record:
+            raise ValueError(f"App not found")
+
+        app_orchestration_config = application_generate_entity.app_orchestration_config_entity
+
+        inputs = application_generate_entity.inputs
+        query = application_generate_entity.query
+        files = application_generate_entity.files
+
+        # Pre-calculate the number of tokens of the prompt messages,
+        # and return the rest number of tokens by model context token size limit and max token size limit.
+        # If the rest number of tokens is not enough, raise exception.
+        # Include: prompt template, inputs, query(optional), files(optional)
+        # Not Include: memory, external data, dataset context
+        self.get_pre_calculate_rest_tokens(
+            app_record=app_record,
+            model_config=app_orchestration_config.model_config,
+            prompt_template_entity=app_orchestration_config.prompt_template,
+            inputs=inputs,
+            files=files,
+            query=query
+        )
+
+        memory = None
+        if application_generate_entity.conversation_id:
+            # get memory of conversation (read-only)
+            model_instance = ModelInstance(
+                provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
+                model=app_orchestration_config.model_config.model
+            )
+
+            memory = TokenBufferMemory(
+                conversation=conversation,
+                model_instance=model_instance
+            )
+
+        # reorganize all inputs and template to prompt messages
+        # Include: prompt template, inputs, query(optional), files(optional)
+        #          memory(optional)
+        prompt_messages, stop = self.originze_prompt_messages(
+            app_record=app_record,
+            model_config=app_orchestration_config.model_config,
+            prompt_template_entity=app_orchestration_config.prompt_template,
+            inputs=inputs,
+            files=files,
+            query=query,
+            context=None,
+            memory=memory
+        )
+
+        # Create MessageChain
+        message_chain = self._init_message_chain(
+            message=message,
+            query=query
+        )
+
+        # add agent callback to record agent thoughts
+        agent_callback = AgentLoopGatherCallbackHandler(
+            model_config=app_orchestration_config.model_config,
+            message=message,
+            queue_manager=queue_manager,
+            message_chain=message_chain
+        )
+
+        # init LLM Callback
+        agent_llm_callback = AgentLLMCallback(
+            agent_callback=agent_callback
+        )
+
+        agent_runner = AgentRunnerFeature(
+            tenant_id=application_generate_entity.tenant_id,
+            app_orchestration_config=app_orchestration_config,
+            model_config=app_orchestration_config.model_config,
+            config=app_orchestration_config.agent,
+            queue_manager=queue_manager,
+            message=message,
+            user_id=application_generate_entity.user_id,
+            agent_llm_callback=agent_llm_callback,
+            callback=agent_callback,
+            memory=memory
+        )
+
+        # agent run
+        result = agent_runner.run(
+            query=query,
+            invoke_from=application_generate_entity.invoke_from
+        )
+
+        if result:
+            self._save_message_chain(
+                message_chain=message_chain,
+                output_text=result
+            )
+
+        if (result
+                and app_orchestration_config.prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE
+                and app_orchestration_config.prompt_template.simple_prompt_template
+        ):
+            # Direct output if agent result exists and has pre prompt
+            self.direct_output(
+                queue_manager=queue_manager,
+                app_orchestration_config=app_orchestration_config,
+                prompt_messages=prompt_messages,
+                stream=application_generate_entity.stream,
+                text=result,
+                usage=self._get_usage_of_all_agent_thoughts(
+                    model_config=app_orchestration_config.model_config,
+                    message=message
+                )
+            )
+        else:
+            # As normal LLM run, agent result as context
+            context = result
+
+            # reorganize all inputs and template to prompt messages
+            # Include: prompt template, inputs, query(optional), files(optional)
+            #          memory(optional), external data, dataset context(optional)
+            prompt_messages, stop = self.originze_prompt_messages(
+                app_record=app_record,
+                model_config=app_orchestration_config.model_config,
+                prompt_template_entity=app_orchestration_config.prompt_template,
+                inputs=inputs,
+                files=files,
+                query=query,
+                context=context,
+                memory=memory
+            )
+
+            # Re-calculate the max tokens if sum(prompt_token +  max_tokens) over model token limit
+            self.recale_llm_max_tokens(
+                model_config=app_orchestration_config.model_config,
+                prompt_messages=prompt_messages
+            )
+
+            # Invoke model
+            model_instance = ModelInstance(
+                provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
+                model=app_orchestration_config.model_config.model
+            )
+
+            invoke_result = model_instance.invoke_llm(
+                prompt_messages=prompt_messages,
+                model_parameters=app_orchestration_config.model_config.parameters,
+                stop=stop,
+                stream=application_generate_entity.stream,
+                user=application_generate_entity.user_id,
+            )
+
+            # handle invoke result
+            self._handle_invoke_result(
+                invoke_result=invoke_result,
+                queue_manager=queue_manager,
+                stream=application_generate_entity.stream
+            )
+
+    def _init_message_chain(self, message: Message, query: str) -> MessageChain:
+        """
+        Init MessageChain
+        :param message: message
+        :param query: query
+        :return:
+        """
+        message_chain = MessageChain(
+            message_id=message.id,
+            type="AgentExecutor",
+            input=json.dumps({
+                "input": query
+            })
+        )
+
+        db.session.add(message_chain)
+        db.session.commit()
+
+        return message_chain
+
+    def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None:
+        """
+        Save MessageChain
+        :param message_chain: message chain
+        :param output_text: output text
+        :return:
+        """
+        message_chain.output = json.dumps({
+            "output": output_text
+        })
+        db.session.commit()
+
+    def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
+                                         message: Message) -> LLMUsage:
+        """
+        Get usage of all agent thoughts
+        :param model_config: model config
+        :param message: message
+        :return:
+        """
+        agent_thoughts = (db.session.query(MessageAgentThought)
+                          .filter(MessageAgentThought.message_id == message.id).all())
+
+        all_message_tokens = 0
+        all_answer_tokens = 0
+        for agent_thought in agent_thoughts:
+            all_message_tokens += agent_thought.message_tokens
+            all_answer_tokens += agent_thought.answer_tokens
+
+        model_type_instance = model_config.provider_model_bundle.model_type_instance
+        model_type_instance = cast(LargeLanguageModel, model_type_instance)
+
+        return model_type_instance._calc_response_usage(
+            model_config.model,
+            model_config.credentials,
+            all_message_tokens,
+            all_answer_tokens
+        )

+ 267 - 0
api/core/app_runner/app_runner.py

@@ -0,0 +1,267 @@
+import time
+from typing import cast, Optional, List, Tuple, Generator, Union
+
+from core.application_queue_manager import ApplicationQueueManager
+from core.entities.application_entities import ModelConfigEntity, PromptTemplateEntity, AppOrchestrationConfigEntity
+from core.file.file_obj import FileObj
+from core.memory.token_buffer_memory import TokenBufferMemory
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
+from core.model_runtime.entities.message_entities import PromptMessage, AssistantPromptMessage
+from core.model_runtime.entities.model_entities import ModelPropertyKey
+from core.model_runtime.errors.invoke import InvokeBadRequestError
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.prompt.prompt_transform import PromptTransform
+from models.model import App
+
+
+class AppRunner:
+    def get_pre_calculate_rest_tokens(self, app_record: App,
+                                      model_config: ModelConfigEntity,
+                                      prompt_template_entity: PromptTemplateEntity,
+                                      inputs: dict[str, str],
+                                      files: list[FileObj],
+                                      query: Optional[str] = None) -> int:
+        """
+        Get pre calculate rest tokens
+        :param app_record: app record
+        :param model_config: model config entity
+        :param prompt_template_entity: prompt template entity
+        :param inputs: inputs
+        :param files: files
+        :param query: query
+        :return:
+        """
+        model_type_instance = model_config.provider_model_bundle.model_type_instance
+        model_type_instance = cast(LargeLanguageModel, model_type_instance)
+
+        model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
+
+        max_tokens = 0
+        for parameter_rule in model_config.model_schema.parameter_rules:
+            if (parameter_rule.name == 'max_tokens'
+                    or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
+                max_tokens = (model_config.parameters.get(parameter_rule.name)
+                              or model_config.parameters.get(parameter_rule.use_template)) or 0
+
+        if model_context_tokens is None:
+            return -1
+
+        if max_tokens is None:
+            max_tokens = 0
+
+        # get prompt messages without memory and context
+        prompt_messages, stop = self.originze_prompt_messages(
+            app_record=app_record,
+            model_config=model_config,
+            prompt_template_entity=prompt_template_entity,
+            inputs=inputs,
+            files=files,
+            query=query
+        )
+
+        prompt_tokens = model_type_instance.get_num_tokens(
+            model_config.model,
+            model_config.credentials,
+            prompt_messages
+        )
+
+        rest_tokens = model_context_tokens - max_tokens - prompt_tokens
+        if rest_tokens < 0:
+            raise InvokeBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
+                                        "or shrink the max token, or switch to a llm with a larger token limit size.")
+
+        return rest_tokens
+
+    def recale_llm_max_tokens(self, model_config: ModelConfigEntity,
+                              prompt_messages: List[PromptMessage]):
+        # recalc max_tokens if sum(prompt_token +  max_tokens) over model token limit
+        model_type_instance = model_config.provider_model_bundle.model_type_instance
+        model_type_instance = cast(LargeLanguageModel, model_type_instance)
+
+        model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
+
+        max_tokens = 0
+        for parameter_rule in model_config.model_schema.parameter_rules:
+            if (parameter_rule.name == 'max_tokens'
+                    or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
+                max_tokens = (model_config.parameters.get(parameter_rule.name)
+                              or model_config.parameters.get(parameter_rule.use_template)) or 0
+
+        if model_context_tokens is None:
+            return -1
+
+        if max_tokens is None:
+            max_tokens = 0
+
+        prompt_tokens = model_type_instance.get_num_tokens(
+            model_config.model,
+            model_config.credentials,
+            prompt_messages
+        )
+
+        if prompt_tokens + max_tokens > model_context_tokens:
+            max_tokens = max(model_context_tokens - prompt_tokens, 16)
+
+            for parameter_rule in model_config.model_schema.parameter_rules:
+                if (parameter_rule.name == 'max_tokens'
+                        or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
+                    model_config.parameters[parameter_rule.name] = max_tokens
+
+    def originze_prompt_messages(self, app_record: App,
+                                 model_config: ModelConfigEntity,
+                                 prompt_template_entity: PromptTemplateEntity,
+                                 inputs: dict[str, str],
+                                 files: list[FileObj],
+                                 query: Optional[str] = None,
+                                 context: Optional[str] = None,
+                                 memory: Optional[TokenBufferMemory] = None) \
+            -> Tuple[List[PromptMessage], Optional[List[str]]]:
+        """
+        Organize prompt messages
+        :param context:
+        :param app_record: app record
+        :param model_config: model config entity
+        :param prompt_template_entity: prompt template entity
+        :param inputs: inputs
+        :param files: files
+        :param query: query
+        :param memory: memory
+        :return:
+        """
+        prompt_transform = PromptTransform()
+
+        # get prompt without memory and context
+        if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
+            prompt_messages, stop = prompt_transform.get_prompt(
+                app_mode=app_record.mode,
+                prompt_template_entity=prompt_template_entity,
+                inputs=inputs,
+                query=query if query else '',
+                files=files,
+                context=context,
+                memory=memory,
+                model_config=model_config
+            )
+        else:
+            prompt_messages = prompt_transform.get_advanced_prompt(
+                app_mode=app_record.mode,
+                prompt_template_entity=prompt_template_entity,
+                inputs=inputs,
+                query=query,
+                files=files,
+                context=context,
+                memory=memory,
+                model_config=model_config
+            )
+            stop = model_config.stop
+
+        return prompt_messages, stop
+
+    def direct_output(self, queue_manager: ApplicationQueueManager,
+                      app_orchestration_config: AppOrchestrationConfigEntity,
+                      prompt_messages: list,
+                      text: str,
+                      stream: bool,
+                      usage: Optional[LLMUsage] = None) -> None:
+        """
+        Direct output
+        :param queue_manager: application queue manager
+        :param app_orchestration_config: app orchestration config
+        :param prompt_messages: prompt messages
+        :param text: text
+        :param stream: stream
+        :param usage: usage
+        :return:
+        """
+        if stream:
+            index = 0
+            for token in text:
+                queue_manager.publish_chunk_message(LLMResultChunk(
+                    model=app_orchestration_config.model_config.model,
+                    prompt_messages=prompt_messages,
+                    delta=LLMResultChunkDelta(
+                        index=index,
+                        message=AssistantPromptMessage(content=token)
+                    )
+                ))
+                index += 1
+                time.sleep(0.01)
+
+        queue_manager.publish_message_end(
+            llm_result=LLMResult(
+                model=app_orchestration_config.model_config.model,
+                prompt_messages=prompt_messages,
+                message=AssistantPromptMessage(content=text),
+                usage=usage if usage else LLMUsage.empty_usage()
+            )
+        )
+
+    def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator],
+                              queue_manager: ApplicationQueueManager,
+                              stream: bool) -> None:
+        """
+        Handle invoke result
+        :param invoke_result: invoke result
+        :param queue_manager: application queue manager
+        :param stream: stream
+        :return:
+        """
+        if not stream:
+            self._handle_invoke_result_direct(
+                invoke_result=invoke_result,
+                queue_manager=queue_manager
+            )
+        else:
+            self._handle_invoke_result_stream(
+                invoke_result=invoke_result,
+                queue_manager=queue_manager
+            )
+
+    def _handle_invoke_result_direct(self, invoke_result: LLMResult,
+                                     queue_manager: ApplicationQueueManager) -> None:
+        """
+        Handle invoke result direct
+        :param invoke_result: invoke result
+        :param queue_manager: application queue manager
+        :return:
+        """
+        queue_manager.publish_message_end(
+            llm_result=invoke_result
+        )
+
+    def _handle_invoke_result_stream(self, invoke_result: Generator,
+                                     queue_manager: ApplicationQueueManager) -> None:
+        """
+        Handle invoke result
+        :param invoke_result: invoke result
+        :param queue_manager: application queue manager
+        :return:
+        """
+        model = None
+        prompt_messages = []
+        text = ''
+        usage = None
+        for result in invoke_result:
+            queue_manager.publish_chunk_message(result)
+
+            text += result.delta.message.content
+
+            if not model:
+                model = result.model
+
+            if not prompt_messages:
+                prompt_messages = result.prompt_messages
+
+            if not usage and result.delta.usage:
+                usage = result.delta.usage
+
+        llm_result = LLMResult(
+            model=model,
+            prompt_messages=prompt_messages,
+            message=AssistantPromptMessage(content=text),
+            usage=usage
+        )
+
+        queue_manager.publish_message_end(
+            llm_result=llm_result
+        )

+ 363 - 0
api/core/app_runner/basic_app_runner.py

@@ -0,0 +1,363 @@
+import logging
+from typing import Tuple, Optional
+
+from core.app_runner.app_runner import AppRunner
+from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
+from core.entities.application_entities import ApplicationGenerateEntity, ModelConfigEntity, \
+    AppOrchestrationConfigEntity, InvokeFrom, ExternalDataVariableEntity, DatasetEntity
+from core.application_queue_manager import ApplicationQueueManager
+from core.features.annotation_reply import AnnotationReplyFeature
+from core.features.dataset_retrieval import DatasetRetrievalFeature
+from core.features.external_data_fetch import ExternalDataFetchFeature
+from core.features.hosting_moderation import HostingModerationFeature
+from core.features.moderation import ModerationFeature
+from core.memory.token_buffer_memory import TokenBufferMemory
+from core.model_manager import ModelInstance
+from core.model_runtime.entities.message_entities import PromptMessage
+from core.moderation.base import ModerationException
+from core.prompt.prompt_transform import AppMode
+from extensions.ext_database import db
+from models.model import Conversation, Message, App, MessageAnnotation
+
+logger = logging.getLogger(__name__)
+
+
+class BasicApplicationRunner(AppRunner):
+    """
+    Basic Application Runner
+    """
+
+    def run(self, application_generate_entity: ApplicationGenerateEntity,
+            queue_manager: ApplicationQueueManager,
+            conversation: Conversation,
+            message: Message) -> None:
+        """
+        Run application
+        :param application_generate_entity: application generate entity
+        :param queue_manager: application queue manager
+        :param conversation: conversation
+        :param message: message
+        :return:
+        """
+        app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
+        if not app_record:
+            raise ValueError(f"App not found")
+
+        app_orchestration_config = application_generate_entity.app_orchestration_config_entity
+
+        inputs = application_generate_entity.inputs
+        query = application_generate_entity.query
+        files = application_generate_entity.files
+
+        # Pre-calculate the number of tokens of the prompt messages,
+        # and return the rest number of tokens by model context token size limit and max token size limit.
+        # If the rest number of tokens is not enough, raise exception.
+        # Include: prompt template, inputs, query(optional), files(optional)
+        # Not Include: memory, external data, dataset context
+        self.get_pre_calculate_rest_tokens(
+            app_record=app_record,
+            model_config=app_orchestration_config.model_config,
+            prompt_template_entity=app_orchestration_config.prompt_template,
+            inputs=inputs,
+            files=files,
+            query=query
+        )
+
+        memory = None
+        if application_generate_entity.conversation_id:
+            # get memory of conversation (read-only)
+            model_instance = ModelInstance(
+                provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
+                model=app_orchestration_config.model_config.model
+            )
+
+            memory = TokenBufferMemory(
+                conversation=conversation,
+                model_instance=model_instance
+            )
+
+        # organize all inputs and template to prompt messages
+        # Include: prompt template, inputs, query(optional), files(optional)
+        #          memory(optional)
+        prompt_messages, stop = self.originze_prompt_messages(
+            app_record=app_record,
+            model_config=app_orchestration_config.model_config,
+            prompt_template_entity=app_orchestration_config.prompt_template,
+            inputs=inputs,
+            files=files,
+            query=query,
+            memory=memory
+        )
+
+        # moderation
+        try:
+            # process sensitive_word_avoidance
+            _, inputs, query = self.moderation_for_inputs(
+                app_id=app_record.id,
+                tenant_id=application_generate_entity.tenant_id,
+                app_orchestration_config_entity=app_orchestration_config,
+                inputs=inputs,
+                query=query,
+            )
+        except ModerationException as e:
+            self.direct_output(
+                queue_manager=queue_manager,
+                app_orchestration_config=app_orchestration_config,
+                prompt_messages=prompt_messages,
+                text=str(e),
+                stream=application_generate_entity.stream
+            )
+            return
+
+        if query:
+            # annotation reply
+            annotation_reply = self.query_app_annotations_to_reply(
+                app_record=app_record,
+                message=message,
+                query=query,
+                user_id=application_generate_entity.user_id,
+                invoke_from=application_generate_entity.invoke_from
+            )
+
+            if annotation_reply:
+                queue_manager.publish_annotation_reply(
+                    message_annotation_id=annotation_reply.id
+                )
+                self.direct_output(
+                    queue_manager=queue_manager,
+                    app_orchestration_config=app_orchestration_config,
+                    prompt_messages=prompt_messages,
+                    text=annotation_reply.content,
+                    stream=application_generate_entity.stream
+                )
+                return
+
+            # fill in variable inputs from external data tools if exists
+            external_data_tools = app_orchestration_config.external_data_variables
+            if external_data_tools:
+                inputs = self.fill_in_inputs_from_external_data_tools(
+                    tenant_id=app_record.tenant_id,
+                    app_id=app_record.id,
+                    external_data_tools=external_data_tools,
+                    inputs=inputs,
+                    query=query
+                )
+
+        # get context from datasets
+        context = None
+        if app_orchestration_config.dataset:
+            context = self.retrieve_dataset_context(
+                tenant_id=app_record.tenant_id,
+                app_record=app_record,
+                queue_manager=queue_manager,
+                model_config=app_orchestration_config.model_config,
+                show_retrieve_source=app_orchestration_config.show_retrieve_source,
+                dataset_config=app_orchestration_config.dataset,
+                message=message,
+                inputs=inputs,
+                query=query,
+                user_id=application_generate_entity.user_id,
+                invoke_from=application_generate_entity.invoke_from,
+                memory=memory
+            )
+
+        # reorganize all inputs and template to prompt messages
+        # Include: prompt template, inputs, query(optional), files(optional)
+        #          memory(optional), external data, dataset context(optional)
+        prompt_messages, stop = self.originze_prompt_messages(
+            app_record=app_record,
+            model_config=app_orchestration_config.model_config,
+            prompt_template_entity=app_orchestration_config.prompt_template,
+            inputs=inputs,
+            files=files,
+            query=query,
+            context=context,
+            memory=memory
+        )
+
+        # check hosting moderation
+        hosting_moderation_result = self.check_hosting_moderation(
+            application_generate_entity=application_generate_entity,
+            queue_manager=queue_manager,
+            prompt_messages=prompt_messages
+        )
+
+        if hosting_moderation_result:
+            return
+
+        # Re-calculate the max tokens if sum(prompt_token +  max_tokens) over model token limit
+        self.recale_llm_max_tokens(
+            model_config=app_orchestration_config.model_config,
+            prompt_messages=prompt_messages
+        )
+
+        # Invoke model
+        model_instance = ModelInstance(
+            provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
+            model=app_orchestration_config.model_config.model
+        )
+
+        invoke_result = model_instance.invoke_llm(
+            prompt_messages=prompt_messages,
+            model_parameters=app_orchestration_config.model_config.parameters,
+            stop=stop,
+            stream=application_generate_entity.stream,
+            user=application_generate_entity.user_id,
+        )
+
+        # handle invoke result
+        self._handle_invoke_result(
+            invoke_result=invoke_result,
+            queue_manager=queue_manager,
+            stream=application_generate_entity.stream
+        )
+
+    def moderation_for_inputs(self, app_id: str,
+                              tenant_id: str,
+                              app_orchestration_config_entity: AppOrchestrationConfigEntity,
+                              inputs: dict,
+                              query: str) -> Tuple[bool, dict, str]:
+        """
+        Process sensitive_word_avoidance.
+        :param app_id: app id
+        :param tenant_id: tenant id
+        :param app_orchestration_config_entity: app orchestration config entity
+        :param inputs: inputs
+        :param query: query
+        :return:
+        """
+        moderation_feature = ModerationFeature()
+        return moderation_feature.check(
+            app_id=app_id,
+            tenant_id=tenant_id,
+            app_orchestration_config_entity=app_orchestration_config_entity,
+            inputs=inputs,
+            query=query,
+        )
+
+    def query_app_annotations_to_reply(self, app_record: App,
+                                       message: Message,
+                                       query: str,
+                                       user_id: str,
+                                       invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
+        """
+        Query app annotations to reply
+        :param app_record: app record
+        :param message: message
+        :param query: query
+        :param user_id: user id
+        :param invoke_from: invoke from
+        :return:
+        """
+        annotation_reply_feature = AnnotationReplyFeature()
+        return annotation_reply_feature.query(
+            app_record=app_record,
+            message=message,
+            query=query,
+            user_id=user_id,
+            invoke_from=invoke_from
+        )
+
+    def fill_in_inputs_from_external_data_tools(self, tenant_id: str,
+                                                app_id: str,
+                                                external_data_tools: list[ExternalDataVariableEntity],
+                                                inputs: dict,
+                                                query: str) -> dict:
+        """
+        Fill in variable inputs from external data tools if exists.
+
+        :param tenant_id: workspace id
+        :param app_id: app id
+        :param external_data_tools: external data tools configs
+        :param inputs: the inputs
+        :param query: the query
+        :return: the filled inputs
+        """
+        external_data_fetch_feature = ExternalDataFetchFeature()
+        return external_data_fetch_feature.fetch(
+            tenant_id=tenant_id,
+            app_id=app_id,
+            external_data_tools=external_data_tools,
+            inputs=inputs,
+            query=query
+        )
+
+    def retrieve_dataset_context(self, tenant_id: str,
+                                 app_record: App,
+                                 queue_manager: ApplicationQueueManager,
+                                 model_config: ModelConfigEntity,
+                                 dataset_config: DatasetEntity,
+                                 show_retrieve_source: bool,
+                                 message: Message,
+                                 inputs: dict,
+                                 query: str,
+                                 user_id: str,
+                                 invoke_from: InvokeFrom,
+                                 memory: Optional[TokenBufferMemory] = None) -> Optional[str]:
+        """
+        Retrieve dataset context
+        :param tenant_id: tenant id
+        :param app_record: app record
+        :param queue_manager: queue manager
+        :param model_config: model config
+        :param dataset_config: dataset config
+        :param show_retrieve_source: show retrieve source
+        :param message: message
+        :param inputs: inputs
+        :param query: query
+        :param user_id: user id
+        :param invoke_from: invoke from
+        :param memory: memory
+        :return:
+        """
+        hit_callback = DatasetIndexToolCallbackHandler(
+            queue_manager,
+            app_record.id,
+            message.id,
+            user_id,
+            invoke_from
+        )
+
+        if (app_record.mode == AppMode.COMPLETION.value and dataset_config
+                and dataset_config.retrieve_config.query_variable):
+            query = inputs.get(dataset_config.retrieve_config.query_variable, "")
+
+        dataset_retrieval = DatasetRetrievalFeature()
+        return dataset_retrieval.retrieve(
+            tenant_id=tenant_id,
+            model_config=model_config,
+            config=dataset_config,
+            query=query,
+            invoke_from=invoke_from,
+            show_retrieve_source=show_retrieve_source,
+            hit_callback=hit_callback,
+            memory=memory
+        )
+
+    def check_hosting_moderation(self, application_generate_entity: ApplicationGenerateEntity,
+                                 queue_manager: ApplicationQueueManager,
+                                 prompt_messages: list[PromptMessage]) -> bool:
+        """
+        Check hosting moderation
+        :param application_generate_entity: application generate entity
+        :param queue_manager: queue manager
+        :param prompt_messages: prompt messages
+        :return:
+        """
+        hosting_moderation_feature = HostingModerationFeature()
+        moderation_result = hosting_moderation_feature.check(
+            application_generate_entity=application_generate_entity,
+            prompt_messages=prompt_messages
+        )
+
+        if moderation_result:
+            self.direct_output(
+                queue_manager=queue_manager,
+                app_orchestration_config=application_generate_entity.app_orchestration_config_entity,
+                prompt_messages=prompt_messages,
+                text="I apologize for any confusion, " \
+                     "but I'm an AI assistant to be helpful, harmless, and honest.",
+                stream=application_generate_entity.stream
+            )
+
+        return moderation_result

+ 483 - 0
api/core/app_runner/generate_task_pipeline.py

@@ -0,0 +1,483 @@
+import json
+import logging
+import time
+from typing import Union, Generator, cast, Optional
+
+from pydantic import BaseModel
+
+from core.app_runner.moderation_handler import OutputModerationHandler, ModerationRule
+from core.entities.application_entities import ApplicationGenerateEntity
+from core.application_queue_manager import ApplicationQueueManager
+from core.entities.queue_entities import QueueErrorEvent, QueueStopEvent, QueueMessageEndEvent, \
+    QueueRetrieverResourcesEvent, QueueAgentThoughtEvent, QueuePingEvent, QueueMessageEvent, QueueMessageReplaceEvent, \
+    AnnotationReplyEvent
+from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessageRole, \
+    TextPromptMessageContent, PromptMessageContentType, ImagePromptMessageContent, PromptMessage
+from core.model_runtime.errors.invoke import InvokeError, InvokeAuthorizationError
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.prompt.prompt_template import PromptTemplateParser
+from events.message_event import message_was_created
+from extensions.ext_database import db
+from models.model import Message, Conversation, MessageAgentThought
+from services.annotation_service import AppAnnotationService
+
+logger = logging.getLogger(__name__)
+
+
+class TaskState(BaseModel):
+    """
+    TaskState entity
+    """
+    llm_result: LLMResult
+    metadata: dict = {}
+
+
+class GenerateTaskPipeline:
+    """
+    GenerateTaskPipeline is a class that generate stream output and state management for Application.
+    """
+
+    def __init__(self, application_generate_entity: ApplicationGenerateEntity,
+                 queue_manager: ApplicationQueueManager,
+                 conversation: Conversation,
+                 message: Message) -> None:
+        """
+        Initialize GenerateTaskPipeline.
+        :param application_generate_entity: application generate entity
+        :param queue_manager: queue manager
+        :param conversation: conversation
+        :param message: message
+        """
+        self._application_generate_entity = application_generate_entity
+        self._queue_manager = queue_manager
+        self._conversation = conversation
+        self._message = message
+        self._task_state = TaskState(
+            llm_result=LLMResult(
+                model=self._application_generate_entity.app_orchestration_config_entity.model_config.model,
+                prompt_messages=[],
+                message=AssistantPromptMessage(content=""),
+                usage=LLMUsage.empty_usage()
+            )
+        )
+        self._start_at = time.perf_counter()
+        self._output_moderation_handler = self._init_output_moderation()
+
+    def process(self, stream: bool) -> Union[dict, Generator]:
+        """
+        Process generate task pipeline.
+        :return:
+        """
+        if stream:
+            return self._process_stream_response()
+        else:
+            return self._process_blocking_response()
+
+    def _process_blocking_response(self) -> dict:
+        """
+        Process blocking response.
+        :return:
+        """
+        for queue_message in self._queue_manager.listen():
+            event = queue_message.event
+
+            if isinstance(event, QueueErrorEvent):
+                raise self._handle_error(event)
+            elif isinstance(event, QueueRetrieverResourcesEvent):
+                self._task_state.metadata['retriever_resources'] = event.retriever_resources
+            elif isinstance(event, AnnotationReplyEvent):
+                annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
+                if annotation:
+                    account = annotation.account
+                    self._task_state.metadata['annotation_reply'] = {
+                        'id': annotation.id,
+                        'account': {
+                            'id': annotation.account_id,
+                            'name': account.name if account else 'Dify user'
+                        }
+                    }
+
+                    self._task_state.llm_result.message.content = annotation.content
+            elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
+                if isinstance(event, QueueMessageEndEvent):
+                    self._task_state.llm_result = event.llm_result
+                else:
+                    model_config = self._application_generate_entity.app_orchestration_config_entity.model_config
+                    model = model_config.model
+                    model_type_instance = model_config.provider_model_bundle.model_type_instance
+                    model_type_instance = cast(LargeLanguageModel, model_type_instance)
+
+                    # calculate num tokens
+                    prompt_tokens = 0
+                    if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
+                        prompt_tokens = model_type_instance.get_num_tokens(
+                            model,
+                            model_config.credentials,
+                            self._task_state.llm_result.prompt_messages
+                        )
+
+                    completion_tokens = 0
+                    if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
+                        completion_tokens = model_type_instance.get_num_tokens(
+                            model,
+                            model_config.credentials,
+                            [self._task_state.llm_result.message]
+                        )
+
+                    credentials = model_config.credentials
+
+                    # transform usage
+                    self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
+                        model,
+                        credentials,
+                        prompt_tokens,
+                        completion_tokens
+                    )
+
+                # response moderation
+                if self._output_moderation_handler:
+                    self._output_moderation_handler.stop_thread()
+
+                    self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion(
+                        completion=self._task_state.llm_result.message.content,
+                        public_event=False
+                    )
+
+                # Save message
+                self._save_message(event.llm_result)
+
+                response = {
+                    'event': 'message',
+                    'task_id': self._application_generate_entity.task_id,
+                    'id': self._message.id,
+                    'mode': self._conversation.mode,
+                    'answer': event.llm_result.message.content,
+                    'metadata': {},
+                    'created_at': int(self._message.created_at.timestamp())
+                }
+
+                if self._conversation.mode == 'chat':
+                    response['conversation_id'] = self._conversation.id
+
+                if self._task_state.metadata:
+                    response['metadata'] = self._task_state.metadata
+
+                return response
+            else:
+                continue
+
+    def _process_stream_response(self) -> Generator:
+        """
+        Process stream response.
+        :return:
+        """
+        for message in self._queue_manager.listen():
+            event = message.event
+
+            if isinstance(event, QueueErrorEvent):
+                raise self._handle_error(event)
+            elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
+                if isinstance(event, QueueMessageEndEvent):
+                    self._task_state.llm_result = event.llm_result
+                else:
+                    model_config = self._application_generate_entity.app_orchestration_config_entity.model_config
+                    model = model_config.model
+                    model_type_instance = model_config.provider_model_bundle.model_type_instance
+                    model_type_instance = cast(LargeLanguageModel, model_type_instance)
+
+                    # calculate num tokens
+                    prompt_tokens = 0
+                    if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
+                        prompt_tokens = model_type_instance.get_num_tokens(
+                            model,
+                            model_config.credentials,
+                            self._task_state.llm_result.prompt_messages
+                        )
+
+                    completion_tokens = 0
+                    if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
+                        completion_tokens = model_type_instance.get_num_tokens(
+                            model,
+                            model_config.credentials,
+                            [self._task_state.llm_result.message]
+                        )
+
+                    credentials = model_config.credentials
+
+                    # transform usage
+                    self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
+                        model,
+                        credentials,
+                        prompt_tokens,
+                        completion_tokens
+                    )
+
+                # response moderation
+                if self._output_moderation_handler:
+                    self._output_moderation_handler.stop_thread()
+
+                    self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion(
+                        completion=self._task_state.llm_result.message.content,
+                        public_event=False
+                    )
+
+                    self._output_moderation_handler = None
+
+                    replace_response = {
+                        'event': 'message_replace',
+                        'task_id': self._application_generate_entity.task_id,
+                        'message_id': self._message.id,
+                        'answer': self._task_state.llm_result.message.content,
+                        'created_at': int(self._message.created_at.timestamp())
+                    }
+
+                    if self._conversation.mode == 'chat':
+                        replace_response['conversation_id'] = self._conversation.id
+
+                    yield self._yield_response(replace_response)
+
+                # Save message
+                self._save_message(self._task_state.llm_result)
+
+                response = {
+                    'event': 'message_end',
+                    'task_id': self._application_generate_entity.task_id,
+                    'id': self._message.id,
+                }
+
+                if self._conversation.mode == 'chat':
+                    response['conversation_id'] = self._conversation.id
+
+                if self._task_state.metadata:
+                    response['metadata'] = self._task_state.metadata
+
+                yield self._yield_response(response)
+            elif isinstance(event, QueueRetrieverResourcesEvent):
+                self._task_state.metadata['retriever_resources'] = event.retriever_resources
+            elif isinstance(event, AnnotationReplyEvent):
+                annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
+                if annotation:
+                    account = annotation.account
+                    self._task_state.metadata['annotation_reply'] = {
+                        'id': annotation.id,
+                        'account': {
+                            'id': annotation.account_id,
+                            'name': account.name if account else 'Dify user'
+                        }
+                    }
+
+                    self._task_state.llm_result.message.content = annotation.content
+            elif isinstance(event, QueueAgentThoughtEvent):
+                agent_thought = (
+                    db.session.query(MessageAgentThought)
+                    .filter(MessageAgentThought.id == event.agent_thought_id)
+                    .first()
+                )
+
+                if agent_thought:
+                    response = {
+                        'event': 'agent_thought',
+                        'id': agent_thought.id,
+                        'task_id': self._application_generate_entity.task_id,
+                        'message_id': self._message.id,
+                        'position': agent_thought.position,
+                        'thought': agent_thought.thought,
+                        'tool': agent_thought.tool,
+                        'tool_input': agent_thought.tool_input,
+                        'created_at': int(self._message.created_at.timestamp())
+                    }
+
+                    if self._conversation.mode == 'chat':
+                        response['conversation_id'] = self._conversation.id
+
+                    yield self._yield_response(response)
+            elif isinstance(event, QueueMessageEvent):
+                chunk = event.chunk
+                delta_text = chunk.delta.message.content
+                if delta_text is None:
+                    continue
+
+                if not self._task_state.llm_result.prompt_messages:
+                    self._task_state.llm_result.prompt_messages = chunk.prompt_messages
+
+                if self._output_moderation_handler:
+                    if self._output_moderation_handler.should_direct_output():
+                        # stop subscribe new token when output moderation should direct output
+                        self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output()
+                        self._queue_manager.publish_chunk_message(LLMResultChunk(
+                            model=self._task_state.llm_result.model,
+                            prompt_messages=self._task_state.llm_result.prompt_messages,
+                            delta=LLMResultChunkDelta(
+                                index=0,
+                                message=AssistantPromptMessage(content=self._task_state.llm_result.message.content)
+                            )
+                        ))
+                        self._queue_manager.publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION))
+                        continue
+                    else:
+                        self._output_moderation_handler.append_new_token(delta_text)
+
+                self._task_state.llm_result.message.content += delta_text
+                response = self._handle_chunk(delta_text)
+                yield self._yield_response(response)
+            elif isinstance(event, QueueMessageReplaceEvent):
+                response = {
+                    'event': 'message_replace',
+                    'task_id': self._application_generate_entity.task_id,
+                    'message_id': self._message.id,
+                    'answer': event.text,
+                    'created_at': int(self._message.created_at.timestamp())
+                }
+
+                if self._conversation.mode == 'chat':
+                    response['conversation_id'] = self._conversation.id
+
+                yield self._yield_response(response)
+            elif isinstance(event, QueuePingEvent):
+                yield "event: ping\n\n"
+            else:
+                continue
+
+    def _save_message(self, llm_result: LLMResult) -> None:
+        """
+        Save message.
+        :param llm_result: llm result
+        :return:
+        """
+        usage = llm_result.usage
+
+        self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
+
+        self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages)
+        self._message.message_tokens = usage.prompt_tokens
+        self._message.message_unit_price = usage.prompt_unit_price
+        self._message.message_price_unit = usage.prompt_price_unit
+        self._message.answer = PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) \
+            if llm_result.message.content else ''
+        self._message.answer_tokens = usage.completion_tokens
+        self._message.answer_unit_price = usage.completion_unit_price
+        self._message.answer_price_unit = usage.completion_price_unit
+        self._message.provider_response_latency = time.perf_counter() - self._start_at
+        self._message.total_price = usage.total_price
+
+        db.session.commit()
+
+        message_was_created.send(
+            self._message,
+            application_generate_entity=self._application_generate_entity,
+            conversation=self._conversation,
+            is_first_message=self._application_generate_entity.conversation_id is None,
+            extras=self._application_generate_entity.extras
+        )
+
+    def _handle_chunk(self, text: str) -> dict:
+        """
+        Handle completed event.
+        :param text: text
+        :return:
+        """
+        response = {
+            'event': 'message',
+            'id': self._message.id,
+            'task_id': self._application_generate_entity.task_id,
+            'message_id': self._message.id,
+            'answer': text,
+            'created_at': int(self._message.created_at.timestamp())
+        }
+
+        if self._conversation.mode == 'chat':
+            response['conversation_id'] = self._conversation.id
+
+        return response
+
+    def _handle_error(self, event: QueueErrorEvent) -> Exception:
+        """
+        Handle error event.
+        :param event: event
+        :return:
+        """
+        logger.debug("error: %s", event.error)
+        e = event.error
+
+        if isinstance(e, InvokeAuthorizationError):
+            return InvokeAuthorizationError('Incorrect API key provided')
+        elif isinstance(e, InvokeError) or isinstance(e, ValueError):
+            return e
+        else:
+            return Exception(e.description if getattr(e, 'description', None) is not None else str(e))
+
+    def _yield_response(self, response: dict) -> str:
+        """
+        Yield response.
+        :param response: response
+        :return:
+        """
+        return "data: " + json.dumps(response) + "\n\n"
+
+    def _prompt_messages_to_prompt_for_saving(self, prompt_messages: list[PromptMessage]) -> list[dict]:
+        """
+        Prompt messages to prompt for saving.
+        :param prompt_messages: prompt messages
+        :return:
+        """
+        prompts = []
+        if self._application_generate_entity.app_orchestration_config_entity.model_config.mode == 'chat':
+            for prompt_message in prompt_messages:
+                if prompt_message.role == PromptMessageRole.USER:
+                    role = 'user'
+                elif prompt_message.role == PromptMessageRole.ASSISTANT:
+                    role = 'assistant'
+                elif prompt_message.role == PromptMessageRole.SYSTEM:
+                    role = 'system'
+                else:
+                    continue
+
+                text = ''
+                files = []
+                if isinstance(prompt_message.content, list):
+                    for content in prompt_message.content:
+                        if content.type == PromptMessageContentType.TEXT:
+                            content = cast(TextPromptMessageContent, content)
+                            text += content.data
+                        else:
+                            content = cast(ImagePromptMessageContent, content)
+                            files.append({
+                                "type": 'image',
+                                "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
+                                "detail": content.detail.value
+                            })
+                else:
+                    text = prompt_message.content
+
+                prompts.append({
+                    "role": role,
+                    "text": text,
+                    "files": files
+                })
+        else:
+            prompts.append({
+                "role": 'user',
+                "text": prompt_messages[0].content
+            })
+
+        return prompts
+
+    def _init_output_moderation(self) -> Optional[OutputModerationHandler]:
+        """
+        Init output moderation.
+        :return:
+        """
+        app_orchestration_config_entity = self._application_generate_entity.app_orchestration_config_entity
+        sensitive_word_avoidance = app_orchestration_config_entity.sensitive_word_avoidance
+
+        if sensitive_word_avoidance:
+            return OutputModerationHandler(
+                tenant_id=self._application_generate_entity.tenant_id,
+                app_id=self._application_generate_entity.app_id,
+                rule=ModerationRule(
+                    type=sensitive_word_avoidance.type,
+                    config=sensitive_word_avoidance.config
+                ),
+                on_message_replace_func=self._queue_manager.publish_message_replace
+            )

+ 138 - 0
api/core/app_runner/moderation_handler.py

@@ -0,0 +1,138 @@
+import logging
+import threading
+import time
+from typing import Any, Optional, Dict
+
+from flask import current_app, Flask
+from pydantic import BaseModel
+
+from core.moderation.base import ModerationAction, ModerationOutputsResult
+from core.moderation.factory import ModerationFactory
+
+logger = logging.getLogger(__name__)
+
+
+class ModerationRule(BaseModel):
+    type: str
+    config: Dict[str, Any]
+
+
+class OutputModerationHandler(BaseModel):
+    DEFAULT_BUFFER_SIZE: int = 300
+
+    tenant_id: str
+    app_id: str
+
+    rule: ModerationRule
+    on_message_replace_func: Any
+
+    thread: Optional[threading.Thread] = None
+    thread_running: bool = True
+    buffer: str = ''
+    is_final_chunk: bool = False
+    final_output: Optional[str] = None
+
+    class Config:
+        arbitrary_types_allowed = True
+
+    def should_direct_output(self):
+        return self.final_output is not None
+
+    def get_final_output(self):
+        return self.final_output
+
+    def append_new_token(self, token: str):
+        self.buffer += token
+
+        if not self.thread:
+            self.thread = self.start_thread()
+
+    def moderation_completion(self, completion: str, public_event: bool = False) -> str:
+        self.buffer = completion
+        self.is_final_chunk = True
+
+        result = self.moderation(
+            tenant_id=self.tenant_id,
+            app_id=self.app_id,
+            moderation_buffer=completion
+        )
+
+        if not result or not result.flagged:
+            return completion
+
+        if result.action == ModerationAction.DIRECT_OUTPUT:
+            final_output = result.preset_response
+        else:
+            final_output = result.text
+
+        if public_event:
+            self.on_message_replace_func(final_output)
+
+        return final_output
+
+    def start_thread(self) -> threading.Thread:
+        buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE))
+        thread = threading.Thread(target=self.worker, kwargs={
+            'flask_app': current_app._get_current_object(),
+            'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE
+        })
+
+        thread.start()
+
+        return thread
+
+    def stop_thread(self):
+        if self.thread and self.thread.is_alive():
+            self.thread_running = False
+
+    def worker(self, flask_app: Flask, buffer_size: int):
+        with flask_app.app_context():
+            current_length = 0
+            while self.thread_running:
+                moderation_buffer = self.buffer
+                buffer_length = len(moderation_buffer)
+                if not self.is_final_chunk:
+                    chunk_length = buffer_length - current_length
+                    if 0 <= chunk_length < buffer_size:
+                        time.sleep(1)
+                        continue
+
+                current_length = buffer_length
+
+                result = self.moderation(
+                    tenant_id=self.tenant_id,
+                    app_id=self.app_id,
+                    moderation_buffer=moderation_buffer
+                )
+
+                if not result or not result.flagged:
+                    continue
+
+                if result.action == ModerationAction.DIRECT_OUTPUT:
+                    final_output = result.preset_response
+                    self.final_output = final_output
+                else:
+                    final_output = result.text + self.buffer[len(moderation_buffer):]
+
+                # trigger replace event
+                if self.thread_running:
+                    self.on_message_replace_func(final_output)
+
+                if result.action == ModerationAction.DIRECT_OUTPUT:
+                    break
+
+    def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
+        try:
+            moderation_factory = ModerationFactory(
+                name=self.rule.type,
+                app_id=app_id,
+                tenant_id=tenant_id,
+                config=self.rule.config
+            )
+
+            result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
+            return result
+        except Exception as e:
+            logger.error("Moderation Output error: %s", e)
+
+        return None

+ 655 - 0
api/core/application_manager.py

@@ -0,0 +1,655 @@
+import json
+import logging
+import threading
+import uuid
+from typing import cast, Optional, Any, Union, Generator, Tuple
+
+from flask import Flask, current_app
+from pydantic import ValidationError
+
+from core.app_runner.agent_app_runner import AgentApplicationRunner
+from core.app_runner.basic_app_runner import BasicApplicationRunner
+from core.app_runner.generate_task_pipeline import GenerateTaskPipeline
+from core.entities.application_entities import ApplicationGenerateEntity, AppOrchestrationConfigEntity, \
+    ModelConfigEntity, PromptTemplateEntity, AdvancedChatPromptTemplateEntity, \
+    AdvancedCompletionPromptTemplateEntity, ExternalDataVariableEntity, DatasetEntity, DatasetRetrieveConfigEntity, \
+    AgentEntity, AgentToolEntity, FileUploadEntity, SensitiveWordAvoidanceEntity, InvokeFrom
+from core.entities.model_entities import ModelStatus
+from core.file.file_obj import FileObj
+from core.errors.error import QuotaExceededError, ProviderTokenNotInitError, ModelCurrentlyNotSupportError
+from core.model_runtime.entities.message_entities import PromptMessageRole
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.prompt.prompt_template import PromptTemplateParser
+from core.provider_manager import ProviderManager
+from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException
+from extensions.ext_database import db
+from models.account import Account
+from models.model import EndUser, Conversation, Message, MessageFile, App
+
+logger = logging.getLogger(__name__)
+
+
+class ApplicationManager:
+    """
+    This class is responsible for managing application
+    """
+
+    def generate(self, tenant_id: str,
+                 app_id: str,
+                 app_model_config_id: str,
+                 app_model_config_dict: dict,
+                 app_model_config_override: bool,
+                 user: Union[Account, EndUser],
+                 invoke_from: InvokeFrom,
+                 inputs: dict[str, str],
+                 query: Optional[str] = None,
+                 files: Optional[list[FileObj]] = None,
+                 conversation: Optional[Conversation] = None,
+                 stream: bool = False,
+                 extras: Optional[dict[str, Any]] = None) \
+            -> Union[dict, Generator]:
+        """
+        Generate App response.
+
+        :param tenant_id: workspace ID
+        :param app_id: app ID
+        :param app_model_config_id: app model config id
+        :param app_model_config_dict: app model config dict
+        :param app_model_config_override: app model config override
+        :param user: account or end user
+        :param invoke_from: invoke from source
+        :param inputs: inputs
+        :param query: query
+        :param files: file obj list
+        :param conversation: conversation
+        :param stream: is stream
+        :param extras: extras
+        """
+        # init task id
+        task_id = str(uuid.uuid4())
+
+        # init application generate entity
+        application_generate_entity = ApplicationGenerateEntity(
+            task_id=task_id,
+            tenant_id=tenant_id,
+            app_id=app_id,
+            app_model_config_id=app_model_config_id,
+            app_model_config_dict=app_model_config_dict,
+            app_orchestration_config_entity=self._convert_from_app_model_config_dict(
+                tenant_id=tenant_id,
+                app_model_config_dict=app_model_config_dict
+            ),
+            app_model_config_override=app_model_config_override,
+            conversation_id=conversation.id if conversation else None,
+            inputs=conversation.inputs if conversation else inputs,
+            query=query.replace('\x00', '') if query else None,
+            files=files if files else [],
+            user_id=user.id,
+            stream=stream,
+            invoke_from=invoke_from,
+            extras=extras
+        )
+
+        # init generate records
+        (
+            conversation,
+            message
+        ) = self._init_generate_records(application_generate_entity)
+
+        # init queue manager
+        queue_manager = ApplicationQueueManager(
+            task_id=application_generate_entity.task_id,
+            user_id=application_generate_entity.user_id,
+            invoke_from=application_generate_entity.invoke_from,
+            conversation_id=conversation.id,
+            app_mode=conversation.mode,
+            message_id=message.id
+        )
+
+        # new thread
+        worker_thread = threading.Thread(target=self._generate_worker, kwargs={
+            'flask_app': current_app._get_current_object(),
+            'application_generate_entity': application_generate_entity,
+            'queue_manager': queue_manager,
+            'conversation_id': conversation.id,
+            'message_id': message.id,
+        })
+
+        worker_thread.start()
+
+        # return response or stream generator
+        return self._handle_response(
+            application_generate_entity=application_generate_entity,
+            queue_manager=queue_manager,
+            conversation=conversation,
+            message=message,
+            stream=stream
+        )
+
+    def _generate_worker(self, flask_app: Flask,
+                         application_generate_entity: ApplicationGenerateEntity,
+                         queue_manager: ApplicationQueueManager,
+                         conversation_id: str,
+                         message_id: str) -> None:
+        """
+        Generate worker in a new thread.
+        :param flask_app: Flask app
+        :param application_generate_entity: application generate entity
+        :param queue_manager: queue manager
+        :param conversation_id: conversation ID
+        :param message_id: message ID
+        :return:
+        """
+        with flask_app.app_context():
+            try:
+                # get conversation and message
+                conversation = self._get_conversation(conversation_id)
+                message = self._get_message(message_id)
+
+                if application_generate_entity.app_orchestration_config_entity.agent:
+                    # agent app
+                    runner = AgentApplicationRunner()
+                    runner.run(
+                        application_generate_entity=application_generate_entity,
+                        queue_manager=queue_manager,
+                        conversation=conversation,
+                        message=message
+                    )
+                else:
+                    # basic app
+                    runner = BasicApplicationRunner()
+                    runner.run(
+                        application_generate_entity=application_generate_entity,
+                        queue_manager=queue_manager,
+                        conversation=conversation,
+                        message=message
+                    )
+            except ConversationTaskStoppedException:
+                pass
+            except InvokeAuthorizationError:
+                queue_manager.publish_error(InvokeAuthorizationError('Incorrect API key provided'))
+            except ValidationError as e:
+                logger.exception("Validation Error when generating")
+                queue_manager.publish_error(e)
+            except (ValueError, InvokeError) as e:
+                queue_manager.publish_error(e)
+            except Exception as e:
+                logger.exception("Unknown Error when generating")
+                queue_manager.publish_error(e)
+            finally:
+                db.session.remove()
+
+    def _handle_response(self, application_generate_entity: ApplicationGenerateEntity,
+                         queue_manager: ApplicationQueueManager,
+                         conversation: Conversation,
+                         message: Message,
+                         stream: bool = False) -> Union[dict, Generator]:
+        """
+        Handle response.
+        :param application_generate_entity: application generate entity
+        :param queue_manager: queue manager
+        :param conversation: conversation
+        :param message: message
+        :param stream: is stream
+        :return:
+        """
+        # init generate task pipeline
+        generate_task_pipeline = GenerateTaskPipeline(
+            application_generate_entity=application_generate_entity,
+            queue_manager=queue_manager,
+            conversation=conversation,
+            message=message
+        )
+
+        try:
+            return generate_task_pipeline.process(stream=stream)
+        except ValueError as e:
+            if e.args[0] == "I/O operation on closed file.":  # ignore this error
+                raise ConversationTaskStoppedException()
+            else:
+                logger.exception(e)
+                raise e
+        finally:
+            db.session.remove()
+
+    def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \
+            -> AppOrchestrationConfigEntity:
+        """
+        Convert app model config dict to entity.
+        :param tenant_id: tenant ID
+        :param app_model_config_dict: app model config dict
+        :raises ProviderTokenNotInitError: provider token not init error
+        :return: app orchestration config entity
+        """
+        properties = {}
+
+        copy_app_model_config_dict = app_model_config_dict.copy()
+
+        provider_manager = ProviderManager()
+        provider_model_bundle = provider_manager.get_provider_model_bundle(
+            tenant_id=tenant_id,
+            provider=copy_app_model_config_dict['model']['provider'],
+            model_type=ModelType.LLM
+        )
+
+        provider_name = provider_model_bundle.configuration.provider.provider
+        model_name = copy_app_model_config_dict['model']['name']
+
+        model_type_instance = provider_model_bundle.model_type_instance
+        model_type_instance = cast(LargeLanguageModel, model_type_instance)
+
+        # check model credentials
+        model_credentials = provider_model_bundle.configuration.get_current_credentials(
+            model_type=ModelType.LLM,
+            model=copy_app_model_config_dict['model']['name']
+        )
+
+        if model_credentials is None:
+            raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
+
+        # check model
+        provider_model = provider_model_bundle.configuration.get_provider_model(
+            model=copy_app_model_config_dict['model']['name'],
+            model_type=ModelType.LLM
+        )
+
+        if provider_model is None:
+            model_name = copy_app_model_config_dict['model']['name']
+            raise ValueError(f"Model {model_name} not exist.")
+
+        if provider_model.status == ModelStatus.NO_CONFIGURE:
+            raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
+        elif provider_model.status == ModelStatus.NO_PERMISSION:
+            raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
+        elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
+            raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
+
+        # model config
+        completion_params = copy_app_model_config_dict['model'].get('completion_params')
+        stop = []
+        if 'stop' in completion_params:
+            stop = completion_params['stop']
+            del completion_params['stop']
+
+        # get model mode
+        model_mode = copy_app_model_config_dict['model'].get('mode')
+        if not model_mode:
+            mode_enum = model_type_instance.get_model_mode(
+                model=copy_app_model_config_dict['model']['name'],
+                credentials=model_credentials
+            )
+
+            model_mode = mode_enum.value
+
+        model_schema = model_type_instance.get_model_schema(
+            copy_app_model_config_dict['model']['name'],
+            model_credentials
+        )
+
+        if not model_schema:
+            raise ValueError(f"Model {model_name} not exist.")
+
+        properties['model_config'] = ModelConfigEntity(
+            provider=copy_app_model_config_dict['model']['provider'],
+            model=copy_app_model_config_dict['model']['name'],
+            model_schema=model_schema,
+            mode=model_mode,
+            provider_model_bundle=provider_model_bundle,
+            credentials=model_credentials,
+            parameters=completion_params,
+            stop=stop,
+        )
+
+        # prompt template
+        prompt_type = PromptTemplateEntity.PromptType.value_of(copy_app_model_config_dict['prompt_type'])
+        if prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
+            simple_prompt_template = copy_app_model_config_dict.get("pre_prompt", "")
+            properties['prompt_template'] = PromptTemplateEntity(
+                prompt_type=prompt_type,
+                simple_prompt_template=simple_prompt_template
+            )
+        else:
+            advanced_chat_prompt_template = None
+            chat_prompt_config = copy_app_model_config_dict.get("chat_prompt_config", {})
+            if chat_prompt_config:
+                chat_prompt_messages = []
+                for message in chat_prompt_config.get("prompt", []):
+                    chat_prompt_messages.append({
+                        "text": message["text"],
+                        "role": PromptMessageRole.value_of(message["role"])
+                    })
+
+                advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(
+                    messages=chat_prompt_messages
+                )
+
+            advanced_completion_prompt_template = None
+            completion_prompt_config = copy_app_model_config_dict.get("completion_prompt_config", {})
+            if completion_prompt_config:
+                completion_prompt_template_params = {
+                    'prompt': completion_prompt_config['prompt']['text'],
+                }
+
+                if 'conversation_histories_role' in completion_prompt_config:
+                    completion_prompt_template_params['role_prefix'] = {
+                        'user': completion_prompt_config['conversation_histories_role']['user_prefix'],
+                        'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix']
+                    }
+
+                advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
+                    **completion_prompt_template_params
+                )
+
+            properties['prompt_template'] = PromptTemplateEntity(
+                prompt_type=prompt_type,
+                advanced_chat_prompt_template=advanced_chat_prompt_template,
+                advanced_completion_prompt_template=advanced_completion_prompt_template
+            )
+
+        # external data variables
+        properties['external_data_variables'] = []
+        external_data_tools = copy_app_model_config_dict.get('external_data_tools', [])
+        for external_data_tool in external_data_tools:
+            if 'enabled' not in external_data_tool or not external_data_tool['enabled']:
+                continue
+
+            properties['external_data_variables'].append(
+                ExternalDataVariableEntity(
+                    variable=external_data_tool['variable'],
+                    type=external_data_tool['type'],
+                    config=external_data_tool['config']
+                )
+            )
+
+        # show retrieve source
+        show_retrieve_source = False
+        retriever_resource_dict = copy_app_model_config_dict.get('retriever_resource')
+        if retriever_resource_dict:
+            if 'enabled' in retriever_resource_dict and retriever_resource_dict['enabled']:
+                show_retrieve_source = True
+
+        properties['show_retrieve_source'] = show_retrieve_source
+
+        if 'agent_mode' in copy_app_model_config_dict and copy_app_model_config_dict['agent_mode'] \
+                and 'enabled' in copy_app_model_config_dict['agent_mode'] and copy_app_model_config_dict['agent_mode'][
+            'enabled']:
+            agent_dict = copy_app_model_config_dict.get('agent_mode')
+            if agent_dict['strategy'] in ['router', 'react_router']:
+                dataset_ids = []
+                for tool in agent_dict.get('tools', []):
+                    key = list(tool.keys())[0]
+
+                    if key != 'dataset':
+                        continue
+
+                    tool_item = tool[key]
+
+                    if "enabled" not in tool_item or not tool_item["enabled"]:
+                        continue
+
+                    dataset_id = tool_item['id']
+                    dataset_ids.append(dataset_id)
+
+                dataset_configs = copy_app_model_config_dict.get('dataset_configs', {'retrieval_model': 'single'})
+                query_variable = copy_app_model_config_dict.get('dataset_query_variable')
+                if dataset_configs['retrieval_model'] == 'single':
+                    properties['dataset'] = DatasetEntity(
+                        dataset_ids=dataset_ids,
+                        retrieve_config=DatasetRetrieveConfigEntity(
+                            query_variable=query_variable,
+                            retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
+                                dataset_configs['retrieval_model']
+                            ),
+                            single_strategy=agent_dict['strategy']
+                        )
+                    )
+                else:
+                    properties['dataset'] = DatasetEntity(
+                        dataset_ids=dataset_ids,
+                        retrieve_config=DatasetRetrieveConfigEntity(
+                            query_variable=query_variable,
+                            retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
+                                dataset_configs['retrieval_model']
+                            ),
+                            top_k=dataset_configs.get('top_k'),
+                            score_threshold=dataset_configs.get('score_threshold'),
+                            reranking_model=dataset_configs.get('reranking_model')
+                        )
+                    )
+            else:
+                if agent_dict['strategy'] == 'react':
+                    strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
+                else:
+                    strategy = AgentEntity.Strategy.FUNCTION_CALLING
+
+                agent_tools = []
+                for tool in agent_dict.get('tools', []):
+                    key = list(tool.keys())[0]
+                    tool_item = tool[key]
+
+                    agent_tool_properties = {
+                        "tool_id": key
+                    }
+
+                    if "enabled" not in tool_item or not tool_item["enabled"]:
+                        continue
+
+                    agent_tool_properties["config"] = tool_item
+                    agent_tools.append(AgentToolEntity(**agent_tool_properties))
+
+                properties['agent'] = AgentEntity(
+                    provider=properties['model_config'].provider,
+                    model=properties['model_config'].model,
+                    strategy=strategy,
+                    tools=agent_tools
+                )
+
+        # file upload
+        file_upload_dict = copy_app_model_config_dict.get('file_upload')
+        if file_upload_dict:
+            if 'image' in file_upload_dict and file_upload_dict['image']:
+                if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']:
+                    properties['file_upload'] = FileUploadEntity(
+                        image_config={
+                            'number_limits': file_upload_dict['image']['number_limits'],
+                            'detail': file_upload_dict['image']['detail'],
+                            'transfer_methods': file_upload_dict['image']['transfer_methods']
+                        }
+                    )
+
+        # opening statement
+        properties['opening_statement'] = copy_app_model_config_dict.get('opening_statement')
+
+        # suggested questions after answer
+        suggested_questions_after_answer_dict = copy_app_model_config_dict.get('suggested_questions_after_answer')
+        if suggested_questions_after_answer_dict:
+            if 'enabled' in suggested_questions_after_answer_dict and suggested_questions_after_answer_dict['enabled']:
+                properties['suggested_questions_after_answer'] = True
+
+        # more like this
+        more_like_this_dict = copy_app_model_config_dict.get('more_like_this')
+        if more_like_this_dict:
+            if 'enabled' in more_like_this_dict and more_like_this_dict['enabled']:
+                properties['more_like_this'] = copy_app_model_config_dict.get('opening_statement')
+
+        # speech to text
+        speech_to_text_dict = copy_app_model_config_dict.get('speech_to_text')
+        if speech_to_text_dict:
+            if 'enabled' in speech_to_text_dict and speech_to_text_dict['enabled']:
+                properties['speech_to_text'] = True
+
+        # sensitive word avoidance
+        sensitive_word_avoidance_dict = copy_app_model_config_dict.get('sensitive_word_avoidance')
+        if sensitive_word_avoidance_dict:
+            if 'enabled' in sensitive_word_avoidance_dict and sensitive_word_avoidance_dict['enabled']:
+                properties['sensitive_word_avoidance'] = SensitiveWordAvoidanceEntity(
+                    type=sensitive_word_avoidance_dict.get('type'),
+                    config=sensitive_word_avoidance_dict.get('config'),
+                )
+
+        return AppOrchestrationConfigEntity(**properties)
+
+    def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \
+            -> Tuple[Conversation, Message]:
+        """
+        Initialize generate records
+        :param application_generate_entity: application generate entity
+        :return:
+        """
+        app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
+
+        model_type_instance = app_orchestration_config_entity.model_config.provider_model_bundle.model_type_instance
+        model_type_instance = cast(LargeLanguageModel, model_type_instance)
+        model_schema = model_type_instance.get_model_schema(
+            model=app_orchestration_config_entity.model_config.model,
+            credentials=app_orchestration_config_entity.model_config.credentials
+        )
+
+        app_record = (db.session.query(App)
+                      .filter(App.id == application_generate_entity.app_id).first())
+
+        app_mode = app_record.mode
+
+        # get from source
+        end_user_id = None
+        account_id = None
+        if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
+            from_source = 'api'
+            end_user_id = application_generate_entity.user_id
+        else:
+            from_source = 'console'
+            account_id = application_generate_entity.user_id
+
+        override_model_configs = None
+        if application_generate_entity.app_model_config_override:
+            override_model_configs = application_generate_entity.app_model_config_dict
+
+        introduction = ''
+        if app_mode == 'chat':
+            # get conversation introduction
+            introduction = self._get_conversation_introduction(application_generate_entity)
+
+        if not application_generate_entity.conversation_id:
+            conversation = Conversation(
+                app_id=app_record.id,
+                app_model_config_id=application_generate_entity.app_model_config_id,
+                model_provider=app_orchestration_config_entity.model_config.provider,
+                model_id=app_orchestration_config_entity.model_config.model,
+                override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
+                mode=app_mode,
+                name='New conversation',
+                inputs=application_generate_entity.inputs,
+                introduction=introduction,
+                system_instruction="",
+                system_instruction_tokens=0,
+                status='normal',
+                from_source=from_source,
+                from_end_user_id=end_user_id,
+                from_account_id=account_id,
+            )
+
+            db.session.add(conversation)
+            db.session.commit()
+        else:
+            conversation = (
+                db.session.query(Conversation)
+                .filter(
+                    Conversation.id == application_generate_entity.conversation_id,
+                    Conversation.app_id == app_record.id
+                ).first()
+            )
+
+        currency = model_schema.pricing.currency if model_schema.pricing else 'USD'
+
+        message = Message(
+            app_id=app_record.id,
+            model_provider=app_orchestration_config_entity.model_config.provider,
+            model_id=app_orchestration_config_entity.model_config.model,
+            override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
+            conversation_id=conversation.id,
+            inputs=application_generate_entity.inputs,
+            query=application_generate_entity.query or "",
+            message="",
+            message_tokens=0,
+            message_unit_price=0,
+            message_price_unit=0,
+            answer="",
+            answer_tokens=0,
+            answer_unit_price=0,
+            answer_price_unit=0,
+            provider_response_latency=0,
+            total_price=0,
+            currency=currency,
+            from_source=from_source,
+            from_end_user_id=end_user_id,
+            from_account_id=account_id,
+            agent_based=app_orchestration_config_entity.agent is not None
+        )
+
+        db.session.add(message)
+        db.session.commit()
+
+        for file in application_generate_entity.files:
+            message_file = MessageFile(
+                message_id=message.id,
+                type=file.type.value,
+                transfer_method=file.transfer_method.value,
+                url=file.url,
+                upload_file_id=file.upload_file_id,
+                created_by_role=('account' if account_id else 'end_user'),
+                created_by=account_id or end_user_id,
+            )
+            db.session.add(message_file)
+            db.session.commit()
+
+        return conversation, message
+
+    def _get_conversation_introduction(self, application_generate_entity: ApplicationGenerateEntity) -> str:
+        """
+        Get conversation introduction
+        :param application_generate_entity: application generate entity
+        :return: conversation introduction
+        """
+        app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
+        introduction = app_orchestration_config_entity.opening_statement
+
+        if introduction:
+            try:
+                inputs = application_generate_entity.inputs
+                prompt_template = PromptTemplateParser(template=introduction)
+                prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
+                introduction = prompt_template.format(prompt_inputs)
+            except KeyError:
+                pass
+
+        return introduction
+
+    def _get_conversation(self, conversation_id: str) -> Conversation:
+        """
+        Get conversation by conversation id
+        :param conversation_id: conversation id
+        :return: conversation
+        """
+        conversation = (
+            db.session.query(Conversation)
+            .filter(Conversation.id == conversation_id)
+            .first()
+        )
+
+        return conversation
+
+    def _get_message(self, message_id: str) -> Message:
+        """
+        Get message by message id
+        :param message_id: message id
+        :return: message
+        """
+        message = (
+            db.session.query(Message)
+            .filter(Message.id == message_id)
+            .first()
+        )
+
+        return message

+ 228 - 0
api/core/application_queue_manager.py

@@ -0,0 +1,228 @@
+import queue
+import time
+from typing import Generator, Any
+
+from sqlalchemy.orm import DeclarativeMeta
+
+from core.entities.application_entities import InvokeFrom
+from core.entities.queue_entities import QueueStopEvent, AppQueueEvent, QueuePingEvent, QueueErrorEvent, \
+    QueueAgentThoughtEvent, QueueMessageEndEvent, QueueRetrieverResourcesEvent, QueueMessageReplaceEvent, \
+    QueueMessageEvent, QueueMessage, AnnotationReplyEvent
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
+from extensions.ext_redis import redis_client
+from models.model import MessageAgentThought
+
+
+class ApplicationQueueManager:
+    def __init__(self, task_id: str,
+                 user_id: str,
+                 invoke_from: InvokeFrom,
+                 conversation_id: str,
+                 app_mode: str,
+                 message_id: str) -> None:
+        if not user_id:
+            raise ValueError("user is required")
+
+        self._task_id = task_id
+        self._user_id = user_id
+        self._invoke_from = invoke_from
+        self._conversation_id = str(conversation_id)
+        self._app_mode = app_mode
+        self._message_id = str(message_id)
+
+        user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
+        redis_client.setex(ApplicationQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}")
+
+        q = queue.Queue()
+
+        self._q = q
+
+    def listen(self) -> Generator:
+        """
+        Listen to queue
+        :return:
+        """
+        # wait for 10 minutes to stop listen
+        listen_timeout = 600
+        start_time = time.time()
+        last_ping_time = 0
+
+        while True:
+            try:
+                message = self._q.get(timeout=1)
+                if message is None:
+                    break
+
+                yield message
+            except queue.Empty:
+                continue
+            finally:
+                elapsed_time = time.time() - start_time
+                if elapsed_time >= listen_timeout or self._is_stopped():
+                    # publish two messages to make sure the client can receive the stop signal
+                    # and stop listening after the stop signal processed
+                    self.publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL))
+                    self.stop_listen()
+
+                if elapsed_time // 10 > last_ping_time:
+                    self.publish(QueuePingEvent())
+                    last_ping_time = elapsed_time // 10
+
+    def stop_listen(self) -> None:
+        """
+        Stop listen to queue
+        :return:
+        """
+        self._q.put(None)
+
+    def publish_chunk_message(self, chunk: LLMResultChunk) -> None:
+        """
+        Publish chunk message to channel
+
+        :param chunk: chunk
+        :return:
+        """
+        self.publish(QueueMessageEvent(
+            chunk=chunk
+        ))
+
+    def publish_message_replace(self, text: str) -> None:
+        """
+        Publish message replace
+        :param text: text
+        :return:
+        """
+        self.publish(QueueMessageReplaceEvent(
+            text=text
+        ))
+
+    def publish_retriever_resources(self, retriever_resources: list[dict]) -> None:
+        """
+        Publish retriever resources
+        :return:
+        """
+        self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources))
+
+    def publish_annotation_reply(self, message_annotation_id: str) -> None:
+        """
+        Publish annotation reply
+        :param message_annotation_id: message annotation id
+        :return:
+        """
+        self.publish(AnnotationReplyEvent(message_annotation_id=message_annotation_id))
+
+    def publish_message_end(self, llm_result: LLMResult) -> None:
+        """
+        Publish message end
+        :param llm_result: llm result
+        :return:
+        """
+        self.publish(QueueMessageEndEvent(llm_result=llm_result))
+        self.stop_listen()
+
+    def publish_agent_thought(self, message_agent_thought: MessageAgentThought) -> None:
+        """
+        Publish agent thought
+        :param message_agent_thought: message agent thought
+        :return:
+        """
+        self.publish(QueueAgentThoughtEvent(
+            agent_thought_id=message_agent_thought.id
+        ))
+
+    def publish_error(self, e) -> None:
+        """
+        Publish error
+        :param e: error
+        :return:
+        """
+        self.publish(QueueErrorEvent(
+            error=e
+        ))
+        self.stop_listen()
+
+    def publish(self, event: AppQueueEvent) -> None:
+        """
+        Publish event to queue
+        :param event:
+        :return:
+        """
+        self._check_for_sqlalchemy_models(event.dict())
+
+        message = QueueMessage(
+            task_id=self._task_id,
+            message_id=self._message_id,
+            conversation_id=self._conversation_id,
+            app_mode=self._app_mode,
+            event=event
+        )
+
+        self._q.put(message)
+
+        if isinstance(event, QueueStopEvent):
+            self.stop_listen()
+
+    @classmethod
+    def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None:
+        """
+        Set task stop flag
+        :return:
+        """
+        result = redis_client.get(cls._generate_task_belong_cache_key(task_id))
+        if result is None:
+            return
+
+        user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
+        if result != f"{user_prefix}-{user_id}":
+            return
+
+        stopped_cache_key = cls._generate_stopped_cache_key(task_id)
+        redis_client.setex(stopped_cache_key, 600, 1)
+
+    def _is_stopped(self) -> bool:
+        """
+        Check if task is stopped
+        :return:
+        """
+        stopped_cache_key = ApplicationQueueManager._generate_stopped_cache_key(self._task_id)
+        result = redis_client.get(stopped_cache_key)
+        if result is not None:
+            redis_client.delete(stopped_cache_key)
+            return True
+
+        return False
+
+    @classmethod
+    def _generate_task_belong_cache_key(cls, task_id: str) -> str:
+        """
+        Generate task belong cache key
+        :param task_id: task id
+        :return:
+        """
+        return f"generate_task_belong:{task_id}"
+
+    @classmethod
+    def _generate_stopped_cache_key(cls, task_id: str) -> str:
+        """
+        Generate stopped cache key
+        :param task_id: task id
+        :return:
+        """
+        return f"generate_task_stopped:{task_id}"
+
+    def _check_for_sqlalchemy_models(self, data: Any):
+        # from entity to dict or list
+        if isinstance(data, dict):
+            for key, value in data.items():
+                self._check_for_sqlalchemy_models(value)
+        elif isinstance(data, list):
+            for item in data:
+                self._check_for_sqlalchemy_models(item)
+        else:
+            if isinstance(data, DeclarativeMeta) or hasattr(data, '_sa_instance_state'):
+                raise TypeError("Critical Error: Passing SQLAlchemy Model instances "
+                                "that cause thread safety issues is not allowed.")
+
+
+class ConversationTaskStoppedException(Exception):
+    pass

+ 110 - 65
api/core/callback_handler/agent_loop_gather_callback_handler.py

@@ -2,30 +2,40 @@ import json
 import logging
 import time
 
-from typing import Any, Dict, List, Union, Optional
+from typing import Any, Dict, List, Union, Optional, cast
 
 from langchain.agents import openai_functions_agent, openai_functions_multi_agent
 from langchain.callbacks.base import BaseCallbackHandler
 from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage
 
+from core.application_queue_manager import ApplicationQueueManager
 from core.callback_handler.entity.agent_loop import AgentLoop
-from core.conversation_message_task import ConversationMessageTask
-from core.model_providers.models.entity.message import PromptMessage
-from core.model_providers.models.llm.base import BaseLLM
+from core.entities.application_entities import ModelConfigEntity
+from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult
+from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessage
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from extensions.ext_database import db
+from models.model import MessageChain, MessageAgentThought, Message
 
 
 class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
     """Callback Handler that prints to std out."""
     raise_error: bool = True
 
-    def __init__(self, model_instance: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
+    def __init__(self, model_config: ModelConfigEntity,
+                 queue_manager: ApplicationQueueManager,
+                 message: Message,
+                 message_chain: MessageChain) -> None:
         """Initialize callback handler."""
-        self.model_instance = model_instance
-        self.conversation_message_task = conversation_message_task
+        self.model_config = model_config
+        self.queue_manager = queue_manager
+        self.message = message
+        self.message_chain = message_chain
+        model_type_instance = self.model_config.provider_model_bundle.model_type_instance
+        self.model_type_instance = cast(LargeLanguageModel, model_type_instance)
         self._agent_loops = []
         self._current_loop = None
         self._message_agent_thought = None
-        self.current_chain = None
 
     @property
     def agent_loops(self) -> List[AgentLoop]:
@@ -46,66 +56,61 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
         """Whether to ignore chain callbacks."""
         return True
 
-    def on_chat_model_start(
-            self,
-            serialized: Dict[str, Any],
-            messages: List[List[BaseMessage]],
-            **kwargs: Any
-    ) -> Any:
-        if not self._current_loop:
-            # Agent start with a LLM query
-            self._current_loop = AgentLoop(
-                position=len(self._agent_loops) + 1,
-                prompt="\n".join([message.content for message in messages[0]]),
-                status='llm_started',
-                started_at=time.perf_counter()
-            )
-
-    def on_llm_start(
-        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
-    ) -> None:
-        """Print out the prompts."""
-        # serialized={'name': 'OpenAI'}
-        # prompts=['Answer the following questions...\nThought:']
-        # kwargs={}
+    def on_llm_before_invoke(self, prompt_messages: list[PromptMessage]) -> None:
         if not self._current_loop:
             # Agent start with a LLM query
             self._current_loop = AgentLoop(
                 position=len(self._agent_loops) + 1,
-                prompt=prompts[0],
+                prompt="\n".join([prompt_message.content for prompt_message in prompt_messages]),
                 status='llm_started',
                 started_at=time.perf_counter()
             )
 
-    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
-        """Do nothing."""
-        # kwargs={}
+    def on_llm_after_invoke(self, result: RuntimeLLMResult) -> None:
         if self._current_loop and self._current_loop.status == 'llm_started':
             self._current_loop.status = 'llm_end'
-            if response.llm_output:
-                self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
+            if result.usage:
+                self._current_loop.prompt_tokens = result.usage.prompt_tokens
             else:
-                self._current_loop.prompt_tokens = self.model_instance.get_num_tokens(
-                    [PromptMessage(content=self._current_loop.prompt)]
+                self._current_loop.prompt_tokens = self.model_type_instance.get_num_tokens(
+                    model=self.model_config.model,
+                    credentials=self.model_config.credentials,
+                    prompt_messages=[UserPromptMessage(content=self._current_loop.prompt)]
                 )
-            completion_generation = response.generations[0][0]
-            if isinstance(completion_generation, ChatGeneration):
-                completion_message = completion_generation.message
-                if 'function_call' in completion_message.additional_kwargs:
-                    self._current_loop.completion \
-                        = json.dumps({'function_call': completion_message.additional_kwargs['function_call']})
-                else:
-                    self._current_loop.completion = response.generations[0][0].text
+
+            completion_message = result.message
+            if completion_message.tool_calls:
+                self._current_loop.completion \
+                    = json.dumps({'function_call': completion_message.tool_calls})
             else:
-                self._current_loop.completion = completion_generation.text
+                self._current_loop.completion = completion_message.content
 
-            if response.llm_output:
-                self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
+            if result.usage:
+                self._current_loop.completion_tokens = result.usage.completion_tokens
             else:
-                self._current_loop.completion_tokens = self.model_instance.get_num_tokens(
-                    [PromptMessage(content=self._current_loop.completion)]
+                self._current_loop.completion_tokens = self.model_type_instance.get_num_tokens(
+                    model=self.model_config.model,
+                    credentials=self.model_config.credentials,
+                    prompt_messages=[AssistantPromptMessage(content=self._current_loop.completion)]
                 )
 
+    def on_chat_model_start(
+            self,
+            serialized: Dict[str, Any],
+            messages: List[List[BaseMessage]],
+            **kwargs: Any
+    ) -> Any:
+        pass
+
+    def on_llm_start(
+        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
+    ) -> None:
+        pass
+
+    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
+        """Do nothing."""
+        pass
+
     def on_llm_error(
         self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
     ) -> None:
@@ -150,10 +155,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
             if completion is not None:
                 self._current_loop.completion = completion
 
-            self._message_agent_thought = self.conversation_message_task.on_agent_start(
-                self.current_chain,
-                self._current_loop
-            )
+            self._message_agent_thought = self._init_agent_thought()
 
     def on_tool_end(
         self,
@@ -176,9 +178,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
             self._current_loop.completed_at = time.perf_counter()
             self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
 
-            self.conversation_message_task.on_agent_end(
-                self._message_agent_thought, self.model_instance, self._current_loop
-            )
+            self._complete_agent_thought(self._message_agent_thought)
 
             self._agent_loops.append(self._current_loop)
             self._current_loop = None
@@ -202,17 +202,62 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
             self._current_loop.completed_at = time.perf_counter()
             self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
             self._current_loop.thought = '[DONE]'
-            self._message_agent_thought = self.conversation_message_task.on_agent_start(
-                self.current_chain,
-                self._current_loop
-            )
+            self._message_agent_thought = self._init_agent_thought()
 
-            self.conversation_message_task.on_agent_end(
-                self._message_agent_thought, self.model_instance, self._current_loop
-            )
+            self._complete_agent_thought(self._message_agent_thought)
 
             self._agent_loops.append(self._current_loop)
             self._current_loop = None
             self._message_agent_thought = None
         elif not self._current_loop and self._agent_loops:
             self._agent_loops[-1].status = 'agent_finish'
+
+    def _init_agent_thought(self) -> MessageAgentThought:
+        message_agent_thought = MessageAgentThought(
+            message_id=self.message.id,
+            message_chain_id=self.message_chain.id,
+            position=self._current_loop.position,
+            thought=self._current_loop.thought,
+            tool=self._current_loop.tool_name,
+            tool_input=self._current_loop.tool_input,
+            message=self._current_loop.prompt,
+            message_price_unit=0,
+            answer=self._current_loop.completion,
+            answer_price_unit=0,
+            created_by_role=('account' if self.message.from_source == 'console' else 'end_user'),
+            created_by=(self.message.from_account_id
+                        if self.message.from_source == 'console' else self.message.from_end_user_id)
+        )
+
+        db.session.add(message_agent_thought)
+        db.session.commit()
+
+        self.queue_manager.publish_agent_thought(message_agent_thought)
+
+        return message_agent_thought
+
+    def _complete_agent_thought(self, message_agent_thought: MessageAgentThought) -> None:
+        loop_message_tokens = self._current_loop.prompt_tokens
+        loop_answer_tokens = self._current_loop.completion_tokens
+
+        # transform usage
+        llm_usage = self.model_type_instance._calc_response_usage(
+            self.model_config.model,
+            self.model_config.credentials,
+            loop_message_tokens,
+            loop_answer_tokens
+        )
+
+        message_agent_thought.observation = self._current_loop.tool_output
+        message_agent_thought.tool_process_data = ''  # currently not support
+        message_agent_thought.message_token = loop_message_tokens
+        message_agent_thought.message_unit_price = llm_usage.prompt_unit_price
+        message_agent_thought.message_price_unit = llm_usage.prompt_price_unit
+        message_agent_thought.answer_token = loop_answer_tokens
+        message_agent_thought.answer_unit_price = llm_usage.completion_unit_price
+        message_agent_thought.answer_price_unit = llm_usage.completion_price_unit
+        message_agent_thought.latency = self._current_loop.latency
+        message_agent_thought.tokens = self._current_loop.prompt_tokens + self._current_loop.completion_tokens
+        message_agent_thought.total_price = llm_usage.total_price
+        message_agent_thought.currency = llm_usage.currency
+        db.session.commit()

+ 0 - 74
api/core/callback_handler/dataset_tool_callback_handler.py

@@ -1,74 +0,0 @@
-import json
-import logging
-from json import JSONDecodeError
-
-from typing import Any, Dict, List, Union, Optional
-
-from langchain.callbacks.base import BaseCallbackHandler
-
-from core.callback_handler.entity.dataset_query import DatasetQueryObj
-from core.conversation_message_task import ConversationMessageTask
-
-
-class DatasetToolCallbackHandler(BaseCallbackHandler):
-    """Callback Handler that prints to std out."""
-    raise_error: bool = True
-
-    def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
-        """Initialize callback handler."""
-        self.queries = []
-        self.conversation_message_task = conversation_message_task
-
-    @property
-    def always_verbose(self) -> bool:
-        """Whether to call verbose callbacks even if verbose is False."""
-        return True
-
-    @property
-    def ignore_llm(self) -> bool:
-        """Whether to ignore LLM callbacks."""
-        return True
-
-    @property
-    def ignore_chain(self) -> bool:
-        """Whether to ignore chain callbacks."""
-        return True
-
-    @property
-    def ignore_agent(self) -> bool:
-        """Whether to ignore agent callbacks."""
-        return False
-
-    def on_tool_start(
-        self,
-        serialized: Dict[str, Any],
-        input_str: str,
-        **kwargs: Any,
-    ) -> None:
-        tool_name: str = serialized.get('name')
-        dataset_id = tool_name.removeprefix('dataset-')
-
-        try:
-            input_dict = json.loads(input_str.replace("'", "\""))
-            query = input_dict.get('query')
-        except JSONDecodeError:
-            query = input_str
-
-        self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=query))
-
-    def on_tool_end(
-        self,
-        output: str,
-        color: Optional[str] = None,
-        observation_prefix: Optional[str] = None,
-        llm_prefix: Optional[str] = None,
-        **kwargs: Any,
-    ) -> None:
-        pass
-
-
-    def on_tool_error(
-        self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
-    ) -> None:
-        """Do nothing."""
-        logging.debug("Dataset tool on_llm_error: %s", error)

+ 0 - 16
api/core/callback_handler/entity/chain_result.py

@@ -1,16 +0,0 @@
-from pydantic import BaseModel
-
-
-class ChainResult(BaseModel):
-    type: str = None
-    prompt: dict = None
-    completion: dict = None
-
-    status: str = 'chain_started'
-    completed: bool = False
-
-    started_at: float = None
-    completed_at: float = None
-
-    agent_result: dict = None
-    """only when type is 'AgentExecutor'"""

+ 0 - 6
api/core/callback_handler/entity/dataset_query.py

@@ -1,6 +0,0 @@
-from pydantic import BaseModel
-
-
-class DatasetQueryObj(BaseModel):
-    dataset_id: str = None
-    query: str = None

+ 0 - 8
api/core/callback_handler/entity/llm_message.py

@@ -1,8 +0,0 @@
-from pydantic import BaseModel
-
-
-class LLMMessage(BaseModel):
-    prompt: str = ''
-    prompt_tokens: int = 0
-    completion: str = ''
-    completion_tokens: int = 0

+ 56 - 6
api/core/callback_handler/index_tool_callback_handler.py

@@ -1,17 +1,44 @@
-from typing import List
+from typing import List, Union
 
 from langchain.schema import Document
 
-from core.conversation_message_task import ConversationMessageTask
+from core.application_queue_manager import ApplicationQueueManager
+from core.entities.application_entities import InvokeFrom
 from extensions.ext_database import db
-from models.dataset import DocumentSegment
+from models.dataset import DocumentSegment, DatasetQuery
+from models.model import DatasetRetrieverResource
 
 
 class DatasetIndexToolCallbackHandler:
     """Callback handler for dataset tool."""
 
-    def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
-        self.conversation_message_task = conversation_message_task
+    def __init__(self, queue_manager: ApplicationQueueManager,
+                 app_id: str,
+                 message_id: str,
+                 user_id: str,
+                 invoke_from: InvokeFrom) -> None:
+        self._queue_manager = queue_manager
+        self._app_id = app_id
+        self._message_id = message_id
+        self._user_id = user_id
+        self._invoke_from = invoke_from
+
+    def on_query(self, query: str, dataset_id: str) -> None:
+        """
+        Handle query.
+        """
+        dataset_query = DatasetQuery(
+            dataset_id=dataset_id,
+            content=query,
+            source='app',
+            source_app_id=self._app_id,
+            created_by_role=('account'
+                             if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'),
+            created_by=self._user_id
+        )
+
+        db.session.add(dataset_query)
+        db.session.commit()
 
     def on_tool_end(self, documents: List[Document]) -> None:
         """Handle tool end."""
@@ -30,4 +57,27 @@ class DatasetIndexToolCallbackHandler:
 
     def return_retriever_resource_info(self, resource: List):
         """Handle return_retriever_resource_info."""
-        self.conversation_message_task.on_dataset_query_finish(resource)
+        if resource and len(resource) > 0:
+            for item in resource:
+                dataset_retriever_resource = DatasetRetrieverResource(
+                    message_id=self._message_id,
+                    position=item.get('position'),
+                    dataset_id=item.get('dataset_id'),
+                    dataset_name=item.get('dataset_name'),
+                    document_id=item.get('document_id'),
+                    document_name=item.get('document_name'),
+                    data_source_type=item.get('data_source_type'),
+                    segment_id=item.get('segment_id'),
+                    score=item.get('score') if 'score' in item else None,
+                    hit_count=item.get('hit_count') if 'hit_count' else None,
+                    word_count=item.get('word_count') if 'word_count' in item else None,
+                    segment_position=item.get('segment_position') if 'segment_position' in item else None,
+                    index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None,
+                    content=item.get('content'),
+                    retriever_from=item.get('retriever_from'),
+                    created_by=self._user_id
+                )
+                db.session.add(dataset_retriever_resource)
+                db.session.commit()
+
+        self._queue_manager.publish_retriever_resources(resource)

+ 0 - 284
api/core/callback_handler/llm_callback_handler.py

@@ -1,284 +0,0 @@
-import logging
-import threading
-import time
-from typing import Any, Dict, List, Union, Optional
-
-from flask import Flask, current_app
-from langchain.callbacks.base import BaseCallbackHandler
-from langchain.schema import LLMResult, BaseMessage
-from pydantic import BaseModel
-
-from core.callback_handler.entity.llm_message import LLMMessage
-from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
-    ConversationTaskInterruptException
-from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage, LCHumanMessageWithFiles, \
-    ImagePromptMessageFile
-from core.model_providers.models.llm.base import BaseLLM
-from core.moderation.base import ModerationOutputsResult, ModerationAction
-from core.moderation.factory import ModerationFactory
-
-
-class ModerationRule(BaseModel):
-    type: str
-    config: Dict[str, Any]
-
-
-class LLMCallbackHandler(BaseCallbackHandler):
-    raise_error: bool = True
-
-    def __init__(self, model_instance: BaseLLM,
-                 conversation_message_task: ConversationMessageTask):
-        self.model_instance = model_instance
-        self.llm_message = LLMMessage()
-        self.start_at = None
-        self.conversation_message_task = conversation_message_task
-
-        self.output_moderation_handler = None
-        self.init_output_moderation()
-
-    def init_output_moderation(self):
-        app_model_config = self.conversation_message_task.app_model_config
-        sensitive_word_avoidance_dict = app_model_config.sensitive_word_avoidance_dict
-
-        if sensitive_word_avoidance_dict and sensitive_word_avoidance_dict.get("enabled"):
-            self.output_moderation_handler = OutputModerationHandler(
-                tenant_id=self.conversation_message_task.tenant_id,
-                app_id=self.conversation_message_task.app.id,
-                rule=ModerationRule(
-                    type=sensitive_word_avoidance_dict.get("type"),
-                    config=sensitive_word_avoidance_dict.get("config")
-                ),
-                on_message_replace_func=self.conversation_message_task.on_message_replace
-            )
-
-    @property
-    def always_verbose(self) -> bool:
-        """Whether to call verbose callbacks even if verbose is False."""
-        return True
-
-    def on_chat_model_start(
-            self,
-            serialized: Dict[str, Any],
-            messages: List[List[BaseMessage]],
-            **kwargs: Any
-    ) -> Any:
-        real_prompts = []
-        for message in messages[0]:
-            if message.type == 'human':
-                role = 'user'
-            elif message.type == 'ai':
-                role = 'assistant'
-            else:
-                role = 'system'
-
-            real_prompts.append({
-                "role": role,
-                "text": message.content,
-                "files": [{
-                    "type": file.type.value,
-                    "data": file.data[:10] + '...[TRUNCATED]...' + file.data[-10:],
-                    "detail": file.detail.value if isinstance(file, ImagePromptMessageFile) else None,
-                } for file in (message.files if isinstance(message, LCHumanMessageWithFiles) else [])]
-            })
-
-        self.llm_message.prompt = real_prompts
-        self.llm_message.prompt_tokens = self.model_instance.get_num_tokens(to_prompt_messages(messages[0]))
-
-    def on_llm_start(
-        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
-    ) -> None:
-        self.llm_message.prompt = [{
-            "role": 'user',
-            "text": prompts[0]
-        }]
-
-        self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
-
-    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
-        if self.output_moderation_handler:
-            self.output_moderation_handler.stop_thread()
-
-            self.llm_message.completion = self.output_moderation_handler.moderation_completion(
-                completion=response.generations[0][0].text,
-                public_event=True if self.conversation_message_task.streaming else False
-            )
-        else:
-            self.llm_message.completion = response.generations[0][0].text
-
-        if not self.conversation_message_task.streaming:
-            self.conversation_message_task.append_message_text(self.llm_message.completion)
-
-        if response.llm_output and 'token_usage' in response.llm_output:
-            if 'prompt_tokens' in response.llm_output['token_usage']:
-                self.llm_message.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
-
-            if 'completion_tokens' in response.llm_output['token_usage']:
-                self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
-            else:
-                self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
-                    [PromptMessage(content=self.llm_message.completion)])
-        else:
-            self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
-                [PromptMessage(content=self.llm_message.completion)])
-
-        self.conversation_message_task.save_message(self.llm_message)
-
-    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
-        if self.output_moderation_handler and self.output_moderation_handler.should_direct_output():
-            # stop subscribe new token when output moderation should direct output
-            ex = ConversationTaskInterruptException()
-            self.on_llm_error(error=ex)
-            raise ex
-
-        try:
-            self.conversation_message_task.append_message_text(token)
-            self.llm_message.completion += token
-
-            if self.output_moderation_handler:
-                self.output_moderation_handler.append_new_token(token)
-        except ConversationTaskStoppedException as ex:
-            self.on_llm_error(error=ex)
-            raise ex
-
-    def on_llm_error(
-            self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
-    ) -> None:
-        """Do nothing."""
-        if self.output_moderation_handler:
-            self.output_moderation_handler.stop_thread()
-
-        if isinstance(error, ConversationTaskStoppedException):
-            if self.conversation_message_task.streaming:
-                self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
-                    [PromptMessage(content=self.llm_message.completion)]
-                )
-                self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
-        if isinstance(error, ConversationTaskInterruptException):
-            self.llm_message.completion = self.output_moderation_handler.get_final_output()
-            self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
-                [PromptMessage(content=self.llm_message.completion)]
-            )
-            self.conversation_message_task.save_message(llm_message=self.llm_message)
-        else:
-            logging.debug("on_llm_error: %s", error)
-
-
-class OutputModerationHandler(BaseModel):
-    DEFAULT_BUFFER_SIZE: int = 300
-
-    tenant_id: str
-    app_id: str
-
-    rule: ModerationRule
-    on_message_replace_func: Any
-
-    thread: Optional[threading.Thread] = None
-    thread_running: bool = True
-    buffer: str = ''
-    is_final_chunk: bool = False
-    final_output: Optional[str] = None
-
-    class Config:
-        arbitrary_types_allowed = True
-
-    def should_direct_output(self):
-        return self.final_output is not None
-
-    def get_final_output(self):
-        return self.final_output
-
-    def append_new_token(self, token: str):
-        self.buffer += token
-
-        if not self.thread:
-            self.thread = self.start_thread()
-
-    def moderation_completion(self, completion: str, public_event: bool = False) -> str:
-        self.buffer = completion
-        self.is_final_chunk = True
-
-        result = self.moderation(
-            tenant_id=self.tenant_id,
-            app_id=self.app_id,
-            moderation_buffer=completion
-        )
-
-        if not result or not result.flagged:
-            return completion
-
-        if result.action == ModerationAction.DIRECT_OUTPUT:
-            final_output = result.preset_response
-        else:
-            final_output = result.text
-
-        if public_event:
-            self.on_message_replace_func(final_output)
-
-        return final_output
-
-    def start_thread(self) -> threading.Thread:
-        buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE))
-        thread = threading.Thread(target=self.worker, kwargs={
-            'flask_app': current_app._get_current_object(),
-            'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE
-        })
-
-        thread.start()
-
-        return thread
-
-    def stop_thread(self):
-        if self.thread and self.thread.is_alive():
-            self.thread_running = False
-
-    def worker(self, flask_app: Flask, buffer_size: int):
-        with flask_app.app_context():
-            current_length = 0
-            while self.thread_running:
-                moderation_buffer = self.buffer
-                buffer_length = len(moderation_buffer)
-                if not self.is_final_chunk:
-                    chunk_length = buffer_length - current_length
-                    if 0 <= chunk_length < buffer_size:
-                        time.sleep(1)
-                        continue
-
-                current_length = buffer_length
-
-                result = self.moderation(
-                    tenant_id=self.tenant_id,
-                    app_id=self.app_id,
-                    moderation_buffer=moderation_buffer
-                )
-
-                if not result or not result.flagged:
-                    continue
-
-                if result.action == ModerationAction.DIRECT_OUTPUT:
-                    final_output = result.preset_response
-                    self.final_output = final_output
-                else:
-                    final_output = result.text + self.buffer[len(moderation_buffer):]
-
-                # trigger replace event
-                if self.thread_running:
-                    self.on_message_replace_func(final_output)
-
-                if result.action == ModerationAction.DIRECT_OUTPUT:
-                    break
-
-    def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
-        try:
-            moderation_factory = ModerationFactory(
-                name=self.rule.type,
-                app_id=app_id,
-                tenant_id=tenant_id,
-                config=self.rule.config
-            )
-
-            result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
-            return result
-        except Exception as e:
-            logging.error("Moderation Output error: %s", e)
-
-        return None

+ 0 - 76
api/core/callback_handler/main_chain_gather_callback_handler.py

@@ -1,76 +0,0 @@
-import logging
-import time
-
-from typing import Any, Dict, Union
-
-from langchain.callbacks.base import BaseCallbackHandler
-
-from core.callback_handler.entity.chain_result import ChainResult
-from core.conversation_message_task import ConversationMessageTask
-
-
-class MainChainGatherCallbackHandler(BaseCallbackHandler):
-    """Callback Handler that prints to std out."""
-    raise_error: bool = True
-
-    def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
-        """Initialize callback handler."""
-        self._current_chain_result = None
-        self._current_chain_message = None
-        self.conversation_message_task = conversation_message_task
-        self.agent_callback = None
-
-    def clear_chain_results(self) -> None:
-        self._current_chain_result = None
-        self._current_chain_message = None
-        if self.agent_callback:
-            self.agent_callback.current_chain = None
-
-    @property
-    def always_verbose(self) -> bool:
-        """Whether to call verbose callbacks even if verbose is False."""
-        return True
-
-    @property
-    def ignore_llm(self) -> bool:
-        """Whether to ignore LLM callbacks."""
-        return True
-
-    @property
-    def ignore_agent(self) -> bool:
-        """Whether to ignore agent callbacks."""
-        return True
-
-    def on_chain_start(
-        self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
-    ) -> None:
-        """Print out that we are entering a chain."""
-        if not self._current_chain_result:
-            chain_type = serialized['id'][-1]
-            if chain_type:
-                self._current_chain_result = ChainResult(
-                    type=chain_type,
-                    prompt=inputs,
-                    started_at=time.perf_counter()
-                )
-                self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result)
-                if self.agent_callback:
-                    self.agent_callback.current_chain = self._current_chain_message
-
-    def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
-        """Print out that we finished a chain."""
-        if self._current_chain_result and self._current_chain_result.status == 'chain_started':
-            self._current_chain_result.status = 'chain_ended'
-            self._current_chain_result.completion = outputs
-            self._current_chain_result.completed = True
-            self._current_chain_result.completed_at = time.perf_counter()
-
-            self.conversation_message_task.on_chain_end(self._current_chain_message, self._current_chain_result)
-
-            self.clear_chain_results()
-
-    def on_chain_error(
-        self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
-    ) -> None:
-        logging.debug("Dataset tool on_chain_error: %s", error)
-        self.clear_chain_results()

+ 5 - 2
api/core/callback_handler/std_out_callback_handler.py

@@ -79,8 +79,11 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
         """Run on agent action."""
         tool = action.tool
         tool_input = action.tool_input
-        action_name_position = action.log.index("\nAction:") + 1 if action.log else -1
-        thought = action.log[:action_name_position].strip() if action.log else ''
+        try:
+            action_name_position = action.log.index("\nAction:") + 1 if action.log else -1
+            thought = action.log[:action_name_position].strip() if action.log else ''
+        except ValueError:
+            thought = ''
 
         log = f"Thought: {thought}\nTool: {tool}\nTool Input: {tool_input}"
         print_text("\n[on_agent_action]\n" + log + "\n", color='green')

+ 21 - 8
api/core/chain/llm_chain.py

@@ -5,15 +5,19 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
 from langchain.schema import LLMResult, Generation
 from langchain.schema.language_model import BaseLanguageModel
 
-from core.model_providers.models.entity.message import to_prompt_messages
-from core.model_providers.models.llm.base import BaseLLM
+from core.agent.agent.agent_llm_callback import AgentLLMCallback
+from core.entities.application_entities import ModelConfigEntity
+from core.model_manager import ModelInstance
+from core.entities.message_entities import lc_messages_to_prompt_messages
 from core.third_party.langchain.llms.fake import FakeLLM
 
 
 class LLMChain(LCLLMChain):
-    model_instance: BaseLLM
+    model_config: ModelConfigEntity
     """The language model instance to use."""
     llm: BaseLanguageModel = FakeLLM(response="")
+    parameters: Dict[str, Any] = {}
+    agent_llm_callback: Optional[AgentLLMCallback] = None
 
     def generate(
         self,
@@ -23,14 +27,23 @@ class LLMChain(LCLLMChain):
         """Generate LLM result from inputs."""
         prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
         messages = prompts[0].to_messages()
-        prompt_messages = to_prompt_messages(messages)
-        result = self.model_instance.run(
-            messages=prompt_messages,
-            stop=stop
+        prompt_messages = lc_messages_to_prompt_messages(messages)
+
+        model_instance = ModelInstance(
+            provider_model_bundle=self.model_config.provider_model_bundle,
+            model=self.model_config.model,
+        )
+
+        result = model_instance.invoke_llm(
+            prompt_messages=prompt_messages,
+            stream=False,
+            stop=stop,
+            callbacks=[self.agent_llm_callback] if self.agent_llm_callback else None,
+            model_parameters=self.parameters
         )
 
         generations = [
-            [Generation(text=result.content)]
+            [Generation(text=result.message.content)]
         ]
 
         return LLMResult(generations=generations)

+ 0 - 501
api/core/completion.py

@@ -1,501 +0,0 @@
-import concurrent
-import json
-import logging
-from concurrent.futures import ThreadPoolExecutor
-from typing import Optional, List, Union, Tuple
-
-from flask import current_app, Flask
-from requests.exceptions import ChunkedEncodingError
-
-from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
-from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
-from core.callback_handler.llm_callback_handler import LLMCallbackHandler
-from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
-    ConversationTaskInterruptException
-from core.embedding.cached_embedding import CacheEmbedding
-from core.external_data_tool.factory import ExternalDataToolFactory
-from core.file.file_obj import FileObj
-from core.index.vector_index.vector_index import VectorIndex
-from core.model_providers.error import LLMBadRequestError
-from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
-    ReadOnlyConversationTokenDBBufferSharedMemory
-from core.model_providers.model_factory import ModelFactory
-from core.model_providers.models.entity.message import PromptMessage, PromptMessageFile
-from core.model_providers.models.llm.base import BaseLLM
-from core.orchestrator_rule_parser import OrchestratorRuleParser
-from core.prompt.prompt_template import PromptTemplateParser
-from core.prompt.prompt_transform import PromptTransform
-from models.dataset import Dataset
-from models.model import App, AppModelConfig, Account, Conversation, EndUser
-from core.moderation.base import ModerationException, ModerationAction
-from core.moderation.factory import ModerationFactory
-from services.annotation_service import AppAnnotationService
-from services.dataset_service import DatasetCollectionBindingService
-
-
-class Completion:
-    @classmethod
-    def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
-                 files: List[FileObj], user: Union[Account, EndUser], conversation: Optional[Conversation],
-                 streaming: bool, is_override: bool = False, retriever_from: str = 'dev',
-                 auto_generate_name: bool = True, from_source: str = 'console'):
-        """
-        errors: ProviderTokenNotInitError
-        """
-        query = PromptTemplateParser.remove_template_variables(query)
-
-        memory = None
-        if conversation:
-            # get memory of conversation (read-only)
-            memory = cls.get_memory_from_conversation(
-                tenant_id=app.tenant_id,
-                app_model_config=app_model_config,
-                conversation=conversation,
-                return_messages=False
-            )
-
-            inputs = conversation.inputs
-
-        final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
-            tenant_id=app.tenant_id,
-            model_config=app_model_config.model_dict,
-            streaming=streaming
-        )
-
-        conversation_message_task = ConversationMessageTask(
-            task_id=task_id,
-            app=app,
-            app_model_config=app_model_config,
-            user=user,
-            conversation=conversation,
-            is_override=is_override,
-            inputs=inputs,
-            query=query,
-            files=files,
-            streaming=streaming,
-            model_instance=final_model_instance,
-            auto_generate_name=auto_generate_name
-        )
-
-        prompt_message_files = [file.prompt_message_file for file in files]
-
-        rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
-            mode=app.mode,
-            model_instance=final_model_instance,
-            app_model_config=app_model_config,
-            query=query,
-            inputs=inputs,
-            files=prompt_message_files
-        )
-
-        # init orchestrator rule parser
-        orchestrator_rule_parser = OrchestratorRuleParser(
-            tenant_id=app.tenant_id,
-            app_model_config=app_model_config
-        )
-
-        try:
-            chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
-
-            try:
-                # process sensitive_word_avoidance
-                inputs, query = cls.moderation_for_inputs(app.id, app.tenant_id, app_model_config, inputs, query)
-            except ModerationException as e:
-                cls.run_final_llm(
-                    model_instance=final_model_instance,
-                    mode=app.mode,
-                    app_model_config=app_model_config,
-                    query=query,
-                    inputs=inputs,
-                    files=prompt_message_files,
-                    agent_execute_result=None,
-                    conversation_message_task=conversation_message_task,
-                    memory=memory,
-                    fake_response=str(e)
-                )
-                return
-            # check annotation reply
-            annotation_reply = cls.query_app_annotations_to_reply(conversation_message_task, from_source)
-            if annotation_reply:
-                return
-            # fill in variable inputs from external data tools if exists
-            external_data_tools = app_model_config.external_data_tools_list
-            if external_data_tools:
-                inputs = cls.fill_in_inputs_from_external_data_tools(
-                    tenant_id=app.tenant_id,
-                    app_id=app.id,
-                    external_data_tools=external_data_tools,
-                    inputs=inputs,
-                    query=query
-                )
-
-            # get agent executor
-            agent_executor = orchestrator_rule_parser.to_agent_executor(
-                conversation_message_task=conversation_message_task,
-                memory=memory,
-                rest_tokens=rest_tokens_for_context_and_memory,
-                chain_callback=chain_callback,
-                tenant_id=app.tenant_id,
-                retriever_from=retriever_from
-            )
-
-            query_for_agent = cls.get_query_for_agent(app, app_model_config, query, inputs)
-
-            # run agent executor
-            agent_execute_result = None
-            if query_for_agent and agent_executor:
-                should_use_agent = agent_executor.should_use_agent(query_for_agent)
-                if should_use_agent:
-                    agent_execute_result = agent_executor.run(query_for_agent)
-
-            # When no extra pre prompt is specified,
-            # the output of the agent can be used directly as the main output content without calling LLM again
-            fake_response = None
-            if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
-                    and agent_execute_result.strategy not in [PlanningStrategy.ROUTER,
-                                                              PlanningStrategy.REACT_ROUTER]:
-                fake_response = agent_execute_result.output
-
-            # run the final llm
-            cls.run_final_llm(
-                model_instance=final_model_instance,
-                mode=app.mode,
-                app_model_config=app_model_config,
-                query=query,
-                inputs=inputs,
-                files=prompt_message_files,
-                agent_execute_result=agent_execute_result,
-                conversation_message_task=conversation_message_task,
-                memory=memory,
-                fake_response=fake_response
-            )
-        except (ConversationTaskInterruptException, ConversationTaskStoppedException):
-            return
-        except ChunkedEncodingError as e:
-            # Interrupt by LLM (like OpenAI), handle it.
-            logging.warning(f'ChunkedEncodingError: {e}')
-            return
-
-    @classmethod
-    def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict,
-                              query: str):
-        if not app_model_config.sensitive_word_avoidance_dict['enabled']:
-            return inputs, query
-
-        type = app_model_config.sensitive_word_avoidance_dict['type']
-
-        moderation = ModerationFactory(type, app_id, tenant_id,
-                                       app_model_config.sensitive_word_avoidance_dict['config'])
-        moderation_result = moderation.moderation_for_inputs(inputs, query)
-
-        if not moderation_result.flagged:
-            return inputs, query
-
-        if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
-            raise ModerationException(moderation_result.preset_response)
-        elif moderation_result.action == ModerationAction.OVERRIDED:
-            inputs = moderation_result.inputs
-            query = moderation_result.query
-
-        return inputs, query
-
-    @classmethod
-    def fill_in_inputs_from_external_data_tools(cls, tenant_id: str, app_id: str, external_data_tools: list[dict],
-                                                inputs: dict, query: str) -> dict:
-        """
-        Fill in variable inputs from external data tools if exists.
-
-        :param tenant_id: workspace id
-        :param app_id: app id
-        :param external_data_tools: external data tools configs
-        :param inputs: the inputs
-        :param query: the query
-        :return: the filled inputs
-        """
-        # Group tools by type and config
-        grouped_tools = {}
-        for tool in external_data_tools:
-            if not tool.get("enabled"):
-                continue
-
-            tool_key = (tool.get("type"), json.dumps(tool.get("config"), sort_keys=True))
-            grouped_tools.setdefault(tool_key, []).append(tool)
-
-        results = {}
-        with ThreadPoolExecutor() as executor:
-            futures = {}
-            for tool in external_data_tools:
-                if not tool.get("enabled"):
-                    continue
-
-                future = executor.submit(
-                    cls.query_external_data_tool, current_app._get_current_object(), tenant_id, app_id, tool,
-                    inputs, query
-                )
-
-                futures[future] = tool
-
-            for future in concurrent.futures.as_completed(futures):
-                tool_variable, result = future.result()
-                results[tool_variable] = result
-
-        inputs.update(results)
-        return inputs
-
-    @classmethod
-    def query_external_data_tool(cls, flask_app: Flask, tenant_id: str, app_id: str, external_data_tool: dict,
-                                 inputs: dict, query: str) -> Tuple[Optional[str], Optional[str]]:
-        with flask_app.app_context():
-            tool_variable = external_data_tool.get("variable")
-            tool_type = external_data_tool.get("type")
-            tool_config = external_data_tool.get("config")
-
-            external_data_tool_factory = ExternalDataToolFactory(
-                name=tool_type,
-                tenant_id=tenant_id,
-                app_id=app_id,
-                variable=tool_variable,
-                config=tool_config
-            )
-
-            # query external data tool
-            result = external_data_tool_factory.query(
-                inputs=inputs,
-                query=query
-            )
-
-            return tool_variable, result
-
-    @classmethod
-    def get_query_for_agent(cls, app: App, app_model_config: AppModelConfig, query: str, inputs: dict) -> str:
-        if app.mode != 'completion':
-            return query
-
-        return inputs.get(app_model_config.dataset_query_variable, "")
-
-    @classmethod
-    def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str,
-                      inputs: dict,
-                      files: List[PromptMessageFile],
-                      agent_execute_result: Optional[AgentExecuteResult],
-                      conversation_message_task: ConversationMessageTask,
-                      memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
-                      fake_response: Optional[str]):
-        prompt_transform = PromptTransform()
-
-        # get llm prompt
-        if app_model_config.prompt_type == 'simple':
-            prompt_messages, stop_words = prompt_transform.get_prompt(
-                app_mode=mode,
-                pre_prompt=app_model_config.pre_prompt,
-                inputs=inputs,
-                query=query,
-                files=files,
-                context=agent_execute_result.output if agent_execute_result else None,
-                memory=memory,
-                model_instance=model_instance
-            )
-        else:
-            prompt_messages = prompt_transform.get_advanced_prompt(
-                app_mode=mode,
-                app_model_config=app_model_config,
-                inputs=inputs,
-                query=query,
-                files=files,
-                context=agent_execute_result.output if agent_execute_result else None,
-                memory=memory,
-                model_instance=model_instance
-            )
-
-            model_config = app_model_config.model_dict
-            completion_params = model_config.get("completion_params", {})
-            stop_words = completion_params.get("stop", [])
-
-        cls.recale_llm_max_tokens(
-            model_instance=model_instance,
-            prompt_messages=prompt_messages,
-        )
-
-        response = model_instance.run(
-            messages=prompt_messages,
-            stop=stop_words if stop_words else None,
-            callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
-            fake_response=fake_response
-        )
-        return response
-
-    @classmethod
-    def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
-                                         max_token_limit: int) -> str:
-        """Get memory messages."""
-        memory.max_token_limit = max_token_limit
-        memory_key = memory.memory_variables[0]
-        external_context = memory.load_memory_variables({})
-        return external_context[memory_key]
-
-    @classmethod
-    def query_app_annotations_to_reply(cls, conversation_message_task: ConversationMessageTask,
-                                       from_source: str) -> bool:
-        """Get memory messages."""
-        app_model_config = conversation_message_task.app_model_config
-        app = conversation_message_task.app
-        annotation_reply = app_model_config.annotation_reply_dict
-        if annotation_reply['enabled']:
-            try:
-                score_threshold = annotation_reply.get('score_threshold', 1)
-                embedding_provider_name = annotation_reply['embedding_model']['embedding_provider_name']
-                embedding_model_name = annotation_reply['embedding_model']['embedding_model_name']
-                # get embedding model
-                embedding_model = ModelFactory.get_embedding_model(
-                    tenant_id=app.tenant_id,
-                    model_provider_name=embedding_provider_name,
-                    model_name=embedding_model_name
-                )
-                embeddings = CacheEmbedding(embedding_model)
-
-                dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
-                    embedding_provider_name,
-                    embedding_model_name,
-                    'annotation'
-                )
-
-                dataset = Dataset(
-                    id=app.id,
-                    tenant_id=app.tenant_id,
-                    indexing_technique='high_quality',
-                    embedding_model_provider=embedding_provider_name,
-                    embedding_model=embedding_model_name,
-                    collection_binding_id=dataset_collection_binding.id
-                )
-
-                vector_index = VectorIndex(
-                    dataset=dataset,
-                    config=current_app.config,
-                    embeddings=embeddings,
-                    attributes=['doc_id', 'annotation_id', 'app_id']
-                )
-
-                documents = vector_index.search(
-                    conversation_message_task.query,
-                    search_type='similarity_score_threshold',
-                    search_kwargs={
-                        'k': 1,
-                        'score_threshold': score_threshold,
-                        'filter': {
-                            'group_id': [dataset.id]
-                        }
-                    }
-                )
-                if documents:
-                    annotation_id = documents[0].metadata['annotation_id']
-                    score = documents[0].metadata['score']
-                    annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
-                    if annotation:
-                        conversation_message_task.annotation_end(annotation.content, annotation.id, annotation.account.name)
-                        # insert annotation history
-                        AppAnnotationService.add_annotation_history(annotation.id,
-                                                                    app.id,
-                                                                    annotation.question,
-                                                                    annotation.content,
-                                                                    conversation_message_task.query,
-                                                                    conversation_message_task.user.id,
-                                                                    conversation_message_task.message.id,
-                                                                    from_source,
-                                                                    score)
-                        return True
-            except Exception as e:
-                logging.warning(f'Query annotation failed, exception: {str(e)}.')
-                return False
-        return False
-
-    @classmethod
-    def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig,
-                                     conversation: Conversation,
-                                     **kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory:
-        # only for calc token in memory
-        memory_model_instance = ModelFactory.get_text_generation_model_from_model_config(
-            tenant_id=tenant_id,
-            model_config=app_model_config.model_dict
-        )
-
-        # use llm config from conversation
-        memory = ReadOnlyConversationTokenDBBufferSharedMemory(
-            conversation=conversation,
-            model_instance=memory_model_instance,
-            max_token_limit=kwargs.get("max_token_limit", 2048),
-            memory_key=kwargs.get("memory_key", "chat_history"),
-            return_messages=kwargs.get("return_messages", True),
-            input_key=kwargs.get("input_key", "input"),
-            output_key=kwargs.get("output_key", "output"),
-            message_limit=kwargs.get("message_limit", 10),
-        )
-
-        return memory
-
-    @classmethod
-    def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_config: AppModelConfig,
-                                 query: str, inputs: dict, files: List[PromptMessageFile]) -> int:
-        model_limited_tokens = model_instance.model_rules.max_tokens.max
-        max_tokens = model_instance.get_model_kwargs().max_tokens
-
-        if model_limited_tokens is None:
-            return -1
-
-        if max_tokens is None:
-            max_tokens = 0
-
-        prompt_transform = PromptTransform()
-
-        # get prompt without memory and context
-        if app_model_config.prompt_type == 'simple':
-            prompt_messages, _ = prompt_transform.get_prompt(
-                app_mode=mode,
-                pre_prompt=app_model_config.pre_prompt,
-                inputs=inputs,
-                query=query,
-                files=files,
-                context=None,
-                memory=None,
-                model_instance=model_instance
-            )
-        else:
-            prompt_messages = prompt_transform.get_advanced_prompt(
-                app_mode=mode,
-                app_model_config=app_model_config,
-                inputs=inputs,
-                query=query,
-                files=files,
-                context=None,
-                memory=None,
-                model_instance=model_instance
-            )
-
-        prompt_tokens = model_instance.get_num_tokens(prompt_messages)
-        rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
-        if rest_tokens < 0:
-            raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
-                                     "or shrink the max token, or switch to a llm with a larger token limit size.")
-
-        return rest_tokens
-
-    @classmethod
-    def recale_llm_max_tokens(cls, model_instance: BaseLLM, prompt_messages: List[PromptMessage]):
-        # recalc max_tokens if sum(prompt_token +  max_tokens) over model token limit
-        model_limited_tokens = model_instance.model_rules.max_tokens.max
-        max_tokens = model_instance.get_model_kwargs().max_tokens
-
-        if model_limited_tokens is None:
-            return
-
-        if max_tokens is None:
-            max_tokens = 0
-
-        prompt_tokens = model_instance.get_num_tokens(prompt_messages)
-
-        if prompt_tokens + max_tokens > model_limited_tokens:
-            max_tokens = max(model_limited_tokens - prompt_tokens, 16)
-
-            # update model instance max tokens
-            model_kwargs = model_instance.get_model_kwargs()
-            model_kwargs.max_tokens = max_tokens
-            model_instance.set_model_kwargs(model_kwargs)

+ 0 - 517
api/core/conversation_message_task.py

@@ -1,517 +0,0 @@
-import json
-import time
-from typing import Optional, Union, List
-
-from core.callback_handler.entity.agent_loop import AgentLoop
-from core.callback_handler.entity.dataset_query import DatasetQueryObj
-from core.callback_handler.entity.llm_message import LLMMessage
-from core.callback_handler.entity.chain_result import ChainResult
-from core.file.file_obj import FileObj
-from core.model_providers.model_factory import ModelFactory
-from core.model_providers.models.entity.message import to_prompt_messages, MessageType, PromptMessageFile
-from core.model_providers.models.llm.base import BaseLLM
-from core.prompt.prompt_builder import PromptBuilder
-from core.prompt.prompt_template import PromptTemplateParser
-from events.message_event import message_was_created
-from extensions.ext_database import db
-from extensions.ext_redis import redis_client
-from models.dataset import DatasetQuery
-from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, \
-    MessageChain, DatasetRetrieverResource, MessageFile
-
-
-class ConversationMessageTask:
-    def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
-                 inputs: dict, query: str, files: List[FileObj], streaming: bool,
-                 model_instance: BaseLLM, conversation: Optional[Conversation] = None, is_override: bool = False,
-                 auto_generate_name: bool = True):
-        self.start_at = time.perf_counter()
-
-        self.task_id = task_id
-
-        self.app = app
-        self.tenant_id = app.tenant_id
-        self.app_model_config = app_model_config
-        self.is_override = is_override
-
-        self.user = user
-        self.inputs = inputs
-        self.query = query
-        self.files = files
-        self.streaming = streaming
-
-        self.conversation = conversation
-        self.is_new_conversation = False
-
-        self.model_instance = model_instance
-
-        self.message = None
-
-        self.retriever_resource = None
-        self.auto_generate_name = auto_generate_name
-
-        self.model_dict = self.app_model_config.model_dict
-        self.provider_name = self.model_dict.get('provider')
-        self.model_name = self.model_dict.get('name')
-        self.mode = app.mode
-
-        self.init()
-
-        self._pub_handler = PubHandler(
-            user=self.user,
-            task_id=self.task_id,
-            message=self.message,
-            conversation=self.conversation,
-            chain_pub=False,  # disabled currently
-            agent_thought_pub=True
-        )
-
-    def init(self):
-
-        override_model_configs = None
-        if self.is_override:
-            override_model_configs = self.app_model_config.to_dict()
-
-        introduction = ''
-        system_instruction = ''
-        system_instruction_tokens = 0
-        if self.mode == 'chat':
-            introduction = self.app_model_config.opening_statement
-            if introduction:
-                prompt_template = PromptTemplateParser(template=introduction)
-                prompt_inputs = {k: self.inputs[k] for k in prompt_template.variable_keys if k in self.inputs}
-                try:
-                    introduction = prompt_template.format(prompt_inputs)
-                except KeyError:
-                    pass
-
-            if self.app_model_config.pre_prompt:
-                system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs)
-                system_instruction = system_message.content
-                model_instance = ModelFactory.get_text_generation_model(
-                    tenant_id=self.tenant_id,
-                    model_provider_name=self.provider_name,
-                    model_name=self.model_name
-                )
-                system_instruction_tokens = model_instance.get_num_tokens(to_prompt_messages([system_message]))
-
-        if not self.conversation:
-            self.is_new_conversation = True
-            self.conversation = Conversation(
-                app_id=self.app.id,
-                app_model_config_id=self.app_model_config.id,
-                model_provider=self.provider_name,
-                model_id=self.model_name,
-                override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
-                mode=self.mode,
-                name='New conversation',
-                inputs=self.inputs,
-                introduction=introduction,
-                system_instruction=system_instruction,
-                system_instruction_tokens=system_instruction_tokens,
-                status='normal',
-                from_source=('console' if isinstance(self.user, Account) else 'api'),
-                from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
-                from_account_id=(self.user.id if isinstance(self.user, Account) else None),
-            )
-
-            db.session.add(self.conversation)
-            db.session.commit()
-
-        self.message = Message(
-            app_id=self.app.id,
-            model_provider=self.provider_name,
-            model_id=self.model_name,
-            override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
-            conversation_id=self.conversation.id,
-            inputs=self.inputs,
-            query=self.query,
-            message="",
-            message_tokens=0,
-            message_unit_price=0,
-            message_price_unit=0,
-            answer="",
-            answer_tokens=0,
-            answer_unit_price=0,
-            answer_price_unit=0,
-            provider_response_latency=0,
-            total_price=0,
-            currency=self.model_instance.get_currency(),
-            from_source=('console' if isinstance(self.user, Account) else 'api'),
-            from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
-            from_account_id=(self.user.id if isinstance(self.user, Account) else None),
-            agent_based=self.app_model_config.agent_mode_dict.get('enabled'),
-        )
-
-        db.session.add(self.message)
-        db.session.commit()
-
-        for file in self.files:
-            message_file = MessageFile(
-                message_id=self.message.id,
-                type=file.type.value,
-                transfer_method=file.transfer_method.value,
-                url=file.url,
-                upload_file_id=file.upload_file_id,
-                created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
-                created_by=self.user.id
-            )
-            db.session.add(message_file)
-            db.session.commit()
-
-    def append_message_text(self, text: str):
-        if text is not None:
-            self._pub_handler.pub_text(text)
-
-    def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
-        message_tokens = llm_message.prompt_tokens
-        answer_tokens = llm_message.completion_tokens
-
-        message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.USER)
-        message_price_unit = self.model_instance.get_price_unit(MessageType.USER)
-        answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
-        answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT)
-
-        message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.USER)
-        answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
-        total_price = message_total_price + answer_total_price
-
-        self.message.message = llm_message.prompt
-        self.message.message_tokens = message_tokens
-        self.message.message_unit_price = message_unit_price
-        self.message.message_price_unit = message_price_unit
-        self.message.answer = PromptTemplateParser.remove_template_variables(
-            llm_message.completion.strip()) if llm_message.completion else ''
-        self.message.answer_tokens = answer_tokens
-        self.message.answer_unit_price = answer_unit_price
-        self.message.answer_price_unit = answer_price_unit
-        self.message.provider_response_latency = time.perf_counter() - self.start_at
-        self.message.total_price = total_price
-
-        db.session.commit()
-
-        message_was_created.send(
-            self.message,
-            conversation=self.conversation,
-            is_first_message=self.is_new_conversation,
-            auto_generate_name=self.auto_generate_name
-        )
-
-        if not by_stopped:
-            self.end()
-
-    def init_chain(self, chain_result: ChainResult):
-        message_chain = MessageChain(
-            message_id=self.message.id,
-            type=chain_result.type,
-            input=json.dumps(chain_result.prompt),
-            output=''
-        )
-
-        db.session.add(message_chain)
-        db.session.commit()
-
-        return message_chain
-
-    def on_chain_end(self, message_chain: MessageChain, chain_result: ChainResult):
-        message_chain.output = json.dumps(chain_result.completion)
-        db.session.commit()
-
-        self._pub_handler.pub_chain(message_chain)
-
-    def on_agent_start(self, message_chain: MessageChain, agent_loop: AgentLoop) -> MessageAgentThought:
-        message_agent_thought = MessageAgentThought(
-            message_id=self.message.id,
-            message_chain_id=message_chain.id,
-            position=agent_loop.position,
-            thought=agent_loop.thought,
-            tool=agent_loop.tool_name,
-            tool_input=agent_loop.tool_input,
-            message=agent_loop.prompt,
-            message_price_unit=0,
-            answer=agent_loop.completion,
-            answer_price_unit=0,
-            created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
-            created_by=self.user.id
-        )
-
-        db.session.add(message_agent_thought)
-        db.session.commit()
-
-        self._pub_handler.pub_agent_thought(message_agent_thought)
-
-        return message_agent_thought
-
-    def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM,
-                     agent_loop: AgentLoop):
-        agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.USER)
-        agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.USER)
-        agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
-        agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT)
-
-        loop_message_tokens = agent_loop.prompt_tokens
-        loop_answer_tokens = agent_loop.completion_tokens
-
-        loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.USER)
-        loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
-        loop_total_price = loop_message_total_price + loop_answer_total_price
-
-        message_agent_thought.observation = agent_loop.tool_output
-        message_agent_thought.tool_process_data = ''  # currently not support
-        message_agent_thought.message_token = loop_message_tokens
-        message_agent_thought.message_unit_price = agent_message_unit_price
-        message_agent_thought.message_price_unit = agent_message_price_unit
-        message_agent_thought.answer_token = loop_answer_tokens
-        message_agent_thought.answer_unit_price = agent_answer_unit_price
-        message_agent_thought.answer_price_unit = agent_answer_price_unit
-        message_agent_thought.latency = agent_loop.latency
-        message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
-        message_agent_thought.total_price = loop_total_price
-        message_agent_thought.currency = agent_model_instance.get_currency()
-        db.session.commit()
-
-    def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):
-        dataset_query = DatasetQuery(
-            dataset_id=dataset_query_obj.dataset_id,
-            content=dataset_query_obj.query,
-            source='app',
-            source_app_id=self.app.id,
-            created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
-            created_by=self.user.id
-        )
-
-        db.session.add(dataset_query)
-        db.session.commit()
-
-    def on_dataset_query_finish(self, resource: List):
-        if resource and len(resource) > 0:
-            for item in resource:
-                dataset_retriever_resource = DatasetRetrieverResource(
-                    message_id=self.message.id,
-                    position=item.get('position'),
-                    dataset_id=item.get('dataset_id'),
-                    dataset_name=item.get('dataset_name'),
-                    document_id=item.get('document_id'),
-                    document_name=item.get('document_name'),
-                    data_source_type=item.get('data_source_type'),
-                    segment_id=item.get('segment_id'),
-                    score=item.get('score') if 'score' in item else None,
-                    hit_count=item.get('hit_count') if 'hit_count' else None,
-                    word_count=item.get('word_count') if 'word_count' in item else None,
-                    segment_position=item.get('segment_position') if 'segment_position' in item else None,
-                    index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None,
-                    content=item.get('content'),
-                    retriever_from=item.get('retriever_from'),
-                    created_by=self.user.id
-                )
-                db.session.add(dataset_retriever_resource)
-                db.session.commit()
-            self.retriever_resource = resource
-
-    def on_message_replace(self, text: str):
-        if text is not None:
-            self._pub_handler.pub_message_replace(text)
-
-    def message_end(self):
-        self._pub_handler.pub_message_end(self.retriever_resource)
-
-    def end(self):
-        self._pub_handler.pub_message_end(self.retriever_resource)
-        self._pub_handler.pub_end()
-
-    def annotation_end(self, text: str, annotation_id: str, annotation_author_name: str):
-        self._pub_handler.pub_annotation(text, annotation_id, annotation_author_name, self.start_at)
-        self._pub_handler.pub_end()
-
-
-class PubHandler:
-    def __init__(self, user: Union[Account, EndUser], task_id: str,
-                 message: Message, conversation: Conversation,
-                 chain_pub: bool = False, agent_thought_pub: bool = False):
-        self._channel = PubHandler.generate_channel_name(user, task_id)
-        self._stopped_cache_key = PubHandler.generate_stopped_cache_key(user, task_id)
-
-        self._task_id = task_id
-        self._message = message
-        self._conversation = conversation
-        self._chain_pub = chain_pub
-        self._agent_thought_pub = agent_thought_pub
-
-    @classmethod
-    def generate_channel_name(cls, user: Union[Account, EndUser], task_id: str):
-        if not user:
-            raise ValueError("user is required")
-
-        user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
-        return "generate_result:{}-{}".format(user_str, task_id)
-
-    @classmethod
-    def generate_stopped_cache_key(cls, user: Union[Account, EndUser], task_id: str):
-        user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
-        return "generate_result_stopped:{}-{}".format(user_str, task_id)
-
-    def pub_text(self, text: str):
-        content = {
-            'event': 'message',
-            'data': {
-                'task_id': self._task_id,
-                'message_id': str(self._message.id),
-                'text': text,
-                'mode': self._conversation.mode,
-                'conversation_id': str(self._conversation.id)
-            }
-        }
-
-        redis_client.publish(self._channel, json.dumps(content))
-
-        if self._is_stopped():
-            self.pub_end()
-            raise ConversationTaskStoppedException()
-
-    def pub_message_replace(self, text: str):
-        content = {
-            'event': 'message_replace',
-            'data': {
-                'task_id': self._task_id,
-                'message_id': str(self._message.id),
-                'text': text,
-                'mode': self._conversation.mode,
-                'conversation_id': str(self._conversation.id)
-            }
-        }
-
-        redis_client.publish(self._channel, json.dumps(content))
-
-        if self._is_stopped():
-            self.pub_end()
-            raise ConversationTaskStoppedException()
-
-    def pub_chain(self, message_chain: MessageChain):
-        if self._chain_pub:
-            content = {
-                'event': 'chain',
-                'data': {
-                    'task_id': self._task_id,
-                    'message_id': self._message.id,
-                    'chain_id': message_chain.id,
-                    'type': message_chain.type,
-                    'input': json.loads(message_chain.input),
-                    'output': json.loads(message_chain.output),
-                    'mode': self._conversation.mode,
-                    'conversation_id': self._conversation.id
-                }
-            }
-
-            redis_client.publish(self._channel, json.dumps(content))
-
-        if self._is_stopped():
-            self.pub_end()
-            raise ConversationTaskStoppedException()
-
-    def pub_agent_thought(self, message_agent_thought: MessageAgentThought):
-        if self._agent_thought_pub:
-            content = {
-                'event': 'agent_thought',
-                'data': {
-                    'id': message_agent_thought.id,
-                    'task_id': self._task_id,
-                    'message_id': self._message.id,
-                    'chain_id': message_agent_thought.message_chain_id,
-                    'position': message_agent_thought.position,
-                    'thought': message_agent_thought.thought,
-                    'tool': message_agent_thought.tool,
-                    'tool_input': message_agent_thought.tool_input,
-                    'mode': self._conversation.mode,
-                    'conversation_id': self._conversation.id
-                }
-            }
-
-            redis_client.publish(self._channel, json.dumps(content))
-
-        if self._is_stopped():
-            self.pub_end()
-            raise ConversationTaskStoppedException()
-
-    def pub_message_end(self, retriever_resource: List):
-        content = {
-            'event': 'message_end',
-            'data': {
-                'task_id': self._task_id,
-                'message_id': self._message.id,
-                'mode': self._conversation.mode,
-                'conversation_id': self._conversation.id,
-            }
-        }
-        if retriever_resource:
-            content['data']['retriever_resources'] = retriever_resource
-        redis_client.publish(self._channel, json.dumps(content))
-
-        if self._is_stopped():
-            self.pub_end()
-            raise ConversationTaskStoppedException()
-
-    def pub_annotation(self, text: str, annotation_id: str, annotation_author_name: str, start_at: float):
-        content = {
-            'event': 'annotation',
-            'data': {
-                'task_id': self._task_id,
-                'message_id': self._message.id,
-                'mode': self._conversation.mode,
-                'conversation_id': self._conversation.id,
-                'text': text,
-                'annotation_id': annotation_id,
-                'annotation_author_name': annotation_author_name
-            }
-        }
-        self._message.answer = text
-        self._message.provider_response_latency = time.perf_counter() - start_at
-
-        db.session.commit()
-
-        redis_client.publish(self._channel, json.dumps(content))
-
-        if self._is_stopped():
-            self.pub_end()
-            raise ConversationTaskStoppedException()
-
-    def pub_end(self):
-        content = {
-            'event': 'end',
-        }
-
-        redis_client.publish(self._channel, json.dumps(content))
-
-    @classmethod
-    def pub_error(cls, user: Union[Account, EndUser], task_id: str, e):
-        content = {
-            'error': type(e).__name__,
-            'description': e.description if getattr(e, 'description', None) is not None else str(e)
-        }
-
-        channel = cls.generate_channel_name(user, task_id)
-        redis_client.publish(channel, json.dumps(content))
-
-    def _is_stopped(self):
-        return redis_client.get(self._stopped_cache_key) is not None
-
-    @classmethod
-    def ping(cls, user: Union[Account, EndUser], task_id: str):
-        content = {
-            'event': 'ping'
-        }
-
-        channel = cls.generate_channel_name(user, task_id)
-        redis_client.publish(channel, json.dumps(content))
-
-    @classmethod
-    def stop(cls, user: Union[Account, EndUser], task_id: str):
-        stopped_cache_key = cls.generate_stopped_cache_key(user, task_id)
-        redis_client.setex(stopped_cache_key, 600, 1)
-
-
-class ConversationTaskStoppedException(Exception):
-    pass
-
-
-class ConversationTaskInterruptException(Exception):
-    pass

+ 19 - 6
api/core/docstore/dataset_docstore.py

@@ -1,9 +1,11 @@
-from typing import Any, Dict, Optional, Sequence
+from typing import Any, Dict, Optional, Sequence, cast
 
 from langchain.schema import Document
 from sqlalchemy import func
 
-from core.model_providers.model_factory import ModelFactory
+from core.model_manager import ModelManager
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
 from extensions.ext_database import db
 from models.dataset import Dataset, DocumentSegment
 
@@ -69,10 +71,12 @@ class DatasetDocumentStore:
             max_position = 0
         embedding_model = None
         if self._dataset.indexing_technique == 'high_quality':
-            embedding_model = ModelFactory.get_embedding_model(
+            model_manager = ModelManager()
+            embedding_model = model_manager.get_model_instance(
                 tenant_id=self._dataset.tenant_id,
-                model_provider_name=self._dataset.embedding_model_provider,
-                model_name=self._dataset.embedding_model
+                provider=self._dataset.embedding_model_provider,
+                model_type=ModelType.TEXT_EMBEDDING,
+                model=self._dataset.embedding_model
             )
 
         for doc in docs:
@@ -89,7 +93,16 @@ class DatasetDocumentStore:
                 )
 
             # calc embedding use tokens
-            tokens = embedding_model.get_num_tokens(doc.page_content) if embedding_model else 0
+            if embedding_model:
+                model_type_instance = embedding_model.model_type_instance
+                model_type_instance = cast(TextEmbeddingModel, model_type_instance)
+                tokens = model_type_instance.get_num_tokens(
+                    model=embedding_model.model,
+                    credentials=embedding_model.credentials,
+                    texts=[doc.page_content]
+                )
+            else:
+                tokens = 0
 
             if not segment_document:
                 max_position += 1

+ 26 - 13
api/core/embedding/cached_embedding.py

@@ -1,19 +1,22 @@
 import logging
-from typing import List
+from typing import List, Optional
 
 import numpy as np
 from langchain.embeddings.base import Embeddings
 from sqlalchemy.exc import IntegrityError
 
-from core.model_providers.models.embedding.base import BaseEmbedding
+from core.model_manager import ModelInstance
 from extensions.ext_database import db
 from libs import helper
 from models.dataset import Embedding
 
+logger = logging.getLogger(__name__)
+
 
 class CacheEmbedding(Embeddings):
-    def __init__(self, embeddings: BaseEmbedding):
-        self._embeddings = embeddings
+    def __init__(self, model_instance: ModelInstance, user: Optional[str] = None) -> None:
+        self._model_instance = model_instance
+        self._user = user
 
     def embed_documents(self, texts: List[str]) -> List[List[float]]:
         """Embed search docs."""
@@ -22,7 +25,7 @@ class CacheEmbedding(Embeddings):
         embedding_queue_indices = []
         for i, text in enumerate(texts):
             hash = helper.generate_text_hash(text)
-            embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first()
+            embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first()
             if embedding:
                 text_embeddings[i] = embedding.get_embedding()
             else:
@@ -30,15 +33,21 @@ class CacheEmbedding(Embeddings):
 
         if embedding_queue_indices:
             try:
-                embedding_results = self._embeddings.client.embed_documents([texts[i] for i in embedding_queue_indices])
+                embedding_result = self._model_instance.invoke_text_embedding(
+                    texts=[texts[i] for i in embedding_queue_indices],
+                    user=self._user
+                )
+
+                embedding_results = embedding_result.embeddings
             except Exception as ex:
-                raise self._embeddings.handle_exceptions(ex)
+                logger.error('Failed to embed documents: ', ex)
+                raise ex
 
             for i, indice in enumerate(embedding_queue_indices):
                 hash = helper.generate_text_hash(texts[indice])
 
                 try:
-                    embedding = Embedding(model_name=self._embeddings.name, hash=hash)
+                    embedding = Embedding(model_name=self._model_instance.model, hash=hash)
                     vector = embedding_results[i]
                     normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
                     text_embeddings[indice] = normalized_embedding
@@ -58,18 +67,23 @@ class CacheEmbedding(Embeddings):
         """Embed query text."""
         # use doc embedding cache or store if not exists
         hash = helper.generate_text_hash(text)
-        embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first()
+        embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first()
         if embedding:
             return embedding.get_embedding()
 
         try:
-            embedding_results = self._embeddings.client.embed_query(text)
+            embedding_result = self._model_instance.invoke_text_embedding(
+                texts=[text],
+                user=self._user
+            )
+
+            embedding_results = embedding_result.embeddings[0]
             embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
         except Exception as ex:
-            raise self._embeddings.handle_exceptions(ex)
+            raise ex
 
         try:
-            embedding = Embedding(model_name=self._embeddings.name, hash=hash)
+            embedding = Embedding(model_name=self._model_instance.model, hash=hash)
             embedding.set_embedding(embedding_results)
             db.session.add(embedding)
             db.session.commit()
@@ -79,4 +93,3 @@ class CacheEmbedding(Embeddings):
             logging.exception('Failed to add embedding to db')
 
         return embedding_results
-

+ 0 - 0
api/core/model_providers/models/embedding/__init__.py → api/core/entities/__init__.py


+ 265 - 0
api/core/entities/application_entities.py

@@ -0,0 +1,265 @@
+from enum import Enum
+from typing import Optional, Any, cast
+
+from pydantic import BaseModel
+
+from core.entities.provider_configuration import ProviderModelBundle
+from core.file.file_obj import FileObj
+from core.model_runtime.entities.message_entities import PromptMessageRole
+from core.model_runtime.entities.model_entities import AIModelEntity
+
+
+class ModelConfigEntity(BaseModel):
+    """
+    Model Config Entity.
+    """
+    provider: str
+    model: str
+    model_schema: AIModelEntity
+    mode: str
+    provider_model_bundle: ProviderModelBundle
+    credentials: dict[str, Any] = {}
+    parameters: dict[str, Any] = {}
+    stop: list[str] = []
+
+
+class AdvancedChatMessageEntity(BaseModel):
+    """
+    Advanced Chat Message Entity.
+    """
+    text: str
+    role: PromptMessageRole
+
+
+class AdvancedChatPromptTemplateEntity(BaseModel):
+    """
+    Advanced Chat Prompt Template Entity.
+    """
+    messages: list[AdvancedChatMessageEntity]
+
+
+class AdvancedCompletionPromptTemplateEntity(BaseModel):
+    """
+    Advanced Completion Prompt Template Entity.
+    """
+    class RolePrefixEntity(BaseModel):
+        """
+        Role Prefix Entity.
+        """
+        user: str
+        assistant: str
+
+    prompt: str
+    role_prefix: Optional[RolePrefixEntity] = None
+
+
+class PromptTemplateEntity(BaseModel):
+    """
+    Prompt Template Entity.
+    """
+    class PromptType(Enum):
+        """
+        Prompt Type.
+        'simple', 'advanced'
+        """
+        SIMPLE = 'simple'
+        ADVANCED = 'advanced'
+
+        @classmethod
+        def value_of(cls, value: str) -> 'PromptType':
+            """
+            Get value of given mode.
+
+            :param value: mode value
+            :return: mode
+            """
+            for mode in cls:
+                if mode.value == value:
+                    return mode
+            raise ValueError(f'invalid prompt type value {value}')
+
+    prompt_type: PromptType
+    simple_prompt_template: Optional[str] = None
+    advanced_chat_prompt_template: Optional[AdvancedChatPromptTemplateEntity] = None
+    advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None
+
+
+class ExternalDataVariableEntity(BaseModel):
+    """
+    External Data Variable Entity.
+    """
+    variable: str
+    type: str
+    config: dict[str, Any] = {}
+
+
+class DatasetRetrieveConfigEntity(BaseModel):
+    """
+    Dataset Retrieve Config Entity.
+    """
+    class RetrieveStrategy(Enum):
+        """
+        Dataset Retrieve Strategy.
+        'single' or 'multiple'
+        """
+        SINGLE = 'single'
+        MULTIPLE = 'multiple'
+
+        @classmethod
+        def value_of(cls, value: str) -> 'RetrieveStrategy':
+            """
+            Get value of given mode.
+
+            :param value: mode value
+            :return: mode
+            """
+            for mode in cls:
+                if mode.value == value:
+                    return mode
+            raise ValueError(f'invalid retrieve strategy value {value}')
+
+    query_variable: Optional[str] = None  # Only when app mode is completion
+
+    retrieve_strategy: RetrieveStrategy
+    single_strategy: Optional[str] = None  # for temp
+    top_k: Optional[int] = None
+    score_threshold: Optional[float] = None
+    reranking_model: Optional[dict] = None
+
+
+class DatasetEntity(BaseModel):
+    """
+    Dataset Config Entity.
+    """
+    dataset_ids: list[str]
+    retrieve_config: DatasetRetrieveConfigEntity
+
+
+class SensitiveWordAvoidanceEntity(BaseModel):
+    """
+    Sensitive Word Avoidance Entity.
+    """
+    type: str
+    config: dict[str, Any] = {}
+
+
+class FileUploadEntity(BaseModel):
+    """
+    File Upload Entity.
+    """
+    image_config: Optional[dict[str, Any]] = None
+
+
+class AgentToolEntity(BaseModel):
+    """
+    Agent Tool Entity.
+    """
+    tool_id: str
+    config: dict[str, Any] = {}
+
+
+class AgentEntity(BaseModel):
+    """
+    Agent Entity.
+    """
+    class Strategy(Enum):
+        """
+        Agent Strategy.
+        """
+        CHAIN_OF_THOUGHT = 'chain-of-thought'
+        FUNCTION_CALLING = 'function-calling'
+
+    provider: str
+    model: str
+    strategy: Strategy
+    tools: list[AgentToolEntity] = []
+
+
+class AppOrchestrationConfigEntity(BaseModel):
+    """
+    App Orchestration Config Entity.
+    """
+    model_config: ModelConfigEntity
+    prompt_template: PromptTemplateEntity
+    external_data_variables: list[ExternalDataVariableEntity] = []
+    agent: Optional[AgentEntity] = None
+
+    # features
+    dataset: Optional[DatasetEntity] = None
+    file_upload: Optional[FileUploadEntity] = None
+    opening_statement: Optional[str] = None
+    suggested_questions_after_answer: bool = False
+    show_retrieve_source: bool = False
+    more_like_this: bool = False
+    speech_to_text: bool = False
+    sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
+
+
+class InvokeFrom(Enum):
+    """
+    Invoke From.
+    """
+    SERVICE_API = 'service-api'
+    WEB_APP = 'web-app'
+    EXPLORE = 'explore'
+    DEBUGGER = 'debugger'
+
+    @classmethod
+    def value_of(cls, value: str) -> 'InvokeFrom':
+        """
+        Get value of given mode.
+
+        :param value: mode value
+        :return: mode
+        """
+        for mode in cls:
+            if mode.value == value:
+                return mode
+        raise ValueError(f'invalid invoke from value {value}')
+
+    def to_source(self) -> str:
+        """
+        Get source of invoke from.
+
+        :return: source
+        """
+        if self == InvokeFrom.WEB_APP:
+            return 'web_app'
+        elif self == InvokeFrom.DEBUGGER:
+            return 'dev'
+        elif self == InvokeFrom.EXPLORE:
+            return 'explore_app'
+        elif self == InvokeFrom.SERVICE_API:
+            return 'api'
+
+        return 'dev'
+
+
+class ApplicationGenerateEntity(BaseModel):
+    """
+    Application Generate Entity.
+    """
+    task_id: str
+    tenant_id: str
+
+    app_id: str
+    app_model_config_id: str
+    # for save
+    app_model_config_dict: dict
+    app_model_config_override: bool
+
+    # Converted from app_model_config to Entity object, or directly covered by external input
+    app_orchestration_config_entity: AppOrchestrationConfigEntity
+
+    conversation_id: Optional[str] = None
+    inputs: dict[str, str]
+    query: Optional[str] = None
+    files: list[FileObj] = []
+    user_id: str
+
+    # extras
+    stream: bool
+    invoke_from: InvokeFrom
+
+    # extra parameters, like: auto_generate_conversation_name
+    extras: dict[str, Any] = {}

+ 128 - 0
api/core/entities/message_entities.py

@@ -0,0 +1,128 @@
+import enum
+from typing import Any, cast
+
+from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage, FunctionMessage
+from pydantic import BaseModel
+
+from core.model_runtime.entities.message_entities import PromptMessage, UserPromptMessage, TextPromptMessageContent, \
+    ImagePromptMessageContent, AssistantPromptMessage, SystemPromptMessage, ToolPromptMessage
+
+
+class PromptMessageFileType(enum.Enum):
+    IMAGE = 'image'
+
+    @staticmethod
+    def value_of(value):
+        for member in PromptMessageFileType:
+            if member.value == value:
+                return member
+        raise ValueError(f"No matching enum found for value '{value}'")
+
+
+class PromptMessageFile(BaseModel):
+    type: PromptMessageFileType
+    data: Any
+
+
+class ImagePromptMessageFile(PromptMessageFile):
+    class DETAIL(enum.Enum):
+        LOW = 'low'
+        HIGH = 'high'
+
+    type: PromptMessageFileType = PromptMessageFileType.IMAGE
+    detail: DETAIL = DETAIL.LOW
+
+
+class LCHumanMessageWithFiles(HumanMessage):
+    # content: Union[str, List[Union[str, Dict]]]
+    content: str
+    files: list[PromptMessageFile]
+
+
+def lc_messages_to_prompt_messages(messages: list[BaseMessage]) -> list[PromptMessage]:
+    prompt_messages = []
+    for message in messages:
+        if isinstance(message, HumanMessage):
+            if isinstance(message, LCHumanMessageWithFiles):
+                file_prompt_message_contents = []
+                for file in message.files:
+                    if file.type == PromptMessageFileType.IMAGE:
+                        file = cast(ImagePromptMessageFile, file)
+                        file_prompt_message_contents.append(ImagePromptMessageContent(
+                            data=file.data,
+                            detail=ImagePromptMessageContent.DETAIL.HIGH
+                            if file.detail.value == "high" else ImagePromptMessageContent.DETAIL.LOW
+                        ))
+
+                prompt_message_contents = [TextPromptMessageContent(data=message.content)]
+                prompt_message_contents.extend(file_prompt_message_contents)
+
+                prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
+            else:
+                prompt_messages.append(UserPromptMessage(content=message.content))
+        elif isinstance(message, AIMessage):
+            message_kwargs = {
+                'content': message.content
+            }
+
+            if 'function_call' in message.additional_kwargs:
+                message_kwargs['tool_calls'] = [
+                    AssistantPromptMessage.ToolCall(
+                        id=message.additional_kwargs['function_call']['id'],
+                        type='function',
+                        function=AssistantPromptMessage.ToolCall.ToolCallFunction(
+                            name=message.additional_kwargs['function_call']['name'],
+                            arguments=message.additional_kwargs['function_call']['arguments']
+                        )
+                    )
+                ]
+
+            prompt_messages.append(AssistantPromptMessage(**message_kwargs))
+        elif isinstance(message, SystemMessage):
+            prompt_messages.append(SystemPromptMessage(content=message.content))
+        elif isinstance(message, FunctionMessage):
+            prompt_messages.append(ToolPromptMessage(content=message.content, tool_call_id=message.name))
+
+    return prompt_messages
+
+
+def prompt_messages_to_lc_messages(prompt_messages: list[PromptMessage]) -> list[BaseMessage]:
+    messages = []
+    for prompt_message in prompt_messages:
+        if isinstance(prompt_message, UserPromptMessage):
+            if isinstance(prompt_message.content, str):
+                messages.append(HumanMessage(content=prompt_message.content))
+            else:
+                message_contents = []
+                for content in prompt_message.content:
+                    if isinstance(content, TextPromptMessageContent):
+                        message_contents.append(content.data)
+                    elif isinstance(content, ImagePromptMessageContent):
+                        message_contents.append({
+                            'type': 'image',
+                            'data': content.data,
+                            'detail': content.detail.value
+                        })
+
+                messages.append(HumanMessage(content=message_contents))
+        elif isinstance(prompt_message, AssistantPromptMessage):
+            message_kwargs = {
+                'content': prompt_message.content
+            }
+
+            if prompt_message.tool_calls:
+                message_kwargs['additional_kwargs'] = {
+                    'function_call': {
+                        'id': prompt_message.tool_calls[0].id,
+                        'name': prompt_message.tool_calls[0].function.name,
+                        'arguments': prompt_message.tool_calls[0].function.arguments
+                    }
+                }
+
+            messages.append(AIMessage(**message_kwargs))
+        elif isinstance(prompt_message, SystemPromptMessage):
+            messages.append(SystemMessage(content=prompt_message.content))
+        elif isinstance(prompt_message, ToolPromptMessage):
+            messages.append(FunctionMessage(name=prompt_message.tool_call_id, content=prompt_message.content))
+
+    return messages

+ 71 - 0
api/core/entities/model_entities.py

@@ -0,0 +1,71 @@
+from enum import Enum
+from typing import Optional
+
+from pydantic import BaseModel
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.model_entities import ProviderModel, ModelType
+from core.model_runtime.entities.provider_entities import SimpleProviderEntity, ProviderEntity
+
+
+class ModelStatus(Enum):
+    """
+    Enum class for model status.
+    """
+    ACTIVE = "active"
+    NO_CONFIGURE = "no-configure"
+    QUOTA_EXCEEDED = "quota-exceeded"
+    NO_PERMISSION = "no-permission"
+
+
+class SimpleModelProviderEntity(BaseModel):
+    """
+    Simple provider.
+    """
+    provider: str
+    label: I18nObject
+    icon_small: Optional[I18nObject] = None
+    icon_large: Optional[I18nObject] = None
+    supported_model_types: list[ModelType]
+
+    def __init__(self, provider_entity: ProviderEntity) -> None:
+        """
+        Init simple provider.
+
+        :param provider_entity: provider entity
+        """
+        super().__init__(
+            provider=provider_entity.provider,
+            label=provider_entity.label,
+            icon_small=provider_entity.icon_small,
+            icon_large=provider_entity.icon_large,
+            supported_model_types=provider_entity.supported_model_types
+        )
+
+
+class ModelWithProviderEntity(ProviderModel):
+    """
+    Model with provider entity.
+    """
+    provider: SimpleModelProviderEntity
+    status: ModelStatus
+
+
+class DefaultModelProviderEntity(BaseModel):
+    """
+    Default model provider entity.
+    """
+    provider: str
+    label: I18nObject
+    icon_small: Optional[I18nObject] = None
+    icon_large: Optional[I18nObject] = None
+    supported_model_types: list[ModelType]
+
+
+class DefaultModelEntity(BaseModel):
+    """
+    Default model entity.
+    """
+    model: str
+    model_type: ModelType
+    provider: DefaultModelProviderEntity

+ 657 - 0
api/core/entities/provider_configuration.py

@@ -0,0 +1,657 @@
+import datetime
+import json
+import time
+from json import JSONDecodeError
+from typing import Optional, List, Dict, Tuple, Iterator
+
+from pydantic import BaseModel
+
+from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, SimpleModelProviderEntity
+from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus
+from core.helper import encrypter
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType
+from core.model_runtime.model_providers import model_provider_factory
+from core.model_runtime.model_providers.__base.ai_model import AIModel
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+from core.model_runtime.utils import encoders
+from extensions.ext_database import db
+from models.provider import ProviderType, Provider, ProviderModel, TenantPreferredModelProvider
+
+
+class ProviderConfiguration(BaseModel):
+    """
+    Model class for provider configuration.
+    """
+    tenant_id: str
+    provider: ProviderEntity
+    preferred_provider_type: ProviderType
+    using_provider_type: ProviderType
+    system_configuration: SystemConfiguration
+    custom_configuration: CustomConfiguration
+
+    def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
+        """
+        Get current credentials.
+
+        :param model_type: model type
+        :param model: model name
+        :return:
+        """
+        if self.using_provider_type == ProviderType.SYSTEM:
+            return self.system_configuration.credentials
+        else:
+            if self.custom_configuration.models:
+                for model_configuration in self.custom_configuration.models:
+                    if model_configuration.model_type == model_type and model_configuration.model == model:
+                        return model_configuration.credentials
+
+            if self.custom_configuration.provider:
+                return self.custom_configuration.provider.credentials
+            else:
+                return None
+
+    def get_system_configuration_status(self) -> SystemConfigurationStatus:
+        """
+        Get system configuration status.
+        :return:
+        """
+        if self.system_configuration.enabled is False:
+            return SystemConfigurationStatus.UNSUPPORTED
+
+        current_quota_type = self.system_configuration.current_quota_type
+        current_quota_configuration = next(
+            (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type),
+            None
+        )
+
+        return SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else \
+            SystemConfigurationStatus.QUOTA_EXCEEDED
+
+    def is_custom_configuration_available(self) -> bool:
+        """
+        Check custom configuration available.
+        :return:
+        """
+        return (self.custom_configuration.provider is not None
+                or len(self.custom_configuration.models) > 0)
+
+    def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
+        """
+        Get custom credentials.
+
+        :param obfuscated: obfuscated secret data in credentials
+        :return:
+        """
+        if self.custom_configuration.provider is None:
+            return None
+
+        credentials = self.custom_configuration.provider.credentials
+        if not obfuscated:
+            return credentials
+
+        # Obfuscate credentials
+        return self._obfuscated_credentials(
+            credentials=credentials,
+            credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
+            if self.provider.provider_credential_schema else []
+        )
+
+    def custom_credentials_validate(self, credentials: dict) -> Tuple[Provider, dict]:
+        """
+        Validate custom credentials.
+        :param credentials: provider credentials
+        :return:
+        """
+        # get provider
+        provider_record = db.session.query(Provider) \
+            .filter(
+            Provider.tenant_id == self.tenant_id,
+            Provider.provider_name == self.provider.provider,
+            Provider.provider_type == ProviderType.CUSTOM.value
+        ).first()
+
+        # Get provider credential secret variables
+        provider_credential_secret_variables = self._extract_secret_variables(
+            self.provider.provider_credential_schema.credential_form_schemas
+            if self.provider.provider_credential_schema else []
+        )
+
+        if provider_record:
+            try:
+                original_credentials = json.loads(provider_record.encrypted_config) if provider_record.encrypted_config else {}
+            except JSONDecodeError:
+                original_credentials = {}
+
+            # encrypt credentials
+            for key, value in credentials.items():
+                if key in provider_credential_secret_variables:
+                    # if send [__HIDDEN__] in secret input, it will be same as original value
+                    if value == '[__HIDDEN__]' and key in original_credentials:
+                        credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
+
+        model_provider_factory.provider_credentials_validate(
+            self.provider.provider,
+            credentials
+        )
+
+        for key, value in credentials.items():
+            if key in provider_credential_secret_variables:
+                credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
+
+        return provider_record, credentials
+
+    def add_or_update_custom_credentials(self, credentials: dict) -> None:
+        """
+        Add or update custom provider credentials.
+        :param credentials:
+        :return:
+        """
+        # validate custom provider config
+        provider_record, credentials = self.custom_credentials_validate(credentials)
+
+        # save provider
+        # Note: Do not switch the preferred provider, which allows users to use quotas first
+        if provider_record:
+            provider_record.encrypted_config = json.dumps(credentials)
+            provider_record.is_valid = True
+            provider_record.updated_at = datetime.datetime.utcnow()
+            db.session.commit()
+        else:
+            provider_record = Provider(
+                tenant_id=self.tenant_id,
+                provider_name=self.provider.provider,
+                provider_type=ProviderType.CUSTOM.value,
+                encrypted_config=json.dumps(credentials),
+                is_valid=True
+            )
+            db.session.add(provider_record)
+            db.session.commit()
+
+        self.switch_preferred_provider_type(ProviderType.CUSTOM)
+
+    def delete_custom_credentials(self) -> None:
+        """
+        Delete custom provider credentials.
+        :return:
+        """
+        # get provider
+        provider_record = db.session.query(Provider) \
+            .filter(
+            Provider.tenant_id == self.tenant_id,
+            Provider.provider_name == self.provider.provider,
+            Provider.provider_type == ProviderType.CUSTOM.value
+        ).first()
+
+        # delete provider
+        if provider_record:
+            self.switch_preferred_provider_type(ProviderType.SYSTEM)
+
+            db.session.delete(provider_record)
+            db.session.commit()
+
+    def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \
+            -> Optional[dict]:
+        """
+        Get custom model credentials.
+
+        :param model_type: model type
+        :param model: model name
+        :param obfuscated: obfuscated secret data in credentials
+        :return:
+        """
+        if not self.custom_configuration.models:
+            return None
+
+        for model_configuration in self.custom_configuration.models:
+            if model_configuration.model_type == model_type and model_configuration.model == model:
+                credentials = model_configuration.credentials
+                if not obfuscated:
+                    return credentials
+
+                # Obfuscate credentials
+                return self._obfuscated_credentials(
+                    credentials=credentials,
+                    credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
+                    if self.provider.model_credential_schema else []
+                )
+
+        return None
+
+    def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \
+            -> Tuple[ProviderModel, dict]:
+        """
+        Validate custom model credentials.
+
+        :param model_type: model type
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        # get provider model
+        provider_model_record = db.session.query(ProviderModel) \
+            .filter(
+            ProviderModel.tenant_id == self.tenant_id,
+            ProviderModel.provider_name == self.provider.provider,
+            ProviderModel.model_name == model,
+            ProviderModel.model_type == model_type.to_origin_model_type()
+        ).first()
+
+        # Get provider credential secret variables
+        provider_credential_secret_variables = self._extract_secret_variables(
+            self.provider.model_credential_schema.credential_form_schemas
+            if self.provider.model_credential_schema else []
+        )
+
+        if provider_model_record:
+            try:
+                original_credentials = json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
+            except JSONDecodeError:
+                original_credentials = {}
+
+            # decrypt credentials
+            for key, value in credentials.items():
+                if key in provider_credential_secret_variables:
+                    # if send [__HIDDEN__] in secret input, it will be same as original value
+                    if value == '[__HIDDEN__]' and key in original_credentials:
+                        credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
+
+        model_provider_factory.model_credentials_validate(
+            provider=self.provider.provider,
+            model_type=model_type,
+            model=model,
+            credentials=credentials
+        )
+
+        model_schema = (
+            model_provider_factory.get_provider_instance(self.provider.provider)
+            .get_model_instance(model_type)._get_customizable_model_schema(
+                model=model,
+                credentials=credentials
+            )
+        )
+
+        if model_schema:
+            credentials['schema'] = json.dumps(encoders.jsonable_encoder(model_schema))
+
+        for key, value in credentials.items():
+            if key in provider_credential_secret_variables:
+                credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
+
+        return provider_model_record, credentials
+
+    def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None:
+        """
+        Add or update custom model credentials.
+
+        :param model_type: model type
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        # validate custom model config
+        provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials)
+
+        # save provider model
+        # Note: Do not switch the preferred provider, which allows users to use quotas first
+        if provider_model_record:
+            provider_model_record.encrypted_config = json.dumps(credentials)
+            provider_model_record.is_valid = True
+            provider_model_record.updated_at = datetime.datetime.utcnow()
+            db.session.commit()
+        else:
+            provider_model_record = ProviderModel(
+                tenant_id=self.tenant_id,
+                provider_name=self.provider.provider,
+                model_name=model,
+                model_type=model_type.to_origin_model_type(),
+                encrypted_config=json.dumps(credentials),
+                is_valid=True
+            )
+            db.session.add(provider_model_record)
+            db.session.commit()
+
+    def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None:
+        """
+        Delete custom model credentials.
+        :param model_type: model type
+        :param model: model name
+        :return:
+        """
+        # get provider model
+        provider_model_record = db.session.query(ProviderModel) \
+            .filter(
+            ProviderModel.tenant_id == self.tenant_id,
+            ProviderModel.provider_name == self.provider.provider,
+            ProviderModel.model_name == model,
+            ProviderModel.model_type == model_type.to_origin_model_type()
+        ).first()
+
+        # delete provider model
+        if provider_model_record:
+            db.session.delete(provider_model_record)
+            db.session.commit()
+
+    def get_provider_instance(self) -> ModelProvider:
+        """
+        Get provider instance.
+        :return:
+        """
+        return model_provider_factory.get_provider_instance(self.provider.provider)
+
+    def get_model_type_instance(self, model_type: ModelType) -> AIModel:
+        """
+        Get current model type instance.
+
+        :param model_type: model type
+        :return:
+        """
+        # Get provider instance
+        provider_instance = self.get_provider_instance()
+
+        # Get model instance of LLM
+        return provider_instance.get_model_instance(model_type)
+
+    def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
+        """
+        Switch preferred provider type.
+        :param provider_type:
+        :return:
+        """
+        if provider_type == self.preferred_provider_type:
+            return
+
+        if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
+            return
+
+        # get preferred provider
+        preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
+            .filter(
+            TenantPreferredModelProvider.tenant_id == self.tenant_id,
+            TenantPreferredModelProvider.provider_name == self.provider.provider
+        ).first()
+
+        if preferred_model_provider:
+            preferred_model_provider.preferred_provider_type = provider_type.value
+        else:
+            preferred_model_provider = TenantPreferredModelProvider(
+                tenant_id=self.tenant_id,
+                provider_name=self.provider.provider,
+                preferred_provider_type=provider_type.value
+            )
+            db.session.add(preferred_model_provider)
+
+        db.session.commit()
+
+    def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
+        """
+        Extract secret input form variables.
+
+        :param credential_form_schemas:
+        :return:
+        """
+        secret_input_form_variables = []
+        for credential_form_schema in credential_form_schemas:
+            if credential_form_schema.type == FormType.SECRET_INPUT:
+                secret_input_form_variables.append(credential_form_schema.variable)
+
+        return secret_input_form_variables
+
+    def _obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
+        """
+        Obfuscated credentials.
+
+        :param credentials: credentials
+        :param credential_form_schemas: credential form schemas
+        :return:
+        """
+        # Get provider credential secret variables
+        credential_secret_variables = self._extract_secret_variables(
+            credential_form_schemas
+        )
+
+        # Obfuscate provider credentials
+        copy_credentials = credentials.copy()
+        for key, value in copy_credentials.items():
+            if key in credential_secret_variables:
+                copy_credentials[key] = encrypter.obfuscated_token(value)
+
+        return copy_credentials
+
+    def get_provider_model(self, model_type: ModelType,
+                           model: str,
+                           only_active: bool = False) -> Optional[ModelWithProviderEntity]:
+        """
+        Get provider model.
+        :param model_type: model type
+        :param model: model name
+        :param only_active: return active model only
+        :return:
+        """
+        provider_models = self.get_provider_models(model_type, only_active)
+
+        for provider_model in provider_models:
+            if provider_model.model == model:
+                return provider_model
+
+        return None
+
+    def get_provider_models(self, model_type: Optional[ModelType] = None,
+                            only_active: bool = False) -> list[ModelWithProviderEntity]:
+        """
+        Get provider models.
+        :param model_type: model type
+        :param only_active: only active models
+        :return:
+        """
+        provider_instance = self.get_provider_instance()
+
+        model_types = []
+        if model_type:
+            model_types.append(model_type)
+        else:
+            model_types = provider_instance.get_provider_schema().supported_model_types
+
+        if self.using_provider_type == ProviderType.SYSTEM:
+            provider_models = self._get_system_provider_models(
+                model_types=model_types,
+                provider_instance=provider_instance
+            )
+        else:
+            provider_models = self._get_custom_provider_models(
+                model_types=model_types,
+                provider_instance=provider_instance
+            )
+
+        if only_active:
+            provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
+
+        # resort provider_models
+        return sorted(provider_models, key=lambda x: x.model_type.value)
+
+    def _get_system_provider_models(self,
+                                    model_types: list[ModelType],
+                                    provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
+        """
+        Get system provider models.
+
+        :param model_types: model types
+        :param provider_instance: provider instance
+        :return:
+        """
+        provider_models = []
+        for model_type in model_types:
+            provider_models.extend(
+                [
+                    ModelWithProviderEntity(
+                        **m.dict(),
+                        provider=SimpleModelProviderEntity(self.provider),
+                        status=ModelStatus.ACTIVE
+                    )
+                    for m in provider_instance.models(model_type)
+                ]
+            )
+
+        for quota_configuration in self.system_configuration.quota_configurations:
+            if self.system_configuration.current_quota_type != quota_configuration.quota_type:
+                continue
+
+            restrict_llms = quota_configuration.restrict_llms
+            if not restrict_llms:
+                break
+
+            # if llm name not in restricted llm list, remove it
+            for m in provider_models:
+                if m.model_type == ModelType.LLM and m.model not in restrict_llms:
+                    m.status = ModelStatus.NO_PERMISSION
+                elif not quota_configuration.is_valid:
+                    m.status = ModelStatus.QUOTA_EXCEEDED
+
+        return provider_models
+
+    def _get_custom_provider_models(self,
+                                    model_types: list[ModelType],
+                                    provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
+        """
+        Get custom provider models.
+
+        :param model_types: model types
+        :param provider_instance: provider instance
+        :return:
+        """
+        provider_models = []
+
+        credentials = None
+        if self.custom_configuration.provider:
+            credentials = self.custom_configuration.provider.credentials
+
+        for model_type in model_types:
+            if model_type not in self.provider.supported_model_types:
+                continue
+
+            models = provider_instance.models(model_type)
+            for m in models:
+                provider_models.append(
+                    ModelWithProviderEntity(
+                        **m.dict(),
+                        provider=SimpleModelProviderEntity(self.provider),
+                        status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
+                    )
+                )
+
+        # custom models
+        for model_configuration in self.custom_configuration.models:
+            if model_configuration.model_type not in model_types:
+                continue
+
+            custom_model_schema = (
+                provider_instance.get_model_instance(model_configuration.model_type)
+                .get_customizable_model_schema_from_credentials(
+                    model_configuration.model,
+                    model_configuration.credentials
+                )
+            )
+
+            if not custom_model_schema:
+                continue
+
+            provider_models.append(
+                ModelWithProviderEntity(
+                    **custom_model_schema.dict(),
+                    provider=SimpleModelProviderEntity(self.provider),
+                    status=ModelStatus.ACTIVE
+                )
+            )
+
+        return provider_models
+
+
+class ProviderConfigurations(BaseModel):
+    """
+    Model class for provider configuration dict.
+    """
+    tenant_id: str
+    configurations: Dict[str, ProviderConfiguration] = {}
+
+    def __init__(self, tenant_id: str):
+        super().__init__(tenant_id=tenant_id)
+
+    def get_models(self,
+                   provider: Optional[str] = None,
+                   model_type: Optional[ModelType] = None,
+                   only_active: bool = False) \
+            -> list[ModelWithProviderEntity]:
+        """
+        Get available models.
+
+        If preferred provider type is `system`:
+          Get the current **system mode** if provider supported,
+          if all system modes are not available (no quota), it is considered to be the **custom credential mode**.
+          If there is no model configured in custom mode, it is treated as no_configure.
+        system > custom > no_configure
+
+        If preferred provider type is `custom`:
+          If custom credentials are configured, it is treated as custom mode.
+          Otherwise, get the current **system mode** if supported,
+          If all system modes are not available (no quota), it is treated as no_configure.
+        custom > system > no_configure
+
+        If real mode is `system`, use system credentials to get models,
+          paid quotas > provider free quotas > system free quotas
+          include pre-defined models (exclude GPT-4, status marked as `no_permission`).
+        If real mode is `custom`, use workspace custom credentials to get models,
+          include pre-defined models, custom models(manual append).
+        If real mode is `no_configure`, only return pre-defined models from `model runtime`.
+          (model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`)
+        model status marked as `active` is available.
+
+        :param provider: provider name
+        :param model_type: model type
+        :param only_active: only active models
+        :return:
+        """
+        all_models = []
+        for provider_configuration in self.values():
+            if provider and provider_configuration.provider.provider != provider:
+                continue
+
+            all_models.extend(provider_configuration.get_provider_models(model_type, only_active))
+
+        return all_models
+
+    def to_list(self) -> List[ProviderConfiguration]:
+        """
+        Convert to list.
+
+        :return:
+        """
+        return list(self.values())
+
+    def __getitem__(self, key):
+        return self.configurations[key]
+
+    def __setitem__(self, key, value):
+        self.configurations[key] = value
+
+    def __iter__(self):
+        return iter(self.configurations)
+
+    def values(self) -> Iterator[ProviderConfiguration]:
+        return self.configurations.values()
+
+    def get(self, key, default=None):
+        return self.configurations.get(key, default)
+
+
+class ProviderModelBundle(BaseModel):
+    """
+    Provider model bundle.
+    """
+    configuration: ProviderConfiguration
+    provider_instance: ModelProvider
+    model_type_instance: AIModel
+
+    class Config:
+        """Configuration for this pydantic object."""
+
+        arbitrary_types_allowed = True

+ 67 - 0
api/core/entities/provider_entities.py

@@ -0,0 +1,67 @@
+from enum import Enum
+from typing import Optional
+
+from pydantic import BaseModel
+
+from core.model_runtime.entities.model_entities import ModelType
+from models.provider import ProviderQuotaType
+
+
+class QuotaUnit(Enum):
+    TIMES = 'times'
+    TOKENS = 'tokens'
+
+
+class SystemConfigurationStatus(Enum):
+    """
+    Enum class for system configuration status.
+    """
+    ACTIVE = 'active'
+    QUOTA_EXCEEDED = 'quota-exceeded'
+    UNSUPPORTED = 'unsupported'
+
+
+class QuotaConfiguration(BaseModel):
+    """
+    Model class for provider quota configuration.
+    """
+    quota_type: ProviderQuotaType
+    quota_unit: QuotaUnit
+    quota_limit: int
+    quota_used: int
+    is_valid: bool
+    restrict_llms: list[str] = []
+
+
+class SystemConfiguration(BaseModel):
+    """
+    Model class for provider system configuration.
+    """
+    enabled: bool
+    current_quota_type: Optional[ProviderQuotaType] = None
+    quota_configurations: list[QuotaConfiguration] = []
+    credentials: Optional[dict] = None
+
+
+class CustomProviderConfiguration(BaseModel):
+    """
+    Model class for provider custom configuration.
+    """
+    credentials: dict
+
+
+class CustomModelConfiguration(BaseModel):
+    """
+    Model class for provider custom model configuration.
+    """
+    model: str
+    model_type: ModelType
+    credentials: dict
+
+
+class CustomConfiguration(BaseModel):
+    """
+    Model class for provider custom configuration.
+    """
+    provider: Optional[CustomProviderConfiguration] = None
+    models: list[CustomModelConfiguration] = []

+ 118 - 0
api/core/entities/queue_entities.py

@@ -0,0 +1,118 @@
+from enum import Enum
+from typing import Any
+
+from pydantic import BaseModel
+
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
+
+
+class QueueEvent(Enum):
+    """
+    QueueEvent enum
+    """
+    MESSAGE = "message"
+    MESSAGE_REPLACE = "message-replace"
+    MESSAGE_END = "message-end"
+    RETRIEVER_RESOURCES = "retriever-resources"
+    ANNOTATION_REPLY = "annotation-reply"
+    AGENT_THOUGHT = "agent-thought"
+    ERROR = "error"
+    PING = "ping"
+    STOP = "stop"
+
+
+class AppQueueEvent(BaseModel):
+    """
+    QueueEvent entity
+    """
+    event: QueueEvent
+
+
+class QueueMessageEvent(AppQueueEvent):
+    """
+    QueueMessageEvent entity
+    """
+    event = QueueEvent.MESSAGE
+    chunk: LLMResultChunk
+    
+    
+class QueueMessageReplaceEvent(AppQueueEvent):
+    """
+    QueueMessageReplaceEvent entity
+    """
+    event = QueueEvent.MESSAGE_REPLACE
+    text: str
+
+
+class QueueRetrieverResourcesEvent(AppQueueEvent):
+    """
+    QueueRetrieverResourcesEvent entity
+    """
+    event = QueueEvent.RETRIEVER_RESOURCES
+    retriever_resources: list[dict]
+
+
+class AnnotationReplyEvent(AppQueueEvent):
+    """
+    AnnotationReplyEvent entity
+    """
+    event = QueueEvent.ANNOTATION_REPLY
+    message_annotation_id: str
+
+
+class QueueMessageEndEvent(AppQueueEvent):
+    """
+    QueueMessageEndEvent entity
+    """
+    event = QueueEvent.MESSAGE_END
+    llm_result: LLMResult
+
+    
+class QueueAgentThoughtEvent(AppQueueEvent):
+    """
+    QueueAgentThoughtEvent entity
+    """
+    event = QueueEvent.AGENT_THOUGHT
+    agent_thought_id: str
+    
+    
+class QueueErrorEvent(AppQueueEvent):
+    """
+    QueueErrorEvent entity
+    """
+    event = QueueEvent.ERROR
+    error: Any
+
+
+class QueuePingEvent(AppQueueEvent):
+    """
+    QueuePingEvent entity
+    """
+    event = QueueEvent.PING
+
+
+class QueueStopEvent(AppQueueEvent):
+    """
+    QueueStopEvent entity
+    """
+    class StopBy(Enum):
+        """
+        Stop by enum
+        """
+        USER_MANUAL = "user-manual"
+        ANNOTATION_REPLY = "annotation-reply"
+        OUTPUT_MODERATION = "output-moderation"
+
+    event = QueueEvent.STOP
+    stopped_by: StopBy
+
+
+class QueueMessage(BaseModel):
+    """
+    QueueMessage entity
+    """
+    task_id: str
+    message_id: str
+    conversation_id: str
+    app_mode: str
+    event: AppQueueEvent

+ 0 - 0
api/core/model_providers/models/entity/__init__.py → api/core/errors/__init__.py


+ 0 - 20
api/core/model_providers/error.py → api/core/errors/error.py

@@ -14,26 +14,6 @@ class LLMBadRequestError(LLMError):
     description = "Bad Request"
 
 
-class LLMAPIConnectionError(LLMError):
-    """Raised when the LLM returns API connection error."""
-    description = "API Connection Error"
-
-
-class LLMAPIUnavailableError(LLMError):
-    """Raised when the LLM returns API unavailable error."""
-    description = "API Unavailable Error"
-
-
-class LLMRateLimitError(LLMError):
-    """Raised when the LLM returns rate limit error."""
-    description = "Rate Limit Error"
-
-
-class LLMAuthorizationError(LLMError):
-    """Raised when the LLM returns authorization error."""
-    description = "Authorization Error"
-
-
 class ProviderTokenNotInitError(Exception):
     """
     Custom exception raised when the provider token is not initialized.

+ 0 - 0
api/core/model_providers/models/llm/__init__.py → api/core/external_data_tool/weather_search/__init__.py


+ 35 - 0
api/core/external_data_tool/weather_search/schema.json

@@ -0,0 +1,35 @@
+{
+    "label": {
+        "en-US": "Weather Search",
+        "zh-Hans": "天气查询"
+    },
+    "form_schema": [
+        {
+            "type": "select",
+            "label": {
+                "en-US": "Temperature Unit",
+                "zh-Hans": "温度单位"
+            },
+            "variable": "temperature_unit",
+            "required": true,
+            "options": [
+                {
+                    "label": {
+                        "en-US": "Fahrenheit",
+                        "zh-Hans": "华氏度"
+                    },
+                    "value": "fahrenheit"
+                },
+                {
+                    "label": {
+                        "en-US": "Centigrade",
+                        "zh-Hans": "摄氏度"
+                    },
+                    "value": "centigrade"
+                }
+            ],
+            "default": "centigrade",
+            "placeholder": "Please select temperature unit"
+        }
+    ]
+}

+ 45 - 0
api/core/external_data_tool/weather_search/weather_search.py

@@ -0,0 +1,45 @@
+from typing import Optional
+
+from core.external_data_tool.base import ExternalDataTool
+
+
+class WeatherSearch(ExternalDataTool):
+    """
+    The name of custom type must be unique, keep the same with directory and file name.
+    """
+    name: str = "weather_search"
+
+    @classmethod
+    def validate_config(cls, tenant_id: str, config: dict) -> None:
+        """
+        schema.json validation. It will be called when user save the config.
+
+        Example:
+            .. code-block:: python
+                config = {
+                    "temperature_unit": "centigrade"
+                }
+
+        :param tenant_id: the id of workspace
+        :param config: the variables of form config
+        :return:
+        """
+
+        if not config.get('temperature_unit'):
+            raise ValueError('temperature unit is required')
+
+    def query(self, inputs: dict, query: Optional[str] = None) -> str:
+        """
+        Query the external data tool.
+
+        :param inputs: user inputs
+        :param query: the query of chat app
+        :return: the tool query result
+        """
+        city = inputs.get('city')
+        temperature_unit = self.config.get('temperature_unit')
+
+        if temperature_unit == 'fahrenheit':
+            return f'Weather in {city} is 32°F'
+        else:
+            return f'Weather in {city} is 0°C'

+ 0 - 0
api/core/model_providers/models/moderation/__init__.py → api/core/features/__init__.py


+ 325 - 0
api/core/features/agent_runner.py

@@ -0,0 +1,325 @@
+import logging
+from typing import cast, Optional, List
+
+from langchain import WikipediaAPIWrapper
+from langchain.callbacks.base import BaseCallbackHandler
+from langchain.tools import BaseTool, WikipediaQueryRun, Tool
+from pydantic import BaseModel, Field
+
+from core.agent.agent.agent_llm_callback import AgentLLMCallback
+from core.agent.agent_executor import PlanningStrategy, AgentConfiguration, AgentExecutor
+from core.application_queue_manager import ApplicationQueueManager
+from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
+from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
+from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
+from core.entities.application_entities import ModelConfigEntity, InvokeFrom, \
+    AgentEntity, AgentToolEntity, AppOrchestrationConfigEntity
+from core.memory.token_buffer_memory import TokenBufferMemory
+from core.model_runtime.entities.model_entities import ModelFeature, ModelType
+from core.model_runtime.model_providers import model_provider_factory
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.tool.current_datetime_tool import DatetimeTool
+from core.tool.dataset_retriever_tool import DatasetRetrieverTool
+from core.tool.provider.serpapi_provider import SerpAPIToolProvider
+from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput
+from core.tool.web_reader_tool import WebReaderTool
+from extensions.ext_database import db
+from models.dataset import Dataset
+from models.model import Message
+
+logger = logging.getLogger(__name__)
+
+
+class AgentRunnerFeature:
+    def __init__(self, tenant_id: str,
+                 app_orchestration_config: AppOrchestrationConfigEntity,
+                 model_config: ModelConfigEntity,
+                 config: AgentEntity,
+                 queue_manager: ApplicationQueueManager,
+                 message: Message,
+                 user_id: str,
+                 agent_llm_callback: AgentLLMCallback,
+                 callback: AgentLoopGatherCallbackHandler,
+                 memory: Optional[TokenBufferMemory] = None,) -> None:
+        """
+        Agent runner
+        :param tenant_id: tenant id
+        :param app_orchestration_config: app orchestration config
+        :param model_config: model config
+        :param config: dataset config
+        :param queue_manager: queue manager
+        :param message: message
+        :param user_id: user id
+        :param agent_llm_callback: agent llm callback
+        :param callback: callback
+        :param memory: memory
+        """
+        self.tenant_id = tenant_id
+        self.app_orchestration_config = app_orchestration_config
+        self.model_config = model_config
+        self.config = config
+        self.queue_manager = queue_manager
+        self.message = message
+        self.user_id = user_id
+        self.agent_llm_callback = agent_llm_callback
+        self.callback = callback
+        self.memory = memory
+
+    def run(self, query: str,
+            invoke_from: InvokeFrom) -> Optional[str]:
+        """
+        Retrieve agent loop result.
+        :param query: query
+        :param invoke_from: invoke from
+        :return:
+        """
+        provider = self.config.provider
+        model = self.config.model
+        tool_configs = self.config.tools
+
+        # check model is support tool calling
+        provider_instance = model_provider_factory.get_provider_instance(provider=provider)
+        model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
+        model_type_instance = cast(LargeLanguageModel, model_type_instance)
+
+        # get model schema
+        model_schema = model_type_instance.get_model_schema(
+            model=model,
+            credentials=self.model_config.credentials
+        )
+
+        if not model_schema:
+            return None
+
+        planning_strategy = PlanningStrategy.REACT
+        features = model_schema.features
+        if features:
+            if ModelFeature.TOOL_CALL in features \
+                    or ModelFeature.MULTI_TOOL_CALL in features:
+                planning_strategy = PlanningStrategy.FUNCTION_CALL
+
+        tools = self.to_tools(
+            tool_configs=tool_configs,
+            invoke_from=invoke_from,
+            callbacks=[self.callback, DifyStdOutCallbackHandler()],
+        )
+
+        if len(tools) == 0:
+            return None
+
+        agent_configuration = AgentConfiguration(
+            strategy=planning_strategy,
+            model_config=self.model_config,
+            tools=tools,
+            memory=self.memory,
+            max_iterations=10,
+            max_execution_time=400.0,
+            early_stopping_method="generate",
+            agent_llm_callback=self.agent_llm_callback,
+            callbacks=[self.callback, DifyStdOutCallbackHandler()]
+        )
+
+        agent_executor = AgentExecutor(agent_configuration)
+
+        try:
+            # check if should use agent
+            should_use_agent = agent_executor.should_use_agent(query)
+            if not should_use_agent:
+                return None
+
+            result = agent_executor.run(query)
+            return result.output
+        except Exception as ex:
+            logger.exception("agent_executor run failed")
+            return None
+
+    def to_tools(self, tool_configs: list[AgentToolEntity],
+                 invoke_from: InvokeFrom,
+                 callbacks: list[BaseCallbackHandler]) \
+            -> Optional[List[BaseTool]]:
+        """
+        Convert tool configs to tools
+        :param tool_configs: tool configs
+        :param invoke_from: invoke from
+        :param callbacks: callbacks
+        """
+        tools = []
+        for tool_config in tool_configs:
+            tool = None
+            if tool_config.tool_id == "dataset":
+                tool = self.to_dataset_retriever_tool(
+                    tool_config=tool_config.config,
+                    invoke_from=invoke_from
+                )
+            elif tool_config.tool_id == "web_reader":
+                tool = self.to_web_reader_tool(
+                    tool_config=tool_config.config,
+                    invoke_from=invoke_from
+                )
+            elif tool_config.tool_id == "google_search":
+                tool = self.to_google_search_tool(
+                    tool_config=tool_config.config,
+                    invoke_from=invoke_from
+                )
+            elif tool_config.tool_id == "wikipedia":
+                tool = self.to_wikipedia_tool(
+                    tool_config=tool_config.config,
+                    invoke_from=invoke_from
+                )
+            elif tool_config.tool_id == "current_datetime":
+                tool = self.to_current_datetime_tool(
+                    tool_config=tool_config.config,
+                    invoke_from=invoke_from
+                )
+
+            if tool:
+                if tool.callbacks is not None:
+                    tool.callbacks.extend(callbacks)
+                else:
+                    tool.callbacks = callbacks
+
+                tools.append(tool)
+
+        return tools
+
+    def to_dataset_retriever_tool(self, tool_config: dict,
+                                  invoke_from: InvokeFrom) \
+            -> Optional[BaseTool]:
+        """
+        A dataset tool is a tool that can be used to retrieve information from a dataset
+        :param tool_config: tool config
+        :param invoke_from: invoke from
+        """
+        show_retrieve_source = self.app_orchestration_config.show_retrieve_source
+
+        hit_callback = DatasetIndexToolCallbackHandler(
+            queue_manager=self.queue_manager,
+            app_id=self.message.app_id,
+            message_id=self.message.id,
+            user_id=self.user_id,
+            invoke_from=invoke_from
+        )
+
+        # get dataset from dataset id
+        dataset = db.session.query(Dataset).filter(
+            Dataset.tenant_id == self.tenant_id,
+            Dataset.id == tool_config.get("id")
+        ).first()
+
+        # pass if dataset is not available
+        if not dataset:
+            return None
+
+        # pass if dataset is not available
+        if (dataset and dataset.available_document_count == 0
+                and dataset.available_document_count == 0):
+            return None
+
+        # get retrieval model config
+        default_retrieval_model = {
+            'search_method': 'semantic_search',
+            'reranking_enable': False,
+            'reranking_model': {
+                'reranking_provider_name': '',
+                'reranking_model_name': ''
+            },
+            'top_k': 2,
+            'score_threshold_enabled': False
+        }
+
+        retrieval_model_config = dataset.retrieval_model \
+            if dataset.retrieval_model else default_retrieval_model
+
+        # get top k
+        top_k = retrieval_model_config['top_k']
+
+        # get score threshold
+        score_threshold = None
+        score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
+        if score_threshold_enabled:
+            score_threshold = retrieval_model_config.get("score_threshold")
+
+        tool = DatasetRetrieverTool.from_dataset(
+            dataset=dataset,
+            top_k=top_k,
+            score_threshold=score_threshold,
+            hit_callbacks=[hit_callback],
+            return_resource=show_retrieve_source,
+            retriever_from=invoke_from.to_source()
+        )
+
+        return tool
+
+    def to_web_reader_tool(self, tool_config: dict,
+                           invoke_from: InvokeFrom) -> Optional[BaseTool]:
+        """
+        A tool for reading web pages
+        :param tool_config: tool config
+        :param invoke_from: invoke from
+        :return:
+        """
+        model_parameters = {
+            "temperature": 0,
+            "max_tokens": 500
+        }
+
+        tool = WebReaderTool(
+            model_config=self.model_config,
+            model_parameters=model_parameters,
+            max_chunk_length=4000,
+            continue_reading=True
+        )
+
+        return tool
+
+    def to_google_search_tool(self, tool_config: dict,
+                              invoke_from: InvokeFrom) -> Optional[BaseTool]:
+        """
+        A tool for performing a Google search and extracting snippets and webpages
+        :param tool_config: tool config
+        :param invoke_from: invoke from
+        :return:
+        """
+        tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id)
+        func_kwargs = tool_provider.credentials_to_func_kwargs()
+        if not func_kwargs:
+            return None
+
+        tool = Tool(
+            name="google_search",
+            description="A tool for performing a Google search and extracting snippets and webpages "
+                        "when you need to search for something you don't know or when your information "
+                        "is not up to date. "
+                        "Input should be a search query.",
+            func=OptimizedSerpAPIWrapper(**func_kwargs).run,
+            args_schema=OptimizedSerpAPIInput
+        )
+
+        return tool
+
+    def to_current_datetime_tool(self, tool_config: dict,
+                                 invoke_from: InvokeFrom) -> Optional[BaseTool]:
+        """
+        A tool for getting the current date and time
+        :param tool_config: tool config
+        :param invoke_from: invoke from
+        :return:
+        """
+        return DatetimeTool()
+
+    def to_wikipedia_tool(self, tool_config: dict,
+                          invoke_from: InvokeFrom) -> Optional[BaseTool]:
+        """
+        A tool for searching Wikipedia
+        :param tool_config: tool config
+        :param invoke_from: invoke from
+        :return:
+        """
+        class WikipediaInput(BaseModel):
+            query: str = Field(..., description="search query.")
+
+        return WikipediaQueryRun(
+            name="wikipedia",
+            api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
+            args_schema=WikipediaInput
+        )

+ 119 - 0
api/core/features/annotation_reply.py

@@ -0,0 +1,119 @@
+import logging
+from typing import Optional
+
+from flask import current_app
+
+from core.embedding.cached_embedding import CacheEmbedding
+from core.entities.application_entities import InvokeFrom
+from core.index.vector_index.vector_index import VectorIndex
+from core.model_manager import ModelManager
+from core.model_runtime.entities.model_entities import ModelType
+from extensions.ext_database import db
+from models.dataset import Dataset
+from models.model import App, Message, AppAnnotationSetting, MessageAnnotation
+from services.annotation_service import AppAnnotationService
+from services.dataset_service import DatasetCollectionBindingService
+
+logger = logging.getLogger(__name__)
+
+
+class AnnotationReplyFeature:
+    def query(self, app_record: App,
+              message: Message,
+              query: str,
+              user_id: str,
+              invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
+        """
+        Query app annotations to reply
+        :param app_record: app record
+        :param message: message
+        :param query: query
+        :param user_id: user id
+        :param invoke_from: invoke from
+        :return:
+        """
+        annotation_setting = db.session.query(AppAnnotationSetting).filter(
+            AppAnnotationSetting.app_id == app_record.id).first()
+
+        if not annotation_setting:
+            return None
+
+        collection_binding_detail = annotation_setting.collection_binding_detail
+
+        try:
+            score_threshold = annotation_setting.score_threshold or 1
+            embedding_provider_name = collection_binding_detail.provider_name
+            embedding_model_name = collection_binding_detail.model_name
+
+            model_manager = ModelManager()
+            model_instance = model_manager.get_model_instance(
+                tenant_id=app_record.tenant_id,
+                provider=embedding_provider_name,
+                model_type=ModelType.TEXT_EMBEDDING,
+                model=embedding_model_name
+            )
+
+            # get embedding model
+            embeddings = CacheEmbedding(model_instance)
+
+            dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
+                embedding_provider_name,
+                embedding_model_name,
+                'annotation'
+            )
+
+            dataset = Dataset(
+                id=app_record.id,
+                tenant_id=app_record.tenant_id,
+                indexing_technique='high_quality',
+                embedding_model_provider=embedding_provider_name,
+                embedding_model=embedding_model_name,
+                collection_binding_id=dataset_collection_binding.id
+            )
+
+            vector_index = VectorIndex(
+                dataset=dataset,
+                config=current_app.config,
+                embeddings=embeddings,
+                attributes=['doc_id', 'annotation_id', 'app_id']
+            )
+
+            documents = vector_index.search(
+                query=query,
+                search_type='similarity_score_threshold',
+                search_kwargs={
+                    'k': 1,
+                    'score_threshold': score_threshold,
+                    'filter': {
+                        'group_id': [dataset.id]
+                    }
+                }
+            )
+
+            if documents:
+                annotation_id = documents[0].metadata['annotation_id']
+                score = documents[0].metadata['score']
+                annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
+                if annotation:
+                    if invoke_from in [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP]:
+                        from_source = 'api'
+                    else:
+                        from_source = 'console'
+
+                    # insert annotation history
+                    AppAnnotationService.add_annotation_history(annotation.id,
+                                                                app_record.id,
+                                                                annotation.question,
+                                                                annotation.content,
+                                                                query,
+                                                                user_id,
+                                                                message.id,
+                                                                from_source,
+                                                                score)
+
+                    return annotation
+        except Exception as e:
+            logger.warning(f'Query annotation failed, exception: {str(e)}.')
+            return None
+
+        return None

+ 181 - 0
api/core/features/dataset_retrieval.py

@@ -0,0 +1,181 @@
+from typing import cast, Optional, List
+
+from langchain.tools import BaseTool
+
+from core.agent.agent_executor import PlanningStrategy, AgentConfiguration, AgentExecutor
+from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
+from core.entities.application_entities import DatasetEntity, ModelConfigEntity, InvokeFrom, DatasetRetrieveConfigEntity
+from core.memory.token_buffer_memory import TokenBufferMemory
+from core.model_runtime.entities.model_entities import ModelFeature
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
+from core.tool.dataset_retriever_tool import DatasetRetrieverTool
+from extensions.ext_database import db
+from models.dataset import Dataset
+
+
+class DatasetRetrievalFeature:
+    def retrieve(self, tenant_id: str,
+                 model_config: ModelConfigEntity,
+                 config: DatasetEntity,
+                 query: str,
+                 invoke_from: InvokeFrom,
+                 show_retrieve_source: bool,
+                 hit_callback: DatasetIndexToolCallbackHandler,
+                 memory: Optional[TokenBufferMemory] = None) -> Optional[str]:
+        """
+        Retrieve dataset.
+        :param tenant_id: tenant id
+        :param model_config: model config
+        :param config: dataset config
+        :param query: query
+        :param invoke_from: invoke from
+        :param show_retrieve_source: show retrieve source
+        :param hit_callback: hit callback
+        :param memory: memory
+        :return:
+        """
+        dataset_ids = config.dataset_ids
+        retrieve_config = config.retrieve_config
+
+        # check model is support tool calling
+        model_type_instance = model_config.provider_model_bundle.model_type_instance
+        model_type_instance = cast(LargeLanguageModel, model_type_instance)
+
+        # get model schema
+        model_schema = model_type_instance.get_model_schema(
+            model=model_config.model,
+            credentials=model_config.credentials
+        )
+
+        if not model_schema:
+            return None
+
+        planning_strategy = PlanningStrategy.REACT_ROUTER
+        features = model_schema.features
+        if features:
+            if ModelFeature.TOOL_CALL in features \
+                    or ModelFeature.MULTI_TOOL_CALL in features:
+                planning_strategy = PlanningStrategy.ROUTER
+
+        dataset_retriever_tools = self.to_dataset_retriever_tool(
+            tenant_id=tenant_id,
+            dataset_ids=dataset_ids,
+            retrieve_config=retrieve_config,
+            return_resource=show_retrieve_source,
+            invoke_from=invoke_from,
+            hit_callback=hit_callback
+        )
+
+        if len(dataset_retriever_tools) == 0:
+            return None
+
+        agent_configuration = AgentConfiguration(
+            strategy=planning_strategy,
+            model_config=model_config,
+            tools=dataset_retriever_tools,
+            memory=memory,
+            max_iterations=10,
+            max_execution_time=400.0,
+            early_stopping_method="generate"
+        )
+
+        agent_executor = AgentExecutor(agent_configuration)
+
+        should_use_agent = agent_executor.should_use_agent(query)
+        if not should_use_agent:
+            return None
+
+        result = agent_executor.run(query)
+
+        return result.output
+
+    def to_dataset_retriever_tool(self, tenant_id: str,
+                                  dataset_ids: list[str],
+                                  retrieve_config: DatasetRetrieveConfigEntity,
+                                  return_resource: bool,
+                                  invoke_from: InvokeFrom,
+                                  hit_callback: DatasetIndexToolCallbackHandler) \
+            -> Optional[List[BaseTool]]:
+        """
+        A dataset tool is a tool that can be used to retrieve information from a dataset
+        :param tenant_id: tenant id
+        :param dataset_ids: dataset ids
+        :param retrieve_config: retrieve config
+        :param return_resource: return resource
+        :param invoke_from: invoke from
+        :param hit_callback: hit callback
+        """
+        tools = []
+        available_datasets = []
+        for dataset_id in dataset_ids:
+            # get dataset from dataset id
+            dataset = db.session.query(Dataset).filter(
+                Dataset.tenant_id == tenant_id,
+                Dataset.id == dataset_id
+            ).first()
+
+            # pass if dataset is not available
+            if not dataset:
+                continue
+
+            # pass if dataset is not available
+            if (dataset and dataset.available_document_count == 0
+                    and dataset.available_document_count == 0):
+                continue
+
+            available_datasets.append(dataset)
+
+        if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
+            # get retrieval model config
+            default_retrieval_model = {
+                'search_method': 'semantic_search',
+                'reranking_enable': False,
+                'reranking_model': {
+                    'reranking_provider_name': '',
+                    'reranking_model_name': ''
+                },
+                'top_k': 2,
+                'score_threshold_enabled': False
+            }
+
+            for dataset in available_datasets:
+                retrieval_model_config = dataset.retrieval_model \
+                    if dataset.retrieval_model else default_retrieval_model
+
+                # get top k
+                top_k = retrieval_model_config['top_k']
+
+                # get score threshold
+                score_threshold = None
+                score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
+                if score_threshold_enabled:
+                    score_threshold = retrieval_model_config.get("score_threshold")
+
+                tool = DatasetRetrieverTool.from_dataset(
+                    dataset=dataset,
+                    top_k=top_k,
+                    score_threshold=score_threshold,
+                    hit_callbacks=[hit_callback],
+                    return_resource=return_resource,
+                    retriever_from=invoke_from.to_source()
+                )
+
+                tools.append(tool)
+        elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
+            tool = DatasetMultiRetrieverTool.from_dataset(
+                dataset_ids=[dataset.id for dataset in available_datasets],
+                tenant_id=tenant_id,
+                top_k=retrieve_config.top_k or 2,
+                score_threshold=(retrieve_config.score_threshold or 0.5)
+                if retrieve_config.reranking_model.get('score_threshold_enabled', False) else None,
+                hit_callbacks=[hit_callback],
+                return_resource=return_resource,
+                retriever_from=invoke_from.to_source(),
+                reranking_provider_name=retrieve_config.reranking_model.get('reranking_provider_name'),
+                reranking_model_name=retrieve_config.reranking_model.get('reranking_model_name')
+            )
+
+            tools.append(tool)
+
+        return tools

+ 96 - 0
api/core/features/external_data_fetch.py

@@ -0,0 +1,96 @@
+import concurrent
+import json
+import logging
+
+from concurrent.futures import ThreadPoolExecutor
+from typing import Tuple, Optional
+
+from flask import current_app, Flask
+
+from core.entities.application_entities import ExternalDataVariableEntity
+from core.external_data_tool.factory import ExternalDataToolFactory
+
+logger = logging.getLogger(__name__)
+
+
+class ExternalDataFetchFeature:
+    def fetch(self, tenant_id: str,
+              app_id: str,
+              external_data_tools: list[ExternalDataVariableEntity],
+              inputs: dict,
+              query: str) -> dict:
+        """
+        Fill in variable inputs from external data tools if exists.
+
+        :param tenant_id: workspace id
+        :param app_id: app id
+        :param external_data_tools: external data tools configs
+        :param inputs: the inputs
+        :param query: the query
+        :return: the filled inputs
+        """
+        # Group tools by type and config
+        grouped_tools = {}
+        for tool in external_data_tools:
+            tool_key = (tool.type, json.dumps(tool.config, sort_keys=True))
+            grouped_tools.setdefault(tool_key, []).append(tool)
+
+        results = {}
+        with ThreadPoolExecutor() as executor:
+            futures = {}
+            for tool in external_data_tools:
+                future = executor.submit(
+                    self._query_external_data_tool,
+                    current_app._get_current_object(),
+                    tenant_id,
+                    app_id,
+                    tool,
+                    inputs,
+                    query
+                )
+
+                futures[future] = tool
+
+            for future in concurrent.futures.as_completed(futures):
+                tool_variable, result = future.result()
+                results[tool_variable] = result
+
+        inputs.update(results)
+        return inputs
+
+    def _query_external_data_tool(self, flask_app: Flask,
+                                  tenant_id: str,
+                                  app_id: str,
+                                  external_data_tool: ExternalDataVariableEntity,
+                                  inputs: dict,
+                                  query: str) -> Tuple[Optional[str], Optional[str]]:
+        """
+        Query external data tool.
+        :param flask_app: flask app
+        :param tenant_id: tenant id
+        :param app_id: app id
+        :param external_data_tool: external data tool
+        :param inputs: inputs
+        :param query: query
+        :return:
+        """
+        with flask_app.app_context():
+            tool_variable = external_data_tool.variable
+            tool_type = external_data_tool.type
+            tool_config = external_data_tool.config
+
+            external_data_tool_factory = ExternalDataToolFactory(
+                name=tool_type,
+                tenant_id=tenant_id,
+                app_id=app_id,
+                variable=tool_variable,
+                config=tool_config
+            )
+
+            # query external data tool
+            result = external_data_tool_factory.query(
+                inputs=inputs,
+                query=query
+            )
+
+            return tool_variable, result

+ 32 - 0
api/core/features/hosting_moderation.py

@@ -0,0 +1,32 @@
+import logging
+
+from core.entities.application_entities import ApplicationGenerateEntity
+from core.helper import moderation
+from core.model_runtime.entities.message_entities import PromptMessage
+
+logger = logging.getLogger(__name__)
+
+
+class HostingModerationFeature:
+    def check(self, application_generate_entity: ApplicationGenerateEntity,
+              prompt_messages: list[PromptMessage]) -> bool:
+        """
+        Check hosting moderation
+        :param application_generate_entity: application generate entity
+        :param prompt_messages: prompt messages
+        :return:
+        """
+        app_orchestration_config = application_generate_entity.app_orchestration_config_entity
+        model_config = app_orchestration_config.model_config
+
+        text = ""
+        for prompt_message in prompt_messages:
+            if isinstance(prompt_message.content, str):
+                text += prompt_message.content + "\n"
+
+        moderation_result = moderation.check_moderation(
+            model_config,
+            text
+        )
+
+        return moderation_result

+ 50 - 0
api/core/features/moderation.py

@@ -0,0 +1,50 @@
+import logging
+from typing import Tuple
+
+from core.entities.application_entities import AppOrchestrationConfigEntity
+from core.moderation.base import ModerationAction, ModerationException
+from core.moderation.factory import ModerationFactory
+
+logger = logging.getLogger(__name__)
+
+
+class ModerationFeature:
+    def check(self, app_id: str,
+              tenant_id: str,
+              app_orchestration_config_entity: AppOrchestrationConfigEntity,
+              inputs: dict,
+              query: str) -> Tuple[bool, dict, str]:
+        """
+        Process sensitive_word_avoidance.
+        :param app_id: app id
+        :param tenant_id: tenant id
+        :param app_orchestration_config_entity: app orchestration config entity
+        :param inputs: inputs
+        :param query: query
+        :return:
+        """
+        if not app_orchestration_config_entity.sensitive_word_avoidance:
+            return False, inputs, query
+
+        sensitive_word_avoidance_config = app_orchestration_config_entity.sensitive_word_avoidance
+        moderation_type = sensitive_word_avoidance_config.type
+
+        moderation_factory = ModerationFactory(
+            name=moderation_type,
+            app_id=app_id,
+            tenant_id=tenant_id,
+            config=sensitive_word_avoidance_config.config
+        )
+
+        moderation_result = moderation_factory.moderation_for_inputs(inputs, query)
+
+        if not moderation_result.flagged:
+            return False, inputs, query
+
+        if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
+            raise ModerationException(moderation_result.preset_response)
+        elif moderation_result.action == ModerationAction.OVERRIDED:
+            inputs = moderation_result.inputs
+            query = moderation_result.query
+
+        return True, inputs, query

+ 5 - 5
api/core/file/file_obj.py

@@ -4,7 +4,7 @@ from typing import Optional
 from pydantic import BaseModel
 
 from core.file.upload_file_parser import UploadFileParser
-from core.model_providers.models.entity.message import PromptMessageFile, ImagePromptMessageFile
+from core.model_runtime.entities.message_entities import ImagePromptMessageContent
 from extensions.ext_database import db
 from models.model import UploadFile
 
@@ -50,14 +50,14 @@ class FileObj(BaseModel):
         return self._get_data(force_url=True)
 
     @property
-    def prompt_message_file(self) -> PromptMessageFile:
+    def prompt_message_content(self) -> ImagePromptMessageContent:
         if self.type == FileType.IMAGE:
             image_config = self.file_config.get('image')
 
-            return ImagePromptMessageFile(
+            return ImagePromptMessageContent(
                 data=self.data,
-                detail=ImagePromptMessageFile.DETAIL.HIGH
-                if image_config.get("detail") == "high" else ImagePromptMessageFile.DETAIL.LOW
+                detail=ImagePromptMessageContent.DETAIL.HIGH
+                if image_config.get("detail") == "high" else ImagePromptMessageContent.DETAIL.LOW
             )
 
     def _get_data(self, force_url: bool = False) -> Optional[str]:

+ 63 - 40
api/core/generator/llm_generator.py

@@ -3,10 +3,10 @@ import logging
 
 from langchain.schema import OutputParserException
 
-from core.model_providers.error import LLMError, ProviderTokenNotInitError
-from core.model_providers.model_factory import ModelFactory
-from core.model_providers.models.entity.message import PromptMessage, MessageType
-from core.model_providers.models.entity.model_params import ModelKwargs
+from core.model_manager import ModelManager
+from core.model_runtime.entities.message_entities import UserPromptMessage, SystemPromptMessage
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
 from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
 
 from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
@@ -26,17 +26,22 @@ class LLMGenerator:
 
         prompt += query + "\n"
 
-        model_instance = ModelFactory.get_text_generation_model(
+        model_manager = ModelManager()
+        model_instance = model_manager.get_default_model_instance(
             tenant_id=tenant_id,
-            model_kwargs=ModelKwargs(
-                temperature=1,
-                max_tokens=100
-            )
+            model_type=ModelType.LLM,
         )
 
-        prompts = [PromptMessage(content=prompt)]
-        response = model_instance.run(prompts)
-        answer = response.content
+        prompts = [UserPromptMessage(content=prompt)]
+        response = model_instance.invoke_llm(
+            prompt_messages=prompts,
+            model_parameters={
+                "max_tokens": 100,
+                "temperature": 1
+            },
+            stream=False
+        )
+        answer = response.message.content
 
         result_dict = json.loads(answer)
         answer = result_dict['Your Output']
@@ -62,22 +67,28 @@ class LLMGenerator:
         })
 
         try:
-            model_instance = ModelFactory.get_text_generation_model(
+            model_manager = ModelManager()
+            model_instance = model_manager.get_default_model_instance(
                 tenant_id=tenant_id,
-                model_kwargs=ModelKwargs(
-                    max_tokens=256,
-                    temperature=0
-                )
+                model_type=ModelType.LLM,
             )
-        except ProviderTokenNotInitError:
+        except InvokeAuthorizationError:
             return []
 
-        prompt_messages = [PromptMessage(content=prompt)]
+        prompt_messages = [UserPromptMessage(content=prompt)]
 
         try:
-            output = model_instance.run(prompt_messages)
-            questions = output_parser.parse(output.content)
-        except LLMError:
+            response = model_instance.invoke_llm(
+                prompt_messages=prompt_messages,
+                model_parameters={
+                    "max_tokens": 256,
+                    "temperature": 0
+                },
+                stream=False
+            )
+
+            questions = output_parser.parse(response.message.content)
+        except InvokeError:
             questions = []
         except Exception as e:
             logging.exception(e)
@@ -105,20 +116,26 @@ class LLMGenerator:
             remove_template_variables=False
         )
 
-        model_instance = ModelFactory.get_text_generation_model(
+        model_manager = ModelManager()
+        model_instance = model_manager.get_default_model_instance(
             tenant_id=tenant_id,
-            model_kwargs=ModelKwargs(
-                max_tokens=512,
-                temperature=0
-            )
+            model_type=ModelType.LLM,
         )
 
-        prompt_messages = [PromptMessage(content=prompt)]
+        prompt_messages = [UserPromptMessage(content=prompt)]
 
         try:
-            output = model_instance.run(prompt_messages)
-            rule_config = output_parser.parse(output.content)
-        except LLMError as e:
+            response = model_instance.invoke_llm(
+                prompt_messages=prompt_messages,
+                model_parameters={
+                    "max_tokens": 512,
+                    "temperature": 0
+                },
+                stream=False
+            )
+
+            rule_config = output_parser.parse(response.message.content)
+        except InvokeError as e:
             raise e
         except OutputParserException:
             raise ValueError('Please give a valid input for intended audience or hoping to solve problems.')
@@ -136,18 +153,24 @@ class LLMGenerator:
     def generate_qa_document(cls, tenant_id: str, query, document_language: str):
         prompt = GENERATOR_QA_PROMPT.format(language=document_language)
 
-        model_instance = ModelFactory.get_text_generation_model(
+        model_manager = ModelManager()
+        model_instance = model_manager.get_default_model_instance(
             tenant_id=tenant_id,
-            model_kwargs=ModelKwargs(
-                max_tokens=2000
-            )
+            model_type=ModelType.LLM,
         )
 
-        prompts = [
-            PromptMessage(content=prompt, type=MessageType.SYSTEM),
-            PromptMessage(content=query)
+        prompt_messages = [
+            SystemPromptMessage(content=prompt),
+            UserPromptMessage(content=query)
         ]
 
-        response = model_instance.run(prompts)
-        answer = response.content
+        response = model_instance.invoke_llm(
+            prompt_messages=prompt_messages,
+            model_parameters={
+                "max_tokens": 2000
+            },
+            stream=False
+        )
+
+        answer = response.message.content
         return answer.strip()

+ 14 - 0
api/core/helper/encrypter.py

@@ -18,3 +18,17 @@ def encrypt_token(tenant_id: str, token: str):
 
 def decrypt_token(tenant_id: str, token: str):
     return rsa.decrypt(base64.b64decode(token), tenant_id)
+
+
+def batch_decrypt_token(tenant_id: str, tokens: list[str]):
+    rsa_key, cipher_rsa = rsa.get_decrypt_decoding(tenant_id)
+
+    return [rsa.decrypt_token_with_decoding(base64.b64decode(token), rsa_key, cipher_rsa) for token in tokens]
+
+
+def get_decrypt_decoding(tenant_id: str):
+    return rsa.get_decrypt_decoding(tenant_id)
+
+
+def decrypt_token_with_decoding(token: str, rsa_key, cipher_rsa):
+    return rsa.decrypt_token_with_decoding(base64.b64decode(token), rsa_key, cipher_rsa)

+ 22 - 0
api/core/helper/lru_cache.py

@@ -0,0 +1,22 @@
+from collections import OrderedDict
+from typing import Any
+
+
+class LRUCache:
+    def __init__(self, capacity: int):
+        self.cache = OrderedDict()
+        self.capacity = capacity
+
+    def get(self, key: Any) -> Any:
+        if key not in self.cache:
+            return None
+        else:
+            self.cache.move_to_end(key)  # move the key to the end of the OrderedDict
+            return self.cache[key]
+
+    def put(self, key: Any, value: Any) -> None:
+        if key in self.cache:
+            self.cache.move_to_end(key)
+        self.cache[key] = value
+        if len(self.cache) > self.capacity:
+            self.cache.popitem(last=False)  # pop the first item

+ 30 - 18
api/core/helper/moderation.py

@@ -1,18 +1,27 @@
 import logging
 import random
 
-import openai
-
-from core.model_providers.error import LLMBadRequestError
-from core.model_providers.providers.base import BaseModelProvider
-from core.model_providers.providers.hosted import hosted_config, hosted_model_providers
+from core.entities.application_entities import ModelConfigEntity
+from core.model_runtime.errors.invoke import InvokeBadRequestError
+from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel
+from extensions.ext_hosting_provider import hosting_configuration
 from models.provider import ProviderType
 
+logger = logging.getLogger(__name__)
+
+
+def check_moderation(model_config: ModelConfigEntity, text: str) -> bool:
+    moderation_config = hosting_configuration.moderation_config
+    if (moderation_config and moderation_config.enabled is True
+            and 'openai' in hosting_configuration.provider_map
+            and hosting_configuration.provider_map['openai'].enabled is True
+    ):
+        using_provider_type = model_config.provider_model_bundle.configuration.using_provider_type
+        provider_name = model_config.provider
+        if using_provider_type == ProviderType.SYSTEM \
+                and provider_name in moderation_config.providers:
+            hosting_openai_config = hosting_configuration.provider_map['openai']
 
-def check_moderation(model_provider: BaseModelProvider, text: str) -> bool:
-    if hosted_config.moderation.enabled is True and hosted_model_providers.openai:
-        if model_provider.provider.provider_type == ProviderType.SYSTEM.value \
-                and model_provider.provider_name in hosted_config.moderation.providers:
             # 2000 text per chunk
             length = 2000
             text_chunks = [text[i:i + length] for i in range(0, len(text), length)]
@@ -23,14 +32,17 @@ def check_moderation(model_provider: BaseModelProvider, text: str) -> bool:
             text_chunk = random.choice(text_chunks)
 
             try:
-                moderation_result = openai.Moderation.create(input=text_chunk,
-                                                             api_key=hosted_model_providers.openai.api_key)
+                model_type_instance = OpenAIModerationModel()
+                moderation_result = model_type_instance.invoke(
+                    model='text-moderation-stable',
+                    credentials=hosting_openai_config.credentials,
+                    text=text_chunk
+                )
+
+                if moderation_result is True:
+                    return True
             except Exception as ex:
-                logging.exception(ex)
-                raise LLMBadRequestError('Rate limit exceeded, please try again later.')
-
-            for result in moderation_result.results:
-                if result['flagged'] is True:
-                    return False
+                logger.exception(ex)
+                raise InvokeBadRequestError('Rate limit exceeded, please try again later.')
 
-    return True
+    return False

+ 213 - 0
api/core/hosting_configuration.py

@@ -0,0 +1,213 @@
+import os
+from typing import Optional
+
+from flask import Flask
+from pydantic import BaseModel
+
+from core.entities.provider_entities import QuotaUnit
+from models.provider import ProviderQuotaType
+
+
+class HostingQuota(BaseModel):
+    quota_type: ProviderQuotaType
+    restrict_llms: list[str] = []
+
+
+class TrialHostingQuota(HostingQuota):
+    quota_type: ProviderQuotaType = ProviderQuotaType.TRIAL
+    quota_limit: int = 0
+    """Quota limit for the hosting provider models. -1 means unlimited."""
+
+
+class PaidHostingQuota(HostingQuota):
+    quota_type: ProviderQuotaType = ProviderQuotaType.PAID
+    stripe_price_id: str = None
+    increase_quota: int = 1
+    min_quantity: int = 20
+    max_quantity: int = 100
+
+
+class FreeHostingQuota(HostingQuota):
+    quota_type: ProviderQuotaType = ProviderQuotaType.FREE
+
+
+class HostingProvider(BaseModel):
+    enabled: bool = False
+    credentials: Optional[dict] = None
+    quota_unit: Optional[QuotaUnit] = None
+    quotas: list[HostingQuota] = []
+
+
+class HostedModerationConfig(BaseModel):
+    enabled: bool = False
+    providers: list[str] = []
+
+
+class HostingConfiguration:
+    provider_map: dict[str, HostingProvider] = {}
+    moderation_config: HostedModerationConfig = None
+
+    def init_app(self, app: Flask):
+        if app.config.get('EDITION') != 'CLOUD':
+            return
+
+        self.provider_map["openai"] = self.init_openai()
+        self.provider_map["anthropic"] = self.init_anthropic()
+        self.provider_map["minimax"] = self.init_minimax()
+        self.provider_map["spark"] = self.init_spark()
+        self.provider_map["zhipuai"] = self.init_zhipuai()
+
+        self.moderation_config = self.init_moderation_config()
+
+    def init_openai(self) -> HostingProvider:
+        quota_unit = QuotaUnit.TIMES
+        if os.environ.get("HOSTED_OPENAI_ENABLED") and os.environ.get("HOSTED_OPENAI_ENABLED").lower() == 'true':
+            credentials = {
+                "openai_api_key": os.environ.get("HOSTED_OPENAI_API_KEY"),
+            }
+
+            if os.environ.get("HOSTED_OPENAI_API_BASE"):
+                credentials["openai_api_base"] = os.environ.get("HOSTED_OPENAI_API_BASE")
+
+            if os.environ.get("HOSTED_OPENAI_API_ORGANIZATION"):
+                credentials["openai_organization"] = os.environ.get("HOSTED_OPENAI_API_ORGANIZATION")
+
+            quotas = []
+            hosted_quota_limit = int(os.environ.get("HOSTED_OPENAI_QUOTA_LIMIT", "200"))
+            if hosted_quota_limit != -1 or hosted_quota_limit > 0:
+                trial_quota = TrialHostingQuota(
+                    quota_limit=hosted_quota_limit,
+                    restrict_llms=[
+                        "gpt-3.5-turbo",
+                        "gpt-3.5-turbo-1106",
+                        "gpt-3.5-turbo-instruct",
+                        "gpt-3.5-turbo-16k",
+                        "text-davinci-003"
+                    ]
+                )
+                quotas.append(trial_quota)
+
+            if os.environ.get("HOSTED_OPENAI_PAID_ENABLED") and os.environ.get(
+                    "HOSTED_OPENAI_PAID_ENABLED").lower() == 'true':
+                paid_quota = PaidHostingQuota(
+                    stripe_price_id=os.environ.get("HOSTED_OPENAI_PAID_STRIPE_PRICE_ID"),
+                    increase_quota=int(os.environ.get("HOSTED_OPENAI_PAID_INCREASE_QUOTA", "1")),
+                    min_quantity=int(os.environ.get("HOSTED_OPENAI_PAID_MIN_QUANTITY", "1")),
+                    max_quantity=int(os.environ.get("HOSTED_OPENAI_PAID_MAX_QUANTITY", "1"))
+                )
+                quotas.append(paid_quota)
+
+            return HostingProvider(
+                enabled=True,
+                credentials=credentials,
+                quota_unit=quota_unit,
+                quotas=quotas
+            )
+
+        return HostingProvider(
+            enabled=False,
+            quota_unit=quota_unit,
+        )
+
+    def init_anthropic(self) -> HostingProvider:
+        quota_unit = QuotaUnit.TOKENS
+        if os.environ.get("HOSTED_ANTHROPIC_ENABLED") and os.environ.get("HOSTED_ANTHROPIC_ENABLED").lower() == 'true':
+            credentials = {
+                "anthropic_api_key": os.environ.get("HOSTED_ANTHROPIC_API_KEY"),
+            }
+
+            if os.environ.get("HOSTED_ANTHROPIC_API_BASE"):
+                credentials["anthropic_api_url"] = os.environ.get("HOSTED_ANTHROPIC_API_BASE")
+
+            quotas = []
+            hosted_quota_limit = int(os.environ.get("HOSTED_ANTHROPIC_QUOTA_LIMIT", "0"))
+            if hosted_quota_limit != -1 or hosted_quota_limit > 0:
+                trial_quota = TrialHostingQuota(
+                    quota_limit=hosted_quota_limit
+                )
+                quotas.append(trial_quota)
+
+            if os.environ.get("HOSTED_ANTHROPIC_PAID_ENABLED") and os.environ.get(
+                    "HOSTED_ANTHROPIC_PAID_ENABLED").lower() == 'true':
+                paid_quota = PaidHostingQuota(
+                    stripe_price_id=os.environ.get("HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID"),
+                    increase_quota=int(os.environ.get("HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA", "1000000")),
+                    min_quantity=int(os.environ.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY", "20")),
+                    max_quantity=int(os.environ.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY", "100"))
+                )
+                quotas.append(paid_quota)
+
+            return HostingProvider(
+                enabled=True,
+                credentials=credentials,
+                quota_unit=quota_unit,
+                quotas=quotas
+            )
+
+        return HostingProvider(
+            enabled=False,
+            quota_unit=quota_unit,
+        )
+
+    def init_minimax(self) -> HostingProvider:
+        quota_unit = QuotaUnit.TOKENS
+        if os.environ.get("HOSTED_MINIMAX_ENABLED") and os.environ.get("HOSTED_MINIMAX_ENABLED").lower() == 'true':
+            quotas = [FreeHostingQuota()]
+
+            return HostingProvider(
+                enabled=True,
+                credentials=None,  # use credentials from the provider
+                quota_unit=quota_unit,
+                quotas=quotas
+            )
+
+        return HostingProvider(
+            enabled=False,
+            quota_unit=quota_unit,
+        )
+
+    def init_spark(self) -> HostingProvider:
+        quota_unit = QuotaUnit.TOKENS
+        if os.environ.get("HOSTED_SPARK_ENABLED") and os.environ.get("HOSTED_SPARK_ENABLED").lower() == 'true':
+            quotas = [FreeHostingQuota()]
+
+            return HostingProvider(
+                enabled=True,
+                credentials=None,  # use credentials from the provider
+                quota_unit=quota_unit,
+                quotas=quotas
+            )
+
+        return HostingProvider(
+            enabled=False,
+            quota_unit=quota_unit,
+        )
+
+    def init_zhipuai(self) -> HostingProvider:
+        quota_unit = QuotaUnit.TOKENS
+        if os.environ.get("HOSTED_ZHIPUAI_ENABLED") and os.environ.get("HOSTED_ZHIPUAI_ENABLED").lower() == 'true':
+            quotas = [FreeHostingQuota()]
+
+            return HostingProvider(
+                enabled=True,
+                credentials=None,  # use credentials from the provider
+                quota_unit=quota_unit,
+                quotas=quotas
+            )
+
+        return HostingProvider(
+            enabled=False,
+            quota_unit=quota_unit,
+        )
+
+    def init_moderation_config(self) -> HostedModerationConfig:
+        if os.environ.get("HOSTED_MODERATION_ENABLED") and os.environ.get("HOSTED_MODERATION_ENABLED").lower() == 'true' \
+                and os.environ.get("HOSTED_MODERATION_PROVIDERS"):
+            return HostedModerationConfig(
+                enabled=True,
+                providers=os.environ.get("HOSTED_MODERATION_PROVIDERS").split(',')
+            )
+
+        return HostedModerationConfig(
+            enabled=False
+        )

+ 7 - 11
api/core/index/index.py

@@ -1,18 +1,12 @@
-import json
-
 from flask import current_app
 from langchain.embeddings import OpenAIEmbeddings
 
 from core.embedding.cached_embedding import CacheEmbedding
 from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
 from core.index.vector_index.vector_index import VectorIndex
-from core.model_providers.model_factory import ModelFactory
-from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
-from core.model_providers.models.entity.model_params import ModelKwargs
-from core.model_providers.models.llm.openai_model import OpenAIModel
-from core.model_providers.providers.openai_provider import OpenAIProvider
+from core.model_manager import ModelManager
+from core.model_runtime.entities.model_entities import ModelType
 from models.dataset import Dataset
-from models.provider import Provider, ProviderType
 
 
 class IndexBuilder:
@@ -22,10 +16,12 @@ class IndexBuilder:
             if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality':
                 return None
 
-            embedding_model = ModelFactory.get_embedding_model(
+            model_manager = ModelManager()
+            embedding_model = model_manager.get_model_instance(
                 tenant_id=dataset.tenant_id,
-                model_provider_name=dataset.embedding_model_provider,
-                model_name=dataset.embedding_model
+                model_type=ModelType.TEXT_EMBEDDING,
+                provider=dataset.embedding_model_provider,
+                model=dataset.embedding_model
             )
 
             embeddings = CacheEmbedding(embedding_model)

+ 111 - 41
api/core/indexing_runner.py

@@ -18,9 +18,11 @@ from core.data_loader.loader.notion import NotionLoader
 from core.docstore.dataset_docstore import DatasetDocumentStore
 from core.generator.llm_generator import LLMGenerator
 from core.index.index import IndexBuilder
-from core.model_providers.error import ProviderTokenNotInitError
-from core.model_providers.model_factory import ModelFactory
-from core.model_providers.models.entity.message import MessageType
+from core.model_manager import ModelManager
+from core.errors.error import ProviderTokenNotInitError
+from core.model_runtime.entities.model_entities import ModelType, PriceType
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
 from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
@@ -36,6 +38,7 @@ class IndexingRunner:
 
     def __init__(self):
         self.storage = storage
+        self.model_manager = ModelManager()
 
     def run(self, dataset_documents: List[DatasetDocument]):
         """Run the indexing process."""
@@ -210,7 +213,7 @@ class IndexingRunner:
         """
         Estimate the indexing for the document.
         """
-        embedding_model = None
+        embedding_model_instance = None
         if dataset_id:
             dataset = Dataset.query.filter_by(
                 id=dataset_id
@@ -218,15 +221,17 @@ class IndexingRunner:
             if not dataset:
                 raise ValueError('Dataset not found.')
             if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
-                embedding_model = ModelFactory.get_embedding_model(
-                    tenant_id=dataset.tenant_id,
-                    model_provider_name=dataset.embedding_model_provider,
-                    model_name=dataset.embedding_model
+                embedding_model_instance = self.model_manager.get_model_instance(
+                    tenant_id=tenant_id,
+                    provider=dataset.embedding_model_provider,
+                    model_type=ModelType.TEXT_EMBEDDING,
+                    model=dataset.embedding_model
                 )
         else:
             if indexing_technique == 'high_quality':
-                embedding_model = ModelFactory.get_embedding_model(
-                    tenant_id=tenant_id
+                embedding_model_instance = self.model_manager.get_default_model_instance(
+                    tenant_id=tenant_id,
+                    model_type=ModelType.TEXT_EMBEDDING,
                 )
         tokens = 0
         preview_texts = []
@@ -255,32 +260,56 @@ class IndexingRunner:
             for document in documents:
                 if len(preview_texts) < 5:
                     preview_texts.append(document.page_content)
-                if indexing_technique == 'high_quality' or embedding_model:
-                    tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content))
+                if indexing_technique == 'high_quality' or embedding_model_instance:
+                    embedding_model_type_instance = embedding_model_instance.model_type_instance
+                    embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
+                    tokens += embedding_model_type_instance.get_num_tokens(
+                        model=embedding_model_instance.model,
+                        credentials=embedding_model_instance.credentials,
+                        texts=[self.filter_string(document.page_content)]
+                    )
 
         if doc_form and doc_form == 'qa_model':
-            text_generation_model = ModelFactory.get_text_generation_model(
-                tenant_id=tenant_id
+            model_instance = self.model_manager.get_default_model_instance(
+                tenant_id=tenant_id,
+                model_type=ModelType.LLM
             )
+
+            model_type_instance = model_instance.model_type_instance
+            model_type_instance = cast(LargeLanguageModel, model_type_instance)
+
             if len(preview_texts) > 0:
                 # qa model document
                 response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
                                                              doc_language)
                 document_qa_list = self.format_split_text(response)
+                price_info = model_type_instance.get_price(
+                    model=model_instance.model,
+                    credentials=model_instance.credentials,
+                    price_type=PriceType.INPUT,
+                    tokens=total_segments * 2000,
+                )
                 return {
                     "total_segments": total_segments * 20,
                     "tokens": total_segments * 2000,
-                    "total_price": '{:f}'.format(
-                        text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)),
-                    "currency": embedding_model.get_currency(),
+                    "total_price": '{:f}'.format(price_info.total_amount),
+                    "currency": price_info.currency,
                     "qa_preview": document_qa_list,
                     "preview": preview_texts
                 }
+        if embedding_model_instance:
+            embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_instance.model_type_instance)
+            embedding_price_info = embedding_model_type_instance.get_price(
+                model=embedding_model_instance.model,
+                credentials=embedding_model_instance.credentials,
+                price_type=PriceType.INPUT,
+                tokens=tokens
+            )
         return {
             "total_segments": total_segments,
             "tokens": tokens,
-            "total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)) if embedding_model else 0,
-            "currency": embedding_model.get_currency() if embedding_model else 'USD',
+            "total_price": '{:f}'.format(embedding_price_info.total_amount) if embedding_model_instance else 0,
+            "currency": embedding_price_info.currency if embedding_model_instance else 'USD',
             "preview": preview_texts
         }
 
@@ -290,7 +319,7 @@ class IndexingRunner:
         """
         Estimate the indexing for the document.
         """
-        embedding_model = None
+        embedding_model_instance = None
         if dataset_id:
             dataset = Dataset.query.filter_by(
                 id=dataset_id
@@ -298,15 +327,17 @@ class IndexingRunner:
             if not dataset:
                 raise ValueError('Dataset not found.')
             if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
-                embedding_model = ModelFactory.get_embedding_model(
-                    tenant_id=dataset.tenant_id,
-                    model_provider_name=dataset.embedding_model_provider,
-                    model_name=dataset.embedding_model
+                embedding_model_instance = self.model_manager.get_model_instance(
+                    tenant_id=tenant_id,
+                    provider=dataset.embedding_model_provider,
+                    model_type=ModelType.TEXT_EMBEDDING,
+                    model=dataset.embedding_model
                 )
         else:
             if indexing_technique == 'high_quality':
-                embedding_model = ModelFactory.get_embedding_model(
-                    tenant_id=tenant_id
+                embedding_model_instance = self.model_manager.get_default_model_instance(
+                    tenant_id=tenant_id,
+                    model_type=ModelType.TEXT_EMBEDDING
                 )
         # load data from notion
         tokens = 0
@@ -349,35 +380,63 @@ class IndexingRunner:
                     processing_rule=processing_rule
                 )
                 total_segments += len(documents)
+
+                embedding_model_type_instance = embedding_model_instance.model_type_instance
+                embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
+
                 for document in documents:
                     if len(preview_texts) < 5:
                         preview_texts.append(document.page_content)
-                    if indexing_technique == 'high_quality' or embedding_model:
-                        tokens += embedding_model.get_num_tokens(document.page_content)
+                    if indexing_technique == 'high_quality' or embedding_model_instance:
+                        tokens += embedding_model_type_instance.get_num_tokens(
+                            model=embedding_model_instance.model,
+                            credentials=embedding_model_instance.credentials,
+                            texts=[document.page_content]
+                        )
 
         if doc_form and doc_form == 'qa_model':
-            text_generation_model = ModelFactory.get_text_generation_model(
-                tenant_id=tenant_id
+            model_instance = self.model_manager.get_default_model_instance(
+                tenant_id=tenant_id,
+                model_type=ModelType.LLM
             )
+
+            model_type_instance = model_instance.model_type_instance
+            model_type_instance = cast(LargeLanguageModel, model_type_instance)
             if len(preview_texts) > 0:
                 # qa model document
                 response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
                                                              doc_language)
                 document_qa_list = self.format_split_text(response)
+
+                price_info = model_type_instance.get_price(
+                    model=model_instance.model,
+                    credentials=model_instance.credentials,
+                    price_type=PriceType.INPUT,
+                    tokens=total_segments * 2000,
+                )
+
                 return {
                     "total_segments": total_segments * 20,
                     "tokens": total_segments * 2000,
-                    "total_price": '{:f}'.format(
-                        text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)),
-                    "currency": embedding_model.get_currency(),
+                    "total_price": '{:f}'.format(price_info.total_amount),
+                    "currency": price_info.currency,
                     "qa_preview": document_qa_list,
                     "preview": preview_texts
                 }
+
+        embedding_model_type_instance = embedding_model_instance.model_type_instance
+        embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
+        embedding_price_info = embedding_model_type_instance.get_price(
+            model=embedding_model_instance.model,
+            credentials=embedding_model_instance.credentials,
+            price_type=PriceType.INPUT,
+            tokens=tokens
+        )
         return {
             "total_segments": total_segments,
             "tokens": tokens,
-            "total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)) if embedding_model else 0,
-            "currency": embedding_model.get_currency() if embedding_model else 'USD',
+            "total_price": '{:f}'.format(embedding_price_info.total_amount) if embedding_model_instance else 0,
+            "currency": embedding_price_info.currency if embedding_model_instance else 'USD',
             "preview": preview_texts
         }
 
@@ -656,25 +715,36 @@ class IndexingRunner:
         """
         vector_index = IndexBuilder.get_index(dataset, 'high_quality')
         keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
-        embedding_model = None
+        embedding_model_instance = None
         if dataset.indexing_technique == 'high_quality':
-            embedding_model = ModelFactory.get_embedding_model(
+            embedding_model_instance = self.model_manager.get_model_instance(
                 tenant_id=dataset.tenant_id,
-                model_provider_name=dataset.embedding_model_provider,
-                model_name=dataset.embedding_model
+                provider=dataset.embedding_model_provider,
+                model_type=ModelType.TEXT_EMBEDDING,
+                model=dataset.embedding_model
             )
 
         # chunk nodes by chunk size
         indexing_start_at = time.perf_counter()
         tokens = 0
         chunk_size = 100
+
+        embedding_model_type_instance = None
+        if embedding_model_instance:
+            embedding_model_type_instance = embedding_model_instance.model_type_instance
+            embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
+
         for i in range(0, len(documents), chunk_size):
             # check document is paused
             self._check_document_paused_status(dataset_document.id)
             chunk_documents = documents[i:i + chunk_size]
-            if dataset.indexing_technique == 'high_quality' or embedding_model:
+            if dataset.indexing_technique == 'high_quality' or embedding_model_type_instance:
                 tokens += sum(
-                    embedding_model.get_num_tokens(document.page_content)
+                    embedding_model_type_instance.get_num_tokens(
+                        embedding_model_instance.model,
+                        embedding_model_instance.credentials,
+                        [document.page_content]
+                    )
                     for document in chunk_documents
                 )
 

+ 0 - 95
api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py

@@ -1,95 +0,0 @@
-from typing import Any, List, Dict
-
-from langchain.memory.chat_memory import BaseChatMemory
-from langchain.schema import get_buffer_string, BaseMessage
-
-from core.file.message_file_parser import MessageFileParser
-from core.model_providers.models.entity.message import PromptMessage, MessageType, to_lc_messages
-from core.model_providers.models.llm.base import BaseLLM
-from extensions.ext_database import db
-from models.model import Conversation, Message
-
-
-class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
-    conversation: Conversation
-    human_prefix: str = "Human"
-    ai_prefix: str = "Assistant"
-    model_instance: BaseLLM
-    memory_key: str = "chat_history"
-    max_token_limit: int = 2000
-    message_limit: int = 10
-
-    @property
-    def buffer(self) -> List[BaseMessage]:
-        """String buffer of memory."""
-        app_model = self.conversation.app
-
-        # fetch limited messages desc, and return reversed
-        messages = db.session.query(Message).filter(
-            Message.conversation_id == self.conversation.id,
-            Message.answer != ''
-        ).order_by(Message.created_at.desc()).limit(self.message_limit).all()
-
-        messages = list(reversed(messages))
-        message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=self.conversation.app_id)
-
-        chat_messages: List[PromptMessage] = []
-        for message in messages:
-            files = message.message_files
-            if files:
-                file_objs = message_file_parser.transform_message_files(
-                    files, message.app_model_config
-                )
-
-                prompt_message_files = [file_obj.prompt_message_file for file_obj in file_objs]
-                chat_messages.append(PromptMessage(
-                    content=message.query,
-                    type=MessageType.USER,
-                    files=prompt_message_files
-                ))
-            else:
-                chat_messages.append(PromptMessage(content=message.query, type=MessageType.USER))
-
-            chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT))
-
-        if not chat_messages:
-            return []
-
-        # prune the chat message if it exceeds the max token limit
-        curr_buffer_length = self.model_instance.get_num_tokens(chat_messages)
-        if curr_buffer_length > self.max_token_limit:
-            pruned_memory = []
-            while curr_buffer_length > self.max_token_limit and chat_messages:
-                pruned_memory.append(chat_messages.pop(0))
-                curr_buffer_length = self.model_instance.get_num_tokens(chat_messages)
-
-        return to_lc_messages(chat_messages)
-
-    @property
-    def memory_variables(self) -> List[str]:
-        """Will always return list of memory variables.
-
-        :meta private:
-        """
-        return [self.memory_key]
-
-    def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
-        """Return history buffer."""
-        buffer: Any = self.buffer
-        if self.return_messages:
-            final_buffer: Any = buffer
-        else:
-            final_buffer = get_buffer_string(
-                buffer,
-                human_prefix=self.human_prefix,
-                ai_prefix=self.ai_prefix,
-            )
-        return {self.memory_key: final_buffer}
-
-    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
-        """Nothing should be saved or changed"""
-        pass
-
-    def clear(self) -> None:
-        """Nothing to clear, got a memory like a vault."""
-        pass

+ 0 - 36
api/core/memory/read_only_conversation_token_db_string_buffer_shared_memory.py

@@ -1,36 +0,0 @@
-from typing import Any, List, Dict
-
-from langchain.memory.chat_memory import BaseChatMemory
-from langchain.schema import get_buffer_string
-
-from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
-    ReadOnlyConversationTokenDBBufferSharedMemory
-
-
-class ReadOnlyConversationTokenDBStringBufferSharedMemory(BaseChatMemory):
-    memory: ReadOnlyConversationTokenDBBufferSharedMemory
-
-    @property
-    def memory_variables(self) -> List[str]:
-        """Return memory variables."""
-        return self.memory.memory_variables
-
-    def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
-        """Load memory variables from memory."""
-        buffer: Any = self.memory.buffer
-
-        final_buffer = get_buffer_string(
-            buffer,
-            human_prefix=self.memory.human_prefix,
-            ai_prefix=self.memory.ai_prefix,
-        )
-
-        return {self.memory.memory_key: final_buffer}
-
-    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
-        """Nothing should be saved or changed"""
-        pass
-
-    def clear(self) -> None:
-        """Nothing to clear, got a memory like a vault."""
-        pass

+ 109 - 0
api/core/memory/token_buffer_memory.py

@@ -0,0 +1,109 @@
+from core.file.message_file_parser import MessageFileParser
+from core.model_manager import ModelInstance
+from core.model_runtime.entities.message_entities import PromptMessage, TextPromptMessageContent, UserPromptMessage, \
+    AssistantPromptMessage, PromptMessageRole
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.model_providers import model_provider_factory
+from extensions.ext_database import db
+from models.model import Conversation, Message
+
+
+class TokenBufferMemory:
+    def __init__(self, conversation: Conversation, model_instance: ModelInstance) -> None:
+        self.conversation = conversation
+        self.model_instance = model_instance
+
+    def get_history_prompt_messages(self, max_token_limit: int = 2000,
+                                    message_limit: int = 10) -> list[PromptMessage]:
+        """
+        Get history prompt messages.
+        :param max_token_limit: max token limit
+        :param message_limit: message limit
+        """
+        app_record = self.conversation.app
+
+        # fetch limited messages, and return reversed
+        messages = db.session.query(Message).filter(
+            Message.conversation_id == self.conversation.id,
+            Message.answer != ''
+        ).order_by(Message.created_at.desc()).limit(message_limit).all()
+
+        messages = list(reversed(messages))
+        message_file_parser = MessageFileParser(
+            tenant_id=app_record.tenant_id,
+            app_id=app_record.id
+        )
+
+        prompt_messages = []
+        for message in messages:
+            files = message.message_files
+            if files:
+                file_objs = message_file_parser.transform_message_files(
+                    files, message.app_model_config
+                )
+
+                prompt_message_contents = [TextPromptMessageContent(data=message.query)]
+                for file_obj in file_objs:
+                    prompt_message_contents.append(file_obj.prompt_message_content)
+
+                prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
+            else:
+                prompt_messages.append(UserPromptMessage(content=message.query))
+
+            prompt_messages.append(AssistantPromptMessage(content=message.answer))
+
+        if not prompt_messages:
+            return []
+
+        # prune the chat message if it exceeds the max token limit
+        provider_instance = model_provider_factory.get_provider_instance(self.model_instance.provider)
+        model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
+
+        curr_message_tokens = model_type_instance.get_num_tokens(
+            self.model_instance.model,
+            self.model_instance.credentials,
+            prompt_messages
+        )
+
+        if curr_message_tokens > max_token_limit:
+            pruned_memory = []
+            while curr_message_tokens > max_token_limit and prompt_messages:
+                pruned_memory.append(prompt_messages.pop(0))
+                curr_message_tokens = model_type_instance.get_num_tokens(
+                    self.model_instance.model,
+                    self.model_instance.credentials,
+                    prompt_messages
+                )
+
+        return prompt_messages
+
+    def get_history_prompt_text(self, human_prefix: str = "Human",
+                                ai_prefix: str = "Assistant",
+                                max_token_limit: int = 2000,
+                                message_limit: int = 10) -> str:
+        """
+        Get history prompt text.
+        :param human_prefix: human prefix
+        :param ai_prefix: ai prefix
+        :param max_token_limit: max token limit
+        :param message_limit: message limit
+        :return:
+        """
+        prompt_messages = self.get_history_prompt_messages(
+            max_token_limit=max_token_limit,
+            message_limit=message_limit
+        )
+
+        string_messages = []
+        for m in prompt_messages:
+            if m.role == PromptMessageRole.USER:
+                role = human_prefix
+            elif m.role == PromptMessageRole.ASSISTANT:
+                role = ai_prefix
+            else:
+                continue
+
+            message = f"{role}: {m.content}"
+            string_messages.append(message)
+
+        return "\n".join(string_messages)

+ 209 - 0
api/core/model_manager.py

@@ -0,0 +1,209 @@
+from typing import Optional, Union, Generator, cast, List, IO
+
+from core.entities.provider_configuration import ProviderModelBundle
+from core.errors.error import ProviderTokenNotInitError
+from core.model_runtime.callbacks.base_callback import Callback
+from core.model_runtime.entities.llm_entities import LLMResult
+from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.entities.rerank_entities import RerankResult
+from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
+from core.model_runtime.model_providers.__base.rerank_model import RerankModel
+from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
+from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+from core.provider_manager import ProviderManager
+
+
+class ModelInstance:
+    """
+    Model instance class
+    """
+
+    def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> None:
+        self._provider_model_bundle = provider_model_bundle
+        self.model = model
+        self.provider = provider_model_bundle.configuration.provider.provider
+        self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
+        self.model_type_instance = self._provider_model_bundle.model_type_instance
+
+    def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBundle, model: str) -> dict:
+        """
+        Fetch credentials from provider model bundle
+        :param provider_model_bundle: provider model bundle
+        :param model: model name
+        :return:
+        """
+        credentials = provider_model_bundle.configuration.get_current_credentials(
+            model_type=provider_model_bundle.model_type_instance.model_type,
+            model=model
+        )
+
+        if credentials is None:
+            raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.")
+
+        return credentials
+
+    def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
+                   tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
+                   stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
+            -> Union[LLMResult, Generator]:
+        """
+        Invoke large language model
+
+        :param prompt_messages: prompt messages
+        :param model_parameters: model parameters
+        :param tools: tools for tool calling
+        :param stop: stop words
+        :param stream: is stream response
+        :param user: unique user id
+        :param callbacks: callbacks
+        :return: full response or stream response chunk generator result
+        """
+        if not isinstance(self.model_type_instance, LargeLanguageModel):
+            raise Exception(f"Model type instance is not LargeLanguageModel")
+
+        self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
+        return self.model_type_instance.invoke(
+            model=self.model,
+            credentials=self.credentials,
+            prompt_messages=prompt_messages,
+            model_parameters=model_parameters,
+            tools=tools,
+            stop=stop,
+            stream=stream,
+            user=user,
+            callbacks=callbacks
+        )
+
+    def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \
+            -> TextEmbeddingResult:
+        """
+        Invoke large language model
+
+        :param texts: texts to embed
+        :param user: unique user id
+        :return: embeddings result
+        """
+        if not isinstance(self.model_type_instance, TextEmbeddingModel):
+            raise Exception(f"Model type instance is not TextEmbeddingModel")
+
+        self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
+        return self.model_type_instance.invoke(
+            model=self.model,
+            credentials=self.credentials,
+            texts=texts,
+            user=user
+        )
+
+    def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
+                      user: Optional[str] = None) \
+            -> RerankResult:
+        """
+        Invoke rerank model
+
+        :param query: search query
+        :param docs: docs for reranking
+        :param score_threshold: score threshold
+        :param top_n: top n
+        :param user: unique user id
+        :return: rerank result
+        """
+        if not isinstance(self.model_type_instance, RerankModel):
+            raise Exception(f"Model type instance is not RerankModel")
+
+        self.model_type_instance = cast(RerankModel, self.model_type_instance)
+        return self.model_type_instance.invoke(
+            model=self.model,
+            credentials=self.credentials,
+            query=query,
+            docs=docs,
+            score_threshold=score_threshold,
+            top_n=top_n,
+            user=user
+        )
+
+    def invoke_moderation(self, text: str, user: Optional[str] = None) \
+            -> bool:
+        """
+        Invoke moderation model
+
+        :param text: text to moderate
+        :param user: unique user id
+        :return: false if text is safe, true otherwise
+        """
+        if not isinstance(self.model_type_instance, ModerationModel):
+            raise Exception(f"Model type instance is not ModerationModel")
+
+        self.model_type_instance = cast(ModerationModel, self.model_type_instance)
+        return self.model_type_instance.invoke(
+            model=self.model,
+            credentials=self.credentials,
+            text=text,
+            user=user
+        )
+
+    def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \
+            -> str:
+        """
+        Invoke large language model
+
+        :param file: audio file
+        :param user: unique user id
+        :return: text for given audio file
+        """
+        if not isinstance(self.model_type_instance, Speech2TextModel):
+            raise Exception(f"Model type instance is not Speech2TextModel")
+
+        self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
+        return self.model_type_instance.invoke(
+            model=self.model,
+            credentials=self.credentials,
+            file=file,
+            user=user
+        )
+
+
+class ModelManager:
+    def __init__(self) -> None:
+        self._provider_manager = ProviderManager()
+
+    def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance:
+        """
+        Get model instance
+        :param tenant_id: tenant id
+        :param provider: provider name
+        :param model_type: model type
+        :param model: model name
+        :return:
+        """
+        provider_model_bundle = self._provider_manager.get_provider_model_bundle(
+            tenant_id=tenant_id,
+            provider=provider,
+            model_type=model_type
+        )
+
+        return ModelInstance(provider_model_bundle, model)
+
+    def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance:
+        """
+        Get default model instance
+        :param tenant_id: tenant id
+        :param model_type: model type
+        :return:
+        """
+        default_model_entity = self._provider_manager.get_default_model(
+            tenant_id=tenant_id,
+            model_type=model_type
+        )
+
+        if not default_model_entity:
+            raise ProviderTokenNotInitError(f"Default model not found for {model_type}")
+
+        return self.get_model_instance(
+            tenant_id=tenant_id,
+            provider=default_model_entity.provider.provider,
+            model_type=model_type,
+            model=default_model_entity.model
+        )

+ 0 - 335
api/core/model_providers/model_factory.py

@@ -1,335 +0,0 @@
-from typing import Optional
-
-from langchain.callbacks.base import Callbacks
-
-from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
-from core.model_providers.model_provider_factory import ModelProviderFactory, DEFAULT_MODELS
-from core.model_providers.models.base import BaseProviderModel
-from core.model_providers.models.embedding.base import BaseEmbedding
-from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
-from core.model_providers.models.llm.base import BaseLLM
-from core.model_providers.models.moderation.base import BaseModeration
-from core.model_providers.models.reranking.base import BaseReranking
-from core.model_providers.models.speech2text.base import BaseSpeech2Text
-from extensions.ext_database import db
-from models.provider import TenantDefaultModel
-
-
-class ModelFactory:
-
-    @classmethod
-    def get_text_generation_model_from_model_config(cls, tenant_id: str,
-                                                    model_config: dict,
-                                                    streaming: bool = False,
-                                                    callbacks: Callbacks = None) -> Optional[BaseLLM]:
-        provider_name = model_config.get("provider")
-        model_name = model_config.get("name")
-        completion_params = model_config.get("completion_params", {})
-
-        return cls.get_text_generation_model(
-            tenant_id=tenant_id,
-            model_provider_name=provider_name,
-            model_name=model_name,
-            model_kwargs=ModelKwargs(
-                temperature=completion_params.get('temperature', 0),
-                max_tokens=completion_params.get('max_tokens', 256),
-                top_p=completion_params.get('top_p', 0),
-                frequency_penalty=completion_params.get('frequency_penalty', 0.1),
-                presence_penalty=completion_params.get('presence_penalty', 0.1)
-            ),
-            streaming=streaming,
-            callbacks=callbacks
-        )
-
-    @classmethod
-    def get_text_generation_model(cls,
-                                  tenant_id: str,
-                                  model_provider_name: Optional[str] = None,
-                                  model_name: Optional[str] = None,
-                                  model_kwargs: Optional[ModelKwargs] = None,
-                                  streaming: bool = False,
-                                  callbacks: Callbacks = None,
-                                  deduct_quota: bool = True) -> Optional[BaseLLM]:
-        """
-        get text generation model.
-
-        :param tenant_id: a string representing the ID of the tenant.
-        :param model_provider_name:
-        :param model_name:
-        :param model_kwargs:
-        :param streaming:
-        :param callbacks:
-        :param deduct_quota:
-        :return:
-        """
-        is_default_model = False
-        if model_provider_name is None and model_name is None:
-            default_model = cls.get_default_model(tenant_id, ModelType.TEXT_GENERATION)
-
-            if not default_model:
-                raise LLMBadRequestError(f"Default model is not available. "
-                                         f"Please configure a Default System Reasoning Model "
-                                         f"in the Settings -> Model Provider.")
-
-            model_provider_name = default_model.provider_name
-            model_name = default_model.model_name
-            is_default_model = True
-
-        # get model provider
-        model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
-
-        if not model_provider:
-            raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
-
-        # init text generation model
-        model_class = model_provider.get_model_class(model_type=ModelType.TEXT_GENERATION)
-
-        try:
-            model_instance = model_class(
-                model_provider=model_provider,
-                name=model_name,
-                model_kwargs=model_kwargs,
-                streaming=streaming,
-                callbacks=callbacks
-            )
-        except LLMBadRequestError as e:
-            if is_default_model:
-                raise LLMBadRequestError(f"Default model {model_name} is not available. "
-                                         f"Please check your model provider credentials.")
-            else:
-                raise e
-
-        if is_default_model or not deduct_quota:
-            model_instance.deduct_quota = False
-
-        return model_instance
-
-    @classmethod
-    def get_embedding_model(cls,
-                            tenant_id: str,
-                            model_provider_name: Optional[str] = None,
-                            model_name: Optional[str] = None) -> Optional[BaseEmbedding]:
-        """
-        get embedding model.
-
-        :param tenant_id: a string representing the ID of the tenant.
-        :param model_provider_name:
-        :param model_name:
-        :return:
-        """
-        if model_provider_name is None and model_name is None:
-            default_model = cls.get_default_model(tenant_id, ModelType.EMBEDDINGS)
-
-            if not default_model:
-                raise LLMBadRequestError(f"Default model is not available. "
-                                         f"Please configure a Default Embedding Model "
-                                         f"in the Settings -> Model Provider.")
-
-            model_provider_name = default_model.provider_name
-            model_name = default_model.model_name
-
-        # get model provider
-        model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
-
-        if not model_provider:
-            raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
-
-        # init embedding model
-        model_class = model_provider.get_model_class(model_type=ModelType.EMBEDDINGS)
-        return model_class(
-            model_provider=model_provider,
-            name=model_name
-        )
-
-
-    @classmethod
-    def get_reranking_model(cls,
-                            tenant_id: str,
-                            model_provider_name: Optional[str] = None,
-                            model_name: Optional[str] = None) -> Optional[BaseReranking]:
-        """
-        get reranking model.
-
-        :param tenant_id: a string representing the ID of the tenant.
-        :param model_provider_name:
-        :param model_name:
-        :return:
-        """
-        if (model_provider_name is None or len(model_provider_name) == 0) and (model_name is None or len(model_name) == 0):
-            default_model = cls.get_default_model(tenant_id, ModelType.RERANKING)
-
-            if not default_model:
-                raise LLMBadRequestError(f"Default model is not available. "
-                                         f"Please configure a Default Reranking Model "
-                                         f"in the Settings -> Model Provider.")
-
-            model_provider_name = default_model.provider_name
-            model_name = default_model.model_name
-
-        # get model provider
-        model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
-
-        if not model_provider:
-            raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
-
-        # init reranking model
-        model_class = model_provider.get_model_class(model_type=ModelType.RERANKING)
-        return model_class(
-            model_provider=model_provider,
-            name=model_name
-        )
-
-    @classmethod
-    def get_speech2text_model(cls,
-                              tenant_id: str,
-                              model_provider_name: Optional[str] = None,
-                              model_name: Optional[str] = None) -> Optional[BaseSpeech2Text]:
-        """
-        get speech to text model.
-
-        :param tenant_id: a string representing the ID of the tenant.
-        :param model_provider_name:
-        :param model_name:
-        :return:
-        """
-        if model_provider_name is None and model_name is None:
-            default_model = cls.get_default_model(tenant_id, ModelType.SPEECH_TO_TEXT)
-
-            if not default_model:
-                raise LLMBadRequestError(f"Default model is not available. "
-                                         f"Please configure a Default Speech-to-Text Model "
-                                         f"in the Settings -> Model Provider.")
-
-            model_provider_name = default_model.provider_name
-            model_name = default_model.model_name
-
-        # get model provider
-        model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
-
-        if not model_provider:
-            raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
-
-        # init speech to text model
-        model_class = model_provider.get_model_class(model_type=ModelType.SPEECH_TO_TEXT)
-        return model_class(
-            model_provider=model_provider,
-            name=model_name
-        )
-
-    @classmethod
-    def get_moderation_model(cls,
-                             tenant_id: str,
-                             model_provider_name: str,
-                             model_name: str) -> Optional[BaseModeration]:
-        """
-        get moderation model.
-
-        :param tenant_id: a string representing the ID of the tenant.
-        :param model_provider_name:
-        :param model_name:
-        :return:
-        """
-        # get model provider
-        model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
-
-        if not model_provider:
-            raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
-
-        # init moderation model
-        model_class = model_provider.get_model_class(model_type=ModelType.MODERATION)
-        return model_class(
-            model_provider=model_provider,
-            name=model_name
-        )
-
-    @classmethod
-    def get_default_model(cls, tenant_id: str, model_type: ModelType) -> TenantDefaultModel:
-        """
-        get default model of model type.
-
-        :param tenant_id:
-        :param model_type:
-        :return:
-        """
-        # get default model
-        default_model = db.session.query(TenantDefaultModel) \
-            .filter(
-            TenantDefaultModel.tenant_id == tenant_id,
-            TenantDefaultModel.model_type == model_type.value
-        ).first()
-
-        if not default_model:
-            model_provider_rules = ModelProviderFactory.get_provider_rules()
-            for model_provider_name, model_provider_rule in model_provider_rules.items():
-                model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
-                if not model_provider:
-                    continue
-
-                model_list = model_provider.get_supported_model_list(model_type)
-                if model_list:
-                    model_info = model_list[0]
-                    default_model = TenantDefaultModel(
-                        tenant_id=tenant_id,
-                        model_type=model_type.value,
-                        provider_name=model_provider_name,
-                        model_name=model_info['id']
-                    )
-                    db.session.add(default_model)
-                    db.session.commit()
-                    break
-
-        return default_model
-
-    @classmethod
-    def update_default_model(cls,
-                             tenant_id: str,
-                             model_type: ModelType,
-                             provider_name: str,
-                             model_name: str) -> TenantDefaultModel:
-        """
-        update default model of model type.
-
-        :param tenant_id:
-        :param model_type:
-        :param provider_name:
-        :param model_name:
-        :return:
-        """
-        model_provider_name = ModelProviderFactory.get_provider_names()
-        if provider_name not in model_provider_name:
-            raise ValueError(f'Invalid provider name: {provider_name}')
-
-        model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, provider_name)
-
-        if not model_provider:
-            raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
-
-        model_list = model_provider.get_supported_model_list(model_type)
-        model_ids = [model['id'] for model in model_list]
-        if model_name not in model_ids:
-            raise ValueError(f'Invalid model name: {model_name}')
-
-        # get default model
-        default_model = db.session.query(TenantDefaultModel) \
-            .filter(
-            TenantDefaultModel.tenant_id == tenant_id,
-            TenantDefaultModel.model_type == model_type.value
-        ).first()
-
-        if default_model:
-            # update default model
-            default_model.provider_name = provider_name
-            default_model.model_name = model_name
-            db.session.commit()
-        else:
-            # create default model
-            default_model = TenantDefaultModel(
-                tenant_id=tenant_id,
-                model_type=model_type.value,
-                provider_name=provider_name,
-                model_name=model_name,
-            )
-            db.session.add(default_model)
-            db.session.commit()
-
-        return default_model

+ 0 - 276
api/core/model_providers/model_provider_factory.py

@@ -1,276 +0,0 @@
-from typing import Type
-
-from sqlalchemy.exc import IntegrityError
-
-from core.model_providers.models.entity.model_params import ModelType
-from core.model_providers.providers.base import BaseModelProvider
-from core.model_providers.rules import provider_rules
-from extensions.ext_database import db
-from models.provider import TenantPreferredModelProvider, ProviderType, Provider, ProviderQuotaType
-
-DEFAULT_MODELS = {
-    ModelType.TEXT_GENERATION.value: {
-        'provider_name': 'openai',
-        'model_name': 'gpt-3.5-turbo',
-    },
-    ModelType.EMBEDDINGS.value: {
-        'provider_name': 'openai',
-        'model_name': 'text-embedding-ada-002',
-    },
-    ModelType.SPEECH_TO_TEXT.value: {
-        'provider_name': 'openai',
-        'model_name': 'whisper-1',
-    }
-}
-
-
-class ModelProviderFactory:
-    @classmethod
-    def get_model_provider_class(cls, provider_name: str) -> Type[BaseModelProvider]:
-        if provider_name == 'openai':
-            from core.model_providers.providers.openai_provider import OpenAIProvider
-            return OpenAIProvider
-        elif provider_name == 'anthropic':
-            from core.model_providers.providers.anthropic_provider import AnthropicProvider
-            return AnthropicProvider
-        elif provider_name == 'minimax':
-            from core.model_providers.providers.minimax_provider import MinimaxProvider
-            return MinimaxProvider
-        elif provider_name == 'spark':
-            from core.model_providers.providers.spark_provider import SparkProvider
-            return SparkProvider
-        elif provider_name == 'tongyi':
-            from core.model_providers.providers.tongyi_provider import TongyiProvider
-            return TongyiProvider
-        elif provider_name == 'wenxin':
-            from core.model_providers.providers.wenxin_provider import WenxinProvider
-            return WenxinProvider
-        elif provider_name == 'zhipuai':
-            from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider
-            return ZhipuAIProvider
-        elif provider_name == 'chatglm':
-            from core.model_providers.providers.chatglm_provider import ChatGLMProvider
-            return ChatGLMProvider
-        elif provider_name == 'baichuan':
-            from core.model_providers.providers.baichuan_provider import BaichuanProvider
-            return BaichuanProvider
-        elif provider_name == 'azure_openai':
-            from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider
-            return AzureOpenAIProvider
-        elif provider_name == 'replicate':
-            from core.model_providers.providers.replicate_provider import ReplicateProvider
-            return ReplicateProvider
-        elif provider_name == 'huggingface_hub':
-            from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
-            return HuggingfaceHubProvider
-        elif provider_name == 'xinference':
-            from core.model_providers.providers.xinference_provider import XinferenceProvider
-            return XinferenceProvider
-        elif provider_name == 'openllm':
-            from core.model_providers.providers.openllm_provider import OpenLLMProvider
-            return OpenLLMProvider
-        elif provider_name == 'localai':
-            from core.model_providers.providers.localai_provider import LocalAIProvider
-            return LocalAIProvider
-        elif provider_name == 'cohere':
-            from core.model_providers.providers.cohere_provider import CohereProvider
-            return CohereProvider
-        elif provider_name == 'jina':
-            from core.model_providers.providers.jina_provider import JinaProvider
-            return JinaProvider
-        else:
-            raise NotImplementedError
-
-    @classmethod
-    def get_provider_names(cls):
-        """
-        Returns a list of provider names.
-        """
-        return list(provider_rules.keys())
-
-    @classmethod
-    def get_provider_rules(cls):
-        """
-        Returns a list of provider rules.
-
-        :return:
-        """
-        return provider_rules
-
-    @classmethod
-    def get_provider_rule(cls, provider_name: str):
-        """
-        Returns provider rule.
-        """
-        return provider_rules[provider_name]
-
-    @classmethod
-    def get_preferred_model_provider(cls, tenant_id: str, model_provider_name: str):
-        """
-        get preferred model provider.
-
-        :param tenant_id: a string representing the ID of the tenant.
-        :param model_provider_name:
-        :return:
-        """
-        # get preferred provider
-        preferred_provider = cls._get_preferred_provider(tenant_id, model_provider_name)
-        if not preferred_provider or not preferred_provider.is_valid:
-            return None
-
-        # init model provider
-        model_provider_class = ModelProviderFactory.get_model_provider_class(model_provider_name)
-        return model_provider_class(provider=preferred_provider)
-
-    @classmethod
-    def get_preferred_type_by_preferred_model_provider(cls,
-                                                       tenant_id: str,
-                                                       model_provider_name: str,
-                                                       preferred_model_provider: TenantPreferredModelProvider):
-        """
-        get preferred provider type by preferred model provider.
-
-        :param model_provider_name:
-        :param preferred_model_provider:
-        :return:
-        """
-        if not preferred_model_provider:
-            model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name)
-            support_provider_types = model_provider_rules['support_provider_types']
-
-            if ProviderType.CUSTOM.value in support_provider_types:
-                custom_provider = db.session.query(Provider) \
-                    .filter(
-                        Provider.tenant_id == tenant_id,
-                        Provider.provider_name == model_provider_name,
-                        Provider.provider_type == ProviderType.CUSTOM.value,
-                        Provider.is_valid == True
-                    ).first()
-
-                if custom_provider:
-                    return ProviderType.CUSTOM.value
-
-            model_provider = cls.get_model_provider_class(model_provider_name)
-
-            if ProviderType.SYSTEM.value in support_provider_types \
-                    and model_provider.is_provider_type_system_supported():
-                return ProviderType.SYSTEM.value
-            elif ProviderType.CUSTOM.value in support_provider_types:
-                return ProviderType.CUSTOM.value
-        else:
-            return preferred_model_provider.preferred_provider_type
-
-    @classmethod
-    def _get_preferred_provider(cls, tenant_id: str, model_provider_name: str):
-        """
-        get preferred provider of tenant.
-
-        :param tenant_id:
-        :param model_provider_name:
-        :return:
-        """
-        # get preferred provider type
-        preferred_provider_type = cls._get_preferred_provider_type(tenant_id, model_provider_name)
-
-        # get providers by preferred provider type
-        providers = db.session.query(Provider) \
-            .filter(
-                Provider.tenant_id == tenant_id,
-                Provider.provider_name == model_provider_name,
-                Provider.provider_type == preferred_provider_type
-            ).all()
-
-        no_system_provider = False
-        if preferred_provider_type == ProviderType.SYSTEM.value:
-            quota_type_to_provider_dict = {}
-            for provider in providers:
-                quota_type_to_provider_dict[provider.quota_type] = provider
-
-            model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name)
-            for quota_type_enum in ProviderQuotaType:
-                quota_type = quota_type_enum.value
-                if quota_type in model_provider_rules['system_config']['supported_quota_types']:
-                    if quota_type in quota_type_to_provider_dict.keys():
-                        provider = quota_type_to_provider_dict[quota_type]
-                        if provider.is_valid and provider.quota_limit > provider.quota_used:
-                            return provider
-                    elif quota_type == ProviderQuotaType.TRIAL.value:
-                        try:
-                            provider = Provider(
-                                tenant_id=tenant_id,
-                                provider_name=model_provider_name,
-                                provider_type=ProviderType.SYSTEM.value,
-                                is_valid=True,
-                                quota_type=ProviderQuotaType.TRIAL.value,
-                                quota_limit=model_provider_rules['system_config']['quota_limit'],
-                                quota_used=0
-                            )
-                            db.session.add(provider)
-                            db.session.commit()
-                        except IntegrityError:
-                            db.session.rollback()
-                            provider = db.session.query(Provider) \
-                                .filter(
-                                Provider.tenant_id == tenant_id,
-                                Provider.provider_name == model_provider_name,
-                                Provider.provider_type == ProviderType.SYSTEM.value,
-                                Provider.quota_type == ProviderQuotaType.TRIAL.value
-                            ).first()
-
-                        if provider.quota_limit == 0:
-                            return None
-
-                        return provider
-
-            no_system_provider = True
-
-        if no_system_provider:
-            providers = db.session.query(Provider) \
-                .filter(
-                Provider.tenant_id == tenant_id,
-                Provider.provider_name == model_provider_name,
-                Provider.provider_type == ProviderType.CUSTOM.value
-            ).all()
-
-        if preferred_provider_type == ProviderType.CUSTOM.value or no_system_provider:
-            if providers:
-                return providers[0]
-            else:
-                try:
-                    provider = Provider(
-                        tenant_id=tenant_id,
-                        provider_name=model_provider_name,
-                        provider_type=ProviderType.CUSTOM.value,
-                        is_valid=False
-                    )
-                    db.session.add(provider)
-                    db.session.commit()
-                except IntegrityError:
-                    db.session.rollback()
-                    provider = db.session.query(Provider) \
-                        .filter(
-                            Provider.tenant_id == tenant_id,
-                            Provider.provider_name == model_provider_name,
-                            Provider.provider_type == ProviderType.CUSTOM.value
-                        ).first()
-
-                return provider
-
-        return None
-
-    @classmethod
-    def _get_preferred_provider_type(cls, tenant_id: str, model_provider_name: str):
-        """
-        get preferred provider type of tenant.
-
-        :param tenant_id:
-        :param model_provider_name:
-        :return:
-        """
-        preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
-            .filter(
-            TenantPreferredModelProvider.tenant_id == tenant_id,
-            TenantPreferredModelProvider.provider_name == model_provider_name
-        ).first()
-
-        return cls.get_preferred_type_by_preferred_model_provider(tenant_id, model_provider_name, preferred_model_provider)

Vissa filer visades inte eftersom för många filer har ändrats