|
@@ -4,10 +4,6 @@ import time
|
|
|
import numpy as np
|
|
|
from sklearn.manifold import TSNE
|
|
|
|
|
|
-from core.embedding.cached_embedding import CacheEmbedding
|
|
|
-from core.model_manager import ModelManager
|
|
|
-from core.model_runtime.entities.model_entities import ModelType
|
|
|
-from core.rag.datasource.entity.embedding import Embeddings
|
|
|
from core.rag.datasource.retrieval_service import RetrievalService
|
|
|
from core.rag.models.document import Document
|
|
|
from core.rag.retrieval.retrival_methods import RetrievalMethod
|
|
@@ -45,17 +41,6 @@ class HitTestingService:
|
|
|
if not retrieval_model:
|
|
|
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
|
|
|
|
|
- # get embedding model
|
|
|
- model_manager = ModelManager()
|
|
|
- embedding_model = model_manager.get_model_instance(
|
|
|
- tenant_id=dataset.tenant_id,
|
|
|
- model_type=ModelType.TEXT_EMBEDDING,
|
|
|
- provider=dataset.embedding_model_provider,
|
|
|
- model=dataset.embedding_model
|
|
|
- )
|
|
|
-
|
|
|
- embeddings = CacheEmbedding(embedding_model)
|
|
|
-
|
|
|
all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
|
|
|
dataset_id=dataset.id,
|
|
|
query=query,
|
|
@@ -80,20 +65,10 @@ class HitTestingService:
|
|
|
db.session.add(dataset_query)
|
|
|
db.session.commit()
|
|
|
|
|
|
- return cls.compact_retrieve_response(dataset, embeddings, query, all_documents)
|
|
|
+ return cls.compact_retrieve_response(dataset, query, all_documents)
|
|
|
|
|
|
@classmethod
|
|
|
- def compact_retrieve_response(cls, dataset: Dataset, embeddings: Embeddings, query: str, documents: list[Document]):
|
|
|
- text_embeddings = [
|
|
|
- embeddings.embed_query(query)
|
|
|
- ]
|
|
|
-
|
|
|
- text_embeddings.extend(embeddings.embed_documents([document.page_content for document in documents]))
|
|
|
-
|
|
|
- tsne_position_data = cls.get_tsne_positions_from_embeddings(text_embeddings)
|
|
|
-
|
|
|
- query_position = tsne_position_data.pop(0)
|
|
|
-
|
|
|
+ def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]):
|
|
|
i = 0
|
|
|
records = []
|
|
|
for document in documents:
|
|
@@ -113,7 +88,6 @@ class HitTestingService:
|
|
|
record = {
|
|
|
"segment": segment,
|
|
|
"score": document.metadata.get('score', None),
|
|
|
- "tsne_position": tsne_position_data[i]
|
|
|
}
|
|
|
|
|
|
records.append(record)
|
|
@@ -123,7 +97,6 @@ class HitTestingService:
|
|
|
return {
|
|
|
"query": {
|
|
|
"content": query,
|
|
|
- "tsne_position": query_position,
|
|
|
},
|
|
|
"records": records
|
|
|
}
|