Jyong 10 mesiacov pred
rodič
commit
02e4de5166

+ 8 - 8
api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py

@@ -79,8 +79,6 @@ class TiDBVector(BaseVector):
                 return
             with Session(self._engine) as session:
                 session.begin()
-                drop_statement = sql_text(f"""DROP TABLE IF EXISTS {self._collection_name}; """)
-                session.execute(drop_statement)
                 create_statement = sql_text(f"""
                     CREATE TABLE IF NOT EXISTS {self._collection_name} (
                         id CHAR(36) PRIMARY KEY,
@@ -123,7 +121,7 @@ class TiDBVector(BaseVector):
 
     def text_exists(self, id: str) -> bool:
         result = self.get_ids_by_metadata_field('doc_id', id)
-        return len(result) > 0
+        return bool(result)
 
     def delete_by_ids(self, ids: list[str]) -> None:
         with Session(self._engine) as session:
@@ -184,14 +182,14 @@ class TiDBVector(BaseVector):
         docs = []
         if self._distance_func == 'l2':
             tidb_func = 'Vec_l2_distance'
-        elif self._distance_func == 'l2':
+        elif self._distance_func == 'cosine':
             tidb_func = 'Vec_Cosine_distance'
         else:
             tidb_func = 'Vec_Cosine_distance'
 
         with Session(self._engine) as session:
             select_statement = sql_text(
-                f"""SELECT meta, text FROM (
+                f"""SELECT meta, text, distance FROM (
                         SELECT meta, text, {tidb_func}(vector, "{query_vector_str}")  as distance 
                         FROM {self._collection_name} 
                         ORDER BY distance
@@ -199,9 +197,11 @@ class TiDBVector(BaseVector):
                     ) t WHERE distance < {distance};"""
             )
             res = session.execute(select_statement)
-            results = [(row[0], row[1]) for row in res]
-            for meta, text in results:
-                docs.append(Document(page_content=text, metadata=json.loads(meta)))
+            results = [(row[0], row[1], row[2]) for row in res]
+            for meta, text, distance in results:
+                metadata = json.loads(meta)
+                metadata['score'] = 1 - distance
+                docs.append(Document(page_content=text, metadata=metadata))
         return docs
 
     def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: