message.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. # -*- coding:utf-8 -*-
  2. import json
  3. import logging
  4. from typing import Generator, Union
  5. import services
  6. from controllers.web import api
  7. from controllers.web.error import (AppMoreLikeThisDisabledError, AppSuggestedQuestionsAfterAnswerDisabledError,
  8. CompletionRequestError, NotChatAppError, NotCompletionAppError,
  9. ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError,
  10. ProviderQuotaExceededError)
  11. from controllers.web.wraps import WebApiResource
  12. from core.entities.application_entities import InvokeFrom
  13. from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
  14. from core.model_runtime.errors.invoke import InvokeError
  15. from fields.conversation_fields import message_file_fields
  16. from fields.message_fields import agent_thought_fields
  17. from flask import Response, stream_with_context
  18. from flask_restful import fields, marshal_with, reqparse
  19. from flask_restful.inputs import int_range
  20. from libs.helper import TimestampField, uuid_value
  21. from services.completion_service import CompletionService
  22. from services.errors.app import MoreLikeThisDisabledError
  23. from services.errors.conversation import ConversationNotExistsError
  24. from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
  25. from services.message_service import MessageService
  26. from werkzeug.exceptions import InternalServerError, NotFound
  27. class MessageListApi(WebApiResource):
  28. feedback_fields = {
  29. 'rating': fields.String
  30. }
  31. retriever_resource_fields = {
  32. 'id': fields.String,
  33. 'message_id': fields.String,
  34. 'position': fields.Integer,
  35. 'dataset_id': fields.String,
  36. 'dataset_name': fields.String,
  37. 'document_id': fields.String,
  38. 'document_name': fields.String,
  39. 'data_source_type': fields.String,
  40. 'segment_id': fields.String,
  41. 'score': fields.Float,
  42. 'hit_count': fields.Integer,
  43. 'word_count': fields.Integer,
  44. 'segment_position': fields.Integer,
  45. 'index_node_hash': fields.String,
  46. 'content': fields.String,
  47. 'created_at': TimestampField
  48. }
  49. message_fields = {
  50. 'id': fields.String,
  51. 'conversation_id': fields.String,
  52. 'inputs': fields.Raw,
  53. 'query': fields.String,
  54. 'answer': fields.String,
  55. 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
  56. 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
  57. 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
  58. 'created_at': TimestampField,
  59. 'agent_thoughts': fields.List(fields.Nested(agent_thought_fields))
  60. }
  61. message_infinite_scroll_pagination_fields = {
  62. 'limit': fields.Integer,
  63. 'has_more': fields.Boolean,
  64. 'data': fields.List(fields.Nested(message_fields))
  65. }
  66. @marshal_with(message_infinite_scroll_pagination_fields)
  67. def get(self, app_model, end_user):
  68. if app_model.mode != 'chat':
  69. raise NotChatAppError()
  70. parser = reqparse.RequestParser()
  71. parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
  72. parser.add_argument('first_id', type=uuid_value, location='args')
  73. parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
  74. args = parser.parse_args()
  75. try:
  76. return MessageService.pagination_by_first_id(app_model, end_user,
  77. args['conversation_id'], args['first_id'], args['limit'])
  78. except services.errors.conversation.ConversationNotExistsError:
  79. raise NotFound("Conversation Not Exists.")
  80. except services.errors.message.FirstMessageNotExistsError:
  81. raise NotFound("First Message Not Exists.")
  82. class MessageFeedbackApi(WebApiResource):
  83. def post(self, app_model, end_user, message_id):
  84. message_id = str(message_id)
  85. parser = reqparse.RequestParser()
  86. parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
  87. args = parser.parse_args()
  88. try:
  89. MessageService.create_feedback(app_model, message_id, end_user, args['rating'])
  90. except services.errors.message.MessageNotExistsError:
  91. raise NotFound("Message Not Exists.")
  92. return {'result': 'success'}
  93. class MessageMoreLikeThisApi(WebApiResource):
  94. def get(self, app_model, end_user, message_id):
  95. if app_model.mode != 'completion':
  96. raise NotCompletionAppError()
  97. message_id = str(message_id)
  98. parser = reqparse.RequestParser()
  99. parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args')
  100. args = parser.parse_args()
  101. streaming = args['response_mode'] == 'streaming'
  102. try:
  103. response = CompletionService.generate_more_like_this(
  104. app_model=app_model,
  105. user=end_user,
  106. message_id=message_id,
  107. invoke_from=InvokeFrom.WEB_APP,
  108. streaming=streaming
  109. )
  110. return compact_response(response)
  111. except MessageNotExistsError:
  112. raise NotFound("Message Not Exists.")
  113. except MoreLikeThisDisabledError:
  114. raise AppMoreLikeThisDisabledError()
  115. except ProviderTokenNotInitError as ex:
  116. raise ProviderNotInitializeError(ex.description)
  117. except QuotaExceededError:
  118. raise ProviderQuotaExceededError()
  119. except ModelCurrentlyNotSupportError:
  120. raise ProviderModelCurrentlyNotSupportError()
  121. except InvokeError as e:
  122. raise CompletionRequestError(e.description)
  123. except ValueError as e:
  124. raise e
  125. except Exception:
  126. logging.exception("internal server error.")
  127. raise InternalServerError()
  128. def compact_response(response: Union[dict, Generator]) -> Response:
  129. if isinstance(response, dict):
  130. return Response(response=json.dumps(response), status=200, mimetype='application/json')
  131. else:
  132. def generate() -> Generator:
  133. for chunk in response:
  134. yield chunk
  135. return Response(stream_with_context(generate()), status=200,
  136. mimetype='text/event-stream')
  137. class MessageSuggestedQuestionApi(WebApiResource):
  138. def get(self, app_model, end_user, message_id):
  139. if app_model.mode != 'chat':
  140. raise NotCompletionAppError()
  141. message_id = str(message_id)
  142. try:
  143. questions = MessageService.get_suggested_questions_after_answer(
  144. app_model=app_model,
  145. user=end_user,
  146. message_id=message_id
  147. )
  148. except MessageNotExistsError:
  149. raise NotFound("Message not found")
  150. except ConversationNotExistsError:
  151. raise NotFound("Conversation not found")
  152. except SuggestedQuestionsAfterAnswerDisabledError:
  153. raise AppSuggestedQuestionsAfterAnswerDisabledError()
  154. except ProviderTokenNotInitError as ex:
  155. raise ProviderNotInitializeError(ex.description)
  156. except QuotaExceededError:
  157. raise ProviderQuotaExceededError()
  158. except ModelCurrentlyNotSupportError:
  159. raise ProviderModelCurrentlyNotSupportError()
  160. except InvokeError as e:
  161. raise CompletionRequestError(e.description)
  162. except Exception:
  163. logging.exception("internal server error.")
  164. raise InternalServerError()
  165. return {'data': questions}
  166. api.add_resource(MessageListApi, '/messages')
  167. api.add_resource(MessageFeedbackApi, '/messages/<uuid:message_id>/feedbacks')
  168. api.add_resource(MessageMoreLikeThisApi, '/messages/<uuid:message_id>/more-like-this')
  169. api.add_resource(MessageSuggestedQuestionApi, '/messages/<uuid:message_id>/suggested-questions')