Browse Source

feat: [backend] vision support (#1510)

Co-authored-by: Garfield Dai <dai.hai@foxmail.com>
takatost 1 year ago
parent
commit
41d0a8b295
61 changed files with 1551 additions and 288 deletions
  1. 11 0
      api/.env.example
  2. 7 0
      api/app.py
  3. 120 83
      api/config.py
  4. 4 0
      api/controllers/console/app/completion.py
  5. 2 4
      api/controllers/console/app/conversation.py
  6. 0 3
      api/controllers/console/datasets/data_source.py
  7. 5 6
      api/controllers/console/datasets/file.py
  8. 12 0
      api/controllers/console/explore/completion.py
  9. 11 3
      api/controllers/console/explore/conversation.py
  10. 3 2
      api/controllers/console/explore/installed_app.py
  11. 13 2
      api/controllers/console/explore/parameter.py
  12. 2 0
      api/controllers/console/explore/saved_message.py
  13. 3 0
      api/controllers/console/universal_chat/chat.py
  14. 9 2
      api/controllers/console/universal_chat/conversation.py
  15. 10 0
      api/controllers/files/__init__.py
  16. 40 0
      api/controllers/files/image_preview.py
  17. 1 1
      api/controllers/service_api/__init__.py
  18. 13 2
      api/controllers/service_api/app/app.py
  19. 6 1
      api/controllers/service_api/app/completion.py
  20. 9 2
      api/controllers/service_api/app/conversation.py
  21. 23 0
      api/controllers/service_api/app/error.py
  22. 42 0
      api/controllers/service_api/app/file.py
  23. 2 1
      api/controllers/service_api/app/message.py
  24. 3 2
      api/controllers/service_api/dataset/document.py
  25. 1 1
      api/controllers/web/__init__.py
  26. 13 2
      api/controllers/web/app.py
  27. 4 0
      api/controllers/web/completion.py
  28. 9 2
      api/controllers/web/conversation.py
  29. 25 1
      api/controllers/web/error.py
  30. 36 0
      api/controllers/web/file.py
  31. 2 0
      api/controllers/web/message.py
  32. 3 0
      api/controllers/web/saved_message.py
  33. 8 2
      api/core/callback_handler/llm_callback_handler.py
  34. 24 9
      api/core/completion.py
  35. 24 6
      api/core/conversation_message_task.py
  36. 0 0
      api/core/file/__init__.py
  37. 79 0
      api/core/file/file_obj.py
  38. 180 0
      api/core/file/message_file_parser.py
  39. 79 0
      api/core/file/upload_file_parser.py
  40. 6 2
      api/core/generator/llm_generator.py
  41. 19 1
      api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py
  42. 46 2
      api/core/model_providers/models/entity/message.py
  43. 0 2
      api/core/model_providers/models/llm/openai_model.py
  44. 153 76
      api/core/prompt/prompt_transform.py
  45. 103 1
      api/core/third_party/langchain/llms/chat_open_ai.py
  46. 4 10
      api/events/event_handlers/generate_conversation_name_when_first_message_created.py
  47. 36 1
      api/extensions/ext_storage.py
  48. 3 2
      api/fields/app_fields.py
  49. 9 7
      api/fields/conversation_fields.py
  50. 2 1
      api/fields/file_fields.py
  51. 2 0
      api/fields/message_fields.py
  52. 59 0
      api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py
  53. 63 4
      api/models/model.py
  54. 33 1
      api/services/app_model_config_service.py
  55. 40 10
      api/services/completion_service.py
  56. 36 4
      api/services/conversation_service.py
  57. 54 21
      api/services/file_service.py
  58. 4 2
      api/services/web_conversation_service.py
  59. 29 1
      api/tests/integration_tests/models/llm/test_openai_model.py
  60. 7 3
      docker/docker-compose.yaml
  61. 5 0
      docker/nginx/conf.d/default.conf

+ 11 - 0
api/.env.example

@@ -18,6 +18,9 @@ SERVICE_API_URL=http://127.0.0.1:5001
 APP_API_URL=http://127.0.0.1:5001
 APP_API_URL=http://127.0.0.1:5001
 APP_WEB_URL=http://127.0.0.1:3000
 APP_WEB_URL=http://127.0.0.1:3000
 
 
+# Files URL
+FILES_URL=http://127.0.0.1:5001
+
 # celery configuration
 # celery configuration
 CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
 CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
 
 
@@ -70,6 +73,14 @@ MILVUS_USER=root
 MILVUS_PASSWORD=Milvus
 MILVUS_PASSWORD=Milvus
 MILVUS_SECURE=false
 MILVUS_SECURE=false
 
 
+# Upload configuration
+UPLOAD_FILE_SIZE_LIMIT=15
+UPLOAD_FILE_BATCH_LIMIT=5
+UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
+
+# Model Configuration
+MULTIMODAL_SEND_IMAGE_FORMAT=base64
+
 # Mail configuration, support: resend
 # Mail configuration, support: resend
 MAIL_TYPE=
 MAIL_TYPE=
 MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai>
 MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai>

+ 7 - 0
api/app.py

@@ -126,6 +126,7 @@ def register_blueprints(app):
     from controllers.service_api import bp as service_api_bp
     from controllers.service_api import bp as service_api_bp
     from controllers.web import bp as web_bp
     from controllers.web import bp as web_bp
     from controllers.console import bp as console_app_bp
     from controllers.console import bp as console_app_bp
+    from controllers.files import bp as files_bp
 
 
     CORS(service_api_bp,
     CORS(service_api_bp,
          allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
          allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
@@ -155,6 +156,12 @@ def register_blueprints(app):
 
 
     app.register_blueprint(console_app_bp)
     app.register_blueprint(console_app_bp)
 
 
+    CORS(files_bp,
+         allow_headers=['Content-Type'],
+         methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH']
+         )
+    app.register_blueprint(files_bp)
+
 
 
 # create app
 # create app
 app = create_app()
 app = create_app()

+ 120 - 83
api/config.py

@@ -26,6 +26,7 @@ DEFAULTS = {
     'SERVICE_API_URL': 'https://api.dify.ai',
     'SERVICE_API_URL': 'https://api.dify.ai',
     'APP_WEB_URL': 'https://udify.app',
     'APP_WEB_URL': 'https://udify.app',
     'APP_API_URL': 'https://udify.app',
     'APP_API_URL': 'https://udify.app',
+    'FILES_URL': '',
     'STORAGE_TYPE': 'local',
     'STORAGE_TYPE': 'local',
     'STORAGE_LOCAL_PATH': 'storage',
     'STORAGE_LOCAL_PATH': 'storage',
     'CHECK_UPDATE_URL': 'https://updates.dify.ai',
     'CHECK_UPDATE_URL': 'https://updates.dify.ai',
@@ -57,7 +58,9 @@ DEFAULTS = {
     'CLEAN_DAY_SETTING': 30,
     'CLEAN_DAY_SETTING': 30,
     'UPLOAD_FILE_SIZE_LIMIT': 15,
     'UPLOAD_FILE_SIZE_LIMIT': 15,
     'UPLOAD_FILE_BATCH_LIMIT': 5,
     'UPLOAD_FILE_BATCH_LIMIT': 5,
-    'OUTPUT_MODERATION_BUFFER_SIZE': 300
+    'UPLOAD_IMAGE_FILE_SIZE_LIMIT': 10,
+    'OUTPUT_MODERATION_BUFFER_SIZE': 300,
+    'MULTIMODAL_SEND_IMAGE_FORMAT': 'base64'
 }
 }
 
 
 
 
@@ -84,15 +87,9 @@ class Config:
     """Application configuration class."""
     """Application configuration class."""
 
 
     def __init__(self):
     def __init__(self):
-        # app settings
-        self.CONSOLE_API_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_API_URL')
-        self.CONSOLE_WEB_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_WEB_URL')
-        self.SERVICE_API_URL = get_env('API_URL') if get_env('API_URL') else get_env('SERVICE_API_URL')
-        self.APP_WEB_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_WEB_URL')
-        self.APP_API_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_API_URL')
-        self.CONSOLE_URL = get_env('CONSOLE_URL')
-        self.API_URL = get_env('API_URL')
-        self.APP_URL = get_env('APP_URL')
+        # ------------------------
+        # General Configurations.
+        # ------------------------
         self.CURRENT_VERSION = "0.3.29"
         self.CURRENT_VERSION = "0.3.29"
         self.COMMIT_SHA = get_env('COMMIT_SHA')
         self.COMMIT_SHA = get_env('COMMIT_SHA')
         self.EDITION = "SELF_HOSTED"
         self.EDITION = "SELF_HOSTED"
@@ -100,13 +97,71 @@ class Config:
         self.TESTING = False
         self.TESTING = False
         self.LOG_LEVEL = get_env('LOG_LEVEL')
         self.LOG_LEVEL = get_env('LOG_LEVEL')
 
 
+        # The backend URL prefix of the console API.
+        # used to concatenate the login authorization callback or notion integration callback.
+        self.CONSOLE_API_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_API_URL')
+
+        # The front-end URL prefix of the console web.
+        # used to concatenate some front-end addresses and for CORS configuration use.
+        self.CONSOLE_WEB_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_WEB_URL')
+
+        # WebApp API backend Url prefix.
+        # used to declare the back-end URL for the front-end API.
+        self.APP_API_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_API_URL')
+
+        # WebApp Url prefix.
+        # used to display WebAPP API Base Url to the front-end.
+        self.APP_WEB_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_WEB_URL')
+
+        # Service API Url prefix.
+        # used to display Service API Base Url to the front-end.
+        self.SERVICE_API_URL = get_env('API_URL') if get_env('API_URL') else get_env('SERVICE_API_URL')
+
+        # File preview or download Url prefix.
+        # used to display File preview or download Url to the front-end or as Multi-model inputs;
+        # Url is signed and has expiration time.
+        self.FILES_URL = get_env('FILES_URL') if get_env('FILES_URL') else self.CONSOLE_API_URL
+
+        # Fallback Url prefix.
+        # Will be deprecated in the future.
+        self.CONSOLE_URL = get_env('CONSOLE_URL')
+        self.API_URL = get_env('API_URL')
+        self.APP_URL = get_env('APP_URL')
+
         # Your App secret key will be used for securely signing the session cookie
         # Your App secret key will be used for securely signing the session cookie
         # Make sure you are changing this key for your deployment with a strong key.
         # Make sure you are changing this key for your deployment with a strong key.
         # You can generate a strong key using `openssl rand -base64 42`.
         # You can generate a strong key using `openssl rand -base64 42`.
         # Alternatively you can set it with `SECRET_KEY` environment variable.
         # Alternatively you can set it with `SECRET_KEY` environment variable.
         self.SECRET_KEY = get_env('SECRET_KEY')
         self.SECRET_KEY = get_env('SECRET_KEY')
 
 
-        # redis settings
+        # cors settings
+        self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
+            'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL)
+        self.WEB_API_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
+            'WEB_API_CORS_ALLOW_ORIGINS', '*')
+
+        # check update url
+        self.CHECK_UPDATE_URL = get_env('CHECK_UPDATE_URL')
+
+        # ------------------------
+        # Database Configurations.
+        # ------------------------
+        db_credentials = {
+            key: get_env(key) for key in
+            ['DB_USERNAME', 'DB_PASSWORD', 'DB_HOST', 'DB_PORT', 'DB_DATABASE']
+        }
+
+        self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/{db_credentials['DB_DATABASE']}"
+        self.SQLALCHEMY_ENGINE_OPTIONS = {
+            'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE')),
+            'pool_recycle': int(get_env('SQLALCHEMY_POOL_RECYCLE'))
+        }
+
+        self.SQLALCHEMY_ECHO = get_bool_env('SQLALCHEMY_ECHO')
+
+        # ------------------------
+        # Redis Configurations.
+        # ------------------------
         self.REDIS_HOST = get_env('REDIS_HOST')
         self.REDIS_HOST = get_env('REDIS_HOST')
         self.REDIS_PORT = get_env('REDIS_PORT')
         self.REDIS_PORT = get_env('REDIS_PORT')
         self.REDIS_USERNAME = get_env('REDIS_USERNAME')
         self.REDIS_USERNAME = get_env('REDIS_USERNAME')
@@ -114,7 +169,18 @@ class Config:
         self.REDIS_DB = get_env('REDIS_DB')
         self.REDIS_DB = get_env('REDIS_DB')
         self.REDIS_USE_SSL = get_bool_env('REDIS_USE_SSL')
         self.REDIS_USE_SSL = get_bool_env('REDIS_USE_SSL')
 
 
-        # storage settings
+        # ------------------------
+        # Celery worker Configurations.
+        # ------------------------
+        self.CELERY_BROKER_URL = get_env('CELERY_BROKER_URL')
+        self.CELERY_BACKEND = get_env('CELERY_BACKEND')
+        self.CELERY_RESULT_BACKEND = 'db+{}'.format(self.SQLALCHEMY_DATABASE_URI) \
+            if self.CELERY_BACKEND == 'database' else self.CELERY_BROKER_URL
+        self.BROKER_USE_SSL = self.CELERY_BROKER_URL.startswith('rediss://')
+
+        # ------------------------
+        # File Storage Configurations.
+        # ------------------------
         self.STORAGE_TYPE = get_env('STORAGE_TYPE')
         self.STORAGE_TYPE = get_env('STORAGE_TYPE')
         self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH')
         self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH')
         self.S3_ENDPOINT = get_env('S3_ENDPOINT')
         self.S3_ENDPOINT = get_env('S3_ENDPOINT')
@@ -123,68 +189,72 @@ class Config:
         self.S3_SECRET_KEY = get_env('S3_SECRET_KEY')
         self.S3_SECRET_KEY = get_env('S3_SECRET_KEY')
         self.S3_REGION = get_env('S3_REGION')
         self.S3_REGION = get_env('S3_REGION')
 
 
-        # vector store settings, only support weaviate, qdrant
+        # ------------------------
+        # Vector Store Configurations.
+        # Currently, only support: qdrant, milvus, zilliz, weaviate
+        # ------------------------
         self.VECTOR_STORE = get_env('VECTOR_STORE')
         self.VECTOR_STORE = get_env('VECTOR_STORE')
 
 
-        # weaviate settings
-        self.WEAVIATE_ENDPOINT = get_env('WEAVIATE_ENDPOINT')
-        self.WEAVIATE_API_KEY = get_env('WEAVIATE_API_KEY')
-        self.WEAVIATE_GRPC_ENABLED = get_bool_env('WEAVIATE_GRPC_ENABLED')
-        self.WEAVIATE_BATCH_SIZE = int(get_env('WEAVIATE_BATCH_SIZE'))
-
         # qdrant settings
         # qdrant settings
         self.QDRANT_URL = get_env('QDRANT_URL')
         self.QDRANT_URL = get_env('QDRANT_URL')
         self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
         self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
 
 
-        # milvus setting
+        # milvus / zilliz setting
         self.MILVUS_HOST = get_env('MILVUS_HOST')
         self.MILVUS_HOST = get_env('MILVUS_HOST')
         self.MILVUS_PORT = get_env('MILVUS_PORT')
         self.MILVUS_PORT = get_env('MILVUS_PORT')
         self.MILVUS_USER = get_env('MILVUS_USER')
         self.MILVUS_USER = get_env('MILVUS_USER')
         self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD')
         self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD')
         self.MILVUS_SECURE = get_env('MILVUS_SECURE')
         self.MILVUS_SECURE = get_env('MILVUS_SECURE')
 
 
+        # weaviate settings
+        self.WEAVIATE_ENDPOINT = get_env('WEAVIATE_ENDPOINT')
+        self.WEAVIATE_API_KEY = get_env('WEAVIATE_API_KEY')
+        self.WEAVIATE_GRPC_ENABLED = get_bool_env('WEAVIATE_GRPC_ENABLED')
+        self.WEAVIATE_BATCH_SIZE = int(get_env('WEAVIATE_BATCH_SIZE'))
 
 
-        # cors settings
-        self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
-            'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL)
-        self.WEB_API_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
-            'WEB_API_CORS_ALLOW_ORIGINS', '*')
-
-        # mail settings
+        # ------------------------
+        # Mail Configurations.
+        # ------------------------
         self.MAIL_TYPE = get_env('MAIL_TYPE')
         self.MAIL_TYPE = get_env('MAIL_TYPE')
         self.MAIL_DEFAULT_SEND_FROM = get_env('MAIL_DEFAULT_SEND_FROM')
         self.MAIL_DEFAULT_SEND_FROM = get_env('MAIL_DEFAULT_SEND_FROM')
         self.RESEND_API_KEY = get_env('RESEND_API_KEY')
         self.RESEND_API_KEY = get_env('RESEND_API_KEY')
 
 
-        # sentry settings
+        # ------------------------
+        # Sentry Configurations.
+        # ------------------------
         self.SENTRY_DSN = get_env('SENTRY_DSN')
         self.SENTRY_DSN = get_env('SENTRY_DSN')
         self.SENTRY_TRACES_SAMPLE_RATE = float(get_env('SENTRY_TRACES_SAMPLE_RATE'))
         self.SENTRY_TRACES_SAMPLE_RATE = float(get_env('SENTRY_TRACES_SAMPLE_RATE'))
         self.SENTRY_PROFILES_SAMPLE_RATE = float(get_env('SENTRY_PROFILES_SAMPLE_RATE'))
         self.SENTRY_PROFILES_SAMPLE_RATE = float(get_env('SENTRY_PROFILES_SAMPLE_RATE'))
 
 
-        # check update url
-        self.CHECK_UPDATE_URL = get_env('CHECK_UPDATE_URL')
+        # ------------------------
+        # Business Configurations.
+        # ------------------------
 
 
-        # database settings
-        db_credentials = {
-            key: get_env(key) for key in
-            ['DB_USERNAME', 'DB_PASSWORD', 'DB_HOST', 'DB_PORT', 'DB_DATABASE']
-        }
+        # multi model send image format, support base64, url, default is base64
+        self.MULTIMODAL_SEND_IMAGE_FORMAT = get_env('MULTIMODAL_SEND_IMAGE_FORMAT')
 
 
-        self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/{db_credentials['DB_DATABASE']}"
-        self.SQLALCHEMY_ENGINE_OPTIONS = {
-            'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE')),
-            'pool_recycle': int(get_env('SQLALCHEMY_POOL_RECYCLE'))
-        }
+        # Dataset Configurations.
+        self.TENANT_DOCUMENT_COUNT = get_env('TENANT_DOCUMENT_COUNT')
+        self.CLEAN_DAY_SETTING = get_env('CLEAN_DAY_SETTING')
 
 
-        self.SQLALCHEMY_ECHO = get_bool_env('SQLALCHEMY_ECHO')
+        # File upload Configurations.
+        self.UPLOAD_FILE_SIZE_LIMIT = int(get_env('UPLOAD_FILE_SIZE_LIMIT'))
+        self.UPLOAD_FILE_BATCH_LIMIT = int(get_env('UPLOAD_FILE_BATCH_LIMIT'))
+        self.UPLOAD_IMAGE_FILE_SIZE_LIMIT = int(get_env('UPLOAD_IMAGE_FILE_SIZE_LIMIT'))
 
 
-        # celery settings
-        self.CELERY_BROKER_URL = get_env('CELERY_BROKER_URL')
-        self.CELERY_BACKEND = get_env('CELERY_BACKEND')
-        self.CELERY_RESULT_BACKEND = 'db+{}'.format(self.SQLALCHEMY_DATABASE_URI) \
-            if self.CELERY_BACKEND == 'database' else self.CELERY_BROKER_URL
-        self.BROKER_USE_SSL = self.CELERY_BROKER_URL.startswith('rediss://')
+        # Moderation in app Configurations.
+        self.OUTPUT_MODERATION_BUFFER_SIZE = int(get_env('OUTPUT_MODERATION_BUFFER_SIZE'))
+
+        # Notion integration setting
+        self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
+        self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
+        self.NOTION_INTEGRATION_TYPE = get_env('NOTION_INTEGRATION_TYPE')
+        self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET')
+        self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN')
 
 
-        # hosted provider credentials
+        # ------------------------
+        # Platform Configurations.
+        # ------------------------
         self.HOSTED_OPENAI_ENABLED = get_bool_env('HOSTED_OPENAI_ENABLED')
         self.HOSTED_OPENAI_ENABLED = get_bool_env('HOSTED_OPENAI_ENABLED')
         self.HOSTED_OPENAI_API_KEY = get_env('HOSTED_OPENAI_API_KEY')
         self.HOSTED_OPENAI_API_KEY = get_env('HOSTED_OPENAI_API_KEY')
         self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
         self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
@@ -212,26 +282,6 @@ class Config:
         self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED')
         self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED')
         self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS')
         self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS')
 
 
-        self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
-        self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
-
-        # notion import setting
-        self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
-        self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
-        self.NOTION_INTEGRATION_TYPE = get_env('NOTION_INTEGRATION_TYPE')
-        self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET')
-        self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN')
-
-        self.TENANT_DOCUMENT_COUNT = get_env('TENANT_DOCUMENT_COUNT')
-        self.CLEAN_DAY_SETTING = get_env('CLEAN_DAY_SETTING')
-
-        # uploading settings
-        self.UPLOAD_FILE_SIZE_LIMIT = int(get_env('UPLOAD_FILE_SIZE_LIMIT'))
-        self.UPLOAD_FILE_BATCH_LIMIT = int(get_env('UPLOAD_FILE_BATCH_LIMIT'))
-
-        # moderation settings
-        self.OUTPUT_MODERATION_BUFFER_SIZE = int(get_env('OUTPUT_MODERATION_BUFFER_SIZE'))
-
 
 
 class CloudEditionConfig(Config):
 class CloudEditionConfig(Config):
 
 
@@ -246,18 +296,5 @@ class CloudEditionConfig(Config):
         self.GOOGLE_CLIENT_SECRET = get_env('GOOGLE_CLIENT_SECRET')
         self.GOOGLE_CLIENT_SECRET = get_env('GOOGLE_CLIENT_SECRET')
         self.OAUTH_REDIRECT_PATH = get_env('OAUTH_REDIRECT_PATH')
         self.OAUTH_REDIRECT_PATH = get_env('OAUTH_REDIRECT_PATH')
 
 
-
-class TestConfig(Config):
-
-    def __init__(self):
-        super().__init__()
-
-        self.EDITION = "SELF_HOSTED"
-        self.TESTING = True
-
-        db_credentials = {
-            key: get_env(key) for key in ['DB_USERNAME', 'DB_PASSWORD', 'DB_HOST', 'DB_PORT']
-        }
-
-        # use a different database for testing: dify_test
-        self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/dify_test"
+        self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
+        self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')

+ 4 - 0
api/controllers/console/app/completion.py

@@ -40,12 +40,14 @@ class CompletionMessageApi(Resource):
         parser = reqparse.RequestParser()
         parser = reqparse.RequestParser()
         parser.add_argument('inputs', type=dict, required=True, location='json')
         parser.add_argument('inputs', type=dict, required=True, location='json')
         parser.add_argument('query', type=str, location='json', default='')
         parser.add_argument('query', type=str, location='json', default='')
+        parser.add_argument('files', type=list, required=False, location='json')
         parser.add_argument('model_config', type=dict, required=True, location='json')
         parser.add_argument('model_config', type=dict, required=True, location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
         parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
         args = parser.parse_args()
         args = parser.parse_args()
 
 
         streaming = args['response_mode'] != 'blocking'
         streaming = args['response_mode'] != 'blocking'
+        args['auto_generate_name'] = False
 
 
         account = flask_login.current_user
         account = flask_login.current_user
 
 
@@ -113,6 +115,7 @@ class ChatMessageApi(Resource):
         parser = reqparse.RequestParser()
         parser = reqparse.RequestParser()
         parser.add_argument('inputs', type=dict, required=True, location='json')
         parser.add_argument('inputs', type=dict, required=True, location='json')
         parser.add_argument('query', type=str, required=True, location='json')
         parser.add_argument('query', type=str, required=True, location='json')
+        parser.add_argument('files', type=list, required=False, location='json')
         parser.add_argument('model_config', type=dict, required=True, location='json')
         parser.add_argument('model_config', type=dict, required=True, location='json')
         parser.add_argument('conversation_id', type=uuid_value, location='json')
         parser.add_argument('conversation_id', type=uuid_value, location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
@@ -120,6 +123,7 @@ class ChatMessageApi(Resource):
         args = parser.parse_args()
         args = parser.parse_args()
 
 
         streaming = args['response_mode'] != 'blocking'
         streaming = args['response_mode'] != 'blocking'
+        args['auto_generate_name'] = False
 
 
         account = flask_login.current_user
         account = flask_login.current_user
 
 

+ 2 - 4
api/controllers/console/app/conversation.py

@@ -108,7 +108,7 @@ class CompletionConversationDetailApi(Resource):
         conversation_id = str(conversation_id)
         conversation_id = str(conversation_id)
 
 
         return _get_conversation(app_id, conversation_id, 'completion')
         return _get_conversation(app_id, conversation_id, 'completion')
-    
+
     @setup_required
     @setup_required
     @login_required
     @login_required
     @account_initialization_required
     @account_initialization_required
@@ -230,7 +230,7 @@ class ChatConversationDetailApi(Resource):
         conversation_id = str(conversation_id)
         conversation_id = str(conversation_id)
 
 
         return _get_conversation(app_id, conversation_id, 'chat')
         return _get_conversation(app_id, conversation_id, 'chat')
-    
+
     @setup_required
     @setup_required
     @login_required
     @login_required
     @account_initialization_required
     @account_initialization_required
@@ -253,8 +253,6 @@ class ChatConversationDetailApi(Resource):
         return {'result': 'success'}, 204
         return {'result': 'success'}, 204
 
 
 
 
-
-
 api.add_resource(CompletionConversationApi, '/apps/<uuid:app_id>/completion-conversations')
 api.add_resource(CompletionConversationApi, '/apps/<uuid:app_id>/completion-conversations')
 api.add_resource(CompletionConversationDetailApi, '/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>')
 api.add_resource(CompletionConversationDetailApi, '/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>')
 api.add_resource(ChatConversationApi, '/apps/<uuid:app_id>/chat-conversations')
 api.add_resource(ChatConversationApi, '/apps/<uuid:app_id>/chat-conversations')

+ 0 - 3
api/controllers/console/datasets/data_source.py

@@ -1,7 +1,6 @@
 import datetime
 import datetime
 import json
 import json
 
 
-from cachetools import TTLCache
 from flask import request
 from flask import request
 from flask_login import current_user
 from flask_login import current_user
 from libs.login import login_required
 from libs.login import login_required
@@ -20,8 +19,6 @@ from models.source import DataSourceBinding
 from services.dataset_service import DatasetService, DocumentService
 from services.dataset_service import DatasetService, DocumentService
 from tasks.document_indexing_sync_task import document_indexing_sync_task
 from tasks.document_indexing_sync_task import document_indexing_sync_task
 
 
-cache = TTLCache(maxsize=None, ttl=30)
-
 
 
 class DataSourceApi(Resource):
 class DataSourceApi(Resource):
 
 

+ 5 - 6
api/controllers/console/datasets/file.py

@@ -1,5 +1,5 @@
-from cachetools import TTLCache
 from flask import request, current_app
 from flask import request, current_app
+from flask_login import current_user
 
 
 import services
 import services
 from libs.login import login_required
 from libs.login import login_required
@@ -15,9 +15,6 @@ from fields.file_fields import upload_config_fields, file_fields
 
 
 from services.file_service import FileService
 from services.file_service import FileService
 
 
-cache = TTLCache(maxsize=None, ttl=30)
-
-ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv']
 PREVIEW_WORDS_LIMIT = 3000
 PREVIEW_WORDS_LIMIT = 3000
 
 
 
 
@@ -30,9 +27,11 @@ class FileApi(Resource):
     def get(self):
     def get(self):
         file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT")
         file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT")
         batch_count_limit = current_app.config.get("UPLOAD_FILE_BATCH_LIMIT")
         batch_count_limit = current_app.config.get("UPLOAD_FILE_BATCH_LIMIT")
+        image_file_size_limit = current_app.config.get("UPLOAD_IMAGE_FILE_SIZE_LIMIT")
         return {
         return {
             'file_size_limit': file_size_limit,
             'file_size_limit': file_size_limit,
-            'batch_count_limit': batch_count_limit
+            'batch_count_limit': batch_count_limit,
+            'image_file_size_limit': image_file_size_limit
         }, 200
         }, 200
 
 
     @setup_required
     @setup_required
@@ -51,7 +50,7 @@ class FileApi(Resource):
         if len(request.files) > 1:
         if len(request.files) > 1:
             raise TooManyFilesError()
             raise TooManyFilesError()
         try:
         try:
-            upload_file = FileService.upload_file(file)
+            upload_file = FileService.upload_file(file, current_user)
         except services.errors.file.FileTooLargeError as file_too_large_error:
         except services.errors.file.FileTooLargeError as file_too_large_error:
             raise FileTooLargeError(file_too_large_error.description)
             raise FileTooLargeError(file_too_large_error.description)
         except services.errors.file.UnsupportedFileTypeError:
         except services.errors.file.UnsupportedFileTypeError:

+ 12 - 0
api/controllers/console/explore/completion.py

@@ -1,6 +1,7 @@
 # -*- coding:utf-8 -*-
 # -*- coding:utf-8 -*-
 import json
 import json
 import logging
 import logging
+from datetime import datetime
 from typing import Generator, Union
 from typing import Generator, Union
 
 
 from flask import Response, stream_with_context
 from flask import Response, stream_with_context
@@ -17,6 +18,7 @@ from controllers.console.explore.wraps import InstalledAppResource
 from core.conversation_message_task import PubHandler
 from core.conversation_message_task import PubHandler
 from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
 from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
     LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
     LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from extensions.ext_database import db
 from libs.helper import uuid_value
 from libs.helper import uuid_value
 from services.completion_service import CompletionService
 from services.completion_service import CompletionService
 
 
@@ -32,11 +34,16 @@ class CompletionApi(InstalledAppResource):
         parser = reqparse.RequestParser()
         parser = reqparse.RequestParser()
         parser.add_argument('inputs', type=dict, required=True, location='json')
         parser.add_argument('inputs', type=dict, required=True, location='json')
         parser.add_argument('query', type=str, location='json', default='')
         parser.add_argument('query', type=str, location='json', default='')
+        parser.add_argument('files', type=list, required=False, location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
         parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
         args = parser.parse_args()
         args = parser.parse_args()
 
 
         streaming = args['response_mode'] == 'streaming'
         streaming = args['response_mode'] == 'streaming'
+        args['auto_generate_name'] = False
+
+        installed_app.last_used_at = datetime.utcnow()
+        db.session.commit()
 
 
         try:
         try:
             response = CompletionService.completion(
             response = CompletionService.completion(
@@ -91,12 +98,17 @@ class ChatApi(InstalledAppResource):
         parser = reqparse.RequestParser()
         parser = reqparse.RequestParser()
         parser.add_argument('inputs', type=dict, required=True, location='json')
         parser.add_argument('inputs', type=dict, required=True, location='json')
         parser.add_argument('query', type=str, required=True, location='json')
         parser.add_argument('query', type=str, required=True, location='json')
+        parser.add_argument('files', type=list, required=False, location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('conversation_id', type=uuid_value, location='json')
         parser.add_argument('conversation_id', type=uuid_value, location='json')
         parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
         parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
         args = parser.parse_args()
         args = parser.parse_args()
 
 
         streaming = args['response_mode'] == 'streaming'
         streaming = args['response_mode'] == 'streaming'
+        args['auto_generate_name'] = False
+
+        installed_app.last_used_at = datetime.utcnow()
+        db.session.commit()
 
 
         try:
         try:
             response = CompletionService.completion(
             response = CompletionService.completion(

+ 11 - 3
api/controllers/console/explore/conversation.py

@@ -38,7 +38,8 @@ class ConversationListApi(InstalledAppResource):
                 user=current_user,
                 user=current_user,
                 last_id=args['last_id'],
                 last_id=args['last_id'],
                 limit=args['limit'],
                 limit=args['limit'],
-                pinned=pinned
+                pinned=pinned,
+                exclude_debug_conversation=True
             )
             )
         except LastConversationNotExistsError:
         except LastConversationNotExistsError:
             raise NotFound("Last Conversation Not Exists.")
             raise NotFound("Last Conversation Not Exists.")
@@ -71,11 +72,18 @@ class ConversationRenameApi(InstalledAppResource):
         conversation_id = str(c_id)
         conversation_id = str(c_id)
 
 
         parser = reqparse.RequestParser()
         parser = reqparse.RequestParser()
-        parser.add_argument('name', type=str, required=True, location='json')
+        parser.add_argument('name', type=str, required=False, location='json')
+        parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
         args = parser.parse_args()
         args = parser.parse_args()
 
 
         try:
         try:
-            return ConversationService.rename(app_model, conversation_id, current_user, args['name'])
+            return ConversationService.rename(
+                app_model,
+                conversation_id,
+                current_user,
+                args['name'],
+                args['auto_generate']
+            )
         except ConversationNotExistsError:
         except ConversationNotExistsError:
             raise NotFound("Conversation Not Exists.")
             raise NotFound("Conversation Not Exists.")
 
 

+ 3 - 2
api/controllers/console/explore/installed_app.py

@@ -39,8 +39,9 @@ class InstalledAppsListApi(Resource):
             }
             }
             for installed_app in installed_apps
             for installed_app in installed_apps
         ]
         ]
-        installed_apps.sort(key=lambda app: (-app['is_pinned'], app['last_used_at']
-                            if app['last_used_at'] is not None else datetime.min))
+        installed_apps.sort(key=lambda app: (-app['is_pinned'],
+                                             app['last_used_at'] is None,
+                                             -app['last_used_at'].timestamp() if app['last_used_at'] is not None else 0))
 
 
         return {'installed_apps': installed_apps}
         return {'installed_apps': installed_apps}
 
 

+ 13 - 2
api/controllers/console/explore/parameter.py

@@ -1,5 +1,6 @@
 # -*- coding:utf-8 -*-
 # -*- coding:utf-8 -*-
 from flask_restful import marshal_with, fields
 from flask_restful import marshal_with, fields
+from flask import current_app
 
 
 from controllers.console import api
 from controllers.console import api
 from controllers.console.explore.wraps import InstalledAppResource
 from controllers.console.explore.wraps import InstalledAppResource
@@ -19,6 +20,10 @@ class AppParameterApi(InstalledAppResource):
         'options': fields.List(fields.String)
         'options': fields.List(fields.String)
     }
     }
 
 
+    system_parameters_fields = {
+        'image_file_size_limit': fields.String
+    }
+
     parameters_fields = {
     parameters_fields = {
         'opening_statement': fields.String,
         'opening_statement': fields.String,
         'suggested_questions': fields.Raw,
         'suggested_questions': fields.Raw,
@@ -27,7 +32,9 @@ class AppParameterApi(InstalledAppResource):
         'retriever_resource': fields.Raw,
         'retriever_resource': fields.Raw,
         'more_like_this': fields.Raw,
         'more_like_this': fields.Raw,
         'user_input_form': fields.Raw,
         'user_input_form': fields.Raw,
-        'sensitive_word_avoidance': fields.Raw
+        'sensitive_word_avoidance': fields.Raw,
+        'file_upload': fields.Raw,
+        'system_parameters': fields.Nested(system_parameters_fields)
     }
     }
 
 
     @marshal_with(parameters_fields)
     @marshal_with(parameters_fields)
@@ -44,7 +51,11 @@ class AppParameterApi(InstalledAppResource):
             'retriever_resource': app_model_config.retriever_resource_dict,
             'retriever_resource': app_model_config.retriever_resource_dict,
             'more_like_this': app_model_config.more_like_this_dict,
             'more_like_this': app_model_config.more_like_this_dict,
             'user_input_form': app_model_config.user_input_form_list,
             'user_input_form': app_model_config.user_input_form_list,
-            'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict
+            'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
+            'file_upload': app_model_config.file_upload_dict,
+            'system_parameters': {
+                'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT')
+            }
         }
         }
 
 
 
 

+ 2 - 0
api/controllers/console/explore/saved_message.py

@@ -9,6 +9,7 @@ from controllers.console.explore.wraps import InstalledAppResource
 from libs.helper import uuid_value, TimestampField
 from libs.helper import uuid_value, TimestampField
 from services.errors.message import MessageNotExistsError
 from services.errors.message import MessageNotExistsError
 from services.saved_message_service import SavedMessageService
 from services.saved_message_service import SavedMessageService
+from fields.conversation_fields import message_file_fields
 
 
 feedback_fields = {
 feedback_fields = {
     'rating': fields.String
     'rating': fields.String
@@ -19,6 +20,7 @@ message_fields = {
     'inputs': fields.Raw,
     'inputs': fields.Raw,
     'query': fields.String,
     'query': fields.String,
     'answer': fields.String,
     'answer': fields.String,
+    'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
     'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
     'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
     'created_at': TimestampField
     'created_at': TimestampField
 }
 }

+ 3 - 0
api/controllers/console/universal_chat/chat.py

@@ -25,6 +25,7 @@ class UniversalChatApi(UniversalChatResource):
 
 
         parser = reqparse.RequestParser()
         parser = reqparse.RequestParser()
         parser.add_argument('query', type=str, required=True, location='json')
         parser.add_argument('query', type=str, required=True, location='json')
+        parser.add_argument('files', type=list, required=False, location='json')
         parser.add_argument('conversation_id', type=uuid_value, location='json')
         parser.add_argument('conversation_id', type=uuid_value, location='json')
         parser.add_argument('provider', type=str, required=True, location='json')
         parser.add_argument('provider', type=str, required=True, location='json')
         parser.add_argument('model', type=str, required=True, location='json')
         parser.add_argument('model', type=str, required=True, location='json')
@@ -60,6 +61,8 @@ class UniversalChatApi(UniversalChatResource):
         del args['model']
         del args['model']
         del args['tools']
         del args['tools']
 
 
+        args['auto_generate_name'] = False
+
         try:
         try:
             response = CompletionService.completion(
             response = CompletionService.completion(
                 app_model=app_model,
                 app_model=app_model,

+ 9 - 2
api/controllers/console/universal_chat/conversation.py

@@ -65,11 +65,18 @@ class UniversalChatConversationRenameApi(UniversalChatResource):
         conversation_id = str(c_id)
         conversation_id = str(c_id)
 
 
         parser = reqparse.RequestParser()
         parser = reqparse.RequestParser()
-        parser.add_argument('name', type=str, required=True, location='json')
+        parser.add_argument('name', type=str, required=False, location='json')
+        parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
         args = parser.parse_args()
         args = parser.parse_args()
 
 
         try:
         try:
-            return ConversationService.rename(app_model, conversation_id, current_user, args['name'])
+            return ConversationService.rename(
+                app_model,
+                conversation_id,
+                current_user,
+                args['name'],
+                args['auto_generate']
+            )
         except ConversationNotExistsError:
         except ConversationNotExistsError:
             raise NotFound("Conversation Not Exists.")
             raise NotFound("Conversation Not Exists.")
 
 

+ 10 - 0
api/controllers/files/__init__.py

@@ -0,0 +1,10 @@
+# -*- coding:utf-8 -*-
+from flask import Blueprint
+
+from libs.external_api import ExternalApi
+
+bp = Blueprint('files', __name__)
+api = ExternalApi(bp)
+
+
+from . import image_preview

+ 40 - 0
api/controllers/files/image_preview.py

@@ -0,0 +1,40 @@
+from flask import request, Response
+from flask_restful import Resource
+
+import services
+from controllers.files import api
+from libs.exception import BaseHTTPException
+from services.file_service import FileService
+
+
+class ImagePreviewApi(Resource):
+    def get(self, file_id):
+        file_id = str(file_id)
+
+        timestamp = request.args.get('timestamp')
+        nonce = request.args.get('nonce')
+        sign = request.args.get('sign')
+
+        if not timestamp or not nonce or not sign:
+            return {'content': 'Invalid request.'}, 400
+
+        try:
+            generator, mimetype = FileService.get_image_preview(
+                file_id,
+                timestamp,
+                nonce,
+                sign
+            )
+        except services.errors.file.UnsupportedFileTypeError:
+            raise UnsupportedFileTypeError()
+
+        return Response(generator, mimetype=mimetype)
+
+
+api.add_resource(ImagePreviewApi, '/files/<uuid:file_id>/image-preview')
+
+
+class UnsupportedFileTypeError(BaseHTTPException):
+    error_code = 'unsupported_file_type'
+    description = "File type not allowed."
+    code = 415

+ 1 - 1
api/controllers/service_api/__init__.py

@@ -7,6 +7,6 @@ bp = Blueprint('service_api', __name__, url_prefix='/v1')
 api = ExternalApi(bp)
 api = ExternalApi(bp)
 
 
 
 
-from .app import completion, app, conversation, message, audio
+from .app import completion, app, conversation, message, audio, file
 
 
 from .dataset import document, segment, dataset
 from .dataset import document, segment, dataset

+ 13 - 2
api/controllers/service_api/app/app.py

@@ -1,5 +1,6 @@
 # -*- coding:utf-8 -*-
 # -*- coding:utf-8 -*-
 from flask_restful import fields, marshal_with
 from flask_restful import fields, marshal_with
+from flask import current_app
 
 
 from controllers.service_api import api
 from controllers.service_api import api
 from controllers.service_api.wraps import AppApiResource
 from controllers.service_api.wraps import AppApiResource
@@ -20,6 +21,10 @@ class AppParameterApi(AppApiResource):
         'options': fields.List(fields.String)
         'options': fields.List(fields.String)
     }
     }
 
 
+    system_parameters_fields = {
+        'image_file_size_limit': fields.String
+    }
+
     parameters_fields = {
     parameters_fields = {
         'opening_statement': fields.String,
         'opening_statement': fields.String,
         'suggested_questions': fields.Raw,
         'suggested_questions': fields.Raw,
@@ -28,7 +33,9 @@ class AppParameterApi(AppApiResource):
         'retriever_resource': fields.Raw,
         'retriever_resource': fields.Raw,
         'more_like_this': fields.Raw,
         'more_like_this': fields.Raw,
         'user_input_form': fields.Raw,
         'user_input_form': fields.Raw,
-        'sensitive_word_avoidance': fields.Raw
+        'sensitive_word_avoidance': fields.Raw,
+        'file_upload': fields.Raw,
+        'system_parameters': fields.Nested(system_parameters_fields)
     }
     }
 
 
     @marshal_with(parameters_fields)
     @marshal_with(parameters_fields)
@@ -44,7 +51,11 @@ class AppParameterApi(AppApiResource):
             'retriever_resource': app_model_config.retriever_resource_dict,
             'retriever_resource': app_model_config.retriever_resource_dict,
             'more_like_this': app_model_config.more_like_this_dict,
             'more_like_this': app_model_config.more_like_this_dict,
             'user_input_form': app_model_config.user_input_form_list,
             'user_input_form': app_model_config.user_input_form_list,
-            'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict
+            'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
+            'file_upload': app_model_config.file_upload_dict,
+            'system_parameters': {
+                'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT')
+            }
         }
         }
 
 
 
 

+ 6 - 1
api/controllers/service_api/app/completion.py

@@ -28,6 +28,7 @@ class CompletionApi(AppApiResource):
         parser = reqparse.RequestParser()
         parser = reqparse.RequestParser()
         parser.add_argument('inputs', type=dict, required=True, location='json')
         parser.add_argument('inputs', type=dict, required=True, location='json')
         parser.add_argument('query', type=str, location='json', default='')
         parser.add_argument('query', type=str, location='json', default='')
+        parser.add_argument('files', type=list, required=False, location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('user', type=str, location='json')
         parser.add_argument('user', type=str, location='json')
         parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
         parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
@@ -39,13 +40,15 @@ class CompletionApi(AppApiResource):
         if end_user is None and args['user'] is not None:
         if end_user is None and args['user'] is not None:
             end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
             end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
 
 
+        args['auto_generate_name'] = False
+
         try:
         try:
             response = CompletionService.completion(
             response = CompletionService.completion(
                 app_model=app_model,
                 app_model=app_model,
                 user=end_user,
                 user=end_user,
                 args=args,
                 args=args,
                 from_source='api',
                 from_source='api',
-                streaming=streaming
+                streaming=streaming,
             )
             )
 
 
             return compact_response(response)
             return compact_response(response)
@@ -90,10 +93,12 @@ class ChatApi(AppApiResource):
         parser = reqparse.RequestParser()
         parser = reqparse.RequestParser()
         parser.add_argument('inputs', type=dict, required=True, location='json')
         parser.add_argument('inputs', type=dict, required=True, location='json')
         parser.add_argument('query', type=str, required=True, location='json')
         parser.add_argument('query', type=str, required=True, location='json')
+        parser.add_argument('files', type=list, required=False, location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('conversation_id', type=uuid_value, location='json')
         parser.add_argument('conversation_id', type=uuid_value, location='json')
         parser.add_argument('user', type=str, location='json')
         parser.add_argument('user', type=str, location='json')
         parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
         parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
+        parser.add_argument('auto_generate_name', type=bool, required=False, default='True', location='json')
 
 
         args = parser.parse_args()
         args = parser.parse_args()
 
 

+ 9 - 2
api/controllers/service_api/app/conversation.py

@@ -65,15 +65,22 @@ class ConversationRenameApi(AppApiResource):
         conversation_id = str(c_id)
         conversation_id = str(c_id)
 
 
         parser = reqparse.RequestParser()
         parser = reqparse.RequestParser()
-        parser.add_argument('name', type=str, required=True, location='json')
+        parser.add_argument('name', type=str, required=False, location='json')
         parser.add_argument('user', type=str, location='json')
         parser.add_argument('user', type=str, location='json')
+        parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
         args = parser.parse_args()
         args = parser.parse_args()
 
 
         if end_user is None and args['user'] is not None:
         if end_user is None and args['user'] is not None:
             end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
             end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
 
 
         try:
         try:
-            return ConversationService.rename(app_model, conversation_id, end_user, args['name'])
+            return ConversationService.rename(
+                app_model,
+                conversation_id,
+                end_user,
+                args['name'],
+                args['auto_generate']
+            )
         except services.errors.conversation.ConversationNotExistsError:
         except services.errors.conversation.ConversationNotExistsError:
             raise NotFound("Conversation Not Exists.")
             raise NotFound("Conversation Not Exists.")
 
 

+ 23 - 0
api/controllers/service_api/app/error.py

@@ -75,3 +75,26 @@ class ProviderNotSupportSpeechToTextError(BaseHTTPException):
     description = "Provider not support speech to text."
     description = "Provider not support speech to text."
     code = 400
     code = 400
 
 
+
+class NoFileUploadedError(BaseHTTPException):
+    error_code = 'no_file_uploaded'
+    description = "Please upload your file."
+    code = 400
+
+
+class TooManyFilesError(BaseHTTPException):
+    error_code = 'too_many_files'
+    description = "Only one file is allowed."
+    code = 400
+
+
+class FileTooLargeError(BaseHTTPException):
+    error_code = 'file_too_large'
+    description = "File size exceeded. {message}"
+    code = 413
+
+
+class UnsupportedFileTypeError(BaseHTTPException):
+    error_code = 'unsupported_file_type'
+    description = "File type not allowed."
+    code = 415

+ 42 - 0
api/controllers/service_api/app/file.py

@@ -0,0 +1,42 @@
+from flask import request
+from flask_restful import marshal_with
+
+from controllers.service_api import api
+from controllers.service_api.wraps import AppApiResource
+from controllers.service_api.app import create_or_update_end_user_for_user_id
+from controllers.service_api.app.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \
+    UnsupportedFileTypeError
+import services
+from services.file_service import FileService
+from fields.file_fields import file_fields
+
+
+class FileApi(AppApiResource):
+
+    @marshal_with(file_fields)
+    def post(self, app_model, end_user):
+
+        file = request.files['file']
+        user_args = request.form.get('user')
+
+        if end_user is None and user_args is not None:
+            end_user = create_or_update_end_user_for_user_id(app_model, user_args)
+
+        # check file
+        if 'file' not in request.files:
+            raise NoFileUploadedError()
+
+        if len(request.files) > 1:
+            raise TooManyFilesError()
+
+        try:
+            upload_file = FileService.upload_file(file, end_user)
+        except services.errors.file.FileTooLargeError as file_too_large_error:
+            raise FileTooLargeError(file_too_large_error.description)
+        except services.errors.file.UnsupportedFileTypeError:
+            raise UnsupportedFileTypeError()
+
+        return upload_file, 201
+
+
+api.add_resource(FileApi, '/files/upload')

+ 2 - 1
api/controllers/service_api/app/message.py

@@ -12,7 +12,7 @@ from libs.helper import TimestampField, uuid_value
 from services.message_service import MessageService
 from services.message_service import MessageService
 from extensions.ext_database import db
 from extensions.ext_database import db
 from models.model import Message, EndUser
 from models.model import Message, EndUser
-
+from fields.conversation_fields import message_file_fields
 
 
 class MessageListApi(AppApiResource):
 class MessageListApi(AppApiResource):
     feedback_fields = {
     feedback_fields = {
@@ -43,6 +43,7 @@ class MessageListApi(AppApiResource):
         'inputs': fields.Raw,
         'inputs': fields.Raw,
         'query': fields.String,
         'query': fields.String,
         'answer': fields.String,
         'answer': fields.String,
+        'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
         'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
         'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
         'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
         'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
         'created_at': TimestampField
         'created_at': TimestampField

+ 3 - 2
api/controllers/service_api/dataset/document.py

@@ -2,6 +2,7 @@ import json
 
 
 from flask import request
 from flask import request
 from flask_restful import reqparse, marshal
 from flask_restful import reqparse, marshal
+from flask_login import current_user
 from sqlalchemy import desc
 from sqlalchemy import desc
 from werkzeug.exceptions import NotFound
 from werkzeug.exceptions import NotFound
 
 
@@ -173,7 +174,7 @@ class DocumentAddByFileApi(DatasetApiResource):
         if len(request.files) > 1:
         if len(request.files) > 1:
             raise TooManyFilesError()
             raise TooManyFilesError()
 
 
-        upload_file = FileService.upload_file(file)
+        upload_file = FileService.upload_file(file, current_user)
         data_source = {
         data_source = {
             'type': 'upload_file',
             'type': 'upload_file',
             'info_list': {
             'info_list': {
@@ -235,7 +236,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
             if len(request.files) > 1:
             if len(request.files) > 1:
                 raise TooManyFilesError()
                 raise TooManyFilesError()
 
 
-            upload_file = FileService.upload_file(file)
+            upload_file = FileService.upload_file(file, current_user)
             data_source = {
             data_source = {
                 'type': 'upload_file',
                 'type': 'upload_file',
                 'info_list': {
                 'info_list': {

+ 1 - 1
api/controllers/web/__init__.py

@@ -7,4 +7,4 @@ bp = Blueprint('web', __name__, url_prefix='/api')
 api = ExternalApi(bp)
 api = ExternalApi(bp)
 
 
 
 
-from . import completion, app, conversation, message, site, saved_message, audio, passport
+from . import completion, app, conversation, message, site, saved_message, audio, passport, file

+ 13 - 2
api/controllers/web/app.py

@@ -1,5 +1,6 @@
 # -*- coding:utf-8 -*-
 # -*- coding:utf-8 -*-
 from flask_restful import marshal_with, fields
 from flask_restful import marshal_with, fields
+from flask import current_app
 
 
 from controllers.web import api
 from controllers.web import api
 from controllers.web.wraps import WebApiResource
 from controllers.web.wraps import WebApiResource
@@ -19,6 +20,10 @@ class AppParameterApi(WebApiResource):
         'options': fields.List(fields.String)
         'options': fields.List(fields.String)
     }
     }
 
 
+    system_parameters_fields = {
+        'image_file_size_limit': fields.String
+    }
+
     parameters_fields = {
     parameters_fields = {
         'opening_statement': fields.String,
         'opening_statement': fields.String,
         'suggested_questions': fields.Raw,
         'suggested_questions': fields.Raw,
@@ -27,7 +32,9 @@ class AppParameterApi(WebApiResource):
         'retriever_resource': fields.Raw,
         'retriever_resource': fields.Raw,
         'more_like_this': fields.Raw,
         'more_like_this': fields.Raw,
         'user_input_form': fields.Raw,
         'user_input_form': fields.Raw,
-        'sensitive_word_avoidance': fields.Raw
+        'sensitive_word_avoidance': fields.Raw,
+        'file_upload': fields.Raw,
+        'system_parameters': fields.Nested(system_parameters_fields)
     }
     }
 
 
     @marshal_with(parameters_fields)
     @marshal_with(parameters_fields)
@@ -43,7 +50,11 @@ class AppParameterApi(WebApiResource):
             'retriever_resource': app_model_config.retriever_resource_dict,
             'retriever_resource': app_model_config.retriever_resource_dict,
             'more_like_this': app_model_config.more_like_this_dict,
             'more_like_this': app_model_config.more_like_this_dict,
             'user_input_form': app_model_config.user_input_form_list,
             'user_input_form': app_model_config.user_input_form_list,
-            'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict
+            'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
+            'file_upload': app_model_config.file_upload_dict,
+            'system_parameters': {
+                'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT')
+            }
         }
         }
 
 
 
 

+ 4 - 0
api/controllers/web/completion.py

@@ -30,12 +30,14 @@ class CompletionApi(WebApiResource):
         parser = reqparse.RequestParser()
         parser = reqparse.RequestParser()
         parser.add_argument('inputs', type=dict, required=True, location='json')
         parser.add_argument('inputs', type=dict, required=True, location='json')
         parser.add_argument('query', type=str, location='json', default='')
         parser.add_argument('query', type=str, location='json', default='')
+        parser.add_argument('files', type=list, required=False, location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json')
         parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json')
 
 
         args = parser.parse_args()
         args = parser.parse_args()
 
 
         streaming = args['response_mode'] == 'streaming'
         streaming = args['response_mode'] == 'streaming'
+        args['auto_generate_name'] = False
 
 
         try:
         try:
             response = CompletionService.completion(
             response = CompletionService.completion(
@@ -88,6 +90,7 @@ class ChatApi(WebApiResource):
         parser = reqparse.RequestParser()
         parser = reqparse.RequestParser()
         parser.add_argument('inputs', type=dict, required=True, location='json')
         parser.add_argument('inputs', type=dict, required=True, location='json')
         parser.add_argument('query', type=str, required=True, location='json')
         parser.add_argument('query', type=str, required=True, location='json')
+        parser.add_argument('files', type=list, required=False, location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('conversation_id', type=uuid_value, location='json')
         parser.add_argument('conversation_id', type=uuid_value, location='json')
         parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json')
         parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json')
@@ -95,6 +98,7 @@ class ChatApi(WebApiResource):
         args = parser.parse_args()
         args = parser.parse_args()
 
 
         streaming = args['response_mode'] == 'streaming'
         streaming = args['response_mode'] == 'streaming'
+        args['auto_generate_name'] = False
 
 
         try:
         try:
             response = CompletionService.completion(
             response = CompletionService.completion(

+ 9 - 2
api/controllers/web/conversation.py

@@ -67,11 +67,18 @@ class ConversationRenameApi(WebApiResource):
         conversation_id = str(c_id)
         conversation_id = str(c_id)
 
 
         parser = reqparse.RequestParser()
         parser = reqparse.RequestParser()
-        parser.add_argument('name', type=str, required=True, location='json')
+        parser.add_argument('name', type=str, required=False, location='json')
+        parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
         args = parser.parse_args()
         args = parser.parse_args()
 
 
         try:
         try:
-            return ConversationService.rename(app_model, conversation_id, end_user, args['name'])
+            return ConversationService.rename(
+                app_model,
+                conversation_id,
+                end_user,
+                args['name'],
+                args['auto_generate']
+            )
         except ConversationNotExistsError:
         except ConversationNotExistsError:
             raise NotFound("Conversation Not Exists.")
             raise NotFound("Conversation Not Exists.")
 
 

+ 25 - 1
api/controllers/web/error.py

@@ -85,4 +85,28 @@ class UnsupportedAudioTypeError(BaseHTTPException):
 class ProviderNotSupportSpeechToTextError(BaseHTTPException):
 class ProviderNotSupportSpeechToTextError(BaseHTTPException):
     error_code = 'provider_not_support_speech_to_text'
     error_code = 'provider_not_support_speech_to_text'
     description = "Provider not support speech to text."
     description = "Provider not support speech to text."
-    code = 400
+    code = 400
+
+
+class NoFileUploadedError(BaseHTTPException):
+    error_code = 'no_file_uploaded'
+    description = "Please upload your file."
+    code = 400
+
+
+class TooManyFilesError(BaseHTTPException):
+    error_code = 'too_many_files'
+    description = "Only one file is allowed."
+    code = 400
+
+
+class FileTooLargeError(BaseHTTPException):
+    error_code = 'file_too_large'
+    description = "File size exceeded. {message}"
+    code = 413
+
+
+class UnsupportedFileTypeError(BaseHTTPException):
+    error_code = 'unsupported_file_type'
+    description = "File type not allowed."
+    code = 415

+ 36 - 0
api/controllers/web/file.py

@@ -0,0 +1,36 @@
+from flask import request
+from flask_restful import marshal_with
+
+from controllers.web import api
+from controllers.web.wraps import WebApiResource
+from controllers.web.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \
+    UnsupportedFileTypeError
+import services
+from services.file_service import FileService
+from fields.file_fields import file_fields
+
+
+class FileApi(WebApiResource):
+
+    @marshal_with(file_fields)
+    def post(self, app_model, end_user):
+        # get file from request
+        file = request.files['file']
+
+        # check file
+        if 'file' not in request.files:
+            raise NoFileUploadedError()
+
+        if len(request.files) > 1:
+            raise TooManyFilesError()
+        try:
+            upload_file = FileService.upload_file(file, end_user)
+        except services.errors.file.FileTooLargeError as file_too_large_error:
+            raise FileTooLargeError(file_too_large_error.description)
+        except services.errors.file.UnsupportedFileTypeError:
+            raise UnsupportedFileTypeError()
+
+        return upload_file, 201
+
+
+api.add_resource(FileApi, '/files/upload')

+ 2 - 0
api/controllers/web/message.py

@@ -22,6 +22,7 @@ from services.errors.app import MoreLikeThisDisabledError
 from services.errors.conversation import ConversationNotExistsError
 from services.errors.conversation import ConversationNotExistsError
 from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
 from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
 from services.message_service import MessageService
 from services.message_service import MessageService
+from fields.conversation_fields import message_file_fields
 
 
 
 
 class MessageListApi(WebApiResource):
 class MessageListApi(WebApiResource):
@@ -54,6 +55,7 @@ class MessageListApi(WebApiResource):
         'inputs': fields.Raw,
         'inputs': fields.Raw,
         'query': fields.String,
         'query': fields.String,
         'answer': fields.String,
         'answer': fields.String,
+        'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
         'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
         'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
         'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
         'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
         'created_at': TimestampField
         'created_at': TimestampField

+ 3 - 0
api/controllers/web/saved_message.py

@@ -8,6 +8,8 @@ from controllers.web.wraps import WebApiResource
 from libs.helper import uuid_value, TimestampField
 from libs.helper import uuid_value, TimestampField
 from services.errors.message import MessageNotExistsError
 from services.errors.message import MessageNotExistsError
 from services.saved_message_service import SavedMessageService
 from services.saved_message_service import SavedMessageService
+from fields.conversation_fields import message_file_fields
+
 
 
 feedback_fields = {
 feedback_fields = {
     'rating': fields.String
     'rating': fields.String
@@ -18,6 +20,7 @@ message_fields = {
     'inputs': fields.Raw,
     'inputs': fields.Raw,
     'query': fields.String,
     'query': fields.String,
     'answer': fields.String,
     'answer': fields.String,
+    'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
     'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
     'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
     'created_at': TimestampField
     'created_at': TimestampField
 }
 }

+ 8 - 2
api/core/callback_handler/llm_callback_handler.py

@@ -11,7 +11,8 @@ from pydantic import BaseModel
 from core.callback_handler.entity.llm_message import LLMMessage
 from core.callback_handler.entity.llm_message import LLMMessage
 from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
 from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
     ConversationTaskInterruptException
     ConversationTaskInterruptException
-from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage
+from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage, LCHumanMessageWithFiles, \
+    ImagePromptMessageFile
 from core.model_providers.models.llm.base import BaseLLM
 from core.model_providers.models.llm.base import BaseLLM
 from core.moderation.base import ModerationOutputsResult, ModerationAction
 from core.moderation.base import ModerationOutputsResult, ModerationAction
 from core.moderation.factory import ModerationFactory
 from core.moderation.factory import ModerationFactory
@@ -72,7 +73,12 @@ class LLMCallbackHandler(BaseCallbackHandler):
 
 
             real_prompts.append({
             real_prompts.append({
                 "role": role,
                 "role": role,
-                "text": message.content
+                "text": message.content,
+                "files": [{
+                    "type": file.type.value,
+                    "data": file.data[:10] + '...[TRUNCATED]...' + file.data[-10:],
+                    "detail": file.detail.value if isinstance(file, ImagePromptMessageFile) else None,
+                } for file in (message.files if isinstance(message, LCHumanMessageWithFiles) else [])]
             })
             })
 
 
         self.llm_message.prompt = real_prompts
         self.llm_message.prompt = real_prompts

+ 24 - 9
api/core/completion.py

@@ -13,11 +13,12 @@ from core.callback_handler.llm_callback_handler import LLMCallbackHandler
 from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
 from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
     ConversationTaskInterruptException
     ConversationTaskInterruptException
 from core.external_data_tool.factory import ExternalDataToolFactory
 from core.external_data_tool.factory import ExternalDataToolFactory
+from core.file.file_obj import FileObj
 from core.model_providers.error import LLMBadRequestError
 from core.model_providers.error import LLMBadRequestError
 from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
 from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
     ReadOnlyConversationTokenDBBufferSharedMemory
     ReadOnlyConversationTokenDBBufferSharedMemory
 from core.model_providers.model_factory import ModelFactory
 from core.model_providers.model_factory import ModelFactory
-from core.model_providers.models.entity.message import PromptMessage
+from core.model_providers.models.entity.message import PromptMessage, PromptMessageFile
 from core.model_providers.models.llm.base import BaseLLM
 from core.model_providers.models.llm.base import BaseLLM
 from core.orchestrator_rule_parser import OrchestratorRuleParser
 from core.orchestrator_rule_parser import OrchestratorRuleParser
 from core.prompt.prompt_template import PromptTemplateParser
 from core.prompt.prompt_template import PromptTemplateParser
@@ -30,8 +31,9 @@ from core.moderation.factory import ModerationFactory
 class Completion:
 class Completion:
     @classmethod
     @classmethod
     def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
     def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
-                 user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool,
-                 is_override: bool = False, retriever_from: str = 'dev'):
+                 files: List[FileObj], user: Union[Account, EndUser], conversation: Optional[Conversation],
+                 streaming: bool, is_override: bool = False, retriever_from: str = 'dev',
+                 auto_generate_name: bool = True):
         """
         """
         errors: ProviderTokenNotInitError
         errors: ProviderTokenNotInitError
         """
         """
@@ -64,16 +66,21 @@ class Completion:
             is_override=is_override,
             is_override=is_override,
             inputs=inputs,
             inputs=inputs,
             query=query,
             query=query,
+            files=files,
             streaming=streaming,
             streaming=streaming,
-            model_instance=final_model_instance
+            model_instance=final_model_instance,
+            auto_generate_name=auto_generate_name
         )
         )
 
 
+        prompt_message_files = [file.prompt_message_file for file in files]
+
         rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
         rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
             mode=app.mode,
             mode=app.mode,
             model_instance=final_model_instance,
             model_instance=final_model_instance,
             app_model_config=app_model_config,
             app_model_config=app_model_config,
             query=query,
             query=query,
-            inputs=inputs
+            inputs=inputs,
+            files=prompt_message_files
         )
         )
 
 
         # init orchestrator rule parser
         # init orchestrator rule parser
@@ -95,6 +102,7 @@ class Completion:
                     app_model_config=app_model_config,
                     app_model_config=app_model_config,
                     query=query,
                     query=query,
                     inputs=inputs,
                     inputs=inputs,
+                    files=prompt_message_files,
                     agent_execute_result=None,
                     agent_execute_result=None,
                     conversation_message_task=conversation_message_task,
                     conversation_message_task=conversation_message_task,
                     memory=memory,
                     memory=memory,
@@ -146,6 +154,7 @@ class Completion:
                 app_model_config=app_model_config,
                 app_model_config=app_model_config,
                 query=query,
                 query=query,
                 inputs=inputs,
                 inputs=inputs,
+                files=prompt_message_files,
                 agent_execute_result=agent_execute_result,
                 agent_execute_result=agent_execute_result,
                 conversation_message_task=conversation_message_task,
                 conversation_message_task=conversation_message_task,
                 memory=memory,
                 memory=memory,
@@ -257,6 +266,7 @@ class Completion:
     @classmethod
     @classmethod
     def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str,
     def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str,
                       inputs: dict,
                       inputs: dict,
+                      files: List[PromptMessageFile],
                       agent_execute_result: Optional[AgentExecuteResult],
                       agent_execute_result: Optional[AgentExecuteResult],
                       conversation_message_task: ConversationMessageTask,
                       conversation_message_task: ConversationMessageTask,
                       memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
                       memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
@@ -266,10 +276,12 @@ class Completion:
         # get llm prompt
         # get llm prompt
         if app_model_config.prompt_type == 'simple':
         if app_model_config.prompt_type == 'simple':
             prompt_messages, stop_words = prompt_transform.get_prompt(
             prompt_messages, stop_words = prompt_transform.get_prompt(
-                mode=mode,
+                app_mode=mode,
+                app_model_config=app_model_config,
                 pre_prompt=app_model_config.pre_prompt,
                 pre_prompt=app_model_config.pre_prompt,
                 inputs=inputs,
                 inputs=inputs,
                 query=query,
                 query=query,
+                files=files,
                 context=agent_execute_result.output if agent_execute_result else None,
                 context=agent_execute_result.output if agent_execute_result else None,
                 memory=memory,
                 memory=memory,
                 model_instance=model_instance
                 model_instance=model_instance
@@ -280,6 +292,7 @@ class Completion:
                 app_model_config=app_model_config,
                 app_model_config=app_model_config,
                 inputs=inputs,
                 inputs=inputs,
                 query=query,
                 query=query,
+                files=files,
                 context=agent_execute_result.output if agent_execute_result else None,
                 context=agent_execute_result.output if agent_execute_result else None,
                 memory=memory,
                 memory=memory,
                 model_instance=model_instance
                 model_instance=model_instance
@@ -337,7 +350,7 @@ class Completion:
 
 
     @classmethod
     @classmethod
     def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_config: AppModelConfig,
     def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_config: AppModelConfig,
-                                 query: str, inputs: dict) -> int:
+                                 query: str, inputs: dict, files: List[PromptMessageFile]) -> int:
         model_limited_tokens = model_instance.model_rules.max_tokens.max
         model_limited_tokens = model_instance.model_rules.max_tokens.max
         max_tokens = model_instance.get_model_kwargs().max_tokens
         max_tokens = model_instance.get_model_kwargs().max_tokens
 
 
@@ -348,15 +361,16 @@ class Completion:
             max_tokens = 0
             max_tokens = 0
 
 
         prompt_transform = PromptTransform()
         prompt_transform = PromptTransform()
-        prompt_messages = []
 
 
         # get prompt without memory and context
         # get prompt without memory and context
         if app_model_config.prompt_type == 'simple':
         if app_model_config.prompt_type == 'simple':
             prompt_messages, _ = prompt_transform.get_prompt(
             prompt_messages, _ = prompt_transform.get_prompt(
-                mode=mode,
+                app_mode=mode,
+                app_model_config=app_model_config,
                 pre_prompt=app_model_config.pre_prompt,
                 pre_prompt=app_model_config.pre_prompt,
                 inputs=inputs,
                 inputs=inputs,
                 query=query,
                 query=query,
+                files=files,
                 context=None,
                 context=None,
                 memory=None,
                 memory=None,
                 model_instance=model_instance
                 model_instance=model_instance
@@ -367,6 +381,7 @@ class Completion:
                 app_model_config=app_model_config,
                 app_model_config=app_model_config,
                 inputs=inputs,
                 inputs=inputs,
                 query=query,
                 query=query,
+                files=files,
                 context=None,
                 context=None,
                 memory=None,
                 memory=None,
                 model_instance=model_instance
                 model_instance=model_instance

+ 24 - 6
api/core/conversation_message_task.py

@@ -6,8 +6,9 @@ from core.callback_handler.entity.agent_loop import AgentLoop
 from core.callback_handler.entity.dataset_query import DatasetQueryObj
 from core.callback_handler.entity.dataset_query import DatasetQueryObj
 from core.callback_handler.entity.llm_message import LLMMessage
 from core.callback_handler.entity.llm_message import LLMMessage
 from core.callback_handler.entity.chain_result import ChainResult
 from core.callback_handler.entity.chain_result import ChainResult
+from core.file.file_obj import FileObj
 from core.model_providers.model_factory import ModelFactory
 from core.model_providers.model_factory import ModelFactory
-from core.model_providers.models.entity.message import to_prompt_messages, MessageType
+from core.model_providers.models.entity.message import to_prompt_messages, MessageType, PromptMessageFile
 from core.model_providers.models.llm.base import BaseLLM
 from core.model_providers.models.llm.base import BaseLLM
 from core.prompt.prompt_builder import PromptBuilder
 from core.prompt.prompt_builder import PromptBuilder
 from core.prompt.prompt_template import PromptTemplateParser
 from core.prompt.prompt_template import PromptTemplateParser
@@ -16,13 +17,14 @@ from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from extensions.ext_redis import redis_client
 from models.dataset import DatasetQuery
 from models.dataset import DatasetQuery
 from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, \
 from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, \
-    MessageChain, DatasetRetrieverResource
+    MessageChain, DatasetRetrieverResource, MessageFile
 
 
 
 
 class ConversationMessageTask:
 class ConversationMessageTask:
     def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
     def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
-                 inputs: dict, query: str, streaming: bool, model_instance: BaseLLM,
-                 conversation: Optional[Conversation] = None, is_override: bool = False):
+                 inputs: dict, query: str, files: List[FileObj], streaming: bool,
+                 model_instance: BaseLLM, conversation: Optional[Conversation] = None, is_override: bool = False,
+                 auto_generate_name: bool = True):
         self.start_at = time.perf_counter()
         self.start_at = time.perf_counter()
 
 
         self.task_id = task_id
         self.task_id = task_id
@@ -35,6 +37,7 @@ class ConversationMessageTask:
         self.user = user
         self.user = user
         self.inputs = inputs
         self.inputs = inputs
         self.query = query
         self.query = query
+        self.files = files
         self.streaming = streaming
         self.streaming = streaming
 
 
         self.conversation = conversation
         self.conversation = conversation
@@ -45,6 +48,7 @@ class ConversationMessageTask:
         self.message = None
         self.message = None
 
 
         self.retriever_resource = None
         self.retriever_resource = None
+        self.auto_generate_name = auto_generate_name
 
 
         self.model_dict = self.app_model_config.model_dict
         self.model_dict = self.app_model_config.model_dict
         self.provider_name = self.model_dict.get('provider')
         self.provider_name = self.model_dict.get('provider')
@@ -100,7 +104,7 @@ class ConversationMessageTask:
                 model_id=self.model_name,
                 model_id=self.model_name,
                 override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
                 override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
                 mode=self.mode,
                 mode=self.mode,
-                name='',
+                name='New conversation',
                 inputs=self.inputs,
                 inputs=self.inputs,
                 introduction=introduction,
                 introduction=introduction,
                 system_instruction=system_instruction,
                 system_instruction=system_instruction,
@@ -142,6 +146,19 @@ class ConversationMessageTask:
         db.session.add(self.message)
         db.session.add(self.message)
         db.session.commit()
         db.session.commit()
 
 
+        for file in self.files:
+            message_file = MessageFile(
+                message_id=self.message.id,
+                type=file.type.value,
+                transfer_method=file.transfer_method.value,
+                url=file.url,
+                upload_file_id=file.upload_file_id,
+                created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
+                created_by=self.user.id
+            )
+            db.session.add(message_file)
+            db.session.commit()
+
     def append_message_text(self, text: str):
     def append_message_text(self, text: str):
         if text is not None:
         if text is not None:
             self._pub_handler.pub_text(text)
             self._pub_handler.pub_text(text)
@@ -176,7 +193,8 @@ class ConversationMessageTask:
         message_was_created.send(
         message_was_created.send(
             self.message,
             self.message,
             conversation=self.conversation,
             conversation=self.conversation,
-            is_first_message=self.is_new_conversation
+            is_first_message=self.is_new_conversation,
+            auto_generate_name=self.auto_generate_name
         )
         )
 
 
         if not by_stopped:
         if not by_stopped:

+ 0 - 0
api/core/file/__init__.py


+ 79 - 0
api/core/file/file_obj.py

@@ -0,0 +1,79 @@
+import enum
+from typing import Optional
+
+from pydantic import BaseModel
+
+from core.file.upload_file_parser import UploadFileParser
+from core.model_providers.models.entity.message import PromptMessageFile, ImagePromptMessageFile
+from extensions.ext_database import db
+from models.model import UploadFile
+
+
+class FileType(enum.Enum):
+    IMAGE = 'image'
+
+    @staticmethod
+    def value_of(value):
+        for member in FileType:
+            if member.value == value:
+                return member
+        raise ValueError(f"No matching enum found for value '{value}'")
+
+
+class FileTransferMethod(enum.Enum):
+    REMOTE_URL = 'remote_url'
+    LOCAL_FILE = 'local_file'
+
+    @staticmethod
+    def value_of(value):
+        for member in FileTransferMethod:
+            if member.value == value:
+                return member
+        raise ValueError(f"No matching enum found for value '{value}'")
+
+
+class FileObj(BaseModel):
+    id: Optional[str]
+    tenant_id: str
+    type: FileType
+    transfer_method: FileTransferMethod
+    url: Optional[str]
+    upload_file_id: Optional[str]
+    file_config: dict
+
+    @property
+    def data(self) -> Optional[str]:
+        return self._get_data()
+
+    @property
+    def preview_url(self) -> Optional[str]:
+        return self._get_data(force_url=True)
+
+    @property
+    def prompt_message_file(self) -> PromptMessageFile:
+        if self.type == FileType.IMAGE:
+            image_config = self.file_config.get('image')
+
+            return ImagePromptMessageFile(
+                data=self.data,
+                detail=ImagePromptMessageFile.DETAIL.HIGH
+                if image_config.get("detail") == "high" else ImagePromptMessageFile.DETAIL.LOW
+            )
+
+    def _get_data(self, force_url: bool = False) -> Optional[str]:
+        if self.type == FileType.IMAGE:
+            if self.transfer_method == FileTransferMethod.REMOTE_URL:
+                return self.url
+            elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
+                upload_file = (db.session.query(UploadFile)
+                               .filter(
+                    UploadFile.id == self.upload_file_id,
+                    UploadFile.tenant_id == self.tenant_id
+                ).first())
+
+                return UploadFileParser.get_image_data(
+                    upload_file=upload_file,
+                    force_url=force_url
+                )
+
+        return None

+ 180 - 0
api/core/file/message_file_parser.py

@@ -0,0 +1,180 @@
+from typing import List, Union, Optional, Dict
+
+import requests
+
+from core.file.file_obj import FileObj, FileType, FileTransferMethod
+from core.file.upload_file_parser import SUPPORT_EXTENSIONS
+from extensions.ext_database import db
+from models.account import Account
+from models.model import MessageFile, EndUser, AppModelConfig, UploadFile
+
+
+class MessageFileParser:
+
+    def __init__(self, tenant_id: str, app_id: str) -> None:
+        self.tenant_id = tenant_id
+        self.app_id = app_id
+
+    def validate_and_transform_files_arg(self, files: List[dict], app_model_config: AppModelConfig,
+                                         user: Union[Account, EndUser]) -> List[FileObj]:
+        """
+        validate and transform files arg
+
+        :param files:
+        :param app_model_config:
+        :param user:
+        :return:
+        """
+        file_upload_config = app_model_config.file_upload_dict
+
+        for file in files:
+            if not isinstance(file, dict):
+                raise ValueError('Invalid file format, must be dict')
+            if not file.get('type'):
+                raise ValueError('Missing file type')
+            FileType.value_of(file.get('type'))
+            if not file.get('transfer_method'):
+                raise ValueError('Missing file transfer method')
+            FileTransferMethod.value_of(file.get('transfer_method'))
+            if file.get('transfer_method') == FileTransferMethod.REMOTE_URL.value:
+                if not file.get('url'):
+                    raise ValueError('Missing file url')
+                if not file.get('url').startswith('http'):
+                    raise ValueError('Invalid file url')
+            if file.get('transfer_method') == FileTransferMethod.LOCAL_FILE.value and not file.get('upload_file_id'):
+                raise ValueError('Missing file upload_file_id')
+
+        # transform files to file objs
+        type_file_objs = self._to_file_objs(files, file_upload_config)
+
+        # validate files
+        new_files = []
+        for file_type, file_objs in type_file_objs.items():
+            if file_type == FileType.IMAGE:
+                # parse and validate files
+                image_config = file_upload_config.get('image')
+
+                # check if image file feature is enabled
+                if not image_config['enabled']:
+                    continue
+
+                # Validate number of files
+                if len(files) > image_config['number_limits']:
+                    raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}")
+
+                for file_obj in file_objs:
+                    # Validate transfer method
+                    if file_obj.transfer_method.value not in image_config['transfer_methods']:
+                        raise ValueError(f'Invalid transfer method: {file_obj.transfer_method.value}')
+
+                    # Validate file type
+                    if file_obj.type != FileType.IMAGE:
+                        raise ValueError(f'Invalid file type: {file_obj.type}')
+
+                    if file_obj.transfer_method == FileTransferMethod.REMOTE_URL:
+                        # check remote url valid and is image
+                        result, error = self._check_image_remote_url(file_obj.url)
+                        if result is False:
+                            raise ValueError(error)
+                    elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE:
+                        # get upload file from upload_file_id
+                        upload_file = (db.session.query(UploadFile)
+                                       .filter(
+                            UploadFile.id == file_obj.upload_file_id,
+                            UploadFile.tenant_id == self.tenant_id,
+                            UploadFile.created_by == user.id,
+                            UploadFile.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
+                            UploadFile.extension.in_(SUPPORT_EXTENSIONS)
+                        ).first())
+
+                        # check upload file is belong to tenant and user
+                        if not upload_file:
+                            raise ValueError('Invalid upload file')
+
+                    new_files.append(file_obj)
+
+        # return all file objs
+        return new_files
+
+    def transform_message_files(self, files: List[MessageFile], app_model_config: Optional[AppModelConfig]) -> List[FileObj]:
+        """
+        transform message files
+
+        :param files:
+        :param app_model_config:
+        :return:
+        """
+        # transform files to file objs
+        type_file_objs = self._to_file_objs(files, app_model_config.file_upload_dict)
+
+        # return all file objs
+        return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
+
+    def _to_file_objs(self, files: List[Union[Dict, MessageFile]],
+                      file_upload_config: dict) -> Dict[FileType, List[FileObj]]:
+        """
+        transform files to file objs
+
+        :param files:
+        :param file_upload_config:
+        :return:
+        """
+        type_file_objs: Dict[FileType, List[FileObj]] = {
+            # Currently only support image
+            FileType.IMAGE: []
+        }
+
+        if not files:
+            return type_file_objs
+
+        # group by file type and convert file args or message files to FileObj
+        for file in files:
+            file_obj = self._to_file_obj(file, file_upload_config)
+            if file_obj.type not in type_file_objs:
+                continue
+
+            type_file_objs[file_obj.type].append(file_obj)
+
+        return type_file_objs
+
+    def _to_file_obj(self, file: Union[dict, MessageFile], file_upload_config: dict) -> FileObj:
+        """
+        transform file to file obj
+
+        :param file:
+        :return:
+        """
+        if isinstance(file, dict):
+            transfer_method = FileTransferMethod.value_of(file.get('transfer_method'))
+            return FileObj(
+                tenant_id=self.tenant_id,
+                type=FileType.value_of(file.get('type')),
+                transfer_method=transfer_method,
+                url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
+                upload_file_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
+                file_config=file_upload_config
+            )
+        else:
+            return FileObj(
+                id=file.id,
+                tenant_id=self.tenant_id,
+                type=FileType.value_of(file.type),
+                transfer_method=FileTransferMethod.value_of(file.transfer_method),
+                url=file.url,
+                upload_file_id=file.upload_file_id or None,
+                file_config=file_upload_config
+            )
+
+    def _check_image_remote_url(self, url):
+        try:
+            headers = {
+                "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
+            }
+
+            response = requests.head(url, headers=headers, allow_redirects=True)
+            if response.status_code == 200:
+                return True, ""
+            else:
+                return False, "URL does not exist."
+        except requests.RequestException as e:
+            return False, f"Error checking URL: {e}"

+ 79 - 0
api/core/file/upload_file_parser.py

@@ -0,0 +1,79 @@
+import base64
+import hashlib
+import hmac
+import logging
+import os
+import time
+from typing import Optional
+
+from flask import current_app
+
+from extensions.ext_storage import storage
+
+SUPPORT_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif']
+
+
+class UploadFileParser:
+    @classmethod
+    def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]:
+        if not upload_file:
+            return None
+
+        if upload_file.extension not in SUPPORT_EXTENSIONS:
+            return None
+
+        if current_app.config['MULTIMODAL_SEND_IMAGE_FORMAT'] == 'url' or force_url:
+            return cls.get_signed_temp_image_url(upload_file)
+        else:
+            # get image file base64
+            try:
+                data = storage.load(upload_file.key)
+            except FileNotFoundError:
+                logging.error(f'File not found: {upload_file.key}')
+                return None
+
+            encoded_string = base64.b64encode(data).decode('utf-8')
+            return f'data:{upload_file.mime_type};base64,{encoded_string}'
+
+    @classmethod
+    def get_signed_temp_image_url(cls, upload_file) -> str:
+        """
+        get signed url from upload file
+
+        :param upload_file: UploadFile object
+        :return:
+        """
+        base_url = current_app.config.get('FILES_URL')
+        image_preview_url = f'{base_url}/files/{upload_file.id}/image-preview'
+
+        timestamp = str(int(time.time()))
+        nonce = os.urandom(16).hex()
+        data_to_sign = f"image-preview|{upload_file.id}|{timestamp}|{nonce}"
+        secret_key = current_app.config['SECRET_KEY'].encode()
+        sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
+        encoded_sign = base64.urlsafe_b64encode(sign).decode()
+
+        return f"{image_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
+
+    @classmethod
+    def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
+        """
+        verify signature
+
+        :param upload_file_id: file id
+        :param timestamp: timestamp
+        :param nonce: nonce
+        :param sign: signature
+        :return:
+        """
+        data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
+        secret_key = current_app.config['SECRET_KEY'].encode()
+        recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
+        recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
+
+        # verify signature
+        if sign != recalculated_encoded_sign:
+            return False
+
+        current_time = int(time.time())
+        return current_time - int(timestamp) <= 300  # expired after 5 minutes

+ 6 - 2
api/core/generator/llm_generator.py

@@ -16,7 +16,7 @@ from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT
 
 
 class LLMGenerator:
 class LLMGenerator:
     @classmethod
     @classmethod
-    def generate_conversation_name(cls, tenant_id: str, query, answer):
+    def generate_conversation_name(cls, tenant_id: str, query):
         prompt = CONVERSATION_TITLE_PROMPT
         prompt = CONVERSATION_TITLE_PROMPT
 
 
         if len(query) > 2000:
         if len(query) > 2000:
@@ -40,8 +40,12 @@ class LLMGenerator:
 
 
         result_dict = json.loads(answer)
         result_dict = json.loads(answer)
         answer = result_dict['Your Output']
         answer = result_dict['Your Output']
+        name = answer.strip()
 
 
-        return answer.strip()
+        if len(name) > 75:
+            name = name[:75] + '...'
+
+        return name
 
 
     @classmethod
     @classmethod
     def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):
     def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):

+ 19 - 1
api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py

@@ -3,6 +3,7 @@ from typing import Any, List, Dict
 from langchain.memory.chat_memory import BaseChatMemory
 from langchain.memory.chat_memory import BaseChatMemory
 from langchain.schema import get_buffer_string, BaseMessage
 from langchain.schema import get_buffer_string, BaseMessage
 
 
+from core.file.message_file_parser import MessageFileParser
 from core.model_providers.models.entity.message import PromptMessage, MessageType, to_lc_messages
 from core.model_providers.models.entity.message import PromptMessage, MessageType, to_lc_messages
 from core.model_providers.models.llm.base import BaseLLM
 from core.model_providers.models.llm.base import BaseLLM
 from extensions.ext_database import db
 from extensions.ext_database import db
@@ -21,6 +22,8 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
     @property
     @property
     def buffer(self) -> List[BaseMessage]:
     def buffer(self) -> List[BaseMessage]:
         """String buffer of memory."""
         """String buffer of memory."""
+        app_model = self.conversation.app
+
         # fetch limited messages desc, and return reversed
         # fetch limited messages desc, and return reversed
         messages = db.session.query(Message).filter(
         messages = db.session.query(Message).filter(
             Message.conversation_id == self.conversation.id,
             Message.conversation_id == self.conversation.id,
@@ -28,10 +31,25 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
         ).order_by(Message.created_at.desc()).limit(self.message_limit).all()
         ).order_by(Message.created_at.desc()).limit(self.message_limit).all()
 
 
         messages = list(reversed(messages))
         messages = list(reversed(messages))
+        message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=self.conversation.app_id)
 
 
         chat_messages: List[PromptMessage] = []
         chat_messages: List[PromptMessage] = []
         for message in messages:
         for message in messages:
-            chat_messages.append(PromptMessage(content=message.query, type=MessageType.USER))
+            files = message.message_files
+            if files:
+                file_objs = message_file_parser.transform_message_files(
+                    files, message.app_model_config
+                )
+
+                prompt_message_files = [file_obj.prompt_message_file for file_obj in file_objs]
+                chat_messages.append(PromptMessage(
+                    content=message.query,
+                    type=MessageType.USER,
+                    files=prompt_message_files
+                ))
+            else:
+                chat_messages.append(PromptMessage(content=message.query, type=MessageType.USER))
+
             chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT))
             chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT))
 
 
         if not chat_messages:
         if not chat_messages:

+ 46 - 2
api/core/model_providers/models/entity/message.py

@@ -1,4 +1,5 @@
 import enum
 import enum
+from typing import Any, cast, Union, List, Dict
 
 
 from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage, FunctionMessage
 from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage, FunctionMessage
 from pydantic import BaseModel
 from pydantic import BaseModel
@@ -18,17 +19,53 @@ class MessageType(enum.Enum):
     SYSTEM = 'system'
     SYSTEM = 'system'
 
 
 
 
+class PromptMessageFileType(enum.Enum):
+    IMAGE = 'image'
+
+    @staticmethod
+    def value_of(value):
+        for member in PromptMessageFileType:
+            if member.value == value:
+                return member
+        raise ValueError(f"No matching enum found for value '{value}'")
+
+
+
+class PromptMessageFile(BaseModel):
+    type: PromptMessageFileType
+    data: Any
+
+
+class ImagePromptMessageFile(PromptMessageFile):
+    class DETAIL(enum.Enum):
+        LOW = 'low'
+        HIGH = 'high'
+
+    type: PromptMessageFileType = PromptMessageFileType.IMAGE
+    detail: DETAIL = DETAIL.LOW
+
+
 class PromptMessage(BaseModel):
 class PromptMessage(BaseModel):
     type: MessageType = MessageType.USER
     type: MessageType = MessageType.USER
     content: str = ''
     content: str = ''
+    files: list[PromptMessageFile] = []
     function_call: dict = None
     function_call: dict = None
 
 
 
 
+class LCHumanMessageWithFiles(HumanMessage):
+    # content: Union[str, List[Union[str, Dict]]]
+    content: str
+    files: list[PromptMessageFile]
+
+
 def to_lc_messages(messages: list[PromptMessage]):
 def to_lc_messages(messages: list[PromptMessage]):
     lc_messages = []
     lc_messages = []
     for message in messages:
     for message in messages:
         if message.type == MessageType.USER:
         if message.type == MessageType.USER:
-            lc_messages.append(HumanMessage(content=message.content))
+            if not message.files:
+                lc_messages.append(HumanMessage(content=message.content))
+            else:
+                lc_messages.append(LCHumanMessageWithFiles(content=message.content, files=message.files))
         elif message.type == MessageType.ASSISTANT:
         elif message.type == MessageType.ASSISTANT:
             additional_kwargs = {}
             additional_kwargs = {}
             if message.function_call:
             if message.function_call:
@@ -44,7 +81,14 @@ def to_prompt_messages(messages: list[BaseMessage]):
     prompt_messages = []
     prompt_messages = []
     for message in messages:
     for message in messages:
         if isinstance(message, HumanMessage):
         if isinstance(message, HumanMessage):
-            prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
+            if isinstance(message, LCHumanMessageWithFiles):
+                prompt_messages.append(PromptMessage(
+                    content=message.content,
+                    type=MessageType.USER,
+                    files=message.files
+                ))
+            else:
+                prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
         elif isinstance(message, AIMessage):
         elif isinstance(message, AIMessage):
             message_kwargs = {
             message_kwargs = {
                 'content': message.content,
                 'content': message.content,

+ 0 - 2
api/core/model_providers/models/llm/openai_model.py

@@ -1,11 +1,9 @@
-import decimal
 import logging
 import logging
 from typing import List, Optional, Any
 from typing import List, Optional, Any
 
 
 import openai
 import openai
 from langchain.callbacks.manager import Callbacks
 from langchain.callbacks.manager import Callbacks
 from langchain.schema import LLMResult
 from langchain.schema import LLMResult
-from openai import api_requestor
 
 
 from core.model_providers.providers.base import BaseModelProvider
 from core.model_providers.providers.base import BaseModelProvider
 from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
 from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI

+ 153 - 76
api/core/prompt/prompt_transform.py

@@ -8,7 +8,7 @@ from langchain.memory.chat_memory import BaseChatMemory
 from langchain.schema import BaseMessage
 from langchain.schema import BaseMessage
 
 
 from core.model_providers.models.entity.model_params import ModelMode
 from core.model_providers.models.entity.model_params import ModelMode
-from core.model_providers.models.entity.message import PromptMessage, MessageType, to_prompt_messages
+from core.model_providers.models.entity.message import PromptMessage, MessageType, to_prompt_messages, PromptMessageFile
 from core.model_providers.models.llm.base import BaseLLM
 from core.model_providers.models.llm.base import BaseLLM
 from core.model_providers.models.llm.baichuan_model import BaichuanModel
 from core.model_providers.models.llm.baichuan_model import BaichuanModel
 from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel
 from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel
@@ -16,32 +16,59 @@ from core.model_providers.models.llm.openllm_model import OpenLLMModel
 from core.model_providers.models.llm.xinference_model import XinferenceModel
 from core.model_providers.models.llm.xinference_model import XinferenceModel
 from core.prompt.prompt_builder import PromptBuilder
 from core.prompt.prompt_builder import PromptBuilder
 from core.prompt.prompt_template import PromptTemplateParser
 from core.prompt.prompt_template import PromptTemplateParser
+from models.model import AppModelConfig
+
 
 
 class AppMode(enum.Enum):
 class AppMode(enum.Enum):
     COMPLETION = 'completion'
     COMPLETION = 'completion'
     CHAT = 'chat'
     CHAT = 'chat'
 
 
+
 class PromptTransform:
 class PromptTransform:
-    def get_prompt(self, mode: str,
-                   pre_prompt: str, inputs: dict,
+    def get_prompt(self,
+                   app_mode: str,
+                   app_model_config: AppModelConfig,
+                   pre_prompt: str,
+                   inputs: dict,
                    query: str,
                    query: str,
+                   files: List[PromptMessageFile],
                    context: Optional[str],
                    context: Optional[str],
                    memory: Optional[BaseChatMemory],
                    memory: Optional[BaseChatMemory],
                    model_instance: BaseLLM) -> \
                    model_instance: BaseLLM) -> \
             Tuple[List[PromptMessage], Optional[List[str]]]:
             Tuple[List[PromptMessage], Optional[List[str]]]:
-        prompt_rules = self._read_prompt_rules_from_file(self._prompt_file_name(mode, model_instance))
-        prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory, model_instance)
-        return [PromptMessage(content=prompt)], stops
-
-    def get_advanced_prompt(self, 
-            app_mode: str,
-            app_model_config: str, 
-            inputs: dict,
-            query: str,
-            context: Optional[str],
-            memory: Optional[BaseChatMemory],
-            model_instance: BaseLLM) -> List[PromptMessage]:
-        
+        model_mode = app_model_config.model_dict['mode']
+
+        app_mode_enum = AppMode(app_mode)
+        model_mode_enum = ModelMode(model_mode)
+
+        prompt_rules = self._read_prompt_rules_from_file(self._prompt_file_name(app_mode, model_instance))
+
+        if app_mode_enum == AppMode.CHAT and model_mode_enum == ModelMode.CHAT:
+            stops = None
+
+            prompt_messages = self._get_simple_chat_app_chat_model_prompt_messages(prompt_rules, pre_prompt, inputs,
+                                                                                   query, context, memory,
+                                                                                   model_instance, files)
+        else:
+            stops = prompt_rules.get('stops')
+            if stops is not None and len(stops) == 0:
+                stops = None
+
+            prompt_messages = self._get_simple_others_prompt_messages(prompt_rules, pre_prompt, inputs, query, context,
+                                                                      memory,
+                                                                      model_instance, files)
+        return prompt_messages, stops
+
+    def get_advanced_prompt(self,
+                            app_mode: str,
+                            app_model_config: AppModelConfig,
+                            inputs: dict,
+                            query: str,
+                            files: List[PromptMessageFile],
+                            context: Optional[str],
+                            memory: Optional[BaseChatMemory],
+                            model_instance: BaseLLM) -> List[PromptMessage]:
+
         model_mode = app_model_config.model_dict['mode']
         model_mode = app_model_config.model_dict['mode']
 
 
         app_mode_enum = AppMode(app_mode)
         app_mode_enum = AppMode(app_mode)
@@ -51,15 +78,20 @@ class PromptTransform:
 
 
         if app_mode_enum == AppMode.CHAT:
         if app_mode_enum == AppMode.CHAT:
             if model_mode_enum == ModelMode.COMPLETION:
             if model_mode_enum == ModelMode.COMPLETION:
-                prompt_messages = self._get_chat_app_completion_model_prompt_messages(app_model_config, inputs, query, context, memory, model_instance)
+                prompt_messages = self._get_chat_app_completion_model_prompt_messages(app_model_config, inputs, query,
+                                                                                      files, context, memory,
+                                                                                      model_instance)
             elif model_mode_enum == ModelMode.CHAT:
             elif model_mode_enum == ModelMode.CHAT:
-                prompt_messages =  self._get_chat_app_chat_model_prompt_messages(app_model_config, inputs, query, context, memory, model_instance)
+                prompt_messages = self._get_chat_app_chat_model_prompt_messages(app_model_config, inputs, query, files,
+                                                                                context, memory, model_instance)
         elif app_mode_enum == AppMode.COMPLETION:
         elif app_mode_enum == AppMode.COMPLETION:
             if model_mode_enum == ModelMode.CHAT:
             if model_mode_enum == ModelMode.CHAT:
-                prompt_messages =  self._get_completion_app_chat_model_prompt_messages(app_model_config, inputs, context)
+                prompt_messages = self._get_completion_app_chat_model_prompt_messages(app_model_config, inputs,
+                                                                                      files, context)
             elif model_mode_enum == ModelMode.COMPLETION:
             elif model_mode_enum == ModelMode.COMPLETION:
-                prompt_messages =  self._get_completion_app_completion_model_prompt_messages(app_model_config, inputs, context)
-            
+                prompt_messages = self._get_completion_app_completion_model_prompt_messages(app_model_config, inputs,
+                                                                                            files, context)
+
         return prompt_messages
         return prompt_messages
 
 
     def _get_history_messages_from_memory(self, memory: BaseChatMemory,
     def _get_history_messages_from_memory(self, memory: BaseChatMemory,
@@ -71,7 +103,7 @@ class PromptTransform:
         return external_context[memory_key]
         return external_context[memory_key]
 
 
     def _get_history_messages_list_from_memory(self, memory: BaseChatMemory,
     def _get_history_messages_list_from_memory(self, memory: BaseChatMemory,
-                                          max_token_limit: int) -> List[PromptMessage]:
+                                               max_token_limit: int) -> List[PromptMessage]:
         """Get memory messages."""
         """Get memory messages."""
         memory.max_token_limit = max_token_limit
         memory.max_token_limit = max_token_limit
         memory.return_messages = True
         memory.return_messages = True
@@ -79,7 +111,7 @@ class PromptTransform:
         external_context = memory.load_memory_variables({})
         external_context = memory.load_memory_variables({})
         memory.return_messages = False
         memory.return_messages = False
         return to_prompt_messages(external_context[memory_key])
         return to_prompt_messages(external_context[memory_key])
-    
+
     def _prompt_file_name(self, mode: str, model_instance: BaseLLM) -> str:
     def _prompt_file_name(self, mode: str, model_instance: BaseLLM) -> str:
         # baichuan
         # baichuan
         if isinstance(model_instance, BaichuanModel):
         if isinstance(model_instance, BaichuanModel):
@@ -94,13 +126,13 @@ class PromptTransform:
             return 'common_completion'
             return 'common_completion'
         else:
         else:
             return 'common_chat'
             return 'common_chat'
-        
+
     def _prompt_file_name_for_baichuan(self, mode: str) -> str:
     def _prompt_file_name_for_baichuan(self, mode: str) -> str:
         if mode == 'completion':
         if mode == 'completion':
             return 'baichuan_completion'
             return 'baichuan_completion'
         else:
         else:
             return 'baichuan_chat'
             return 'baichuan_chat'
-    
+
     def _read_prompt_rules_from_file(self, prompt_name: str) -> dict:
     def _read_prompt_rules_from_file(self, prompt_name: str) -> dict:
         # Get the absolute path of the subdirectory
         # Get the absolute path of the subdirectory
         prompt_path = os.path.join(
         prompt_path = os.path.join(
@@ -111,12 +143,53 @@ class PromptTransform:
         # Open the JSON file and read its content
         # Open the JSON file and read its content
         with open(json_file_path, 'r') as json_file:
         with open(json_file_path, 'r') as json_file:
             return json.load(json_file)
             return json.load(json_file)
-        
-    def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
-                             query: str,
-                             context: Optional[str],
-                             memory: Optional[BaseChatMemory],
-                             model_instance: BaseLLM) -> Tuple[str, Optional[list]]:
+
+    def _get_simple_chat_app_chat_model_prompt_messages(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
+                                                        query: str,
+                                                        context: Optional[str],
+                                                        memory: Optional[BaseChatMemory],
+                                                        model_instance: BaseLLM,
+                                                        files: List[PromptMessageFile]) -> List[PromptMessage]:
+        prompt_messages = []
+
+        context_prompt_content = ''
+        if context and 'context_prompt' in prompt_rules:
+            prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt'])
+            context_prompt_content = prompt_template.format(
+                {'context': context}
+            )
+
+        pre_prompt_content = ''
+        if pre_prompt:
+            prompt_template = PromptTemplateParser(template=pre_prompt)
+            prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
+            pre_prompt_content = prompt_template.format(
+                prompt_inputs
+            )
+
+        prompt = ''
+        for order in prompt_rules['system_prompt_orders']:
+            if order == 'context_prompt':
+                prompt += context_prompt_content
+            elif order == 'pre_prompt':
+                prompt += pre_prompt_content
+
+        prompt = re.sub(r'<\|.*?\|>', '', prompt)
+
+        prompt_messages.append(PromptMessage(type=MessageType.SYSTEM, content=prompt))
+
+        self._append_chat_histories(memory, prompt_messages, model_instance)
+
+        prompt_messages.append(PromptMessage(type=MessageType.USER, content=query, files=files))
+
+        return prompt_messages
+
+    def _get_simple_others_prompt_messages(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
+                                           query: str,
+                                           context: Optional[str],
+                                           memory: Optional[BaseChatMemory],
+                                           model_instance: BaseLLM,
+                                           files: List[PromptMessageFile]) -> List[PromptMessage]:
         context_prompt_content = ''
         context_prompt_content = ''
         if context and 'context_prompt' in prompt_rules:
         if context and 'context_prompt' in prompt_rules:
             prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt'])
             prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt'])
@@ -175,16 +248,12 @@ class PromptTransform:
 
 
         prompt = re.sub(r'<\|.*?\|>', '', prompt)
         prompt = re.sub(r'<\|.*?\|>', '', prompt)
 
 
-        stops = prompt_rules.get('stops')
-        if stops is not None and len(stops) == 0:
-            stops = None
+        return [PromptMessage(content=prompt, files=files)]
 
 
-        return prompt, stops
-    
     def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None:
     def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None:
         if '#context#' in prompt_template.variable_keys:
         if '#context#' in prompt_template.variable_keys:
             if context:
             if context:
-                prompt_inputs['#context#'] = context    
+                prompt_inputs['#context#'] = context
             else:
             else:
                 prompt_inputs['#context#'] = ''
                 prompt_inputs['#context#'] = ''
 
 
@@ -195,17 +264,18 @@ class PromptTransform:
             else:
             else:
                 prompt_inputs['#query#'] = ''
                 prompt_inputs['#query#'] = ''
 
 
-    def _set_histories_variable(self, memory: BaseChatMemory, raw_prompt: str, conversation_histories_role: dict, 
-                                prompt_template: PromptTemplateParser, prompt_inputs: dict, model_instance: BaseLLM) -> None:
+    def _set_histories_variable(self, memory: BaseChatMemory, raw_prompt: str, conversation_histories_role: dict,
+                                prompt_template: PromptTemplateParser, prompt_inputs: dict,
+                                model_instance: BaseLLM) -> None:
         if '#histories#' in prompt_template.variable_keys:
         if '#histories#' in prompt_template.variable_keys:
             if memory:
             if memory:
                 tmp_human_message = PromptBuilder.to_human_message(
                 tmp_human_message = PromptBuilder.to_human_message(
                     prompt_content=raw_prompt,
                     prompt_content=raw_prompt,
-                    inputs={ '#histories#': '', **prompt_inputs }
+                    inputs={'#histories#': '', **prompt_inputs}
                 )
                 )
 
 
                 rest_tokens = self._calculate_rest_token(tmp_human_message, model_instance)
                 rest_tokens = self._calculate_rest_token(tmp_human_message, model_instance)
-                
+
                 memory.human_prefix = conversation_histories_role['user_prefix']
                 memory.human_prefix = conversation_histories_role['user_prefix']
                 memory.ai_prefix = conversation_histories_role['assistant_prefix']
                 memory.ai_prefix = conversation_histories_role['assistant_prefix']
                 histories = self._get_history_messages_from_memory(memory, rest_tokens)
                 histories = self._get_history_messages_from_memory(memory, rest_tokens)
@@ -213,7 +283,8 @@ class PromptTransform:
             else:
             else:
                 prompt_inputs['#histories#'] = ''
                 prompt_inputs['#histories#'] = ''
 
 
-    def _append_chat_histories(self, memory: BaseChatMemory, prompt_messages: list[PromptMessage], model_instance: BaseLLM) -> None:
+    def _append_chat_histories(self, memory: BaseChatMemory, prompt_messages: list[PromptMessage],
+                               model_instance: BaseLLM) -> None:
         if memory:
         if memory:
             rest_tokens = self._calculate_rest_token(prompt_messages, model_instance)
             rest_tokens = self._calculate_rest_token(prompt_messages, model_instance)
 
 
@@ -242,19 +313,19 @@ class PromptTransform:
         return prompt
         return prompt
 
 
     def _get_chat_app_completion_model_prompt_messages(self,
     def _get_chat_app_completion_model_prompt_messages(self,
-            app_model_config: str,
-            inputs: dict,
-            query: str,
-            context: Optional[str],
-            memory: Optional[BaseChatMemory],
-            model_instance: BaseLLM) -> List[PromptMessage]:
-        
+                                                       app_model_config: AppModelConfig,
+                                                       inputs: dict,
+                                                       query: str,
+                                                       files: List[PromptMessageFile],
+                                                       context: Optional[str],
+                                                       memory: Optional[BaseChatMemory],
+                                                       model_instance: BaseLLM) -> List[PromptMessage]:
+
         raw_prompt = app_model_config.completion_prompt_config_dict['prompt']['text']
         raw_prompt = app_model_config.completion_prompt_config_dict['prompt']['text']
         conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role']
         conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role']
 
 
         prompt_messages = []
         prompt_messages = []
-        prompt = ''
-        
+
         prompt_template = PromptTemplateParser(template=raw_prompt)
         prompt_template = PromptTemplateParser(template=raw_prompt)
         prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
         prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
 
 
@@ -262,28 +333,29 @@ class PromptTransform:
 
 
         self._set_query_variable(query, prompt_template, prompt_inputs)
         self._set_query_variable(query, prompt_template, prompt_inputs)
 
 
-        self._set_histories_variable(memory, raw_prompt, conversation_histories_role, prompt_template, prompt_inputs, model_instance)
+        self._set_histories_variable(memory, raw_prompt, conversation_histories_role, prompt_template, prompt_inputs,
+                                     model_instance)
 
 
         prompt = self._format_prompt(prompt_template, prompt_inputs)
         prompt = self._format_prompt(prompt_template, prompt_inputs)
 
 
-        prompt_messages.append(PromptMessage(type = MessageType(MessageType.USER) ,content=prompt))
+        prompt_messages.append(PromptMessage(type=MessageType.USER, content=prompt, files=files))
 
 
         return prompt_messages
         return prompt_messages
 
 
     def _get_chat_app_chat_model_prompt_messages(self,
     def _get_chat_app_chat_model_prompt_messages(self,
-            app_model_config: str,
-            inputs: dict,
-            query: str,
-            context: Optional[str],
-            memory: Optional[BaseChatMemory],
-            model_instance: BaseLLM) -> List[PromptMessage]:
+                                                 app_model_config: AppModelConfig,
+                                                 inputs: dict,
+                                                 query: str,
+                                                 files: List[PromptMessageFile],
+                                                 context: Optional[str],
+                                                 memory: Optional[BaseChatMemory],
+                                                 model_instance: BaseLLM) -> List[PromptMessage]:
         raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
         raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
 
 
         prompt_messages = []
         prompt_messages = []
 
 
         for prompt_item in raw_prompt_list:
         for prompt_item in raw_prompt_list:
             raw_prompt = prompt_item['text']
             raw_prompt = prompt_item['text']
-            prompt = ''
 
 
             prompt_template = PromptTemplateParser(template=raw_prompt)
             prompt_template = PromptTemplateParser(template=raw_prompt)
             prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
             prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
@@ -292,23 +364,23 @@ class PromptTransform:
 
 
             prompt = self._format_prompt(prompt_template, prompt_inputs)
             prompt = self._format_prompt(prompt_template, prompt_inputs)
 
 
-            prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))
-        
+            prompt_messages.append(PromptMessage(type=MessageType(prompt_item['role']), content=prompt))
+
         self._append_chat_histories(memory, prompt_messages, model_instance)
         self._append_chat_histories(memory, prompt_messages, model_instance)
 
 
-        prompt_messages.append(PromptMessage(type = MessageType.USER ,content=query))
+        prompt_messages.append(PromptMessage(type=MessageType.USER, content=query, files=files))
 
 
         return prompt_messages
         return prompt_messages
 
 
     def _get_completion_app_completion_model_prompt_messages(self,
     def _get_completion_app_completion_model_prompt_messages(self,
-                   app_model_config: str,
-                   inputs: dict,
-                   context: Optional[str]) -> List[PromptMessage]:
+                                                             app_model_config: AppModelConfig,
+                                                             inputs: dict,
+                                                             files: List[PromptMessageFile],
+                                                             context: Optional[str]) -> List[PromptMessage]:
         raw_prompt = app_model_config.completion_prompt_config_dict['prompt']['text']
         raw_prompt = app_model_config.completion_prompt_config_dict['prompt']['text']
 
 
         prompt_messages = []
         prompt_messages = []
-        prompt = ''
-        
+
         prompt_template = PromptTemplateParser(template=raw_prompt)
         prompt_template = PromptTemplateParser(template=raw_prompt)
         prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
         prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
 
 
@@ -316,21 +388,21 @@ class PromptTransform:
 
 
         prompt = self._format_prompt(prompt_template, prompt_inputs)
         prompt = self._format_prompt(prompt_template, prompt_inputs)
 
 
-        prompt_messages.append(PromptMessage(type = MessageType(MessageType.USER) ,content=prompt))
+        prompt_messages.append(PromptMessage(type=MessageType(MessageType.USER), content=prompt, files=files))
 
 
         return prompt_messages
         return prompt_messages
 
 
     def _get_completion_app_chat_model_prompt_messages(self,
     def _get_completion_app_chat_model_prompt_messages(self,
-                   app_model_config: str,
-                   inputs: dict,
-                   context: Optional[str]) -> List[PromptMessage]:
+                                                       app_model_config: AppModelConfig,
+                                                       inputs: dict,
+                                                       files: List[PromptMessageFile],
+                                                       context: Optional[str]) -> List[PromptMessage]:
         raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
         raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
 
 
         prompt_messages = []
         prompt_messages = []
 
 
         for prompt_item in raw_prompt_list:
         for prompt_item in raw_prompt_list:
             raw_prompt = prompt_item['text']
             raw_prompt = prompt_item['text']
-            prompt = ''
 
 
             prompt_template = PromptTemplateParser(template=raw_prompt)
             prompt_template = PromptTemplateParser(template=raw_prompt)
             prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
             prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
@@ -339,6 +411,11 @@ class PromptTransform:
 
 
             prompt = self._format_prompt(prompt_template, prompt_inputs)
             prompt = self._format_prompt(prompt_template, prompt_inputs)
 
 
-            prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))
-        
-        return prompt_messages
+            prompt_messages.append(PromptMessage(type=MessageType(prompt_item['role']), content=prompt))
+
+        for prompt_message in prompt_messages[::-1]:
+            if prompt_message.type == MessageType.USER:
+                prompt_message.files = files
+                break
+
+        return prompt_messages

+ 103 - 1
api/core/third_party/langchain/llms/chat_open_ai.py

@@ -1,10 +1,13 @@
 import os
 import os
 
 
-from typing import Dict, Any, Optional, Union, Tuple
+from typing import Dict, Any, Optional, Union, Tuple, List, cast
 
 
 from langchain.chat_models import ChatOpenAI
 from langchain.chat_models import ChatOpenAI
+from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage, FunctionMessage
 from pydantic import root_validator
 from pydantic import root_validator
 
 
+from core.model_providers.models.entity.message import LCHumanMessageWithFiles, PromptMessageFileType, ImagePromptMessageFile
+
 
 
 class EnhanceChatOpenAI(ChatOpenAI):
 class EnhanceChatOpenAI(ChatOpenAI):
     request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
     request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
@@ -48,3 +51,102 @@ class EnhanceChatOpenAI(ChatOpenAI):
             "api_key": self.openai_api_key,
             "api_key": self.openai_api_key,
             "organization": self.openai_organization if self.openai_organization else None,
             "organization": self.openai_organization if self.openai_organization else None,
         }
         }
+
+    def _create_message_dicts(
+        self, messages: List[BaseMessage], stop: Optional[List[str]]
+    ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
+        params = self._client_params
+        if stop is not None:
+            if "stop" in params:
+                raise ValueError("`stop` found in both the input and default params.")
+            params["stop"] = stop
+        message_dicts = [self._convert_message_to_dict(m) for m in messages]
+        return message_dicts, params
+
+    def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
+        """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
+
+        Official documentation: https://github.com/openai/openai-cookbook/blob/
+        main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
+        model, encoding = self._get_encoding_model()
+        if model.startswith("gpt-3.5-turbo-0301"):
+            # every message follows <im_start>{role/name}\n{content}<im_end>\n
+            tokens_per_message = 4
+            # if there's a name, the role is omitted
+            tokens_per_name = -1
+        elif model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"):
+            tokens_per_message = 3
+            tokens_per_name = 1
+        else:
+            raise NotImplementedError(
+                f"get_num_tokens_from_messages() is not presently implemented "
+                f"for model {model}."
+                "See https://github.com/openai/openai-python/blob/main/chatml.md for "
+                "information on how messages are converted to tokens."
+            )
+        num_tokens = 0
+        messages_dict = [self._convert_message_to_dict(m) for m in messages]
+        for message in messages_dict:
+            num_tokens += tokens_per_message
+            for key, value in message.items():
+                # Cast str(value) in case the message value is not a string
+                # This occurs with function messages
+                # TODO: The current token calculation method for the image type is not implemented,
+                #  which need to download the image and then get the resolution for calculation,
+                #  and will increase the request delay
+                if isinstance(value, list):
+                    text = ''
+                    for item in value:
+                        if isinstance(item, dict) and item['type'] == 'text':
+                            text += item['text']
+
+                    value = text
+                num_tokens += len(encoding.encode(str(value)))
+                if key == "name":
+                    num_tokens += tokens_per_name
+        # every reply is primed with <im_start>assistant
+        num_tokens += 3
+        return num_tokens
+
+    def _convert_message_to_dict(self, message: BaseMessage) -> dict:
+        if isinstance(message, ChatMessage):
+            message_dict = {"role": message.role, "content": message.content}
+        elif isinstance(message, LCHumanMessageWithFiles):
+            content = [
+                {
+                    "type": "text",
+                    "text": message.content
+                }
+            ]
+
+            for file in message.files:
+                if file.type == PromptMessageFileType.IMAGE:
+                    file = cast(ImagePromptMessageFile, file)
+                    content.append({
+                        "type": "image_url",
+                        "image_url": {
+                            "url": file.data,
+                            "detail": file.detail.value
+                        }
+                    })
+
+            message_dict = {"role": "user", "content": content}
+        elif isinstance(message, HumanMessage):
+            message_dict = {"role": "user", "content": message.content}
+        elif isinstance(message, AIMessage):
+            message_dict = {"role": "assistant", "content": message.content}
+            if "function_call" in message.additional_kwargs:
+                message_dict["function_call"] = message.additional_kwargs["function_call"]
+        elif isinstance(message, SystemMessage):
+            message_dict = {"role": "system", "content": message.content}
+        elif isinstance(message, FunctionMessage):
+            message_dict = {
+                "role": "function",
+                "content": message.content,
+                "name": message.name,
+            }
+        else:
+            raise ValueError(f"Got unknown type {message}")
+        if "name" in message.additional_kwargs:
+            message_dict["name"] = message.additional_kwargs["name"]
+        return message_dict

+ 4 - 10
api/events/event_handlers/generate_conversation_name_when_first_message_created.py

@@ -1,5 +1,3 @@
-import logging
-
 from core.generator.llm_generator import LLMGenerator
 from core.generator.llm_generator import LLMGenerator
 from events.message_event import message_was_created
 from events.message_event import message_was_created
 from extensions.ext_database import db
 from extensions.ext_database import db
@@ -10,8 +8,9 @@ def handle(sender, **kwargs):
     message = sender
     message = sender
     conversation = kwargs.get('conversation')
     conversation = kwargs.get('conversation')
     is_first_message = kwargs.get('is_first_message')
     is_first_message = kwargs.get('is_first_message')
+    auto_generate_name = kwargs.get('auto_generate_name', True)
 
 
-    if is_first_message:
+    if auto_generate_name and is_first_message:
         if conversation.mode == 'chat':
         if conversation.mode == 'chat':
             app_model = conversation.app
             app_model = conversation.app
             if not app_model:
             if not app_model:
@@ -19,14 +18,9 @@ def handle(sender, **kwargs):
 
 
             # generate conversation name
             # generate conversation name
             try:
             try:
-                name = LLMGenerator.generate_conversation_name(app_model.tenant_id, message.query, message.answer)
-
-                if len(name) > 75:
-                    name = name[:75] + '...'
-
+                name = LLMGenerator.generate_conversation_name(app_model.tenant_id, message.query)
                 conversation.name = name
                 conversation.name = name
             except:
             except:
-                conversation.name = 'New conversation'
+                pass
 
 
-            db.session.add(conversation)
             db.session.commit()
             db.session.commit()

+ 36 - 1
api/extensions/ext_storage.py

@@ -1,6 +1,7 @@
 import os
 import os
 import shutil
 import shutil
 from contextlib import closing
 from contextlib import closing
+from typing import Union, Generator
 
 
 import boto3
 import boto3
 from botocore.exceptions import ClientError
 from botocore.exceptions import ClientError
@@ -45,7 +46,13 @@ class Storage:
             with open(os.path.join(os.getcwd(), filename), "wb") as f:
             with open(os.path.join(os.getcwd(), filename), "wb") as f:
                 f.write(data)
                 f.write(data)
 
 
-    def load(self, filename):
+    def load(self, filename: str, stream: bool = False) -> Union[bytes, Generator]:
+        if stream:
+            return self.load_stream(filename)
+        else:
+            return self.load_once(filename)
+
+    def load_once(self, filename: str) -> bytes:
         if self.storage_type == 's3':
         if self.storage_type == 's3':
             try:
             try:
                 with closing(self.client) as client:
                 with closing(self.client) as client:
@@ -69,6 +76,34 @@ class Storage:
 
 
         return data
         return data
 
 
+    def load_stream(self, filename: str) -> Generator:
+        def generate(filename: str = filename) -> Generator:
+            if self.storage_type == 's3':
+                try:
+                    with closing(self.client) as client:
+                        response = client.get_object(Bucket=self.bucket_name, Key=filename)
+                        for chunk in response['Body'].iter_chunks():
+                            yield chunk
+                except ClientError as ex:
+                    if ex.response['Error']['Code'] == 'NoSuchKey':
+                        raise FileNotFoundError("File not found")
+                    else:
+                        raise
+            else:
+                if not self.folder or self.folder.endswith('/'):
+                    filename = self.folder + filename
+                else:
+                    filename = self.folder + '/' + filename
+
+                if not os.path.exists(filename):
+                    raise FileNotFoundError("File not found")
+
+                with open(filename, "rb") as f:
+                    while chunk := f.read(4096):  # Read in chunks of 4KB
+                        yield chunk
+
+        return generate()
+
     def download(self, filename, target_filepath):
     def download(self, filename, target_filepath):
         if self.storage_type == 's3':
         if self.storage_type == 's3':
             with closing(self.client) as client:
             with closing(self.client) as client:

+ 3 - 2
api/fields/app_fields.py

@@ -32,7 +32,8 @@ model_config_fields = {
     'prompt_type': fields.String,
     'prompt_type': fields.String,
     'chat_prompt_config': fields.Raw(attribute='chat_prompt_config_dict'),
     'chat_prompt_config': fields.Raw(attribute='chat_prompt_config_dict'),
     'completion_prompt_config': fields.Raw(attribute='completion_prompt_config_dict'),
     'completion_prompt_config': fields.Raw(attribute='completion_prompt_config_dict'),
-    'dataset_configs': fields.Raw(attribute='dataset_configs_dict')
+    'dataset_configs': fields.Raw(attribute='dataset_configs_dict'),
+    'file_upload': fields.Raw(attribute='file_upload_dict'),
 }
 }
 
 
 app_detail_fields = {
 app_detail_fields = {
@@ -140,4 +141,4 @@ app_site_fields = {
     'privacy_policy': fields.String,
     'privacy_policy': fields.String,
     'customize_token_strategy': fields.String,
     'customize_token_strategy': fields.String,
     'prompt_public': fields.Boolean
     'prompt_public': fields.Boolean
-}
+}

+ 9 - 7
api/fields/conversation_fields.py

@@ -28,6 +28,12 @@ annotation_fields = {
     'created_at': TimestampField
     'created_at': TimestampField
 }
 }
 
 
+message_file_fields = {
+    'id': fields.String,
+    'type': fields.String,
+    'url': fields.String,
+}
+
 message_detail_fields = {
 message_detail_fields = {
     'id': fields.String,
     'id': fields.String,
     'conversation_id': fields.String,
     'conversation_id': fields.String,
@@ -43,7 +49,8 @@ message_detail_fields = {
     'from_account_id': fields.String,
     'from_account_id': fields.String,
     'feedbacks': fields.List(fields.Nested(feedback_fields)),
     'feedbacks': fields.List(fields.Nested(feedback_fields)),
     'annotation': fields.Nested(annotation_fields, allow_null=True),
     'annotation': fields.Nested(annotation_fields, allow_null=True),
-    'created_at': TimestampField
+    'created_at': TimestampField,
+    'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
 }
 }
 
 
 feedback_stat_fields = {
 feedback_stat_fields = {
@@ -111,11 +118,6 @@ conversation_message_detail_fields = {
     'message': fields.Nested(message_detail_fields, attribute='first_message'),
     'message': fields.Nested(message_detail_fields, attribute='first_message'),
 }
 }
 
 
-simple_model_config_fields = {
-    'model': fields.Raw(attribute='model_dict'),
-    'pre_prompt': fields.String,
-}
-
 conversation_with_summary_fields = {
 conversation_with_summary_fields = {
     'id': fields.String,
     'id': fields.String,
     'status': fields.String,
     'status': fields.String,
@@ -180,4 +182,4 @@ conversation_with_model_config_infinite_scroll_pagination_fields = {
     'limit': fields.Integer,
     'limit': fields.Integer,
     'has_more': fields.Boolean,
     'has_more': fields.Boolean,
     'data': fields.List(fields.Nested(conversation_with_model_config_fields))
     'data': fields.List(fields.Nested(conversation_with_model_config_fields))
-}
+}

+ 2 - 1
api/fields/file_fields.py

@@ -4,7 +4,8 @@ from libs.helper import TimestampField
 
 
 upload_config_fields = {
 upload_config_fields = {
     'file_size_limit': fields.Integer,
     'file_size_limit': fields.Integer,
-    'batch_count_limit': fields.Integer
+    'batch_count_limit': fields.Integer,
+    'image_file_size_limit': fields.Integer,
 }
 }
 
 
 file_fields = {
 file_fields = {

+ 2 - 0
api/fields/message_fields.py

@@ -1,6 +1,7 @@
 from flask_restful import fields
 from flask_restful import fields
 
 
 from libs.helper import TimestampField
 from libs.helper import TimestampField
+from fields.conversation_fields import message_file_fields
 
 
 feedback_fields = {
 feedback_fields = {
     'rating': fields.String
     'rating': fields.String
@@ -31,6 +32,7 @@ message_fields = {
     'inputs': fields.Raw,
     'inputs': fields.Raw,
     'query': fields.String,
     'query': fields.String,
     'answer': fields.String,
     'answer': fields.String,
+    'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
     'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
     'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
     'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
     'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
     'created_at': TimestampField
     'created_at': TimestampField

+ 59 - 0
api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py

@@ -0,0 +1,59 @@
+"""add gpt4v supports
+
+Revision ID: 8fe468ba0ca5
+Revises: a9836e3baeee
+Create Date: 2023-11-09 11:39:00.006432
+
+"""
+from alembic import op
+import sqlalchemy as sa
+from sqlalchemy.dialects import postgresql
+
+# revision identifiers, used by Alembic.
+revision = '8fe468ba0ca5'
+down_revision = 'a9836e3baeee'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    op.create_table('message_files',
+    sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+    sa.Column('message_id', postgresql.UUID(), nullable=False),
+    sa.Column('type', sa.String(length=255), nullable=False),
+    sa.Column('transfer_method', sa.String(length=255), nullable=False),
+    sa.Column('url', sa.Text(), nullable=True),
+    sa.Column('upload_file_id', postgresql.UUID(), nullable=True),
+    sa.Column('created_by_role', sa.String(length=255), nullable=False),
+    sa.Column('created_by', postgresql.UUID(), nullable=False),
+    sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
+    sa.PrimaryKeyConstraint('id', name='message_file_pkey')
+    )
+    with op.batch_alter_table('message_files', schema=None) as batch_op:
+        batch_op.create_index('message_file_created_by_idx', ['created_by'], unique=False)
+        batch_op.create_index('message_file_message_idx', ['message_id'], unique=False)
+
+    with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
+        batch_op.add_column(sa.Column('file_upload', sa.Text(), nullable=True))
+
+    with op.batch_alter_table('upload_files', schema=None) as batch_op:
+        batch_op.add_column(sa.Column('created_by_role', sa.String(length=255), server_default=sa.text("'account'::character varying"), nullable=False))
+
+    # ### end Alembic commands ###
+
+
+def downgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('upload_files', schema=None) as batch_op:
+        batch_op.drop_column('created_by_role')
+
+    with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
+        batch_op.drop_column('file_upload')
+
+    with op.batch_alter_table('message_files', schema=None) as batch_op:
+        batch_op.drop_index('message_file_message_idx')
+        batch_op.drop_index('message_file_created_by_idx')
+
+    op.drop_table('message_files')
+    # ### end Alembic commands ###

+ 63 - 4
api/models/model.py

@@ -1,10 +1,10 @@
 import json
 import json
-from json import JSONDecodeError
 
 
 from flask import current_app, request
 from flask import current_app, request
 from flask_login import UserMixin
 from flask_login import UserMixin
 from sqlalchemy.dialects.postgresql import UUID
 from sqlalchemy.dialects.postgresql import UUID
 
 
+from core.file.upload_file_parser import UploadFileParser
 from libs.helper import generate_string
 from libs.helper import generate_string
 from extensions.ext_database import db
 from extensions.ext_database import db
 from .account import Account, Tenant
 from .account import Account, Tenant
@@ -98,6 +98,7 @@ class AppModelConfig(db.Model):
     completion_prompt_config = db.Column(db.Text)
     completion_prompt_config = db.Column(db.Text)
     dataset_configs = db.Column(db.Text)
     dataset_configs = db.Column(db.Text)
     external_data_tools = db.Column(db.Text)
     external_data_tools = db.Column(db.Text)
+    file_upload = db.Column(db.Text)
 
 
     @property
     @property
     def app(self):
     def app(self):
@@ -161,6 +162,10 @@ class AppModelConfig(db.Model):
     def dataset_configs_dict(self) -> dict:
     def dataset_configs_dict(self) -> dict:
         return json.loads(self.dataset_configs) if self.dataset_configs else {"top_k": 2, "score_threshold": {"enable": False}}
         return json.loads(self.dataset_configs) if self.dataset_configs else {"top_k": 2, "score_threshold": {"enable": False}}
 
 
+    @property
+    def file_upload_dict(self) -> dict:
+        return json.loads(self.file_upload) if self.file_upload else {"image": {"enabled": False, "number_limits": 3, "detail": "high", "transfer_methods": ["remote_url", "local_file"]}}
+
     def to_dict(self) -> dict:
     def to_dict(self) -> dict:
         return {
         return {
             "provider": "",
             "provider": "",
@@ -182,7 +187,8 @@ class AppModelConfig(db.Model):
             "prompt_type": self.prompt_type,
             "prompt_type": self.prompt_type,
             "chat_prompt_config": self.chat_prompt_config_dict,
             "chat_prompt_config": self.chat_prompt_config_dict,
             "completion_prompt_config": self.completion_prompt_config_dict,
             "completion_prompt_config": self.completion_prompt_config_dict,
-            "dataset_configs": self.dataset_configs_dict
+            "dataset_configs": self.dataset_configs_dict,
+            "file_upload": self.file_upload_dict
         }
         }
 
 
     def from_model_config_dict(self, model_config: dict):
     def from_model_config_dict(self, model_config: dict):
@@ -213,6 +219,8 @@ class AppModelConfig(db.Model):
             if model_config.get('completion_prompt_config') else None
             if model_config.get('completion_prompt_config') else None
         self.dataset_configs = json.dumps(model_config.get('dataset_configs')) \
         self.dataset_configs = json.dumps(model_config.get('dataset_configs')) \
             if model_config.get('dataset_configs') else None
             if model_config.get('dataset_configs') else None
+        self.file_upload = json.dumps(model_config.get('file_upload')) \
+            if model_config.get('file_upload') else None
         return self
         return self
 
 
     def copy(self):
     def copy(self):
@@ -238,7 +246,8 @@ class AppModelConfig(db.Model):
             prompt_type=self.prompt_type,
             prompt_type=self.prompt_type,
             chat_prompt_config=self.chat_prompt_config,
             chat_prompt_config=self.chat_prompt_config,
             completion_prompt_config=self.completion_prompt_config,
             completion_prompt_config=self.completion_prompt_config,
-            dataset_configs=self.dataset_configs
+            dataset_configs=self.dataset_configs,
+            file_upload=self.file_upload
         )
         )
 
 
         return new_app_model_config
         return new_app_model_config
@@ -512,6 +521,37 @@ class Message(db.Model):
         return db.session.query(DatasetRetrieverResource).filter(DatasetRetrieverResource.message_id == self.id) \
         return db.session.query(DatasetRetrieverResource).filter(DatasetRetrieverResource.message_id == self.id) \
             .order_by(DatasetRetrieverResource.position.asc()).all()
             .order_by(DatasetRetrieverResource.position.asc()).all()
 
 
+    @property
+    def message_files(self):
+        return db.session.query(MessageFile).filter(MessageFile.message_id == self.id).all()
+
+    @property
+    def files(self):
+        message_files = self.message_files
+
+        files = []
+        for message_file in message_files:
+            url = message_file.url
+            if message_file.type == 'image':
+                if message_file.transfer_method == 'local_file':
+                    upload_file = (db.session.query(UploadFile)
+                                   .filter(
+                        UploadFile.id == message_file.upload_file_id
+                    ).first())
+
+                    url = UploadFileParser.get_image_data(
+                        upload_file=upload_file,
+                        force_url=True
+                    )
+
+            files.append({
+                'id': message_file.id,
+                'type': message_file.type,
+                'url': url
+            })
+
+        return files
+
 
 
 class MessageFeedback(db.Model):
 class MessageFeedback(db.Model):
     __tablename__ = 'message_feedbacks'
     __tablename__ = 'message_feedbacks'
@@ -540,6 +580,25 @@ class MessageFeedback(db.Model):
         return account
         return account
 
 
 
 
+class MessageFile(db.Model):
+    __tablename__ = 'message_files'
+    __table_args__ = (
+        db.PrimaryKeyConstraint('id', name='message_file_pkey'),
+        db.Index('message_file_message_idx', 'message_id'),
+        db.Index('message_file_created_by_idx', 'created_by')
+    )
+
+    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
+    message_id = db.Column(UUID, nullable=False)
+    type = db.Column(db.String(255), nullable=False)
+    transfer_method = db.Column(db.String(255), nullable=False)
+    url = db.Column(db.Text, nullable=True)
+    upload_file_id = db.Column(UUID, nullable=True)
+    created_by_role = db.Column(db.String(255), nullable=False)
+    created_by = db.Column(UUID, nullable=False)
+    created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
+
+
 class MessageAnnotation(db.Model):
 class MessageAnnotation(db.Model):
     __tablename__ = 'message_annotations'
     __tablename__ = 'message_annotations'
     __table_args__ = (
     __table_args__ = (
@@ -683,6 +742,7 @@ class UploadFile(db.Model):
     size = db.Column(db.Integer, nullable=False)
     size = db.Column(db.Integer, nullable=False)
     extension = db.Column(db.String(255), nullable=False)
     extension = db.Column(db.String(255), nullable=False)
     mime_type = db.Column(db.String(255), nullable=True)
     mime_type = db.Column(db.String(255), nullable=True)
+    created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'account'::character varying"))
     created_by = db.Column(UUID, nullable=False)
     created_by = db.Column(UUID, nullable=False)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     used = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
     used = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
@@ -783,4 +843,3 @@ class DatasetRetrieverResource(db.Model):
     retriever_from = db.Column(db.Text, nullable=False)
     retriever_from = db.Column(db.Text, nullable=False)
     created_by = db.Column(UUID, nullable=False)
     created_by = db.Column(UUID, nullable=False)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
-

+ 33 - 1
api/services/app_model_config_service.py

@@ -315,6 +315,9 @@ class AppModelConfigService:
         # moderation validation
         # moderation validation
         cls.is_moderation_valid(tenant_id, config)
         cls.is_moderation_valid(tenant_id, config)
 
 
+        # file upload validation
+        cls.is_file_upload_valid(config)
+
         # Filter out extra parameters
         # Filter out extra parameters
         filtered_config = {
         filtered_config = {
             "opening_statement": config["opening_statement"],
             "opening_statement": config["opening_statement"],
@@ -338,7 +341,8 @@ class AppModelConfigService:
             "prompt_type": config["prompt_type"],
             "prompt_type": config["prompt_type"],
             "chat_prompt_config": config["chat_prompt_config"],
             "chat_prompt_config": config["chat_prompt_config"],
             "completion_prompt_config": config["completion_prompt_config"],
             "completion_prompt_config": config["completion_prompt_config"],
-            "dataset_configs": config["dataset_configs"]
+            "dataset_configs": config["dataset_configs"],
+            "file_upload": config["file_upload"]
         }
         }
 
 
         return filtered_config
         return filtered_config
@@ -371,6 +375,34 @@ class AppModelConfigService:
             config=config
             config=config
         )
         )
 
 
+    @classmethod
+    def is_file_upload_valid(cls, config: dict):
+        if 'file_upload' not in config or not config["file_upload"]:
+            config["file_upload"] = {}
+
+        if not isinstance(config["file_upload"], dict):
+            raise ValueError("file_upload must be of dict type")
+
+        # check image config
+        if 'image' not in config["file_upload"] or not config["file_upload"]["image"]:
+            config["file_upload"]["image"] = {"enabled": False}
+
+        if config['file_upload']['image']['enabled']:
+            number_limits = config['file_upload']['image']['number_limits']
+            if number_limits < 1 or number_limits > 6:
+                raise ValueError("number_limits must be in [1, 6]")
+
+            detail = config['file_upload']['image']['detail']
+            if detail not in ['high', 'low']:
+                raise ValueError("detail must be in ['high', 'low']")
+
+            transfer_methods = config['file_upload']['image']['transfer_methods']
+            if not isinstance(transfer_methods, list):
+                raise ValueError("transfer_methods must be of list type")
+            for method in transfer_methods:
+                if method not in ['remote_url', 'local_file']:
+                    raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
+
     @classmethod
     @classmethod
     def is_external_data_tools_valid(cls, tenant_id: str, config: dict):
     def is_external_data_tools_valid(cls, tenant_id: str, config: dict):
         if 'external_data_tools' not in config or not config["external_data_tools"]:
         if 'external_data_tools' not in config or not config["external_data_tools"]:

+ 40 - 10
api/services/completion_service.py

@@ -3,7 +3,7 @@ import logging
 import threading
 import threading
 import time
 import time
 import uuid
 import uuid
-from typing import Generator, Union, Any, Optional
+from typing import Generator, Union, Any, Optional, List
 
 
 from flask import current_app, Flask
 from flask import current_app, Flask
 from redis.client import PubSub
 from redis.client import PubSub
@@ -12,9 +12,11 @@ from sqlalchemy import and_
 from core.completion import Completion
 from core.completion import Completion
 from core.conversation_message_task import PubHandler, ConversationTaskStoppedException, \
 from core.conversation_message_task import PubHandler, ConversationTaskStoppedException, \
     ConversationTaskInterruptException
     ConversationTaskInterruptException
+from core.file.message_file_parser import MessageFileParser
 from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
 from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
     LLMRateLimitError, \
     LLMRateLimitError, \
     LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
     LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.model_providers.models.entity.message import PromptMessageFile
 from extensions.ext_database import db
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from extensions.ext_redis import redis_client
 from models.model import Conversation, AppModelConfig, App, Account, EndUser, Message
 from models.model import Conversation, AppModelConfig, App, Account, EndUser, Message
@@ -35,6 +37,9 @@ class CompletionService:
         # is streaming mode
         # is streaming mode
         inputs = args['inputs']
         inputs = args['inputs']
         query = args['query']
         query = args['query']
+        files = args['files'] if 'files' in args and args['files'] else []
+        auto_generate_name = args['auto_generate_name'] \
+            if 'auto_generate_name' in args else True
 
 
         if app_model.mode != 'completion' and not query:
         if app_model.mode != 'completion' and not query:
             raise ValueError('query is required')
             raise ValueError('query is required')
@@ -132,6 +137,14 @@ class CompletionService:
         # clean input by app_model_config form rules
         # clean input by app_model_config form rules
         inputs = cls.get_cleaned_inputs(inputs, app_model_config)
         inputs = cls.get_cleaned_inputs(inputs, app_model_config)
 
 
+        # parse files
+        message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
+        file_objs = message_file_parser.validate_and_transform_files_arg(
+            files,
+            app_model_config,
+            user
+        )
+
         generate_task_id = str(uuid.uuid4())
         generate_task_id = str(uuid.uuid4())
 
 
         pubsub = redis_client.pubsub()
         pubsub = redis_client.pubsub()
@@ -146,17 +159,20 @@ class CompletionService:
             'app_model_config': app_model_config.copy(),
             'app_model_config': app_model_config.copy(),
             'query': query,
             'query': query,
             'inputs': inputs,
             'inputs': inputs,
+            'files': file_objs,
             'detached_user': user,
             'detached_user': user,
             'detached_conversation': conversation,
             'detached_conversation': conversation,
             'streaming': streaming,
             'streaming': streaming,
             'is_model_config_override': is_model_config_override,
             'is_model_config_override': is_model_config_override,
-            'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev'
+            'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev',
+            'auto_generate_name': auto_generate_name
         })
         })
 
 
         generate_worker_thread.start()
         generate_worker_thread.start()
 
 
         # wait for 10 minutes to close the thread
         # wait for 10 minutes to close the thread
-        cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id)
+        cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user,
+                                generate_task_id)
 
 
         return cls.compact_response(pubsub, streaming)
         return cls.compact_response(pubsub, streaming)
 
 
@@ -172,10 +188,12 @@ class CompletionService:
         return user
         return user
 
 
     @classmethod
     @classmethod
-    def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App, app_model_config: AppModelConfig,
-                        query: str, inputs: dict, detached_user: Union[Account, EndUser],
+    def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App,
+                        app_model_config: AppModelConfig,
+                        query: str, inputs: dict, files: List[PromptMessageFile],
+                        detached_user: Union[Account, EndUser],
                         detached_conversation: Optional[Conversation], streaming: bool, is_model_config_override: bool,
                         detached_conversation: Optional[Conversation], streaming: bool, is_model_config_override: bool,
-                        retriever_from: str = 'dev'):
+                        retriever_from: str = 'dev', auto_generate_name: bool = True):
         with flask_app.app_context():
         with flask_app.app_context():
             # fixed the state of the model object when it detached from the original session
             # fixed the state of the model object when it detached from the original session
             user = db.session.merge(detached_user)
             user = db.session.merge(detached_user)
@@ -195,10 +213,12 @@ class CompletionService:
                     query=query,
                     query=query,
                     inputs=inputs,
                     inputs=inputs,
                     user=user,
                     user=user,
+                    files=files,
                     conversation=conversation,
                     conversation=conversation,
                     streaming=streaming,
                     streaming=streaming,
                     is_override=is_model_config_override,
                     is_override=is_model_config_override,
-                    retriever_from=retriever_from
+                    retriever_from=retriever_from,
+                    auto_generate_name=auto_generate_name
                 )
                 )
             except (ConversationTaskInterruptException, ConversationTaskStoppedException):
             except (ConversationTaskInterruptException, ConversationTaskStoppedException):
                 pass
                 pass
@@ -215,7 +235,8 @@ class CompletionService:
                 db.session.commit()
                 db.session.commit()
 
 
     @classmethod
     @classmethod
-    def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, detached_user, generate_task_id) -> threading.Thread:
+    def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, detached_user,
+                            generate_task_id) -> threading.Thread:
         # wait for 10 minutes to close the thread
         # wait for 10 minutes to close the thread
         timeout = 600
         timeout = 600
 
 
@@ -274,6 +295,12 @@ class CompletionService:
         model_dict['completion_params'] = completion_params
         model_dict['completion_params'] = completion_params
         app_model_config.model = json.dumps(model_dict)
         app_model_config.model = json.dumps(model_dict)
 
 
+        # parse files
+        message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
+        file_objs = message_file_parser.transform_message_files(
+            message.files, app_model_config
+        )
+
         generate_task_id = str(uuid.uuid4())
         generate_task_id = str(uuid.uuid4())
 
 
         pubsub = redis_client.pubsub()
         pubsub = redis_client.pubsub()
@@ -288,11 +315,13 @@ class CompletionService:
             'app_model_config': app_model_config.copy(),
             'app_model_config': app_model_config.copy(),
             'query': message.query,
             'query': message.query,
             'inputs': message.inputs,
             'inputs': message.inputs,
+            'files': file_objs,
             'detached_user': user,
             'detached_user': user,
             'detached_conversation': None,
             'detached_conversation': None,
             'streaming': streaming,
             'streaming': streaming,
             'is_model_config_override': True,
             'is_model_config_override': True,
-            'retriever_from': retriever_from
+            'retriever_from': retriever_from,
+            'auto_generate_name': False
         })
         })
 
 
         generate_worker_thread.start()
         generate_worker_thread.start()
@@ -388,7 +417,8 @@ class CompletionService:
                             if event == 'message':
                             if event == 'message':
                                 yield "data: " + json.dumps(cls.get_message_response_data(result.get('data'))) + "\n\n"
                                 yield "data: " + json.dumps(cls.get_message_response_data(result.get('data'))) + "\n\n"
                             elif event == 'message_replace':
                             elif event == 'message_replace':
-                                yield "data: " + json.dumps(cls.get_message_replace_response_data(result.get('data'))) + "\n\n"
+                                yield "data: " + json.dumps(
+                                    cls.get_message_replace_response_data(result.get('data'))) + "\n\n"
                             elif event == 'chain':
                             elif event == 'chain':
                                 yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n"
                                 yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n"
                             elif event == 'agent_thought':
                             elif event == 'agent_thought':

+ 36 - 4
api/services/conversation_service.py

@@ -1,17 +1,20 @@
 from typing import Union, Optional
 from typing import Union, Optional
 
 
+from core.generator.llm_generator import LLMGenerator
 from libs.infinite_scroll_pagination import InfiniteScrollPagination
 from libs.infinite_scroll_pagination import InfiniteScrollPagination
 from extensions.ext_database import db
 from extensions.ext_database import db
 from models.account import Account
 from models.account import Account
-from models.model import Conversation, App, EndUser
+from models.model import Conversation, App, EndUser, Message
 from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
 from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
+from services.errors.message import MessageNotExistsError
 
 
 
 
 class ConversationService:
 class ConversationService:
     @classmethod
     @classmethod
     def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account | EndUser]],
     def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account | EndUser]],
                               last_id: Optional[str], limit: int,
                               last_id: Optional[str], limit: int,
-                              include_ids: Optional[list] = None, exclude_ids: Optional[list] = None) -> InfiniteScrollPagination:
+                              include_ids: Optional[list] = None, exclude_ids: Optional[list] = None,
+                              exclude_debug_conversation: bool = False) -> InfiniteScrollPagination:
         if not user:
         if not user:
             return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
             return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
 
 
@@ -29,6 +32,9 @@ class ConversationService:
         if exclude_ids is not None:
         if exclude_ids is not None:
             base_query = base_query.filter(~Conversation.id.in_(exclude_ids))
             base_query = base_query.filter(~Conversation.id.in_(exclude_ids))
 
 
+        if exclude_debug_conversation:
+            base_query = base_query.filter(Conversation.override_model_configs == None)
+
         if last_id:
         if last_id:
             last_conversation = base_query.filter(
             last_conversation = base_query.filter(
                 Conversation.id == last_id,
                 Conversation.id == last_id,
@@ -63,10 +69,36 @@ class ConversationService:
 
 
     @classmethod
     @classmethod
     def rename(cls, app_model: App, conversation_id: str,
     def rename(cls, app_model: App, conversation_id: str,
-               user: Optional[Union[Account | EndUser]], name: str):
+               user: Optional[Union[Account | EndUser]], name: str, auto_generate: bool):
         conversation = cls.get_conversation(app_model, conversation_id, user)
         conversation = cls.get_conversation(app_model, conversation_id, user)
 
 
-        conversation.name = name
+        if auto_generate:
+            return cls.auto_generate_name(app_model, conversation)
+        else:
+            conversation.name = name
+            db.session.commit()
+
+        return conversation
+
+    @classmethod
+    def auto_generate_name(cls, app_model: App, conversation: Conversation):
+        # get conversation first message
+        message = db.session.query(Message) \
+            .filter(
+                Message.app_id == app_model.id,
+                Message.conversation_id == conversation.id
+            ).order_by(Message.created_at.asc()).first()
+
+        if not message:
+            raise MessageNotExistsError()
+
+        # generate conversation name
+        try:
+            name = LLMGenerator.generate_conversation_name(app_model.tenant_id, message.query)
+            conversation.name = name
+        except:
+            pass
+
         db.session.commit()
         db.session.commit()
 
 
         return conversation
         return conversation

+ 54 - 21
api/services/file_service.py

@@ -1,46 +1,62 @@
 import datetime
 import datetime
 import hashlib
 import hashlib
-import time
 import uuid
 import uuid
+from typing import Generator, Tuple, Union
 
 
-from cachetools import TTLCache
-from flask import request, current_app
+from flask import current_app
 from flask_login import current_user
 from flask_login import current_user
 from werkzeug.datastructures import FileStorage
 from werkzeug.datastructures import FileStorage
 from werkzeug.exceptions import NotFound
 from werkzeug.exceptions import NotFound
 
 
 from core.data_loader.file_extractor import FileExtractor
 from core.data_loader.file_extractor import FileExtractor
+from core.file.upload_file_parser import UploadFileParser
 from extensions.ext_storage import storage
 from extensions.ext_storage import storage
 from extensions.ext_database import db
 from extensions.ext_database import db
-from models.model import UploadFile
+from models.account import Account
+from models.model import UploadFile, EndUser
 from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
 from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
 
 
-ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv']
+ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv',
+                      'jpg', 'jpeg', 'png', 'webp', 'gif']
+IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif']
 PREVIEW_WORDS_LIMIT = 3000
 PREVIEW_WORDS_LIMIT = 3000
-cache = TTLCache(maxsize=None, ttl=30)
 
 
 
 
 class FileService:
 class FileService:
 
 
     @staticmethod
     @staticmethod
-    def upload_file(file: FileStorage) -> UploadFile:
+    def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bool = False) -> UploadFile:
+        extension = file.filename.split('.')[-1]
+        if extension.lower() not in ALLOWED_EXTENSIONS:
+            raise UnsupportedFileTypeError()
+        elif only_image and extension.lower() not in IMAGE_EXTENSIONS:
+            raise UnsupportedFileTypeError()
+
         # read file content
         # read file content
         file_content = file.read()
         file_content = file.read()
+
         # get file size
         # get file size
         file_size = len(file_content)
         file_size = len(file_content)
 
 
-        file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024
+        if extension.lower() in IMAGE_EXTENSIONS:
+            file_size_limit = current_app.config.get("UPLOAD_IMAGE_FILE_SIZE_LIMIT") * 1024 * 1024
+        else:
+            file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024
+
         if file_size > file_size_limit:
         if file_size > file_size_limit:
             message = f'File size exceeded. {file_size} > {file_size_limit}'
             message = f'File size exceeded. {file_size} > {file_size_limit}'
             raise FileTooLargeError(message)
             raise FileTooLargeError(message)
 
 
-        extension = file.filename.split('.')[-1]
-        if extension.lower() not in ALLOWED_EXTENSIONS:
-            raise UnsupportedFileTypeError()
-
         # user uuid as file name
         # user uuid as file name
         file_uuid = str(uuid.uuid4())
         file_uuid = str(uuid.uuid4())
-        file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.' + extension
+
+        if isinstance(user, Account):
+            current_tenant_id = user.current_tenant_id
+        else:
+            # end_user
+            current_tenant_id = user.tenant_id
+
+        file_key = 'upload_files/' + current_tenant_id + '/' + file_uuid + '.' + extension
 
 
         # save file to storage
         # save file to storage
         storage.save(file_key, file_content)
         storage.save(file_key, file_content)
@@ -48,14 +64,15 @@ class FileService:
         # save file to db
         # save file to db
         config = current_app.config
         config = current_app.config
         upload_file = UploadFile(
         upload_file = UploadFile(
-            tenant_id=current_user.current_tenant_id,
+            tenant_id=current_tenant_id,
             storage_type=config['STORAGE_TYPE'],
             storage_type=config['STORAGE_TYPE'],
             key=file_key,
             key=file_key,
             name=file.filename,
             name=file.filename,
             size=file_size,
             size=file_size,
             extension=extension,
             extension=extension,
             mime_type=file.mimetype,
             mime_type=file.mimetype,
-            created_by=current_user.id,
+            created_by_role=('account' if isinstance(user, Account) else 'end_user'),
+            created_by=user.id,
             created_at=datetime.datetime.utcnow(),
             created_at=datetime.datetime.utcnow(),
             used=False,
             used=False,
             hash=hashlib.sha3_256(file_content).hexdigest()
             hash=hashlib.sha3_256(file_content).hexdigest()
@@ -99,12 +116,6 @@ class FileService:
 
 
     @staticmethod
     @staticmethod
     def get_file_preview(file_id: str) -> str:
     def get_file_preview(file_id: str) -> str:
-        # get file storage key
-        key = file_id + request.path
-        cached_response = cache.get(key)
-        if cached_response and time.time() - cached_response['timestamp'] < cache.ttl:
-            return cached_response['response']
-
         upload_file = db.session.query(UploadFile) \
         upload_file = db.session.query(UploadFile) \
             .filter(UploadFile.id == file_id) \
             .filter(UploadFile.id == file_id) \
             .first()
             .first()
@@ -121,3 +132,25 @@ class FileService:
         text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
         text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
 
 
         return text
         return text
+
+    @staticmethod
+    def get_image_preview(file_id: str, timestamp: str, nonce: str, sign: str) -> Tuple[Generator, str]:
+        result = UploadFileParser.verify_image_file_signature(file_id, timestamp, nonce, sign)
+        if not result:
+            raise NotFound("File not found or signature is invalid")
+
+        upload_file = db.session.query(UploadFile) \
+            .filter(UploadFile.id == file_id) \
+            .first()
+
+        if not upload_file:
+            raise NotFound("File not found or signature is invalid")
+
+        # extract text from file
+        extension = upload_file.extension
+        if extension.lower() not in IMAGE_EXTENSIONS:
+            raise UnsupportedFileTypeError()
+
+        generator = storage.load(upload_file.key, stream=True)
+
+        return generator, upload_file.mime_type

+ 4 - 2
api/services/web_conversation_service.py

@@ -11,7 +11,8 @@ from services.conversation_service import ConversationService
 class WebConversationService:
 class WebConversationService:
     @classmethod
     @classmethod
     def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account | EndUser]],
     def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account | EndUser]],
-                              last_id: Optional[str], limit: int, pinned: Optional[bool] = None) -> InfiniteScrollPagination:
+                              last_id: Optional[str], limit: int, pinned: Optional[bool] = None,
+                              exclude_debug_conversation: bool = False) -> InfiniteScrollPagination:
         include_ids = None
         include_ids = None
         exclude_ids = None
         exclude_ids = None
         if pinned is not None:
         if pinned is not None:
@@ -32,7 +33,8 @@ class WebConversationService:
             last_id=last_id,
             last_id=last_id,
             limit=limit,
             limit=limit,
             include_ids=include_ids,
             include_ids=include_ids,
-            exclude_ids=exclude_ids
+            exclude_ids=exclude_ids,
+            exclude_debug_conversation=exclude_debug_conversation
         )
         )
 
 
     @classmethod
     @classmethod

+ 29 - 1
api/tests/integration_tests/models/llm/test_openai_model.py

@@ -5,7 +5,7 @@ from unittest.mock import patch
 from langchain.schema import Generation, ChatGeneration, AIMessage
 from langchain.schema import Generation, ChatGeneration, AIMessage
 
 
 from core.model_providers.providers.openai_provider import OpenAIProvider
 from core.model_providers.providers.openai_provider import OpenAIProvider
-from core.model_providers.models.entity.message import PromptMessage, MessageType
+from core.model_providers.models.entity.message import PromptMessage, MessageType, ImageMessageFile
 from core.model_providers.models.entity.model_params import ModelKwargs
 from core.model_providers.models.entity.model_params import ModelKwargs
 from core.model_providers.models.llm.openai_model import OpenAIModel
 from core.model_providers.models.llm.openai_model import OpenAIModel
 from models.provider import Provider, ProviderType
 from models.provider import Provider, ProviderType
@@ -57,6 +57,18 @@ def test_chat_get_num_tokens(mock_decrypt):
     assert rst == 22
     assert rst == 22
 
 
 
 
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_vision_chat_get_num_tokens(mock_decrypt):
+    openai_model = get_mock_openai_model('gpt-4-vision-preview')
+    messages = [
+        PromptMessage(content='What’s in first image?', files=[
+            ImageMessageFile(
+                data='https://upload.wikimedia.org/wikipedia/commons/0/00/1890s_Carlisle_Boarding_School_Graduates_PA.jpg')
+        ])
+    ]
+    rst = openai_model.get_num_tokens(messages)
+    assert rst == 77
+
 @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
 @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
 def test_run(mock_decrypt, mocker):
 def test_run(mock_decrypt, mocker):
     mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
     mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
@@ -80,4 +92,20 @@ def test_chat_run(mock_decrypt, mocker):
         messages,
         messages,
         stop=['\nHuman:'],
         stop=['\nHuman:'],
     )
     )
+    assert (len(rst.content) > 0)
+
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_vision_run(mock_decrypt, mocker):
+    mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
+
+    openai_model = get_mock_openai_model('gpt-4-vision-preview')
+    messages = [
+        PromptMessage(content='What’s in first image?', files=[
+            ImageMessageFile(data='https://upload.wikimedia.org/wikipedia/commons/0/00/1890s_Carlisle_Boarding_School_Graduates_PA.jpg')
+        ])
+    ]
+    rst = openai_model.run(
+        messages,
+    )
     assert len(rst.content) > 0
     assert len(rst.content) > 0

+ 7 - 3
docker/docker-compose.yaml

@@ -19,18 +19,22 @@ services:
       # different from api or web app domain.
       # different from api or web app domain.
       # example: http://cloud.dify.ai
       # example: http://cloud.dify.ai
       CONSOLE_API_URL: ''
       CONSOLE_API_URL: ''
-      # The URL for Service API endpoints, refers to the base URL of the current API service if api domain is
+      # The URL prefix for Service API endpoints, refers to the base URL of the current API service if api domain is
       # different from console domain.
       # different from console domain.
       # example: http://api.dify.ai
       # example: http://api.dify.ai
       SERVICE_API_URL: ''
       SERVICE_API_URL: ''
-      # The URL for Web APP api server, refers to the Web App base URL of WEB service if web app domain is different from
+      # The URL prefix for Web APP api server, refers to the Web App base URL of WEB service if web app domain is different from
       # console or api domain.
       # console or api domain.
       # example: http://udify.app
       # example: http://udify.app
       APP_API_URL: ''
       APP_API_URL: ''
-      # The URL for Web APP frontend, refers to the Web App base URL of WEB service if web app domain is different from
+      # The URL prefix for Web APP frontend, refers to the Web App base URL of WEB service if web app domain is different from
       # console or api domain.
       # console or api domain.
       # example: http://udify.app
       # example: http://udify.app
       APP_WEB_URL: ''
       APP_WEB_URL: ''
+      # File preview or download Url prefix.
+      # used to display File preview or download Url to the front-end or as Multi-model inputs;
+      # Url is signed and has expiration time.
+      FILES_URL: ''
       # When enabled, migrations will be executed prior to application startup and the application will start after the migrations have completed.
       # When enabled, migrations will be executed prior to application startup and the application will start after the migrations have completed.
       MIGRATION_ENABLED: 'true'
       MIGRATION_ENABLED: 'true'
       # The configurations of postgres database connection.
       # The configurations of postgres database connection.

+ 5 - 0
docker/nginx/conf.d/default.conf

@@ -17,6 +17,11 @@ server {
       include proxy.conf;
       include proxy.conf;
     }
     }
 
 
+    location /files {
+      proxy_pass http://api:5001;
+      include proxy.conf;
+    }
+
     location / {
     location / {
       proxy_pass http://web:3000;
       proxy_pass http://web:3000;
       include proxy.conf;
       include proxy.conf;