|
@@ -1,14 +1,14 @@
|
|
|
from typing import Optional
|
|
|
|
|
|
-from core.model_manager import ModelManager
|
|
|
+from core.model_manager import ModelInstance, ModelManager
|
|
|
from core.model_runtime.entities.model_entities import ModelType
|
|
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
|
|
from core.rag.data_post_processor.reorder import ReorderRunner
|
|
|
from core.rag.models.document import Document
|
|
|
-from core.rag.rerank.constants.rerank_mode import RerankMode
|
|
|
from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights
|
|
|
-from core.rag.rerank.rerank_model import RerankModelRunner
|
|
|
-from core.rag.rerank.weight_rerank import WeightRerankRunner
|
|
|
+from core.rag.rerank.rerank_base import BaseRerankRunner
|
|
|
+from core.rag.rerank.rerank_factory import RerankRunnerFactory
|
|
|
+from core.rag.rerank.rerank_type import RerankMode
|
|
|
|
|
|
|
|
|
class DataPostProcessor:
|
|
@@ -47,11 +47,12 @@ class DataPostProcessor:
|
|
|
tenant_id: str,
|
|
|
reranking_model: Optional[dict] = None,
|
|
|
weights: Optional[dict] = None,
|
|
|
- ) -> Optional[RerankModelRunner | WeightRerankRunner]:
|
|
|
+ ) -> Optional[BaseRerankRunner]:
|
|
|
if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights:
|
|
|
- return WeightRerankRunner(
|
|
|
- tenant_id,
|
|
|
- Weights(
|
|
|
+ runner = RerankRunnerFactory.create_rerank_runner(
|
|
|
+ runner_type=reranking_mode,
|
|
|
+ tenant_id=tenant_id,
|
|
|
+ weights=Weights(
|
|
|
vector_setting=VectorSetting(
|
|
|
vector_weight=weights["vector_setting"]["vector_weight"],
|
|
|
embedding_provider_name=weights["vector_setting"]["embedding_provider_name"],
|
|
@@ -62,23 +63,33 @@ class DataPostProcessor:
|
|
|
),
|
|
|
),
|
|
|
)
|
|
|
+ return runner
|
|
|
elif reranking_mode == RerankMode.RERANKING_MODEL.value:
|
|
|
- if reranking_model:
|
|
|
- try:
|
|
|
- model_manager = ModelManager()
|
|
|
- rerank_model_instance = model_manager.get_model_instance(
|
|
|
- tenant_id=tenant_id,
|
|
|
- provider=reranking_model["reranking_provider_name"],
|
|
|
- model_type=ModelType.RERANK,
|
|
|
- model=reranking_model["reranking_model_name"],
|
|
|
- )
|
|
|
- except InvokeAuthorizationError:
|
|
|
- return None
|
|
|
- return RerankModelRunner(rerank_model_instance)
|
|
|
- return None
|
|
|
+ rerank_model_instance = self._get_rerank_model_instance(tenant_id, reranking_model)
|
|
|
+ if rerank_model_instance is None:
|
|
|
+ return None
|
|
|
+ runner = RerankRunnerFactory.create_rerank_runner(
|
|
|
+ runner_type=reranking_mode, rerank_model_instance=rerank_model_instance
|
|
|
+ )
|
|
|
+ return runner
|
|
|
return None
|
|
|
|
|
|
def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]:
|
|
|
if reorder_enabled:
|
|
|
return ReorderRunner()
|
|
|
return None
|
|
|
+
|
|
|
+ def _get_rerank_model_instance(self, tenant_id: str, reranking_model: Optional[dict]) -> ModelInstance | None:
|
|
|
+ if reranking_model:
|
|
|
+ try:
|
|
|
+ model_manager = ModelManager()
|
|
|
+ rerank_model_instance = model_manager.get_model_instance(
|
|
|
+ tenant_id=tenant_id,
|
|
|
+ provider=reranking_model["reranking_provider_name"],
|
|
|
+ model_type=ModelType.RERANK,
|
|
|
+ model=reranking_model["reranking_model_name"],
|
|
|
+ )
|
|
|
+ return rerank_model_instance
|
|
|
+ except InvokeAuthorizationError:
|
|
|
+ return None
|
|
|
+ return None
|