|
@@ -278,6 +278,7 @@ class DatasetRetrieval:
|
|
|
query=query,
|
|
|
top_k=top_k, score_threshold=score_threshold,
|
|
|
reranking_model=reranking_model,
|
|
|
+ reranking_mode=retrieval_model_config.get('reranking_mode', 'reranking_model'),
|
|
|
weights=retrieval_model_config.get('weights', None),
|
|
|
)
|
|
|
self._on_query(query, [dataset_id], app_id, user_from, user_id)
|
|
@@ -431,10 +432,12 @@ class DatasetRetrieval:
|
|
|
dataset_id=dataset.id,
|
|
|
query=query,
|
|
|
top_k=top_k,
|
|
|
- score_threshold=retrieval_model['score_threshold']
|
|
|
+ score_threshold=retrieval_model.get('score_threshold', .0)
|
|
|
if retrieval_model['score_threshold_enabled'] else None,
|
|
|
- reranking_model=retrieval_model['reranking_model']
|
|
|
+ reranking_model=retrieval_model.get('reranking_model', None)
|
|
|
if retrieval_model['reranking_enable'] else None,
|
|
|
+ reranking_mode=retrieval_model.get('reranking_mode')
|
|
|
+ if retrieval_model.get('reranking_mode') else 'reranking_model',
|
|
|
weights=retrieval_model.get('weights', None),
|
|
|
)
|
|
|
|