Browse Source

refactor(models&tools): switch to dify_config in models and tools. (#6394)

Co-authored-by: Poorandy <andymonicamua1@gmail.com>
Poorandy 9 tháng trước cách đây
mục cha
commit
8a80af39c9

+ 75 - 46
api/core/tools/tool_file_manager.py

@@ -9,9 +9,9 @@ from mimetypes import guess_extension, guess_type
 from typing import Optional, Union
 from uuid import uuid4
 
-from flask import current_app
 from httpx import get
 
+from configs import dify_config
 from extensions.ext_database import db
 from extensions.ext_storage import storage
 from models.model import MessageFile
@@ -26,25 +26,25 @@ class ToolFileManager:
         """
         sign file to get a temporary url
         """
-        base_url = current_app.config.get('FILES_URL')
+        base_url = dify_config.FILES_URL
         file_preview_url = f'{base_url}/files/tools/{tool_file_id}{extension}'
 
         timestamp = str(int(time.time()))
         nonce = os.urandom(16).hex()
-        data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}"
-        secret_key = current_app.config['SECRET_KEY'].encode()
+        data_to_sign = f'file-preview|{tool_file_id}|{timestamp}|{nonce}'
+        secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b''
         sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
         encoded_sign = base64.urlsafe_b64encode(sign).decode()
 
-        return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
+        return f'{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}'
 
     @staticmethod
     def verify_file(file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
         """
         verify signature
         """
-        data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}"
-        secret_key = current_app.config['SECRET_KEY'].encode()
+        data_to_sign = f'file-preview|{file_id}|{timestamp}|{nonce}'
+        secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b''
         recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
         recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
 
@@ -53,23 +53,23 @@ class ToolFileManager:
             return False
 
         current_time = int(time.time())
-        return current_time - int(timestamp) <= current_app.config.get('FILES_ACCESS_TIMEOUT')
+        return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
 
     @staticmethod
-    def create_file_by_raw(user_id: str, tenant_id: str,
-                           conversation_id: Optional[str], file_binary: bytes,
-                           mimetype: str
-                           ) -> ToolFile:
+    def create_file_by_raw(
+        user_id: str, tenant_id: str, conversation_id: Optional[str], file_binary: bytes, mimetype: str
+    ) -> ToolFile:
         """
         create file
         """
         extension = guess_extension(mimetype) or '.bin'
         unique_name = uuid4().hex
-        filename = f"tools/{tenant_id}/{unique_name}{extension}"
+        filename = f'tools/{tenant_id}/{unique_name}{extension}'
         storage.save(filename, file_binary)
 
-        tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id,
-                             conversation_id=conversation_id, file_key=filename, mimetype=mimetype)
+        tool_file = ToolFile(
+            user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_key=filename, mimetype=mimetype
+        )
 
         db.session.add(tool_file)
         db.session.commit()
@@ -77,9 +77,12 @@ class ToolFileManager:
         return tool_file
 
     @staticmethod
-    def create_file_by_url(user_id: str, tenant_id: str,
-                           conversation_id: str, file_url: str,
-                           ) -> ToolFile:
+    def create_file_by_url(
+        user_id: str,
+        tenant_id: str,
+        conversation_id: str,
+        file_url: str,
+    ) -> ToolFile:
         """
         create file
         """
@@ -90,12 +93,17 @@ class ToolFileManager:
         mimetype = guess_type(file_url)[0] or 'octet/stream'
         extension = guess_extension(mimetype) or '.bin'
         unique_name = uuid4().hex
-        filename = f"tools/{tenant_id}/{unique_name}{extension}"
+        filename = f'tools/{tenant_id}/{unique_name}{extension}'
         storage.save(filename, blob)
 
-        tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id,
-                             conversation_id=conversation_id, file_key=filename,
-                             mimetype=mimetype, original_url=file_url)
+        tool_file = ToolFile(
+            user_id=user_id,
+            tenant_id=tenant_id,
+            conversation_id=conversation_id,
+            file_key=filename,
+            mimetype=mimetype,
+            original_url=file_url,
+        )
 
         db.session.add(tool_file)
         db.session.commit()
@@ -103,15 +111,15 @@ class ToolFileManager:
         return tool_file
 
     @staticmethod
-    def create_file_by_key(user_id: str, tenant_id: str,
-                           conversation_id: str, file_key: str,
-                           mimetype: str
-                           ) -> ToolFile:
+    def create_file_by_key(
+        user_id: str, tenant_id: str, conversation_id: str, file_key: str, mimetype: str
+    ) -> ToolFile:
         """
         create file
         """
-        tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id,
-                             conversation_id=conversation_id, file_key=file_key, mimetype=mimetype)
+        tool_file = ToolFile(
+            user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_key=file_key, mimetype=mimetype
+        )
         return tool_file
 
     @staticmethod
@@ -123,9 +131,13 @@ class ToolFileManager:
 
         :return: the binary of the file, mime type
         """
-        tool_file: ToolFile = db.session.query(ToolFile).filter(
-            ToolFile.id == id,
-        ).first()
+        tool_file: ToolFile = (
+            db.session.query(ToolFile)
+            .filter(
+                ToolFile.id == id,
+            )
+            .first()
+        )
 
         if not tool_file:
             return None
@@ -143,18 +155,31 @@ class ToolFileManager:
 
         :return: the binary of the file, mime type
         """
-        message_file: MessageFile = db.session.query(MessageFile).filter(
-            MessageFile.id == id,
-        ).first()
-
-        # get tool file id
-        tool_file_id = message_file.url.split('/')[-1]
-        # trim extension
-        tool_file_id = tool_file_id.split('.')[0]
-
-        tool_file: ToolFile = db.session.query(ToolFile).filter(
-            ToolFile.id == tool_file_id,
-        ).first()
+        message_file: MessageFile = (
+            db.session.query(MessageFile)
+            .filter(
+                MessageFile.id == id,
+            )
+            .first()
+        )
+
+        # Check if message_file is not None
+        if message_file is not None:
+            # get tool file id
+            tool_file_id = message_file.url.split('/')[-1]
+            # trim extension
+            tool_file_id = tool_file_id.split('.')[0]
+        else:
+            tool_file_id = None
+
+
+        tool_file: ToolFile = (
+            db.session.query(ToolFile)
+            .filter(
+                ToolFile.id == tool_file_id,
+            )
+            .first()
+        )
 
         if not tool_file:
             return None
@@ -172,9 +197,13 @@ class ToolFileManager:
 
         :return: the binary of the file, mime type
         """
-        tool_file: ToolFile = db.session.query(ToolFile).filter(
-            ToolFile.id == tool_file_id,
-        ).first()
+        tool_file: ToolFile = (
+            db.session.query(ToolFile)
+            .filter(
+                ToolFile.id == tool_file_id,
+            )
+            .first()
+        )
 
         if not tool_file:
             return None

+ 3 - 4
api/core/tools/tool_manager.py

@@ -6,8 +6,7 @@ from os import listdir, path
 from threading import Lock
 from typing import Any, Union
 
-from flask import current_app
-
+from configs import dify_config
 from core.agent.entities import AgentToolEntity
 from core.app.entities.app_invoke_entities import InvokeFrom
 from core.helper.module_import_helper import load_single_subclass_from_source
@@ -566,7 +565,7 @@ class ToolManager:
         provider_type = provider_type
         provider_id = provider_id
         if provider_type == 'builtin':
-            return (current_app.config.get("CONSOLE_API_URL")
+            return (dify_config.CONSOLE_API_URL
                     + "/console/api/workspaces/current/tool-provider/builtin/"
                     + provider_id
                     + "/icon")
@@ -594,4 +593,4 @@ class ToolManager:
         else:
             raise ValueError(f"provider type {provider_type} not found")
 
-ToolManager.load_builtin_providers_cache()
+ToolManager.load_builtin_providers_cache()

+ 21 - 22
api/core/workflow/workflow_engine_manager.py

@@ -2,8 +2,7 @@ import logging
 import time
 from typing import Optional, cast
 
-from flask import current_app
-
+from configs import dify_config
 from core.app.app_config.entities import FileExtraConfig
 from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
 from core.app.entities.app_invoke_entities import InvokeFrom
@@ -118,7 +117,7 @@ class WorkflowEngineManager:
 
         if not isinstance(graph.get('edges'), list):
             raise ValueError('edges in workflow graph must be a list')
-        
+
         # init variable pool
         if not variable_pool:
             variable_pool = VariablePool(
@@ -126,7 +125,7 @@ class WorkflowEngineManager:
                 user_inputs=user_inputs
             )
 
-        workflow_call_max_depth = current_app.config.get("WORKFLOW_CALL_MAX_DEPTH")
+        workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH
         if call_depth > workflow_call_max_depth:
             raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth))
 
@@ -177,8 +176,8 @@ class WorkflowEngineManager:
             predecessor_node: BaseNode = None
             current_iteration_node: BaseIterationNode = None
             has_entry_node = False
-            max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS")
-            max_execution_time = current_app.config.get("WORKFLOW_MAX_EXECUTION_TIME")
+            max_execution_steps = dify_config.WORKFLOW_MAX_EXECUTION_STEPS
+            max_execution_time = dify_config.WORKFLOW_MAX_EXECUTION_TIME
             while True:
                 # get next node, multiple target nodes in the future
                 next_node = self._get_next_overall_node(
@@ -237,7 +236,7 @@ class WorkflowEngineManager:
                             next_node_id = next_iteration
                             # get next id
                             next_node = self._get_node(workflow_run_state, graph, next_node_id, callbacks)
-                
+
                 if not next_node:
                     break
 
@@ -398,7 +397,7 @@ class WorkflowEngineManager:
                 tenant_id=workflow.tenant_id,
                 node_instance=node_instance
             )
-            
+
             # run node
             node_run_result = node_instance.run(
                 variable_pool=variable_pool
@@ -443,7 +442,7 @@ class WorkflowEngineManager:
                     node_config = node
                 else:
                     raise ValueError('node id is not an iteration node')
-        
+
         # init variable pool
         variable_pool = VariablePool(
             system_variables={},
@@ -452,7 +451,7 @@ class WorkflowEngineManager:
 
         # variable selector to variable mapping
         iteration_nested_nodes = [
-            node for node in nodes 
+            node for node in nodes
             if node.get('data', {}).get('iteration_id') == node_id or node.get('id') == node_id
         ]
         iteration_nested_node_ids = [node.get('id') for node in iteration_nested_nodes]
@@ -475,13 +474,13 @@ class WorkflowEngineManager:
 
             # remove iteration variables
             variable_mapping = {
-                f'{node_config.get("id")}.{key}': value for key, value in variable_mapping.items() 
+                f'{node_config.get("id")}.{key}': value for key, value in variable_mapping.items()
                 if value[0] != node_id
             }
 
             # remove variable out from iteration
             variable_mapping = {
-                key: value for key, value in variable_mapping.items() 
+                key: value for key, value in variable_mapping.items()
                 if value[0] not in iteration_nested_node_ids
             }
 
@@ -561,7 +560,7 @@ class WorkflowEngineManager:
                     error=error
                 )
 
-    def _workflow_iteration_started(self, graph: dict, 
+    def _workflow_iteration_started(self, graph: dict,
                                     current_iteration_node: BaseIterationNode,
                                     workflow_run_state: WorkflowRunState,
                                     predecessor_node_id: Optional[str] = None,
@@ -600,7 +599,7 @@ class WorkflowEngineManager:
 
     def _workflow_iteration_next(self, graph: dict,
                                  current_iteration_node: BaseIterationNode,
-                                 workflow_run_state: WorkflowRunState, 
+                                 workflow_run_state: WorkflowRunState,
                                  callbacks: list[BaseWorkflowCallback] = None) -> None:
         """
         Workflow iteration next
@@ -629,9 +628,9 @@ class WorkflowEngineManager:
 
         for node in nodes:
             workflow_run_state.variable_pool.clear_node_variables(node_id=node.get('id'))
-    
+
     def _workflow_iteration_completed(self, current_iteration_node: BaseIterationNode,
-                                        workflow_run_state: WorkflowRunState, 
+                                        workflow_run_state: WorkflowRunState,
                                         callbacks: list[BaseWorkflowCallback] = None) -> None:
         if callbacks:
             if isinstance(workflow_run_state.current_iteration_state, IterationState):
@@ -684,7 +683,7 @@ class WorkflowEngineManager:
                         callbacks=callbacks,
                         workflow_call_depth=workflow_run_state.workflow_call_depth
                     )
-                
+
         else:
             edges = graph.get('edges')
             source_node_id = predecessor_node.node_id
@@ -738,9 +737,9 @@ class WorkflowEngineManager:
                 callbacks=callbacks,
                 workflow_call_depth=workflow_run_state.workflow_call_depth
             )
-        
-    def _get_node(self, workflow_run_state: WorkflowRunState, 
-                  graph: dict, 
+
+    def _get_node(self, workflow_run_state: WorkflowRunState,
+                  graph: dict,
                   node_id: str,
                   callbacks: list[BaseWorkflowCallback]) -> Optional[BaseNode]:
         """
@@ -940,7 +939,7 @@ class WorkflowEngineManager:
 
         return new_value
 
-    def _mapping_user_inputs_to_variable_pool(self, 
+    def _mapping_user_inputs_to_variable_pool(self,
                                               variable_mapping: dict,
                                               user_inputs: dict,
                                               variable_pool: VariablePool,
@@ -988,4 +987,4 @@ class WorkflowEngineManager:
                 node_id=variable_node_id,
                 variable_key_list=variable_key_list,
                 value=value
-            )
+            )

+ 2 - 2
api/models/dataset.py

@@ -9,10 +9,10 @@ import re
 import time
 from json import JSONDecodeError
 
-from flask import current_app
 from sqlalchemy import func
 from sqlalchemy.dialects.postgresql import JSONB
 
+from configs import dify_config
 from core.rag.retrieval.retrival_methods import RetrievalMethod
 from extensions.ext_database import db
 from extensions.ext_storage import storage
@@ -528,7 +528,7 @@ class DocumentSegment(db.Model):
             nonce = os.urandom(16).hex()
             timestamp = str(int(time.time()))
             data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
-            secret_key = current_app.config['SECRET_KEY'].encode()
+            secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b''
             sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
             encoded_sign = base64.urlsafe_b64encode(sign).decode()
 

+ 4 - 3
api/models/model.py

@@ -4,10 +4,11 @@ import uuid
 from enum import Enum
 from typing import Optional
 
-from flask import current_app, request
+from flask import request
 from flask_login import UserMixin
 from sqlalchemy import Float, func, text
 
+from configs import dify_config
 from core.file.tool_file_parser import ToolFileParser
 from core.file.upload_file_parser import UploadFileParser
 from extensions.ext_database import db
@@ -111,7 +112,7 @@ class App(db.Model):
 
     @property
     def api_base_url(self):
-        return (current_app.config['SERVICE_API_URL'] if current_app.config['SERVICE_API_URL']
+        return (dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL
                 else request.host_url.rstrip('/')) + '/v1'
 
     @property
@@ -1113,7 +1114,7 @@ class Site(db.Model):
     @property
     def app_base_url(self):
         return (
-            current_app.config['APP_WEB_URL'] if current_app.config['APP_WEB_URL'] else request.host_url.rstrip('/'))
+            dify_config.APP_WEB_URL if  dify_config.APP_WEB_URL else request.host_url.rstrip('/'))
 
 
 class ApiToken(db.Model):