Преглед изворни кода

Compatible with unique index conflicts (#3183)

Jyong пре 1 година
родитељ
комит
2e4dec365d
1 измењених фајлова са 16 додато и 12 уклоњено
  1. 16 12
      api/core/embedding/cached_embedding.py

+ 16 - 12
api/core/embedding/cached_embedding.py

@@ -41,7 +41,8 @@ class CacheEmbedding(Embeddings):
             embedding_queue_embeddings = []
             try:
                 model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
-                model_schema = model_type_instance.get_model_schema(self._model_instance.model, self._model_instance.credentials)
+                model_schema = model_type_instance.get_model_schema(self._model_instance.model,
+                                                                    self._model_instance.credentials)
                 max_chunks = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] \
                     if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties else 1
                 for i in range(0, len(embedding_queue_texts), max_chunks):
@@ -61,17 +62,20 @@ class CacheEmbedding(Embeddings):
                         except Exception as e:
                             logging.exception('Failed transform embedding: ', e)
                 cache_embeddings = []
-                for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
-                    text_embeddings[i] = embedding
-                    hash = helper.generate_text_hash(texts[i])
-                    if hash not in cache_embeddings:
-                        embedding_cache = Embedding(model_name=self._model_instance.model,
-                                              hash=hash,
-                                              provider_name=self._model_instance.provider)
-                        embedding_cache.set_embedding(embedding)
-                        db.session.add(embedding_cache)
-                        cache_embeddings.append(hash)
-                db.session.commit()
+                try:
+                    for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
+                        text_embeddings[i] = embedding
+                        hash = helper.generate_text_hash(texts[i])
+                        if hash not in cache_embeddings:
+                            embedding_cache = Embedding(model_name=self._model_instance.model,
+                                                        hash=hash,
+                                                        provider_name=self._model_instance.provider)
+                            embedding_cache.set_embedding(embedding)
+                            db.session.add(embedding_cache)
+                            cache_embeddings.append(hash)
+                    db.session.commit()
+                except IntegrityError:
+                    db.session.rollback()
             except Exception as ex:
                 db.session.rollback()
                 logger.error('Failed to embed documents: ', ex)