app.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. import os
  2. if os.environ.get("DEBUG", "false").lower() != 'true':
  3. from gevent import monkey
  4. monkey.patch_all()
  5. import grpc.experimental.gevent
  6. grpc.experimental.gevent.init_gevent()
  7. import json
  8. import logging
  9. import sys
  10. import threading
  11. import time
  12. import warnings
  13. from logging.handlers import RotatingFileHandler
  14. from flask import Flask, Response, request
  15. from flask_cors import CORS
  16. from werkzeug.exceptions import Unauthorized
  17. import contexts
  18. from commands import register_commands
  19. from configs import dify_config
  20. # DO NOT REMOVE BELOW
  21. from events import event_handlers
  22. from extensions import (
  23. ext_celery,
  24. ext_code_based_extension,
  25. ext_compress,
  26. ext_database,
  27. ext_hosting_provider,
  28. ext_login,
  29. ext_mail,
  30. ext_migrate,
  31. ext_redis,
  32. ext_sentry,
  33. ext_storage,
  34. )
  35. from extensions.ext_database import db
  36. from extensions.ext_login import login_manager
  37. from libs.passport import PassportService
  38. # TODO: Find a way to avoid importing models here
  39. from models import account, dataset, model, source, task, tool, tools, web
  40. from services.account_service import AccountService
  41. # DO NOT REMOVE ABOVE
  42. warnings.simplefilter("ignore", ResourceWarning)
  43. # fix windows platform
  44. if os.name == "nt":
  45. os.system('tzutil /s "UTC"')
  46. else:
  47. os.environ['TZ'] = 'UTC'
  48. time.tzset()
  49. class DifyApp(Flask):
  50. pass
  51. # -------------
  52. # Configuration
  53. # -------------
  54. config_type = os.getenv('EDITION', default='SELF_HOSTED') # ce edition first
  55. # ----------------------------
  56. # Application Factory Function
  57. # ----------------------------
  58. def create_flask_app_with_configs() -> Flask:
  59. """
  60. create a raw flask app
  61. with configs loaded from .env file
  62. """
  63. dify_app = DifyApp(__name__)
  64. dify_app.config.from_mapping(dify_config.model_dump())
  65. # populate configs into system environment variables
  66. for key, value in dify_app.config.items():
  67. if isinstance(value, str):
  68. os.environ[key] = value
  69. elif isinstance(value, int | float | bool):
  70. os.environ[key] = str(value)
  71. elif value is None:
  72. os.environ[key] = ''
  73. return dify_app
  74. def create_app() -> Flask:
  75. app = create_flask_app_with_configs()
  76. app.secret_key = app.config['SECRET_KEY']
  77. log_handlers = None
  78. log_file = app.config.get('LOG_FILE')
  79. if log_file:
  80. log_dir = os.path.dirname(log_file)
  81. os.makedirs(log_dir, exist_ok=True)
  82. log_handlers = [
  83. RotatingFileHandler(
  84. filename=log_file,
  85. maxBytes=1024 * 1024 * 1024,
  86. backupCount=5
  87. ),
  88. logging.StreamHandler(sys.stdout)
  89. ]
  90. logging.basicConfig(
  91. level=app.config.get('LOG_LEVEL'),
  92. format=app.config.get('LOG_FORMAT'),
  93. datefmt=app.config.get('LOG_DATEFORMAT'),
  94. handlers=log_handlers,
  95. force=True
  96. )
  97. log_tz = app.config.get('LOG_TZ')
  98. if log_tz:
  99. from datetime import datetime
  100. import pytz
  101. timezone = pytz.timezone(log_tz)
  102. def time_converter(seconds):
  103. return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple()
  104. for handler in logging.root.handlers:
  105. handler.formatter.converter = time_converter
  106. initialize_extensions(app)
  107. register_blueprints(app)
  108. register_commands(app)
  109. return app
  110. def initialize_extensions(app):
  111. # Since the application instance is now created, pass it to each Flask
  112. # extension instance to bind it to the Flask application instance (app)
  113. ext_compress.init_app(app)
  114. ext_code_based_extension.init()
  115. ext_database.init_app(app)
  116. ext_migrate.init(app, db)
  117. ext_redis.init_app(app)
  118. ext_storage.init_app(app)
  119. ext_celery.init_app(app)
  120. ext_login.init_app(app)
  121. ext_mail.init_app(app)
  122. ext_hosting_provider.init_app(app)
  123. ext_sentry.init_app(app)
  124. # Flask-Login configuration
  125. @login_manager.request_loader
  126. def load_user_from_request(request_from_flask_login):
  127. """Load user based on the request."""
  128. if request.blueprint not in ['console', 'inner_api']:
  129. return None
  130. # Check if the user_id contains a dot, indicating the old format
  131. auth_header = request.headers.get('Authorization', '')
  132. if not auth_header:
  133. auth_token = request.args.get('_token')
  134. if not auth_token:
  135. raise Unauthorized('Invalid Authorization token.')
  136. else:
  137. if ' ' not in auth_header:
  138. raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
  139. auth_scheme, auth_token = auth_header.split(None, 1)
  140. auth_scheme = auth_scheme.lower()
  141. if auth_scheme != 'bearer':
  142. raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
  143. decoded = PassportService().verify(auth_token)
  144. user_id = decoded.get('user_id')
  145. account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token)
  146. if account:
  147. contexts.tenant_id.set(account.current_tenant_id)
  148. return account
  149. @login_manager.unauthorized_handler
  150. def unauthorized_handler():
  151. """Handle unauthorized requests."""
  152. return Response(json.dumps({
  153. 'code': 'unauthorized',
  154. 'message': "Unauthorized."
  155. }), status=401, content_type="application/json")
  156. # register blueprint routers
  157. def register_blueprints(app):
  158. from controllers.console import bp as console_app_bp
  159. from controllers.files import bp as files_bp
  160. from controllers.inner_api import bp as inner_api_bp
  161. from controllers.service_api import bp as service_api_bp
  162. from controllers.web import bp as web_bp
  163. CORS(service_api_bp,
  164. allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
  165. methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH']
  166. )
  167. app.register_blueprint(service_api_bp)
  168. CORS(web_bp,
  169. resources={
  170. r"/*": {"origins": app.config['WEB_API_CORS_ALLOW_ORIGINS']}},
  171. supports_credentials=True,
  172. allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
  173. methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'],
  174. expose_headers=['X-Version', 'X-Env']
  175. )
  176. app.register_blueprint(web_bp)
  177. CORS(console_app_bp,
  178. resources={
  179. r"/*": {"origins": app.config['CONSOLE_CORS_ALLOW_ORIGINS']}},
  180. supports_credentials=True,
  181. allow_headers=['Content-Type', 'Authorization'],
  182. methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'],
  183. expose_headers=['X-Version', 'X-Env']
  184. )
  185. app.register_blueprint(console_app_bp)
  186. CORS(files_bp,
  187. allow_headers=['Content-Type'],
  188. methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH']
  189. )
  190. app.register_blueprint(files_bp)
  191. app.register_blueprint(inner_api_bp)
  192. # create app
  193. app = create_app()
  194. celery = app.extensions["celery"]
  195. if app.config.get('TESTING'):
  196. print("App is running in TESTING mode")
  197. @app.after_request
  198. def after_request(response):
  199. """Add Version headers to the response."""
  200. response.set_cookie('remember_token', '', expires=0)
  201. response.headers.add('X-Version', app.config['CURRENT_VERSION'])
  202. response.headers.add('X-Env', app.config['DEPLOY_ENV'])
  203. return response
  204. @app.route('/health')
  205. def health():
  206. return Response(json.dumps({
  207. 'pid': os.getpid(),
  208. 'status': 'ok',
  209. 'version': app.config['CURRENT_VERSION']
  210. }), status=200, content_type="application/json")
  211. @app.route('/threads')
  212. def threads():
  213. num_threads = threading.active_count()
  214. threads = threading.enumerate()
  215. thread_list = []
  216. for thread in threads:
  217. thread_name = thread.name
  218. thread_id = thread.ident
  219. is_alive = thread.is_alive()
  220. thread_list.append({
  221. 'name': thread_name,
  222. 'id': thread_id,
  223. 'is_alive': is_alive
  224. })
  225. return {
  226. 'pid': os.getpid(),
  227. 'thread_num': num_threads,
  228. 'threads': thread_list
  229. }
  230. @app.route('/db-pool-stat')
  231. def pool_stat():
  232. engine = db.engine
  233. return {
  234. 'pid': os.getpid(),
  235. 'pool_size': engine.pool.size(),
  236. 'checked_in_connections': engine.pool.checkedin(),
  237. 'checked_out_connections': engine.pool.checkedout(),
  238. 'overflow_connections': engine.pool.overflow(),
  239. 'connection_timeout': engine.pool.timeout(),
  240. 'recycle_time': db.engine.pool._recycle
  241. }
  242. if __name__ == '__main__':
  243. app.run(host='0.0.0.0', port=5001)