Преглед на файлове

add segment with keyword issue (#3351)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Jyong преди 1 година
родител
ревизия
1f302990c6
променени са 3 файла, в които са добавени 89 реда и са изтрити 80 реда
  1. 3 0
      api/core/rag/datasource/keyword/jieba/jieba.py
  2. 83 79
      api/services/dataset_service.py
  3. 3 1
      web/app/components/base/tag-input/index.tsx

+ 3 - 0
api/core/rag/datasource/keyword/jieba/jieba.py

@@ -48,6 +48,9 @@ class Jieba(BaseKeyword):
                 text = texts[i]
                 if keywords_list:
                     keywords = keywords_list[i]
+                    if not keywords:
+                        keywords = keyword_table_handler.extract_keywords(text.page_content,
+                                                                          self._config.max_keywords_per_chunk)
                 else:
                     keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
                 self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))

+ 83 - 79
api/services/dataset_service.py

@@ -1046,73 +1046,11 @@ class SegmentService:
                 credentials=embedding_model.credentials,
                 texts=[content]
             )
-        max_position = db.session.query(func.max(DocumentSegment.position)).filter(
-            DocumentSegment.document_id == document.id
-        ).scalar()
-        segment_document = DocumentSegment(
-            tenant_id=current_user.current_tenant_id,
-            dataset_id=document.dataset_id,
-            document_id=document.id,
-            index_node_id=doc_id,
-            index_node_hash=segment_hash,
-            position=max_position + 1 if max_position else 1,
-            content=content,
-            word_count=len(content),
-            tokens=tokens,
-            status='completed',
-            indexing_at=datetime.datetime.utcnow(),
-            completed_at=datetime.datetime.utcnow(),
-            created_by=current_user.id
-        )
-        if document.doc_form == 'qa_model':
-            segment_document.answer = args['answer']
-
-        db.session.add(segment_document)
-        db.session.commit()
-
-        # save vector index
-        try:
-            VectorService.create_segments_vector([args['keywords']], [segment_document], dataset)
-        except Exception as e:
-            logging.exception("create segment index failed")
-            segment_document.enabled = False
-            segment_document.disabled_at = datetime.datetime.utcnow()
-            segment_document.status = 'error'
-            segment_document.error = str(e)
-            db.session.commit()
-        segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first()
-        return segment
-
-    @classmethod
-    def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset):
-        embedding_model = None
-        if dataset.indexing_technique == 'high_quality':
-            model_manager = ModelManager()
-            embedding_model = model_manager.get_model_instance(
-                tenant_id=current_user.current_tenant_id,
-                provider=dataset.embedding_model_provider,
-                model_type=ModelType.TEXT_EMBEDDING,
-                model=dataset.embedding_model
-            )
-        max_position = db.session.query(func.max(DocumentSegment.position)).filter(
-            DocumentSegment.document_id == document.id
-        ).scalar()
-        pre_segment_data_list = []
-        segment_data_list = []
-        keywords_list = []
-        for segment_item in segments:
-            content = segment_item['content']
-            doc_id = str(uuid.uuid4())
-            segment_hash = helper.generate_text_hash(content)
-            tokens = 0
-            if dataset.indexing_technique == 'high_quality' and embedding_model:
-                # calc embedding use tokens
-                model_type_instance = cast(TextEmbeddingModel, embedding_model.model_type_instance)
-                tokens = model_type_instance.get_num_tokens(
-                    model=embedding_model.model,
-                    credentials=embedding_model.credentials,
-                    texts=[content]
-                )
+        lock_name = 'add_segment_lock_document_id_{}'.format(document.id)
+        with redis_client.lock(lock_name, timeout=600):
+            max_position = db.session.query(func.max(DocumentSegment.position)).filter(
+                DocumentSegment.document_id == document.id
+            ).scalar()
             segment_document = DocumentSegment(
                 tenant_id=current_user.current_tenant_id,
                 dataset_id=document.dataset_id,
@@ -1129,25 +1067,91 @@ class SegmentService:
                 created_by=current_user.id
             )
             if document.doc_form == 'qa_model':
-                segment_document.answer = segment_item['answer']
-            db.session.add(segment_document)
-            segment_data_list.append(segment_document)
+                segment_document.answer = args['answer']
 
-            pre_segment_data_list.append(segment_document)
-            keywords_list.append(segment_item['keywords'])
+            db.session.add(segment_document)
+            db.session.commit()
 
-        try:
             # save vector index
-            VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset)
-        except Exception as e:
-            logging.exception("create segment index failed")
-            for segment_document in segment_data_list:
+            try:
+                VectorService.create_segments_vector([args['keywords']], [segment_document], dataset)
+            except Exception as e:
+                logging.exception("create segment index failed")
                 segment_document.enabled = False
                 segment_document.disabled_at = datetime.datetime.utcnow()
                 segment_document.status = 'error'
                 segment_document.error = str(e)
-        db.session.commit()
-        return segment_data_list
+                db.session.commit()
+            segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first()
+            return segment
+
+    @classmethod
+    def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset):
+        lock_name = 'multi_add_segment_lock_document_id_{}'.format(document.id)
+        with redis_client.lock(lock_name, timeout=600):
+            embedding_model = None
+            if dataset.indexing_technique == 'high_quality':
+                model_manager = ModelManager()
+                embedding_model = model_manager.get_model_instance(
+                    tenant_id=current_user.current_tenant_id,
+                    provider=dataset.embedding_model_provider,
+                    model_type=ModelType.TEXT_EMBEDDING,
+                    model=dataset.embedding_model
+                )
+            max_position = db.session.query(func.max(DocumentSegment.position)).filter(
+                DocumentSegment.document_id == document.id
+            ).scalar()
+            pre_segment_data_list = []
+            segment_data_list = []
+            keywords_list = []
+            for segment_item in segments:
+                content = segment_item['content']
+                doc_id = str(uuid.uuid4())
+                segment_hash = helper.generate_text_hash(content)
+                tokens = 0
+                if dataset.indexing_technique == 'high_quality' and embedding_model:
+                    # calc embedding use tokens
+                    model_type_instance = cast(TextEmbeddingModel, embedding_model.model_type_instance)
+                    tokens = model_type_instance.get_num_tokens(
+                        model=embedding_model.model,
+                        credentials=embedding_model.credentials,
+                        texts=[content]
+                    )
+                segment_document = DocumentSegment(
+                    tenant_id=current_user.current_tenant_id,
+                    dataset_id=document.dataset_id,
+                    document_id=document.id,
+                    index_node_id=doc_id,
+                    index_node_hash=segment_hash,
+                    position=max_position + 1 if max_position else 1,
+                    content=content,
+                    word_count=len(content),
+                    tokens=tokens,
+                    status='completed',
+                    indexing_at=datetime.datetime.utcnow(),
+                    completed_at=datetime.datetime.utcnow(),
+                    created_by=current_user.id
+                )
+                if document.doc_form == 'qa_model':
+                    segment_document.answer = segment_item['answer']
+                db.session.add(segment_document)
+                segment_data_list.append(segment_document)
+
+                pre_segment_data_list.append(segment_document)
+                keywords_list.append(segment_item['keywords'])
+
+            try:
+                # save vector index
+                VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset)
+            except Exception as e:
+                logging.exception("create segment index failed")
+                for segment_document in segment_data_list:
+                    segment_document.enabled = False
+                    segment_document.disabled_at = datetime.datetime.utcnow()
+                    segment_document.status = 'error'
+                    segment_document.error = str(e)
+            db.session.commit()
+            return segment_data_list
 
     @classmethod
     def update_segment(cls, args: dict, segment: DocumentSegment, document: Document, dataset: Dataset):

+ 3 - 1
web/app/components/base/tag-input/index.tsx

@@ -56,7 +56,9 @@ const TagInput: FC<TagInputProps> = ({
       }
 
       onChange([...items, valueTrimed])
-      setValue('')
+      setTimeout(() => {
+        setValue('')
+      })
     }
   }