Browse Source

fix: refactor conversation pagination to use SQLAlchemy session manag… (#11956)

Signed-off-by: -LAN- <laipz8200@outlook.com>
-LAN- 4 months ago
parent
commit
3d07a94bd7

+ 12 - 8
api/controllers/console/explore/conversation.py

@@ -1,12 +1,14 @@
 from flask_login import current_user
 from flask_login import current_user
 from flask_restful import marshal_with, reqparse
 from flask_restful import marshal_with, reqparse
 from flask_restful.inputs import int_range
 from flask_restful.inputs import int_range
+from sqlalchemy.orm import Session
 from werkzeug.exceptions import NotFound
 from werkzeug.exceptions import NotFound
 
 
 from controllers.console import api
 from controllers.console import api
 from controllers.console.explore.error import NotChatAppError
 from controllers.console.explore.error import NotChatAppError
 from controllers.console.explore.wraps import InstalledAppResource
 from controllers.console.explore.wraps import InstalledAppResource
 from core.app.entities.app_invoke_entities import InvokeFrom
 from core.app.entities.app_invoke_entities import InvokeFrom
+from extensions.ext_database import db
 from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
 from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
 from libs.helper import uuid_value
 from libs.helper import uuid_value
 from models.model import AppMode
 from models.model import AppMode
@@ -34,14 +36,16 @@ class ConversationListApi(InstalledAppResource):
             pinned = True if args["pinned"] == "true" else False
             pinned = True if args["pinned"] == "true" else False
 
 
         try:
         try:
-            return WebConversationService.pagination_by_last_id(
-                app_model=app_model,
-                user=current_user,
-                last_id=args["last_id"],
-                limit=args["limit"],
-                invoke_from=InvokeFrom.EXPLORE,
-                pinned=pinned,
-            )
+            with Session(db.engine) as session:
+                return WebConversationService.pagination_by_last_id(
+                    session=session,
+                    app_model=app_model,
+                    user=current_user,
+                    last_id=args["last_id"],
+                    limit=args["limit"],
+                    invoke_from=InvokeFrom.EXPLORE,
+                    pinned=pinned,
+                )
         except LastConversationNotExistsError:
         except LastConversationNotExistsError:
             raise NotFound("Last Conversation Not Exists.")
             raise NotFound("Last Conversation Not Exists.")
 
 

+ 12 - 8
api/controllers/service_api/app/conversation.py

@@ -1,5 +1,6 @@
 from flask_restful import Resource, marshal_with, reqparse
 from flask_restful import Resource, marshal_with, reqparse
 from flask_restful.inputs import int_range
 from flask_restful.inputs import int_range
+from sqlalchemy.orm import Session
 from werkzeug.exceptions import NotFound
 from werkzeug.exceptions import NotFound
 
 
 import services
 import services
@@ -7,6 +8,7 @@ from controllers.service_api import api
 from controllers.service_api.app.error import NotChatAppError
 from controllers.service_api.app.error import NotChatAppError
 from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
 from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
 from core.app.entities.app_invoke_entities import InvokeFrom
 from core.app.entities.app_invoke_entities import InvokeFrom
+from extensions.ext_database import db
 from fields.conversation_fields import (
 from fields.conversation_fields import (
     conversation_delete_fields,
     conversation_delete_fields,
     conversation_infinite_scroll_pagination_fields,
     conversation_infinite_scroll_pagination_fields,
@@ -39,14 +41,16 @@ class ConversationApi(Resource):
         args = parser.parse_args()
         args = parser.parse_args()
 
 
         try:
         try:
-            return ConversationService.pagination_by_last_id(
-                app_model=app_model,
-                user=end_user,
-                last_id=args["last_id"],
-                limit=args["limit"],
-                invoke_from=InvokeFrom.SERVICE_API,
-                sort_by=args["sort_by"],
-            )
+            with Session(db.engine) as session:
+                return ConversationService.pagination_by_last_id(
+                    session=session,
+                    app_model=app_model,
+                    user=end_user,
+                    last_id=args["last_id"],
+                    limit=args["limit"],
+                    invoke_from=InvokeFrom.SERVICE_API,
+                    sort_by=args["sort_by"],
+                )
         except services.errors.conversation.LastConversationNotExistsError:
         except services.errors.conversation.LastConversationNotExistsError:
             raise NotFound("Last Conversation Not Exists.")
             raise NotFound("Last Conversation Not Exists.")
 
 

+ 13 - 9
api/controllers/web/conversation.py

@@ -1,11 +1,13 @@
 from flask_restful import marshal_with, reqparse
 from flask_restful import marshal_with, reqparse
 from flask_restful.inputs import int_range
 from flask_restful.inputs import int_range
+from sqlalchemy.orm import Session
 from werkzeug.exceptions import NotFound
 from werkzeug.exceptions import NotFound
 
 
 from controllers.web import api
 from controllers.web import api
 from controllers.web.error import NotChatAppError
 from controllers.web.error import NotChatAppError
 from controllers.web.wraps import WebApiResource
 from controllers.web.wraps import WebApiResource
 from core.app.entities.app_invoke_entities import InvokeFrom
 from core.app.entities.app_invoke_entities import InvokeFrom
+from extensions.ext_database import db
 from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
 from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
 from libs.helper import uuid_value
 from libs.helper import uuid_value
 from models.model import AppMode
 from models.model import AppMode
@@ -40,15 +42,17 @@ class ConversationListApi(WebApiResource):
             pinned = True if args["pinned"] == "true" else False
             pinned = True if args["pinned"] == "true" else False
 
 
         try:
         try:
-            return WebConversationService.pagination_by_last_id(
-                app_model=app_model,
-                user=end_user,
-                last_id=args["last_id"],
-                limit=args["limit"],
-                invoke_from=InvokeFrom.WEB_APP,
-                pinned=pinned,
-                sort_by=args["sort_by"],
-            )
+            with Session(db.engine) as session:
+                return WebConversationService.pagination_by_last_id(
+                    session=session,
+                    app_model=app_model,
+                    user=end_user,
+                    last_id=args["last_id"],
+                    limit=args["limit"],
+                    invoke_from=InvokeFrom.WEB_APP,
+                    pinned=pinned,
+                    sort_by=args["sort_by"],
+                )
         except LastConversationNotExistsError:
         except LastConversationNotExistsError:
             raise NotFound("Last Conversation Not Exists.")
             raise NotFound("Last Conversation Not Exists.")
 
 

+ 2 - 1
api/models/web.py

@@ -1,4 +1,5 @@
 from sqlalchemy import func
 from sqlalchemy import func
+from sqlalchemy.orm import Mapped, mapped_column
 
 
 from .engine import db
 from .engine import db
 from .model import Message
 from .model import Message
@@ -33,7 +34,7 @@ class PinnedConversation(db.Model):
 
 
     id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
     id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
     app_id = db.Column(StringUUID, nullable=False)
     app_id = db.Column(StringUUID, nullable=False)
-    conversation_id = db.Column(StringUUID, nullable=False)
+    conversation_id: Mapped[str] = mapped_column(StringUUID)
     created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying"))
     created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying"))
     created_by = db.Column(StringUUID, nullable=False)
     created_by = db.Column(StringUUID, nullable=False)
     created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
     created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())

+ 27 - 23
api/services/conversation_service.py

@@ -1,8 +1,9 @@
-from collections.abc import Callable
+from collections.abc import Callable, Sequence
 from datetime import UTC, datetime
 from datetime import UTC, datetime
 from typing import Optional, Union
 from typing import Optional, Union
 
 
-from sqlalchemy import asc, desc, or_
+from sqlalchemy import asc, desc, func, or_, select
+from sqlalchemy.orm import Session
 
 
 from core.app.entities.app_invoke_entities import InvokeFrom
 from core.app.entities.app_invoke_entities import InvokeFrom
 from core.llm_generator.llm_generator import LLMGenerator
 from core.llm_generator.llm_generator import LLMGenerator
@@ -18,19 +19,21 @@ class ConversationService:
     @classmethod
     @classmethod
     def pagination_by_last_id(
     def pagination_by_last_id(
         cls,
         cls,
+        *,
+        session: Session,
         app_model: App,
         app_model: App,
         user: Optional[Union[Account, EndUser]],
         user: Optional[Union[Account, EndUser]],
         last_id: Optional[str],
         last_id: Optional[str],
         limit: int,
         limit: int,
         invoke_from: InvokeFrom,
         invoke_from: InvokeFrom,
-        include_ids: Optional[list] = None,
-        exclude_ids: Optional[list] = None,
+        include_ids: Optional[Sequence[str]] = None,
+        exclude_ids: Optional[Sequence[str]] = None,
         sort_by: str = "-updated_at",
         sort_by: str = "-updated_at",
     ) -> InfiniteScrollPagination:
     ) -> InfiniteScrollPagination:
         if not user:
         if not user:
             return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
             return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
 
 
-        base_query = db.session.query(Conversation).filter(
+        stmt = select(Conversation).where(
             Conversation.is_deleted == False,
             Conversation.is_deleted == False,
             Conversation.app_id == app_model.id,
             Conversation.app_id == app_model.id,
             Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"),
             Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"),
@@ -38,37 +41,40 @@ class ConversationService:
             Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
             Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
             or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value),
             or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value),
         )
         )
-
         if include_ids is not None:
         if include_ids is not None:
-            base_query = base_query.filter(Conversation.id.in_(include_ids))
-
+            stmt = stmt.where(Conversation.id.in_(include_ids))
         if exclude_ids is not None:
         if exclude_ids is not None:
-            base_query = base_query.filter(~Conversation.id.in_(exclude_ids))
+            stmt = stmt.where(~Conversation.id.in_(exclude_ids))
 
 
         # define sort fields and directions
         # define sort fields and directions
         sort_field, sort_direction = cls._get_sort_params(sort_by)
         sort_field, sort_direction = cls._get_sort_params(sort_by)
 
 
         if last_id:
         if last_id:
-            last_conversation = base_query.filter(Conversation.id == last_id).first()
+            last_conversation = session.scalar(stmt.where(Conversation.id == last_id))
             if not last_conversation:
             if not last_conversation:
                 raise LastConversationNotExistsError()
                 raise LastConversationNotExistsError()
 
 
             # build filters based on sorting
             # build filters based on sorting
-            filter_condition = cls._build_filter_condition(sort_field, sort_direction, last_conversation)
-            base_query = base_query.filter(filter_condition)
-
-        base_query = base_query.order_by(sort_direction(getattr(Conversation, sort_field)))
-
-        conversations = base_query.limit(limit).all()
+            filter_condition = cls._build_filter_condition(
+                sort_field=sort_field,
+                sort_direction=sort_direction,
+                reference_conversation=last_conversation,
+            )
+            stmt = stmt.where(filter_condition)
+        query_stmt = stmt.order_by(sort_direction(getattr(Conversation, sort_field))).limit(limit)
+        conversations = session.scalars(query_stmt).all()
 
 
         has_more = False
         has_more = False
         if len(conversations) == limit:
         if len(conversations) == limit:
             current_page_last_conversation = conversations[-1]
             current_page_last_conversation = conversations[-1]
             rest_filter_condition = cls._build_filter_condition(
             rest_filter_condition = cls._build_filter_condition(
-                sort_field, sort_direction, current_page_last_conversation, is_next_page=True
+                sort_field=sort_field,
+                sort_direction=sort_direction,
+                reference_conversation=current_page_last_conversation,
             )
             )
-            rest_count = base_query.filter(rest_filter_condition).count()
-
+            count_stmt = stmt.where(rest_filter_condition)
+            count_stmt = select(func.count()).select_from(count_stmt.subquery())
+            rest_count = session.scalar(count_stmt) or 0
             if rest_count > 0:
             if rest_count > 0:
                 has_more = True
                 has_more = True
 
 
@@ -81,11 +87,9 @@ class ConversationService:
         return sort_by, asc
         return sort_by, asc
 
 
     @classmethod
     @classmethod
-    def _build_filter_condition(
-        cls, sort_field: str, sort_direction: Callable, reference_conversation: Conversation, is_next_page: bool = False
-    ):
+    def _build_filter_condition(cls, sort_field: str, sort_direction: Callable, reference_conversation: Conversation):
         field_value = getattr(reference_conversation, sort_field)
         field_value = getattr(reference_conversation, sort_field)
-        if (sort_direction == desc and not is_next_page) or (sort_direction == asc and is_next_page):
+        if sort_direction == desc:
             return getattr(Conversation, sort_field) < field_value
             return getattr(Conversation, sort_field) < field_value
         else:
         else:
             return getattr(Conversation, sort_field) > field_value
             return getattr(Conversation, sort_field) > field_value

+ 12 - 6
api/services/web_conversation_service.py

@@ -1,5 +1,8 @@
 from typing import Optional, Union
 from typing import Optional, Union
 
 
+from sqlalchemy import select
+from sqlalchemy.orm import Session
+
 from core.app.entities.app_invoke_entities import InvokeFrom
 from core.app.entities.app_invoke_entities import InvokeFrom
 from extensions.ext_database import db
 from extensions.ext_database import db
 from libs.infinite_scroll_pagination import InfiniteScrollPagination
 from libs.infinite_scroll_pagination import InfiniteScrollPagination
@@ -13,6 +16,8 @@ class WebConversationService:
     @classmethod
     @classmethod
     def pagination_by_last_id(
     def pagination_by_last_id(
         cls,
         cls,
+        *,
+        session: Session,
         app_model: App,
         app_model: App,
         user: Optional[Union[Account, EndUser]],
         user: Optional[Union[Account, EndUser]],
         last_id: Optional[str],
         last_id: Optional[str],
@@ -23,24 +28,25 @@ class WebConversationService:
     ) -> InfiniteScrollPagination:
     ) -> InfiniteScrollPagination:
         include_ids = None
         include_ids = None
         exclude_ids = None
         exclude_ids = None
-        if pinned is not None:
-            pinned_conversations = (
-                db.session.query(PinnedConversation)
-                .filter(
+        if pinned is not None and user:
+            stmt = (
+                select(PinnedConversation.conversation_id)
+                .where(
                     PinnedConversation.app_id == app_model.id,
                     PinnedConversation.app_id == app_model.id,
                     PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
                     PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
                     PinnedConversation.created_by == user.id,
                     PinnedConversation.created_by == user.id,
                 )
                 )
                 .order_by(PinnedConversation.created_at.desc())
                 .order_by(PinnedConversation.created_at.desc())
-                .all()
             )
             )
-            pinned_conversation_ids = [pc.conversation_id for pc in pinned_conversations]
+            pinned_conversation_ids = session.scalars(stmt).all()
+
             if pinned:
             if pinned:
                 include_ids = pinned_conversation_ids
                 include_ids = pinned_conversation_ids
             else:
             else:
                 exclude_ids = pinned_conversation_ids
                 exclude_ids = pinned_conversation_ids
 
 
         return ConversationService.pagination_by_last_id(
         return ConversationService.pagination_by_last_id(
+            session=session,
             app_model=app_model,
             app_model=app_model,
             user=user,
             user=user,
             last_id=last_id,
             last_id=last_id,