Przeglądaj źródła

refactor: move the embedding to the rag module and abstract the rerank runner for extension (#9423)

zhuhao 6 miesięcy temu
rodzic
commit
b90ad587c2
61 zmienionych plików z 135 dodań i 78 usunięć
  1. 0 0
      api/core/entities/embedding_type.py
  2. 1 1
      api/core/model_manager.py
  3. 1 1
      api/core/model_runtime/model_providers/__base/text_embedding_model.py
  4. 1 1
      api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py
  5. 1 1
      api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py
  6. 1 1
      api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py
  7. 1 1
      api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py
  8. 1 1
      api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py
  9. 1 1
      api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py
  10. 1 1
      api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py
  11. 1 1
      api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py
  12. 1 1
      api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py
  13. 1 1
      api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py
  14. 1 1
      api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py
  15. 1 1
      api/core/model_runtime/model_providers/mixedbread/text_embedding/text_embedding.py
  16. 1 1
      api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py
  17. 1 1
      api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py
  18. 1 1
      api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py
  19. 1 1
      api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py
  20. 1 1
      api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py
  21. 1 1
      api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py
  22. 1 1
      api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py
  23. 1 1
      api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py
  24. 1 1
      api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py
  25. 1 1
      api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py
  26. 1 1
      api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py
  27. 1 1
      api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py
  28. 1 1
      api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py
  29. 1 1
      api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py
  30. 1 1
      api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py
  31. 1 1
      api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py
  32. 1 1
      api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py
  33. 1 1
      api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py
  34. 1 1
      api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py
  35. 32 21
      api/core/rag/data_post_processor/data_post_processor.py
  36. 1 1
      api/core/rag/datasource/retrieval_service.py
  37. 1 1
      api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py
  38. 1 1
      api/core/rag/datasource/vdb/baidu/baidu_vector.py
  39. 1 1
      api/core/rag/datasource/vdb/chroma/chroma_vector.py
  40. 1 1
      api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
  41. 1 1
      api/core/rag/datasource/vdb/milvus/milvus_vector.py
  42. 1 1
      api/core/rag/datasource/vdb/myscale/myscale_vector.py
  43. 1 1
      api/core/rag/datasource/vdb/opensearch/opensearch_vector.py
  44. 1 1
      api/core/rag/datasource/vdb/oracle/oraclevector.py
  45. 1 1
      api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
  46. 1 1
      api/core/rag/datasource/vdb/pgvector/pgvector.py
  47. 1 1
      api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
  48. 1 1
      api/core/rag/datasource/vdb/relyt/relyt_vector.py
  49. 1 1
      api/core/rag/datasource/vdb/tencent/tencent_vector.py
  50. 1 1
      api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py
  51. 2 2
      api/core/rag/datasource/vdb/vector_factory.py
  52. 1 1
      api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py
  53. 1 1
      api/core/rag/datasource/vdb/weaviate/weaviate_vector.py
  54. 0 0
      api/core/rag/embedding/__init__.py
  55. 2 2
      api/core/rag/embedding/cached_embedding.py
  56. 2 0
      api/core/rag/embedding/embedding_base.py
  57. 26 0
      api/core/rag/rerank/rerank_base.py
  58. 16 0
      api/core/rag/rerank/rerank_factory.py
  59. 2 1
      api/core/rag/rerank/rerank_model.py
  60. 0 0
      api/core/rag/rerank/rerank_type.py
  61. 3 2
      api/core/rag/rerank/weight_rerank.py

+ 0 - 0
api/core/embedding/embedding_constant.py → api/core/entities/embedding_type.py


+ 1 - 1
api/core/model_manager.py

@@ -3,7 +3,7 @@ import os
 from collections.abc import Callable, Generator, Sequence
 from typing import IO, Optional, Union, cast
 
-from core.embedding.embedding_constant import EmbeddingInputType
+from core.entities.embedding_type import EmbeddingInputType
 from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
 from core.entities.provider_entities import ModelLoadBalancingConfiguration
 from core.errors.error import ProviderTokenNotInitError

+ 1 - 1
api/core/model_runtime/model_providers/__base/text_embedding_model.py

@@ -4,7 +4,7 @@ from typing import Optional
 
 from pydantic import ConfigDict
 
-from core.embedding.embedding_constant import EmbeddingInputType
+from core.entities.embedding_type import EmbeddingInputType
 from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
 from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
 from core.model_runtime.model_providers.__base.ai_model import AIModel

+ 1 - 1
api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py

@@ -7,7 +7,7 @@ import numpy as np
 import tiktoken
 from openai import AzureOpenAI
 
-from core.embedding.embedding_constant import EmbeddingInputType
+from core.entities.embedding_type import EmbeddingInputType
 from core.model_runtime.entities.model_entities import AIModelEntity, PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 from core.model_runtime.errors.validate import CredentialsValidateFailedError

+ 1 - 1
api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py

@@ -4,7 +4,7 @@ from typing import Optional
 
 from requests import post
 
-from core.embedding.embedding_constant import EmbeddingInputType
+from core.entities.embedding_type import EmbeddingInputType
 from core.model_runtime.entities.model_entities import PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 from core.model_runtime.errors.invoke import (

+ 1 - 1
api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py

@@ -13,7 +13,7 @@ from botocore.exceptions import (
     UnknownServiceError,
 )
 
-from core.embedding.embedding_constant import EmbeddingInputType
+from core.entities.embedding_type import EmbeddingInputType
 from core.model_runtime.entities.model_entities import PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 from core.model_runtime.errors.invoke import (

+ 1 - 1
api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py

@@ -5,7 +5,7 @@ import cohere
 import numpy as np
 from cohere.core import RequestOptions
 
-from core.embedding.embedding_constant import EmbeddingInputType
+from core.entities.embedding_type import EmbeddingInputType
 from core.model_runtime.entities.model_entities import PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 from core.model_runtime.errors.invoke import (

+ 1 - 1
api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py

@@ -5,7 +5,7 @@ from typing import Optional, Union
 import numpy as np
 from openai import OpenAI
 
-from core.embedding.embedding_constant import EmbeddingInputType
+from core.entities.embedding_type import EmbeddingInputType
 from core.model_runtime.entities.model_entities import PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 from core.model_runtime.errors.validate import CredentialsValidateFailedError

+ 1 - 1
api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py

@@ -6,7 +6,7 @@ import numpy as np
 import requests
 from huggingface_hub import HfApi, InferenceClient
 
-from core.embedding.embedding_constant import EmbeddingInputType
+from core.entities.embedding_type import EmbeddingInputType
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult

+ 1 - 1
api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py

@@ -1,7 +1,7 @@
 import time
 from typing import Optional
 
-from core.embedding.embedding_constant import EmbeddingInputType
+from core.entities.embedding_type import EmbeddingInputType
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult

+ 1 - 1
api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py

@@ -9,7 +9,7 @@ from tencentcloud.common.profile.client_profile import ClientProfile
 from tencentcloud.common.profile.http_profile import HttpProfile
 from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
 
-from core.embedding.embedding_constant import EmbeddingInputType
+from core.entities.embedding_type import EmbeddingInputType
 from core.model_runtime.entities.model_entities import PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 from core.model_runtime.errors.invoke import (

+ 1 - 1
api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py

@@ -4,7 +4,7 @@ from typing import Optional
 
 from requests import post
 
-from core.embedding.embedding_constant import EmbeddingInputType
+from core.entities.embedding_type import EmbeddingInputType
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult

+ 1 - 1
api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py

@@ -5,7 +5,7 @@ from typing import Optional
 from requests import post
 from yarl import URL
 
-from core.embedding.embedding_constant import EmbeddingInputType
+from core.entities.embedding_type import EmbeddingInputType
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult

+ 1 - 1
api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py

@@ -4,7 +4,7 @@ from typing import Optional
 
 from requests import post
 
-from core.embedding.embedding_constant import EmbeddingInputType
+from core.entities.embedding_type import EmbeddingInputType
 from core.model_runtime.entities.model_entities import PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 from core.model_runtime.errors.invoke import (

+ 1 - 1
api/core/model_runtime/model_providers/mixedbread/text_embedding/text_embedding.py

@@ -4,7 +4,7 @@ from typing import Optional
 
 import requests
 
-from core.embedding.embedding_constant import EmbeddingInputType
+from core.entities.embedding_type import EmbeddingInputType
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult

+ 1 - 1
api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py

@@ -5,7 +5,7 @@ from typing import Optional
 from nomic import embed
 from nomic import login as nomic_login
 
-from core.embedding.embedding_constant import EmbeddingInputType
+from core.entities.embedding_type import EmbeddingInputType
 from core.model_runtime.entities.model_entities import PriceType
 from core.model_runtime.entities.text_embedding_entities import (
     EmbeddingUsage,

+ 1 - 1
api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py

@@ -4,7 +4,7 @@ from typing import Optional
 
 from requests import post
 
-from core.embedding.embedding_constant import EmbeddingInputType
+from core.entities.embedding_type import EmbeddingInputType
 from core.model_runtime.entities.model_entities import PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 from core.model_runtime.errors.invoke import (

+ 1 - 1
api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py

@@ -6,7 +6,7 @@ from typing import Optional
 import numpy as np
 import oci
 
-from core.embedding.embedding_constant import EmbeddingInputType
+from core.entities.embedding_type import EmbeddingInputType
 from core.model_runtime.entities.model_entities import PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 from core.model_runtime.errors.invoke import (

+ 1 - 1
api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py

@@ -8,7 +8,7 @@ from urllib.parse import urljoin
 import numpy as np
 import requests
 
-from core.embedding.embedding_constant import EmbeddingInputType
+from core.entities.embedding_type import EmbeddingInputType
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.model_entities import (
     AIModelEntity,

+ 1 - 1
api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py

@@ -6,7 +6,7 @@ import numpy as np
 import tiktoken
 from openai import OpenAI
 
-from core.embedding.embedding_constant import EmbeddingInputType
+from core.entities.embedding_type import EmbeddingInputType
 from core.model_runtime.entities.model_entities import PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 from core.model_runtime.errors.validate import CredentialsValidateFailedError

+ 1 - 1
api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py

@@ -7,7 +7,7 @@ from urllib.parse import urljoin
 import numpy as np
 import requests
 
-from core.embedding.embedding_constant import EmbeddingInputType
+from core.entities.embedding_type import EmbeddingInputType
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.model_entities import (
     AIModelEntity,

+ 1 - 1
api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py

@@ -5,7 +5,7 @@ from typing import Optional
 from requests import post
 from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema
 
-from core.embedding.embedding_constant import EmbeddingInputType
+from core.entities.embedding_type import EmbeddingInputType
 from core.model_runtime.entities.model_entities import PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 from core.model_runtime.errors.invoke import (

+ 1 - 1
api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py

@@ -7,7 +7,7 @@ from urllib.parse import urljoin
 import numpy as np
 import requests
 
-from core.embedding.embedding_constant import EmbeddingInputType
+from core.entities.embedding_type import EmbeddingInputType
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.model_entities import (
     AIModelEntity,

+ 1 - 1
api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py

@@ -4,7 +4,7 @@ from typing import Optional
 
 from replicate import Client as ReplicateClient
 
-from core.embedding.embedding_constant import EmbeddingInputType
+from core.entities.embedding_type import EmbeddingInputType
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult

+ 1 - 1
api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py

@@ -6,7 +6,7 @@ from typing import Any, Optional
 
 import boto3
 
-from core.embedding.embedding_constant import EmbeddingInputType
+from core.entities.embedding_type import EmbeddingInputType
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult

+ 1 - 1
api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py

@@ -1,6 +1,6 @@
 from typing import Optional
 
-from core.embedding.embedding_constant import EmbeddingInputType
+from core.entities.embedding_type import EmbeddingInputType
 from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
 from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import (
     OAICompatEmbeddingModel,

+ 1 - 1
api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py

@@ -4,7 +4,7 @@ from typing import Optional
 import dashscope
 import numpy as np
 
-from core.embedding.embedding_constant import EmbeddingInputType
+from core.entities.embedding_type import EmbeddingInputType
 from core.model_runtime.entities.model_entities import PriceType
 from core.model_runtime.entities.text_embedding_entities import (
     EmbeddingUsage,

+ 1 - 1
api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py

@@ -7,7 +7,7 @@ import numpy as np
 from openai import OpenAI
 from tokenizers import Tokenizer
 
-from core.embedding.embedding_constant import EmbeddingInputType
+from core.entities.embedding_type import EmbeddingInputType
 from core.model_runtime.entities.model_entities import PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 from core.model_runtime.errors.validate import CredentialsValidateFailedError

+ 1 - 1
api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py

@@ -9,7 +9,7 @@ from google.cloud import aiplatform
 from google.oauth2 import service_account
 from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
 
-from core.embedding.embedding_constant import EmbeddingInputType
+from core.entities.embedding_type import EmbeddingInputType
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.model_entities import (
     AIModelEntity,

+ 1 - 1
api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py

@@ -2,7 +2,7 @@ import time
 from decimal import Decimal
 from typing import Optional
 
-from core.embedding.embedding_constant import EmbeddingInputType
+from core.entities.embedding_type import EmbeddingInputType
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.model_entities import (
     AIModelEntity,

+ 1 - 1
api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py

@@ -4,7 +4,7 @@ from typing import Optional
 
 import requests
 
-from core.embedding.embedding_constant import EmbeddingInputType
+from core.entities.embedding_type import EmbeddingInputType
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult

+ 1 - 1
api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py

@@ -7,7 +7,7 @@ from typing import Any, Optional
 import numpy as np
 from requests import Response, post
 
-from core.embedding.embedding_constant import EmbeddingInputType
+from core.entities.embedding_type import EmbeddingInputType
 from core.model_runtime.entities.model_entities import PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 from core.model_runtime.errors.invoke import InvokeError

+ 1 - 1
api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py

@@ -3,7 +3,7 @@ from typing import Optional
 
 from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle
 
-from core.embedding.embedding_constant import EmbeddingInputType
+from core.entities.embedding_type import EmbeddingInputType
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult

+ 1 - 1
api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py

@@ -3,7 +3,7 @@ from typing import Optional
 
 from zhipuai import ZhipuAI
 
-from core.embedding.embedding_constant import EmbeddingInputType
+from core.entities.embedding_type import EmbeddingInputType
 from core.model_runtime.entities.model_entities import PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 from core.model_runtime.errors.validate import CredentialsValidateFailedError

+ 32 - 21
api/core/rag/data_post_processor/data_post_processor.py

@@ -1,14 +1,14 @@
 from typing import Optional
 
-from core.model_manager import ModelManager
+from core.model_manager import ModelInstance, ModelManager
 from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.errors.invoke import InvokeAuthorizationError
 from core.rag.data_post_processor.reorder import ReorderRunner
 from core.rag.models.document import Document
-from core.rag.rerank.constants.rerank_mode import RerankMode
 from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights
-from core.rag.rerank.rerank_model import RerankModelRunner
-from core.rag.rerank.weight_rerank import WeightRerankRunner
+from core.rag.rerank.rerank_base import BaseRerankRunner
+from core.rag.rerank.rerank_factory import RerankRunnerFactory
+from core.rag.rerank.rerank_type import RerankMode
 
 
 class DataPostProcessor:
@@ -47,11 +47,12 @@ class DataPostProcessor:
         tenant_id: str,
         reranking_model: Optional[dict] = None,
         weights: Optional[dict] = None,
-    ) -> Optional[RerankModelRunner | WeightRerankRunner]:
+    ) -> Optional[BaseRerankRunner]:
         if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights:
-            return WeightRerankRunner(
-                tenant_id,
-                Weights(
+            runner = RerankRunnerFactory.create_rerank_runner(
+                runner_type=reranking_mode,
+                tenant_id=tenant_id,
+                weights=Weights(
                     vector_setting=VectorSetting(
                         vector_weight=weights["vector_setting"]["vector_weight"],
                         embedding_provider_name=weights["vector_setting"]["embedding_provider_name"],
@@ -62,23 +63,33 @@ class DataPostProcessor:
                     ),
                 ),
             )
+            return runner
         elif reranking_mode == RerankMode.RERANKING_MODEL.value:
-            if reranking_model:
-                try:
-                    model_manager = ModelManager()
-                    rerank_model_instance = model_manager.get_model_instance(
-                        tenant_id=tenant_id,
-                        provider=reranking_model["reranking_provider_name"],
-                        model_type=ModelType.RERANK,
-                        model=reranking_model["reranking_model_name"],
-                    )
-                except InvokeAuthorizationError:
-                    return None
-                return RerankModelRunner(rerank_model_instance)
-            return None
+            rerank_model_instance = self._get_rerank_model_instance(tenant_id, reranking_model)
+            if rerank_model_instance is None:
+                return None
+            runner = RerankRunnerFactory.create_rerank_runner(
+                runner_type=reranking_mode, rerank_model_instance=rerank_model_instance
+            )
+            return runner
         return None
 
     def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]:
         if reorder_enabled:
             return ReorderRunner()
         return None
+
+    def _get_rerank_model_instance(self, tenant_id: str, reranking_model: Optional[dict]) -> ModelInstance | None:
+        if reranking_model:
+            try:
+                model_manager = ModelManager()
+                rerank_model_instance = model_manager.get_model_instance(
+                    tenant_id=tenant_id,
+                    provider=reranking_model["reranking_provider_name"],
+                    model_type=ModelType.RERANK,
+                    model=reranking_model["reranking_model_name"],
+                )
+                return rerank_model_instance
+            except InvokeAuthorizationError:
+                return None
+        return None

+ 1 - 1
api/core/rag/datasource/retrieval_service.py

@@ -6,7 +6,7 @@ from flask import Flask, current_app
 from core.rag.data_post_processor.data_post_processor import DataPostProcessor
 from core.rag.datasource.keyword.keyword_factory import Keyword
 from core.rag.datasource.vdb.vector_factory import Vector
-from core.rag.rerank.constants.rerank_mode import RerankMode
+from core.rag.rerank.rerank_type import RerankMode
 from core.rag.retrieval.retrieval_methods import RetrievalMethod
 from extensions.ext_database import db
 from models.dataset import Dataset

+ 1 - 1
api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py

@@ -9,10 +9,10 @@ _import_err_msg = (
 )
 
 from configs import dify_config
-from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
 from core.rag.datasource.vdb.vector_type import VectorType
+from core.rag.embedding.embedding_base import Embeddings
 from core.rag.models.document import Document
 from extensions.ext_redis import redis_client
 from models.dataset import Dataset

+ 1 - 1
api/core/rag/datasource/vdb/baidu/baidu_vector.py

@@ -12,10 +12,10 @@ from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex
 from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row
 
 from configs import dify_config
-from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
 from core.rag.datasource.vdb.vector_type import VectorType
+from core.rag.embedding.embedding_base import Embeddings
 from core.rag.models.document import Document
 from extensions.ext_redis import redis_client
 from models.dataset import Dataset

+ 1 - 1
api/core/rag/datasource/vdb/chroma/chroma_vector.py

@@ -6,10 +6,10 @@ from chromadb import QueryResult, Settings
 from pydantic import BaseModel
 
 from configs import dify_config
-from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
 from core.rag.datasource.vdb.vector_type import VectorType
+from core.rag.embedding.embedding_base import Embeddings
 from core.rag.models.document import Document
 from extensions.ext_redis import redis_client
 from models.dataset import Dataset

+ 1 - 1
api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py

@@ -9,11 +9,11 @@ from elasticsearch import Elasticsearch
 from flask import current_app
 from pydantic import BaseModel, model_validator
 
-from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.field import Field
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
 from core.rag.datasource.vdb.vector_type import VectorType
+from core.rag.embedding.embedding_base import Embeddings
 from core.rag.models.document import Document
 from extensions.ext_redis import redis_client
 from models.dataset import Dataset

+ 1 - 1
api/core/rag/datasource/vdb/milvus/milvus_vector.py

@@ -7,11 +7,11 @@ from pymilvus import MilvusClient, MilvusException
 from pymilvus.milvus_client import IndexParams
 
 from configs import dify_config
-from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.field import Field
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
 from core.rag.datasource.vdb.vector_type import VectorType
+from core.rag.embedding.embedding_base import Embeddings
 from core.rag.models.document import Document
 from extensions.ext_redis import redis_client
 from models.dataset import Dataset

+ 1 - 1
api/core/rag/datasource/vdb/myscale/myscale_vector.py

@@ -8,10 +8,10 @@ from clickhouse_connect import get_client
 from pydantic import BaseModel
 
 from configs import dify_config
-from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
 from core.rag.datasource.vdb.vector_type import VectorType
+from core.rag.embedding.embedding_base import Embeddings
 from core.rag.models.document import Document
 from models.dataset import Dataset
 

+ 1 - 1
api/core/rag/datasource/vdb/opensearch/opensearch_vector.py

@@ -9,11 +9,11 @@ from opensearchpy.helpers import BulkIndexError
 from pydantic import BaseModel, model_validator
 
 from configs import dify_config
-from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.field import Field
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
 from core.rag.datasource.vdb.vector_type import VectorType
+from core.rag.embedding.embedding_base import Embeddings
 from core.rag.models.document import Document
 from extensions.ext_redis import redis_client
 from models.dataset import Dataset

+ 1 - 1
api/core/rag/datasource/vdb/oracle/oraclevector.py

@@ -13,10 +13,10 @@ from nltk.corpus import stopwords
 from pydantic import BaseModel, model_validator
 
 from configs import dify_config
-from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
 from core.rag.datasource.vdb.vector_type import VectorType
+from core.rag.embedding.embedding_base import Embeddings
 from core.rag.models.document import Document
 from extensions.ext_redis import redis_client
 from models.dataset import Dataset

+ 1 - 1
api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py

@@ -12,11 +12,11 @@ from sqlalchemy.dialects import postgresql
 from sqlalchemy.orm import Mapped, Session, mapped_column
 
 from configs import dify_config
-from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
 from core.rag.datasource.vdb.vector_type import VectorType
+from core.rag.embedding.embedding_base import Embeddings
 from core.rag.models.document import Document
 from extensions.ext_redis import redis_client
 from models.dataset import Dataset

+ 1 - 1
api/core/rag/datasource/vdb/pgvector/pgvector.py

@@ -8,10 +8,10 @@ import psycopg2.pool
 from pydantic import BaseModel, model_validator
 
 from configs import dify_config
-from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
 from core.rag.datasource.vdb.vector_type import VectorType
+from core.rag.embedding.embedding_base import Embeddings
 from core.rag.models.document import Document
 from extensions.ext_redis import redis_client
 from models.dataset import Dataset

+ 1 - 1
api/core/rag/datasource/vdb/qdrant/qdrant_vector.py

@@ -20,11 +20,11 @@ from qdrant_client.http.models import (
 from qdrant_client.local.qdrant_local import QdrantLocal
 
 from configs import dify_config
-from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.field import Field
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
 from core.rag.datasource.vdb.vector_type import VectorType
+from core.rag.embedding.embedding_base import Embeddings
 from core.rag.models.document import Document
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client

+ 1 - 1
api/core/rag/datasource/vdb/relyt/relyt_vector.py

@@ -8,9 +8,9 @@ from sqlalchemy import text as sql_text
 from sqlalchemy.dialects.postgresql import JSON, TEXT
 from sqlalchemy.orm import Session
 
-from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
 from core.rag.datasource.vdb.vector_type import VectorType
+from core.rag.embedding.embedding_base import Embeddings
 from models.dataset import Dataset
 
 try:

+ 1 - 1
api/core/rag/datasource/vdb/tencent/tencent_vector.py

@@ -8,10 +8,10 @@ from tcvectordb.model import index as vdb_index
 from tcvectordb.model.document import Filter
 
 from configs import dify_config
-from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
 from core.rag.datasource.vdb.vector_type import VectorType
+from core.rag.embedding.embedding_base import Embeddings
 from core.rag.models.document import Document
 from extensions.ext_redis import redis_client
 from models.dataset import Dataset

+ 1 - 1
api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py

@@ -9,10 +9,10 @@ from sqlalchemy import text as sql_text
 from sqlalchemy.orm import Session, declarative_base
 
 from configs import dify_config
-from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
 from core.rag.datasource.vdb.vector_type import VectorType
+from core.rag.embedding.embedding_base import Embeddings
 from core.rag.models.document import Document
 from extensions.ext_redis import redis_client
 from models.dataset import Dataset

+ 2 - 2
api/core/rag/datasource/vdb/vector_factory.py

@@ -2,12 +2,12 @@ from abc import ABC, abstractmethod
 from typing import Any, Optional
 
 from configs import dify_config
-from core.embedding.cached_embedding import CacheEmbedding
 from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelType
-from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_type import VectorType
+from core.rag.embedding.cached_embedding import CacheEmbedding
+from core.rag.embedding.embedding_base import Embeddings
 from core.rag.models.document import Document
 from extensions.ext_redis import redis_client
 from models.dataset import Dataset

+ 1 - 1
api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py

@@ -14,11 +14,11 @@ from volcengine.viking_db import (
 )
 
 from configs import dify_config
-from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.field import Field as vdb_Field
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
 from core.rag.datasource.vdb.vector_type import VectorType
+from core.rag.embedding.embedding_base import Embeddings
 from core.rag.models.document import Document
 from extensions.ext_redis import redis_client
 from models.dataset import Dataset

+ 1 - 1
api/core/rag/datasource/vdb/weaviate/weaviate_vector.py

@@ -7,11 +7,11 @@ import weaviate
 from pydantic import BaseModel, model_validator
 
 from configs import dify_config
-from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.field import Field
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
 from core.rag.datasource.vdb.vector_type import VectorType
+from core.rag.embedding.embedding_base import Embeddings
 from core.rag.models.document import Document
 from extensions.ext_redis import redis_client
 from models.dataset import Dataset

+ 0 - 0
api/core/rag/embedding/__init__.py


+ 2 - 2
api/core/embedding/cached_embedding.py → api/core/rag/embedding/cached_embedding.py

@@ -6,11 +6,11 @@ import numpy as np
 from sqlalchemy.exc import IntegrityError
 
 from configs import dify_config
-from core.embedding.embedding_constant import EmbeddingInputType
+from core.entities.embedding_type import EmbeddingInputType
 from core.model_manager import ModelInstance
 from core.model_runtime.entities.model_entities import ModelPropertyKey
 from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
-from core.rag.datasource.entity.embedding import Embeddings
+from core.rag.embedding.embedding_base import Embeddings
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from libs import helper

+ 2 - 0
api/core/rag/datasource/entity/embedding.py → api/core/rag/embedding/embedding_base.py

@@ -7,10 +7,12 @@ class Embeddings(ABC):
     @abstractmethod
     def embed_documents(self, texts: list[str]) -> list[list[float]]:
         """Embed search docs."""
+        raise NotImplementedError
 
     @abstractmethod
     def embed_query(self, text: str) -> list[float]:
         """Embed query text."""
+        raise NotImplementedError
 
     async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
         """Asynchronous Embed search docs."""

+ 26 - 0
api/core/rag/rerank/rerank_base.py

@@ -0,0 +1,26 @@
+from abc import ABC, abstractmethod
+from typing import Optional
+
+from core.rag.models.document import Document
+
+
+class BaseRerankRunner(ABC):
+    @abstractmethod
+    def run(
+        self,
+        query: str,
+        documents: list[Document],
+        score_threshold: Optional[float] = None,
+        top_n: Optional[int] = None,
+        user: Optional[str] = None,
+    ) -> list[Document]:
+        """
+        Run rerank model
+        :param query: search query
+        :param documents: documents for reranking
+        :param score_threshold: score threshold
+        :param top_n: top n
+        :param user: unique user id if needed
+        :return:
+        """
+        raise NotImplementedError

+ 16 - 0
api/core/rag/rerank/rerank_factory.py

@@ -0,0 +1,16 @@
+from core.rag.rerank.rerank_base import BaseRerankRunner
+from core.rag.rerank.rerank_model import RerankModelRunner
+from core.rag.rerank.rerank_type import RerankMode
+from core.rag.rerank.weight_rerank import WeightRerankRunner
+
+
+class RerankRunnerFactory:
+    @staticmethod
+    def create_rerank_runner(runner_type: str, *args, **kwargs) -> BaseRerankRunner:
+        match runner_type:
+            case RerankMode.RERANKING_MODEL.value:
+                return RerankModelRunner(*args, **kwargs)
+            case RerankMode.WEIGHTED_SCORE.value:
+                return WeightRerankRunner(*args, **kwargs)
+            case _:
+                raise ValueError(f"Unknown runner type: {runner_type}")

+ 2 - 1
api/core/rag/rerank/rerank_model.py

@@ -2,9 +2,10 @@ from typing import Optional
 
 from core.model_manager import ModelInstance
 from core.rag.models.document import Document
+from core.rag.rerank.rerank_base import BaseRerankRunner
 
 
-class RerankModelRunner:
+class RerankModelRunner(BaseRerankRunner):
     def __init__(self, rerank_model_instance: ModelInstance) -> None:
         self.rerank_model_instance = rerank_model_instance
 

+ 0 - 0
api/core/rag/rerank/constants/rerank_mode.py → api/core/rag/rerank/rerank_type.py


+ 3 - 2
api/core/rag/rerank/weight_rerank.py

@@ -4,15 +4,16 @@ from typing import Optional
 
 import numpy as np
 
-from core.embedding.cached_embedding import CacheEmbedding
 from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelType
 from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
+from core.rag.embedding.cached_embedding import CacheEmbedding
 from core.rag.models.document import Document
 from core.rag.rerank.entity.weight import VectorSetting, Weights
+from core.rag.rerank.rerank_base import BaseRerankRunner
 
 
-class WeightRerankRunner:
+class WeightRerankRunner(BaseRerankRunner):
     def __init__(self, tenant_id: str, weights: Weights) -> None:
         self.tenant_id = tenant_id
         self.weights = weights