Prechádzať zdrojové kódy

refactor(api): Switch to `dify_config` (#6750)

Signed-off-by: -LAN- <laipz8200@outlook.com>
-LAN- 8 mesiacov pred
rodič
commit
a98284b1ef

+ 9 - 9
api/commands.py

@@ -249,8 +249,7 @@ def migrate_knowledge_vector_database():
     create_count = 0
     skipped_count = 0
     total_count = 0
-    config = current_app.config
-    vector_type = config.get('VECTOR_STORE')
+    vector_type = dify_config.VECTOR_STORE
     page = 1
     while True:
         try:
@@ -484,8 +483,7 @@ def convert_to_agent_apps():
 @click.option('--field', default='metadata.doc_id', prompt=False, help='index field , default is metadata.doc_id.')
 def add_qdrant_doc_id_index(field: str):
     click.echo(click.style('Start add qdrant doc_id index.', fg='green'))
-    config = current_app.config
-    vector_type = config.get('VECTOR_STORE')
+    vector_type = dify_config.VECTOR_STORE
     if vector_type != "qdrant":
         click.echo(click.style('Sorry, only support qdrant vector store.', fg='red'))
         return
@@ -502,13 +500,15 @@ def add_qdrant_doc_id_index(field: str):
 
         from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig
         for binding in bindings:
+            if dify_config.QDRANT_URL is None:
+                raise ValueError('Qdrant url is required.')
             qdrant_config = QdrantConfig(
-                endpoint=config.get('QDRANT_URL'),
-                api_key=config.get('QDRANT_API_KEY'),
+                endpoint=dify_config.QDRANT_URL,
+                api_key=dify_config.QDRANT_API_KEY,
                 root_path=current_app.root_path,
-                timeout=config.get('QDRANT_CLIENT_TIMEOUT'),
-                grpc_port=config.get('QDRANT_GRPC_PORT'),
-                prefer_grpc=config.get('QDRANT_GRPC_ENABLED')
+                timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
+                grpc_port=dify_config.QDRANT_GRPC_PORT,
+                prefer_grpc=dify_config.QDRANT_GRPC_ENABLED
             )
             try:
                 client = qdrant_client.QdrantClient(**qdrant_config.to_qdrant_params())

+ 2 - 2
api/controllers/console/auth/login.py

@@ -71,7 +71,7 @@ class ResetPasswordApi(Resource):
         # AccountService.update_password(account, new_password)
 
         # todo: Send email
-        # MAILCHIMP_API_KEY = current_app.config['MAILCHIMP_TRANSACTIONAL_API_KEY']
+        # MAILCHIMP_API_KEY = dify_config.MAILCHIMP_TRANSACTIONAL_API_KEY
         # mailchimp = MailchimpTransactional(MAILCHIMP_API_KEY)
 
         # message = {
@@ -92,7 +92,7 @@ class ResetPasswordApi(Resource):
         #     'message': message,
         #     # required for transactional email
         #     ' settings': {
-        #         'sandbox_mode': current_app.config['MAILCHIMP_SANDBOX_MODE'],
+        #         'sandbox_mode': dify_config.MAILCHIMP_SANDBOX_MODE,
         #     },
         # })
 

+ 3 - 2
api/core/indexing_runner.py

@@ -12,6 +12,7 @@ from flask import Flask, current_app
 from flask_login import current_user
 from sqlalchemy.orm.exc import ObjectDeletedError
 
+from configs import dify_config
 from core.errors.error import ProviderTokenNotInitError
 from core.llm_generator.llm_generator import LLMGenerator
 from core.model_manager import ModelInstance, ModelManager
@@ -224,7 +225,7 @@ class IndexingRunner:
         features = FeatureService.get_features(tenant_id)
         if features.billing.enabled:
             count = len(extract_settings)
-            batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
+            batch_upload_limit = dify_config.BATCH_UPLOAD_LIMIT
             if count > batch_upload_limit:
                 raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
 
@@ -427,7 +428,7 @@ class IndexingRunner:
             # The user-defined segmentation rule
             rules = json.loads(processing_rule.rules)
             segmentation = rules["segmentation"]
-            max_segmentation_tokens_length = int(current_app.config['INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH'])
+            max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
             if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length:
                 raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.")
 

+ 2 - 2
api/core/rag/datasource/keyword/jieba/jieba.py

@@ -2,9 +2,9 @@ import json
 from collections import defaultdict
 from typing import Any, Optional
 
-from flask import current_app
 from pydantic import BaseModel
 
+from configs import dify_config
 from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
 from core.rag.datasource.keyword.keyword_base import BaseKeyword
 from core.rag.models.document import Document
@@ -139,7 +139,7 @@ class Jieba(BaseKeyword):
             if keyword_table_dict:
                 return keyword_table_dict['__data__']['table']
         else:
-            keyword_data_source_type = current_app.config['KEYWORD_DATA_SOURCE_TYPE']
+            keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE
             dataset_keyword_table = DatasetKeywordTable(
                 dataset_id=self.dataset.id,
                 keyword_table='',

+ 5 - 4
api/libs/passport.py

@@ -1,15 +1,16 @@
 import jwt
-from flask import current_app
 from werkzeug.exceptions import Unauthorized
 
+from configs import dify_config
+
 
 class PassportService:
     def __init__(self):
-        self.sk = current_app.config.get('SECRET_KEY')
-    
+        self.sk = dify_config.SECRET_KEY
+
     def issue(self, payload):
         return jwt.encode(payload, self.sk, algorithm='HS256')
-    
+
     def verify(self, token):
         try:
             return jwt.decode(token, self.sk, algorithms=['HS256'])