浏览代码

normalize embedding (#974)

Co-authored-by: jyong <jyong@dify.ai>
Jyong 1 年之前
父节点
当前提交
1fc57d7358
共有 1 个文件被更改,包括 8 次插入4 次删除
  1. 8 4
      api/core/embedding/cached_embedding.py

+ 8 - 4
api/core/embedding/cached_embedding.py

@@ -1,6 +1,7 @@
 import logging
 from typing import List
 
+import numpy as np
 from langchain.embeddings.base import Embeddings
 from sqlalchemy.exc import IntegrityError
 
@@ -32,14 +33,17 @@ class CacheEmbedding(Embeddings):
                 embedding_results = self._embeddings.client.embed_documents(embedding_queue_texts)
             except Exception as ex:
                 raise self._embeddings.handle_exceptions(ex)
-
             i = 0
+            normalized_embedding_results = []
             for text in embedding_queue_texts:
                 hash = helper.generate_text_hash(text)
 
                 try:
                     embedding = Embedding(model_name=self._embeddings.name, hash=hash)
-                    embedding.set_embedding(embedding_results[i])
+                    vector = embedding_results[i]
+                    normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
+                    normalized_embedding_results.append(normalized_embedding)
+                    embedding.set_embedding(normalized_embedding)
                     db.session.add(embedding)
                     db.session.commit()
                 except IntegrityError:
@@ -51,7 +55,7 @@ class CacheEmbedding(Embeddings):
                 finally:
                     i += 1
 
-            text_embeddings.extend(embedding_results)
+            text_embeddings.extend(normalized_embedding_results)
         return text_embeddings
 
     def embed_query(self, text: str) -> List[float]:
@@ -64,6 +68,7 @@ class CacheEmbedding(Embeddings):
 
         try:
             embedding_results = self._embeddings.client.embed_query(text)
+            embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
         except Exception as ex:
             raise self._embeddings.handle_exceptions(ex)
 
@@ -79,4 +84,3 @@ class CacheEmbedding(Embeddings):
 
         return embedding_results
 
-