Selaa lähdekoodia

add vector field for other vectordb (#7051)

Jyong 8 kuukautta sitten
vanhempi
commit
80c94f02e9

+ 1 - 1
api/core/app/apps/advanced_chat/generate_task_pipeline.py

@@ -244,7 +244,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
         :return:
         """
         for message in self._queue_manager.listen():
-            if hasattr(message.event, 'metadata') and message.event.metadata.get('is_answer_previous_node', False) and publisher:
+            if message.event and hasattr(message.event, 'metadata') and message.event.metadata.get('is_answer_previous_node', False) and publisher:
                 publisher.publish(message=message)
             elif (hasattr(message.event, 'execution_metadata')
                   and message.event.execution_metadata

+ 3 - 1
api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py

@@ -285,9 +285,11 @@ class AnalyticdbVector(BaseVector):
         documents = []
         for match in response.body.matches.match:
             if match.score > score_threshold:
+                metadata = json.loads(match.metadata.get("metadata_"))
                 doc = Document(
                     page_content=match.metadata.get("page_content"),
-                    metadata=json.loads(match.metadata.get("metadata_")),
+                    vector=match.metadata.get("vector"),
+                    metadata=metadata,
                 )
                 documents.append(doc)
         return documents

+ 2 - 1
api/core/rag/datasource/vdb/myscale/myscale_vector.py

@@ -126,13 +126,14 @@ class MyScaleVector(BaseVector):
         where_str = f"WHERE dist < {1 - score_threshold}" if \
             self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 else ""
         sql = f"""
-            SELECT text, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
+            SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
             {where_str} ORDER BY dist {order.value} LIMIT {top_k}
         """
         try:
             return [
                 Document(
                     page_content=r["text"],
+                    vector=r['vector'],
                     metadata=r["metadata"],
                 )
                 for r in self._client.query(sql).named_results()

+ 3 - 1
api/core/rag/datasource/vdb/opensearch/opensearch_vector.py

@@ -192,7 +192,9 @@ class OpenSearchVector(BaseVector):
         docs = []
         for hit in response['hits']['hits']:
             metadata = hit['_source'].get(Field.METADATA_KEY.value)
-            doc = Document(page_content=hit['_source'].get(Field.CONTENT_KEY.value), metadata=metadata)
+            vector = hit['_source'].get(Field.VECTOR.value)
+            page_content = hit['_source'].get(Field.CONTENT_KEY.value)
+            doc = Document(page_content=page_content, vector=vector, metadata=metadata)
             docs.append(doc)
 
         return docs

+ 4 - 4
api/core/rag/datasource/vdb/oracle/oraclevector.py

@@ -234,16 +234,16 @@ class OracleVector(BaseVector):
                         entities.append(token)
             with self._get_cursor() as cur:
                 cur.execute(
-                    f"select meta, text FROM {self.table_name} WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only",
+                    f"select meta, text, embedding FROM {self.table_name} WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only",
                     [" ACCUM ".join(entities)]
                 )
                 docs = []
                 for record in cur:
-                    metadata, text = record
-                    docs.append(Document(page_content=text, metadata=metadata))
+                    metadata, text, embedding = record
+                    docs.append(Document(page_content=text, vector=embedding, metadata=metadata))
             return docs
         else:
-            return [Document(page_content="", metadata="")]
+            return [Document(page_content="", metadata={})]
         return []
 
     def delete(self) -> None:

+ 1 - 1
api/core/rag/datasource/vdb/qdrant/qdrant_vector.py

@@ -399,7 +399,6 @@ class QdrantVector(BaseVector):
                 document = self._document_from_scored_point(
                     result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value
                 )
-                document.metadata['vector'] = result.vector
                 documents.append(document)
 
         return documents
@@ -418,6 +417,7 @@ class QdrantVector(BaseVector):
     ) -> Document:
         return Document(
             page_content=scored_point.payload.get(content_payload_key),
+            vector=scored_point.vector,
             metadata=scored_point.payload.get(metadata_payload_key) or {},
         )
 

+ 3 - 3
api/core/rag/datasource/vdb/weaviate/weaviate_vector.py

@@ -239,8 +239,7 @@ class WeaviateVector(BaseVector):
         query_obj = self._client.query.get(collection_name, properties)
         if kwargs.get("where_filter"):
             query_obj = query_obj.with_where(kwargs.get("where_filter"))
-        if kwargs.get("additional"):
-            query_obj = query_obj.with_additional(kwargs.get("additional"))
+        query_obj = query_obj.with_additional(["vector"])
         properties = ['text']
         result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get('top_k', 2)).do()
         if "errors" in result:
@@ -248,7 +247,8 @@ class WeaviateVector(BaseVector):
         docs = []
         for res in result["data"]["Get"][collection_name]:
             text = res.pop(Field.TEXT_KEY.value)
-            docs.append(Document(page_content=text, metadata=res))
+            additional = res.pop('_additional')
+            docs.append(Document(page_content=text, vector=additional['vector'], metadata=res))
         return docs
 
     def _default_schema(self, index_name: str) -> dict:

+ 2 - 0
api/core/rag/models/document.py

@@ -10,6 +10,8 @@ class Document(BaseModel):
 
     page_content: str
 
+    vector: Optional[list[float]] = None
+
     """Arbitrary metadata about the page content (e.g., source, relationships to other
         documents, etc.).
     """

+ 2 - 1
api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py

@@ -78,7 +78,8 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
                                                       top_k=self.top_k,
                                                       score_threshold=retrieval_model.get('score_threshold', .0)
                                                       if retrieval_model['score_threshold_enabled'] else None,
-                                                      reranking_model=retrieval_model.get('reranking_model', None),
+                                                      reranking_model=retrieval_model.get('reranking_model', None)
+                                                      if retrieval_model['reranking_enable'] else None,
                                                       reranking_mode=retrieval_model.get('reranking_mode')
                                                       if retrieval_model.get('reranking_mode') else 'reranking_model',
                                                       weights=retrieval_model.get('weights', None),

+ 2 - 1
api/services/hit_testing_service.py

@@ -44,7 +44,8 @@ class HitTestingService:
                                                   top_k=retrieval_model.get('top_k', 2),
                                                   score_threshold=retrieval_model.get('score_threshold', .0)
                                                   if retrieval_model['score_threshold_enabled'] else None,
-                                                  reranking_model=retrieval_model.get('reranking_model', None),
+                                                  reranking_model=retrieval_model.get('reranking_model', None)
+                                                  if retrieval_model['reranking_enable'] else None,
                                                   reranking_mode=retrieval_model.get('reranking_mode')
                                                   if retrieval_model.get('reranking_mode') else 'reranking_model',
                                                   weights=retrieval_model.get('weights', None),