hit_testing.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import logging
  2. from flask_login import current_user
  3. from libs.login import login_required
  4. from flask_restful import Resource, reqparse, marshal
  5. from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
  6. import services
  7. from controllers.console import api
  8. from controllers.console.app.error import ProviderNotInitializeError, ProviderQuotaExceededError, \
  9. ProviderModelCurrentlyNotSupportError
  10. from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
  11. from controllers.console.setup import setup_required
  12. from controllers.console.wraps import account_initialization_required
  13. from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
  14. LLMBadRequestError
  15. from fields.hit_testing_fields import hit_testing_record_fields
  16. from services.dataset_service import DatasetService
  17. from services.hit_testing_service import HitTestingService
  18. class HitTestingApi(Resource):
  19. @setup_required
  20. @login_required
  21. @account_initialization_required
  22. def post(self, dataset_id):
  23. dataset_id_str = str(dataset_id)
  24. dataset = DatasetService.get_dataset(dataset_id_str)
  25. if dataset is None:
  26. raise NotFound("Dataset not found.")
  27. try:
  28. DatasetService.check_dataset_permission(dataset, current_user)
  29. except services.errors.account.NoPermissionError as e:
  30. raise Forbidden(str(e))
  31. # only high quality dataset can be used for hit testing
  32. if dataset.indexing_technique != 'high_quality':
  33. raise HighQualityDatasetOnlyError()
  34. parser = reqparse.RequestParser()
  35. parser.add_argument('query', type=str, location='json')
  36. parser.add_argument('retrieval_model', type=dict, required=False, location='json')
  37. args = parser.parse_args()
  38. HitTestingService.hit_testing_args_check(args)
  39. try:
  40. response = HitTestingService.retrieve(
  41. dataset=dataset,
  42. query=args['query'],
  43. account=current_user,
  44. retrieval_model=args['retrieval_model'],
  45. limit=10
  46. )
  47. return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}
  48. except services.errors.index.IndexNotInitializedError:
  49. raise DatasetNotInitializedError()
  50. except ProviderTokenNotInitError as ex:
  51. raise ProviderNotInitializeError(ex.description)
  52. except QuotaExceededError:
  53. raise ProviderQuotaExceededError()
  54. except ModelCurrentlyNotSupportError:
  55. raise ProviderModelCurrentlyNotSupportError()
  56. except LLMBadRequestError:
  57. raise ProviderNotInitializeError(
  58. f"No Embedding Model or Reranking Model available. Please configure a valid provider "
  59. f"in the Settings -> Model Provider.")
  60. except ValueError as e:
  61. raise ValueError(str(e))
  62. except Exception as e:
  63. logging.exception("Hit testing failed.")
  64. raise InternalServerError(str(e))
  65. api.add_resource(HitTestingApi, '/datasets/<uuid:dataset_id>/hit-testing')