batch_import_annotations_task.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import json
  2. import logging
  3. import time
  4. import click
  5. from celery import shared_task
  6. from langchain.schema import Document
  7. from werkzeug.exceptions import NotFound
  8. from core.index.index import IndexBuilder
  9. from extensions.ext_database import db
  10. from extensions.ext_redis import redis_client
  11. from models.dataset import Dataset
  12. from models.model import MessageAnnotation, App, AppAnnotationSetting
  13. from services.dataset_service import DatasetCollectionBindingService
  14. @shared_task(queue='dataset')
  15. def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: str, tenant_id: str,
  16. user_id: str):
  17. """
  18. Add annotation to index.
  19. :param job_id: job_id
  20. :param content_list: content list
  21. :param tenant_id: tenant id
  22. :param app_id: app id
  23. :param user_id: user_id
  24. """
  25. logging.info(click.style('Start batch import annotation: {}'.format(job_id), fg='green'))
  26. start_at = time.perf_counter()
  27. indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id))
  28. # get app info
  29. app = db.session.query(App).filter(
  30. App.id == app_id,
  31. App.tenant_id == tenant_id,
  32. App.status == 'normal'
  33. ).first()
  34. if app:
  35. try:
  36. documents = []
  37. for content in content_list:
  38. annotation = MessageAnnotation(
  39. app_id=app.id,
  40. content=content['answer'],
  41. question=content['question'],
  42. account_id=user_id
  43. )
  44. db.session.add(annotation)
  45. db.session.flush()
  46. document = Document(
  47. page_content=content['question'],
  48. metadata={
  49. "annotation_id": annotation.id,
  50. "app_id": app_id,
  51. "doc_id": annotation.id
  52. }
  53. )
  54. documents.append(document)
  55. # if annotation reply is enabled , batch add annotations' index
  56. app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
  57. AppAnnotationSetting.app_id == app_id
  58. ).first()
  59. if app_annotation_setting:
  60. dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
  61. app_annotation_setting.collection_binding_id,
  62. 'annotation'
  63. )
  64. if not dataset_collection_binding:
  65. raise NotFound("App annotation setting not found")
  66. dataset = Dataset(
  67. id=app_id,
  68. tenant_id=tenant_id,
  69. indexing_technique='high_quality',
  70. embedding_model_provider=dataset_collection_binding.provider_name,
  71. embedding_model=dataset_collection_binding.model_name,
  72. collection_binding_id=dataset_collection_binding.id
  73. )
  74. index = IndexBuilder.get_index(dataset, 'high_quality')
  75. if index:
  76. index.add_texts(documents)
  77. db.session.commit()
  78. redis_client.setex(indexing_cache_key, 600, 'completed')
  79. end_at = time.perf_counter()
  80. logging.info(
  81. click.style(
  82. 'Build index successful for batch import annotation: {} latency: {}'.format(job_id, end_at - start_at),
  83. fg='green'))
  84. except Exception as e:
  85. db.session.rollback()
  86. redis_client.setex(indexing_cache_key, 600, 'error')
  87. indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id))
  88. redis_client.setex(indexing_error_msg_key, 600, str(e))
  89. logging.exception("Build index for batch import annotations failed")