|
@@ -8,6 +8,7 @@ from pymilvus import MilvusClient, MilvusException, connections
|
|
|
from core.rag.datasource.vdb.field import Field
|
|
|
from core.rag.datasource.vdb.vector_base import BaseVector
|
|
|
from core.rag.models.document import Document
|
|
|
+from extensions.ext_redis import redis_client
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
@@ -61,17 +62,7 @@ class MilvusVector(BaseVector):
|
|
|
'params': {"M": 8, "efConstruction": 64}
|
|
|
}
|
|
|
metadatas = [d.metadata for d in texts]
|
|
|
-
|
|
|
- # Grab the existing collection if it exists
|
|
|
- from pymilvus import utility
|
|
|
- alias = uuid4().hex
|
|
|
- if self._client_config.secure:
|
|
|
- uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
|
|
- else:
|
|
|
- uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
|
|
- connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
|
|
|
- if not utility.has_collection(self._collection_name, using=alias):
|
|
|
- self.create_collection(embeddings, metadatas, index_params)
|
|
|
+ self.create_collection(embeddings, metadatas, index_params)
|
|
|
self.add_texts(texts, embeddings)
|
|
|
|
|
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
|
@@ -187,46 +178,60 @@ class MilvusVector(BaseVector):
|
|
|
|
|
|
def create_collection(
|
|
|
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
|
|
|
- ) -> str:
|
|
|
- from pymilvus import CollectionSchema, DataType, FieldSchema
|
|
|
- from pymilvus.orm.types import infer_dtype_bydata
|
|
|
-
|
|
|
- # Determine embedding dim
|
|
|
- dim = len(embeddings[0])
|
|
|
- fields = []
|
|
|
- if metadatas:
|
|
|
- fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
|
|
|
-
|
|
|
- # Create the text field
|
|
|
- fields.append(
|
|
|
- FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)
|
|
|
- )
|
|
|
- # Create the primary key field
|
|
|
- fields.append(
|
|
|
- FieldSchema(
|
|
|
- Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True
|
|
|
- )
|
|
|
- )
|
|
|
- # Create the vector field, supports binary or float vectors
|
|
|
- fields.append(
|
|
|
- FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)
|
|
|
- )
|
|
|
-
|
|
|
- # Create the schema for the collection
|
|
|
- schema = CollectionSchema(fields)
|
|
|
-
|
|
|
- for x in schema.fields:
|
|
|
- self._fields.append(x.name)
|
|
|
- # Since primary field is auto-id, no need to track it
|
|
|
- self._fields.remove(Field.PRIMARY_KEY.value)
|
|
|
-
|
|
|
- # Create the collection
|
|
|
- collection_name = self._collection_name
|
|
|
- self._client.create_collection_with_schema(collection_name=collection_name,
|
|
|
- schema=schema, index_param=index_params,
|
|
|
- consistency_level=self._consistency_level)
|
|
|
- return collection_name
|
|
|
+ ):
|
|
|
+ lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
|
|
|
+ with redis_client.lock(lock_name, timeout=20):
|
|
|
+ collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
|
|
|
+ if redis_client.get(collection_exist_cache_key):
|
|
|
+ return
|
|
|
+ # Grab the existing collection if it exists
|
|
|
+ from pymilvus import utility
|
|
|
+ alias = uuid4().hex
|
|
|
+ if self._client_config.secure:
|
|
|
+ uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
|
|
+ else:
|
|
|
+ uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
|
|
+ connections.connect(alias=alias, uri=uri, user=self._client_config.user,
|
|
|
+ password=self._client_config.password)
|
|
|
+ if not utility.has_collection(self._collection_name, using=alias):
|
|
|
+ from pymilvus import CollectionSchema, DataType, FieldSchema
|
|
|
+ from pymilvus.orm.types import infer_dtype_bydata
|
|
|
+
|
|
|
+ # Determine embedding dim
|
|
|
+ dim = len(embeddings[0])
|
|
|
+ fields = []
|
|
|
+ if metadatas:
|
|
|
+ fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
|
|
|
+
|
|
|
+ # Create the text field
|
|
|
+ fields.append(
|
|
|
+ FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)
|
|
|
+ )
|
|
|
+ # Create the primary key field
|
|
|
+ fields.append(
|
|
|
+ FieldSchema(
|
|
|
+ Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True
|
|
|
+ )
|
|
|
+ )
|
|
|
+ # Create the vector field, supports binary or float vectors
|
|
|
+ fields.append(
|
|
|
+ FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)
|
|
|
+ )
|
|
|
+
|
|
|
+ # Create the schema for the collection
|
|
|
+ schema = CollectionSchema(fields)
|
|
|
+
|
|
|
+ for x in schema.fields:
|
|
|
+ self._fields.append(x.name)
|
|
|
+ # Since primary field is auto-id, no need to track it
|
|
|
+ self._fields.remove(Field.PRIMARY_KEY.value)
|
|
|
|
|
|
+ # Create the collection
|
|
|
+ collection_name = self._collection_name
|
|
|
+ self._client.create_collection_with_schema(collection_name=collection_name,
|
|
|
+ schema=schema, index_param=index_params,
|
|
|
+ consistency_level=self._consistency_level)
|
|
|
+ redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
|
|
def _init_client(self, config) -> MilvusClient:
|
|
|
if config.secure:
|
|
|
uri = "https://" + str(config.host) + ":" + str(config.port)
|