app.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. import os
  2. from configs.app_configs import DifyConfigs
  3. if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
  4. from gevent import monkey
  5. monkey.patch_all()
  6. import grpc.experimental.gevent
  7. grpc.experimental.gevent.init_gevent()
  8. import json
  9. import logging
  10. import sys
  11. import threading
  12. import time
  13. import warnings
  14. from logging.handlers import RotatingFileHandler
  15. from flask import Flask, Response, request
  16. from flask_cors import CORS
  17. from werkzeug.exceptions import Unauthorized
  18. from commands import register_commands
  19. from config import Config
  20. # DO NOT REMOVE BELOW
  21. from events import event_handlers
  22. from extensions import (
  23. ext_celery,
  24. ext_code_based_extension,
  25. ext_compress,
  26. ext_database,
  27. ext_hosting_provider,
  28. ext_login,
  29. ext_mail,
  30. ext_migrate,
  31. ext_redis,
  32. ext_sentry,
  33. ext_storage,
  34. )
  35. from extensions.ext_database import db
  36. from extensions.ext_login import login_manager
  37. from libs.passport import PassportService
  38. from models import account, dataset, model, source, task, tool, tools, web
  39. from services.account_service import AccountService
  40. # DO NOT REMOVE ABOVE
  41. warnings.simplefilter("ignore", ResourceWarning)
  42. # fix windows platform
  43. if os.name == "nt":
  44. os.system('tzutil /s "UTC"')
  45. else:
  46. os.environ['TZ'] = 'UTC'
  47. time.tzset()
  48. class DifyApp(Flask):
  49. pass
  50. # -------------
  51. # Configuration
  52. # -------------
  53. config_type = os.getenv('EDITION', default='SELF_HOSTED') # ce edition first
  54. # ----------------------------
  55. # Application Factory Function
  56. # ----------------------------
  57. def create_flask_app_with_configs() -> Flask:
  58. """
  59. create a raw flask app
  60. with configs loaded from .env file
  61. """
  62. dify_app = DifyApp(__name__)
  63. dify_app.config.from_object(Config())
  64. dify_app.config.from_mapping(DifyConfigs().model_dump())
  65. return dify_app
  66. def create_app() -> Flask:
  67. app = create_flask_app_with_configs()
  68. app.secret_key = app.config['SECRET_KEY']
  69. log_handlers = None
  70. log_file = app.config.get('LOG_FILE')
  71. if log_file:
  72. log_dir = os.path.dirname(log_file)
  73. os.makedirs(log_dir, exist_ok=True)
  74. log_handlers = [
  75. RotatingFileHandler(
  76. filename=log_file,
  77. maxBytes=1024 * 1024 * 1024,
  78. backupCount=5
  79. ),
  80. logging.StreamHandler(sys.stdout)
  81. ]
  82. logging.basicConfig(
  83. level=app.config.get('LOG_LEVEL'),
  84. format=app.config.get('LOG_FORMAT'),
  85. datefmt=app.config.get('LOG_DATEFORMAT'),
  86. handlers=log_handlers
  87. )
  88. initialize_extensions(app)
  89. register_blueprints(app)
  90. register_commands(app)
  91. return app
  92. def initialize_extensions(app):
  93. # Since the application instance is now created, pass it to each Flask
  94. # extension instance to bind it to the Flask application instance (app)
  95. ext_compress.init_app(app)
  96. ext_code_based_extension.init()
  97. ext_database.init_app(app)
  98. ext_migrate.init(app, db)
  99. ext_redis.init_app(app)
  100. ext_storage.init_app(app)
  101. ext_celery.init_app(app)
  102. ext_login.init_app(app)
  103. ext_mail.init_app(app)
  104. ext_hosting_provider.init_app(app)
  105. ext_sentry.init_app(app)
  106. # Flask-Login configuration
  107. @login_manager.request_loader
  108. def load_user_from_request(request_from_flask_login):
  109. """Load user based on the request."""
  110. if request.blueprint in ['console', 'inner_api']:
  111. # Check if the user_id contains a dot, indicating the old format
  112. auth_header = request.headers.get('Authorization', '')
  113. if not auth_header:
  114. auth_token = request.args.get('_token')
  115. if not auth_token:
  116. raise Unauthorized('Invalid Authorization token.')
  117. else:
  118. if ' ' not in auth_header:
  119. raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
  120. auth_scheme, auth_token = auth_header.split(None, 1)
  121. auth_scheme = auth_scheme.lower()
  122. if auth_scheme != 'bearer':
  123. raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
  124. decoded = PassportService().verify(auth_token)
  125. user_id = decoded.get('user_id')
  126. return AccountService.load_user(user_id)
  127. else:
  128. return None
  129. @login_manager.unauthorized_handler
  130. def unauthorized_handler():
  131. """Handle unauthorized requests."""
  132. return Response(json.dumps({
  133. 'code': 'unauthorized',
  134. 'message': "Unauthorized."
  135. }), status=401, content_type="application/json")
  136. # register blueprint routers
  137. def register_blueprints(app):
  138. from controllers.console import bp as console_app_bp
  139. from controllers.files import bp as files_bp
  140. from controllers.inner_api import bp as inner_api_bp
  141. from controllers.service_api import bp as service_api_bp
  142. from controllers.web import bp as web_bp
  143. CORS(service_api_bp,
  144. allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
  145. methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH']
  146. )
  147. app.register_blueprint(service_api_bp)
  148. CORS(web_bp,
  149. resources={
  150. r"/*": {"origins": app.config['WEB_API_CORS_ALLOW_ORIGINS']}},
  151. supports_credentials=True,
  152. allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
  153. methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'],
  154. expose_headers=['X-Version', 'X-Env']
  155. )
  156. app.register_blueprint(web_bp)
  157. CORS(console_app_bp,
  158. resources={
  159. r"/*": {"origins": app.config['CONSOLE_CORS_ALLOW_ORIGINS']}},
  160. supports_credentials=True,
  161. allow_headers=['Content-Type', 'Authorization'],
  162. methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'],
  163. expose_headers=['X-Version', 'X-Env']
  164. )
  165. app.register_blueprint(console_app_bp)
  166. CORS(files_bp,
  167. allow_headers=['Content-Type'],
  168. methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH']
  169. )
  170. app.register_blueprint(files_bp)
  171. app.register_blueprint(inner_api_bp)
  172. # create app
  173. app = create_app()
  174. celery = app.extensions["celery"]
  175. if app.config['TESTING']:
  176. print("App is running in TESTING mode")
  177. @app.after_request
  178. def after_request(response):
  179. """Add Version headers to the response."""
  180. response.set_cookie('remember_token', '', expires=0)
  181. response.headers.add('X-Version', app.config['CURRENT_VERSION'])
  182. response.headers.add('X-Env', app.config['DEPLOY_ENV'])
  183. return response
  184. @app.route('/health')
  185. def health():
  186. return Response(json.dumps({
  187. 'status': 'ok',
  188. 'version': app.config['CURRENT_VERSION']
  189. }), status=200, content_type="application/json")
  190. @app.route('/threads')
  191. def threads():
  192. num_threads = threading.active_count()
  193. threads = threading.enumerate()
  194. thread_list = []
  195. for thread in threads:
  196. thread_name = thread.name
  197. thread_id = thread.ident
  198. is_alive = thread.is_alive()
  199. thread_list.append({
  200. 'name': thread_name,
  201. 'id': thread_id,
  202. 'is_alive': is_alive
  203. })
  204. return {
  205. 'thread_num': num_threads,
  206. 'threads': thread_list
  207. }
  208. @app.route('/db-pool-stat')
  209. def pool_stat():
  210. engine = db.engine
  211. return {
  212. 'pool_size': engine.pool.size(),
  213. 'checked_in_connections': engine.pool.checkedin(),
  214. 'checked_out_connections': engine.pool.checkedout(),
  215. 'overflow_connections': engine.pool.overflow(),
  216. 'connection_timeout': engine.pool.timeout(),
  217. 'recycle_time': db.engine.pool._recycle
  218. }
  219. if __name__ == '__main__':
  220. app.run(host='0.0.0.0', port=5001)