Forráskód Böngészése

refactor: extract vdb configs into pydantic-setting based dify configs (#5426)

Bowen Liang 10 hónapja
szülő
commit
65d34ebb96

+ 0 - 84
api/config.py

@@ -17,11 +17,6 @@ DEFAULTS = {
     'SQLALCHEMY_POOL_RECYCLE': 3600,
     'SQLALCHEMY_POOL_PRE_PING': 'False',
     'SQLALCHEMY_ECHO': 'False',
-    'WEAVIATE_GRPC_ENABLED': 'True',
-    'WEAVIATE_BATCH_SIZE': 100,
-    'QDRANT_CLIENT_TIMEOUT': 20,
-    'QDRANT_GRPC_ENABLED': 'False',
-    'QDRANT_GRPC_PORT': '6334',
     'CELERY_BACKEND': 'database',
     'HOSTED_OPENAI_QUOTA_LIMIT': 200,
     'HOSTED_OPENAI_TRIAL_ENABLED': 'False',
@@ -37,7 +32,6 @@ DEFAULTS = {
     'HOSTED_MODERATION_PROVIDERS': '',
     'HOSTED_FETCH_APP_TEMPLATES_MODE': 'remote',
     'HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN': 'https://tmpl.dify.ai',
-    'MILVUS_DATABASE': 'default',
 }
 
 
@@ -141,84 +135,6 @@ class Config:
         self.TENCENT_COS_SECRET_KEY = get_env('TENCENT_COS_SECRET_KEY')
         self.TENCENT_COS_SCHEME = get_env('TENCENT_COS_SCHEME')
 
-        # ------------------------
-        # Vector Store Configurations.
-        # Currently, only support: qdrant, milvus, zilliz, weaviate, relyt, pgvector
-        # ------------------------
-
-        # qdrant settings
-        self.QDRANT_URL = get_env('QDRANT_URL')
-        self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
-        self.QDRANT_CLIENT_TIMEOUT = get_env('QDRANT_CLIENT_TIMEOUT')
-        self.QDRANT_GRPC_ENABLED = get_env('QDRANT_GRPC_ENABLED')
-        self.QDRANT_GRPC_PORT = get_env('QDRANT_GRPC_PORT')
-
-        # milvus / zilliz setting
-        self.MILVUS_HOST = get_env('MILVUS_HOST')
-        self.MILVUS_PORT = get_env('MILVUS_PORT')
-        self.MILVUS_USER = get_env('MILVUS_USER')
-        self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD')
-        self.MILVUS_SECURE = get_env('MILVUS_SECURE')
-        self.MILVUS_DATABASE = get_env('MILVUS_DATABASE')
-
-        # OpenSearch settings
-        self.OPENSEARCH_HOST = get_env('OPENSEARCH_HOST')
-        self.OPENSEARCH_PORT = get_env('OPENSEARCH_PORT')
-        self.OPENSEARCH_USER = get_env('OPENSEARCH_USER')
-        self.OPENSEARCH_PASSWORD = get_env('OPENSEARCH_PASSWORD')
-        self.OPENSEARCH_SECURE = get_bool_env('OPENSEARCH_SECURE')
-
-        # weaviate settings
-        self.WEAVIATE_ENDPOINT = get_env('WEAVIATE_ENDPOINT')
-        self.WEAVIATE_API_KEY = get_env('WEAVIATE_API_KEY')
-        self.WEAVIATE_GRPC_ENABLED = get_bool_env('WEAVIATE_GRPC_ENABLED')
-        self.WEAVIATE_BATCH_SIZE = int(get_env('WEAVIATE_BATCH_SIZE'))
-
-        # relyt settings
-        self.RELYT_HOST = get_env('RELYT_HOST')
-        self.RELYT_PORT = get_env('RELYT_PORT')
-        self.RELYT_USER = get_env('RELYT_USER')
-        self.RELYT_PASSWORD = get_env('RELYT_PASSWORD')
-        self.RELYT_DATABASE = get_env('RELYT_DATABASE')
-
-        # tencent settings
-        self.TENCENT_VECTOR_DB_URL = get_env('TENCENT_VECTOR_DB_URL')
-        self.TENCENT_VECTOR_DB_API_KEY = get_env('TENCENT_VECTOR_DB_API_KEY')
-        self.TENCENT_VECTOR_DB_TIMEOUT = get_env('TENCENT_VECTOR_DB_TIMEOUT')
-        self.TENCENT_VECTOR_DB_USERNAME = get_env('TENCENT_VECTOR_DB_USERNAME')
-        self.TENCENT_VECTOR_DB_DATABASE = get_env('TENCENT_VECTOR_DB_DATABASE')
-        self.TENCENT_VECTOR_DB_SHARD = get_env('TENCENT_VECTOR_DB_SHARD')
-        self.TENCENT_VECTOR_DB_REPLICAS = get_env('TENCENT_VECTOR_DB_REPLICAS')
-
-        # pgvecto rs settings
-        self.PGVECTO_RS_HOST = get_env('PGVECTO_RS_HOST')
-        self.PGVECTO_RS_PORT = get_env('PGVECTO_RS_PORT')
-        self.PGVECTO_RS_USER = get_env('PGVECTO_RS_USER')
-        self.PGVECTO_RS_PASSWORD = get_env('PGVECTO_RS_PASSWORD')
-        self.PGVECTO_RS_DATABASE = get_env('PGVECTO_RS_DATABASE')
-
-        # pgvector settings
-        self.PGVECTOR_HOST = get_env('PGVECTOR_HOST')
-        self.PGVECTOR_PORT = get_env('PGVECTOR_PORT')
-        self.PGVECTOR_USER = get_env('PGVECTOR_USER')
-        self.PGVECTOR_PASSWORD = get_env('PGVECTOR_PASSWORD')
-        self.PGVECTOR_DATABASE = get_env('PGVECTOR_DATABASE')
-
-        # tidb-vector settings
-        self.TIDB_VECTOR_HOST = get_env('TIDB_VECTOR_HOST')
-        self.TIDB_VECTOR_PORT = get_env('TIDB_VECTOR_PORT')
-        self.TIDB_VECTOR_USER = get_env('TIDB_VECTOR_USER')
-        self.TIDB_VECTOR_PASSWORD = get_env('TIDB_VECTOR_PASSWORD')
-        self.TIDB_VECTOR_DATABASE = get_env('TIDB_VECTOR_DATABASE')
-
-        # chroma settings
-        self.CHROMA_HOST = get_env('CHROMA_HOST')
-        self.CHROMA_PORT = get_env('CHROMA_PORT')
-        self.CHROMA_TENANT = get_env('CHROMA_TENANT')
-        self.CHROMA_DATABASE = get_env('CHROMA_DATABASE')
-        self.CHROMA_AUTH_PROVIDER = get_env('CHROMA_AUTH_PROVIDER')
-        self.CHROMA_AUTH_CREDENTIALS = get_env('CHROMA_AUTH_CREDENTIALS')
-
         # ------------------------
         # Platform Configurations.
         # ------------------------

+ 22 - 0
api/configs/middleware/__init__.py

@@ -3,6 +3,16 @@ from typing import Optional
 from pydantic import BaseModel, Field
 
 from configs.middleware.redis_configs import RedisConfigs
+from configs.middleware.vdb.chroma_configs import ChromaConfigs
+from configs.middleware.vdb.milvus_configs import MilvusConfigs
+from configs.middleware.vdb.opensearch_configs import OpenSearchConfigs
+from configs.middleware.vdb.pgvector_configs import PGVectorConfigs
+from configs.middleware.vdb.pgvectors_configs import PGVectoRSConfigs
+from configs.middleware.vdb.qdrant_configs import QdrantConfigs
+from configs.middleware.vdb.relyt_configs import RelytConfigs
+from configs.middleware.vdb.tencent_vector_configs import TencentVectorDBConfigs
+from configs.middleware.vdb.tidb_vector_configs import TiDBVectorConfigs
+from configs.middleware.vdb.weaviate_configs import WeaviateConfigs
 
 
 class StorageConfigs(BaseModel):
@@ -38,6 +48,18 @@ class MiddlewareConfigs(
     KeywordStoreConfigs,
     RedisConfigs,
     StorageConfigs,
+
+    # configs of vdb and vdb providers
     VectorStoreConfigs,
+    ChromaConfigs,
+    MilvusConfigs,
+    OpenSearchConfigs,
+    PGVectorConfigs,
+    PGVectoRSConfigs,
+    QdrantConfigs,
+    RelytConfigs,
+    TencentVectorDBConfigs,
+    TiDBVectorConfigs,
+    WeaviateConfigs,
 ):
     pass

+ 39 - 0
api/configs/middleware/vdb/chroma_configs.py

@@ -0,0 +1,39 @@
+from typing import Optional
+
+from pydantic import BaseModel, Field, PositiveInt
+
+
+class ChromaConfigs(BaseModel):
+    """
+    Chroma configs
+    """
+
+    CHROMA_HOST: Optional[str] = Field(
+        description='Chroma host',
+        default=None,
+    )
+
+    CHROMA_PORT: PositiveInt = Field(
+        description='Chroma port',
+        default=8000,
+    )
+
+    CHROMA_TENANT: Optional[str] = Field(
+        description='Chroma database',
+        default=None,
+    )
+
+    CHROMA_DATABASE: Optional[str] = Field(
+        description='Chroma database',
+        default=None,
+    )
+
+    CHROMA_AUTH_PROVIDER: Optional[str] = Field(
+        description='Chroma authentication provider',
+        default=None,
+    )
+
+    CHROMA_AUTH_CREDENTIALS: Optional[str] = Field(
+        description='Chroma authentication credentials',
+        default=None,
+    )

+ 39 - 0
api/configs/middleware/vdb/milvus_configs.py

@@ -0,0 +1,39 @@
+from typing import Optional
+
+from pydantic import BaseModel, Field, PositiveInt
+
+
+class MilvusConfigs(BaseModel):
+    """
+    Milvus configs
+    """
+
+    MILVUS_HOST: Optional[str] = Field(
+        description='Milvus host',
+        default=None,
+    )
+
+    MILVUS_PORT: PositiveInt = Field(
+        description='Milvus RestFul API port',
+        default=9091,
+    )
+
+    MILVUS_USER: Optional[str] = Field(
+        description='Milvus user',
+        default=None,
+    )
+
+    MILVUS_PASSWORD: Optional[str] = Field(
+        description='Milvus password',
+        default=None,
+    )
+
+    MILVUS_SECURE: bool = Field(
+        description='wheter to use SSL connection for Milvus',
+        default=False,
+    )
+
+    MILVUS_DATABASE: str = Field(
+        description='Milvus database',
+        default='default',
+    )

+ 34 - 0
api/configs/middleware/vdb/opensearch_configs.py

@@ -0,0 +1,34 @@
+from typing import Optional
+
+from pydantic import BaseModel, Field, PositiveInt
+
+
+class OpenSearchConfigs(BaseModel):
+    """
+    OpenSearch configs
+    """
+
+    OPENSEARCH_HOST: Optional[str] = Field(
+        description='OpenSearch host',
+        default=None,
+    )
+
+    OPENSEARCH_PORT: PositiveInt = Field(
+        description='OpenSearch port',
+        default=9200,
+    )
+
+    OPENSEARCH_USER: Optional[str] = Field(
+        description='OpenSearch user',
+        default=None,
+    )
+
+    OPENSEARCH_PASSWORD: Optional[str] = Field(
+        description='OpenSearch password',
+        default=None,
+    )
+
+    OPENSEARCH_SECURE: bool = Field(
+        description='whether to use SSL connection for OpenSearch',
+        default=False,
+    )

+ 34 - 0
api/configs/middleware/vdb/pgvector_configs.py

@@ -0,0 +1,34 @@
+from typing import Optional
+
+from pydantic import BaseModel, Field, PositiveInt
+
+
+class PGVectorConfigs(BaseModel):
+    """
+    PGVector configs
+    """
+
+    PGVECTOR_HOST: Optional[str] = Field(
+        description='PGVector host',
+        default=None,
+    )
+
+    PGVECTOR_PORT: Optional[PositiveInt] = Field(
+        description='PGVector port',
+        default=None,
+    )
+
+    PGVECTOR_USER: Optional[str] = Field(
+        description='PGVector user',
+        default=None,
+    )
+
+    PGVECTOR_PASSWORD: Optional[str] = Field(
+        description='PGVector password',
+        default=None,
+    )
+
+    PGVECTOR_DATABASE: Optional[str] = Field(
+        description='PGVector database',
+        default=None,
+    )

+ 34 - 0
api/configs/middleware/vdb/pgvectors_configs.py

@@ -0,0 +1,34 @@
+from typing import Optional
+
+from pydantic import BaseModel, Field, PositiveInt
+
+
+class PGVectoRSConfigs(BaseModel):
+    """
+    PGVectoRS configs
+    """
+
+    PGVECTO_RS_HOST: Optional[str] = Field(
+        description='PGVectoRS host',
+        default=None,
+    )
+
+    PGVECTO_RS_PORT: Optional[PositiveInt] = Field(
+        description='PGVectoRS port',
+        default=None,
+    )
+
+    PGVECTO_RS_USER: Optional[str] = Field(
+        description='PGVectoRS user',
+        default=None,
+    )
+
+    PGVECTO_RS_PASSWORD: Optional[str] = Field(
+        description='PGVectoRS password',
+        default=None,
+    )
+
+    PGVECTO_RS_DATABASE: Optional[str] = Field(
+        description='PGVectoRS database',
+        default=None,
+    )

+ 34 - 0
api/configs/middleware/vdb/qdrant_configs.py

@@ -0,0 +1,34 @@
+from typing import Optional
+
+from pydantic import BaseModel, Field, NonNegativeInt, PositiveInt
+
+
+class QdrantConfigs(BaseModel):
+    """
+    Qdrant configs
+    """
+
+    QDRANT_URL: Optional[str] = Field(
+        description='Qdrant url',
+        default=None,
+    )
+
+    QDRANT_API_KEY: Optional[str] = Field(
+        description='Qdrant api key',
+        default=None,
+    )
+
+    QDRANT_CLIENT_TIMEOUT: NonNegativeInt = Field(
+        description='Qdrant client timeout in seconds',
+        default=20,
+    )
+
+    QDRANT_GRPC_ENABLED: bool = Field(
+        description='whether enable grpc support for Qdrant connection',
+        default=False,
+    )
+
+    QDRANT_GRPC_PORT: PositiveInt = Field(
+        description='Qdrant grpc port',
+        default=6334,
+    )

+ 34 - 0
api/configs/middleware/vdb/relyt_configs.py

@@ -0,0 +1,34 @@
+from typing import Optional
+
+from pydantic import BaseModel, Field, PositiveInt
+
+
+class RelytConfigs(BaseModel):
+    """
+    Relyt configs
+    """
+
+    RELYT_HOST: Optional[str] = Field(
+        description='Relyt host',
+        default=None,
+    )
+
+    RELYT_PORT: PositiveInt = Field(
+        description='Relyt port',
+        default=9200,
+    )
+
+    RELYT_USER: Optional[str] = Field(
+        description='Relyt user',
+        default=None,
+    )
+
+    RELYT_PASSWORD: Optional[str] = Field(
+        description='Relyt password',
+        default=None,
+    )
+
+    RELYT_DATABASE: Optional[str] = Field(
+        description='Relyt database',
+        default='default',
+    )

+ 44 - 0
api/configs/middleware/vdb/tencent_vector_configs.py

@@ -0,0 +1,44 @@
+from typing import Optional
+
+from pydantic import BaseModel, Field, PositiveInt
+
+
+class TencentVectorDBConfigs(BaseModel):
+    """
+    Tencent Vector configs
+    """
+
+    TENCENT_VECTOR_DB_URL: Optional[str] = Field(
+        description='Tencent Vector URL',
+        default=None,
+    )
+
+    TENCENT_VECTOR_DB_API_KEY: Optional[str] = Field(
+        description='Tencent Vector api key',
+        default=None,
+    )
+
+    TENCENT_VECTOR_DB_TIMEOUT: PositiveInt = Field(
+        description='Tencent Vector timeout',
+        default=30,
+    )
+
+    TENCENT_VECTOR_DB_USERNAME: Optional[str] = Field(
+        description='Tencent Vector password',
+        default=None,
+    )
+
+    TENCENT_VECTOR_DB_PASSWORD: Optional[str] = Field(
+        description='Tencent Vector password',
+        default=None,
+    )
+
+    TENCENT_VECTOR_DB_SHARD: PositiveInt = Field(
+        description='Tencent Vector sharding number',
+        default=1,
+    )
+
+    TENCENT_VECTOR_DB_REPLICAS: PositiveInt = Field(
+        description='Tencent Vector replicas',
+        default=2,
+    )

+ 34 - 0
api/configs/middleware/vdb/tidb_vector_configs.py

@@ -0,0 +1,34 @@
+from typing import Optional
+
+from pydantic import BaseModel, Field, PositiveInt
+
+
+class TiDBVectorConfigs(BaseModel):
+    """
+    TiDB Vector configs
+    """
+
+    TIDB_VECTOR_HOST: Optional[str] = Field(
+        description='TiDB Vector host',
+        default=None,
+    )
+
+    TIDB_VECTOR_PORT: Optional[PositiveInt] = Field(
+        description='TiDB Vector port',
+        default=None,
+    )
+
+    TIDB_VECTOR_USER: Optional[str] = Field(
+        description='TiDB Vector user',
+        default=None,
+    )
+
+    TIDB_VECTOR_PASSWORD: Optional[str] = Field(
+        description='TiDB Vector password',
+        default=None,
+    )
+
+    TIDB_VECTOR_DATABASE: Optional[str] = Field(
+        description='TiDB Vector database',
+        default=None,
+    )

+ 29 - 0
api/configs/middleware/vdb/weaviate_configs.py

@@ -0,0 +1,29 @@
+from typing import Optional
+
+from pydantic import BaseModel, Field, PositiveInt
+
+
+class WeaviateConfigs(BaseModel):
+    """
+    Weaviate configs
+    """
+
+    WEAVIATE_ENDPOINT: Optional[str] = Field(
+        description='Weaviate endpoint URL',
+        default=None,
+    )
+
+    WEAVIATE_API_KEY: Optional[str] = Field(
+        description='Weaviate API key',
+        default=None,
+    )
+
+    WEAVIATE_GRPC_ENABLED: bool = Field(
+        description='whether to enable gRPC for Weaviate connection',
+        default=True,
+    )
+
+    WEAVIATE_BATCH_SIZE: PositiveInt = Field(
+        description='Weaviate batch size',
+        default=100,
+    )