ソースを参照

fix: embedding init err (#956)

takatost 1 年間 前
コミット
78d3aa5fcd

+ 3 - 2
api/core/model_providers/models/embedding/xinference_embedding.py

@@ -1,4 +1,4 @@
-from langchain.embeddings import XinferenceEmbeddings
+from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbedding as XinferenceEmbeddings
 from replicate.exceptions import ModelError, ReplicateError
 
 from core.model_providers.error import LLMBadRequestError
@@ -14,7 +14,8 @@ class XinferenceEmbedding(BaseEmbedding):
         )
 
         client = XinferenceEmbeddings(
-            **credentials,
+            server_url=credentials['server_url'],
+            model_uid=credentials['model_uid'],
         )
 
         super().__init__(model_provider, client, name)

+ 21 - 0
api/core/third_party/langchain/embeddings/xinference_embedding.py

@@ -0,0 +1,21 @@
+from typing import List
+
+import numpy as np
+from langchain.embeddings import XinferenceEmbeddings
+
+
+class XinferenceEmbedding(XinferenceEmbeddings):
+
+    def embed_documents(self, texts: List[str]) -> List[List[float]]:
+        vectors = super().embed_documents(texts)
+
+        normalized_vectors = [(vector / np.linalg.norm(vector)).tolist() for vector in vectors]
+
+        return normalized_vectors
+
+    def embed_query(self, text: str) -> List[float]:
+        vector = super().embed_query(text)
+
+        normalized_vector = (vector / np.linalg.norm(vector)).tolist()
+
+        return normalized_vector