|
@@ -144,6 +144,10 @@ class TidbOnQdrantVector(BaseVector):
|
|
|
self._client.create_payload_index(
|
|
|
collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD
|
|
|
)
|
|
|
+ # create document_id payload index
|
|
|
+ self._client.create_payload_index(
|
|
|
+ collection_name, Field.DOCUMENT_ID.value, field_schema=PayloadSchemaType.KEYWORD
|
|
|
+ )
|
|
|
# create full text index
|
|
|
text_index_params = TextIndexParams(
|
|
|
type=TextIndexType.TEXT,
|
|
@@ -318,23 +322,17 @@ class TidbOnQdrantVector(BaseVector):
|
|
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
|
|
from qdrant_client.http import models
|
|
|
|
|
|
- filter = models.Filter(
|
|
|
- must=[
|
|
|
- models.FieldCondition(
|
|
|
- key="group_id",
|
|
|
- match=models.MatchValue(value=self._group_id),
|
|
|
- ),
|
|
|
- ],
|
|
|
- )
|
|
|
+ filter = None
|
|
|
document_ids_filter = kwargs.get("document_ids_filter")
|
|
|
if document_ids_filter:
|
|
|
- if filter.must:
|
|
|
- filter.must.append(
|
|
|
+ filter = models.Filter(
|
|
|
+ must=[
|
|
|
models.FieldCondition(
|
|
|
key="metadata.document_id",
|
|
|
match=models.MatchAny(any=document_ids_filter),
|
|
|
)
|
|
|
- )
|
|
|
+ ],
|
|
|
+ )
|
|
|
results = self._client.search(
|
|
|
collection_name=self._collection_name,
|
|
|
query_vector=query_vector,
|
|
@@ -369,23 +367,17 @@ class TidbOnQdrantVector(BaseVector):
|
|
|
"""
|
|
|
from qdrant_client.http import models
|
|
|
|
|
|
- scroll_filter = models.Filter(
|
|
|
- must=[
|
|
|
- models.FieldCondition(
|
|
|
- key="page_content",
|
|
|
- match=models.MatchText(text=query),
|
|
|
- )
|
|
|
- ]
|
|
|
- )
|
|
|
+ scroll_filter = None
|
|
|
document_ids_filter = kwargs.get("document_ids_filter")
|
|
|
if document_ids_filter:
|
|
|
- if scroll_filter.must:
|
|
|
- scroll_filter.must.append(
|
|
|
+ scroll_filter = models.Filter(
|
|
|
+ must=[
|
|
|
models.FieldCondition(
|
|
|
key="metadata.document_id",
|
|
|
match=models.MatchAny(any=document_ids_filter),
|
|
|
)
|
|
|
- )
|
|
|
+ ]
|
|
|
+ )
|
|
|
response = self._client.scroll(
|
|
|
collection_name=self._collection_name,
|
|
|
scroll_filter=scroll_filter,
|