hit_testing.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import logging
  2. from flask_login import login_required, current_user
  3. from flask_restful import Resource, reqparse, marshal, fields
  4. from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
  5. import services
  6. from controllers.console import api
  7. from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
  8. from controllers.console.setup import setup_required
  9. from controllers.console.wraps import account_initialization_required
  10. from libs.helper import TimestampField
  11. from services.dataset_service import DatasetService
  12. from services.hit_testing_service import HitTestingService
  13. document_fields = {
  14. 'id': fields.String,
  15. 'data_source_type': fields.String,
  16. 'name': fields.String,
  17. 'doc_type': fields.String,
  18. }
  19. segment_fields = {
  20. 'id': fields.String,
  21. 'position': fields.Integer,
  22. 'document_id': fields.String,
  23. 'content': fields.String,
  24. 'word_count': fields.Integer,
  25. 'tokens': fields.Integer,
  26. 'keywords': fields.List(fields.String),
  27. 'index_node_id': fields.String,
  28. 'index_node_hash': fields.String,
  29. 'hit_count': fields.Integer,
  30. 'enabled': fields.Boolean,
  31. 'disabled_at': TimestampField,
  32. 'disabled_by': fields.String,
  33. 'status': fields.String,
  34. 'created_by': fields.String,
  35. 'created_at': TimestampField,
  36. 'indexing_at': TimestampField,
  37. 'completed_at': TimestampField,
  38. 'error': fields.String,
  39. 'stopped_at': TimestampField,
  40. 'document': fields.Nested(document_fields),
  41. }
  42. hit_testing_record_fields = {
  43. 'segment': fields.Nested(segment_fields),
  44. 'score': fields.Float,
  45. 'tsne_position': fields.Raw
  46. }
  47. class HitTestingApi(Resource):
  48. @setup_required
  49. @login_required
  50. @account_initialization_required
  51. def post(self, dataset_id):
  52. dataset_id_str = str(dataset_id)
  53. dataset = DatasetService.get_dataset(dataset_id_str)
  54. if dataset is None:
  55. raise NotFound("Dataset not found.")
  56. try:
  57. DatasetService.check_dataset_permission(dataset, current_user)
  58. except services.errors.account.NoPermissionError as e:
  59. raise Forbidden(str(e))
  60. # only high quality dataset can be used for hit testing
  61. if dataset.indexing_technique != 'high_quality':
  62. raise HighQualityDatasetOnlyError()
  63. parser = reqparse.RequestParser()
  64. parser.add_argument('query', type=str, location='json')
  65. args = parser.parse_args()
  66. query = args['query']
  67. if not query or len(query) > 250:
  68. raise ValueError('Query is required and cannot exceed 250 characters')
  69. try:
  70. response = HitTestingService.retrieve(
  71. dataset=dataset,
  72. query=query,
  73. account=current_user,
  74. limit=10,
  75. )
  76. return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}
  77. except services.errors.index.IndexNotInitializedError:
  78. raise DatasetNotInitializedError()
  79. except Exception as e:
  80. logging.exception("Hit testing failed.")
  81. raise InternalServerError(str(e))
  82. api.add_resource(HitTestingApi, '/datasets/<uuid:dataset_id>/hit-testing')