Browse Source

refactor(rag): switch to dify_config. (#6410)

Co-authored-by: -LAN- <laipz8200@outlook.com>
Poorandy 9 months ago
parent
commit
c8f5dfcf17

+ 19 - 18
api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py

@@ -7,8 +7,8 @@ _import_err_msg = (
     "`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
     "`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
     "please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
     "please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
 )
 )
-from flask import current_app
 
 
+from configs import dify_config
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@@ -36,7 +36,7 @@ class AnalyticdbConfig(BaseModel):
             "region_id": self.region_id,
             "region_id": self.region_id,
             "read_timeout": self.read_timeout,
             "read_timeout": self.read_timeout,
         }
         }
-    
+
 class AnalyticdbVector(BaseVector):
 class AnalyticdbVector(BaseVector):
     _instance = None
     _instance = None
     _init = False
     _init = False
@@ -45,7 +45,7 @@ class AnalyticdbVector(BaseVector):
         if cls._instance is None:
         if cls._instance is None:
             cls._instance = super().__new__(cls)
             cls._instance = super().__new__(cls)
         return cls._instance
         return cls._instance
-    
+
     def __init__(self, collection_name: str, config: AnalyticdbConfig):
     def __init__(self, collection_name: str, config: AnalyticdbConfig):
         # collection_name must be updated every time
         # collection_name must be updated every time
         self._collection_name = collection_name.lower()
         self._collection_name = collection_name.lower()
@@ -105,7 +105,7 @@ class AnalyticdbVector(BaseVector):
                 raise ValueError(
                 raise ValueError(
                     f"failed to create namespace {self.config.namespace}: {e}"
                     f"failed to create namespace {self.config.namespace}: {e}"
                 )
                 )
-            
+
     def _create_collection_if_not_exists(self, embedding_dimension: int):
     def _create_collection_if_not_exists(self, embedding_dimension: int):
         from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
         from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
         from Tea.exceptions import TeaException
         from Tea.exceptions import TeaException
@@ -149,7 +149,7 @@ class AnalyticdbVector(BaseVector):
 
 
     def get_type(self) -> str:
     def get_type(self) -> str:
         return VectorType.ANALYTICDB
         return VectorType.ANALYTICDB
-    
+
     def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
     def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
         dimension = len(embeddings[0])
         dimension = len(embeddings[0])
         self._create_collection_if_not_exists(dimension)
         self._create_collection_if_not_exists(dimension)
@@ -199,7 +199,7 @@ class AnalyticdbVector(BaseVector):
         )
         )
         response = self._client.query_collection_data(request)
         response = self._client.query_collection_data(request)
         return len(response.body.matches.match) > 0
         return len(response.body.matches.match) > 0
-    
+
     def delete_by_ids(self, ids: list[str]) -> None:
     def delete_by_ids(self, ids: list[str]) -> None:
         from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
         from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
         ids_str = ",".join(f"'{id}'" for id in ids)
         ids_str = ",".join(f"'{id}'" for id in ids)
@@ -260,7 +260,7 @@ class AnalyticdbVector(BaseVector):
                 )
                 )
                 documents.append(doc)
                 documents.append(doc)
         return documents
         return documents
-    
+
     def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
     def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
         from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
         from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
         score_threshold = (
         score_threshold = (
@@ -291,7 +291,7 @@ class AnalyticdbVector(BaseVector):
                 )
                 )
                 documents.append(doc)
                 documents.append(doc)
         return documents
         return documents
-    
+
     def delete(self) -> None:
     def delete(self) -> None:
         from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
         from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
         request = gpdb_20160503_models.DeleteCollectionRequest(
         request = gpdb_20160503_models.DeleteCollectionRequest(
@@ -316,17 +316,18 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
             dataset.index_struct = json.dumps(
             dataset.index_struct = json.dumps(
                 self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name)
                 self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name)
             )
             )
-        config = current_app.config
+
+        # TODO handle optional params
         return AnalyticdbVector(
         return AnalyticdbVector(
             collection_name,
             collection_name,
             AnalyticdbConfig(
             AnalyticdbConfig(
-                access_key_id=config.get("ANALYTICDB_KEY_ID"),
-                access_key_secret=config.get("ANALYTICDB_KEY_SECRET"),
-                region_id=config.get("ANALYTICDB_REGION_ID"),
-                instance_id=config.get("ANALYTICDB_INSTANCE_ID"),
-                account=config.get("ANALYTICDB_ACCOUNT"),
-                account_password=config.get("ANALYTICDB_PASSWORD"),
-                namespace=config.get("ANALYTICDB_NAMESPACE"),
-                namespace_password=config.get("ANALYTICDB_NAMESPACE_PASSWORD"),
+                access_key_id=dify_config.ANALYTICDB_KEY_ID,
+                access_key_secret=dify_config.ANALYTICDB_KEY_SECRET,
+                region_id=dify_config.ANALYTICDB_REGION_ID,
+                instance_id=dify_config.ANALYTICDB_INSTANCE_ID,
+                account=dify_config.ANALYTICDB_ACCOUNT,
+                account_password=dify_config.ANALYTICDB_PASSWORD,
+                namespace=dify_config.ANALYTICDB_NAMESPACE,
+                namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD,
             ),
             ),
-        )
+        )

+ 7 - 8
api/core/rag/datasource/vdb/chroma/chroma_vector.py

@@ -3,9 +3,9 @@ from typing import Any, Optional
 
 
 import chromadb
 import chromadb
 from chromadb import QueryResult, Settings
 from chromadb import QueryResult, Settings
-from flask import current_app
 from pydantic import BaseModel
 from pydantic import BaseModel
 
 
+from configs import dify_config
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@@ -133,15 +133,14 @@ class ChromaVectorFactory(AbstractVectorFactory):
             }
             }
             dataset.index_struct = json.dumps(index_struct_dict)
             dataset.index_struct = json.dumps(index_struct_dict)
 
 
-        config = current_app.config
         return ChromaVector(
         return ChromaVector(
             collection_name=collection_name,
             collection_name=collection_name,
             config=ChromaConfig(
             config=ChromaConfig(
-                host=config.get('CHROMA_HOST'),
-                port=int(config.get('CHROMA_PORT')),
-                tenant=config.get('CHROMA_TENANT', chromadb.DEFAULT_TENANT),
-                database=config.get('CHROMA_DATABASE', chromadb.DEFAULT_DATABASE),
-                auth_provider=config.get('CHROMA_AUTH_PROVIDER'),
-                auth_credentials=config.get('CHROMA_AUTH_CREDENTIALS'),
+                host=dify_config.CHROMA_HOST,
+                port=dify_config.CHROMA_PORT,
+                tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT,
+                database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE,
+                auth_provider=dify_config.CHROMA_AUTH_PROVIDER,
+                auth_credentials=dify_config.CHROMA_AUTH_CREDENTIALS,
             ),
             ),
         )
         )

+ 7 - 8
api/core/rag/datasource/vdb/milvus/milvus_vector.py

@@ -3,10 +3,10 @@ import logging
 from typing import Any, Optional
 from typing import Any, Optional
 from uuid import uuid4
 from uuid import uuid4
 
 
-from flask import current_app
 from pydantic import BaseModel, model_validator
 from pydantic import BaseModel, model_validator
 from pymilvus import MilvusClient, MilvusException, connections
 from pymilvus import MilvusClient, MilvusException, connections
 
 
+from configs import dify_config
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.field import Field
 from core.rag.datasource.vdb.field import Field
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_base import BaseVector
@@ -275,15 +275,14 @@ class MilvusVectorFactory(AbstractVectorFactory):
             dataset.index_struct = json.dumps(
             dataset.index_struct = json.dumps(
                 self.gen_index_struct_dict(VectorType.MILVUS, collection_name))
                 self.gen_index_struct_dict(VectorType.MILVUS, collection_name))
 
 
-        config = current_app.config
         return MilvusVector(
         return MilvusVector(
             collection_name=collection_name,
             collection_name=collection_name,
             config=MilvusConfig(
             config=MilvusConfig(
-                host=config.get('MILVUS_HOST'),
-                port=config.get('MILVUS_PORT'),
-                user=config.get('MILVUS_USER'),
-                password=config.get('MILVUS_PASSWORD'),
-                secure=config.get('MILVUS_SECURE'),
-                database=config.get('MILVUS_DATABASE'),
+                host=dify_config.MILVUS_HOST,
+                port=dify_config.MILVUS_PORT,
+                user=dify_config.MILVUS_USER,
+                password=dify_config.MILVUS_PASSWORD,
+                secure=dify_config.MILVUS_SECURE,
+                database=dify_config.MILVUS_DATABASE,
             )
             )
         )
         )

+ 8 - 8
api/core/rag/datasource/vdb/myscale/myscale_vector.py

@@ -5,9 +5,9 @@ from enum import Enum
 from typing import Any
 from typing import Any
 
 
 from clickhouse_connect import get_client
 from clickhouse_connect import get_client
-from flask import current_app
 from pydantic import BaseModel
 from pydantic import BaseModel
 
 
+from configs import dify_config
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@@ -156,15 +156,15 @@ class MyScaleVectorFactory(AbstractVectorFactory):
             dataset.index_struct = json.dumps(
             dataset.index_struct = json.dumps(
                 self.gen_index_struct_dict(VectorType.MYSCALE, collection_name))
                 self.gen_index_struct_dict(VectorType.MYSCALE, collection_name))
 
 
-        config = current_app.config
         return MyScaleVector(
         return MyScaleVector(
             collection_name=collection_name,
             collection_name=collection_name,
             config=MyScaleConfig(
             config=MyScaleConfig(
-                host=config.get("MYSCALE_HOST", "localhost"),
-                port=int(config.get("MYSCALE_PORT", 8123)),
-                user=config.get("MYSCALE_USER", "default"),
-                password=config.get("MYSCALE_PASSWORD", ""),
-                database=config.get("MYSCALE_DATABASE", "default"),
-                fts_params=config.get("MYSCALE_FTS_PARAMS", ""),
+                # TODO: I think setting those values as the default config would be a better option.
+                host=dify_config.MYSCALE_HOST or "localhost",
+                port=dify_config.MYSCALE_PORT or 8123,
+                user=dify_config.MYSCALE_USER or "default",
+                password=dify_config.MYSCALE_PASSWORD or "",
+                database=dify_config.MYSCALE_DATABASE or "default",
+                fts_params=dify_config.MYSCALE_FTS_PARAMS or "",
             ),
             ),
         )
         )

+ 6 - 7
api/core/rag/datasource/vdb/opensearch/opensearch_vector.py

@@ -4,11 +4,11 @@ import ssl
 from typing import Any, Optional
 from typing import Any, Optional
 from uuid import uuid4
 from uuid import uuid4
 
 
-from flask import current_app
 from opensearchpy import OpenSearch, helpers
 from opensearchpy import OpenSearch, helpers
 from opensearchpy.helpers import BulkIndexError
 from opensearchpy.helpers import BulkIndexError
 from pydantic import BaseModel, model_validator
 from pydantic import BaseModel, model_validator
 
 
+from configs import dify_config
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.field import Field
 from core.rag.datasource.vdb.field import Field
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_base import BaseVector
@@ -257,14 +257,13 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
             dataset.index_struct = json.dumps(
             dataset.index_struct = json.dumps(
                 self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name))
                 self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name))
 
 
-        config = current_app.config
 
 
         open_search_config = OpenSearchConfig(
         open_search_config = OpenSearchConfig(
-            host=config.get('OPENSEARCH_HOST'),
-            port=config.get('OPENSEARCH_PORT'),
-            user=config.get('OPENSEARCH_USER'),
-            password=config.get('OPENSEARCH_PASSWORD'),
-            secure=config.get('OPENSEARCH_SECURE'),
+            host=dify_config.OPENSEARCH_HOST,
+            port=dify_config.OPENSEARCH_PORT,
+            user=dify_config.OPENSEARCH_USER,
+            password=dify_config.OPENSEARCH_PASSWORD,
+            secure=dify_config.OPENSEARCH_SECURE,
         )
         )
 
 
         return OpenSearchVector(
         return OpenSearchVector(

+ 8 - 9
api/core/rag/datasource/vdb/oracle/oraclevector.py

@@ -6,9 +6,9 @@ from typing import Any
 
 
 import numpy
 import numpy
 import oracledb
 import oracledb
-from flask import current_app
 from pydantic import BaseModel, model_validator
 from pydantic import BaseModel, model_validator
 
 
+from configs import dify_config
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@@ -44,11 +44,11 @@ class OracleVectorConfig(BaseModel):
 
 
 SQL_CREATE_TABLE = """
 SQL_CREATE_TABLE = """
 CREATE TABLE IF NOT EXISTS {table_name} (
 CREATE TABLE IF NOT EXISTS {table_name} (
-    id varchar2(100) 
+    id varchar2(100)
     ,text CLOB NOT NULL
     ,text CLOB NOT NULL
     ,meta JSON
     ,meta JSON
     ,embedding vector NOT NULL
     ,embedding vector NOT NULL
-) 
+)
 """
 """
 
 
 
 
@@ -219,14 +219,13 @@ class OracleVectorFactory(AbstractVectorFactory):
             dataset.index_struct = json.dumps(
             dataset.index_struct = json.dumps(
                 self.gen_index_struct_dict(VectorType.ORACLE, collection_name))
                 self.gen_index_struct_dict(VectorType.ORACLE, collection_name))
 
 
-        config = current_app.config
         return OracleVector(
         return OracleVector(
             collection_name=collection_name,
             collection_name=collection_name,
             config=OracleVectorConfig(
             config=OracleVectorConfig(
-                host=config.get("ORACLE_HOST"),
-                port=config.get("ORACLE_PORT"),
-                user=config.get("ORACLE_USER"),
-                password=config.get("ORACLE_PASSWORD"),
-                database=config.get("ORACLE_DATABASE"),
+                host=dify_config.ORACLE_HOST,
+                port=dify_config.ORACLE_PORT,
+                user=dify_config.ORACLE_USER,
+                password=dify_config.ORACLE_PASSWORD,
+                database=dify_config.ORACLE_DATABASE,
             ),
             ),
         )
         )

+ 9 - 9
api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py

@@ -3,7 +3,6 @@ import logging
 from typing import Any
 from typing import Any
 from uuid import UUID, uuid4
 from uuid import UUID, uuid4
 
 
-from flask import current_app
 from numpy import ndarray
 from numpy import ndarray
 from pgvecto_rs.sqlalchemy import Vector
 from pgvecto_rs.sqlalchemy import Vector
 from pydantic import BaseModel, model_validator
 from pydantic import BaseModel, model_validator
@@ -12,6 +11,7 @@ from sqlalchemy import text as sql_text
 from sqlalchemy.dialects import postgresql
 from sqlalchemy.dialects import postgresql
 from sqlalchemy.orm import Mapped, Session, mapped_column
 from sqlalchemy.orm import Mapped, Session, mapped_column
 
 
+from configs import dify_config
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM
 from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_base import BaseVector
@@ -93,7 +93,7 @@ class PGVectoRS(BaseVector):
                         text TEXT NOT NULL,
                         text TEXT NOT NULL,
                         meta JSONB NOT NULL,
                         meta JSONB NOT NULL,
                         vector vector({dimension}) NOT NULL
                         vector vector({dimension}) NOT NULL
-                    ) using heap; 
+                    ) using heap;
                 """)
                 """)
                 session.execute(create_statement)
                 session.execute(create_statement)
                 index_statement = sql_text(f"""
                 index_statement = sql_text(f"""
@@ -233,15 +233,15 @@ class PGVectoRSFactory(AbstractVectorFactory):
             dataset.index_struct = json.dumps(
             dataset.index_struct = json.dumps(
                 self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
                 self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
         dim = len(embeddings.embed_query("pgvecto_rs"))
         dim = len(embeddings.embed_query("pgvecto_rs"))
-        config = current_app.config
+
         return PGVectoRS(
         return PGVectoRS(
             collection_name=collection_name,
             collection_name=collection_name,
             config=PgvectoRSConfig(
             config=PgvectoRSConfig(
-                host=config.get('PGVECTO_RS_HOST'),
-                port=config.get('PGVECTO_RS_PORT'),
-                user=config.get('PGVECTO_RS_USER'),
-                password=config.get('PGVECTO_RS_PASSWORD'),
-                database=config.get('PGVECTO_RS_DATABASE'),
+                host=dify_config.PGVECTO_RS_HOST,
+                port=dify_config.PGVECTO_RS_PORT,
+                user=dify_config.PGVECTO_RS_USER,
+                password=dify_config.PGVECTO_RS_PASSWORD,
+                database=dify_config.PGVECTO_RS_DATABASE,
             ),
             ),
             dim=dim
             dim=dim
-        )
+        )

+ 8 - 9
api/core/rag/datasource/vdb/pgvector/pgvector.py

@@ -5,9 +5,9 @@ from typing import Any
 
 
 import psycopg2.extras
 import psycopg2.extras
 import psycopg2.pool
 import psycopg2.pool
-from flask import current_app
 from pydantic import BaseModel, model_validator
 from pydantic import BaseModel, model_validator
 
 
+from configs import dify_config
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@@ -45,7 +45,7 @@ CREATE TABLE IF NOT EXISTS {table_name} (
     text TEXT NOT NULL,
     text TEXT NOT NULL,
     meta JSONB NOT NULL,
     meta JSONB NOT NULL,
     embedding vector({dimension}) NOT NULL
     embedding vector({dimension}) NOT NULL
-) using heap; 
+) using heap;
 """
 """
 
 
 
 
@@ -185,14 +185,13 @@ class PGVectorFactory(AbstractVectorFactory):
             dataset.index_struct = json.dumps(
             dataset.index_struct = json.dumps(
                 self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name))
                 self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name))
 
 
-        config = current_app.config
         return PGVector(
         return PGVector(
             collection_name=collection_name,
             collection_name=collection_name,
             config=PGVectorConfig(
             config=PGVectorConfig(
-                host=config.get("PGVECTOR_HOST"),
-                port=config.get("PGVECTOR_PORT"),
-                user=config.get("PGVECTOR_USER"),
-                password=config.get("PGVECTOR_PASSWORD"),
-                database=config.get("PGVECTOR_DATABASE"),
+                host=dify_config.PGVECTOR_HOST,
+                port=dify_config.PGVECTOR_PORT,
+                user=dify_config.PGVECTOR_USER,
+                password=dify_config.PGVECTOR_PASSWORD,
+                database=dify_config.PGVECTOR_DATABASE,
             ),
             ),
-        )
+        )

+ 6 - 5
api/core/rag/datasource/vdb/qdrant/qdrant_vector.py

@@ -19,6 +19,7 @@ from qdrant_client.http.models import (
 )
 )
 from qdrant_client.local.qdrant_local import QdrantLocal
 from qdrant_client.local.qdrant_local import QdrantLocal
 
 
+from configs import dify_config
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.field import Field
 from core.rag.datasource.vdb.field import Field
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_base import BaseVector
@@ -444,11 +445,11 @@ class QdrantVectorFactory(AbstractVectorFactory):
             collection_name=collection_name,
             collection_name=collection_name,
             group_id=dataset.id,
             group_id=dataset.id,
             config=QdrantConfig(
             config=QdrantConfig(
-                endpoint=config.get('QDRANT_URL'),
-                api_key=config.get('QDRANT_API_KEY'),
+                endpoint=dify_config.QDRANT_URL,
+                api_key=dify_config.QDRANT_API_KEY,
                 root_path=config.root_path,
                 root_path=config.root_path,
-                timeout=config.get('QDRANT_CLIENT_TIMEOUT'),
-                grpc_port=config.get('QDRANT_GRPC_PORT'),
-                prefer_grpc=config.get('QDRANT_GRPC_ENABLED')
+                timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
+                grpc_port=dify_config.QDRANT_GRPC_PORT,
+                prefer_grpc=dify_config.QDRANT_GRPC_ENABLED
             )
             )
         )
         )

+ 7 - 8
api/core/rag/datasource/vdb/relyt/relyt_vector.py

@@ -2,7 +2,6 @@ import json
 import uuid
 import uuid
 from typing import Any, Optional
 from typing import Any, Optional
 
 
-from flask import current_app
 from pydantic import BaseModel, model_validator
 from pydantic import BaseModel, model_validator
 from sqlalchemy import Column, Sequence, String, Table, create_engine, insert
 from sqlalchemy import Column, Sequence, String, Table, create_engine, insert
 from sqlalchemy import text as sql_text
 from sqlalchemy import text as sql_text
@@ -19,6 +18,7 @@ try:
 except ImportError:
 except ImportError:
     from sqlalchemy.ext.declarative import declarative_base
     from sqlalchemy.ext.declarative import declarative_base
 
 
+from configs import dify_config
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.models.document import Document
 from core.rag.models.document import Document
 from extensions.ext_redis import redis_client
 from extensions.ext_redis import redis_client
@@ -85,7 +85,7 @@ class RelytVector(BaseVector):
                         document TEXT NOT NULL,
                         document TEXT NOT NULL,
                         metadata JSON NOT NULL,
                         metadata JSON NOT NULL,
                         embedding vector({dimension}) NOT NULL
                         embedding vector({dimension}) NOT NULL
-                    ) using heap; 
+                    ) using heap;
                 """)
                 """)
                 session.execute(create_statement)
                 session.execute(create_statement)
                 index_statement = sql_text(f"""
                 index_statement = sql_text(f"""
@@ -313,15 +313,14 @@ class RelytVectorFactory(AbstractVectorFactory):
             dataset.index_struct = json.dumps(
             dataset.index_struct = json.dumps(
                 self.gen_index_struct_dict(VectorType.RELYT, collection_name))
                 self.gen_index_struct_dict(VectorType.RELYT, collection_name))
 
 
-        config = current_app.config
         return RelytVector(
         return RelytVector(
             collection_name=collection_name,
             collection_name=collection_name,
             config=RelytConfig(
             config=RelytConfig(
-                host=config.get('RELYT_HOST'),
-                port=config.get('RELYT_PORT'),
-                user=config.get('RELYT_USER'),
-                password=config.get('RELYT_PASSWORD'),
-                database=config.get('RELYT_DATABASE'),
+                host=dify_config.RELYT_HOST,
+                port=dify_config.RELYT_PORT,
+                user=dify_config.RELYT_USER,
+                password=dify_config.RELYT_PASSWORD,
+                database=dify_config.RELYT_DATABASE,
             ),
             ),
             group_id=dataset.id
             group_id=dataset.id
         )
         )

+ 9 - 10
api/core/rag/datasource/vdb/tencent/tencent_vector.py

@@ -1,13 +1,13 @@
 import json
 import json
 from typing import Any, Optional
 from typing import Any, Optional
 
 
-from flask import current_app
 from pydantic import BaseModel
 from pydantic import BaseModel
 from tcvectordb import VectorDBClient
 from tcvectordb import VectorDBClient
 from tcvectordb.model import document, enum
 from tcvectordb.model import document, enum
 from tcvectordb.model import index as vdb_index
 from tcvectordb.model import index as vdb_index
 from tcvectordb.model.document import Filter
 from tcvectordb.model.document import Filter
 
 
+from configs import dify_config
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@@ -212,16 +212,15 @@ class TencentVectorFactory(AbstractVectorFactory):
             dataset.index_struct = json.dumps(
             dataset.index_struct = json.dumps(
                 self.gen_index_struct_dict(VectorType.TENCENT, collection_name))
                 self.gen_index_struct_dict(VectorType.TENCENT, collection_name))
 
 
-        config = current_app.config
         return TencentVector(
         return TencentVector(
             collection_name=collection_name,
             collection_name=collection_name,
             config=TencentConfig(
             config=TencentConfig(
-                url=config.get('TENCENT_VECTOR_DB_URL'),
-                api_key=config.get('TENCENT_VECTOR_DB_API_KEY'),
-                timeout=config.get('TENCENT_VECTOR_DB_TIMEOUT'),
-                username=config.get('TENCENT_VECTOR_DB_USERNAME'),
-                database=config.get('TENCENT_VECTOR_DB_DATABASE'),
-                shard=config.get('TENCENT_VECTOR_DB_SHARD'),
-                replicas=config.get('TENCENT_VECTOR_DB_REPLICAS'),
+                url=dify_config.TENCENT_VECTOR_DB_URL,
+                api_key=dify_config.TENCENT_VECTOR_DB_API_KEY,
+                timeout=dify_config.TENCENT_VECTOR_DB_TIMEOUT,
+                username=dify_config.TENCENT_VECTOR_DB_USERNAME,
+                database=dify_config.TENCENT_VECTOR_DB_DATABASE,
+                shard=dify_config.TENCENT_VECTOR_DB_SHARD,
+                replicas=dify_config.TENCENT_VECTOR_DB_REPLICAS,
             )
             )
-        )
+        )

+ 10 - 11
api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py

@@ -3,12 +3,12 @@ import logging
 from typing import Any
 from typing import Any
 
 
 import sqlalchemy
 import sqlalchemy
-from flask import current_app
 from pydantic import BaseModel, model_validator
 from pydantic import BaseModel, model_validator
 from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
 from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
 from sqlalchemy import text as sql_text
 from sqlalchemy import text as sql_text
 from sqlalchemy.orm import Session, declarative_base
 from sqlalchemy.orm import Session, declarative_base
 
 
+from configs import dify_config
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@@ -198,8 +198,8 @@ class TiDBVector(BaseVector):
         with Session(self._engine) as session:
         with Session(self._engine) as session:
             select_statement = sql_text(
             select_statement = sql_text(
                 f"""SELECT meta, text, distance FROM (
                 f"""SELECT meta, text, distance FROM (
-                        SELECT meta, text, {tidb_func}(vector, "{query_vector_str}")  as distance 
-                        FROM {self._collection_name} 
+                        SELECT meta, text, {tidb_func}(vector, "{query_vector_str}")  as distance
+                        FROM {self._collection_name}
                         ORDER BY distance
                         ORDER BY distance
                         LIMIT {top_k}
                         LIMIT {top_k}
                     ) t WHERE distance < {distance};"""
                     ) t WHERE distance < {distance};"""
@@ -234,15 +234,14 @@ class TiDBVectorFactory(AbstractVectorFactory):
             dataset.index_struct = json.dumps(
             dataset.index_struct = json.dumps(
                 self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name))
                 self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name))
 
 
-        config = current_app.config
         return TiDBVector(
         return TiDBVector(
             collection_name=collection_name,
             collection_name=collection_name,
             config=TiDBVectorConfig(
             config=TiDBVectorConfig(
-                host=config.get('TIDB_VECTOR_HOST'),
-                port=config.get('TIDB_VECTOR_PORT'),
-                user=config.get('TIDB_VECTOR_USER'),
-                password=config.get('TIDB_VECTOR_PASSWORD'),
-                database=config.get('TIDB_VECTOR_DATABASE'),
-                program_name=config.get('APPLICATION_NAME'),
+                host=dify_config.TIDB_VECTOR_HOST,
+                port=dify_config.TIDB_VECTOR_PORT,
+                user=dify_config.TIDB_VECTOR_USER,
+                password=dify_config.TIDB_VECTOR_PASSWORD,
+                database=dify_config.TIDB_VECTOR_DATABASE,
+                program_name=dify_config.APPLICATION_NAME,
             ),
             ),
-        )
+        )

+ 2 - 4
api/core/rag/datasource/vdb/vector_factory.py

@@ -1,8 +1,7 @@
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from typing import Any
 from typing import Any
 
 
-from flask import current_app
-
+from configs import dify_config
 from core.embedding.cached_embedding import CacheEmbedding
 from core.embedding.cached_embedding import CacheEmbedding
 from core.model_manager import ModelManager
 from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.entities.model_entities import ModelType
@@ -37,8 +36,7 @@ class Vector:
         self._vector_processor = self._init_vector()
         self._vector_processor = self._init_vector()
 
 
     def _init_vector(self) -> BaseVector:
     def _init_vector(self) -> BaseVector:
-        config = current_app.config
-        vector_type = config.get('VECTOR_STORE')
+        vector_type = dify_config.VECTOR_STORE
         if self._dataset.index_struct_dict:
         if self._dataset.index_struct_dict:
             vector_type = self._dataset.index_struct_dict['type']
             vector_type = self._dataset.index_struct_dict['type']
 
 

+ 4 - 4
api/core/rag/datasource/vdb/weaviate/weaviate_vector.py

@@ -4,9 +4,9 @@ from typing import Any, Optional
 
 
 import requests
 import requests
 import weaviate
 import weaviate
-from flask import current_app
 from pydantic import BaseModel, model_validator
 from pydantic import BaseModel, model_validator
 
 
+from configs import dify_config
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.field import Field
 from core.rag.datasource.vdb.field import Field
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_base import BaseVector
@@ -281,9 +281,9 @@ class WeaviateVectorFactory(AbstractVectorFactory):
         return WeaviateVector(
         return WeaviateVector(
             collection_name=collection_name,
             collection_name=collection_name,
             config=WeaviateConfig(
             config=WeaviateConfig(
-                endpoint=current_app.config.get('WEAVIATE_ENDPOINT'),
-                api_key=current_app.config.get('WEAVIATE_API_KEY'),
-                batch_size=int(current_app.config.get('WEAVIATE_BATCH_SIZE'))
+                endpoint=dify_config.WEAVIATE_ENDPOINT,
+                api_key=dify_config.WEAVIATE_API_KEY,
+                batch_size=dify_config.WEAVIATE_BATCH_SIZE
             ),
             ),
             attributes=attributes
             attributes=attributes
         )
         )

+ 4 - 4
api/core/rag/extractor/extract_processor.py

@@ -5,8 +5,8 @@ from typing import Union
 from urllib.parse import unquote
 from urllib.parse import unquote
 
 
 import requests
 import requests
-from flask import current_app
 
 
+from configs import dify_config
 from core.rag.extractor.csv_extractor import CSVExtractor
 from core.rag.extractor.csv_extractor import CSVExtractor
 from core.rag.extractor.entity.datasource_type import DatasourceType
 from core.rag.extractor.entity.datasource_type import DatasourceType
 from core.rag.extractor.entity.extract_setting import ExtractSetting
 from core.rag.extractor.entity.extract_setting import ExtractSetting
@@ -94,9 +94,9 @@ class ExtractProcessor:
                     storage.download(upload_file.key, file_path)
                     storage.download(upload_file.key, file_path)
                 input_file = Path(file_path)
                 input_file = Path(file_path)
                 file_extension = input_file.suffix.lower()
                 file_extension = input_file.suffix.lower()
-                etl_type = current_app.config['ETL_TYPE']
-                unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL']
-                unstructured_api_key = current_app.config['UNSTRUCTURED_API_KEY']
+                etl_type = dify_config.ETL_TYPE
+                unstructured_api_url = dify_config.UNSTRUCTURED_API_URL
+                unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY
                 if etl_type == 'Unstructured':
                 if etl_type == 'Unstructured':
                     if file_extension == '.xlsx' or file_extension == '.xls':
                     if file_extension == '.xlsx' or file_extension == '.xls':
                         extractor = ExcelExtractor(file_path)
                         extractor = ExcelExtractor(file_path)

+ 2 - 2
api/core/rag/extractor/notion_extractor.py

@@ -3,8 +3,8 @@ import logging
 from typing import Any, Optional
 from typing import Any, Optional
 
 
 import requests
 import requests
-from flask import current_app
 
 
+from configs import dify_config
 from core.rag.extractor.extractor_base import BaseExtractor
 from core.rag.extractor.extractor_base import BaseExtractor
 from core.rag.models.document import Document
 from core.rag.models.document import Document
 from extensions.ext_database import db
 from extensions.ext_database import db
@@ -49,7 +49,7 @@ class NotionExtractor(BaseExtractor):
             self._notion_access_token = self._get_access_token(tenant_id,
             self._notion_access_token = self._get_access_token(tenant_id,
                                                                self._notion_workspace_id)
                                                                self._notion_workspace_id)
             if not self._notion_access_token:
             if not self._notion_access_token:
-                integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN')
+                integration_token = dify_config.NOTION_INTEGRATION_TOKEN
                 if integration_token is None:
                 if integration_token is None:
                     raise ValueError(
                     raise ValueError(
                         "Must specify `integration_token` or set environment "
                         "Must specify `integration_token` or set environment "

+ 3 - 4
api/core/rag/extractor/word_extractor.py

@@ -8,8 +8,8 @@ from urllib.parse import urlparse
 
 
 import requests
 import requests
 from docx import Document as DocxDocument
 from docx import Document as DocxDocument
-from flask import current_app
 
 
+from configs import dify_config
 from core.rag.extractor.extractor_base import BaseExtractor
 from core.rag.extractor.extractor_base import BaseExtractor
 from core.rag.models.document import Document
 from core.rag.models.document import Document
 from extensions.ext_database import db
 from extensions.ext_database import db
@@ -96,10 +96,9 @@ class WordExtractor(BaseExtractor):
 
 
                     storage.save(file_key, rel.target_part.blob)
                     storage.save(file_key, rel.target_part.blob)
                 # save file to db
                 # save file to db
-                config = current_app.config
                 upload_file = UploadFile(
                 upload_file = UploadFile(
                     tenant_id=self.tenant_id,
                     tenant_id=self.tenant_id,
-                    storage_type=config['STORAGE_TYPE'],
+                    storage_type=dify_config.STORAGE_TYPE,
                     key=file_key,
                     key=file_key,
                     name=file_key,
                     name=file_key,
                     size=0,
                     size=0,
@@ -114,7 +113,7 @@ class WordExtractor(BaseExtractor):
 
 
                 db.session.add(upload_file)
                 db.session.add(upload_file)
                 db.session.commit()
                 db.session.commit()
-                image_map[rel.target_part] = f"![image]({current_app.config.get('CONSOLE_API_URL')}/files/{upload_file.id}/image-preview)"
+                image_map[rel.target_part] = f"![image]({dify_config.CONSOLE_API_URL}/files/{upload_file.id}/image-preview)"
 
 
         return image_map
         return image_map
 
 

+ 2 - 3
api/core/rag/index_processor/index_processor_base.py

@@ -2,8 +2,7 @@
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from typing import Optional
 from typing import Optional
 
 
-from flask import current_app
-
+from configs import dify_config
 from core.model_manager import ModelInstance
 from core.model_manager import ModelInstance
 from core.rag.extractor.entity.extract_setting import ExtractSetting
 from core.rag.extractor.entity.extract_setting import ExtractSetting
 from core.rag.models.document import Document
 from core.rag.models.document import Document
@@ -48,7 +47,7 @@ class BaseIndexProcessor(ABC):
             # The user-defined segmentation rule
             # The user-defined segmentation rule
             rules = processing_rule['rules']
             rules = processing_rule['rules']
             segmentation = rules["segmentation"]
             segmentation = rules["segmentation"]
-            max_segmentation_tokens_length = int(current_app.config['INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH'])
+            max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
             if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length:
             if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length:
                 raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.")
                 raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.")