Forráskód Böngészése

add redis lock on create collection in multiple thread mode (#3054)

Co-authored-by: jyong <jyong@dify.ai>
Jyong 1 éve
szülő
commit
84d118de07

+ 22 - 19
api/core/rag/datasource/keyword/jieba/jieba.py

@@ -8,6 +8,7 @@ from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaK
 from core.rag.datasource.keyword.keyword_base import BaseKeyword
 from core.rag.models.document import Document
 from extensions.ext_database import db
+from extensions.ext_redis import redis_client
 from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment
 
 
@@ -121,26 +122,28 @@ class Jieba(BaseKeyword):
         db.session.commit()
 
     def _get_dataset_keyword_table(self) -> Optional[dict]:
-        dataset_keyword_table = self.dataset.dataset_keyword_table
-        if dataset_keyword_table:
-            if dataset_keyword_table.keyword_table_dict:
-                return dataset_keyword_table.keyword_table_dict['__data__']['table']
-        else:
-            dataset_keyword_table = DatasetKeywordTable(
-                dataset_id=self.dataset.id,
-                keyword_table=json.dumps({
-                    '__type__': 'keyword_table',
-                    '__data__': {
-                        "index_id": self.dataset.id,
-                        "summary": None,
-                        "table": {}
-                    }
-                }, cls=SetEncoder)
-            )
-            db.session.add(dataset_keyword_table)
-            db.session.commit()
+        lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
+        with redis_client.lock(lock_name, timeout=20):
+            dataset_keyword_table = self.dataset.dataset_keyword_table
+            if dataset_keyword_table:
+                if dataset_keyword_table.keyword_table_dict:
+                    return dataset_keyword_table.keyword_table_dict['__data__']['table']
+            else:
+                dataset_keyword_table = DatasetKeywordTable(
+                    dataset_id=self.dataset.id,
+                    keyword_table=json.dumps({
+                        '__type__': 'keyword_table',
+                        '__data__': {
+                            "index_id": self.dataset.id,
+                            "summary": None,
+                            "table": {}
+                        }
+                    }, cls=SetEncoder)
+                )
+                db.session.add(dataset_keyword_table)
+                db.session.commit()
 
-        return {}
+            return {}
 
     def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]) -> dict:
         for keyword in keywords:

+ 55 - 50
api/core/rag/datasource/vdb/milvus/milvus_vector.py

@@ -8,6 +8,7 @@ from pymilvus import MilvusClient, MilvusException, connections
 from core.rag.datasource.vdb.field import Field
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.models.document import Document
+from extensions.ext_redis import redis_client
 
 logger = logging.getLogger(__name__)
 
@@ -61,17 +62,7 @@ class MilvusVector(BaseVector):
             'params': {"M": 8, "efConstruction": 64}
         }
         metadatas = [d.metadata for d in texts]
-
-        # Grab the existing collection if it exists
-        from pymilvus import utility
-        alias = uuid4().hex
-        if self._client_config.secure:
-            uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
-        else:
-            uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
-        connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
-        if not utility.has_collection(self._collection_name, using=alias):
-            self.create_collection(embeddings, metadatas, index_params)
+        self.create_collection(embeddings, metadatas, index_params)
         self.add_texts(texts, embeddings)
 
     def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
@@ -187,46 +178,60 @@ class MilvusVector(BaseVector):
 
     def create_collection(
             self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
-    ) -> str:
-        from pymilvus import CollectionSchema, DataType, FieldSchema
-        from pymilvus.orm.types import infer_dtype_bydata
-
-        # Determine embedding dim
-        dim = len(embeddings[0])
-        fields = []
-        if metadatas:
-            fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
-
-        # Create the text field
-        fields.append(
-            FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)
-        )
-        # Create the primary key field
-        fields.append(
-            FieldSchema(
-                Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True
-            )
-        )
-        # Create the vector field, supports binary or float vectors
-        fields.append(
-            FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)
-        )
-
-        # Create the schema for the collection
-        schema = CollectionSchema(fields)
-
-        for x in schema.fields:
-            self._fields.append(x.name)
-        # Since primary field is auto-id, no need to track it
-        self._fields.remove(Field.PRIMARY_KEY.value)
-
-        # Create the collection
-        collection_name = self._collection_name
-        self._client.create_collection_with_schema(collection_name=collection_name,
-                                                   schema=schema, index_param=index_params,
-                                                   consistency_level=self._consistency_level)
-        return collection_name
+    ):
+        lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
+        with redis_client.lock(lock_name, timeout=20):
+            collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
+            if redis_client.get(collection_exist_cache_key):
+                return
+            # Grab the existing collection if it exists
+            from pymilvus import utility
+            alias = uuid4().hex
+            if self._client_config.secure:
+                uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
+            else:
+                uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
+            connections.connect(alias=alias, uri=uri, user=self._client_config.user,
+                                password=self._client_config.password)
+            if not utility.has_collection(self._collection_name, using=alias):
+                from pymilvus import CollectionSchema, DataType, FieldSchema
+                from pymilvus.orm.types import infer_dtype_bydata
+
+                # Determine embedding dim
+                dim = len(embeddings[0])
+                fields = []
+                if metadatas:
+                    fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
+
+                # Create the text field
+                fields.append(
+                    FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)
+                )
+                # Create the primary key field
+                fields.append(
+                    FieldSchema(
+                        Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True
+                    )
+                )
+                # Create the vector field, supports binary or float vectors
+                fields.append(
+                    FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)
+                )
+
+                # Create the schema for the collection
+                schema = CollectionSchema(fields)
+
+                for x in schema.fields:
+                    self._fields.append(x.name)
+                # Since primary field is auto-id, no need to track it
+                self._fields.remove(Field.PRIMARY_KEY.value)
 
+                # Create the collection
+                collection_name = self._collection_name
+                self._client.create_collection_with_schema(collection_name=collection_name,
+                                                           schema=schema, index_param=index_params,
+                                                           consistency_level=self._consistency_level)
+            redis_client.set(collection_exist_cache_key, 1, ex=3600)
     def _init_client(self, config) -> MilvusClient:
         if config.secure:
             uri = "https://" + str(config.host) + ":" + str(config.port)

+ 40 - 33
api/core/rag/datasource/vdb/qdrant/qdrant_vector.py

@@ -20,6 +20,7 @@ from qdrant_client.local.qdrant_local import QdrantLocal
 from core.rag.datasource.vdb.field import Field
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.models.document import Document
+from extensions.ext_redis import redis_client
 
 if TYPE_CHECKING:
     from qdrant_client import grpc  # noqa
@@ -77,6 +78,17 @@ class QdrantVector(BaseVector):
             vector_size = len(embeddings[0])
             # get collection name
             collection_name = self._collection_name
+            # create collection
+            self.create_collection(collection_name, vector_size)
+
+            self.add_texts(texts, embeddings, **kwargs)
+
+    def create_collection(self, collection_name: str, vector_size: int):
+        lock_name = 'vector_indexing_lock_{}'.format(collection_name)
+        with redis_client.lock(lock_name, timeout=20):
+            collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
+            if redis_client.get(collection_exist_cache_key):
+                return
             collection_name = collection_name or uuid.uuid4().hex
             all_collection_name = []
             collections_response = self._client.get_collections()
@@ -84,40 +96,35 @@ class QdrantVector(BaseVector):
             for collection in collection_list:
                 all_collection_name.append(collection.name)
             if collection_name not in all_collection_name:
-                # create collection
-                self.create_collection(collection_name, vector_size)
-
-            self.add_texts(texts, embeddings, **kwargs)
-
-    def create_collection(self, collection_name: str, vector_size: int):
-        from qdrant_client.http import models as rest
-        vectors_config = rest.VectorParams(
-            size=vector_size,
-            distance=rest.Distance[self._distance_func],
-        )
-        hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
-                                     max_indexing_threads=0, on_disk=False)
-        self._client.recreate_collection(
-            collection_name=collection_name,
-            vectors_config=vectors_config,
-            hnsw_config=hnsw_config,
-            timeout=int(self._client_config.timeout),
-        )
+                from qdrant_client.http import models as rest
+                vectors_config = rest.VectorParams(
+                    size=vector_size,
+                    distance=rest.Distance[self._distance_func],
+                )
+                hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
+                                             max_indexing_threads=0, on_disk=False)
+                self._client.recreate_collection(
+                    collection_name=collection_name,
+                    vectors_config=vectors_config,
+                    hnsw_config=hnsw_config,
+                    timeout=int(self._client_config.timeout),
+                )
 
-        # create payload index
-        self._client.create_payload_index(collection_name, Field.GROUP_KEY.value,
-                                          field_schema=PayloadSchemaType.KEYWORD,
-                                          field_type=PayloadSchemaType.KEYWORD)
-        # creat full text index
-        text_index_params = TextIndexParams(
-            type=TextIndexType.TEXT,
-            tokenizer=TokenizerType.MULTILINGUAL,
-            min_token_len=2,
-            max_token_len=20,
-            lowercase=True
-        )
-        self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value,
-                                          field_schema=text_index_params)
+                # create payload index
+                self._client.create_payload_index(collection_name, Field.GROUP_KEY.value,
+                                                  field_schema=PayloadSchemaType.KEYWORD,
+                                                  field_type=PayloadSchemaType.KEYWORD)
+                # creat full text index
+                text_index_params = TextIndexParams(
+                    type=TextIndexType.TEXT,
+                    tokenizer=TokenizerType.MULTILINGUAL,
+                    min_token_len=2,
+                    max_token_len=20,
+                    lowercase=True
+                )
+                self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value,
+                                                  field_schema=text_index_params)
+            redis_client.set(collection_exist_cache_key, 1, ex=3600)
 
     def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
         uuids = self._get_uuids(documents)

+ 15 - 7
api/core/rag/datasource/vdb/weaviate/weaviate_vector.py

@@ -8,6 +8,7 @@ from pydantic import BaseModel, root_validator
 from core.rag.datasource.vdb.field import Field
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.models.document import Document
+from extensions.ext_redis import redis_client
 from models.dataset import Dataset
 
 
@@ -79,16 +80,23 @@ class WeaviateVector(BaseVector):
         }
 
     def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
-
-        schema = self._default_schema(self._collection_name)
-
-        # check whether the index already exists
-        if not self._client.schema.contains(schema):
-            # create collection
-            self._client.schema.create_class(schema)
+        # create collection
+        self._create_collection()
         # create vector
         self.add_texts(texts, embeddings)
 
+    def _create_collection(self):
+        lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
+        with redis_client.lock(lock_name, timeout=20):
+            collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
+            if redis_client.get(collection_exist_cache_key):
+                return
+            schema = self._default_schema(self._collection_name)
+            if not self._client.schema.contains(schema):
+                # create collection
+                self._client.schema.create_class(schema)
+            redis_client.set(collection_exist_cache_key, 1, ex=3600)
+
     def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
         uuids = self._get_uuids(documents)
         texts = [d.page_content for d in documents]