hit_testing.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import logging
  2. from flask_login import current_user
  3. from core.model_runtime.errors.invoke import InvokeError
  4. from libs.login import login_required
  5. from flask_restful import Resource, reqparse, marshal
  6. from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
  7. import services
  8. from controllers.console import api
  9. from controllers.console.app.error import ProviderNotInitializeError, ProviderQuotaExceededError, \
  10. ProviderModelCurrentlyNotSupportError, CompletionRequestError
  11. from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
  12. from controllers.console.setup import setup_required
  13. from controllers.console.wraps import account_initialization_required
  14. from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
  15. LLMBadRequestError
  16. from fields.hit_testing_fields import hit_testing_record_fields
  17. from services.dataset_service import DatasetService
  18. from services.hit_testing_service import HitTestingService
  19. class HitTestingApi(Resource):
  20. @setup_required
  21. @login_required
  22. @account_initialization_required
  23. def post(self, dataset_id):
  24. dataset_id_str = str(dataset_id)
  25. dataset = DatasetService.get_dataset(dataset_id_str)
  26. if dataset is None:
  27. raise NotFound("Dataset not found.")
  28. try:
  29. DatasetService.check_dataset_permission(dataset, current_user)
  30. except services.errors.account.NoPermissionError as e:
  31. raise Forbidden(str(e))
  32. # only high quality dataset can be used for hit testing
  33. if dataset.indexing_technique != 'high_quality':
  34. raise HighQualityDatasetOnlyError()
  35. parser = reqparse.RequestParser()
  36. parser.add_argument('query', type=str, location='json')
  37. parser.add_argument('retrieval_model', type=dict, required=False, location='json')
  38. args = parser.parse_args()
  39. HitTestingService.hit_testing_args_check(args)
  40. try:
  41. response = HitTestingService.retrieve(
  42. dataset=dataset,
  43. query=args['query'],
  44. account=current_user,
  45. retrieval_model=args['retrieval_model'],
  46. limit=10
  47. )
  48. return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}
  49. except services.errors.index.IndexNotInitializedError:
  50. raise DatasetNotInitializedError()
  51. except ProviderTokenNotInitError as ex:
  52. raise ProviderNotInitializeError(ex.description)
  53. except QuotaExceededError:
  54. raise ProviderQuotaExceededError()
  55. except ModelCurrentlyNotSupportError:
  56. raise ProviderModelCurrentlyNotSupportError()
  57. except LLMBadRequestError:
  58. raise ProviderNotInitializeError(
  59. f"No Embedding Model or Reranking Model available. Please configure a valid provider "
  60. f"in the Settings -> Model Provider.")
  61. except InvokeError as e:
  62. raise CompletionRequestError(e.description)
  63. except ValueError as e:
  64. raise ValueError(str(e))
  65. except Exception as e:
  66. logging.exception("Hit testing failed.")
  67. raise InternalServerError(str(e))
  68. api.add_resource(HitTestingApi, '/datasets/<uuid:dataset_id>/hit-testing')