hit_testing.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import logging
  2. from flask_login import current_user
  3. from libs.login import login_required
  4. from flask_restful import Resource, reqparse, marshal
  5. from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
  6. import services
  7. from controllers.console import api
  8. from controllers.console.app.error import ProviderNotInitializeError, ProviderQuotaExceededError, \
  9. ProviderModelCurrentlyNotSupportError
  10. from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
  11. from controllers.console.setup import setup_required
  12. from controllers.console.wraps import account_initialization_required
  13. from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
  14. LLMBadRequestError
  15. from fields.hit_testing_fields import hit_testing_record_fields
  16. from services.dataset_service import DatasetService
  17. from services.hit_testing_service import HitTestingService
  18. class HitTestingApi(Resource):
  19. @setup_required
  20. @login_required
  21. @account_initialization_required
  22. def post(self, dataset_id):
  23. dataset_id_str = str(dataset_id)
  24. dataset = DatasetService.get_dataset(dataset_id_str)
  25. if dataset is None:
  26. raise NotFound("Dataset not found.")
  27. try:
  28. DatasetService.check_dataset_permission(dataset, current_user)
  29. except services.errors.account.NoPermissionError as e:
  30. raise Forbidden(str(e))
  31. # only high quality dataset can be used for hit testing
  32. if dataset.indexing_technique != 'high_quality':
  33. raise HighQualityDatasetOnlyError()
  34. parser = reqparse.RequestParser()
  35. parser.add_argument('query', type=str, location='json')
  36. args = parser.parse_args()
  37. query = args['query']
  38. if not query or len(query) > 250:
  39. raise ValueError('Query is required and cannot exceed 250 characters')
  40. try:
  41. response = HitTestingService.retrieve(
  42. dataset=dataset,
  43. query=query,
  44. account=current_user,
  45. limit=10,
  46. )
  47. return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}
  48. except services.errors.index.IndexNotInitializedError:
  49. raise DatasetNotInitializedError()
  50. except ProviderTokenNotInitError as ex:
  51. raise ProviderNotInitializeError(ex.description)
  52. except QuotaExceededError:
  53. raise ProviderQuotaExceededError()
  54. except ModelCurrentlyNotSupportError:
  55. raise ProviderModelCurrentlyNotSupportError()
  56. except LLMBadRequestError:
  57. raise ProviderNotInitializeError(
  58. f"No Embedding Model available. Please configure a valid provider "
  59. f"in the Settings -> Model Provider.")
  60. except ValueError as e:
  61. raise ValueError(str(e))
  62. except Exception as e:
  63. logging.exception("Hit testing failed.")
  64. raise InternalServerError(str(e))
  65. api.add_resource(HitTestingApi, '/datasets/<uuid:dataset_id>/hit-testing')