Przeglądaj źródła

feat/enhance the multi-modal support (#8818)

-LAN- 6 miesięcy temu
rodzic
commit
e61752bd3a
100 zmienionych plików z 1729 dodań i 1088 usunięć
  1. 3 0
      api/.env.example
  2. 11 4
      api/.vscode/launch.json.example
  3. 6 6
      api/commands.py
  4. 16 5
      api/configs/feature/__init__.py
  5. 20 0
      api/constants/__init__.py
  6. 4 2
      api/contexts/__init__.py
  7. 2 1
      api/controllers/console/app/conversation.py
  8. 1 1
      api/controllers/console/app/site.py
  9. 17 15
      api/controllers/console/app/workflow.py
  10. 2 1
      api/controllers/console/app/workflow_app_log.py
  11. 2 1
      api/controllers/console/app/workflow_run.py
  12. 1 1
      api/controllers/console/app/workflow_statistic.py
  13. 2 1
      api/controllers/console/app/wraps.py
  14. 2 1
      api/controllers/console/auth/oauth.py
  15. 1 2
      api/controllers/console/datasets/data_source.py
  16. 2 2
      api/controllers/console/datasets/datasets.py
  17. 1 2
      api/controllers/console/datasets/datasets_document.py
  18. 1 1
      api/controllers/console/datasets/datasets_segments.py
  19. 23 6
      api/controllers/console/datasets/file.py
  20. 1 1
      api/controllers/console/explore/installed_app.py
  21. 1 1
      api/controllers/console/explore/saved_message.py
  22. 1 1
      api/controllers/console/explore/wraps.py
  23. 1 1
      api/controllers/console/workspace/account.py
  24. 9 10
      api/controllers/console/workspace/tool_providers.py
  25. 1 1
      api/controllers/console/workspace/workspace.py
  26. 35 1
      api/controllers/files/image_preview.py
  27. 15 5
      api/controllers/files/tool_files.py
  28. 2 2
      api/controllers/service_api/app/message.py
  29. 19 1
      api/controllers/web/file.py
  30. 3 2
      api/controllers/web/message.py
  31. 1 1
      api/controllers/web/saved_message.py
  32. 15 33
      api/core/agent/base_agent_runner.py
  33. 6 3
      api/core/agent/cot_chat_agent_runner.py
  34. 10 4
      api/core/agent/fc_agent_runner.py
  35. 2 3
      api/core/app/app_config/easy_ui_based_app/variables/manager.py
  36. 13 9
      api/core/app/app_config/entities.py
  37. 15 37
      api/core/app/app_config/features/file_upload/manager.py
  38. 1 1
      api/core/app/app_config/workflow_ui_based_app/variables/manager.py
  39. 16 6
      api/core/app/apps/advanced_chat/app_generator.py
  40. 5 38
      api/core/app/apps/advanced_chat/app_runner.py
  41. 28 8
      api/core/app/apps/advanced_chat/generate_task_pipeline.py
  42. 23 10
      api/core/app/apps/agent_chat/app_generator.py
  43. 93 24
      api/core/app/apps/base_app_generator.py
  44. 3 3
      api/core/app/apps/base_app_runner.py
  45. 16 6
      api/core/app/apps/chat/app_generator.py
  46. 23 9
      api/core/app/apps/completion/app_generator.py
  47. 5 5
      api/core/app/apps/message_based_app_generator.py
  48. 25 29
      api/core/app/apps/workflow/app_generator.py
  49. 4 5
      api/core/app/apps/workflow/app_runner.py
  50. 1 4
      api/core/app/apps/workflow/generate_task_pipeline.py
  51. 5 7
      api/core/app/apps/workflow_app_runner.py
  52. 4 4
      api/core/app/entities/app_invoke_entities.py
  53. 3 2
      api/core/app/entities/queue_entities.py
  54. 4 2
      api/core/app/entities/task_entities.py
  55. 0 18
      api/core/app/segments/parser.py
  56. 33 30
      api/core/app/task_pipeline/workflow_cycle_manage.py
  57. 0 29
      api/core/entities/message_entities.py
  58. 19 0
      api/core/file/__init__.py
  59. 1 0
      api/core/file/constants.py
  60. 55 0
      api/core/file/enums.py
  61. 156 0
      api/core/file/file_manager.py
  62. 0 145
      api/core/file/file_obj.py
  63. 32 0
      api/core/file/file_repository.py
  64. 48 0
      api/core/file/helpers.py
  65. 0 243
      api/core/file/message_file_parser.py
  66. 140 0
      api/core/file/models.py
  67. 6 1
      api/core/file/tool_file_parser.py
  68. 0 79
      api/core/file/upload_file_parser.py
  69. 12 6
      api/core/helper/ssrf_proxy.py
  70. 13 8
      api/core/memory/token_buffer_memory.py
  71. 4 4
      api/core/model_manager.py
  72. 38 0
      api/core/model_runtime/entities/__init__.py
  73. 9 2
      api/core/model_runtime/entities/message_entities.py
  74. 12 3
      api/core/model_runtime/model_providers/__base/large_language_model.py
  75. 44 25
      api/core/model_runtime/model_providers/__base/tts_model.py
  76. 1 0
      api/core/model_runtime/model_providers/openai/llm/_position.yaml
  77. 44 0
      api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml
  78. 19 10
      api/core/model_runtime/model_providers/openai/llm/llm.py
  79. 2 2
      api/core/ops/ops_trace_manager.py
  80. 62 59
      api/core/prompt/advanced_prompt_transform.py
  81. 11 8
      api/core/prompt/simple_prompt_transform.py
  82. 3 1
      api/core/prompt/utils/extract_thread_messages.py
  83. 13 6
      api/core/prompt/utils/prompt_message_util.py
  84. 1 1
      api/core/rag/extractor/word_extractor.py
  85. 1 1
      api/core/rag/retrieval/router/multi_dataset_react_route.py
  86. 3 3
      api/core/tools/entities/api_entities.py
  87. 64 2
      api/core/tools/entities/tool_entities.py
  88. 1 1
      api/core/tools/provider/builtin/dalle/tools/dalle3.py
  89. 1 1
      api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py
  90. 24 0
      api/core/tools/provider/builtin/podcast_generator/_assets/icon.svg
  91. 33 0
      api/core/tools/provider/builtin/podcast_generator/podcast_generator.py
  92. 34 0
      api/core/tools/provider/builtin/podcast_generator/podcast_generator.yaml
  93. 100 0
      api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py
  94. 95 0
      api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.yaml
  95. 1 4
      api/core/tools/provider/builtin_tool_provider.py
  96. 1 4
      api/core/tools/provider/tool_provider.py
  97. 10 9
      api/core/tools/provider/workflow_tool_provider.py
  98. 9 10
      api/core/tools/tool/tool.py
  99. 24 26
      api/core/tools/tool/workflow_tool.py
  100. 24 16
      api/core/tools/tool_engine.py

+ 3 - 0
api/.env.example

@@ -233,6 +233,8 @@ VIKINGDB_SOCKET_TIMEOUT=30
 UPLOAD_FILE_SIZE_LIMIT=15
 UPLOAD_FILE_BATCH_LIMIT=5
 UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
+UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
+UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
 
 # Model Configuration
 MULTIMODAL_SEND_IMAGE_FORMAT=base64
@@ -310,6 +312,7 @@ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=1000
 WORKFLOW_MAX_EXECUTION_STEPS=500
 WORKFLOW_MAX_EXECUTION_TIME=1200
 WORKFLOW_CALL_MAX_DEPTH=5
+MAX_VARIABLE_SIZE=204800
 
 # App configuration
 APP_MAX_EXECUTION_TIME=1200

+ 11 - 4
api/.vscode/launch.json.example

@@ -1,8 +1,15 @@
 {
     "version": "0.2.0",
+    "compounds": [
+        {
+            "name": "Launch Flask and Celery",
+            "configurations": ["Python: Flask", "Python: Celery"]
+        }
+    ],
     "configurations": [
         {
             "name": "Python: Flask",
+            "consoleName": "Flask",
             "type": "debugpy",
             "request": "launch",
             "python": "${workspaceFolder}/.venv/bin/python",
@@ -17,12 +24,12 @@
             },
             "args": [
                 "run",
-                "--host=0.0.0.0",
                 "--port=5001"
             ]
         },
         {
             "name": "Python: Celery",
+            "consoleName": "Celery",
             "type": "debugpy",
             "request": "launch",
             "python": "${workspaceFolder}/.venv/bin/python",
@@ -45,10 +52,10 @@
                 "-c",
                 "1",
                 "--loglevel",
-                "info",
+                "DEBUG",
                 "-Q",
                 "dataset,generation,mail,ops_trace,app_deletion"
             ]
-        },
+        }
     ]
-}
+}

+ 6 - 6
api/commands.py

@@ -19,7 +19,7 @@ from extensions.ext_redis import redis_client
 from libs.helper import email as email_validate
 from libs.password import hash_password, password_pattern, valid_password
 from libs.rsa import generate_key_pair
-from models.account import Tenant
+from models import Tenant
 from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
 from models.dataset import Document as DatasetDocument
 from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
@@ -426,14 +426,14 @@ def convert_to_agent_apps():
         # fetch first 1000 apps
         sql_query = """SELECT a.id AS id FROM apps a
             INNER JOIN app_model_configs am ON a.app_model_config_id=am.id
-            WHERE a.mode = 'chat' 
-            AND am.agent_mode is not null 
+            WHERE a.mode = 'chat'
+            AND am.agent_mode is not null
             AND (
-				am.agent_mode like '%"strategy": "function_call"%' 
+				am.agent_mode like '%"strategy": "function_call"%'
                 OR am.agent_mode  like '%"strategy": "react"%'
-			) 
+			)
             AND (
-				am.agent_mode like '{"enabled": true%' 
+				am.agent_mode like '{"enabled": true%'
                 OR am.agent_mode like '{"max_iteration": %'
 			) ORDER BY a.created_at DESC LIMIT 1000
         """

+ 16 - 5
api/configs/feature/__init__.py

@@ -20,11 +20,11 @@ class SecurityConfig(BaseSettings):
     Security-related configurations for the application
     """
 
-    SECRET_KEY: Optional[str] = Field(
+    SECRET_KEY: str = Field(
         description="Secret key for secure session cookie signing."
         "Make sure you are changing this key for your deployment with a strong key."
         "Generate a strong key using `openssl rand -base64 42` or set via the `SECRET_KEY` environment variable.",
-        default=None,
+        default="",
     )
 
     RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field(
@@ -186,6 +186,16 @@ class FileUploadConfig(BaseSettings):
         default=10,
     )
 
+    UPLOAD_VIDEO_FILE_SIZE_LIMIT: NonNegativeInt = Field(
+        description="video file size limit in Megabytes for uploading files",
+        default=100,
+    )
+
+    UPLOAD_AUDIO_FILE_SIZE_LIMIT: NonNegativeInt = Field(
+        description="audio file size limit in Megabytes for uploading files",
+        default=50,
+    )
+
     BATCH_UPLOAD_LIMIT: NonNegativeInt = Field(
         description="Maximum number of files allowed in a batch upload operation",
         default=20,
@@ -364,8 +374,8 @@ class WorkflowConfig(BaseSettings):
     )
 
     MAX_VARIABLE_SIZE: PositiveInt = Field(
-        description="Maximum size in bytes for a single variable in workflows. Default to 5KB.",
-        default=5 * 1024,
+        description="Maximum size in bytes for a single variable in workflows. Default to 200 KB.",
+        default=200 * 1024,
     )
 
 
@@ -493,6 +503,7 @@ class RagEtlConfig(BaseSettings):
     Configuration for RAG ETL processes
     """
 
+    # TODO: This config is not only for rag etl, it is also for file upload, we should move it to file upload config
     ETL_TYPE: str = Field(
         description="RAG ETL type ('dify' or 'Unstructured'), default to 'dify'",
         default="dify",
@@ -559,7 +570,7 @@ class IndexingConfig(BaseSettings):
 
 
 class ImageFormatConfig(BaseSettings):
-    MULTIMODAL_SEND_IMAGE_FORMAT: str = Field(
+    MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field(
         description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64",
         default="base64",
     )

+ 20 - 0
api/constants/__init__.py

@@ -1,2 +1,22 @@
+from configs import dify_config
+
 HIDDEN_VALUE = "[__HIDDEN__]"
 UUID_NIL = "00000000-0000-0000-0000-000000000000"
+
+IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
+IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
+
+VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "mpga"]
+VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS])
+
+AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "webm", "amr"]
+AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
+
+
+if dify_config.ETL_TYPE == "Unstructured":
+    DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls"]
+    DOCUMENT_EXTENSIONS.extend(("docx", "csv", "eml", "msg", "pptx", "ppt", "xml", "epub"))
+    DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
+else:
+    DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"]
+    DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])

+ 4 - 2
api/contexts/__init__.py

@@ -1,7 +1,9 @@
 from contextvars import ContextVar
+from typing import TYPE_CHECKING
 
-from core.workflow.entities.variable_pool import VariablePool
+if TYPE_CHECKING:
+    from core.workflow.entities.variable_pool import VariablePool
 
 tenant_id: ContextVar[str] = ContextVar("tenant_id")
 
-workflow_variable_pool: ContextVar[VariablePool] = ContextVar("workflow_variable_pool")
+workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")

+ 2 - 1
api/controllers/console/app/conversation.py

@@ -22,7 +22,8 @@ from fields.conversation_fields import (
 )
 from libs.helper import DatetimeString
 from libs.login import login_required
-from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation
+from models import Conversation, EndUser, Message, MessageAnnotation
+from models.model import AppMode
 
 
 class CompletionConversationApi(Resource):

+ 1 - 1
api/controllers/console/app/site.py

@@ -12,7 +12,7 @@ from controllers.console.wraps import account_initialization_required
 from extensions.ext_database import db
 from fields.app_fields import app_site_fields
 from libs.login import login_required
-from models.model import Site
+from models import Site
 
 
 def parse_app_site_args():

+ 17 - 15
api/controllers/console/app/workflow.py

@@ -13,14 +13,14 @@ from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
 from core.app.apps.base_app_queue_manager import AppQueueManager
 from core.app.entities.app_invoke_entities import InvokeFrom
-from core.app.segments import factory
-from core.errors.error import AppInvokeQuotaExceededError
+from factories import variable_factory
 from fields.workflow_fields import workflow_fields
 from fields.workflow_run_fields import workflow_run_node_execution_fields
 from libs import helper
 from libs.helper import TimestampField, uuid_value
 from libs.login import current_user, login_required
-from models.model import App, AppMode
+from models import App
+from models.model import AppMode
 from services.app_dsl_service import AppDslService
 from services.app_generate_service import AppGenerateService
 from services.errors.app import WorkflowHashNotEqualError
@@ -101,9 +101,13 @@ class DraftWorkflowApi(Resource):
 
         try:
             environment_variables_list = args.get("environment_variables") or []
-            environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
+            environment_variables = [
+                variable_factory.build_variable_from_mapping(obj) for obj in environment_variables_list
+            ]
             conversation_variables_list = args.get("conversation_variables") or []
-            conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list]
+            conversation_variables = [
+                variable_factory.build_variable_from_mapping(obj) for obj in conversation_variables_list
+            ]
             workflow = workflow_service.sync_draft_workflow(
                 app_model=app_model,
                 graph=args["graph"],
@@ -273,17 +277,15 @@ class DraftWorkflowRunApi(Resource):
         parser.add_argument("files", type=list, required=False, location="json")
         args = parser.parse_args()
 
-        try:
-            response = AppGenerateService.generate(
-                app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True
-            )
+        response = AppGenerateService.generate(
+            app_model=app_model,
+            user=current_user,
+            args=args,
+            invoke_from=InvokeFrom.DEBUGGER,
+            streaming=True,
+        )
 
-            return helper.compact_generate_response(response)
-        except (ValueError, AppInvokeQuotaExceededError) as e:
-            raise e
-        except Exception as e:
-            logging.exception("internal server error.")
-            raise InternalServerError()
+        return helper.compact_generate_response(response)
 
 
 class WorkflowTaskStopApi(Resource):

+ 2 - 1
api/controllers/console/app/workflow_app_log.py

@@ -7,7 +7,8 @@ from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
 from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
 from libs.login import login_required
-from models.model import App, AppMode
+from models import App
+from models.model import AppMode
 from services.workflow_app_service import WorkflowAppService
 
 

+ 2 - 1
api/controllers/console/app/workflow_run.py

@@ -13,7 +13,8 @@ from fields.workflow_run_fields import (
 )
 from libs.helper import uuid_value
 from libs.login import login_required
-from models.model import App, AppMode
+from models import App
+from models.model import AppMode
 from services.workflow_run_service import WorkflowRunService
 
 

+ 1 - 1
api/controllers/console/app/workflow_statistic.py

@@ -13,8 +13,8 @@ from controllers.console.wraps import account_initialization_required
 from extensions.ext_database import db
 from libs.helper import DatetimeString
 from libs.login import login_required
+from models.enums import WorkflowRunTriggeredFrom
 from models.model import AppMode
-from models.workflow import WorkflowRunTriggeredFrom
 
 
 class WorkflowDailyRunsStatistic(Resource):

+ 2 - 1
api/controllers/console/app/wraps.py

@@ -5,7 +5,8 @@ from typing import Optional, Union
 from controllers.console.app.error import AppNotFoundError
 from extensions.ext_database import db
 from libs.login import current_user
-from models.model import App, AppMode
+from models import App
+from models.model import AppMode
 
 
 def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None):

+ 2 - 1
api/controllers/console/auth/oauth.py

@@ -13,7 +13,8 @@ from events.tenant_event import tenant_was_created
 from extensions.ext_database import db
 from libs.helper import extract_remote_ip
 from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
-from models.account import Account, AccountStatus
+from models import Account
+from models.account import AccountStatus
 from services.account_service import AccountService, RegisterService, TenantService
 from services.errors.account import AccountNotFoundError
 from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError

+ 1 - 2
api/controllers/console/datasets/data_source.py

@@ -15,8 +15,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor
 from extensions.ext_database import db
 from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
 from libs.login import login_required
-from models.dataset import Document
-from models.source import DataSourceOauthBinding
+from models import DataSourceOauthBinding, Document
 from services.dataset_service import DatasetService, DocumentService
 from tasks.document_indexing_sync_task import document_indexing_sync_task
 

+ 2 - 2
api/controllers/console/datasets/datasets.py

@@ -24,8 +24,8 @@ from fields.app_fields import related_app_list
 from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
 from fields.document_fields import document_status_fields
 from libs.login import login_required
-from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment
-from models.model import ApiToken, UploadFile
+from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
+from models.dataset import DatasetPermissionEnum
 from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
 
 

+ 1 - 2
api/controllers/console/datasets/datasets_document.py

@@ -46,8 +46,7 @@ from fields.document_fields import (
     document_with_segments_fields,
 )
 from libs.login import login_required
-from models.dataset import Dataset, DatasetProcessRule, Document, DocumentSegment
-from models.model import UploadFile
+from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
 from services.dataset_service import DatasetService, DocumentService
 from tasks.add_document_to_index_task import add_document_to_index_task
 from tasks.remove_document_from_index_task import remove_document_from_index_task

+ 1 - 1
api/controllers/console/datasets/datasets_segments.py

@@ -24,7 +24,7 @@ from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from fields.segment_fields import segment_fields
 from libs.login import login_required
-from models.dataset import DocumentSegment
+from models import DocumentSegment
 from services.dataset_service import DatasetService, DocumentService, SegmentService
 from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
 from tasks.disable_segment_from_index_task import disable_segment_from_index_task

+ 23 - 6
api/controllers/console/datasets/file.py

@@ -1,9 +1,12 @@
+import urllib.parse
+
 from flask import request
 from flask_login import current_user
 from flask_restful import Resource, marshal_with
 
 import services
 from configs import dify_config
+from constants import DOCUMENT_EXTENSIONS
 from controllers.console import api
 from controllers.console.datasets.error import (
     FileTooLargeError,
@@ -13,9 +16,10 @@ from controllers.console.datasets.error import (
 )
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
-from fields.file_fields import file_fields, upload_config_fields
+from core.helper import ssrf_proxy
+from fields.file_fields import file_fields, remote_file_info_fields, upload_config_fields
 from libs.login import login_required
-from services.file_service import ALLOWED_EXTENSIONS, UNSTRUCTURED_ALLOWED_EXTENSIONS, FileService
+from services.file_service import FileService
 
 PREVIEW_WORDS_LIMIT = 3000
 
@@ -51,7 +55,7 @@ class FileApi(Resource):
         if len(request.files) > 1:
             raise TooManyFilesError()
         try:
-            upload_file = FileService.upload_file(file, current_user)
+            upload_file = FileService.upload_file(file=file, user=current_user)
         except services.errors.file.FileTooLargeError as file_too_large_error:
             raise FileTooLargeError(file_too_large_error.description)
         except services.errors.file.UnsupportedFileTypeError:
@@ -75,11 +79,24 @@ class FileSupportTypeApi(Resource):
     @login_required
     @account_initialization_required
     def get(self):
-        etl_type = dify_config.ETL_TYPE
-        allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == "Unstructured" else ALLOWED_EXTENSIONS
-        return {"allowed_extensions": allowed_extensions}
+        return {"allowed_extensions": DOCUMENT_EXTENSIONS}
+
+
+class RemoteFileInfoApi(Resource):
+    @marshal_with(remote_file_info_fields)
+    def get(self, url):
+        decoded_url = urllib.parse.unquote(url)
+        try:
+            response = ssrf_proxy.head(decoded_url)
+            return {
+                "file_type": response.headers.get("Content-Type", "application/octet-stream"),
+                "file_length": int(response.headers.get("Content-Length", 0)),
+            }
+        except Exception as e:
+            return {"error": str(e)}, 400
 
 
 api.add_resource(FileApi, "/files/upload")
 api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview")
 api.add_resource(FileSupportTypeApi, "/files/support-type")
+api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")

+ 1 - 1
api/controllers/console/explore/installed_app.py

@@ -11,7 +11,7 @@ from controllers.console.wraps import account_initialization_required, cloud_edi
 from extensions.ext_database import db
 from fields.installed_app_fields import installed_app_list_fields
 from libs.login import login_required
-from models.model import App, InstalledApp, RecommendedApp
+from models import App, InstalledApp, RecommendedApp
 from services.account_service import TenantService
 
 

+ 1 - 1
api/controllers/console/explore/saved_message.py

@@ -18,7 +18,7 @@ message_fields = {
     "inputs": fields.Raw,
     "query": fields.String,
     "answer": fields.String,
-    "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
+    "message_files": fields.List(fields.Nested(message_file_fields)),
     "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
     "created_at": TimestampField,
 }

+ 1 - 1
api/controllers/console/explore/wraps.py

@@ -7,7 +7,7 @@ from werkzeug.exceptions import NotFound
 from controllers.console.wraps import account_initialization_required
 from extensions.ext_database import db
 from libs.login import login_required
-from models.model import InstalledApp
+from models import InstalledApp
 
 
 def installed_app_required(view=None):

+ 1 - 1
api/controllers/console/workspace/account.py

@@ -20,7 +20,7 @@ from extensions.ext_database import db
 from fields.member_fields import account_fields
 from libs.helper import TimestampField, timezone
 from libs.login import login_required
-from models.account import AccountIntegrate, InvitationCode
+from models import AccountIntegrate, InvitationCode
 from services.account_service import AccountService
 from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
 

+ 9 - 10
api/controllers/console/workspace/tool_providers.py

@@ -360,16 +360,15 @@ class ToolWorkflowProviderCreateApi(Resource):
         args = reqparser.parse_args()
 
         return WorkflowToolManageService.create_workflow_tool(
-            user_id,
-            tenant_id,
-            args["workflow_app_id"],
-            args["name"],
-            args["label"],
-            args["icon"],
-            args["description"],
-            args["parameters"],
-            args["privacy_policy"],
-            args.get("labels", []),
+            user_id=user_id,
+            tenant_id=tenant_id,
+            workflow_app_id=args["workflow_app_id"],
+            name=args["name"],
+            label=args["label"],
+            icon=args["icon"],
+            description=args["description"],
+            parameters=args["parameters"],
+            privacy_policy=args["privacy_policy"],
         )
 
 

+ 1 - 1
api/controllers/console/workspace/workspace.py

@@ -198,7 +198,7 @@ class WebappLogoWorkspaceApi(Resource):
             raise UnsupportedFileTypeError()
 
         try:
-            upload_file = FileService.upload_file(file, current_user, True)
+            upload_file = FileService.upload_file(file=file, user=current_user)
 
         except services.errors.file.FileTooLargeError as file_too_large_error:
             raise FileTooLargeError(file_too_large_error.description)

+ 35 - 1
api/controllers/files/image_preview.py

@@ -10,6 +10,34 @@ from services.file_service import FileService
 
 
 class ImagePreviewApi(Resource):
+    """
+    Deprecated
+    """
+
+    def get(self, file_id):
+        file_id = str(file_id)
+
+        timestamp = request.args.get("timestamp")
+        nonce = request.args.get("nonce")
+        sign = request.args.get("sign")
+
+        if not timestamp or not nonce or not sign:
+            return {"content": "Invalid request."}, 400
+
+        try:
+            generator, mimetype = FileService.get_image_preview(
+                file_id=file_id,
+                timestamp=timestamp,
+                nonce=nonce,
+                sign=sign,
+            )
+        except services.errors.file.UnsupportedFileTypeError:
+            raise UnsupportedFileTypeError()
+
+        return Response(generator, mimetype=mimetype)
+
+
+class FilePreviewApi(Resource):
     def get(self, file_id):
         file_id = str(file_id)
 
@@ -21,7 +49,12 @@ class ImagePreviewApi(Resource):
             return {"content": "Invalid request."}, 400
 
         try:
-            generator, mimetype = FileService.get_image_preview(file_id, timestamp, nonce, sign)
+            generator, mimetype = FileService.get_signed_file_preview(
+                file_id=file_id,
+                timestamp=timestamp,
+                nonce=nonce,
+                sign=sign,
+            )
         except services.errors.file.UnsupportedFileTypeError:
             raise UnsupportedFileTypeError()
 
@@ -49,4 +82,5 @@ class WorkspaceWebappLogoApi(Resource):
 
 
 api.add_resource(ImagePreviewApi, "/files/<uuid:file_id>/image-preview")
+api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/file-preview")
 api.add_resource(WorkspaceWebappLogoApi, "/files/workspaces/<uuid:workspace_id>/webapp-logo")

+ 15 - 5
api/controllers/files/tool_files.py

@@ -16,6 +16,7 @@ class ToolFilePreviewApi(Resource):
         parser.add_argument("timestamp", type=str, required=True, location="args")
         parser.add_argument("nonce", type=str, required=True, location="args")
         parser.add_argument("sign", type=str, required=True, location="args")
+        parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
 
         args = parser.parse_args()
 
@@ -28,18 +29,27 @@ class ToolFilePreviewApi(Resource):
             raise Forbidden("Invalid request.")
 
         try:
-            result = ToolFileManager.get_file_generator_by_tool_file_id(
+            stream, tool_file = ToolFileManager.get_file_generator_by_tool_file_id(
                 file_id,
             )
 
-            if not result:
+            if not stream or not tool_file:
                 raise NotFound("file is not found")
-
-            generator, mimetype = result
         except Exception:
             raise UnsupportedFileTypeError()
 
-        return Response(generator, mimetype=mimetype)
+        response = Response(
+            stream,
+            mimetype=tool_file.mimetype,
+            direct_passthrough=True,
+            headers={
+                "Content-Length": str(tool_file.size),
+            },
+        )
+        if args["as_attachment"]:
+            response.headers["Content-Disposition"] = f"attachment; filename={tool_file.name}"
+
+        return response
 
 
 api.add_resource(ToolFilePreviewApi, "/files/tools/<uuid:file_id>.<string:extension>")

+ 2 - 2
api/controllers/service_api/app/message.py

@@ -48,7 +48,7 @@ class MessageListApi(Resource):
         "tool_input": fields.String,
         "created_at": TimestampField,
         "observation": fields.String,
-        "message_files": fields.List(fields.String, attribute="files"),
+        "message_files": fields.List(fields.String),
     }
 
     message_fields = {
@@ -58,7 +58,7 @@ class MessageListApi(Resource):
         "inputs": fields.Raw,
         "query": fields.String,
         "answer": fields.String(attribute="re_sign_file_url_answer"),
-        "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
+        "message_files": fields.List(fields.Nested(message_file_fields)),
         "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
         "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
         "created_at": TimestampField,

+ 19 - 1
api/controllers/web/file.py

@@ -1,3 +1,5 @@
+import urllib.parse
+
 from flask import request
 from flask_restful import marshal_with
 
@@ -5,7 +7,8 @@ import services
 from controllers.web import api
 from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError
 from controllers.web.wraps import WebApiResource
-from fields.file_fields import file_fields
+from core.helper import ssrf_proxy
+from fields.file_fields import file_fields, remote_file_info_fields
 from services.file_service import FileService
 
 
@@ -31,4 +34,19 @@ class FileApi(WebApiResource):
         return upload_file, 201
 
 
+class RemoteFileInfoApi(WebApiResource):
+    @marshal_with(remote_file_info_fields)
+    def get(self, url):
+        decoded_url = urllib.parse.unquote(url)
+        try:
+            response = ssrf_proxy.head(decoded_url)
+            return {
+                "file_type": response.headers.get("Content-Type", "application/octet-stream"),
+                "file_length": int(response.headers.get("Content-Length", 0)),
+            }
+        except Exception as e:
+            return {"error": str(e)}, 400
+
+
 api.add_resource(FileApi, "/files/upload")
+api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")

+ 3 - 2
api/controllers/web/message.py

@@ -22,6 +22,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
 from core.model_runtime.errors.invoke import InvokeError
 from fields.conversation_fields import message_file_fields
 from fields.message_fields import agent_thought_fields
+from fields.raws import FilesContainedField
 from libs import helper
 from libs.helper import TimestampField, uuid_value
 from models.model import AppMode
@@ -58,10 +59,10 @@ class MessageListApi(WebApiResource):
         "id": fields.String,
         "conversation_id": fields.String,
         "parent_message_id": fields.String,
-        "inputs": fields.Raw,
+        "inputs": FilesContainedField,
         "query": fields.String,
         "answer": fields.String(attribute="re_sign_file_url_answer"),
-        "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
+        "message_files": fields.List(fields.Nested(message_file_fields)),
         "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
         "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
         "created_at": TimestampField,

+ 1 - 1
api/controllers/web/saved_message.py

@@ -17,7 +17,7 @@ message_fields = {
     "inputs": fields.Raw,
     "query": fields.String,
     "answer": fields.String,
-    "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
+    "message_files": fields.List(fields.Nested(message_file_fields)),
     "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
     "created_at": TimestampField,
 }

+ 15 - 33
api/core/agent/base_agent_runner.py

@@ -16,13 +16,14 @@ from core.app.entities.app_invoke_entities import (
 )
 from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
 from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
-from core.file.message_file_parser import MessageFileParser
+from core.file import file_manager
 from core.memory.token_buffer_memory import TokenBufferMemory
 from core.model_manager import ModelInstance
-from core.model_runtime.entities.llm_entities import LLMUsage
-from core.model_runtime.entities.message_entities import (
+from core.model_runtime.entities import (
     AssistantPromptMessage,
+    LLMUsage,
     PromptMessage,
+    PromptMessageContent,
     PromptMessageTool,
     SystemPromptMessage,
     TextPromptMessageContent,
@@ -40,9 +41,9 @@ from core.tools.entities.tool_entities import (
 from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
 from core.tools.tool.tool import Tool
 from core.tools.tool_manager import ToolManager
-from core.tools.utils.tool_parameter_converter import ToolParameterConverter
 from extensions.ext_database import db
-from models.model import Conversation, Message, MessageAgentThought
+from factories import file_factory
+from models.model import Conversation, Message, MessageAgentThought, MessageFile
 from models.tools import ToolConversationVariables
 
 logger = logging.getLogger(__name__)
@@ -66,23 +67,6 @@ class BaseAgentRunner(AppRunner):
         db_variables: Optional[ToolConversationVariables] = None,
         model_instance: ModelInstance = None,
     ) -> None:
-        """
-        Agent runner
-        :param tenant_id: tenant id
-        :param application_generate_entity: application generate entity
-        :param conversation: conversation
-        :param app_config: app generate entity
-        :param model_config: model config
-        :param config: dataset config
-        :param queue_manager: queue manager
-        :param message: message
-        :param user_id: user id
-        :param memory: memory
-        :param prompt_messages: prompt messages
-        :param variables_pool: variables pool
-        :param db_variables: db variables
-        :param model_instance: model instance
-        """
         self.tenant_id = tenant_id
         self.application_generate_entity = application_generate_entity
         self.conversation = conversation
@@ -180,7 +164,7 @@ class BaseAgentRunner(AppRunner):
             if parameter.form != ToolParameter.ToolParameterForm.LLM:
                 continue
 
-            parameter_type = ToolParameterConverter.get_parameter_type(parameter.type)
+            parameter_type = parameter.type.as_normal_type()
             enum = []
             if parameter.type == ToolParameter.ToolParameterType.SELECT:
                 enum = [option.value for option in parameter.options]
@@ -265,7 +249,7 @@ class BaseAgentRunner(AppRunner):
             if parameter.form != ToolParameter.ToolParameterForm.LLM:
                 continue
 
-            parameter_type = ToolParameterConverter.get_parameter_type(parameter.type)
+            parameter_type = parameter.type.as_normal_type()
             enum = []
             if parameter.type == ToolParameter.ToolParameterType.SELECT:
                 enum = [option.value for option in parameter.options]
@@ -511,26 +495,24 @@ class BaseAgentRunner(AppRunner):
         return result
 
     def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
-        message_file_parser = MessageFileParser(
-            tenant_id=self.tenant_id,
-            app_id=self.app_config.app_id,
-        )
-
-        files = message.message_files
+        files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
         if files:
             file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
 
             if file_extra_config:
-                file_objs = message_file_parser.transform_message_files(files, file_extra_config)
+                file_objs = file_factory.build_from_message_files(
+                    message_files=files, tenant_id=self.tenant_id, config=file_extra_config
+                )
             else:
                 file_objs = []
 
             if not file_objs:
                 return UserPromptMessage(content=message.query)
             else:
-                prompt_message_contents = [TextPromptMessageContent(data=message.query)]
+                prompt_message_contents: list[PromptMessageContent] = []
+                prompt_message_contents.append(TextPromptMessageContent(data=message.query))
                 for file_obj in file_objs:
-                    prompt_message_contents.append(file_obj.prompt_message_content)
+                    prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
 
                 return UserPromptMessage(content=prompt_message_contents)
         else:

+ 6 - 3
api/core/agent/cot_chat_agent_runner.py

@@ -1,9 +1,11 @@
 import json
 
 from core.agent.cot_agent_runner import CotAgentRunner
-from core.model_runtime.entities.message_entities import (
+from core.file import file_manager
+from core.model_runtime.entities import (
     AssistantPromptMessage,
     PromptMessage,
+    PromptMessageContent,
     SystemPromptMessage,
     TextPromptMessageContent,
     UserPromptMessage,
@@ -32,9 +34,10 @@ class CotChatAgentRunner(CotAgentRunner):
         Organize user query
         """
         if self.files:
-            prompt_message_contents = [TextPromptMessageContent(data=query)]
+            prompt_message_contents: list[PromptMessageContent] = []
+            prompt_message_contents.append(TextPromptMessageContent(data=query))
             for file_obj in self.files:
-                prompt_message_contents.append(file_obj.prompt_message_content)
+                prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
 
             prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
         else:

+ 10 - 4
api/core/agent/fc_agent_runner.py

@@ -7,10 +7,15 @@ from typing import Any, Optional, Union
 from core.agent.base_agent_runner import BaseAgentRunner
 from core.app.apps.base_app_queue_manager import PublishFrom
 from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
-from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
-from core.model_runtime.entities.message_entities import (
+from core.file import file_manager
+from core.model_runtime.entities import (
     AssistantPromptMessage,
+    LLMResult,
+    LLMResultChunk,
+    LLMResultChunkDelta,
+    LLMUsage,
     PromptMessage,
+    PromptMessageContent,
     PromptMessageContentType,
     SystemPromptMessage,
     TextPromptMessageContent,
@@ -390,9 +395,10 @@ class FunctionCallAgentRunner(BaseAgentRunner):
         Organize user query
         """
         if self.files:
-            prompt_message_contents = [TextPromptMessageContent(data=query)]
+            prompt_message_contents: list[PromptMessageContent] = []
+            prompt_message_contents.append(TextPromptMessageContent(data=query))
             for file_obj in self.files:
-                prompt_message_contents.append(file_obj.prompt_message_content)
+                prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
 
             prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
         else:

+ 2 - 3
api/core/app/app_config/easy_ui_based_app/variables/manager.py

@@ -53,12 +53,11 @@ class BasicVariablesConfigManager:
                     VariableEntity(
                         type=variable_type,
                         variable=variable.get("variable"),
-                        description=variable.get("description"),
+                        description=variable.get("description", ""),
                         label=variable.get("label"),
                         required=variable.get("required", False),
                         max_length=variable.get("max_length"),
-                        options=variable.get("options"),
-                        default=variable.get("default"),
+                        options=variable.get("options", []),
                     )
                 )
 

+ 13 - 9
api/core/app/app_config/entities.py

@@ -1,11 +1,12 @@
+from collections.abc import Sequence
 from enum import Enum
 from typing import Any, Optional
 
-from pydantic import BaseModel
+from pydantic import BaseModel, Field
 
-from core.file.file_obj import FileExtraConfig
+from core.file import FileExtraConfig, FileTransferMethod, FileType
 from core.model_runtime.entities.message_entities import PromptMessageRole
-from models import AppMode
+from models.model import AppMode
 
 
 class ModelConfigEntity(BaseModel):
@@ -69,7 +70,7 @@ class PromptTemplateEntity(BaseModel):
         ADVANCED = "advanced"
 
         @classmethod
-        def value_of(cls, value: str) -> "PromptType":
+        def value_of(cls, value: str):
             """
             Get value of given mode.
 
@@ -93,6 +94,8 @@ class VariableEntityType(str, Enum):
     PARAGRAPH = "paragraph"
     NUMBER = "number"
     EXTERNAL_DATA_TOOL = "external_data_tool"
+    FILE = "file"
+    FILE_LIST = "file-list"
 
 
 class VariableEntity(BaseModel):
@@ -102,13 +105,14 @@ class VariableEntity(BaseModel):
 
     variable: str
     label: str
-    description: Optional[str] = None
+    description: str = ""
     type: VariableEntityType
     required: bool = False
     max_length: Optional[int] = None
-    options: Optional[list[str]] = None
-    default: Optional[str] = None
-    hint: Optional[str] = None
+    options: Sequence[str] = Field(default_factory=list)
+    allowed_file_types: Sequence[FileType] = Field(default_factory=list)
+    allowed_file_extensions: Sequence[str] = Field(default_factory=list)
+    allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
 
 
 class ExternalDataVariableEntity(BaseModel):
@@ -136,7 +140,7 @@ class DatasetRetrieveConfigEntity(BaseModel):
         MULTIPLE = "multiple"
 
         @classmethod
-        def value_of(cls, value: str) -> "RetrieveStrategy":
+        def value_of(cls, value: str):
             """
             Get value of given mode.
 

+ 15 - 37
api/core/app/app_config/features/file_upload/manager.py

@@ -1,12 +1,13 @@
 from collections.abc import Mapping
-from typing import Any, Optional
+from typing import Any
 
-from core.file.file_obj import FileExtraConfig
+from core.file.models import FileExtraConfig
+from models import FileUploadConfig
 
 
 class FileUploadConfigManager:
     @classmethod
-    def convert(cls, config: Mapping[str, Any], is_vision: bool = True) -> Optional[FileExtraConfig]:
+    def convert(cls, config: Mapping[str, Any], is_vision: bool = True):
         """
         Convert model config to model config
 
@@ -15,19 +16,18 @@ class FileUploadConfigManager:
         """
         file_upload_dict = config.get("file_upload")
         if file_upload_dict:
-            if file_upload_dict.get("image"):
-                if "enabled" in file_upload_dict["image"] and file_upload_dict["image"]["enabled"]:
-                    image_config = {
-                        "number_limits": file_upload_dict["image"]["number_limits"],
-                        "transfer_methods": file_upload_dict["image"]["transfer_methods"],
+            if file_upload_dict.get("enabled"):
+                data = {
+                    "image_config": {
+                        "number_limits": file_upload_dict["number_limits"],
+                        "transfer_methods": file_upload_dict["allowed_file_upload_methods"],
                     }
+                }
 
-                    if is_vision:
-                        image_config["detail"] = file_upload_dict["image"]["detail"]
+                if is_vision:
+                    data["image_config"]["detail"] = file_upload_dict.get("image", {}).get("detail", "low")
 
-                    return FileExtraConfig(image_config=image_config)
-
-        return None
+                return FileExtraConfig.model_validate(data)
 
     @classmethod
     def validate_and_set_defaults(cls, config: dict, is_vision: bool = True) -> tuple[dict, list[str]]:
@@ -39,29 +39,7 @@ class FileUploadConfigManager:
         """
         if not config.get("file_upload"):
             config["file_upload"] = {}
-
-        if not isinstance(config["file_upload"], dict):
-            raise ValueError("file_upload must be of dict type")
-
-        # check image config
-        if not config["file_upload"].get("image"):
-            config["file_upload"]["image"] = {"enabled": False}
-
-        if config["file_upload"]["image"]["enabled"]:
-            number_limits = config["file_upload"]["image"]["number_limits"]
-            if number_limits < 1 or number_limits > 6:
-                raise ValueError("number_limits must be in [1, 6]")
-
-            if is_vision:
-                detail = config["file_upload"]["image"]["detail"]
-                if detail not in {"high", "low"}:
-                    raise ValueError("detail must be in ['high', 'low']")
-
-            transfer_methods = config["file_upload"]["image"]["transfer_methods"]
-            if not isinstance(transfer_methods, list):
-                raise ValueError("transfer_methods must be of list type")
-            for method in transfer_methods:
-                if method not in {"remote_url", "local_file"}:
-                    raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
+        else:
+            FileUploadConfig.model_validate(config["file_upload"])
 
         return config, ["file_upload"]

+ 1 - 1
api/core/app/app_config/workflow_ui_based_app/variables/manager.py

@@ -17,6 +17,6 @@ class WorkflowVariablesConfigManager:
 
         # variables
         for variable in user_input_form:
-            variables.append(VariableEntity(**variable))
+            variables.append(VariableEntity.model_validate(variable))
 
         return variables

+ 16 - 6
api/core/app/apps/advanced_chat/app_generator.py

@@ -21,11 +21,12 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
 from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
 from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
 from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
-from core.file.message_file_parser import MessageFileParser
 from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
 from core.ops.ops_trace_manager import TraceQueueManager
 from extensions.ext_database import db
+from factories import file_factory
 from models.account import Account
+from models.enums import CreatedByRole
 from models.model import App, Conversation, EndUser, Message
 from models.workflow import Workflow
 
@@ -96,10 +97,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
 
         # parse files
         files = args["files"] if args.get("files") else []
-        message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
         file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
+        role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
         if file_extra_config:
-            file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
+            file_objs = file_factory.build_from_mappings(
+                mappings=files,
+                tenant_id=app_model.tenant_id,
+                user_id=user.id,
+                role=role,
+                config=file_extra_config,
+            )
         else:
             file_objs = []
 
@@ -107,8 +114,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
         app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
 
         # get tracing instance
-        user_id = user.id if isinstance(user, Account) else user.session_id
-        trace_manager = TraceQueueManager(app_model.id, user_id)
+        trace_manager = TraceQueueManager(
+            app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
+        )
 
         if invoke_from == InvokeFrom.DEBUGGER:
             # always enable retriever resource in debugger mode
@@ -120,7 +128,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
             task_id=str(uuid.uuid4()),
             app_config=app_config,
             conversation_id=conversation.id if conversation else None,
-            inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
+            inputs=conversation.inputs
+            if conversation
+            else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
             query=query,
             files=file_objs,
             parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,

+ 5 - 38
api/core/app/apps/advanced_chat/app_runner.py

@@ -1,31 +1,27 @@
 import logging
-import os
 from collections.abc import Mapping
 from typing import Any, cast
 
 from sqlalchemy import select
 from sqlalchemy.orm import Session
 
+from configs import dify_config
 from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
 from core.app.apps.base_app_queue_manager import AppQueueManager
 from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
-from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
-from core.app.entities.app_invoke_entities import (
-    AdvancedChatAppGenerateEntity,
-    InvokeFrom,
-)
+from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
 from core.app.entities.queue_entities import (
     QueueAnnotationReplyEvent,
     QueueStopEvent,
     QueueTextChunkEvent,
 )
 from core.moderation.base import ModerationError
-from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
-from core.workflow.entities.node_entities import UserFrom
+from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
 from core.workflow.entities.variable_pool import VariablePool
 from core.workflow.enums import SystemVariableKey
 from core.workflow.workflow_entry import WorkflowEntry
 from extensions.ext_database import db
+from models.enums import UserFrom
 from models.model import App, Conversation, EndUser, Message
 from models.workflow import ConversationVariable, WorkflowType
 
@@ -44,12 +40,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
         conversation: Conversation,
         message: Message,
     ) -> None:
-        """
-        :param application_generate_entity: application generate entity
-        :param queue_manager: application queue manager
-        :param conversation: conversation
-        :param message: message
-        """
         super().__init__(queue_manager)
 
         self.application_generate_entity = application_generate_entity
@@ -57,10 +47,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
         self.message = message
 
     def run(self) -> None:
-        """
-        Run application
-        :return:
-        """
         app_config = self.application_generate_entity.app_config
         app_config = cast(AdvancedChatAppConfig, app_config)
 
@@ -81,7 +67,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
             user_id = self.application_generate_entity.user_id
 
         workflow_callbacks: list[WorkflowCallback] = []
-        if bool(os.environ.get("DEBUG", "False").lower() == "true"):
+        if dify_config.DEBUG:
             workflow_callbacks.append(WorkflowLoggingCallback())
 
         if self.application_generate_entity.single_iteration_run:
@@ -201,15 +187,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
         query: str,
         message_id: str,
     ) -> bool:
-        """
-        Handle input moderation
-        :param app_record: app record
-        :param app_generate_entity: application generate entity
-        :param inputs: inputs
-        :param query: query
-        :param message_id: message id
-        :return:
-        """
         try:
             # process sensitive_word_avoidance
             _, inputs, query = self.moderation_for_inputs(
@@ -229,14 +206,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
     def handle_annotation_reply(
         self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity
     ) -> bool:
-        """
-        Handle annotation reply
-        :param app_record: app record
-        :param message: message
-        :param query: query
-        :param app_generate_entity: application generate entity
-        """
-        # annotation reply
         annotation_reply = self.query_app_annotations_to_reply(
             app_record=app_record,
             message=message,
@@ -258,8 +227,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
     def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None:
         """
         Direct output
-        :param text: text
-        :return:
         """
         self._publish_event(QueueTextChunkEvent(text=text))
 

+ 28 - 8
api/core/app/apps/advanced_chat/generate_task_pipeline.py

@@ -1,7 +1,7 @@
 import json
 import logging
 import time
-from collections.abc import Generator
+from collections.abc import Generator, Mapping
 from typing import Any, Optional, Union
 
 from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
@@ -9,6 +9,7 @@ from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGenerator
 from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
 from core.app.entities.app_invoke_entities import (
     AdvancedChatAppGenerateEntity,
+    InvokeFrom,
 )
 from core.app.entities.queue_entities import (
     QueueAdvancedChatMessageEndEvent,
@@ -50,10 +51,12 @@ from core.model_runtime.utils.encoders import jsonable_encoder
 from core.ops.ops_trace_manager import TraceQueueManager
 from core.workflow.enums import SystemVariableKey
 from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
+from core.workflow.nodes import NodeType
 from events.message_event import message_was_created
 from extensions.ext_database import db
+from models import Conversation, EndUser, Message, MessageFile
 from models.account import Account
-from models.model import Conversation, EndUser, Message
+from models.enums import CreatedByRole
 from models.workflow import (
     Workflow,
     WorkflowNodeExecution,
@@ -120,6 +123,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
         self._wip_workflow_node_executions = {}
 
         self._conversation_name_generate_thread = None
+        self._recorded_files: list[Mapping[str, Any]] = []
 
     def process(self):
         """
@@ -298,6 +302,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
             elif isinstance(event, QueueNodeSucceededEvent):
                 workflow_node_execution = self._handle_workflow_node_execution_success(event)
 
+                # Record files if it's an answer node or end node
+                if event.node_type in [NodeType.ANSWER, NodeType.END]:
+                    self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
+
                 response = self._workflow_node_finish_to_stream_response(
                     event=event,
                     task_id=self._application_generate_entity.task_id,
@@ -364,7 +372,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                     start_at=graph_runtime_state.start_at,
                     total_tokens=graph_runtime_state.total_tokens,
                     total_steps=graph_runtime_state.node_run_steps,
-                    outputs=json.dumps(event.outputs) if event.outputs else None,
+                    outputs=event.outputs,
                     conversation_id=self._conversation.id,
                     trace_manager=trace_manager,
                 )
@@ -490,10 +498,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
             self._conversation_name_generate_thread.join()
 
     def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
-        """
-        Save message.
-        :return:
-        """
         self._refetch_message()
 
         self._message.answer = self._task_state.answer
@@ -501,6 +505,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
         self._message.message_metadata = (
             json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
         )
+        message_files = [
+            MessageFile(
+                message_id=self._message.id,
+                type=file["type"],
+                transfer_method=file["transfer_method"],
+                url=file["remote_url"],
+                belongs_to="assistant",
+                upload_file_id=file["related_id"],
+                created_by_role=CreatedByRole.ACCOUNT
+                if self._message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
+                else CreatedByRole.END_USER,
+                created_by=self._message.from_account_id or self._message.from_end_user_id or "",
+            )
+            for file in self._recorded_files
+        ]
+        db.session.add_all(message_files)
 
         if graph_runtime_state and graph_runtime_state.llm_usage:
             usage = graph_runtime_state.llm_usage
@@ -540,7 +560,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                 del extras["metadata"]["annotation_reply"]
 
         return MessageEndStreamResponse(
-            task_id=self._application_generate_entity.task_id, id=self._message.id, **extras
+            task_id=self._application_generate_entity.task_id, id=self._message.id, files=self._recorded_files, **extras
         )
 
     def _handle_output_moderation_chunk(self, text: str) -> bool:

+ 23 - 10
api/core/app/apps/agent_chat/app_generator.py

@@ -18,12 +18,12 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskSt
 from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
 from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
 from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
-from core.file.message_file_parser import MessageFileParser
 from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
 from core.ops.ops_trace_manager import TraceQueueManager
 from extensions.ext_database import db
-from models.account import Account
-from models.model import App, EndUser
+from factories import file_factory
+from models import Account, App, EndUser
+from models.enums import CreatedByRole
 
 logger = logging.getLogger(__name__)
 
@@ -50,7 +50,12 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
     ) -> dict: ...
 
     def generate(
-        self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True
+        self,
+        app_model: App,
+        user: Union[Account, EndUser],
+        args: Any,
+        invoke_from: InvokeFrom,
+        stream: bool = True,
     ) -> Union[dict, Generator[dict, None, None]]:
         """
         Generate App response.
@@ -98,12 +103,19 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
             # always enable retriever resource in debugger mode
             override_model_config_dict["retriever_resource"] = {"enabled": True}
 
+        role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
+
         # parse files
-        files = args["files"] if args.get("files") else []
-        message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
+        files = args.get("files") or []
         file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
         if file_extra_config:
-            file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
+            file_objs = file_factory.build_from_mappings(
+                mappings=files,
+                tenant_id=app_model.tenant_id,
+                user_id=user.id,
+                role=role,
+                config=file_extra_config,
+            )
         else:
             file_objs = []
 
@@ -116,8 +128,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
         )
 
         # get tracing instance
-        user_id = user.id if isinstance(user, Account) else user.session_id
-        trace_manager = TraceQueueManager(app_model.id, user_id)
+        trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id)
 
         # init application generate entity
         application_generate_entity = AgentChatAppGenerateEntity(
@@ -125,7 +136,9 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
             app_config=app_config,
             model_conf=ModelConfigConverter.convert(app_config),
             conversation_id=conversation.id if conversation else None,
-            inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
+            inputs=conversation.inputs
+            if conversation
+            else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
             query=query,
             files=file_objs,
             parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,

+ 93 - 24
api/core/app/apps/base_app_generator.py

@@ -1,35 +1,92 @@
 from collections.abc import Mapping
-from typing import Any, Optional
+from typing import TYPE_CHECKING, Any, Optional
 
-from core.app.app_config.entities import AppConfig, VariableEntity, VariableEntityType
+from core.app.app_config.entities import VariableEntityType
+from core.file import File, FileExtraConfig
+from factories import file_factory
+
+if TYPE_CHECKING:
+    from core.app.app_config.entities import AppConfig, VariableEntity
+    from models.enums import CreatedByRole
 
 
 class BaseAppGenerator:
-    def _get_cleaned_inputs(self, user_inputs: Optional[Mapping[str, Any]], app_config: AppConfig) -> Mapping[str, Any]:
+    def _prepare_user_inputs(
+        self,
+        *,
+        user_inputs: Optional[Mapping[str, Any]],
+        app_config: "AppConfig",
+        user_id: str,
+        role: "CreatedByRole",
+    ) -> Mapping[str, Any]:
         user_inputs = user_inputs or {}
         # Filter input variables from form configuration, handle required fields, default values, and option values
         variables = app_config.variables
-        filtered_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
-        filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()}
-        return filtered_inputs
+        user_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
+        user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()}
+        # Convert files in inputs to File
+        entity_dictionary = {item.variable: item for item in app_config.variables}
+        # Convert single file to File
+        files_inputs = {
+            k: file_factory.build_from_mapping(
+                mapping=v,
+                tenant_id=app_config.tenant_id,
+                user_id=user_id,
+                role=role,
+                config=FileExtraConfig(
+                    allowed_file_types=entity_dictionary[k].allowed_file_types,
+                    allowed_extensions=entity_dictionary[k].allowed_file_extensions,
+                    allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
+                ),
+            )
+            for k, v in user_inputs.items()
+            if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE
+        }
+        # Convert list of files to File
+        file_list_inputs = {
+            k: file_factory.build_from_mappings(
+                mappings=v,
+                tenant_id=app_config.tenant_id,
+                user_id=user_id,
+                role=role,
+                config=FileExtraConfig(
+                    allowed_file_types=entity_dictionary[k].allowed_file_types,
+                    allowed_extensions=entity_dictionary[k].allowed_file_extensions,
+                    allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
+                ),
+            )
+            for k, v in user_inputs.items()
+            if isinstance(v, list)
+            # Ensure skip List<File>
+            and all(isinstance(item, dict) for item in v)
+            and entity_dictionary[k].type == VariableEntityType.FILE_LIST
+        }
+        # Merge all inputs
+        user_inputs = {**user_inputs, **files_inputs, **file_list_inputs}
 
-    def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
-        user_input_value = inputs.get(var.variable)
-        if var.required and not user_input_value:
-            raise ValueError(f"{var.variable} is required in input form")
-        if not var.required and not user_input_value:
-            # TODO: should we return None here if the default value is None?
-            return var.default or ""
-        if (
-            var.type
-            in {
-                VariableEntityType.TEXT_INPUT,
-                VariableEntityType.SELECT,
-                VariableEntityType.PARAGRAPH,
-            }
-            and user_input_value
-            and not isinstance(user_input_value, str)
+        # Check if all files are converted to File
+        if any(filter(lambda v: isinstance(v, dict), user_inputs.values())):
+            raise ValueError("Invalid input type")
+        if any(
+            filter(lambda v: isinstance(v, dict), filter(lambda item: isinstance(item, list), user_inputs.values()))
         ):
+            raise ValueError("Invalid input type")
+
+        return user_inputs
+
+    def _validate_input(self, *, inputs: Mapping[str, Any], var: "VariableEntity"):
+        user_input_value = inputs.get(var.variable)
+        if not user_input_value:
+            if var.required:
+                raise ValueError(f"{var.variable} is required in input form")
+            else:
+                return None
+
+        if var.type in {
+            VariableEntityType.TEXT_INPUT,
+            VariableEntityType.SELECT,
+            VariableEntityType.PARAGRAPH,
+        } and not isinstance(user_input_value, str):
             raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string")
         if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
             # may raise ValueError if user_input_value is not a valid number
@@ -41,12 +98,24 @@ class BaseAppGenerator:
             except ValueError:
                 raise ValueError(f"{var.variable} in input form must be a valid number")
         if var.type == VariableEntityType.SELECT:
-            options = var.options or []
+            options = var.options
             if user_input_value not in options:
                 raise ValueError(f"{var.variable} in input form must be one of the following: {options}")
         elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}:
-            if var.max_length and user_input_value and len(user_input_value) > var.max_length:
+            if var.max_length and len(user_input_value) > var.max_length:
                 raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")
+        elif var.type == VariableEntityType.FILE:
+            if not isinstance(user_input_value, dict) and not isinstance(user_input_value, File):
+                raise ValueError(f"{var.variable} in input form must be a file")
+        elif var.type == VariableEntityType.FILE_LIST:
+            if not (
+                isinstance(user_input_value, list)
+                and (
+                    all(isinstance(item, dict) for item in user_input_value)
+                    or all(isinstance(item, File) for item in user_input_value)
+                )
+            ):
+                raise ValueError(f"{var.variable} in input form must be a list of files")
 
         return user_input_value
 

+ 3 - 3
api/core/app/apps/base_app_runner.py

@@ -27,7 +27,7 @@ from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
 from models.model import App, AppMode, Message, MessageAnnotation
 
 if TYPE_CHECKING:
-    from core.file.file_obj import FileVar
+    from core.file.models import File
 
 
 class AppRunner:
@@ -37,7 +37,7 @@ class AppRunner:
         model_config: ModelConfigWithCredentialsEntity,
         prompt_template_entity: PromptTemplateEntity,
         inputs: dict[str, str],
-        files: list["FileVar"],
+        files: list["File"],
         query: Optional[str] = None,
     ) -> int:
         """
@@ -137,7 +137,7 @@ class AppRunner:
         model_config: ModelConfigWithCredentialsEntity,
         prompt_template_entity: PromptTemplateEntity,
         inputs: dict[str, str],
-        files: list["FileVar"],
+        files: list["File"],
         query: Optional[str] = None,
         context: Optional[str] = None,
         memory: Optional[TokenBufferMemory] = None,

+ 16 - 6
api/core/app/apps/chat/app_generator.py

@@ -18,11 +18,12 @@ from core.app.apps.chat.generate_response_converter import ChatAppGenerateRespon
 from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
 from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
 from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
-from core.file.message_file_parser import MessageFileParser
 from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
 from core.ops.ops_trace_manager import TraceQueueManager
 from extensions.ext_database import db
+from factories import file_factory
 from models.account import Account
+from models.enums import CreatedByRole
 from models.model import App, EndUser
 
 logger = logging.getLogger(__name__)
@@ -100,12 +101,19 @@ class ChatAppGenerator(MessageBasedAppGenerator):
             # always enable retriever resource in debugger mode
             override_model_config_dict["retriever_resource"] = {"enabled": True}
 
+        role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
+
         # parse files
         files = args["files"] if args.get("files") else []
-        message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
         file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
         if file_extra_config:
-            file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
+            file_objs = file_factory.build_from_mappings(
+                mappings=files,
+                tenant_id=app_model.tenant_id,
+                user_id=user.id,
+                role=role,
+                config=file_extra_config,
+            )
         else:
             file_objs = []
 
@@ -118,7 +126,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
         )
 
         # get tracing instance
-        trace_manager = TraceQueueManager(app_model.id)
+        trace_manager = TraceQueueManager(app_id=app_model.id)
 
         # init application generate entity
         application_generate_entity = ChatAppGenerateEntity(
@@ -126,15 +134,17 @@ class ChatAppGenerator(MessageBasedAppGenerator):
             app_config=app_config,
             model_conf=ModelConfigConverter.convert(app_config),
             conversation_id=conversation.id if conversation else None,
-            inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
+            inputs=conversation.inputs
+            if conversation
+            else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
             query=query,
             files=file_objs,
             parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
             user_id=user.id,
-            stream=stream,
             invoke_from=invoke_from,
             extras=extras,
             trace_manager=trace_manager,
+            stream=stream,
         )
 
         # init generate records

+ 23 - 9
api/core/app/apps/completion/app_generator.py

@@ -17,12 +17,12 @@ from core.app.apps.completion.generate_response_converter import CompletionAppGe
 from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
 from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
 from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
-from core.file.message_file_parser import MessageFileParser
 from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
 from core.ops.ops_trace_manager import TraceQueueManager
 from extensions.ext_database import db
-from models.account import Account
-from models.model import App, EndUser, Message
+from factories import file_factory
+from models import Account, App, EndUser, Message
+from models.enums import CreatedByRole
 from services.errors.app import MoreLikeThisDisabledError
 from services.errors.message import MessageNotExistsError
 
@@ -88,12 +88,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
                 tenant_id=app_model.tenant_id, config=args.get("model_config")
             )
 
+        role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
+
         # parse files
         files = args["files"] if args.get("files") else []
-        message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
         file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
         if file_extra_config:
-            file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
+            file_objs = file_factory.build_from_mappings(
+                mappings=files,
+                tenant_id=app_model.tenant_id,
+                user_id=user.id,
+                role=role,
+                config=file_extra_config,
+            )
         else:
             file_objs = []
 
@@ -103,6 +110,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
         )
 
         # get tracing instance
+        user_id = user.id if isinstance(user, Account) else user.session_id
         trace_manager = TraceQueueManager(app_model.id)
 
         # init application generate entity
@@ -110,7 +118,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
             task_id=str(uuid.uuid4()),
             app_config=app_config,
             model_conf=ModelConfigConverter.convert(app_config),
-            inputs=self._get_cleaned_inputs(inputs, app_config),
+            inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
             query=query,
             files=file_objs,
             user_id=user.id,
@@ -251,10 +259,16 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
         override_model_config_dict["model"] = model_dict
 
         # parse files
-        message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
-        file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
+        role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
+        file_extra_config = FileUploadConfigManager.convert(override_model_config_dict)
         if file_extra_config:
-            file_objs = message_file_parser.validate_and_transform_files_arg(message.files, file_extra_config, user)
+            file_objs = file_factory.build_from_mappings(
+                mappings=message.message_files,
+                tenant_id=app_model.tenant_id,
+                user_id=user.id,
+                role=role,
+                config=file_extra_config,
+            )
         else:
             file_objs = []
 

+ 5 - 5
api/core/app/apps/message_based_app_generator.py

@@ -26,7 +26,7 @@ from core.app.entities.task_entities import (
 from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
 from core.prompt.utils.prompt_template_parser import PromptTemplateParser
 from extensions.ext_database import db
-from models.account import Account
+from models import Account
 from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
 from services.errors.app_model_config import AppModelConfigBrokenError
 from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError
@@ -235,13 +235,13 @@ class MessageBasedAppGenerator(BaseAppGenerator):
         for file in application_generate_entity.files:
             message_file = MessageFile(
                 message_id=message.id,
-                type=file.type.value,
-                transfer_method=file.transfer_method.value,
+                type=file.type,
+                transfer_method=file.transfer_method,
                 belongs_to="user",
-                url=file.url,
+                url=file.remote_url,
                 upload_file_id=file.related_id,
                 created_by_role=("account" if account_id else "end_user"),
-                created_by=account_id or end_user_id,
+                created_by=account_id or end_user_id or "",
             )
             db.session.add(message_file)
             db.session.commit()

+ 25 - 29
api/core/app/apps/workflow/app_generator.py

@@ -3,7 +3,7 @@ import logging
 import os
 import threading
 import uuid
-from collections.abc import Generator
+from collections.abc import Generator, Mapping, Sequence
 from typing import Any, Literal, Optional, Union, overload
 
 from flask import Flask, current_app
@@ -20,13 +20,12 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
 from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
 from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
 from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
-from core.file.message_file_parser import MessageFileParser
 from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
 from core.ops.ops_trace_manager import TraceQueueManager
 from extensions.ext_database import db
-from models.account import Account
-from models.model import App, EndUser
-from models.workflow import Workflow
+from factories import file_factory
+from models import Account, App, EndUser, Workflow
+from models.enums import CreatedByRole
 
 logger = logging.getLogger(__name__)
 
@@ -63,49 +62,46 @@ class WorkflowAppGenerator(BaseAppGenerator):
         app_model: App,
         workflow: Workflow,
         user: Union[Account, EndUser],
-        args: dict,
+        args: Mapping[str, Any],
         invoke_from: InvokeFrom,
         stream: bool = True,
         call_depth: int = 0,
         workflow_thread_pool_id: Optional[str] = None,
     ):
-        """
-        Generate App response.
+        files: Sequence[Mapping[str, Any]] = args.get("files") or []
 
-        :param app_model: App
-        :param workflow: Workflow
-        :param user: account or end user
-        :param args: request args
-        :param invoke_from: invoke from source
-        :param stream: is stream
-        :param call_depth: call depth
-        :param workflow_thread_pool_id: workflow thread pool id
-        """
-        inputs = args["inputs"]
+        role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
 
         # parse files
-        files = args["files"] if args.get("files") else []
-        message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
         file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
-        if file_extra_config:
-            file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
-        else:
-            file_objs = []
+        system_files = file_factory.build_from_mappings(
+            mappings=files,
+            tenant_id=app_model.tenant_id,
+            user_id=user.id,
+            role=role,
+            config=file_extra_config,
+        )
 
         # convert to app config
-        app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
+        app_config = WorkflowAppConfigManager.get_app_config(
+            app_model=app_model,
+            workflow=workflow,
+        )
 
         # get tracing instance
-        user_id = user.id if isinstance(user, Account) else user.session_id
-        trace_manager = TraceQueueManager(app_model.id, user_id)
+        trace_manager = TraceQueueManager(
+            app_id=app_model.id,
+            user_id=user.id if isinstance(user, Account) else user.session_id,
+        )
 
+        inputs: Mapping[str, Any] = args["inputs"]
         workflow_run_id = str(uuid.uuid4())
         # init application generate entity
         application_generate_entity = WorkflowAppGenerateEntity(
             task_id=str(uuid.uuid4()),
             app_config=app_config,
-            inputs=self._get_cleaned_inputs(inputs, app_config),
-            files=file_objs,
+            inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
+            files=system_files,
             user_id=user.id,
             stream=stream,
             invoke_from=invoke_from,

+ 4 - 5
api/core/app/apps/workflow/app_runner.py

@@ -1,21 +1,20 @@
 import logging
-import os
 from typing import Optional, cast
 
+from configs import dify_config
 from core.app.apps.base_app_queue_manager import AppQueueManager
 from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
 from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
-from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
 from core.app.entities.app_invoke_entities import (
     InvokeFrom,
     WorkflowAppGenerateEntity,
 )
-from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
-from core.workflow.entities.node_entities import UserFrom
+from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
 from core.workflow.entities.variable_pool import VariablePool
 from core.workflow.enums import SystemVariableKey
 from core.workflow.workflow_entry import WorkflowEntry
 from extensions.ext_database import db
+from models.enums import UserFrom
 from models.model import App, EndUser
 from models.workflow import WorkflowType
 
@@ -71,7 +70,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
         db.session.close()
 
         workflow_callbacks: list[WorkflowCallback] = []
-        if bool(os.environ.get("DEBUG", "False").lower() == "true"):
+        if dify_config.DEBUG:
             workflow_callbacks.append(WorkflowLoggingCallback())
 
         # if only single iteration run is requested

+ 1 - 4
api/core/app/apps/workflow/generate_task_pipeline.py

@@ -1,4 +1,3 @@
-import json
 import logging
 import time
 from collections.abc import Generator
@@ -334,9 +333,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
                     start_at=graph_runtime_state.start_at,
                     total_tokens=graph_runtime_state.total_tokens,
                     total_steps=graph_runtime_state.node_run_steps,
-                    outputs=json.dumps(event.outputs)
-                    if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs
-                    else None,
+                    outputs=event.outputs,
                     conversation_id=None,
                     trace_manager=trace_manager,
                 )

+ 5 - 7
api/core/app/apps/workflow_app_runner.py

@@ -20,7 +20,6 @@ from core.app.entities.queue_entities import (
     QueueWorkflowStartedEvent,
     QueueWorkflowSucceededEvent,
 )
-from core.workflow.entities.node_entities import NodeType
 from core.workflow.entities.variable_pool import VariablePool
 from core.workflow.graph_engine.entities.event import (
     GraphEngineEvent,
@@ -41,9 +40,9 @@ from core.workflow.graph_engine.entities.event import (
     ParallelBranchRunSucceededEvent,
 )
 from core.workflow.graph_engine.entities.graph import Graph
-from core.workflow.nodes.base_node import BaseNode
-from core.workflow.nodes.iteration.entities import IterationNodeData
-from core.workflow.nodes.node_mapping import node_classes
+from core.workflow.nodes import NodeType
+from core.workflow.nodes.iteration import IterationNodeData
+from core.workflow.nodes.node_mapping import node_type_classes_mapping
 from core.workflow.workflow_entry import WorkflowEntry
 from extensions.ext_database import db
 from models.model import App
@@ -137,9 +136,8 @@ class WorkflowBasedAppRunner(AppRunner):
             raise ValueError("iteration node id not found in workflow graph")
 
         # Get node class
-        node_type = NodeType.value_of(iteration_node_config.get("data", {}).get("type"))
-        node_cls = node_classes.get(node_type)
-        node_cls = cast(type[BaseNode], node_cls)
+        node_type = NodeType(iteration_node_config.get("data", {}).get("type"))
+        node_cls = node_type_classes_mapping[node_type]
 
         # init variable pool
         variable_pool = VariablePool(

+ 4 - 4
api/core/app/entities/app_invoke_entities.py

@@ -1,4 +1,4 @@
-from collections.abc import Mapping
+from collections.abc import Mapping, Sequence
 from enum import Enum
 from typing import Any, Optional
 
@@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validat
 from constants import UUID_NIL
 from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
 from core.entities.provider_configuration import ProviderModelBundle
-from core.file.file_obj import FileVar
+from core.file.models import File
 from core.model_runtime.entities.model_entities import AIModelEntity
 from core.ops.ops_trace_manager import TraceQueueManager
 
@@ -23,7 +23,7 @@ class InvokeFrom(Enum):
     DEBUGGER = "debugger"
 
     @classmethod
-    def value_of(cls, value: str) -> "InvokeFrom":
+    def value_of(cls, value: str):
         """
         Get value of given mode.
 
@@ -82,7 +82,7 @@ class AppGenerateEntity(BaseModel):
     app_config: AppConfig
 
     inputs: Mapping[str, Any]
-    files: list[FileVar] = []
+    files: Sequence[File]
     user_id: str
 
     # extras

+ 3 - 2
api/core/app/entities/queue_entities.py

@@ -5,9 +5,10 @@ from typing import Any, Optional
 from pydantic import BaseModel, field_validator
 
 from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
-from core.workflow.entities.base_node_data_entities import BaseNodeData
-from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
+from core.workflow.entities.node_entities import NodeRunMetadataKey
 from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
+from core.workflow.nodes import NodeType
+from core.workflow.nodes.base import BaseNodeData
 
 
 class QueueEvent(str, Enum):

+ 4 - 2
api/core/app/entities/task_entities.py

@@ -1,3 +1,4 @@
+from collections.abc import Mapping, Sequence
 from enum import Enum
 from typing import Any, Optional
 
@@ -119,6 +120,7 @@ class MessageEndStreamResponse(StreamResponse):
     event: StreamEvent = StreamEvent.MESSAGE_END
     id: str
     metadata: dict = {}
+    files: Optional[Sequence[Mapping[str, Any]]] = None
 
 
 class MessageFileStreamResponse(StreamResponse):
@@ -211,7 +213,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
         created_by: Optional[dict] = None
         created_at: int
         finished_at: int
-        files: Optional[list[dict]] = []
+        files: Optional[Sequence[Mapping[str, Any]]] = []
 
     event: StreamEvent = StreamEvent.WORKFLOW_FINISHED
     workflow_run_id: str
@@ -296,7 +298,7 @@ class NodeFinishStreamResponse(StreamResponse):
         execution_metadata: Optional[dict] = None
         created_at: int
         finished_at: int
-        files: Optional[list[dict]] = []
+        files: Optional[Sequence[Mapping[str, Any]]] = []
         parallel_id: Optional[str] = None
         parallel_start_node_id: Optional[str] = None
         parent_parallel_id: Optional[str] = None

+ 0 - 18
api/core/app/segments/parser.py

@@ -1,18 +0,0 @@
-import re
-
-from core.workflow.entities.variable_pool import VariablePool
-
-from . import SegmentGroup, factory
-
-VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
-
-
-def convert_template(*, template: str, variable_pool: VariablePool):
-    parts = re.split(VARIABLE_PATTERN, template)
-    segments = []
-    for part in filter(lambda x: x, parts):
-        if "." in part and (value := variable_pool.get(part.split("."))):
-            segments.append(value)
-        else:
-            segments.append(factory.build_segment(part))
-    return SegmentGroup(value=segments)

+ 33 - 30
api/core/app/task_pipeline/workflow_cycle_manage.py

@@ -1,5 +1,6 @@
 import json
 import time
+from collections.abc import Mapping, Sequence
 from datetime import datetime, timezone
 from typing import Any, Optional, Union, cast
 
@@ -27,27 +28,26 @@ from core.app.entities.task_entities import (
     WorkflowStartStreamResponse,
     WorkflowTaskState,
 )
-from core.file.file_obj import FileVar
+from core.file import FILE_MODEL_IDENTITY, File
 from core.model_runtime.utils.encoders import jsonable_encoder
 from core.ops.entities.trace_entity import TraceTaskName
 from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
 from core.tools.tool_manager import ToolManager
-from core.workflow.entities.node_entities import NodeType
 from core.workflow.enums import SystemVariableKey
+from core.workflow.nodes import NodeType
 from core.workflow.nodes.tool.entities import ToolNodeData
 from core.workflow.workflow_entry import WorkflowEntry
 from extensions.ext_database import db
 from models.account import Account
+from models.enums import CreatedByRole, WorkflowRunTriggeredFrom
 from models.model import EndUser
 from models.workflow import (
-    CreatedByRole,
     Workflow,
     WorkflowNodeExecution,
     WorkflowNodeExecutionStatus,
     WorkflowNodeExecutionTriggeredFrom,
     WorkflowRun,
     WorkflowRunStatus,
-    WorkflowRunTriggeredFrom,
 )
 
 
@@ -117,7 +117,7 @@ class WorkflowCycleManage:
         start_at: float,
         total_tokens: int,
         total_steps: int,
-        outputs: Optional[str] = None,
+        outputs: Mapping[str, Any] | None = None,
         conversation_id: Optional[str] = None,
         trace_manager: Optional[TraceQueueManager] = None,
     ) -> WorkflowRun:
@@ -133,8 +133,10 @@ class WorkflowCycleManage:
         """
         workflow_run = self._refetch_workflow_run(workflow_run.id)
 
+        outputs = WorkflowEntry.handle_special_values(outputs)
+
         workflow_run.status = WorkflowRunStatus.SUCCEEDED.value
-        workflow_run.outputs = outputs
+        workflow_run.outputs = json.dumps(outputs or {})
         workflow_run.elapsed_time = time.perf_counter() - start_at
         workflow_run.total_tokens = total_tokens
         workflow_run.total_steps = total_steps
@@ -265,6 +267,7 @@ class WorkflowCycleManage:
         workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)
 
         inputs = WorkflowEntry.handle_special_values(event.inputs)
+        process_data = WorkflowEntry.handle_special_values(event.process_data)
         outputs = WorkflowEntry.handle_special_values(event.outputs)
         execution_metadata = (
             json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
@@ -276,7 +279,7 @@ class WorkflowCycleManage:
             {
                 WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.SUCCEEDED.value,
                 WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
-                WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None,
+                WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None,
                 WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
                 WorkflowNodeExecution.execution_metadata: execution_metadata,
                 WorkflowNodeExecution.finished_at: finished_at,
@@ -286,10 +289,11 @@ class WorkflowCycleManage:
 
         db.session.commit()
         db.session.close()
+        process_data = WorkflowEntry.handle_special_values(event.process_data)
 
         workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
         workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
-        workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
+        workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
         workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
         workflow_node_execution.execution_metadata = execution_metadata
         workflow_node_execution.finished_at = finished_at
@@ -308,6 +312,7 @@ class WorkflowCycleManage:
         workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)
 
         inputs = WorkflowEntry.handle_special_values(event.inputs)
+        process_data = WorkflowEntry.handle_special_values(event.process_data)
         outputs = WorkflowEntry.handle_special_values(event.outputs)
         finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
         elapsed_time = (finished_at - event.start_at).total_seconds()
@@ -317,7 +322,7 @@ class WorkflowCycleManage:
                 WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
                 WorkflowNodeExecution.error: event.error,
                 WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
-                WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None,
+                WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None,
                 WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
                 WorkflowNodeExecution.finished_at: finished_at,
                 WorkflowNodeExecution.elapsed_time: elapsed_time,
@@ -326,11 +331,12 @@ class WorkflowCycleManage:
 
         db.session.commit()
         db.session.close()
+        process_data = WorkflowEntry.handle_special_values(event.process_data)
 
         workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
         workflow_node_execution.error = event.error
         workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
-        workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
+        workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
         workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
         workflow_node_execution.finished_at = finished_at
         workflow_node_execution.elapsed_time = elapsed_time
@@ -637,7 +643,7 @@ class WorkflowCycleManage:
             ),
         )
 
-    def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]:
+    def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> Sequence[Mapping[str, Any]]:
         """
         Fetch files from node outputs
         :param outputs_dict: node outputs dict
@@ -646,15 +652,15 @@ class WorkflowCycleManage:
         if not outputs_dict:
             return []
 
-        files = []
-        for output_var, output_value in outputs_dict.items():
-            file_vars = self._fetch_files_from_variable_value(output_value)
-            if file_vars:
-                files.extend(file_vars)
+        files = [self._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()]
+        # Remove None
+        files = [file for file in files if file]
+        # Flatten list
+        files = [file for sublist in files for file in sublist]
 
         return files
 
-    def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> list[dict]:
+    def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]:
         """
         Fetch files from variable value
         :param value: variable value
@@ -666,17 +672,17 @@ class WorkflowCycleManage:
         files = []
         if isinstance(value, list):
             for item in value:
-                file_var = self._get_file_var_from_value(item)
-                if file_var:
-                    files.append(file_var)
+                file = self._get_file_var_from_value(item)
+                if file:
+                    files.append(file)
         elif isinstance(value, dict):
-            file_var = self._get_file_var_from_value(value)
-            if file_var:
-                files.append(file_var)
+            file = self._get_file_var_from_value(value)
+            if file:
+                files.append(file)
 
         return files
 
-    def _get_file_var_from_value(self, value: Union[dict, list]) -> Optional[dict]:
+    def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any] | None:
         """
         Get file var from value
         :param value: variable value
@@ -685,14 +691,11 @@ class WorkflowCycleManage:
         if not value:
             return None
 
-        if isinstance(value, dict):
-            if "__variant" in value and value["__variant"] == FileVar.__name__:
-                return value
-        elif isinstance(value, FileVar):
+        if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
+            return value
+        elif isinstance(value, File):
             return value.to_dict()
 
-        return None
-
     def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
         """
         Refetch workflow run

+ 0 - 29
api/core/entities/message_entities.py

@@ -1,29 +0,0 @@
-import enum
-from typing import Any
-
-from pydantic import BaseModel
-
-
-class PromptMessageFileType(enum.Enum):
-    IMAGE = "image"
-
-    @staticmethod
-    def value_of(value):
-        for member in PromptMessageFileType:
-            if member.value == value:
-                return member
-        raise ValueError(f"No matching enum found for value '{value}'")
-
-
-class PromptMessageFile(BaseModel):
-    type: PromptMessageFileType
-    data: Any = None
-
-
-class ImagePromptMessageFile(PromptMessageFile):
-    class DETAIL(enum.Enum):
-        LOW = "low"
-        HIGH = "high"
-
-    type: PromptMessageFileType = PromptMessageFileType.IMAGE
-    detail: DETAIL = DETAIL.LOW

+ 19 - 0
api/core/file/__init__.py

@@ -0,0 +1,19 @@
+from .constants import FILE_MODEL_IDENTITY
+from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType
+from .models import (
+    File,
+    FileExtraConfig,
+    ImageConfig,
+)
+
+__all__ = [
+    "FileType",
+    "FileExtraConfig",
+    "FileTransferMethod",
+    "FileBelongsTo",
+    "File",
+    "ImageConfig",
+    "FileAttribute",
+    "ArrayFileAttribute",
+    "FILE_MODEL_IDENTITY",
+]

+ 1 - 0
api/core/file/constants.py

@@ -0,0 +1 @@
+FILE_MODEL_IDENTITY = "__dify__file__"

+ 55 - 0
api/core/file/enums.py

@@ -0,0 +1,55 @@
+from enum import Enum
+
+
+class FileType(str, Enum):
+    IMAGE = "image"
+    DOCUMENT = "document"
+    AUDIO = "audio"
+    VIDEO = "video"
+    CUSTOM = "custom"
+
+    @staticmethod
+    def value_of(value):
+        for member in FileType:
+            if member.value == value:
+                return member
+        raise ValueError(f"No matching enum found for value '{value}'")
+
+
+class FileTransferMethod(str, Enum):
+    REMOTE_URL = "remote_url"
+    LOCAL_FILE = "local_file"
+    TOOL_FILE = "tool_file"
+
+    @staticmethod
+    def value_of(value):
+        for member in FileTransferMethod:
+            if member.value == value:
+                return member
+        raise ValueError(f"No matching enum found for value '{value}'")
+
+
+class FileBelongsTo(str, Enum):
+    USER = "user"
+    ASSISTANT = "assistant"
+
+    @staticmethod
+    def value_of(value):
+        for member in FileBelongsTo:
+            if member.value == value:
+                return member
+        raise ValueError(f"No matching enum found for value '{value}'")
+
+
+class FileAttribute(str, Enum):
+    TYPE = "type"
+    SIZE = "size"
+    NAME = "name"
+    MIME_TYPE = "mime_type"
+    TRANSFER_METHOD = "transfer_method"
+    URL = "url"
+    EXTENSION = "extension"
+
+
+class ArrayFileAttribute(str, Enum):
+    LENGTH = "length"

+ 156 - 0
api/core/file/file_manager.py

@@ -0,0 +1,156 @@
+import base64
+
+from configs import dify_config
+from core.file import file_repository
+from core.helper import ssrf_proxy
+from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent
+from extensions.ext_database import db
+from extensions.ext_storage import storage
+
+from . import helpers
+from .enums import FileAttribute
+from .models import File, FileTransferMethod, FileType
+from .tool_file_parser import ToolFileParser
+
+
+def get_attr(*, file: File, attr: FileAttribute):
+    match attr:
+        case FileAttribute.TYPE:
+            return file.type.value
+        case FileAttribute.SIZE:
+            return file.size
+        case FileAttribute.NAME:
+            return file.filename
+        case FileAttribute.MIME_TYPE:
+            return file.mime_type
+        case FileAttribute.TRANSFER_METHOD:
+            return file.transfer_method.value
+        case FileAttribute.URL:
+            return file.remote_url
+        case FileAttribute.EXTENSION:
+            return file.extension
+        case _:
+            raise ValueError(f"Invalid file attribute: {attr}")
+
+
+def to_prompt_message_content(f: File, /):
+    """
+    Convert a File object to an ImagePromptMessageContent object.
+
+    This function takes a File object and converts it to an ImagePromptMessageContent
+    object, which can be used as a prompt for image-based AI models.
+
+    Args:
+        file (File): The File object to convert. Must be of type FileType.IMAGE.
+
+    Returns:
+        ImagePromptMessageContent: An object containing the image data and detail level.
+
+    Raises:
+        ValueError: If the file is not an image or if the file data is missing.
+
+    Note:
+        The detail level of the image prompt is determined by the file's extra_config.
+        If not specified, it defaults to ImagePromptMessageContent.DETAIL.LOW.
+    """
+    match f.type:
+        case FileType.IMAGE:
+            if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
+                data = _to_url(f)
+            else:
+                data = _to_base64_data_string(f)
+
+            if f._extra_config and f._extra_config.image_config and f._extra_config.image_config.detail:
+                detail = f._extra_config.image_config.detail
+            else:
+                detail = ImagePromptMessageContent.DETAIL.LOW
+
+            return ImagePromptMessageContent(data=data, detail=detail)
+        case FileType.AUDIO:
+            encoded_string = _file_to_encoded_string(f)
+            if f.extension is None:
+                raise ValueError("Missing file extension")
+            return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
+        case _:
+            raise ValueError(f"file type {f.type} is not supported")
+
+
+def download(f: File, /):
+    upload_file = file_repository.get_upload_file(session=db.session(), file=f)
+    return _download_file_content(upload_file.key)
+
+
+def _download_file_content(path: str, /):
+    """
+    Download and return the contents of a file as bytes.
+
+    This function loads the file from storage and ensures it's in bytes format.
+
+    Args:
+        path (str): The path to the file in storage.
+
+    Returns:
+        bytes: The contents of the file as a bytes object.
+
+    Raises:
+        ValueError: If the loaded file is not a bytes object.
+    """
+    data = storage.load(path, stream=False)
+    if not isinstance(data, bytes):
+        raise ValueError(f"file {path} is not a bytes object")
+    return data
+
+
+def _get_encoded_string(f: File, /):
+    match f.transfer_method:
+        case FileTransferMethod.REMOTE_URL:
+            response = ssrf_proxy.get(f.remote_url)
+            response.raise_for_status()
+            content = response.content
+            encoded_string = base64.b64encode(content).decode("utf-8")
+            return encoded_string
+        case FileTransferMethod.LOCAL_FILE:
+            upload_file = file_repository.get_upload_file(session=db.session(), file=f)
+            data = _download_file_content(upload_file.key)
+            encoded_string = base64.b64encode(data).decode("utf-8")
+            return encoded_string
+        case FileTransferMethod.TOOL_FILE:
+            tool_file = file_repository.get_tool_file(session=db.session(), file=f)
+            data = _download_file_content(tool_file.file_key)
+            encoded_string = base64.b64encode(data).decode("utf-8")
+            return encoded_string
+        case _:
+            raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
+
+
+def _to_base64_data_string(f: File, /):
+    encoded_string = _get_encoded_string(f)
+    return f"data:{f.mime_type};base64,{encoded_string}"
+
+
+def _file_to_encoded_string(f: File, /):
+    match f.type:
+        case FileType.IMAGE:
+            return _to_base64_data_string(f)
+        case FileType.AUDIO:
+            return _get_encoded_string(f)
+        case _:
+            raise ValueError(f"file type {f.type} is not supported")
+
+
+def _to_url(f: File, /):
+    if f.transfer_method == FileTransferMethod.REMOTE_URL:
+        if f.remote_url is None:
+            raise ValueError("Missing file remote_url")
+        return f.remote_url
+    elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
+        if f.related_id is None:
+            raise ValueError("Missing file related_id")
+        return helpers.get_signed_file_url(upload_file_id=f.related_id)
+    elif f.transfer_method == FileTransferMethod.TOOL_FILE:
+        # add sign url
+        if f.related_id is None or f.extension is None:
+            raise ValueError("Missing file related_id or extension")
+        return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=f.related_id, extension=f.extension)
+    else:
+        raise ValueError(f"Unsupported transfer method: {f.transfer_method}")

+ 0 - 145
api/core/file/file_obj.py

@@ -1,145 +0,0 @@
-import enum
-from typing import Any, Optional
-
-from pydantic import BaseModel
-
-from core.file.tool_file_parser import ToolFileParser
-from core.file.upload_file_parser import UploadFileParser
-from core.model_runtime.entities.message_entities import ImagePromptMessageContent
-from extensions.ext_database import db
-
-
-class FileExtraConfig(BaseModel):
-    """
-    File Upload Entity.
-    """
-
-    image_config: Optional[dict[str, Any]] = None
-
-
-class FileType(enum.Enum):
-    IMAGE = "image"
-
-    @staticmethod
-    def value_of(value):
-        for member in FileType:
-            if member.value == value:
-                return member
-        raise ValueError(f"No matching enum found for value '{value}'")
-
-
-class FileTransferMethod(enum.Enum):
-    REMOTE_URL = "remote_url"
-    LOCAL_FILE = "local_file"
-    TOOL_FILE = "tool_file"
-
-    @staticmethod
-    def value_of(value):
-        for member in FileTransferMethod:
-            if member.value == value:
-                return member
-        raise ValueError(f"No matching enum found for value '{value}'")
-
-
-class FileBelongsTo(enum.Enum):
-    USER = "user"
-    ASSISTANT = "assistant"
-
-    @staticmethod
-    def value_of(value):
-        for member in FileBelongsTo:
-            if member.value == value:
-                return member
-        raise ValueError(f"No matching enum found for value '{value}'")
-
-
-class FileVar(BaseModel):
-    id: Optional[str] = None  # message file id
-    tenant_id: str
-    type: FileType
-    transfer_method: FileTransferMethod
-    url: Optional[str] = None  # remote url
-    related_id: Optional[str] = None
-    extra_config: Optional[FileExtraConfig] = None
-    filename: Optional[str] = None
-    extension: Optional[str] = None
-    mime_type: Optional[str] = None
-
-    def to_dict(self) -> dict:
-        return {
-            "__variant": self.__class__.__name__,
-            "tenant_id": self.tenant_id,
-            "type": self.type.value,
-            "transfer_method": self.transfer_method.value,
-            "url": self.preview_url,
-            "remote_url": self.url,
-            "related_id": self.related_id,
-            "filename": self.filename,
-            "extension": self.extension,
-            "mime_type": self.mime_type,
-        }
-
-    def to_markdown(self) -> str:
-        """
-        Convert file to markdown
-        :return:
-        """
-        preview_url = self.preview_url
-        if self.type == FileType.IMAGE:
-            text = f'![{self.filename or ""}]({preview_url})'
-        else:
-            text = f"[{self.filename or preview_url}]({preview_url})"
-
-        return text
-
-    @property
-    def data(self) -> Optional[str]:
-        """
-        Get image data, file signed url or base64 data
-        depending on config MULTIMODAL_SEND_IMAGE_FORMAT
-        :return:
-        """
-        return self._get_data()
-
-    @property
-    def preview_url(self) -> Optional[str]:
-        """
-        Get signed preview url
-        :return:
-        """
-        return self._get_data(force_url=True)
-
-    @property
-    def prompt_message_content(self) -> ImagePromptMessageContent:
-        if self.type == FileType.IMAGE:
-            image_config = self.extra_config.image_config
-
-            return ImagePromptMessageContent(
-                data=self.data,
-                detail=ImagePromptMessageContent.DETAIL.HIGH
-                if image_config.get("detail") == "high"
-                else ImagePromptMessageContent.DETAIL.LOW,
-            )
-
-    def _get_data(self, force_url: bool = False) -> Optional[str]:
-        from models.model import UploadFile
-
-        if self.type == FileType.IMAGE:
-            if self.transfer_method == FileTransferMethod.REMOTE_URL:
-                return self.url
-            elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
-                upload_file = (
-                    db.session.query(UploadFile)
-                    .filter(UploadFile.id == self.related_id, UploadFile.tenant_id == self.tenant_id)
-                    .first()
-                )
-
-                return UploadFileParser.get_image_data(upload_file=upload_file, force_url=force_url)
-            elif self.transfer_method == FileTransferMethod.TOOL_FILE:
-                extension = self.extension
-                # add sign url
-                return ToolFileParser.get_tool_file_manager().sign_file(
-                    tool_file_id=self.related_id, extension=extension
-                )
-
-        return None

+ 32 - 0
api/core/file/file_repository.py

@@ -0,0 +1,32 @@
+from sqlalchemy import select
+from sqlalchemy.orm import Session
+
+from models import ToolFile, UploadFile
+
+from .models import File
+
+
+def get_upload_file(*, session: Session, file: File):
+    if file.related_id is None:
+        raise ValueError("Missing file related_id")
+    stmt = select(UploadFile).filter(
+        UploadFile.id == file.related_id,
+        UploadFile.tenant_id == file.tenant_id,
+    )
+    record = session.scalar(stmt)
+    if not record:
+        raise ValueError(f"upload file {file.related_id} not found")
+    return record
+
+
+def get_tool_file(*, session: Session, file: File):
+    if file.related_id is None:
+        raise ValueError("Missing file related_id")
+    stmt = select(ToolFile).filter(
+        ToolFile.id == file.related_id,
+        ToolFile.tenant_id == file.tenant_id,
+    )
+    record = session.scalar(stmt)
+    if not record:
+        raise ValueError(f"tool file {file.related_id} not found")
+    return record

+ 48 - 0
api/core/file/helpers.py

@@ -0,0 +1,48 @@
+import base64
+import hashlib
+import hmac
+import os
+import time
+
+from configs import dify_config
+
+
+def get_signed_file_url(upload_file_id: str) -> str:
+    url = f"{dify_config.FILES_URL}/files/{upload_file_id}/file-preview"
+
+    timestamp = str(int(time.time()))
+    nonce = os.urandom(16).hex()
+    key = dify_config.SECRET_KEY.encode()
+    msg = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
+    sign = hmac.new(key, msg.encode(), hashlib.sha256).digest()
+    encoded_sign = base64.urlsafe_b64encode(sign).decode()
+
+    return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
+
+
+def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
+    data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
+    secret_key = dify_config.SECRET_KEY.encode()
+    recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
+    recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
+
+    # verify signature
+    if sign != recalculated_encoded_sign:
+        return False
+
+    current_time = int(time.time())
+    return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
+
+
+def verify_file_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
+    data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
+    secret_key = dify_config.SECRET_KEY.encode()
+    recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
+    recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
+
+    # verify signature
+    if sign != recalculated_encoded_sign:
+        return False
+
+    current_time = int(time.time())
+    return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT

+ 0 - 243
api/core/file/message_file_parser.py

@@ -1,243 +0,0 @@
-import re
-from collections.abc import Mapping, Sequence
-from typing import Any, Union
-from urllib.parse import parse_qs, urlparse
-
-import requests
-
-from core.file.file_obj import FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType, FileVar
-from extensions.ext_database import db
-from models.account import Account
-from models.model import EndUser, MessageFile, UploadFile
-from services.file_service import IMAGE_EXTENSIONS
-
-
-class MessageFileParser:
-    def __init__(self, tenant_id: str, app_id: str) -> None:
-        self.tenant_id = tenant_id
-        self.app_id = app_id
-
-    def validate_and_transform_files_arg(
-        self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig, user: Union[Account, EndUser]
-    ) -> list[FileVar]:
-        """
-        validate and transform files arg
-
-        :param files:
-        :param file_extra_config:
-        :param user:
-        :return:
-        """
-        for file in files:
-            if not isinstance(file, dict):
-                raise ValueError("Invalid file format, must be dict")
-            if not file.get("type"):
-                raise ValueError("Missing file type")
-            FileType.value_of(file.get("type"))
-            if not file.get("transfer_method"):
-                raise ValueError("Missing file transfer method")
-            FileTransferMethod.value_of(file.get("transfer_method"))
-            if file.get("transfer_method") == FileTransferMethod.REMOTE_URL.value:
-                if not file.get("url"):
-                    raise ValueError("Missing file url")
-                if not file.get("url").startswith("http"):
-                    raise ValueError("Invalid file url")
-            if file.get("transfer_method") == FileTransferMethod.LOCAL_FILE.value and not file.get("upload_file_id"):
-                raise ValueError("Missing file upload_file_id")
-            if file.get("transform_method") == FileTransferMethod.TOOL_FILE.value and not file.get("tool_file_id"):
-                raise ValueError("Missing file tool_file_id")
-
-        # transform files to file objs
-        type_file_objs = self._to_file_objs(files, file_extra_config)
-
-        # validate files
-        new_files = []
-        for file_type, file_objs in type_file_objs.items():
-            if file_type == FileType.IMAGE:
-                # parse and validate files
-                image_config = file_extra_config.image_config
-
-                # check if image file feature is enabled
-                if not image_config:
-                    continue
-
-                # Validate number of files
-                if len(files) > image_config["number_limits"]:
-                    raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}")
-
-                for file_obj in file_objs:
-                    # Validate transfer method
-                    if file_obj.transfer_method.value not in image_config["transfer_methods"]:
-                        raise ValueError(f"Invalid transfer method: {file_obj.transfer_method.value}")
-
-                    # Validate file type
-                    if file_obj.type != FileType.IMAGE:
-                        raise ValueError(f"Invalid file type: {file_obj.type}")
-
-                    if file_obj.transfer_method == FileTransferMethod.REMOTE_URL:
-                        # check remote url valid and is image
-                        result, error = self._check_image_remote_url(file_obj.url)
-                        if result is False:
-                            raise ValueError(error)
-                    elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE:
-                        # get upload file from upload_file_id
-                        upload_file = (
-                            db.session.query(UploadFile)
-                            .filter(
-                                UploadFile.id == file_obj.related_id,
-                                UploadFile.tenant_id == self.tenant_id,
-                                UploadFile.created_by == user.id,
-                                UploadFile.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
-                                UploadFile.extension.in_(IMAGE_EXTENSIONS),
-                            )
-                            .first()
-                        )
-
-                        # check upload file is belong to tenant and user
-                        if not upload_file:
-                            raise ValueError("Invalid upload file")
-
-                    new_files.append(file_obj)
-
-        # return all file objs
-        return new_files
-
-    def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig):
-        """
-        transform message files
-
-        :param files:
-        :param file_extra_config:
-        :return:
-        """
-        # transform files to file objs
-        type_file_objs = self._to_file_objs(files, file_extra_config)
-
-        # return all file objs
-        return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
-
-    def _to_file_objs(
-        self, files: list[Union[dict, MessageFile]], file_extra_config: FileExtraConfig
-    ) -> dict[FileType, list[FileVar]]:
-        """
-        transform files to file objs
-
-        :param files:
-        :param file_extra_config:
-        :return:
-        """
-        type_file_objs: dict[FileType, list[FileVar]] = {
-            # Currently only support image
-            FileType.IMAGE: []
-        }
-
-        if not files:
-            return type_file_objs
-
-        # group by file type and convert file args or message files to FileObj
-        for file in files:
-            if isinstance(file, MessageFile):
-                if file.belongs_to == FileBelongsTo.ASSISTANT.value:
-                    continue
-
-            file_obj = self._to_file_obj(file, file_extra_config)
-            if file_obj.type not in type_file_objs:
-                continue
-
-            type_file_objs[file_obj.type].append(file_obj)
-
-        return type_file_objs
-
-    def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig):
-        """
-        transform file to file obj
-
-        :param file:
-        :return:
-        """
-        if isinstance(file, dict):
-            transfer_method = FileTransferMethod.value_of(file.get("transfer_method"))
-            if transfer_method != FileTransferMethod.TOOL_FILE:
-                return FileVar(
-                    tenant_id=self.tenant_id,
-                    type=FileType.value_of(file.get("type")),
-                    transfer_method=transfer_method,
-                    url=file.get("url") if transfer_method == FileTransferMethod.REMOTE_URL else None,
-                    related_id=file.get("upload_file_id") if transfer_method == FileTransferMethod.LOCAL_FILE else None,
-                    extra_config=file_extra_config,
-                )
-            return FileVar(
-                tenant_id=self.tenant_id,
-                type=FileType.value_of(file.get("type")),
-                transfer_method=transfer_method,
-                url=None,
-                related_id=file.get("tool_file_id"),
-                extra_config=file_extra_config,
-            )
-        else:
-            return FileVar(
-                id=file.id,
-                tenant_id=self.tenant_id,
-                type=FileType.value_of(file.type),
-                transfer_method=FileTransferMethod.value_of(file.transfer_method),
-                url=file.url,
-                related_id=file.upload_file_id or None,
-                extra_config=file_extra_config,
-            )
-
-    def _check_image_remote_url(self, url):
-        try:
-            headers = {
-                "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
-                " Chrome/91.0.4472.124 Safari/537.36"
-            }
-
-            def is_s3_presigned_url(url):
-                try:
-                    parsed_url = urlparse(url)
-                    if "amazonaws.com" not in parsed_url.netloc:
-                        return False
-                    query_params = parse_qs(parsed_url.query)
-
-                    def check_presign_v2(query_params):
-                        required_params = ["Signature", "Expires"]
-                        for param in required_params:
-                            if param not in query_params:
-                                return False
-                        if not query_params["Expires"][0].isdigit():
-                            return False
-                        signature = query_params["Signature"][0]
-                        if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature):
-                            return False
-
-                        return True
-
-                    def check_presign_v4(query_params):
-                        required_params = ["X-Amz-Signature", "X-Amz-Expires"]
-                        for param in required_params:
-                            if param not in query_params:
-                                return False
-                        if not query_params["X-Amz-Expires"][0].isdigit():
-                            return False
-                        signature = query_params["X-Amz-Signature"][0]
-                        if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature):
-                            return False
-
-                        return True
-
-                    return check_presign_v4(query_params) or check_presign_v2(query_params)
-                except Exception:
-                    return False
-
-            if is_s3_presigned_url(url):
-                response = requests.get(url, headers=headers, allow_redirects=True)
-                if response.status_code in {200, 304}:
-                    return True, ""
-
-            response = requests.head(url, headers=headers, allow_redirects=True)
-            if response.status_code in {200, 304}:
-                return True, ""
-            else:
-                return False, "URL does not exist."
-        except requests.RequestException as e:
-            return False, f"Error checking URL: {e}"

+ 140 - 0
api/core/file/models.py

@@ -0,0 +1,140 @@
+from collections.abc import Mapping, Sequence
+from typing import Optional
+
+from pydantic import BaseModel, Field, model_validator
+
+from core.model_runtime.entities.message_entities import ImagePromptMessageContent
+
+from . import helpers
+from .constants import FILE_MODEL_IDENTITY
+from .enums import FileTransferMethod, FileType
+from .tool_file_parser import ToolFileParser
+
+
+class ImageConfig(BaseModel):
+    """
+    NOTE: This part of validation is deprecated, but still used in app features "Image Upload".
+    """
+
+    number_limits: int = 0
+    transfer_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
+    detail: ImagePromptMessageContent.DETAIL | None = None
+
+
+class FileExtraConfig(BaseModel):
+    """
+    File Upload Entity.
+    """
+
+    image_config: Optional[ImageConfig] = None
+    allowed_file_types: Sequence[FileType] = Field(default_factory=list)
+    allowed_extensions: Sequence[str] = Field(default_factory=list)
+    allowed_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
+    number_limits: int = 0
+
+
+class File(BaseModel):
+    dify_model_identity: str = FILE_MODEL_IDENTITY
+
+    id: Optional[str] = None  # message file id
+    tenant_id: str
+    type: FileType
+    transfer_method: FileTransferMethod
+    remote_url: Optional[str] = None  # remote url
+    related_id: Optional[str] = None
+    filename: Optional[str] = None
+    extension: Optional[str] = Field(default=None, description="File extension, should contains dot")
+    mime_type: Optional[str] = None
+    size: int = -1
+    _extra_config: FileExtraConfig | None = None
+
+    def to_dict(self) -> Mapping[str, str | int | None]:
+        data = self.model_dump(mode="json")
+        return {
+            **data,
+            "url": self.generate_url(),
+        }
+
+    @property
+    def markdown(self) -> str:
+        url = self.generate_url()
+        if self.type == FileType.IMAGE:
+            text = f'![{self.filename or ""}]({url})'
+        else:
+            text = f"[{self.filename or url}]({url})"
+
+        return text
+
+    def generate_url(self) -> Optional[str]:
+        if self.type == FileType.IMAGE:
+            if self.transfer_method == FileTransferMethod.REMOTE_URL:
+                return self.remote_url
+            elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
+                if self.related_id is None:
+                    raise ValueError("Missing file related_id")
+                return helpers.get_signed_file_url(upload_file_id=self.related_id)
+            elif self.transfer_method == FileTransferMethod.TOOL_FILE:
+                assert self.related_id is not None
+                assert self.extension is not None
+                return ToolFileParser.get_tool_file_manager().sign_file(
+                    tool_file_id=self.related_id, extension=self.extension
+                )
+        else:
+            if self.transfer_method == FileTransferMethod.REMOTE_URL:
+                return self.remote_url
+            elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
+                if self.related_id is None:
+                    raise ValueError("Missing file related_id")
+                return helpers.get_signed_file_url(upload_file_id=self.related_id)
+            elif self.transfer_method == FileTransferMethod.TOOL_FILE:
+                assert self.related_id is not None
+                assert self.extension is not None
+                return ToolFileParser.get_tool_file_manager().sign_file(
+                    tool_file_id=self.related_id, extension=self.extension
+                )
+
+    @model_validator(mode="after")
+    def validate_after(self):
+        match self.transfer_method:
+            case FileTransferMethod.REMOTE_URL:
+                if not self.remote_url:
+                    raise ValueError("Missing file url")
+                if not isinstance(self.remote_url, str) or not self.remote_url.startswith("http"):
+                    raise ValueError("Invalid file url")
+            case FileTransferMethod.LOCAL_FILE:
+                if not self.related_id:
+                    raise ValueError("Missing file related_id")
+            case FileTransferMethod.TOOL_FILE:
+                if not self.related_id:
+                    raise ValueError("Missing file related_id")
+
+        # Validate the extra config.
+        if not self._extra_config:
+            return self
+
+        if self._extra_config.allowed_file_types:
+            if self.type not in self._extra_config.allowed_file_types and self.type != FileType.CUSTOM:
+                raise ValueError(f"Invalid file type: {self.type}")
+
+        if self._extra_config.allowed_extensions and self.extension not in self._extra_config.allowed_extensions:
+            raise ValueError(f"Invalid file extension: {self.extension}")
+
+        if (
+            self._extra_config.allowed_upload_methods
+            and self.transfer_method not in self._extra_config.allowed_upload_methods
+        ):
+            raise ValueError(f"Invalid transfer method: {self.transfer_method}")
+
+        match self.type:
+            case FileType.IMAGE:
+                # NOTE: This part of validation is deprecated, but still used in app features "Image Upload".
+                if not self._extra_config.image_config:
+                    return self
+                # TODO: skip check if transfer_methods is empty, because many test cases are not setting this field
+                if (
+                    self._extra_config.image_config.transfer_methods
+                    and self.transfer_method not in self._extra_config.image_config.transfer_methods
+                ):
+                    raise ValueError(f"Invalid transfer method: {self.transfer_method}")
+
+        return self

+ 6 - 1
api/core/file/tool_file_parser.py

@@ -1,4 +1,9 @@
-tool_file_manager = {"manager": None}
+from typing import TYPE_CHECKING, Any
+
+if TYPE_CHECKING:
+    from core.tools.tool_file_manager import ToolFileManager
+
+tool_file_manager: dict[str, Any] = {"manager": None}
 
 
 class ToolFileParser:

+ 0 - 79
api/core/file/upload_file_parser.py

@@ -1,79 +0,0 @@
-import base64
-import hashlib
-import hmac
-import logging
-import os
-import time
-from typing import Optional
-
-from configs import dify_config
-from extensions.ext_storage import storage
-
-IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
-IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
-
-
-class UploadFileParser:
-    @classmethod
-    def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]:
-        if not upload_file:
-            return None
-
-        if upload_file.extension not in IMAGE_EXTENSIONS:
-            return None
-
-        if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url" or force_url:
-            return cls.get_signed_temp_image_url(upload_file.id)
-        else:
-            # get image file base64
-            try:
-                data = storage.load(upload_file.key)
-            except FileNotFoundError:
-                logging.error(f"File not found: {upload_file.key}")
-                return None
-
-            encoded_string = base64.b64encode(data).decode("utf-8")
-            return f"data:{upload_file.mime_type};base64,{encoded_string}"
-
-    @classmethod
-    def get_signed_temp_image_url(cls, upload_file_id) -> str:
-        """
-        get signed url from upload file
-
-        :param upload_file: UploadFile object
-        :return:
-        """
-        base_url = dify_config.FILES_URL
-        image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview"
-
-        timestamp = str(int(time.time()))
-        nonce = os.urandom(16).hex()
-        data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
-        secret_key = dify_config.SECRET_KEY.encode()
-        sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
-        encoded_sign = base64.urlsafe_b64encode(sign).decode()
-
-        return f"{image_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
-
-    @classmethod
-    def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
-        """
-        verify signature
-
-        :param upload_file_id: file id
-        :param timestamp: timestamp
-        :param nonce: nonce
-        :param sign: signature
-        :return:
-        """
-        data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
-        secret_key = dify_config.SECRET_KEY.encode()
-        recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
-        recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
-
-        # verify signature
-        if sign != recalculated_encoded_sign:
-            return False
-
-        current_time = int(time.time())
-        return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT

+ 12 - 6
api/core/helper/ssrf_proxy.py

@@ -13,8 +13,11 @@ SSRF_PROXY_HTTP_URL = os.getenv("SSRF_PROXY_HTTP_URL", "")
 SSRF_PROXY_HTTPS_URL = os.getenv("SSRF_PROXY_HTTPS_URL", "")
 SSRF_DEFAULT_MAX_RETRIES = int(os.getenv("SSRF_DEFAULT_MAX_RETRIES", "3"))
 
-proxies = (
-    {"http://": SSRF_PROXY_HTTP_URL, "https://": SSRF_PROXY_HTTPS_URL}
+proxy_mounts = (
+    {
+        "http://": httpx.HTTPTransport(proxy=SSRF_PROXY_HTTP_URL),
+        "https://": httpx.HTTPTransport(proxy=SSRF_PROXY_HTTPS_URL),
+    }
     if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL
     else None
 )
@@ -33,11 +36,14 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
     while retries <= max_retries:
         try:
             if SSRF_PROXY_ALL_URL:
-                response = httpx.request(method=method, url=url, proxy=SSRF_PROXY_ALL_URL, **kwargs)
-            elif proxies:
-                response = httpx.request(method=method, url=url, proxies=proxies, **kwargs)
+                with httpx.Client(proxy=SSRF_PROXY_ALL_URL) as client:
+                    response = client.request(method=method, url=url, **kwargs)
+            elif proxy_mounts:
+                with httpx.Client(mounts=proxy_mounts) as client:
+                    response = client.request(method=method, url=url, **kwargs)
             else:
-                response = httpx.request(method=method, url=url, **kwargs)
+                with httpx.Client() as client:
+                    response = client.request(method=method, url=url, **kwargs)
 
             if response.status_code not in STATUS_FORCELIST:
                 return response

+ 13 - 8
api/core/memory/token_buffer_memory.py

@@ -1,18 +1,20 @@
 from typing import Optional
 
 from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
-from core.file.message_file_parser import MessageFileParser
+from core.file import file_manager
 from core.model_manager import ModelInstance
-from core.model_runtime.entities.message_entities import (
+from core.model_runtime.entities import (
     AssistantPromptMessage,
     ImagePromptMessageContent,
     PromptMessage,
+    PromptMessageContent,
     PromptMessageRole,
     TextPromptMessageContent,
     UserPromptMessage,
 )
 from core.prompt.utils.extract_thread_messages import extract_thread_messages
 from extensions.ext_database import db
+from factories import file_factory
 from models.model import AppMode, Conversation, Message, MessageFile
 from models.workflow import WorkflowRun
 
@@ -65,13 +67,12 @@ class TokenBufferMemory:
 
         messages = list(reversed(thread_messages))
 
-        message_file_parser = MessageFileParser(tenant_id=app_record.tenant_id, app_id=app_record.id)
         prompt_messages = []
         for message in messages:
             files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
             if files:
                 file_extra_config = None
-                if self.conversation.mode not in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
+                if self.conversation.mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
                     file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
                 else:
                     if message.workflow_run_id:
@@ -84,17 +85,21 @@ class TokenBufferMemory:
                                 workflow_run.workflow.features_dict, is_vision=False
                             )
 
-                if file_extra_config:
-                    file_objs = message_file_parser.transform_message_files(files, file_extra_config)
+                if file_extra_config and app_record:
+                    file_objs = file_factory.build_from_message_files(
+                        message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config
+                    )
                 else:
                     file_objs = []
 
                 if not file_objs:
                     prompt_messages.append(UserPromptMessage(content=message.query))
                 else:
-                    prompt_message_contents = [TextPromptMessageContent(data=message.query)]
+                    prompt_message_contents: list[PromptMessageContent] = []
+                    prompt_message_contents.append(TextPromptMessageContent(data=message.query))
                     for file_obj in file_objs:
-                        prompt_message_contents.append(file_obj.prompt_message_content)
+                        prompt_message = file_manager.to_prompt_message_content(file_obj)
+                        prompt_message_contents.append(prompt_message)
 
                     prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
             else:

+ 4 - 4
api/core/model_manager.py

@@ -1,7 +1,7 @@
 import logging
 import os
-from collections.abc import Callable, Generator, Sequence
-from typing import IO, Optional, Union, cast
+from collections.abc import Callable, Generator, Iterable, Sequence
+from typing import IO, Any, Optional, Union, cast
 
 from core.entities.embedding_type import EmbeddingInputType
 from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
@@ -274,7 +274,7 @@ class ModelInstance:
             user=user,
         )
 
-    def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> str:
+    def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> Iterable[bytes]:
         """
         Invoke large language tts model
 
@@ -298,7 +298,7 @@ class ModelInstance:
             voice=voice,
         )
 
-    def _round_robin_invoke(self, function: Callable, *args, **kwargs):
+    def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs):
         """
         Round-robin invoke
         :param function: function to invoke

+ 38 - 0
api/core/model_runtime/entities/__init__.py

@@ -0,0 +1,38 @@
+from .llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
+from .message_entities import (
+    AssistantPromptMessage,
+    AudioPromptMessageContent,
+    ImagePromptMessageContent,
+    PromptMessage,
+    PromptMessageContent,
+    PromptMessageContentType,
+    PromptMessageRole,
+    PromptMessageTool,
+    SystemPromptMessage,
+    TextPromptMessageContent,
+    ToolPromptMessage,
+    UserPromptMessage,
+)
+from .model_entities import ModelPropertyKey
+
+__all__ = [
+    "ImagePromptMessageContent",
+    "PromptMessage",
+    "PromptMessageRole",
+    "LLMUsage",
+    "ModelPropertyKey",
+    "AssistantPromptMessage",
+    "PromptMessage",
+    "PromptMessageContent",
+    "PromptMessageRole",
+    "SystemPromptMessage",
+    "TextPromptMessageContent",
+    "UserPromptMessage",
+    "PromptMessageTool",
+    "ToolPromptMessage",
+    "PromptMessageContentType",
+    "LLMResult",
+    "LLMResultChunk",
+    "LLMResultChunkDelta",
+    "AudioPromptMessageContent",
+]

+ 9 - 2
api/core/model_runtime/entities/message_entities.py

@@ -2,7 +2,7 @@ from abc import ABC
 from enum import Enum
 from typing import Optional
 
-from pydantic import BaseModel, field_validator
+from pydantic import BaseModel, Field, field_validator
 
 
 class PromptMessageRole(Enum):
@@ -55,6 +55,7 @@ class PromptMessageContentType(Enum):
 
     TEXT = "text"
     IMAGE = "image"
+    AUDIO = "audio"
 
 
 class PromptMessageContent(BaseModel):
@@ -74,12 +75,18 @@ class TextPromptMessageContent(PromptMessageContent):
     type: PromptMessageContentType = PromptMessageContentType.TEXT
 
 
+class AudioPromptMessageContent(PromptMessageContent):
+    type: PromptMessageContentType = PromptMessageContentType.AUDIO
+    data: str = Field(..., description="Base64 encoded audio data")
+    format: str = Field(..., description="Audio format")
+
+
 class ImagePromptMessageContent(PromptMessageContent):
     """
     Model class for image prompt message content.
     """
 
-    class DETAIL(Enum):
+    class DETAIL(str, Enum):
         LOW = "low"
         HIGH = "high"
 

+ 12 - 3
api/core/model_runtime/model_providers/__base/large_language_model.py

@@ -1,5 +1,4 @@
 import logging
-import os
 import re
 import time
 from abc import abstractmethod
@@ -8,6 +7,7 @@ from typing import Optional, Union
 
 from pydantic import ConfigDict
 
+from configs import dify_config
 from core.model_runtime.callbacks.base_callback import Callback
 from core.model_runtime.callbacks.logging_callback import LoggingCallback
 from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
@@ -77,7 +77,7 @@ class LargeLanguageModel(AIModel):
 
         callbacks = callbacks or []
 
-        if bool(os.environ.get("DEBUG", "False").lower() == "true"):
+        if dify_config.DEBUG:
             callbacks.append(LoggingCallback())
 
         # trigger before invoke callbacks
@@ -107,7 +107,16 @@ class LargeLanguageModel(AIModel):
                     callbacks=callbacks,
                 )
             else:
-                result = self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
+                result = self._invoke(
+                    model=model,
+                    credentials=credentials,
+                    prompt_messages=prompt_messages,
+                    model_parameters=model_parameters,
+                    tools=tools,
+                    stop=stop,
+                    stream=stream,
+                    user=user,
+                )
         except Exception as e:
             self._trigger_invoke_error_callbacks(
                 model=model,

+ 44 - 25
api/core/model_runtime/model_providers/__base/tts_model.py

@@ -1,6 +1,7 @@
 import logging
 import re
 from abc import abstractmethod
+from collections.abc import Iterable
 from typing import Any, Optional
 
 from pydantic import ConfigDict
@@ -22,8 +23,14 @@ class TTSModel(AIModel):
     model_config = ConfigDict(protected_namespaces=())
 
     def invoke(
-        self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None
-    ):
+        self,
+        model: str,
+        tenant_id: str,
+        credentials: dict,
+        content_text: str,
+        voice: str,
+        user: Optional[str] = None,
+    ) -> Iterable[bytes]:
         """
         Invoke large language model
 
@@ -50,8 +57,14 @@ class TTSModel(AIModel):
 
     @abstractmethod
     def _invoke(
-        self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None
-    ):
+        self,
+        model: str,
+        tenant_id: str,
+        credentials: dict,
+        content_text: str,
+        voice: str,
+        user: Optional[str] = None,
+    ) -> Iterable[bytes]:
         """
         Invoke large language model
 
@@ -68,25 +81,25 @@ class TTSModel(AIModel):
 
     def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
         """
-        Get voice for given tts model voices
+        Retrieves the list of voices supported by a given text-to-speech (TTS) model.
 
-        :param language: tts language
-        :param model: model name
-        :param credentials: model credentials
-        :return: voices lists
+        :param language: The language for which the voices are requested.
+        :param model: The name of the TTS model.
+        :param credentials: The credentials required to access the TTS model.
+        :return: A list of voices supported by the TTS model.
         """
         model_schema = self.get_model_schema(model, credentials)
 
-        if model_schema and ModelPropertyKey.VOICES in model_schema.model_properties:
-            voices = model_schema.model_properties[ModelPropertyKey.VOICES]
-            if language:
-                return [
-                    {"name": d["name"], "value": d["mode"]}
-                    for d in voices
-                    if language and language in d.get("language")
-                ]
-            else:
-                return [{"name": d["name"], "value": d["mode"]} for d in voices]
+        if not model_schema or ModelPropertyKey.VOICES not in model_schema.model_properties:
+            raise ValueError("this model does not support voice")
+
+        voices = model_schema.model_properties[ModelPropertyKey.VOICES]
+        if language:
+            return [
+                {"name": d["name"], "value": d["mode"]} for d in voices if language and language in d.get("language")
+            ]
+        else:
+            return [{"name": d["name"], "value": d["mode"]} for d in voices]
 
     def _get_model_default_voice(self, model: str, credentials: dict) -> Any:
         """
@@ -111,8 +124,10 @@ class TTSModel(AIModel):
         """
         model_schema = self.get_model_schema(model, credentials)
 
-        if model_schema and ModelPropertyKey.AUDIO_TYPE in model_schema.model_properties:
-            return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE]
+        if not model_schema or ModelPropertyKey.AUDIO_TYPE not in model_schema.model_properties:
+            raise ValueError("this model does not support audio type")
+
+        return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE]
 
     def _get_model_word_limit(self, model: str, credentials: dict) -> int:
         """
@@ -121,8 +136,10 @@ class TTSModel(AIModel):
         """
         model_schema = self.get_model_schema(model, credentials)
 
-        if model_schema and ModelPropertyKey.WORD_LIMIT in model_schema.model_properties:
-            return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT]
+        if not model_schema or ModelPropertyKey.WORD_LIMIT not in model_schema.model_properties:
+            raise ValueError("this model does not support word limit")
+
+        return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT]
 
     def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
         """
@@ -131,8 +148,10 @@ class TTSModel(AIModel):
         """
         model_schema = self.get_model_schema(model, credentials)
 
-        if model_schema and ModelPropertyKey.MAX_WORKERS in model_schema.model_properties:
-            return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
+        if not model_schema or ModelPropertyKey.MAX_WORKERS not in model_schema.model_properties:
+            raise ValueError("this model does not support max workers")
+
+        return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
 
     @staticmethod
     def _split_text_into_sentences(org_text, max_length=2000, pattern=r"[。.!?]"):

+ 1 - 0
api/core/model_runtime/model_providers/openai/llm/_position.yaml

@@ -1,3 +1,4 @@
+- gpt-4o-audio-preview
 - gpt-4
 - gpt-4o
 - gpt-4o-2024-05-13

+ 44 - 0
api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml

@@ -0,0 +1,44 @@
+model: gpt-4o-audio-preview
+label:
+  zh_Hans: gpt-4o-audio-preview
+  en_US: gpt-4o-audio-preview
+model_type: llm
+features:
+  - multi-tool-call
+  - agent-thought
+  - stream-tool-call
+  - vision
+model_properties:
+  mode: chat
+  context_size: 128000
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: presence_penalty
+    use_template: presence_penalty
+  - name: frequency_penalty
+    use_template: frequency_penalty
+  - name: max_tokens
+    use_template: max_tokens
+    default: 512
+    min: 1
+    max: 4096
+  - name: response_format
+    label:
+      zh_Hans: 回复格式
+      en_US: Response Format
+    type: string
+    help:
+      zh_Hans: 指定模型必须输出的格式
+      en_US: specifying the format that the model must output
+    required: false
+    options:
+      - text
+      - json_object
+pricing:
+  input: '5.00'
+  output: '15.00'
+  unit: '0.000001'
+  currency: USD

+ 19 - 10
api/core/model_runtime/model_providers/openai/llm/llm.py

@@ -1,7 +1,7 @@
 import json
 import logging
 from collections.abc import Generator
-from typing import Optional, Union, cast
+from typing import Any, Optional, Union, cast
 
 import tiktoken
 from openai import OpenAI, Stream
@@ -11,9 +11,9 @@ from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, Cho
 from openai.types.chat.chat_completion_message import FunctionCall
 
 from core.model_runtime.callbacks.base_callback import Callback
-from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
-from core.model_runtime.entities.message_entities import (
+from core.model_runtime.entities import (
     AssistantPromptMessage,
+    AudioPromptMessageContent,
     ImagePromptMessageContent,
     PromptMessage,
     PromptMessageContentType,
@@ -23,6 +23,7 @@ from core.model_runtime.entities.message_entities import (
     ToolPromptMessage,
     UserPromptMessage,
 )
+from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
 from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType, PriceConfig
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
@@ -613,6 +614,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
         # clear illegal prompt messages
         prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
 
+        # o1 compatibility
         block_as_stream = False
         if model.startswith("o1"):
             if stream:
@@ -626,8 +628,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
                 del extra_model_kwargs["stop"]
 
         # chat model
+        messages: Any = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
         response = client.chat.completions.create(
-            messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
+            messages=messages,
             model=model,
             stream=stream,
             **model_parameters,
@@ -946,23 +949,29 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
         Convert PromptMessage to dict for OpenAI API
         """
         if isinstance(message, UserPromptMessage):
-            message = cast(UserPromptMessage, message)
             if isinstance(message.content, str):
                 message_dict = {"role": "user", "content": message.content}
-            else:
+            elif isinstance(message.content, list):
                 sub_messages = []
                 for message_content in message.content:
-                    if message_content.type == PromptMessageContentType.TEXT:
-                        message_content = cast(TextPromptMessageContent, message_content)
+                    if isinstance(message_content, TextPromptMessageContent):
                         sub_message_dict = {"type": "text", "text": message_content.data}
                         sub_messages.append(sub_message_dict)
-                    elif message_content.type == PromptMessageContentType.IMAGE:
-                        message_content = cast(ImagePromptMessageContent, message_content)
+                    elif isinstance(message_content, ImagePromptMessageContent):
                         sub_message_dict = {
                             "type": "image_url",
                             "image_url": {"url": message_content.data, "detail": message_content.detail.value},
                         }
                         sub_messages.append(sub_message_dict)
+                    elif isinstance(message_content, AudioPromptMessageContent):
+                        sub_message_dict = {
+                            "type": "input_audio",
+                            "input_audio": {
+                                "data": message_content.data,
+                                "format": message_content.format,
+                            },
+                        }
+                        sub_messages.append(sub_message_dict)
 
                 message_dict = {"role": "user", "content": sub_messages}
         elif isinstance(message, AssistantPromptMessage):

+ 2 - 2
api/core/ops/ops_trace_manager.py

@@ -358,8 +358,8 @@ class TraceTask:
         workflow_run_id = workflow_run.id
         workflow_run_elapsed_time = workflow_run.elapsed_time
         workflow_run_status = workflow_run.status
-        workflow_run_inputs = json.loads(workflow_run.inputs) if workflow_run.inputs else {}
-        workflow_run_outputs = json.loads(workflow_run.outputs) if workflow_run.outputs else {}
+        workflow_run_inputs = workflow_run.inputs_dict
+        workflow_run_outputs = workflow_run.outputs_dict
         workflow_run_version = workflow_run.version
         error = workflow_run.error or ""
 

+ 62 - 59
api/core/prompt/advanced_prompt_transform.py

@@ -1,12 +1,15 @@
-from typing import Optional, Union
+from collections.abc import Sequence
+from typing import Optional
 
 from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
-from core.file.file_obj import FileVar
+from core.file import file_manager
+from core.file.models import File
 from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter
 from core.memory.token_buffer_memory import TokenBufferMemory
-from core.model_runtime.entities.message_entities import (
+from core.model_runtime.entities import (
     AssistantPromptMessage,
     PromptMessage,
+    PromptMessageContent,
     PromptMessageRole,
     SystemPromptMessage,
     TextPromptMessageContent,
@@ -14,8 +17,8 @@ from core.model_runtime.entities.message_entities import (
 )
 from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
 from core.prompt.prompt_transform import PromptTransform
-from core.prompt.simple_prompt_transform import ModelMode
 from core.prompt.utils.prompt_template_parser import PromptTemplateParser
+from core.workflow.entities.variable_pool import VariablePool
 
 
 class AdvancedPromptTransform(PromptTransform):
@@ -28,22 +31,19 @@ class AdvancedPromptTransform(PromptTransform):
 
     def get_prompt(
         self,
-        prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate],
-        inputs: dict,
+        *,
+        prompt_template: Sequence[ChatModelMessage] | CompletionModelPromptTemplate,
+        inputs: dict[str, str],
         query: str,
-        files: list[FileVar],
+        files: Sequence[File],
         context: Optional[str],
         memory_config: Optional[MemoryConfig],
         memory: Optional[TokenBufferMemory],
         model_config: ModelConfigWithCredentialsEntity,
-        query_prompt_template: Optional[str] = None,
     ) -> list[PromptMessage]:
-        inputs = {key: str(value) for key, value in inputs.items()}
-
         prompt_messages = []
 
-        model_mode = ModelMode.value_of(model_config.mode)
-        if model_mode == ModelMode.COMPLETION:
+        if isinstance(prompt_template, CompletionModelPromptTemplate):
             prompt_messages = self._get_completion_model_prompt_messages(
                 prompt_template=prompt_template,
                 inputs=inputs,
@@ -54,12 +54,11 @@ class AdvancedPromptTransform(PromptTransform):
                 memory=memory,
                 model_config=model_config,
             )
-        elif model_mode == ModelMode.CHAT:
+        elif isinstance(prompt_template, list) and all(isinstance(item, ChatModelMessage) for item in prompt_template):
             prompt_messages = self._get_chat_model_prompt_messages(
                 prompt_template=prompt_template,
                 inputs=inputs,
                 query=query,
-                query_prompt_template=query_prompt_template,
                 files=files,
                 context=context,
                 memory_config=memory_config,
@@ -74,7 +73,7 @@ class AdvancedPromptTransform(PromptTransform):
         prompt_template: CompletionModelPromptTemplate,
         inputs: dict,
         query: Optional[str],
-        files: list[FileVar],
+        files: Sequence[File],
         context: Optional[str],
         memory_config: Optional[MemoryConfig],
         memory: Optional[TokenBufferMemory],
@@ -88,10 +87,10 @@ class AdvancedPromptTransform(PromptTransform):
         prompt_messages = []
 
         if prompt_template.edition_type == "basic" or not prompt_template.edition_type:
-            prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
-            prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
+            parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
+            prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
 
-            prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
+            prompt_inputs = self._set_context_variable(context, parser, prompt_inputs)
 
             if memory and memory_config:
                 role_prefix = memory_config.role_prefix
@@ -100,15 +99,15 @@ class AdvancedPromptTransform(PromptTransform):
                     memory_config=memory_config,
                     raw_prompt=raw_prompt,
                     role_prefix=role_prefix,
-                    prompt_template=prompt_template,
+                    parser=parser,
                     prompt_inputs=prompt_inputs,
                     model_config=model_config,
                 )
 
             if query:
-                prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs)
+                prompt_inputs = self._set_query_variable(query, parser, prompt_inputs)
 
-            prompt = prompt_template.format(prompt_inputs)
+            prompt = parser.format(prompt_inputs)
         else:
             prompt = raw_prompt
             prompt_inputs = inputs
@@ -116,9 +115,10 @@ class AdvancedPromptTransform(PromptTransform):
             prompt = Jinja2Formatter.format(prompt, prompt_inputs)
 
         if files:
-            prompt_message_contents = [TextPromptMessageContent(data=prompt)]
+            prompt_message_contents: list[PromptMessageContent] = []
+            prompt_message_contents.append(TextPromptMessageContent(data=prompt))
             for file in files:
-                prompt_message_contents.append(file.prompt_message_content)
+                prompt_message_contents.append(file_manager.to_prompt_message_content(file))
 
             prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
         else:
@@ -131,35 +131,38 @@ class AdvancedPromptTransform(PromptTransform):
         prompt_template: list[ChatModelMessage],
         inputs: dict,
         query: Optional[str],
-        files: list[FileVar],
+        files: Sequence[File],
         context: Optional[str],
         memory_config: Optional[MemoryConfig],
         memory: Optional[TokenBufferMemory],
         model_config: ModelConfigWithCredentialsEntity,
-        query_prompt_template: Optional[str] = None,
     ) -> list[PromptMessage]:
         """
         Get chat model prompt messages.
         """
-        raw_prompt_list = prompt_template
-
         prompt_messages = []
-
-        for prompt_item in raw_prompt_list:
+        for prompt_item in prompt_template:
             raw_prompt = prompt_item.text
 
             if prompt_item.edition_type == "basic" or not prompt_item.edition_type:
-                prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
-                prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
-
-                prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
-
-                prompt = prompt_template.format(prompt_inputs)
+                if self.with_variable_tmpl:
+                    vp = VariablePool()
+                    for k, v in inputs.items():
+                        if k.startswith("#"):
+                            vp.add(k[1:-1].split("."), v)
+                    raw_prompt = raw_prompt.replace("{{#context#}}", context or "")
+                    prompt = vp.convert_template(raw_prompt).text
+                else:
+                    parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
+                    prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
+                    prompt_inputs = self._set_context_variable(
+                        context=context, parser=parser, prompt_inputs=prompt_inputs
+                    )
+                    prompt = parser.format(prompt_inputs)
             elif prompt_item.edition_type == "jinja2":
                 prompt = raw_prompt
                 prompt_inputs = inputs
-
-                prompt = Jinja2Formatter.format(prompt, prompt_inputs)
+                prompt = Jinja2Formatter.format(template=prompt, inputs=prompt_inputs)
             else:
                 raise ValueError(f"Invalid edition type: {prompt_item.edition_type}")
 
@@ -170,25 +173,25 @@ class AdvancedPromptTransform(PromptTransform):
             elif prompt_item.role == PromptMessageRole.ASSISTANT:
                 prompt_messages.append(AssistantPromptMessage(content=prompt))
 
-        if query and query_prompt_template:
-            prompt_template = PromptTemplateParser(
-                template=query_prompt_template, with_variable_tmpl=self.with_variable_tmpl
+        if query and memory_config and memory_config.query_prompt_template:
+            parser = PromptTemplateParser(
+                template=memory_config.query_prompt_template, with_variable_tmpl=self.with_variable_tmpl
             )
-            prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
+            prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
             prompt_inputs["#sys.query#"] = query
 
-            prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
+            prompt_inputs = self._set_context_variable(context, parser, prompt_inputs)
 
-            query = prompt_template.format(prompt_inputs)
+            query = parser.format(prompt_inputs)
 
         if memory and memory_config:
             prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
 
-            if files:
-                prompt_message_contents = [TextPromptMessageContent(data=query)]
+            if files and query is not None:
+                prompt_message_contents: list[PromptMessageContent] = []
+                prompt_message_contents.append(TextPromptMessageContent(data=query))
                 for file in files:
-                    prompt_message_contents.append(file.prompt_message_content)
-
+                    prompt_message_contents.append(file_manager.to_prompt_message_content(file))
                 prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
             else:
                 prompt_messages.append(UserPromptMessage(content=query))
@@ -200,19 +203,19 @@ class AdvancedPromptTransform(PromptTransform):
                     # get last user message content and add files
                     prompt_message_contents = [TextPromptMessageContent(data=last_message.content)]
                     for file in files:
-                        prompt_message_contents.append(file.prompt_message_content)
+                        prompt_message_contents.append(file_manager.to_prompt_message_content(file))
 
                     last_message.content = prompt_message_contents
                 else:
                     prompt_message_contents = [TextPromptMessageContent(data="")]  # not for query
                     for file in files:
-                        prompt_message_contents.append(file.prompt_message_content)
+                        prompt_message_contents.append(file_manager.to_prompt_message_content(file))
 
                     prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
             else:
                 prompt_message_contents = [TextPromptMessageContent(data=query)]
                 for file in files:
-                    prompt_message_contents.append(file.prompt_message_content)
+                    prompt_message_contents.append(file_manager.to_prompt_message_content(file))
 
                 prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
         elif query:
@@ -220,8 +223,8 @@ class AdvancedPromptTransform(PromptTransform):
 
         return prompt_messages
 
-    def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict:
-        if "#context#" in prompt_template.variable_keys:
+    def _set_context_variable(self, context: str | None, parser: PromptTemplateParser, prompt_inputs: dict) -> dict:
+        if "#context#" in parser.variable_keys:
             if context:
                 prompt_inputs["#context#"] = context
             else:
@@ -229,8 +232,8 @@ class AdvancedPromptTransform(PromptTransform):
 
         return prompt_inputs
 
-    def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict:
-        if "#query#" in prompt_template.variable_keys:
+    def _set_query_variable(self, query: str, parser: PromptTemplateParser, prompt_inputs: dict) -> dict:
+        if "#query#" in parser.variable_keys:
             if query:
                 prompt_inputs["#query#"] = query
             else:
@@ -244,16 +247,16 @@ class AdvancedPromptTransform(PromptTransform):
         memory_config: MemoryConfig,
         raw_prompt: str,
         role_prefix: MemoryConfig.RolePrefix,
-        prompt_template: PromptTemplateParser,
+        parser: PromptTemplateParser,
         prompt_inputs: dict,
         model_config: ModelConfigWithCredentialsEntity,
     ) -> dict:
-        if "#histories#" in prompt_template.variable_keys:
+        if "#histories#" in parser.variable_keys:
             if memory:
                 inputs = {"#histories#": "", **prompt_inputs}
-                prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
-                prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
-                tmp_human_message = UserPromptMessage(content=prompt_template.format(prompt_inputs))
+                parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
+                prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
+                tmp_human_message = UserPromptMessage(content=parser.format(prompt_inputs))
 
                 rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
 

+ 11 - 8
api/core/prompt/simple_prompt_transform.py

@@ -5,9 +5,11 @@ from typing import TYPE_CHECKING, Optional
 
 from core.app.app_config.entities import PromptTemplateEntity
 from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
+from core.file import file_manager
 from core.memory.token_buffer_memory import TokenBufferMemory
 from core.model_runtime.entities.message_entities import (
     PromptMessage,
+    PromptMessageContent,
     SystemPromptMessage,
     TextPromptMessageContent,
     UserPromptMessage,
@@ -18,10 +20,10 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser
 from models.model import AppMode
 
 if TYPE_CHECKING:
-    from core.file.file_obj import FileVar
+    from core.file.models import File
 
 
-class ModelMode(enum.Enum):
+class ModelMode(str, enum.Enum):
     COMPLETION = "completion"
     CHAT = "chat"
 
@@ -53,7 +55,7 @@ class SimplePromptTransform(PromptTransform):
         prompt_template_entity: PromptTemplateEntity,
         inputs: dict,
         query: str,
-        files: list["FileVar"],
+        files: list["File"],
         context: Optional[str],
         memory: Optional[TokenBufferMemory],
         model_config: ModelConfigWithCredentialsEntity,
@@ -169,7 +171,7 @@ class SimplePromptTransform(PromptTransform):
         inputs: dict,
         query: str,
         context: Optional[str],
-        files: list["FileVar"],
+        files: list["File"],
         memory: Optional[TokenBufferMemory],
         model_config: ModelConfigWithCredentialsEntity,
     ) -> tuple[list[PromptMessage], Optional[list[str]]]:
@@ -214,7 +216,7 @@ class SimplePromptTransform(PromptTransform):
         inputs: dict,
         query: str,
         context: Optional[str],
-        files: list["FileVar"],
+        files: list["File"],
         memory: Optional[TokenBufferMemory],
         model_config: ModelConfigWithCredentialsEntity,
     ) -> tuple[list[PromptMessage], Optional[list[str]]]:
@@ -261,11 +263,12 @@ class SimplePromptTransform(PromptTransform):
 
         return [self.get_last_user_message(prompt, files)], stops
 
-    def get_last_user_message(self, prompt: str, files: list["FileVar"]) -> UserPromptMessage:
+    def get_last_user_message(self, prompt: str, files: list["File"]) -> UserPromptMessage:
         if files:
-            prompt_message_contents = [TextPromptMessageContent(data=prompt)]
+            prompt_message_contents: list[PromptMessageContent] = []
+            prompt_message_contents.append(TextPromptMessageContent(data=prompt))
             for file in files:
-                prompt_message_contents.append(file.prompt_message_content)
+                prompt_message_contents.append(file_manager.to_prompt_message_content(file))
 
             prompt_message = UserPromptMessage(content=prompt_message_contents)
         else:

+ 3 - 1
api/core/prompt/utils/extract_thread_messages.py

@@ -1,7 +1,9 @@
+from typing import Any
+
 from constants import UUID_NIL
 
 
-def extract_thread_messages(messages: list[dict]) -> list[dict]:
+def extract_thread_messages(messages: list[Any]):
     thread_messages = []
     next_message = None
 

+ 13 - 6
api/core/prompt/utils/prompt_message_util.py

@@ -1,7 +1,8 @@
 from typing import cast
 
-from core.model_runtime.entities.message_entities import (
+from core.model_runtime.entities import (
     AssistantPromptMessage,
+    AudioPromptMessageContent,
     ImagePromptMessageContent,
     PromptMessage,
     PromptMessageContentType,
@@ -21,7 +22,7 @@ class PromptMessageUtil:
         :return:
         """
         prompts = []
-        if model_mode == ModelMode.CHAT.value:
+        if model_mode == ModelMode.CHAT:
             tool_calls = []
             for prompt_message in prompt_messages:
                 if prompt_message.role == PromptMessageRole.USER:
@@ -51,11 +52,9 @@ class PromptMessageUtil:
                 files = []
                 if isinstance(prompt_message.content, list):
                     for content in prompt_message.content:
-                        if content.type == PromptMessageContentType.TEXT:
-                            content = cast(TextPromptMessageContent, content)
+                        if isinstance(content, TextPromptMessageContent):
                             text += content.data
-                        else:
-                            content = cast(ImagePromptMessageContent, content)
+                        elif isinstance(content, ImagePromptMessageContent):
                             files.append(
                                 {
                                     "type": "image",
@@ -63,6 +62,14 @@ class PromptMessageUtil:
                                     "detail": content.detail.value,
                                 }
                             )
+                        elif isinstance(content, AudioPromptMessageContent):
+                            files.append(
+                                {
+                                    "type": "audio",
+                                    "data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:],
+                                    "format": content.format,
+                                }
+                            )
                 else:
                     text = prompt_message.content
 

+ 1 - 1
api/core/rag/extractor/word_extractor.py

@@ -121,7 +121,7 @@ class WordExtractor(BaseExtractor):
                 db.session.add(upload_file)
                 db.session.commit()
                 image_map[rel.target_part] = (
-                    f"![image]({dify_config.CONSOLE_API_URL}/files/{upload_file.id}/image-preview)"
+                    f"![image]({dify_config.CONSOLE_API_URL}/files/{upload_file.id}/file-preview)"
                 )
 
         return image_map

+ 1 - 1
api/core/rag/retrieval/router/multi_dataset_react_route.py

@@ -9,7 +9,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
 from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
 from core.rag.retrieval.output_parser.react_output import ReactAction
 from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
-from core.workflow.nodes.llm.llm_node import LLMNode
+from core.workflow.nodes.llm import LLMNode
 
 PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
 

+ 3 - 3
api/core/tools/entities/api_entities.py

@@ -32,8 +32,8 @@ class UserToolProvider(BaseModel):
     original_credentials: Optional[dict] = None
     is_team_authorization: bool = False
     allow_delete: bool = True
-    tools: list[UserTool] = None
-    labels: list[str] = None
+    tools: list[UserTool] | None = None
+    labels: list[str] | None = None
 
     def to_dict(self) -> dict:
         # -------------
@@ -42,7 +42,7 @@ class UserToolProvider(BaseModel):
         for tool in tools:
             if tool.get("parameters"):
                 for parameter in tool.get("parameters"):
-                    if parameter.get("type") == ToolParameter.ToolParameterType.FILE.value:
+                    if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value:
                         parameter["type"] = "files"
         # -------------
 

+ 64 - 2
api/core/tools/entities/tool_entities.py

@@ -104,14 +104,15 @@ class ToolInvokeMessage(BaseModel):
         BLOB = "blob"
         JSON = "json"
         IMAGE_LINK = "image_link"
-        FILE_VAR = "file_var"
+        FILE = "file"
 
     type: MessageType = MessageType.TEXT
     """
         plain text, image url or link url
     """
     message: str | bytes | dict | None = None
-    meta: dict[str, Any] | None = None
+    # TODO: Use a BaseModel for meta
+    meta: dict[str, Any] = Field(default_factory=dict)
     save_as: str = ""
 
 
@@ -143,6 +144,67 @@ class ToolParameter(BaseModel):
         SELECT = "select"
         SECRET_INPUT = "secret-input"
         FILE = "file"
+        FILES = "files"
+
+        # deprecated, should not use.
+        SYSTEM_FILES = "systme-files"
+
+        def as_normal_type(self):
+            if self in {
+                ToolParameter.ToolParameterType.SECRET_INPUT,
+                ToolParameter.ToolParameterType.SELECT,
+            }:
+                return "string"
+            return self.value
+
+        def cast_value(self, value: Any, /):
+            try:
+                match self:
+                    case (
+                        ToolParameter.ToolParameterType.STRING
+                        | ToolParameter.ToolParameterType.SECRET_INPUT
+                        | ToolParameter.ToolParameterType.SELECT
+                    ):
+                        if value is None:
+                            return ""
+                        else:
+                            return value if isinstance(value, str) else str(value)
+
+                    case ToolParameter.ToolParameterType.BOOLEAN:
+                        if value is None:
+                            return False
+                        elif isinstance(value, str):
+                            # Allowed YAML boolean value strings: https://yaml.org/type/bool.html
+                            # and also '0' for False and '1' for True
+                            match value.lower():
+                                case "true" | "yes" | "y" | "1":
+                                    return True
+                                case "false" | "no" | "n" | "0":
+                                    return False
+                                case _:
+                                    return bool(value)
+                        else:
+                            return value if isinstance(value, bool) else bool(value)
+
+                    case ToolParameter.ToolParameterType.NUMBER:
+                        if isinstance(value, int | float):
+                            return value
+                        elif isinstance(value, str) and value:
+                            if "." in value:
+                                return float(value)
+                            else:
+                                return int(value)
+                    case (
+                        ToolParameter.ToolParameterType.SYSTEM_FILES
+                        | ToolParameter.ToolParameterType.FILE
+                        | ToolParameter.ToolParameterType.FILES
+                    ):
+                        return value
+                    case _:
+                        return str(value)
+
+            except Exception:
+                raise ValueError(f"The tool parameter value {value} is not in correct type of {parameter_type}.")
 
     class ToolParameterForm(Enum):
         SCHEMA = "schema"  # should be set while adding tool

+ 1 - 1
api/core/tools/provider/builtin/dalle/tools/dalle3.py

@@ -66,7 +66,7 @@ class DallE3Tool(BuiltinTool):
         for image in response.data:
             mime_type, blob_image = DallE3Tool._decode_image(image.b64_json)
             blob_message = self.create_blob_message(
-                blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VariableKey.IMAGE.value
+                blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VariableKey.IMAGE
             )
             result.append(blob_message)
         return result

+ 1 - 1
api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py

@@ -2,7 +2,7 @@ from typing import Any
 
 from duckduckgo_search import DDGS
 
-from core.file.file_obj import FileTransferMethod
+from core.file.models import FileTransferMethod
 from core.tools.entities.tool_entities import ToolInvokeMessage
 from core.tools.tool.builtin_tool import BuiltinTool
 

+ 24 - 0
api/core/tools/provider/builtin/podcast_generator/_assets/icon.svg

@@ -0,0 +1,24 @@
+<svg width="100" height="100" viewBox="0 0 100 100" fill="none" xmlns="http://www.w3.org/2000/svg">
+    <rect width="100" height="100" rx="20" fill="#4A90E2" />
+    <path
+        d="M50 25C40.6 25 33 32.6 33 42V58C33 67.4 40.6 75 50 75C59.4 75 67 67.4 67 58V42C67 32.6 59.4 25 50 25ZM61 58C61 64.1 56.1 69 50 69C43.9 69 39 64.1 39 58V42C39 35.9 43.9 31 50 31C56.1 31 61 35.9 61 42V58Z"
+        fill="white" />
+    <path d="M50 37C47.2 37 45 39.2 45 42V58C45 60.8 47.2 63 50 63C52.8 63 55 60.8 55 58V42C55 39.2 52.8 37 50 37Z"
+        fill="white" />
+    <path
+        d="M73 49H69V58C69 68.5 60.5 77 50 77C39.5 77 31 68.5 31 58V49H27V58C27 70.7 37.3 81 50 81C62.7 81 73 70.7 73 58V49Z"
+        fill="white" />
+    <path d="M50 85C51.1 85 52 84.1 52 83V81H48V83C48 84.1 48.9 85 50 85Z" fill="white" />
+    <path
+        d="M35 45C36.1046 45 37 44.1046 37 43C37 41.8954 36.1046 41 35 41C33.8954 41 33 41.8954 33 43C33 44.1046 33.8954 45 35 45Z"
+        fill="white" />
+    <path
+        d="M35 55C36.1046 55 37 54.1046 37 53C37 51.8954 36.1046 51 35 51C33.8954 51 33 51.8954 33 53C33 54.1046 33.8954 55 35 55Z"
+        fill="white" />
+    <path
+        d="M65 45C66.1046 45 67 44.1046 67 43C67 41.8954 66.1046 41 65 41C63.8954 41 63 41.8954 63 43C63 44.1046 63.8954 45 65 45Z"
+        fill="white" />
+    <path
+        d="M65 55C66.1046 55 67 54.1046 67 53C67 51.8954 66.1046 51 65 51C63.8954 51 63 51.8954 63 53C63 54.1046 63.8954 55 65 55Z"
+        fill="white" />
+</svg>

+ 33 - 0
api/core/tools/provider/builtin/podcast_generator/podcast_generator.py

@@ -0,0 +1,33 @@
+from typing import Any
+
+import openai
+
+from core.tools.errors import ToolProviderCredentialValidationError
+from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
+
+
+class PodcastGeneratorProvider(BuiltinToolProviderController):
+    def _validate_credentials(self, credentials: dict[str, Any]) -> None:
+        tts_service = credentials.get("tts_service")
+        api_key = credentials.get("api_key")
+
+        if not tts_service:
+            raise ToolProviderCredentialValidationError("TTS service is not specified")
+
+        if not api_key:
+            raise ToolProviderCredentialValidationError("API key is missing")
+
+        if tts_service == "openai":
+            self._validate_openai_credentials(api_key)
+        else:
+            raise ToolProviderCredentialValidationError(f"Unsupported TTS service: {tts_service}")
+
+    def _validate_openai_credentials(self, api_key: str) -> None:
+        client = openai.OpenAI(api_key=api_key)
+        try:
+            # We're using a simple API call to validate the credentials
+            client.models.list()
+        except openai.AuthenticationError:
+            raise ToolProviderCredentialValidationError("Invalid OpenAI API key")
+        except Exception as e:
+            raise ToolProviderCredentialValidationError(f"Error validating OpenAI API key: {str(e)}")

+ 34 - 0
api/core/tools/provider/builtin/podcast_generator/podcast_generator.yaml

@@ -0,0 +1,34 @@
+identity:
+  author: Dify
+  name: podcast_generator
+  label:
+    en_US: Podcast Generator
+    zh_Hans: 播客生成器
+  description:
+    en_US: Generate podcast audio using Text-to-Speech services
+    zh_Hans: 使用文字转语音服务生成播客音频
+  icon: icon.svg
+credentials_for_provider:
+  tts_service:
+    type: select
+    required: true
+    label:
+      en_US: TTS Service
+      zh_Hans: TTS 服务
+    placeholder:
+      en_US: Select a TTS service
+      zh_Hans: 选择一个 TTS 服务
+    options:
+      - label:
+          en_US: OpenAI TTS
+          zh_Hans: OpenAI TTS
+        value: openai
+  api_key:
+    type: secret-input
+    required: true
+    label:
+      en_US: API Key
+      zh_Hans: API 密钥
+    placeholder:
+      en_US: Enter your TTS service API key
+      zh_Hans: 输入您的 TTS 服务 API 密钥

+ 100 - 0
api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py

@@ -0,0 +1,100 @@
+import concurrent.futures
+import io
+import random
+from typing import Any, Literal, Optional, Union
+
+import openai
+from pydub import AudioSegment
+
+from core.tools.entities.tool_entities import ToolInvokeMessage
+from core.tools.errors import ToolParameterValidationError, ToolProviderCredentialValidationError
+from core.tools.tool.builtin_tool import BuiltinTool
+
+
+class PodcastAudioGeneratorTool(BuiltinTool):
+    @staticmethod
+    def _generate_silence(duration: float):
+        # Generate silent WAV data using pydub
+        silence = AudioSegment.silent(duration=int(duration * 1000))  # pydub uses milliseconds
+        return silence
+
+    @staticmethod
+    def _generate_audio_segment(
+        client: openai.OpenAI,
+        line: str,
+        voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"],
+        index: int,
+    ) -> tuple[int, Union[AudioSegment, str], Optional[AudioSegment]]:
+        try:
+            response = client.audio.speech.create(model="tts-1", voice=voice, input=line.strip(), response_format="wav")
+            audio = AudioSegment.from_wav(io.BytesIO(response.content))
+            silence_duration = random.uniform(0.1, 1.5)
+            silence = PodcastAudioGeneratorTool._generate_silence(silence_duration)
+            return index, audio, silence
+        except Exception as e:
+            return index, f"Error generating audio: {str(e)}", None
+
+    def _invoke(
+        self, user_id: str, tool_parameters: dict[str, Any]
+    ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
+        # Extract parameters
+        script = tool_parameters.get("script", "")
+        host1_voice = tool_parameters.get("host1_voice")
+        host2_voice = tool_parameters.get("host2_voice")
+
+        # Split the script into lines
+        script_lines = [line for line in script.split("\n") if line.strip()]
+
+        # Ensure voices are provided
+        if not host1_voice or not host2_voice:
+            raise ToolParameterValidationError("Host voices are required")
+
+        # Get OpenAI API key from credentials
+        if not self.runtime or not self.runtime.credentials:
+            raise ToolProviderCredentialValidationError("Tool runtime or credentials are missing")
+        api_key = self.runtime.credentials.get("api_key")
+        if not api_key:
+            raise ToolProviderCredentialValidationError("OpenAI API key is missing")
+
+        # Initialize OpenAI client
+        client = openai.OpenAI(api_key=api_key)
+
+        # Create a thread pool
+        max_workers = 5
+        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
+            futures = []
+            for i, line in enumerate(script_lines):
+                voice = host1_voice if i % 2 == 0 else host2_voice
+                future = executor.submit(self._generate_audio_segment, client, line, voice, i)
+                futures.append(future)
+
+            # Collect results
+            audio_segments: list[Any] = [None] * len(script_lines)
+            for future in concurrent.futures.as_completed(futures):
+                index, audio, silence = future.result()
+                if isinstance(audio, str):  # Error occurred
+                    return self.create_text_message(audio)
+                audio_segments[index] = (audio, silence)
+
+        # Combine audio segments in the correct order
+        combined_audio = AudioSegment.empty()
+        for i, (audio, silence) in enumerate(audio_segments):
+            if audio:
+                combined_audio += audio
+                if i < len(audio_segments) - 1 and silence:
+                    combined_audio += silence
+
+        # Export the combined audio to a WAV file in memory
+        buffer = io.BytesIO()
+        combined_audio.export(buffer, format="wav")
+        wav_bytes = buffer.getvalue()
+
+        # Create a blob message with the combined audio
+        return [
+            self.create_text_message("Audio generated successfully"),
+            self.create_blob_message(
+                blob=wav_bytes,
+                meta={"mime_type": "audio/x-wav"},
+                save_as=self.VariableKey.AUDIO,
+            ),
+        ]

+ 95 - 0
api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.yaml

@@ -0,0 +1,95 @@
+identity:
+  name: podcast_audio_generator
+  author: Dify
+  label:
+    en_US: Podcast Audio Generator
+    zh_Hans: 播客音频生成器
+description:
+  human:
+    en_US: Generate a podcast audio file from a script with two alternating voices using OpenAI's TTS service.
+    zh_Hans: 使用 OpenAI 的 TTS 服务,从包含两个交替声音的脚本生成播客音频文件。
+  llm: This tool converts a prepared podcast script into an audio file using OpenAI's Text-to-Speech service, with two specified voices for alternating hosts.
+parameters:
+  - name: script
+    type: string
+    required: true
+    label:
+      en_US: Podcast Script
+      zh_Hans: 播客脚本
+    human_description:
+      en_US: A string containing alternating lines for two hosts, separated by newline characters.
+      zh_Hans: 包含两位主持人交替台词的字符串,每行用换行符分隔。
+    llm_description: A string representing the script, with alternating lines for two hosts separated by newline characters.
+    form: llm
+  - name: host1_voice
+    type: select
+    required: true
+    label:
+      en_US: Host 1 Voice
+      zh_Hans: 主持人1 音色
+    human_description:
+      en_US: The voice for the first host.
+      zh_Hans: 第一位主持人的音色。
+    llm_description: The voice identifier for the first host's voice.
+    options:
+      - label:
+          en_US: Alloy
+          zh_Hans: Alloy
+        value: alloy
+      - label:
+          en_US: Echo
+          zh_Hans: Echo
+        value: echo
+      - label:
+          en_US: Fable
+          zh_Hans: Fable
+        value: fable
+      - label:
+          en_US: Onyx
+          zh_Hans: Onyx
+        value: onyx
+      - label:
+          en_US: Nova
+          zh_Hans: Nova
+        value: nova
+      - label:
+          en_US: Shimmer
+          zh_Hans: Shimmer
+        value: shimmer
+    form: form
+  - name: host2_voice
+    type: select
+    required: true
+    label:
+      en_US: Host 2 Voice
+      zh_Hans: 主持人2 音色
+    human_description:
+      en_US: The voice for the second host.
+      zh_Hans: 第二位主持人的音色。
+    llm_description: The voice identifier for the second host's voice.
+    options:
+      - label:
+          en_US: Alloy
+          zh_Hans: Alloy
+        value: alloy
+      - label:
+          en_US: Echo
+          zh_Hans: Echo
+        value: echo
+      - label:
+          en_US: Fable
+          zh_Hans: Fable
+        value: fable
+      - label:
+          en_US: Onyx
+          zh_Hans: Onyx
+        value: onyx
+      - label:
+          en_US: Nova
+          zh_Hans: Nova
+        value: nova
+      - label:
+          en_US: Shimmer
+          zh_Hans: Shimmer
+        value: shimmer
+    form: form

+ 1 - 4
api/core/tools/provider/builtin_tool_provider.py

@@ -13,7 +13,6 @@ from core.tools.errors import (
 from core.tools.provider.tool_provider import ToolProviderController
 from core.tools.tool.builtin_tool import BuiltinTool
 from core.tools.tool.tool import Tool
-from core.tools.utils.tool_parameter_converter import ToolParameterConverter
 from core.tools.utils.yaml_utils import load_yaml_file
 
 
@@ -208,9 +207,7 @@ class BuiltinToolProviderController(ToolProviderController):
 
             # the parameter is not set currently, set the default value if needed
             if parameter_schema.default is not None:
-                default_value = ToolParameterConverter.cast_parameter_by_type(
-                    parameter_schema.default, parameter_schema.type
-                )
+                default_value = parameter_schema.type.cast_value(parameter_schema.default)
                 tool_parameters[parameter] = default_value
 
     def validate_credentials(self, credentials: dict[str, Any]) -> None:

+ 1 - 4
api/core/tools/provider/tool_provider.py

@@ -11,7 +11,6 @@ from core.tools.entities.tool_entities import (
 )
 from core.tools.errors import ToolNotFoundError, ToolParameterValidationError, ToolProviderCredentialValidationError
 from core.tools.tool.tool import Tool
-from core.tools.utils.tool_parameter_converter import ToolParameterConverter
 
 
 class ToolProviderController(BaseModel, ABC):
@@ -127,9 +126,7 @@ class ToolProviderController(BaseModel, ABC):
 
             # the parameter is not set currently, set the default value if needed
             if parameter_schema.default is not None:
-                tool_parameters[parameter] = ToolParameterConverter.cast_parameter_by_type(
-                    parameter_schema.default, parameter_schema.type
-                )
+                tool_parameters[parameter] = parameter_schema.type.cast_value(parameter_schema.default)
 
     def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
         """

+ 10 - 9
api/core/tools/provider/workflow_tool_provider.py

@@ -1,6 +1,6 @@
 from typing import Optional
 
-from core.app.app_config.entities import VariableEntity, VariableEntityType
+from core.app.app_config.entities import VariableEntityType
 from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
 from core.tools.entities.common_entities import I18nObject
 from core.tools.entities.tool_entities import (
@@ -23,6 +23,8 @@ VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
     VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING,
     VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT,
     VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER,
+    VariableEntityType.FILE: ToolParameter.ToolParameterType.FILE,
+    VariableEntityType.FILE_LIST: ToolParameter.ToolParameterType.FILES,
 }
 
 
@@ -36,8 +38,8 @@ class WorkflowToolProviderController(ToolProviderController):
         if not app:
             raise ValueError("app not found")
 
-        controller = WorkflowToolProviderController(
-            **{
+        controller = WorkflowToolProviderController.model_validate(
+            {
                 "identity": {
                     "author": db_provider.user.name if db_provider.user_id and db_provider.user else "",
                     "name": db_provider.label,
@@ -67,7 +69,7 @@ class WorkflowToolProviderController(ToolProviderController):
         :param app: the app
         :return: the tool
         """
-        workflow: Workflow = (
+        workflow = (
             db.session.query(Workflow)
             .filter(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
             .first()
@@ -76,14 +78,14 @@ class WorkflowToolProviderController(ToolProviderController):
             raise ValueError("workflow not found")
 
         # fetch start node
-        graph: dict = workflow.graph_dict
-        features_dict: dict = workflow.features_dict
+        graph = workflow.graph_dict
+        features_dict = workflow.features_dict
         features = WorkflowAppConfigManager.convert_features(config_dict=features_dict, app_mode=AppMode.WORKFLOW)
 
         parameters = db_provider.parameter_configurations
         variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
 
-        def fetch_workflow_variable(variable_name: str) -> VariableEntity:
+        def fetch_workflow_variable(variable_name: str):
             return next(filter(lambda x: x.variable == variable_name, variables), None)
 
         user = db_provider.user
@@ -114,7 +116,6 @@ class WorkflowToolProviderController(ToolProviderController):
                         llm_description=parameter.description,
                         required=variable.required,
                         options=options,
-                        default=variable.default,
                     )
                 )
             elif features.file_upload:
@@ -123,7 +124,7 @@ class WorkflowToolProviderController(ToolProviderController):
                         name=parameter.name,
                         label=I18nObject(en_US=parameter.name, zh_Hans=parameter.name),
                         human_description=I18nObject(en_US=parameter.description, zh_Hans=parameter.description),
-                        type=ToolParameter.ToolParameterType.FILE,
+                        type=ToolParameter.ToolParameterType.SYSTEM_FILES,
                         llm_description=parameter.description,
                         required=False,
                         form=parameter.form,

+ 9 - 10
api/core/tools/tool/tool.py

@@ -20,10 +20,9 @@ from core.tools.entities.tool_entities import (
     ToolRuntimeVariablePool,
 )
 from core.tools.tool_file_manager import ToolFileManager
-from core.tools.utils.tool_parameter_converter import ToolParameterConverter
 
 if TYPE_CHECKING:
-    from core.file.file_obj import FileVar
+    from core.file.models import File
 
 
 class Tool(BaseModel, ABC):
@@ -63,8 +62,12 @@ class Tool(BaseModel, ABC):
     def __init__(self, **data: Any):
         super().__init__(**data)
 
-    class VariableKey(Enum):
+    class VariableKey(str, Enum):
         IMAGE = "image"
+        DOCUMENT = "document"
+        VIDEO = "video"
+        AUDIO = "audio"
+        CUSTOM = "custom"
 
     def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool":
         """
@@ -221,9 +224,7 @@ class Tool(BaseModel, ABC):
         result = deepcopy(tool_parameters)
         for parameter in self.parameters or []:
             if parameter.name in tool_parameters:
-                result[parameter.name] = ToolParameterConverter.cast_parameter_by_type(
-                    tool_parameters[parameter.name], parameter.type
-                )
+                result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name])
 
         return result
 
@@ -295,10 +296,8 @@ class Tool(BaseModel, ABC):
         """
         return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, message=image, save_as=save_as)
 
-    def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage:
-        return ToolInvokeMessage(
-            type=ToolInvokeMessage.MessageType.FILE_VAR, message="", meta={"file_var": file_var}, save_as=""
-        )
+    def create_file_message(self, file: "File") -> ToolInvokeMessage:
+        return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE, message="", meta={"file": file}, save_as="")
 
     def create_link_message(self, link: str, save_as: str = "") -> ToolInvokeMessage:
         """

+ 24 - 26
api/core/tools/tool/workflow_tool.py

@@ -3,7 +3,7 @@ import logging
 from copy import deepcopy
 from typing import Any, Optional, Union
 
-from core.file.file_obj import FileTransferMethod, FileVar
+from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
 from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
 from core.tools.tool.tool import Tool
 from extensions.ext_database import db
@@ -45,11 +45,13 @@ class WorkflowTool(Tool):
         workflow = self._get_workflow(app_id=self.workflow_app_id, version=self.version)
 
         # transform the tool parameters
-        tool_parameters, files = self._transform_args(tool_parameters)
+        tool_parameters, files = self._transform_args(tool_parameters=tool_parameters)
 
         from core.app.apps.workflow.app_generator import WorkflowAppGenerator
 
         generator = WorkflowAppGenerator()
+        assert self.runtime is not None
+        assert self.runtime.invoke_from is not None
         result = generator.generate(
             app_model=app,
             workflow=workflow,
@@ -74,7 +76,7 @@ class WorkflowTool(Tool):
         else:
             outputs, files = self._extract_files(outputs)
             for file in files:
-                result.append(self.create_file_var_message(file))
+                result.append(self.create_file_message(file))
 
         result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False)))
         result.append(self.create_json_message(outputs))
@@ -154,22 +156,22 @@ class WorkflowTool(Tool):
         parameters_result = {}
         files = []
         for parameter in parameter_rules:
-            if parameter.type == ToolParameter.ToolParameterType.FILE:
+            if parameter.type == ToolParameter.ToolParameterType.SYSTEM_FILES:
                 file = tool_parameters.get(parameter.name)
                 if file:
                     try:
-                        file_var_list = [FileVar(**f) for f in file]
-                        for file_var in file_var_list:
-                            file_dict = {
-                                "transfer_method": file_var.transfer_method.value,
-                                "type": file_var.type.value,
+                        file_var_list = [File.model_validate(f) for f in file]
+                        for file in file_var_list:
+                            file_dict: dict[str, str | None] = {
+                                "transfer_method": file.transfer_method.value,
+                                "type": file.type.value,
                             }
-                            if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
-                                file_dict["tool_file_id"] = file_var.related_id
-                            elif file_var.transfer_method == FileTransferMethod.LOCAL_FILE:
-                                file_dict["upload_file_id"] = file_var.related_id
-                            elif file_var.transfer_method == FileTransferMethod.REMOTE_URL:
-                                file_dict["url"] = file_var.preview_url
+                            if file.transfer_method == FileTransferMethod.TOOL_FILE:
+                                file_dict["tool_file_id"] = file.related_id
+                            elif file.transfer_method == FileTransferMethod.LOCAL_FILE:
+                                file_dict["upload_file_id"] = file.related_id
+                            elif file.transfer_method == FileTransferMethod.REMOTE_URL:
+                                file_dict["url"] = file.generate_url()
 
                             files.append(file_dict)
                     except Exception as e:
@@ -179,7 +181,7 @@ class WorkflowTool(Tool):
 
         return parameters_result, files
 
-    def _extract_files(self, outputs: dict) -> tuple[dict, list[FileVar]]:
+    def _extract_files(self, outputs: dict) -> tuple[dict, list[File]]:
         """
         extract files from the result
 
@@ -190,17 +192,13 @@ class WorkflowTool(Tool):
         result = {}
         for key, value in outputs.items():
             if isinstance(value, list):
-                has_file = False
                 for item in value:
-                    if isinstance(item, dict) and item.get("__variant") == "FileVar":
-                        try:
-                            files.append(FileVar(**item))
-                            has_file = True
-                        except Exception as e:
-                            pass
-                if has_file:
-                    continue
+                    if isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY:
+                        file = File.model_validate(item)
+                        files.append(file)
+            elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
+                file = File.model_validate(value)
+                files.append(file)
 
             result[key] = value
-
         return result, files

+ 24 - 16
api/core/tools/tool_engine.py

@@ -10,7 +10,8 @@ from yarl import URL
 from core.app.entities.app_invoke_entities import InvokeFrom
 from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
 from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
-from core.file.file_obj import FileTransferMethod
+from core.file import FileType
+from core.file.models import FileTransferMethod
 from core.ops.ops_trace_manager import TraceQueueManager
 from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, ToolInvokeMeta, ToolParameter
 from core.tools.errors import (
@@ -26,6 +27,7 @@ from core.tools.tool.tool import Tool
 from core.tools.tool.workflow_tool import WorkflowTool
 from core.tools.utils.message_transformer import ToolFileMessageTransformer
 from extensions.ext_database import db
+from models.enums import CreatedByRole
 from models.model import Message, MessageFile
 
 
@@ -128,6 +130,7 @@ class ToolEngine:
         """
         try:
             # hit the callback handler
+            assert tool.identity is not None
             workflow_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters)
 
             if isinstance(tool, WorkflowTool):
@@ -258,7 +261,10 @@ class ToolEngine:
 
     @staticmethod
     def _create_message_files(
-        tool_messages: list[ToolInvokeMessageBinary], agent_message: Message, invoke_from: InvokeFrom, user_id: str
+        tool_messages: list[ToolInvokeMessageBinary],
+        agent_message: Message,
+        invoke_from: InvokeFrom,
+        user_id: str,
     ) -> list[tuple[Any, str]]:
         """
         Create message file
@@ -269,29 +275,31 @@ class ToolEngine:
         result = []
 
         for message in tool_messages:
-            file_type = "bin"
             if "image" in message.mimetype:
-                file_type = "image"
+                file_type = FileType.IMAGE
             elif "video" in message.mimetype:
-                file_type = "video"
+                file_type = FileType.VIDEO
             elif "audio" in message.mimetype:
-                file_type = "audio"
-            elif "text" in message.mimetype:
-                file_type = "text"
-            elif "pdf" in message.mimetype:
-                file_type = "pdf"
-            elif "zip" in message.mimetype:
-                file_type = "archive"
-            # ...
+                file_type = FileType.AUDIO
+            elif "text" in message.mimetype or "pdf" in message.mimetype:
+                file_type = FileType.DOCUMENT
+            else:
+                file_type = FileType.CUSTOM
 
+            # extract tool file id from url
+            tool_file_id = message.url.split("/")[-1].split(".")[0]
             message_file = MessageFile(
                 message_id=agent_message.id,
                 type=file_type,
-                transfer_method=FileTransferMethod.TOOL_FILE.value,
+                transfer_method=FileTransferMethod.TOOL_FILE,
                 belongs_to="assistant",
                 url=message.url,
-                upload_file_id=None,
-                created_by_role=("account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"),
+                upload_file_id=tool_file_id,
+                created_by_role=(
+                    CreatedByRole.ACCOUNT
+                    if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
+                    else CreatedByRole.END_USER
+                ),
                 created_by=user_id,
             )
 

Niektóre pliki nie zostały wyświetlone z powodu dużej ilości zmienionych plików