web_conversation_service.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. from typing import Optional, Union
  2. from sqlalchemy import select
  3. from sqlalchemy.orm import Session
  4. from core.app.entities.app_invoke_entities import InvokeFrom
  5. from extensions.ext_database import db
  6. from libs.infinite_scroll_pagination import InfiniteScrollPagination
  7. from models.account import Account
  8. from models.model import App, EndUser
  9. from models.web import PinnedConversation
  10. from services.conversation_service import ConversationService
  11. class WebConversationService:
  12. @classmethod
  13. def pagination_by_last_id(
  14. cls,
  15. *,
  16. session: Session,
  17. app_model: App,
  18. user: Optional[Union[Account, EndUser]],
  19. last_id: Optional[str],
  20. limit: int,
  21. invoke_from: InvokeFrom,
  22. pinned: Optional[bool] = None,
  23. sort_by="-updated_at",
  24. ) -> InfiniteScrollPagination:
  25. if not user:
  26. raise ValueError("User is required")
  27. include_ids = None
  28. exclude_ids = None
  29. if pinned is not None and user:
  30. stmt = (
  31. select(PinnedConversation.conversation_id)
  32. .where(
  33. PinnedConversation.app_id == app_model.id,
  34. PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
  35. PinnedConversation.created_by == user.id,
  36. )
  37. .order_by(PinnedConversation.created_at.desc())
  38. )
  39. pinned_conversation_ids = session.scalars(stmt).all()
  40. if pinned:
  41. include_ids = pinned_conversation_ids
  42. else:
  43. exclude_ids = pinned_conversation_ids
  44. return ConversationService.pagination_by_last_id(
  45. session=session,
  46. app_model=app_model,
  47. user=user,
  48. last_id=last_id,
  49. limit=limit,
  50. invoke_from=invoke_from,
  51. include_ids=include_ids,
  52. exclude_ids=exclude_ids,
  53. sort_by=sort_by,
  54. )
  55. @classmethod
  56. def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
  57. if not user:
  58. return
  59. pinned_conversation = (
  60. db.session.query(PinnedConversation)
  61. .filter(
  62. PinnedConversation.app_id == app_model.id,
  63. PinnedConversation.conversation_id == conversation_id,
  64. PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
  65. PinnedConversation.created_by == user.id,
  66. )
  67. .first()
  68. )
  69. if pinned_conversation:
  70. return
  71. conversation = ConversationService.get_conversation(
  72. app_model=app_model, conversation_id=conversation_id, user=user
  73. )
  74. pinned_conversation = PinnedConversation(
  75. app_id=app_model.id,
  76. conversation_id=conversation.id,
  77. created_by_role="account" if isinstance(user, Account) else "end_user",
  78. created_by=user.id,
  79. )
  80. db.session.add(pinned_conversation)
  81. db.session.commit()
  82. @classmethod
  83. def unpin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
  84. if not user:
  85. return
  86. pinned_conversation = (
  87. db.session.query(PinnedConversation)
  88. .filter(
  89. PinnedConversation.app_id == app_model.id,
  90. PinnedConversation.conversation_id == conversation_id,
  91. PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
  92. PinnedConversation.created_by == user.id,
  93. )
  94. .first()
  95. )
  96. if not pinned_conversation:
  97. return
  98. db.session.delete(pinned_conversation)
  99. db.session.commit()