Browse Source

fix: missing default user for APP service api (#2606)

takatost 1 năm trước cách đây
mục cha
commit
0828873b52

+ 0 - 27
api/controllers/service_api/app/__init__.py

@@ -1,27 +0,0 @@
-from extensions.ext_database import db
-from models.model import EndUser
-
-
-def create_or_update_end_user_for_user_id(app_model, user_id):
-    """
-    Create or update session terminal based on user ID.
-    """
-    end_user = db.session.query(EndUser) \
-        .filter(
-        EndUser.tenant_id == app_model.tenant_id,
-        EndUser.session_id == user_id,
-        EndUser.type == 'service_api'
-    ).first()
-
-    if end_user is None:
-        end_user = EndUser(
-            tenant_id=app_model.tenant_id,
-            app_id=app_model.id,
-            type='service_api',
-            is_anonymous=True,
-            session_id=user_id
-        )
-        db.session.add(end_user)
-        db.session.commit()
-
-    return end_user

+ 8 - 6
api/controllers/service_api/app/app.py

@@ -1,16 +1,16 @@
 import json
 
 from flask import current_app
-from flask_restful import fields, marshal_with
+from flask_restful import fields, marshal_with, Resource
 
 from controllers.service_api import api
-from controllers.service_api.wraps import AppApiResource
+from controllers.service_api.wraps import validate_app_token
 from extensions.ext_database import db
 from models.model import App, AppModelConfig
 from models.tools import ApiToolProvider
 
 
-class AppParameterApi(AppApiResource):
+class AppParameterApi(Resource):
     """Resource for app variables."""
 
     variable_fields = {
@@ -42,8 +42,9 @@ class AppParameterApi(AppApiResource):
         'system_parameters': fields.Nested(system_parameters_fields)
     }
 
+    @validate_app_token
     @marshal_with(parameters_fields)
-    def get(self, app_model: App, end_user):
+    def get(self, app_model: App):
         """Retrieve app parameters."""
         app_model_config = app_model.app_model_config
 
@@ -64,8 +65,9 @@ class AppParameterApi(AppApiResource):
             }
         }
 
-class AppMetaApi(AppApiResource):
-    def get(self, app_model: App, end_user):
+class AppMetaApi(Resource):
+    @validate_app_token
+    def get(self, app_model: App):
         """Get app meta"""
         app_model_config: AppModelConfig = app_model.app_model_config
 

+ 10 - 9
api/controllers/service_api/app/audio.py

@@ -1,7 +1,7 @@
 import logging
 
 from flask import request
-from flask_restful import reqparse
+from flask_restful import Resource, reqparse
 from werkzeug.exceptions import InternalServerError
 
 import services
@@ -17,10 +17,10 @@ from controllers.service_api.app.error import (
     ProviderQuotaExceededError,
     UnsupportedAudioTypeError,
 )
-from controllers.service_api.wraps import AppApiResource
+from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
 from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
 from core.model_runtime.errors.invoke import InvokeError
-from models.model import App, AppModelConfig
+from models.model import App, AppModelConfig, EndUser
 from services.audio_service import AudioService
 from services.errors.audio import (
     AudioTooLargeServiceError,
@@ -30,8 +30,9 @@ from services.errors.audio import (
 )
 
 
-class AudioApi(AppApiResource):
-    def post(self, app_model: App, end_user):
+class AudioApi(Resource):
+    @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
+    def post(self, app_model: App, end_user: EndUser):
         app_model_config: AppModelConfig = app_model.app_model_config
 
         if not app_model_config.speech_to_text_dict['enabled']:
@@ -73,11 +74,11 @@ class AudioApi(AppApiResource):
             raise InternalServerError()
 
 
-class TextApi(AppApiResource):
-    def post(self, app_model: App, end_user):
+class TextApi(Resource):
+    @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
+    def post(self, app_model: App, end_user: EndUser):
         parser = reqparse.RequestParser()
         parser.add_argument('text', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('user', type=str, required=True, nullable=False, location='json')
         parser.add_argument('streaming', type=bool, required=False, nullable=False, location='json')
         args = parser.parse_args()
 
@@ -85,7 +86,7 @@ class TextApi(AppApiResource):
             response = AudioService.transcript_tts(
                 tenant_id=app_model.tenant_id,
                 text=args['text'],
-                end_user=args['user'],
+                end_user=end_user,
                 voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
                 streaming=args['streaming']
             )

+ 15 - 41
api/controllers/service_api/app/completion.py

@@ -4,12 +4,11 @@ from collections.abc import Generator
 from typing import Union
 
 from flask import Response, stream_with_context
-from flask_restful import reqparse
+from flask_restful import Resource, reqparse
 from werkzeug.exceptions import InternalServerError, NotFound
 
 import services
 from controllers.service_api import api
-from controllers.service_api.app import create_or_update_end_user_for_user_id
 from controllers.service_api.app.error import (
     AppUnavailableError,
     CompletionRequestError,
@@ -19,17 +18,19 @@ from controllers.service_api.app.error import (
     ProviderNotInitializeError,
     ProviderQuotaExceededError,
 )
-from controllers.service_api.wraps import AppApiResource
+from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
 from core.application_queue_manager import ApplicationQueueManager
 from core.entities.application_entities import InvokeFrom
 from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
 from core.model_runtime.errors.invoke import InvokeError
 from libs.helper import uuid_value
+from models.model import App, EndUser
 from services.completion_service import CompletionService
 
 
-class CompletionApi(AppApiResource):
-    def post(self, app_model, end_user):
+class CompletionApi(Resource):
+    @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
+    def post(self, app_model: App, end_user: EndUser):
         if app_model.mode != 'completion':
             raise AppUnavailableError()
 
@@ -38,16 +39,12 @@ class CompletionApi(AppApiResource):
         parser.add_argument('query', type=str, location='json', default='')
         parser.add_argument('files', type=list, required=False, location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
-        parser.add_argument('user', required=True, nullable=False, type=str, location='json')
         parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
 
         args = parser.parse_args()
 
         streaming = args['response_mode'] == 'streaming'
 
-        if end_user is None and args['user'] is not None:
-            end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
-
         args['auto_generate_name'] = False
 
         try:
@@ -82,29 +79,20 @@ class CompletionApi(AppApiResource):
             raise InternalServerError()
 
 
-class CompletionStopApi(AppApiResource):
-    def post(self, app_model, end_user, task_id):
+class CompletionStopApi(Resource):
+    @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
+    def post(self, app_model: App, end_user: EndUser, task_id):
         if app_model.mode != 'completion':
             raise AppUnavailableError()
 
-        if end_user is None:
-            parser = reqparse.RequestParser()
-            parser.add_argument('user', required=True, nullable=False, type=str, location='json')
-            args = parser.parse_args()
-
-            user = args.get('user')
-            if user is not None:
-                end_user = create_or_update_end_user_for_user_id(app_model, user)
-            else:
-                raise ValueError("arg user muse be input.")
-
         ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
 
         return {'result': 'success'}, 200
 
 
-class ChatApi(AppApiResource):
-    def post(self, app_model, end_user):
+class ChatApi(Resource):
+    @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
+    def post(self, app_model: App, end_user: EndUser):
         if app_model.mode != 'chat':
             raise NotChatAppError()
 
@@ -114,7 +102,6 @@ class ChatApi(AppApiResource):
         parser.add_argument('files', type=list, required=False, location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('conversation_id', type=uuid_value, location='json')
-        parser.add_argument('user', type=str, required=True, nullable=False, location='json')
         parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
         parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json')
 
@@ -122,9 +109,6 @@ class ChatApi(AppApiResource):
 
         streaming = args['response_mode'] == 'streaming'
 
-        if end_user is None and args['user'] is not None:
-            end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
-
         try:
             response = CompletionService.completion(
                 app_model=app_model,
@@ -157,22 +141,12 @@ class ChatApi(AppApiResource):
             raise InternalServerError()
 
 
-class ChatStopApi(AppApiResource):
-    def post(self, app_model, end_user, task_id):
+class ChatStopApi(Resource):
+    @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
+    def post(self, app_model: App, end_user: EndUser, task_id):
         if app_model.mode != 'chat':
             raise NotChatAppError()
 
-        if end_user is None:
-            parser = reqparse.RequestParser()
-            parser.add_argument('user', required=True, nullable=False, type=str, location='json')
-            args = parser.parse_args()
-
-            user = args.get('user')
-            if user is not None:
-                end_user = create_or_update_end_user_for_user_id(app_model, user)
-            else:
-                raise ValueError("arg user muse be input.")
-
         ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
 
         return {'result': 'success'}, 200

+ 12 - 23
api/controllers/service_api/app/conversation.py

@@ -1,52 +1,44 @@
-from flask import request
-from flask_restful import marshal_with, reqparse
+from flask_restful import Resource, marshal_with, reqparse
 from flask_restful.inputs import int_range
 from werkzeug.exceptions import NotFound
 
 import services
 from controllers.service_api import api
-from controllers.service_api.app import create_or_update_end_user_for_user_id
 from controllers.service_api.app.error import NotChatAppError
-from controllers.service_api.wraps import AppApiResource
+from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
 from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
 from libs.helper import uuid_value
+from models.model import App, EndUser
 from services.conversation_service import ConversationService
 
 
-class ConversationApi(AppApiResource):
+class ConversationApi(Resource):
 
+    @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
     @marshal_with(conversation_infinite_scroll_pagination_fields)
-    def get(self, app_model, end_user):
+    def get(self, app_model: App, end_user: EndUser):
         if app_model.mode != 'chat':
             raise NotChatAppError()
 
         parser = reqparse.RequestParser()
         parser.add_argument('last_id', type=uuid_value, location='args')
         parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
-        parser.add_argument('user', type=str, location='args')
         args = parser.parse_args()
 
-        if end_user is None and args['user'] is not None:
-            end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
-
         try:
             return ConversationService.pagination_by_last_id(app_model, end_user, args['last_id'], args['limit'])
         except services.errors.conversation.LastConversationNotExistsError:
             raise NotFound("Last Conversation Not Exists.")
 
-class ConversationDetailApi(AppApiResource):
+class ConversationDetailApi(Resource):
+    @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
     @marshal_with(simple_conversation_fields)
-    def delete(self, app_model, end_user, c_id):
+    def delete(self, app_model: App, end_user: EndUser, c_id):
         if app_model.mode != 'chat':
             raise NotChatAppError()
 
         conversation_id = str(c_id)
 
-        user = request.get_json().get('user')
-
-        if end_user is None and user is not None:
-            end_user = create_or_update_end_user_for_user_id(app_model, user)
-
         try:
             ConversationService.delete(app_model, conversation_id, end_user)
         except services.errors.conversation.ConversationNotExistsError:
@@ -54,10 +46,11 @@ class ConversationDetailApi(AppApiResource):
         return {"result": "success"}, 204
 
 
-class ConversationRenameApi(AppApiResource):
+class ConversationRenameApi(Resource):
 
+    @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
     @marshal_with(simple_conversation_fields)
-    def post(self, app_model, end_user, c_id):
+    def post(self, app_model: App, end_user: EndUser, c_id):
         if app_model.mode != 'chat':
             raise NotChatAppError()
 
@@ -65,13 +58,9 @@ class ConversationRenameApi(AppApiResource):
 
         parser = reqparse.RequestParser()
         parser.add_argument('name', type=str, required=False, location='json')
-        parser.add_argument('user', type=str, location='json')
         parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
         args = parser.parse_args()
 
-        if end_user is None and args['user'] is not None:
-            end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
-
         try:
             return ConversationService.rename(
                 app_model,

+ 6 - 9
api/controllers/service_api/app/file.py

@@ -1,30 +1,27 @@
 from flask import request
-from flask_restful import marshal_with
+from flask_restful import Resource, marshal_with
 
 import services
 from controllers.service_api import api
-from controllers.service_api.app import create_or_update_end_user_for_user_id
 from controllers.service_api.app.error import (
     FileTooLargeError,
     NoFileUploadedError,
     TooManyFilesError,
     UnsupportedFileTypeError,
 )
-from controllers.service_api.wraps import AppApiResource
+from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
 from fields.file_fields import file_fields
+from models.model import App, EndUser
 from services.file_service import FileService
 
 
-class FileApi(AppApiResource):
+class FileApi(Resource):
 
+    @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
     @marshal_with(file_fields)
-    def post(self, app_model, end_user):
+    def post(self, app_model: App, end_user: EndUser):
 
         file = request.files['file']
-        user_args = request.form.get('user')
-
-        if end_user is None and user_args is not None:
-            end_user = create_or_update_end_user_for_user_id(app_model, user_args)
 
         # check file
         if 'file' not in request.files:

+ 14 - 34
api/controllers/service_api/app/message.py

@@ -1,20 +1,18 @@
-from flask_restful import fields, marshal_with, reqparse
+from flask_restful import Resource, fields, marshal_with, reqparse
 from flask_restful.inputs import int_range
 from werkzeug.exceptions import NotFound
 
 import services
 from controllers.service_api import api
-from controllers.service_api.app import create_or_update_end_user_for_user_id
 from controllers.service_api.app.error import NotChatAppError
-from controllers.service_api.wraps import AppApiResource
-from extensions.ext_database import db
+from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
 from fields.conversation_fields import message_file_fields
 from libs.helper import TimestampField, uuid_value
-from models.model import EndUser, Message
+from models.model import App, EndUser
 from services.message_service import MessageService
 
 
-class MessageListApi(AppApiResource):
+class MessageListApi(Resource):
     feedback_fields = {
         'rating': fields.String
     }
@@ -70,8 +68,9 @@ class MessageListApi(AppApiResource):
         'data': fields.List(fields.Nested(message_fields))
     }
 
+    @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
     @marshal_with(message_infinite_scroll_pagination_fields)
-    def get(self, app_model, end_user):
+    def get(self, app_model: App, end_user: EndUser):
         if app_model.mode != 'chat':
             raise NotChatAppError()
 
@@ -79,12 +78,8 @@ class MessageListApi(AppApiResource):
         parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
         parser.add_argument('first_id', type=uuid_value, location='args')
         parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
-        parser.add_argument('user', type=str, location='args')
         args = parser.parse_args()
 
-        if end_user is None and args['user'] is not None:
-            end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
-
         try:
             return MessageService.pagination_by_first_id(app_model, end_user,
                                                          args['conversation_id'], args['first_id'], args['limit'])
@@ -94,18 +89,15 @@ class MessageListApi(AppApiResource):
             raise NotFound("First Message Not Exists.")
 
 
-class MessageFeedbackApi(AppApiResource):
-    def post(self, app_model, end_user, message_id):
+class MessageFeedbackApi(Resource):
+    @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
+    def post(self, app_model: App, end_user: EndUser, message_id):
         message_id = str(message_id)
 
         parser = reqparse.RequestParser()
         parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
-        parser.add_argument('user', type=str, location='json')
         args = parser.parse_args()
 
-        if end_user is None and args['user'] is not None:
-            end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
-
         try:
             MessageService.create_feedback(app_model, message_id, end_user, args['rating'])
         except services.errors.message.MessageNotExistsError:
@@ -114,29 +106,17 @@ class MessageFeedbackApi(AppApiResource):
         return {'result': 'success'}
 
 
-class MessageSuggestedApi(AppApiResource):
-    def get(self, app_model, end_user, message_id):
+class MessageSuggestedApi(Resource):
+    @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
+    def get(self, app_model: App, end_user: EndUser, message_id):
         message_id = str(message_id)
         if app_model.mode != 'chat':
             raise NotChatAppError()
+
         try:
-            message = db.session.query(Message).filter(
-                Message.id == message_id,
-                Message.app_id == app_model.id,
-            ).first()
-
-            if end_user is None and message.from_end_user_id is not None:
-                user = db.session.query(EndUser) \
-                    .filter(
-                        EndUser.tenant_id == app_model.tenant_id,
-                        EndUser.id == message.from_end_user_id,
-                        EndUser.type == 'service_api'
-                    ).first()
-            else:
-                user = end_user
             questions = MessageService.get_suggested_questions_after_answer(
                 app_model=app_model,
-                user=user,
+                user=end_user,
                 message_id=message_id,
                 check_enabled=False
             )

+ 76 - 14
api/controllers/service_api/wraps.py

@@ -1,22 +1,40 @@
+from collections.abc import Callable
 from datetime import datetime
+from enum import Enum
 from functools import wraps
+from typing import Optional
 
 from flask import current_app, request
 from flask_login import user_logged_in
 from flask_restful import Resource
+from pydantic import BaseModel
 from werkzeug.exceptions import NotFound, Unauthorized
 
 from extensions.ext_database import db
 from libs.login import _get_user
 from models.account import Account, Tenant, TenantAccountJoin
-from models.model import ApiToken, App
+from models.model import ApiToken, App, EndUser
 from services.feature_service import FeatureService
 
 
-def validate_app_token(view=None):
-    def decorator(view):
-        @wraps(view)
-        def decorated(*args, **kwargs):
+class WhereisUserArg(Enum):
+    """
+    Enum for whereis_user_arg.
+    """
+    QUERY = 'query'
+    JSON = 'json'
+    FORM = 'form'
+
+
+class FetchUserArg(BaseModel):
+    fetch_from: WhereisUserArg
+    required: bool = False
+
+
+def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optional[FetchUserArg] = None):
+    def decorator(view_func):
+        @wraps(view_func)
+        def decorated_view(*args, **kwargs):
             api_token = validate_and_get_api_token('app')
 
             app_model = db.session.query(App).filter(App.id == api_token.app_id).first()
@@ -29,15 +47,34 @@ def validate_app_token(view=None):
             if not app_model.enable_api:
                 raise NotFound()
 
-            return view(app_model, None, *args, **kwargs)
-        return decorated
+            kwargs['app_model'] = app_model
 
-    if view:
-        return decorator(view)
+            if not fetch_user_arg:
+                # use default-user
+                user_id = None
+            else:
+                if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
+                    user_id = request.args.get('user')
+                elif fetch_user_arg.fetch_from == WhereisUserArg.JSON:
+                    user_id = request.get_json().get('user')
+                elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
+                    user_id = request.form.get('user')
+                else:
+                    # use default-user
+                    user_id = None
 
-    # if view is None, it means that the decorator is used without parentheses
-    # use the decorator as a function for method_decorators
-    return decorator
+                if not user_id and fetch_user_arg.required:
+                    raise ValueError("Arg user must be provided.")
+
+            kwargs['end_user'] = create_or_update_end_user_for_user_id(app_model, user_id)
+
+            return view_func(*args, **kwargs)
+        return decorated_view
+
+    if view is None:
+        return decorator
+    else:
+        return decorator(view)
 
 
 def cloud_edition_billing_resource_check(resource: str,
@@ -128,8 +165,33 @@ def validate_and_get_api_token(scope=None):
     return api_token
 
 
-class AppApiResource(Resource):
-    method_decorators = [validate_app_token]
+def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] = None) -> EndUser:
+    """
+    Create or update session terminal based on user ID.
+    """
+    if not user_id:
+        user_id = 'DEFAULT-USER'
+
+    end_user = db.session.query(EndUser) \
+        .filter(
+        EndUser.tenant_id == app_model.tenant_id,
+        EndUser.app_id == app_model.id,
+        EndUser.session_id == user_id,
+        EndUser.type == 'service_api'
+    ).first()
+
+    if end_user is None:
+        end_user = EndUser(
+            tenant_id=app_model.tenant_id,
+            app_id=app_model.id,
+            type='service_api',
+            is_anonymous=True if user_id == 'DEFAULT-USER' else False,
+            session_id=user_id
+        )
+        db.session.add(end_user)
+        db.session.commit()
+
+    return end_user
 
 
 class DatasetApiResource(Resource):