소스 검색

refactor(services/tasks): Swtich to dify_config witch Pydantic (#6203)

Waffle 9 달 전
부모
커밋
7b225a5ab0

+ 3 - 4
api/core/rag/datasource/keyword/keyword_factory.py

@@ -1,7 +1,6 @@
 from typing import Any
 
-from flask import current_app
-
+from configs import dify_config
 from core.rag.datasource.keyword.jieba.jieba import Jieba
 from core.rag.datasource.keyword.keyword_base import BaseKeyword
 from core.rag.models.document import Document
@@ -14,8 +13,8 @@ class Keyword:
         self._keyword_processor = self._init_keyword()
 
     def _init_keyword(self) -> BaseKeyword:
-        config = current_app.config
-        keyword_type = config.get('KEYWORD_STORE')
+        config = dify_config
+        keyword_type = config.KEYWORD_STORE
 
         if not keyword_type:
             raise ValueError("Keyword store must be specified.")

+ 5 - 5
api/services/account_service.py

@@ -6,10 +6,10 @@ from datetime import datetime, timedelta, timezone
 from hashlib import sha256
 from typing import Any, Optional
 
-from flask import current_app
 from sqlalchemy import func
 from werkzeug.exceptions import Unauthorized
 
+from configs import dify_config
 from constants.languages import language_timezone_mapping, languages
 from events.tenant_event import tenant_was_created
 from extensions.ext_redis import redis_client
@@ -80,7 +80,7 @@ class AccountService:
         payload = {
             "user_id": account.id,
             "exp": datetime.now(timezone.utc).replace(tzinfo=None) + exp,
-            "iss": current_app.config['EDITION'],
+            "iss": dify_config.EDITION,
             "sub": 'Console API Passport',
         }
 
@@ -524,7 +524,7 @@ class RegisterService:
             TenantService.create_owner_tenant_if_not_exist(account)
 
             dify_setup = DifySetup(
-                version=current_app.config['CURRENT_VERSION']
+                version=dify_config.CURRENT_VERSION
             )
             db.session.add(dify_setup)
             db.session.commit()
@@ -559,7 +559,7 @@ class RegisterService:
 
             if open_id is not None or provider is not None:
                 AccountService.link_account_integrate(provider, open_id, account)
-            if current_app.config['EDITION'] != 'SELF_HOSTED':
+            if dify_config.EDITION != 'SELF_HOSTED':
                 tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
 
                 TenantService.create_tenant_member(tenant, account, role='owner')
@@ -623,7 +623,7 @@ class RegisterService:
             'email': account.email,
             'workspace_id': tenant.id,
         }
-        expiryHours = current_app.config['INVITE_EXPIRY_HOURS']
+        expiryHours = dify_config.INVITE_EXPIRY_HOURS
         redis_client.setex(
             cls._get_invitation_token_key(token),
             expiryHours * 60 * 60,

+ 2 - 2
api/services/app_generate_service.py

@@ -1,6 +1,7 @@
 from collections.abc import Generator
 from typing import Any, Union
 
+from configs import dify_config
 from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
 from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
 from core.app.apps.chat.app_generator import ChatAppGenerator
@@ -89,8 +90,7 @@ class AppGenerateService:
     def _get_max_active_requests(app_model: App) -> int:
         max_active_requests = app_model.max_active_requests
         if app_model.max_active_requests is None:
-            from flask import current_app
-            max_active_requests = int(current_app.config['APP_MAX_ACTIVE_REQUESTS'])
+            max_active_requests = int(dify_config.APP_MAX_ACTIVE_REQUESTS)
         return max_active_requests
 
     @classmethod

+ 2 - 2
api/services/app_service.py

@@ -4,10 +4,10 @@ from datetime import datetime, timezone
 from typing import cast
 
 import yaml
-from flask import current_app
 from flask_login import current_user
 from flask_sqlalchemy.pagination import Pagination
 
+from configs import dify_config
 from constants.model_template import default_app_templates
 from core.agent.entities import AgentToolEntity
 from core.app.features.rate_limiting import RateLimit
@@ -446,7 +446,7 @@ class AppService:
             # get all tools
             tools = agent_config.get('tools', [])
 
-        url_prefix = (current_app.config.get("CONSOLE_API_URL")
+        url_prefix = (dify_config.CONSOLE_API_URL
                       + "/console/api/workspaces/current/tool-provider/builtin/")
 
         for tool in tools:

+ 3 - 3
api/services/dataset_service.py

@@ -6,10 +6,10 @@ import time
 import uuid
 from typing import Optional
 
-from flask import current_app
 from flask_login import current_user
 from sqlalchemy import func
 
+from configs import dify_config
 from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
 from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelType
@@ -650,7 +650,7 @@ class DocumentService:
                 elif document_data["data_source"]["type"] == "website_crawl":
                     website_info = document_data["data_source"]['info_list']['website_info_list']
                     count = len(website_info['urls'])
-                batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
+                batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
                 if count > batch_upload_limit:
                     raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
 
@@ -1028,7 +1028,7 @@ class DocumentService:
             elif document_data["data_source"]["type"] == "website_crawl":
                 website_info = document_data["data_source"]['info_list']['website_info_list']
                 count = len(website_info['urls'])
-            batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
+            batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
             if count > batch_upload_limit:
                 raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
 

+ 4 - 4
api/services/entities/model_provider_entities.py

@@ -1,9 +1,9 @@
 from enum import Enum
 from typing import Optional
 
-from flask import current_app
 from pydantic import BaseModel, ConfigDict
 
+from configs import dify_config
 from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity
 from core.entities.provider_entities import QuotaConfiguration
 from core.model_runtime.entities.common_entities import I18nObject
@@ -67,7 +67,7 @@ class ProviderResponse(BaseModel):
     def __init__(self, **data) -> None:
         super().__init__(**data)
 
-        url_prefix = (current_app.config.get("CONSOLE_API_URL")
+        url_prefix = (dify_config.CONSOLE_API_URL
                       + f"/console/api/workspaces/current/model-providers/{self.provider}")
         if self.icon_small is not None:
             self.icon_small = I18nObject(
@@ -96,7 +96,7 @@ class ProviderWithModelsResponse(BaseModel):
     def __init__(self, **data) -> None:
         super().__init__(**data)
 
-        url_prefix = (current_app.config.get("CONSOLE_API_URL")
+        url_prefix = (dify_config.CONSOLE_API_URL
                       + f"/console/api/workspaces/current/model-providers/{self.provider}")
         if self.icon_small is not None:
             self.icon_small = I18nObject(
@@ -119,7 +119,7 @@ class SimpleProviderEntityResponse(SimpleProviderEntity):
     def __init__(self, **data) -> None:
         super().__init__(**data)
 
-        url_prefix = (current_app.config.get("CONSOLE_API_URL")
+        url_prefix = (dify_config.CONSOLE_API_URL
                       + f"/console/api/workspaces/current/model-providers/{self.provider}")
         if self.icon_small is not None:
             self.icon_small = I18nObject(

+ 6 - 6
api/services/feature_service.py

@@ -1,6 +1,6 @@
-from flask import current_app
 from pydantic import BaseModel, ConfigDict
 
+from configs import dify_config
 from services.billing_service import BillingService
 from services.enterprise.enterprise_service import EnterpriseService
 
@@ -51,7 +51,7 @@ class FeatureService:
 
         cls._fulfill_params_from_env(features)
 
-        if current_app.config['BILLING_ENABLED']:
+        if dify_config.BILLING_ENABLED:
             cls._fulfill_params_from_billing_api(features, tenant_id)
 
         return features
@@ -60,16 +60,16 @@ class FeatureService:
     def get_system_features(cls) -> SystemFeatureModel:
         system_features = SystemFeatureModel()
 
-        if current_app.config['ENTERPRISE_ENABLED']:
+        if dify_config.ENTERPRISE_ENABLED:
             cls._fulfill_params_from_enterprise(system_features)
 
         return system_features
 
     @classmethod
     def _fulfill_params_from_env(cls, features: FeatureModel):
-        features.can_replace_logo = current_app.config['CAN_REPLACE_LOGO']
-        features.model_load_balancing_enabled = current_app.config['MODEL_LB_ENABLED']
-        features.dataset_operator_enabled = current_app.config['DATASET_OPERATOR_ENABLED']
+        features.can_replace_logo = dify_config.CAN_REPLACE_LOGO
+        features.model_load_balancing_enabled = dify_config.MODEL_LB_ENABLED
+        features.dataset_operator_enabled = dify_config.DATASET_OPERATOR_ENABLED
 
     @classmethod
     def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):

+ 7 - 9
api/services/file_service.py

@@ -4,11 +4,11 @@ import uuid
 from collections.abc import Generator
 from typing import Union
 
-from flask import current_app
 from flask_login import current_user
 from werkzeug.datastructures import FileStorage
 from werkzeug.exceptions import NotFound
 
+from configs import dify_config
 from core.file.upload_file_parser import UploadFileParser
 from core.rag.extractor.extract_processor import ExtractProcessor
 from extensions.ext_database import db
@@ -35,7 +35,7 @@ class FileService:
         extension = file.filename.split('.')[-1]
         if len(filename) > 200:
             filename = filename.split('.')[0][:200] + '.' + extension
-        etl_type = current_app.config['ETL_TYPE']
+        etl_type = dify_config.ETL_TYPE
         allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS if etl_type == 'Unstructured' \
             else ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS
         if extension.lower() not in allowed_extensions:
@@ -50,9 +50,9 @@ class FileService:
         file_size = len(file_content)
 
         if extension.lower() in IMAGE_EXTENSIONS:
-            file_size_limit = current_app.config.get("UPLOAD_IMAGE_FILE_SIZE_LIMIT") * 1024 * 1024
+            file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024
         else:
-            file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024
+            file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024
 
         if file_size > file_size_limit:
             message = f'File size exceeded. {file_size} > {file_size_limit}'
@@ -73,10 +73,9 @@ class FileService:
         storage.save(file_key, file_content)
 
         # save file to db
-        config = current_app.config
         upload_file = UploadFile(
             tenant_id=current_tenant_id,
-            storage_type=config['STORAGE_TYPE'],
+            storage_type=dify_config.STORAGE_TYPE,
             key=file_key,
             name=filename,
             size=file_size,
@@ -106,10 +105,9 @@ class FileService:
         storage.save(file_key, text.encode('utf-8'))
 
         # save file to db
-        config = current_app.config
         upload_file = UploadFile(
             tenant_id=current_user.current_tenant_id,
-            storage_type=config['STORAGE_TYPE'],
+            storage_type=dify_config.STORAGE_TYPE,
             key=file_key,
             name=text_name + '.txt',
             size=len(text),
@@ -138,7 +136,7 @@ class FileService:
 
         # extract text from file
         extension = upload_file.extension
-        etl_type = current_app.config['ETL_TYPE']
+        etl_type = dify_config.ETL_TYPE
         allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS
         if extension.lower() not in allowed_extensions:
             raise UnsupportedFileTypeError()

+ 5 - 5
api/services/recommended_app_service.py

@@ -4,8 +4,8 @@ from os import path
 from typing import Optional
 
 import requests
-from flask import current_app
 
+from configs import dify_config
 from constants.languages import languages
 from extensions.ext_database import db
 from models.model import App, RecommendedApp
@@ -25,7 +25,7 @@ class RecommendedAppService:
         :param language: language
         :return:
         """
-        mode = current_app.config.get('HOSTED_FETCH_APP_TEMPLATES_MODE', 'remote')
+        mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
         if mode == 'remote':
             try:
                 result = cls._fetch_recommended_apps_from_dify_official(language)
@@ -104,7 +104,7 @@ class RecommendedAppService:
         :param language: language
         :return:
         """
-        domain = current_app.config.get('HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN', 'https://tmpl.dify.ai')
+        domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN
         url = f'{domain}/apps?language={language}'
         response = requests.get(url, timeout=(3, 10))
         if response.status_code != 200:
@@ -134,7 +134,7 @@ class RecommendedAppService:
         :param app_id: app id
         :return:
         """
-        mode = current_app.config.get('HOSTED_FETCH_APP_TEMPLATES_MODE', 'remote')
+        mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
         if mode == 'remote':
             try:
                 result = cls._fetch_recommended_app_detail_from_dify_official(app_id)
@@ -157,7 +157,7 @@ class RecommendedAppService:
         :param app_id: App ID
         :return:
         """
-        domain = current_app.config.get('HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN', 'https://tmpl.dify.ai')
+        domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN
         url = f'{domain}/apps/{app_id}'
         response = requests.get(url, timeout=(3, 10))
         if response.status_code != 200:

+ 2 - 3
api/services/tools/tools_transform_service.py

@@ -2,8 +2,7 @@ import json
 import logging
 from typing import Optional, Union
 
-from flask import current_app
-
+from configs import dify_config
 from core.tools.entities.api_entities import UserTool, UserToolProvider
 from core.tools.entities.common_entities import I18nObject
 from core.tools.entities.tool_bundle import ApiToolBundle
@@ -29,7 +28,7 @@ class ToolTransformService:
         """
             get tool provider icon url
         """
-        url_prefix = (current_app.config.get("CONSOLE_API_URL")
+        url_prefix = (dify_config.CONSOLE_API_URL
                       + "/console/api/workspaces/current/tool-provider/")
         
         if provider_type == ToolProviderType.BUILT_IN.value:

+ 2 - 2
api/services/workspace_service.py

@@ -1,7 +1,7 @@
 
-from flask import current_app
 from flask_login import current_user
 
+from configs import dify_config
 from extensions.ext_database import db
 from models.account import Tenant, TenantAccountJoin, TenantAccountJoinRole
 from services.account_service import TenantService
@@ -35,7 +35,7 @@ class WorkspaceService:
 
         if can_replace_logo and TenantService.has_roles(tenant, 
         [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]):
-            base_url = current_app.config.get('FILES_URL')
+            base_url = dify_config.FILES_URL
             replace_webapp_logo = f'{base_url}/files/workspaces/{tenant.id}/webapp-logo' if tenant.custom_config_dict.get('replace_webapp_logo') else None
             remove_webapp_brand = tenant.custom_config_dict.get('remove_webapp_brand', False)
 

+ 2 - 2
api/tasks/document_indexing_task.py

@@ -4,8 +4,8 @@ import time
 
 import click
 from celery import shared_task
-from flask import current_app
 
+from configs import dify_config
 from core.indexing_runner import DocumentIsPausedException, IndexingRunner
 from extensions.ext_database import db
 from models.dataset import Dataset, Document
@@ -32,7 +32,7 @@ def document_indexing_task(dataset_id: str, document_ids: list):
         if features.billing.enabled:
             vector_space = features.vector_space
             count = len(document_ids)
-            batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
+            batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
             if count > batch_upload_limit:
                 raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
             if 0 < vector_space.limit <= vector_space.size:

+ 2 - 2
api/tasks/duplicate_document_indexing_task.py

@@ -4,8 +4,8 @@ import time
 
 import click
 from celery import shared_task
-from flask import current_app
 
+from configs import dify_config
 from core.indexing_runner import DocumentIsPausedException, IndexingRunner
 from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
 from extensions.ext_database import db
@@ -33,7 +33,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
         if features.billing.enabled:
             vector_space = features.vector_space
             count = len(document_ids)
-            batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
+            batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
             if count > batch_upload_limit:
                 raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
             if 0 < vector_space.limit <= vector_space.size:

+ 3 - 2
api/tasks/mail_invite_member_task.py

@@ -3,8 +3,9 @@ import time
 
 import click
 from celery import shared_task
-from flask import current_app, render_template
+from flask import render_template
 
+from configs import dify_config
 from extensions.ext_mail import mail
 
 
@@ -29,7 +30,7 @@ def send_invite_member_mail_task(language: str, to: str, token: str, inviter_nam
 
     # send invite member mail using different languages
     try:
-        url = f'{current_app.config.get("CONSOLE_WEB_URL")}/activate?token={token}'
+        url = f'{dify_config.CONSOLE_WEB_URL}/activate?token={token}'
         if language == 'zh-Hans':
             html_content = render_template('invite_member_mail_template_zh-CN.html',
                                            to=to,

+ 3 - 2
api/tasks/mail_reset_password_task.py

@@ -3,8 +3,9 @@ import time
 
 import click
 from celery import shared_task
-from flask import current_app, render_template
+from flask import render_template
 
+from configs import dify_config
 from extensions.ext_mail import mail
 
 
@@ -24,7 +25,7 @@ def send_reset_password_mail_task(language: str, to: str, token: str):
 
     # send reset password mail using different languages
     try:
-        url = f'{current_app.config.get("CONSOLE_WEB_URL")}/forgot-password?token={token}'
+        url = f'{dify_config.CONSOLE_WEB_URL}/forgot-password?token={token}'
         if language == 'zh-Hans':
             html_content = render_template('reset_password_mail_template_zh-CN.html',
                                            to=to,