|
@@ -79,8 +79,6 @@ class TiDBVector(BaseVector):
|
|
|
return
|
|
|
with Session(self._engine) as session:
|
|
|
session.begin()
|
|
|
- drop_statement = sql_text(f"""DROP TABLE IF EXISTS {self._collection_name}; """)
|
|
|
- session.execute(drop_statement)
|
|
|
create_statement = sql_text(f"""
|
|
|
CREATE TABLE IF NOT EXISTS {self._collection_name} (
|
|
|
id CHAR(36) PRIMARY KEY,
|
|
@@ -123,7 +121,7 @@ class TiDBVector(BaseVector):
|
|
|
|
|
|
def text_exists(self, id: str) -> bool:
|
|
|
result = self.get_ids_by_metadata_field('doc_id', id)
|
|
|
- return len(result) > 0
|
|
|
+ return bool(result)
|
|
|
|
|
|
def delete_by_ids(self, ids: list[str]) -> None:
|
|
|
with Session(self._engine) as session:
|
|
@@ -184,14 +182,14 @@ class TiDBVector(BaseVector):
|
|
|
docs = []
|
|
|
if self._distance_func == 'l2':
|
|
|
tidb_func = 'Vec_l2_distance'
|
|
|
- elif self._distance_func == 'l2':
|
|
|
+ elif self._distance_func == 'cosine':
|
|
|
tidb_func = 'Vec_Cosine_distance'
|
|
|
else:
|
|
|
tidb_func = 'Vec_Cosine_distance'
|
|
|
|
|
|
with Session(self._engine) as session:
|
|
|
select_statement = sql_text(
|
|
|
- f"""SELECT meta, text FROM (
|
|
|
+ f"""SELECT meta, text, distance FROM (
|
|
|
SELECT meta, text, {tidb_func}(vector, "{query_vector_str}") as distance
|
|
|
FROM {self._collection_name}
|
|
|
ORDER BY distance
|
|
@@ -199,9 +197,11 @@ class TiDBVector(BaseVector):
|
|
|
) t WHERE distance < {distance};"""
|
|
|
)
|
|
|
res = session.execute(select_statement)
|
|
|
- results = [(row[0], row[1]) for row in res]
|
|
|
- for meta, text in results:
|
|
|
- docs.append(Document(page_content=text, metadata=json.loads(meta)))
|
|
|
+ results = [(row[0], row[1], row[2]) for row in res]
|
|
|
+ for meta, text, distance in results:
|
|
|
+ metadata = json.loads(meta)
|
|
|
+ metadata['score'] = 1 - distance
|
|
|
+ docs.append(Document(page_content=text, metadata=metadata))
|
|
|
return docs
|
|
|
|
|
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|