Selaa lähdekoodia

Feat/milvus support (#671)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: JzoNg <jzongcode@gmail.com>
Jyong 1 vuosi sitten
vanhempi
commit
082f8b17ab

+ 0 - 1
api/controllers/console/datasets/datasets_segments.py

@@ -292,4 +292,3 @@ api.add_resource(DatasetDocumentSegmentAddApi,
                  '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment')
 api.add_resource(DatasetDocumentSegmentUpdateApi,
                  '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')
-

+ 0 - 123
api/core/index/vector_index/test-embedding.py

@@ -1,123 +0,0 @@
-import numpy as np
-import sklearn.decomposition
-import pickle
-import time
-
-
-# Apply 'Algorithm 1' to the ada-002 embeddings to make them isotropic, taken from the paper:
-# ALL-BUT-THE-TOP: SIMPLE AND EFFECTIVE POST- PROCESSING FOR WORD REPRESENTATIONS
-# Jiaqi Mu, Pramod Viswanath
-
-# This uses Principal Component Analysis (PCA) to 'evenly distribute' the embedding vectors (make them isotropic)
-# For more information on PCA, see https://jamesmccaffrey.wordpress.com/2021/07/16/computing-pca-using-numpy-without-scikit/
-
-
-# get the file pointer of the pickle containing the embeddings
-fp = open('/path/to/your/data/Embedding-Latest.pkl', 'rb')
-
-
-# the embedding data here is a dict consisting of key / value pairs
-# the key is the hash of the message (SHA3-256), the value is the embedding from ada-002 (array of dimension 1536)
-# the hash can be used to lookup the orignal text in a database
-E = pickle.load(fp) # load the data into memory
-
-# seperate the keys (hashes) and values (embeddings) into seperate vectors
-K = list(E.keys()) # vector of all the hash values
-X = np.array(list(E.values())) # vector of all the embeddings, converted to numpy arrays
-
-
-# list the total number of embeddings
-# this can be truncated if there are too many embeddings to do PCA on
-print(f"Total number of embeddings: {len(X)}")
-
-# get dimension of embeddings, used later
-Dim = len(X[0])
-
-# flash out the first few embeddings
-print("First two embeddings are: ")
-print(X[0])
-print(f"First embedding length: {len(X[0])}")
-print(X[1])
-print(f"Second embedding length: {len(X[1])}")
-
-
-# compute the mean of all the embeddings, and flash the result
-mu = np.mean(X, axis=0) # same as mu in paper
-print(f"Mean embedding vector: {mu}")
-print(f"Mean embedding vector length: {len(mu)}")
-
-
-# subtract the mean vector from each embedding vector ... vectorized in numpy
-X_tilde = X - mu # same as v_tilde(w) in paper
-
-
-
-# do the heavy lifting of extracting the principal components
-# note that this is a function of the embeddings you currently have here, and this set may grow over time
-# therefore the PCA basis vectors may change over time, and your final isotropic embeddings may drift over time
-# but the drift should stabilize after you have extracted enough embedding data to characterize the nature of the embedding engine
-print(f"Performing PCA on the normalized embeddings ...")
-pca = sklearn.decomposition.PCA()  # new object
-TICK = time.time() # start timer
-pca.fit(X_tilde) # do the heavy lifting!
-TOCK = time.time() # end timer
-DELTA = TOCK - TICK
-
-print(f"PCA finished in {DELTA} seconds ...")
-
-# dimensional reduction stage (the only hyperparameter)
-# pick max dimension of PCA components to express embddings
-# in general this is some integer less than or equal to the dimension of your embeddings
-# it could be set as a high percentile, say 95th percentile of pca.explained_variance_ratio_
-# but just hardcoding a constant here
-D = 15 # hyperparameter on dimension (out of 1536 for ada-002), paper recommeds D = Dim/100
-
-
-# form the set of v_prime(w), which is the final embedding
-# this could be vectorized in numpy to speed it up, but coding it directly here in a double for-loop to avoid errors and to be transparent
-E_prime = dict() # output dict of the new embeddings
-N = len(X_tilde)
-N10 = round(N/10)
-U = pca.components_ # set of PCA basis vectors, sorted by most significant to least significant
-print(f"Shape of full set of PCA componenents {U.shape}")
-U = U[0:D,:] # take the top D dimensions (or take them all if D is the size of the embedding vector)
-print(f"Shape of downselected PCA componenents {U.shape}")
-for ii in range(N):
-    v_tilde = X_tilde[ii]
-    v = X[ii]
-    v_projection = np.zeros(Dim) # start to build the projection
-    # project the original embedding onto the PCA basis vectors, use only first D dimensions
-    for jj in range(D):
-        u_jj = U[jj,:] # vector
-        v_jj = np.dot(u_jj,v) # scaler
-        v_projection += v_jj*u_jj # vector
-    v_prime = v_tilde - v_projection # final embedding vector
-    v_prime = v_prime/np.linalg.norm(v_prime) # create unit vector
-    E_prime[K[ii]] = v_prime
-
-    if (ii%N10 == 0) or (ii == N-1):
-        print(f"Finished with {ii+1} embeddings out of {N} ({round(100*ii/N)}% done)")
-
-
-# save as new pickle
-print("Saving new pickle ...")
-embeddingName = '/path/to/your/data/Embedding-Latest-Isotropic.pkl'
-with open(embeddingName, 'wb') as f:  # Python 3: open(..., 'wb')
-    pickle.dump([E_prime,mu,U], f)
-    print(embeddingName)
-
-print("Done!")
-
-# When working with live data with a new embedding from ada-002, be sure to tranform it first with this function before comparing it
-#
-def projectEmbedding(v,mu,U):
-    v = np.array(v)
-    v_tilde = v - mu
-    v_projection = np.zeros(len(v)) # start to build the projection
-    # project the original embedding onto the PCA basis vectors, use only first D dimensions
-    for u in U:
-        v_jj = np.dot(u,v) # scaler
-        v_projection += v_jj*u # vector
-    v_prime = v_tilde - v_projection # final embedding vector
-    v_prime = v_prime/np.linalg.norm(v_prime) # create unit vector
-    return v_prime

+ 42 - 34
api/core/indexing_runner.py

@@ -7,6 +7,7 @@ import re
 import threading
 import time
 import uuid
+from concurrent.futures import ThreadPoolExecutor
 from multiprocessing import Process
 from typing import Optional, List, cast
 
@@ -14,7 +15,6 @@ import openai
 from billiard.pool import Pool
 from flask import current_app, Flask
 from flask_login import current_user
-from gevent.threadpool import ThreadPoolExecutor
 from langchain.embeddings import OpenAIEmbeddings
 from langchain.schema import Document
 from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
@@ -516,43 +516,51 @@ class IndexingRunner:
                 model_name='gpt-3.5-turbo',
                 max_tokens=2000
             )
-            self.format_document(llm, documents, split_documents, document_form)
+            threads = []
+            for doc in documents:
+                document_format_thread = threading.Thread(target=self.format_document, kwargs={
+                    'llm': llm, 'document_node': doc, 'split_documents': split_documents, 'document_form': document_form})
+                threads.append(document_format_thread)
+                document_format_thread.start()
+            for thread in threads:
+                thread.join()
             all_documents.extend(split_documents)
 
         return all_documents
 
-    def format_document(self, llm: StreamableOpenAI, documents: List[Document], split_documents: List, document_form: str):
-        for document_node in documents:
-            format_documents = []
-            if document_node.page_content is None or not document_node.page_content.strip():
-                return format_documents
-            if document_form == 'text_model':
-                # text model document
-                doc_id = str(uuid.uuid4())
-                hash = helper.generate_text_hash(document_node.page_content)
-
-                document_node.metadata['doc_id'] = doc_id
-                document_node.metadata['doc_hash'] = hash
-
-                format_documents.append(document_node)
-            elif document_form == 'qa_model':
-                try:
-                    # qa model document
-                    response = LLMGenerator.generate_qa_document_sync(llm, document_node.page_content)
-                    document_qa_list = self.format_split_text(response)
-                    qa_documents = []
-                    for result in document_qa_list:
-                        qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
-                        doc_id = str(uuid.uuid4())
-                        hash = helper.generate_text_hash(result['question'])
-                        qa_document.metadata['answer'] = result['answer']
-                        qa_document.metadata['doc_id'] = doc_id
-                        qa_document.metadata['doc_hash'] = hash
-                        qa_documents.append(qa_document)
-                    format_documents.extend(qa_documents)
-                except Exception:
-                    continue
-            split_documents.extend(format_documents)
+    def format_document(self, llm: StreamableOpenAI, document_node, split_documents: List, document_form: str):
+        print(document_node.page_content)
+        format_documents = []
+        if document_node.page_content is None or not document_node.page_content.strip():
+            return format_documents
+        if document_form == 'text_model':
+            # text model document
+            doc_id = str(uuid.uuid4())
+            hash = helper.generate_text_hash(document_node.page_content)
+
+            document_node.metadata['doc_id'] = doc_id
+            document_node.metadata['doc_hash'] = hash
+
+            format_documents.append(document_node)
+        elif document_form == 'qa_model':
+            try:
+                # qa model document
+                response = LLMGenerator.generate_qa_document_sync(llm, document_node.page_content)
+                document_qa_list = self.format_split_text(response)
+                qa_documents = []
+                for result in document_qa_list:
+                    qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
+                    doc_id = str(uuid.uuid4())
+                    hash = helper.generate_text_hash(result['question'])
+                    qa_document.metadata['answer'] = result['answer']
+                    qa_document.metadata['doc_id'] = doc_id
+                    qa_document.metadata['doc_hash'] = hash
+                    qa_documents.append(qa_document)
+                format_documents.extend(qa_documents)
+            except Exception:
+                logging.error("sss")
+        split_documents.extend(format_documents)
+
 
     def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter,
                                          processing_rule: DatasetProcessRule) -> List[Document]:

+ 1 - 0
api/events/event_handlers/__init__.py

@@ -7,3 +7,4 @@ from .clean_when_dataset_deleted import handle
 from .update_app_dataset_join_when_app_model_config_updated import handle
 from .generate_conversation_name_when_first_message_created import handle
 from .generate_conversation_summary_when_few_message_created import handle
+from .create_document_index import handle

+ 48 - 0
api/events/event_handlers/create_document_index.py

@@ -0,0 +1,48 @@
+from events.dataset_event import dataset_was_deleted
+from events.event_handlers.document_index_event import document_index_created
+from tasks.clean_dataset_task import clean_dataset_task
+import datetime
+import logging
+import time
+
+import click
+from celery import shared_task
+from werkzeug.exceptions import NotFound
+
+from core.indexing_runner import IndexingRunner, DocumentIsPausedException
+from extensions.ext_database import db
+from models.dataset import Document
+
+
+@document_index_created.connect
+def handle(sender, **kwargs):
+    dataset_id = sender
+    document_ids = kwargs.get('document_ids', None)
+    documents = []
+    start_at = time.perf_counter()
+    for document_id in document_ids:
+        logging.info(click.style('Start process document: {}'.format(document_id), fg='green'))
+
+        document = db.session.query(Document).filter(
+            Document.id == document_id,
+            Document.dataset_id == dataset_id
+        ).first()
+
+        if not document:
+            raise NotFound('Document not found')
+
+        document.indexing_status = 'parsing'
+        document.processing_started_at = datetime.datetime.utcnow()
+        documents.append(document)
+        db.session.add(document)
+    db.session.commit()
+
+    try:
+        indexing_runner = IndexingRunner()
+        indexing_runner.run(documents)
+        end_at = time.perf_counter()
+        logging.info(click.style('Processed dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green'))
+    except DocumentIsPausedException as ex:
+        logging.info(click.style(str(ex), fg='yellow'))
+    except Exception:
+        pass

+ 4 - 0
api/events/event_handlers/document_index_event.py

@@ -0,0 +1,4 @@
+from blinker import signal
+
+# sender: document
+document_index_created = signal('document-index-created')

+ 2 - 0
api/services/dataset_service.py

@@ -10,6 +10,7 @@ from flask import current_app
 from sqlalchemy import func
 
 from core.llm.token_calculator import TokenCalculator
+from events.event_handlers.document_index_event import document_index_created
 from extensions.ext_redis import redis_client
 from flask_login import current_user
 
@@ -520,6 +521,7 @@ class DocumentService:
             db.session.commit()
 
             # trigger async task
+            #document_index_created.send(dataset.id, document_ids=document_ids)
             document_indexing_task.delay(dataset.id, document_ids)
 
         return documents, batch

+ 0 - 24
api/tasks/generate_test_task.py

@@ -1,24 +0,0 @@
-import logging
-import time
-
-import click
-import requests
-from celery import shared_task
-
-from core.generator.llm_generator import LLMGenerator
-
-
-@shared_task
-def generate_test_task():
-    logging.info(click.style('Start generate test', fg='green'))
-    start_at = time.perf_counter()
-
-    try:
-        #res = requests.post('https://api.openai.com/v1/chat/completions')
-        answer = LLMGenerator.generate_conversation_name('84b2202c-c359-46b7-a810-bce50feaa4d1', 'avb', 'ccc')
-        print(f'answer: {answer}')
-
-        end_at = time.perf_counter()
-        logging.info(click.style('Conversation test, latency: {}'.format(end_at - start_at), fg='green'))
-    except Exception:
-        logging.exception("generate test failed")