Explorar el Código

refactor(rag): switch to dify_config. (#6410)

Co-authored-by: -LAN- <laipz8200@outlook.com>
Poorandy hace 9 meses
padre
commit
c8f5dfcf17

+ 19 - 18
api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py

@@ -7,8 +7,8 @@ _import_err_msg = (
     "`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
     "please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
 )
-from flask import current_app
 
+from configs import dify_config
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@@ -36,7 +36,7 @@ class AnalyticdbConfig(BaseModel):
             "region_id": self.region_id,
             "read_timeout": self.read_timeout,
         }
-    
+
 class AnalyticdbVector(BaseVector):
     _instance = None
     _init = False
@@ -45,7 +45,7 @@ class AnalyticdbVector(BaseVector):
         if cls._instance is None:
             cls._instance = super().__new__(cls)
         return cls._instance
-    
+
     def __init__(self, collection_name: str, config: AnalyticdbConfig):
         # collection_name must be updated every time
         self._collection_name = collection_name.lower()
@@ -105,7 +105,7 @@ class AnalyticdbVector(BaseVector):
                 raise ValueError(
                     f"failed to create namespace {self.config.namespace}: {e}"
                 )
-            
+
     def _create_collection_if_not_exists(self, embedding_dimension: int):
         from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
         from Tea.exceptions import TeaException
@@ -149,7 +149,7 @@ class AnalyticdbVector(BaseVector):
 
     def get_type(self) -> str:
         return VectorType.ANALYTICDB
-    
+
     def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
         dimension = len(embeddings[0])
         self._create_collection_if_not_exists(dimension)
@@ -199,7 +199,7 @@ class AnalyticdbVector(BaseVector):
         )
         response = self._client.query_collection_data(request)
         return len(response.body.matches.match) > 0
-    
+
     def delete_by_ids(self, ids: list[str]) -> None:
         from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
         ids_str = ",".join(f"'{id}'" for id in ids)
@@ -260,7 +260,7 @@ class AnalyticdbVector(BaseVector):
                 )
                 documents.append(doc)
         return documents
-    
+
     def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
         from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
         score_threshold = (
@@ -291,7 +291,7 @@ class AnalyticdbVector(BaseVector):
                 )
                 documents.append(doc)
         return documents
-    
+
     def delete(self) -> None:
         from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
         request = gpdb_20160503_models.DeleteCollectionRequest(
@@ -316,17 +316,18 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
             dataset.index_struct = json.dumps(
                 self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name)
             )
-        config = current_app.config
+
+        # TODO handle optional params
         return AnalyticdbVector(
             collection_name,
             AnalyticdbConfig(
-                access_key_id=config.get("ANALYTICDB_KEY_ID"),
-                access_key_secret=config.get("ANALYTICDB_KEY_SECRET"),
-                region_id=config.get("ANALYTICDB_REGION_ID"),
-                instance_id=config.get("ANALYTICDB_INSTANCE_ID"),
-                account=config.get("ANALYTICDB_ACCOUNT"),
-                account_password=config.get("ANALYTICDB_PASSWORD"),
-                namespace=config.get("ANALYTICDB_NAMESPACE"),
-                namespace_password=config.get("ANALYTICDB_NAMESPACE_PASSWORD"),
+                access_key_id=dify_config.ANALYTICDB_KEY_ID,
+                access_key_secret=dify_config.ANALYTICDB_KEY_SECRET,
+                region_id=dify_config.ANALYTICDB_REGION_ID,
+                instance_id=dify_config.ANALYTICDB_INSTANCE_ID,
+                account=dify_config.ANALYTICDB_ACCOUNT,
+                account_password=dify_config.ANALYTICDB_PASSWORD,
+                namespace=dify_config.ANALYTICDB_NAMESPACE,
+                namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD,
             ),
-        )
+        )

+ 7 - 8
api/core/rag/datasource/vdb/chroma/chroma_vector.py

@@ -3,9 +3,9 @@ from typing import Any, Optional
 
 import chromadb
 from chromadb import QueryResult, Settings
-from flask import current_app
 from pydantic import BaseModel
 
+from configs import dify_config
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@@ -133,15 +133,14 @@ class ChromaVectorFactory(AbstractVectorFactory):
             }
             dataset.index_struct = json.dumps(index_struct_dict)
 
-        config = current_app.config
         return ChromaVector(
             collection_name=collection_name,
             config=ChromaConfig(
-                host=config.get('CHROMA_HOST'),
-                port=int(config.get('CHROMA_PORT')),
-                tenant=config.get('CHROMA_TENANT', chromadb.DEFAULT_TENANT),
-                database=config.get('CHROMA_DATABASE', chromadb.DEFAULT_DATABASE),
-                auth_provider=config.get('CHROMA_AUTH_PROVIDER'),
-                auth_credentials=config.get('CHROMA_AUTH_CREDENTIALS'),
+                host=dify_config.CHROMA_HOST,
+                port=dify_config.CHROMA_PORT,
+                tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT,
+                database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE,
+                auth_provider=dify_config.CHROMA_AUTH_PROVIDER,
+                auth_credentials=dify_config.CHROMA_AUTH_CREDENTIALS,
             ),
         )

+ 7 - 8
api/core/rag/datasource/vdb/milvus/milvus_vector.py

@@ -3,10 +3,10 @@ import logging
 from typing import Any, Optional
 from uuid import uuid4
 
-from flask import current_app
 from pydantic import BaseModel, model_validator
 from pymilvus import MilvusClient, MilvusException, connections
 
+from configs import dify_config
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.field import Field
 from core.rag.datasource.vdb.vector_base import BaseVector
@@ -275,15 +275,14 @@ class MilvusVectorFactory(AbstractVectorFactory):
             dataset.index_struct = json.dumps(
                 self.gen_index_struct_dict(VectorType.MILVUS, collection_name))
 
-        config = current_app.config
         return MilvusVector(
             collection_name=collection_name,
             config=MilvusConfig(
-                host=config.get('MILVUS_HOST'),
-                port=config.get('MILVUS_PORT'),
-                user=config.get('MILVUS_USER'),
-                password=config.get('MILVUS_PASSWORD'),
-                secure=config.get('MILVUS_SECURE'),
-                database=config.get('MILVUS_DATABASE'),
+                host=dify_config.MILVUS_HOST,
+                port=dify_config.MILVUS_PORT,
+                user=dify_config.MILVUS_USER,
+                password=dify_config.MILVUS_PASSWORD,
+                secure=dify_config.MILVUS_SECURE,
+                database=dify_config.MILVUS_DATABASE,
             )
         )

+ 8 - 8
api/core/rag/datasource/vdb/myscale/myscale_vector.py

@@ -5,9 +5,9 @@ from enum import Enum
 from typing import Any
 
 from clickhouse_connect import get_client
-from flask import current_app
 from pydantic import BaseModel
 
+from configs import dify_config
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@@ -156,15 +156,15 @@ class MyScaleVectorFactory(AbstractVectorFactory):
             dataset.index_struct = json.dumps(
                 self.gen_index_struct_dict(VectorType.MYSCALE, collection_name))
 
-        config = current_app.config
         return MyScaleVector(
             collection_name=collection_name,
             config=MyScaleConfig(
-                host=config.get("MYSCALE_HOST", "localhost"),
-                port=int(config.get("MYSCALE_PORT", 8123)),
-                user=config.get("MYSCALE_USER", "default"),
-                password=config.get("MYSCALE_PASSWORD", ""),
-                database=config.get("MYSCALE_DATABASE", "default"),
-                fts_params=config.get("MYSCALE_FTS_PARAMS", ""),
+                # TODO: I think setting those values as the default config would be a better option.
+                host=dify_config.MYSCALE_HOST or "localhost",
+                port=dify_config.MYSCALE_PORT or 8123,
+                user=dify_config.MYSCALE_USER or "default",
+                password=dify_config.MYSCALE_PASSWORD or "",
+                database=dify_config.MYSCALE_DATABASE or "default",
+                fts_params=dify_config.MYSCALE_FTS_PARAMS or "",
             ),
         )

+ 6 - 7
api/core/rag/datasource/vdb/opensearch/opensearch_vector.py

@@ -4,11 +4,11 @@ import ssl
 from typing import Any, Optional
 from uuid import uuid4
 
-from flask import current_app
 from opensearchpy import OpenSearch, helpers
 from opensearchpy.helpers import BulkIndexError
 from pydantic import BaseModel, model_validator
 
+from configs import dify_config
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.field import Field
 from core.rag.datasource.vdb.vector_base import BaseVector
@@ -257,14 +257,13 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
             dataset.index_struct = json.dumps(
                 self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name))
 
-        config = current_app.config
 
         open_search_config = OpenSearchConfig(
-            host=config.get('OPENSEARCH_HOST'),
-            port=config.get('OPENSEARCH_PORT'),
-            user=config.get('OPENSEARCH_USER'),
-            password=config.get('OPENSEARCH_PASSWORD'),
-            secure=config.get('OPENSEARCH_SECURE'),
+            host=dify_config.OPENSEARCH_HOST,
+            port=dify_config.OPENSEARCH_PORT,
+            user=dify_config.OPENSEARCH_USER,
+            password=dify_config.OPENSEARCH_PASSWORD,
+            secure=dify_config.OPENSEARCH_SECURE,
         )
 
         return OpenSearchVector(

+ 8 - 9
api/core/rag/datasource/vdb/oracle/oraclevector.py

@@ -6,9 +6,9 @@ from typing import Any
 
 import numpy
 import oracledb
-from flask import current_app
 from pydantic import BaseModel, model_validator
 
+from configs import dify_config
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@@ -44,11 +44,11 @@ class OracleVectorConfig(BaseModel):
 
 SQL_CREATE_TABLE = """
 CREATE TABLE IF NOT EXISTS {table_name} (
-    id varchar2(100) 
+    id varchar2(100)
     ,text CLOB NOT NULL
     ,meta JSON
     ,embedding vector NOT NULL
-) 
+)
 """
 
 
@@ -219,14 +219,13 @@ class OracleVectorFactory(AbstractVectorFactory):
             dataset.index_struct = json.dumps(
                 self.gen_index_struct_dict(VectorType.ORACLE, collection_name))
 
-        config = current_app.config
         return OracleVector(
             collection_name=collection_name,
             config=OracleVectorConfig(
-                host=config.get("ORACLE_HOST"),
-                port=config.get("ORACLE_PORT"),
-                user=config.get("ORACLE_USER"),
-                password=config.get("ORACLE_PASSWORD"),
-                database=config.get("ORACLE_DATABASE"),
+                host=dify_config.ORACLE_HOST,
+                port=dify_config.ORACLE_PORT,
+                user=dify_config.ORACLE_USER,
+                password=dify_config.ORACLE_PASSWORD,
+                database=dify_config.ORACLE_DATABASE,
             ),
         )

+ 9 - 9
api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py

@@ -3,7 +3,6 @@ import logging
 from typing import Any
 from uuid import UUID, uuid4
 
-from flask import current_app
 from numpy import ndarray
 from pgvecto_rs.sqlalchemy import Vector
 from pydantic import BaseModel, model_validator
@@ -12,6 +11,7 @@ from sqlalchemy import text as sql_text
 from sqlalchemy.dialects import postgresql
 from sqlalchemy.orm import Mapped, Session, mapped_column
 
+from configs import dify_config
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM
 from core.rag.datasource.vdb.vector_base import BaseVector
@@ -93,7 +93,7 @@ class PGVectoRS(BaseVector):
                         text TEXT NOT NULL,
                         meta JSONB NOT NULL,
                         vector vector({dimension}) NOT NULL
-                    ) using heap; 
+                    ) using heap;
                 """)
                 session.execute(create_statement)
                 index_statement = sql_text(f"""
@@ -233,15 +233,15 @@ class PGVectoRSFactory(AbstractVectorFactory):
             dataset.index_struct = json.dumps(
                 self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
         dim = len(embeddings.embed_query("pgvecto_rs"))
-        config = current_app.config
+
         return PGVectoRS(
             collection_name=collection_name,
             config=PgvectoRSConfig(
-                host=config.get('PGVECTO_RS_HOST'),
-                port=config.get('PGVECTO_RS_PORT'),
-                user=config.get('PGVECTO_RS_USER'),
-                password=config.get('PGVECTO_RS_PASSWORD'),
-                database=config.get('PGVECTO_RS_DATABASE'),
+                host=dify_config.PGVECTO_RS_HOST,
+                port=dify_config.PGVECTO_RS_PORT,
+                user=dify_config.PGVECTO_RS_USER,
+                password=dify_config.PGVECTO_RS_PASSWORD,
+                database=dify_config.PGVECTO_RS_DATABASE,
             ),
             dim=dim
-        )
+        )

+ 8 - 9
api/core/rag/datasource/vdb/pgvector/pgvector.py

@@ -5,9 +5,9 @@ from typing import Any
 
 import psycopg2.extras
 import psycopg2.pool
-from flask import current_app
 from pydantic import BaseModel, model_validator
 
+from configs import dify_config
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@@ -45,7 +45,7 @@ CREATE TABLE IF NOT EXISTS {table_name} (
     text TEXT NOT NULL,
     meta JSONB NOT NULL,
     embedding vector({dimension}) NOT NULL
-) using heap; 
+) using heap;
 """
 
 
@@ -185,14 +185,13 @@ class PGVectorFactory(AbstractVectorFactory):
             dataset.index_struct = json.dumps(
                 self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name))
 
-        config = current_app.config
         return PGVector(
             collection_name=collection_name,
             config=PGVectorConfig(
-                host=config.get("PGVECTOR_HOST"),
-                port=config.get("PGVECTOR_PORT"),
-                user=config.get("PGVECTOR_USER"),
-                password=config.get("PGVECTOR_PASSWORD"),
-                database=config.get("PGVECTOR_DATABASE"),
+                host=dify_config.PGVECTOR_HOST,
+                port=dify_config.PGVECTOR_PORT,
+                user=dify_config.PGVECTOR_USER,
+                password=dify_config.PGVECTOR_PASSWORD,
+                database=dify_config.PGVECTOR_DATABASE,
             ),
-        )
+        )

+ 6 - 5
api/core/rag/datasource/vdb/qdrant/qdrant_vector.py

@@ -19,6 +19,7 @@ from qdrant_client.http.models import (
 )
 from qdrant_client.local.qdrant_local import QdrantLocal
 
+from configs import dify_config
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.field import Field
 from core.rag.datasource.vdb.vector_base import BaseVector
@@ -444,11 +445,11 @@ class QdrantVectorFactory(AbstractVectorFactory):
             collection_name=collection_name,
             group_id=dataset.id,
             config=QdrantConfig(
-                endpoint=config.get('QDRANT_URL'),
-                api_key=config.get('QDRANT_API_KEY'),
+                endpoint=dify_config.QDRANT_URL,
+                api_key=dify_config.QDRANT_API_KEY,
                 root_path=config.root_path,
-                timeout=config.get('QDRANT_CLIENT_TIMEOUT'),
-                grpc_port=config.get('QDRANT_GRPC_PORT'),
-                prefer_grpc=config.get('QDRANT_GRPC_ENABLED')
+                timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
+                grpc_port=dify_config.QDRANT_GRPC_PORT,
+                prefer_grpc=dify_config.QDRANT_GRPC_ENABLED
             )
         )

+ 7 - 8
api/core/rag/datasource/vdb/relyt/relyt_vector.py

@@ -2,7 +2,6 @@ import json
 import uuid
 from typing import Any, Optional
 
-from flask import current_app
 from pydantic import BaseModel, model_validator
 from sqlalchemy import Column, Sequence, String, Table, create_engine, insert
 from sqlalchemy import text as sql_text
@@ -19,6 +18,7 @@ try:
 except ImportError:
     from sqlalchemy.ext.declarative import declarative_base
 
+from configs import dify_config
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.models.document import Document
 from extensions.ext_redis import redis_client
@@ -85,7 +85,7 @@ class RelytVector(BaseVector):
                         document TEXT NOT NULL,
                         metadata JSON NOT NULL,
                         embedding vector({dimension}) NOT NULL
-                    ) using heap; 
+                    ) using heap;
                 """)
                 session.execute(create_statement)
                 index_statement = sql_text(f"""
@@ -313,15 +313,14 @@ class RelytVectorFactory(AbstractVectorFactory):
             dataset.index_struct = json.dumps(
                 self.gen_index_struct_dict(VectorType.RELYT, collection_name))
 
-        config = current_app.config
         return RelytVector(
             collection_name=collection_name,
             config=RelytConfig(
-                host=config.get('RELYT_HOST'),
-                port=config.get('RELYT_PORT'),
-                user=config.get('RELYT_USER'),
-                password=config.get('RELYT_PASSWORD'),
-                database=config.get('RELYT_DATABASE'),
+                host=dify_config.RELYT_HOST,
+                port=dify_config.RELYT_PORT,
+                user=dify_config.RELYT_USER,
+                password=dify_config.RELYT_PASSWORD,
+                database=dify_config.RELYT_DATABASE,
             ),
             group_id=dataset.id
         )

+ 9 - 10
api/core/rag/datasource/vdb/tencent/tencent_vector.py

@@ -1,13 +1,13 @@
 import json
 from typing import Any, Optional
 
-from flask import current_app
 from pydantic import BaseModel
 from tcvectordb import VectorDBClient
 from tcvectordb.model import document, enum
 from tcvectordb.model import index as vdb_index
 from tcvectordb.model.document import Filter
 
+from configs import dify_config
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@@ -212,16 +212,15 @@ class TencentVectorFactory(AbstractVectorFactory):
             dataset.index_struct = json.dumps(
                 self.gen_index_struct_dict(VectorType.TENCENT, collection_name))
 
-        config = current_app.config
         return TencentVector(
             collection_name=collection_name,
             config=TencentConfig(
-                url=config.get('TENCENT_VECTOR_DB_URL'),
-                api_key=config.get('TENCENT_VECTOR_DB_API_KEY'),
-                timeout=config.get('TENCENT_VECTOR_DB_TIMEOUT'),
-                username=config.get('TENCENT_VECTOR_DB_USERNAME'),
-                database=config.get('TENCENT_VECTOR_DB_DATABASE'),
-                shard=config.get('TENCENT_VECTOR_DB_SHARD'),
-                replicas=config.get('TENCENT_VECTOR_DB_REPLICAS'),
+                url=dify_config.TENCENT_VECTOR_DB_URL,
+                api_key=dify_config.TENCENT_VECTOR_DB_API_KEY,
+                timeout=dify_config.TENCENT_VECTOR_DB_TIMEOUT,
+                username=dify_config.TENCENT_VECTOR_DB_USERNAME,
+                database=dify_config.TENCENT_VECTOR_DB_DATABASE,
+                shard=dify_config.TENCENT_VECTOR_DB_SHARD,
+                replicas=dify_config.TENCENT_VECTOR_DB_REPLICAS,
             )
-        )
+        )

+ 10 - 11
api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py

@@ -3,12 +3,12 @@ import logging
 from typing import Any
 
 import sqlalchemy
-from flask import current_app
 from pydantic import BaseModel, model_validator
 from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
 from sqlalchemy import text as sql_text
 from sqlalchemy.orm import Session, declarative_base
 
+from configs import dify_config
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@@ -198,8 +198,8 @@ class TiDBVector(BaseVector):
         with Session(self._engine) as session:
             select_statement = sql_text(
                 f"""SELECT meta, text, distance FROM (
-                        SELECT meta, text, {tidb_func}(vector, "{query_vector_str}")  as distance 
-                        FROM {self._collection_name} 
+                        SELECT meta, text, {tidb_func}(vector, "{query_vector_str}")  as distance
+                        FROM {self._collection_name}
                         ORDER BY distance
                         LIMIT {top_k}
                     ) t WHERE distance < {distance};"""
@@ -234,15 +234,14 @@ class TiDBVectorFactory(AbstractVectorFactory):
             dataset.index_struct = json.dumps(
                 self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name))
 
-        config = current_app.config
         return TiDBVector(
             collection_name=collection_name,
             config=TiDBVectorConfig(
-                host=config.get('TIDB_VECTOR_HOST'),
-                port=config.get('TIDB_VECTOR_PORT'),
-                user=config.get('TIDB_VECTOR_USER'),
-                password=config.get('TIDB_VECTOR_PASSWORD'),
-                database=config.get('TIDB_VECTOR_DATABASE'),
-                program_name=config.get('APPLICATION_NAME'),
+                host=dify_config.TIDB_VECTOR_HOST,
+                port=dify_config.TIDB_VECTOR_PORT,
+                user=dify_config.TIDB_VECTOR_USER,
+                password=dify_config.TIDB_VECTOR_PASSWORD,
+                database=dify_config.TIDB_VECTOR_DATABASE,
+                program_name=dify_config.APPLICATION_NAME,
             ),
-        )
+        )

+ 2 - 4
api/core/rag/datasource/vdb/vector_factory.py

@@ -1,8 +1,7 @@
 from abc import ABC, abstractmethod
 from typing import Any
 
-from flask import current_app
-
+from configs import dify_config
 from core.embedding.cached_embedding import CacheEmbedding
 from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelType
@@ -37,8 +36,7 @@ class Vector:
         self._vector_processor = self._init_vector()
 
     def _init_vector(self) -> BaseVector:
-        config = current_app.config
-        vector_type = config.get('VECTOR_STORE')
+        vector_type = dify_config.VECTOR_STORE
         if self._dataset.index_struct_dict:
             vector_type = self._dataset.index_struct_dict['type']
 

+ 4 - 4
api/core/rag/datasource/vdb/weaviate/weaviate_vector.py

@@ -4,9 +4,9 @@ from typing import Any, Optional
 
 import requests
 import weaviate
-from flask import current_app
 from pydantic import BaseModel, model_validator
 
+from configs import dify_config
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.field import Field
 from core.rag.datasource.vdb.vector_base import BaseVector
@@ -281,9 +281,9 @@ class WeaviateVectorFactory(AbstractVectorFactory):
         return WeaviateVector(
             collection_name=collection_name,
             config=WeaviateConfig(
-                endpoint=current_app.config.get('WEAVIATE_ENDPOINT'),
-                api_key=current_app.config.get('WEAVIATE_API_KEY'),
-                batch_size=int(current_app.config.get('WEAVIATE_BATCH_SIZE'))
+                endpoint=dify_config.WEAVIATE_ENDPOINT,
+                api_key=dify_config.WEAVIATE_API_KEY,
+                batch_size=dify_config.WEAVIATE_BATCH_SIZE
             ),
             attributes=attributes
         )

+ 4 - 4
api/core/rag/extractor/extract_processor.py

@@ -5,8 +5,8 @@ from typing import Union
 from urllib.parse import unquote
 
 import requests
-from flask import current_app
 
+from configs import dify_config
 from core.rag.extractor.csv_extractor import CSVExtractor
 from core.rag.extractor.entity.datasource_type import DatasourceType
 from core.rag.extractor.entity.extract_setting import ExtractSetting
@@ -94,9 +94,9 @@ class ExtractProcessor:
                     storage.download(upload_file.key, file_path)
                 input_file = Path(file_path)
                 file_extension = input_file.suffix.lower()
-                etl_type = current_app.config['ETL_TYPE']
-                unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL']
-                unstructured_api_key = current_app.config['UNSTRUCTURED_API_KEY']
+                etl_type = dify_config.ETL_TYPE
+                unstructured_api_url = dify_config.UNSTRUCTURED_API_URL
+                unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY
                 if etl_type == 'Unstructured':
                     if file_extension == '.xlsx' or file_extension == '.xls':
                         extractor = ExcelExtractor(file_path)

+ 2 - 2
api/core/rag/extractor/notion_extractor.py

@@ -3,8 +3,8 @@ import logging
 from typing import Any, Optional
 
 import requests
-from flask import current_app
 
+from configs import dify_config
 from core.rag.extractor.extractor_base import BaseExtractor
 from core.rag.models.document import Document
 from extensions.ext_database import db
@@ -49,7 +49,7 @@ class NotionExtractor(BaseExtractor):
             self._notion_access_token = self._get_access_token(tenant_id,
                                                                self._notion_workspace_id)
             if not self._notion_access_token:
-                integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN')
+                integration_token = dify_config.NOTION_INTEGRATION_TOKEN
                 if integration_token is None:
                     raise ValueError(
                         "Must specify `integration_token` or set environment "

+ 3 - 4
api/core/rag/extractor/word_extractor.py

@@ -8,8 +8,8 @@ from urllib.parse import urlparse
 
 import requests
 from docx import Document as DocxDocument
-from flask import current_app
 
+from configs import dify_config
 from core.rag.extractor.extractor_base import BaseExtractor
 from core.rag.models.document import Document
 from extensions.ext_database import db
@@ -96,10 +96,9 @@ class WordExtractor(BaseExtractor):
 
                     storage.save(file_key, rel.target_part.blob)
                 # save file to db
-                config = current_app.config
                 upload_file = UploadFile(
                     tenant_id=self.tenant_id,
-                    storage_type=config['STORAGE_TYPE'],
+                    storage_type=dify_config.STORAGE_TYPE,
                     key=file_key,
                     name=file_key,
                     size=0,
@@ -114,7 +113,7 @@ class WordExtractor(BaseExtractor):
 
                 db.session.add(upload_file)
                 db.session.commit()
-                image_map[rel.target_part] = f"![image]({current_app.config.get('CONSOLE_API_URL')}/files/{upload_file.id}/image-preview)"
+                image_map[rel.target_part] = f"![image]({dify_config.CONSOLE_API_URL}/files/{upload_file.id}/image-preview)"
 
         return image_map
 

+ 2 - 3
api/core/rag/index_processor/index_processor_base.py

@@ -2,8 +2,7 @@
 from abc import ABC, abstractmethod
 from typing import Optional
 
-from flask import current_app
-
+from configs import dify_config
 from core.model_manager import ModelInstance
 from core.rag.extractor.entity.extract_setting import ExtractSetting
 from core.rag.models.document import Document
@@ -48,7 +47,7 @@ class BaseIndexProcessor(ABC):
             # The user-defined segmentation rule
             rules = processing_rule['rules']
             segmentation = rules["segmentation"]
-            max_segmentation_tokens_length = int(current_app.config['INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH'])
+            max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
             if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length:
                 raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.")