Bladeren bron

delete document cache embedding (#2101)

Co-authored-by: jyong <jyong@dify.ai>
Jyong 1 jaar geleden
bovenliggende
commit
ee9c7e204f
1 gewijzigde bestanden met toevoegingen van 28 en 49 verwijderingen
  1. 28 49
      api/core/embedding/cached_embedding.py

+ 28 - 49
api/core/embedding/cached_embedding.py

@@ -1,10 +1,12 @@
 import base64
 import json
 import logging
-from typing import List, Optional
+from typing import List, Optional, cast
 
 import numpy as np
 from core.model_manager import ModelInstance
+from core.model_runtime.entities.model_entities import ModelPropertyKey
+from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
 from extensions.ext_database import db
 from langchain.embeddings.base import Embeddings
 
@@ -22,56 +24,33 @@ class CacheEmbedding(Embeddings):
         self._user = user
 
     def embed_documents(self, texts: List[str]) -> List[List[float]]:
-        """Embed search docs."""
-        # use doc embedding cache or store if not exists
-        text_embeddings = [None for _ in range(len(texts))]
-        embedding_queue_indices = []
-        for i, text in enumerate(texts):
-            hash = helper.generate_text_hash(text)
-            embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
-            embedding = redis_client.get(embedding_cache_key)
-            if embedding:
-                redis_client.expire(embedding_cache_key, 3600)
-                text_embeddings[i] = list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
-
-            else:
-                embedding_queue_indices.append(i)
-
-        if embedding_queue_indices:
-            try:
+        """Embed search docs in batches of 10."""
+        text_embeddings = []
+        try:
+            model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
+            model_schema = model_type_instance.get_model_schema(self._model_instance.model, self._model_instance.credentials)
+            max_chunks = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] \
+                if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties else 1
+            for i in range(0, len(texts), max_chunks):
+                batch_texts = texts[i:i + max_chunks]
+
                 embedding_result = self._model_instance.invoke_text_embedding(
-                    texts=[texts[i] for i in embedding_queue_indices],
+                    texts=batch_texts,
                     user=self._user
                 )
 
-                embedding_results = embedding_result.embeddings
-            except Exception as ex:
-                logger.error('Failed to embed documents: ', ex)
-                raise ex
-
-            for i, indice in enumerate(embedding_queue_indices):
-                hash = helper.generate_text_hash(texts[indice])
-
-                try:
-                    embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
-                    vector = embedding_results[i]
-                    normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
-                    text_embeddings[indice] = normalized_embedding
-                    # encode embedding to base64
-                    embedding_vector = np.array(normalized_embedding)
-                    vector_bytes = embedding_vector.tobytes()
-                    # Transform to Base64
-                    encoded_vector = base64.b64encode(vector_bytes)
-                    # Transform to string
-                    encoded_str = encoded_vector.decode("utf-8")
-                    redis_client.setex(embedding_cache_key, 3600, encoded_str)
-
-                except IntegrityError:
-                    db.session.rollback()
-                    continue
-                except:
-                    logging.exception('Failed to add embedding to redis')
-                    continue
+                for vector in embedding_result.embeddings:
+                    try:
+                        normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
+                        text_embeddings.append(normalized_embedding)
+                    except IntegrityError:
+                        db.session.rollback()
+                    except Exception as e:
+                        logging.exception('Failed to add embedding to redis')
+
+        except Exception as ex:
+            logger.error('Failed to embed documents: ', ex)
+            raise ex
 
         return text_embeddings
 
@@ -82,7 +61,7 @@ class CacheEmbedding(Embeddings):
         embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
         embedding = redis_client.get(embedding_cache_key)
         if embedding:
-            redis_client.expire(embedding_cache_key, 3600)
+            redis_client.expire(embedding_cache_key, 600)
             return list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
 
 
@@ -105,7 +84,7 @@ class CacheEmbedding(Embeddings):
             encoded_vector = base64.b64encode(vector_bytes)
             # Transform to string
             encoded_str = encoded_vector.decode("utf-8")
-            redis_client.setex(embedding_cache_key, 3600, encoded_str)
+            redis_client.setex(embedding_cache_key, 600, encoded_str)
 
         except IntegrityError:
             db.session.rollback()