dataset_retriever_tool.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. from typing import Type
  2. from flask import current_app
  3. from langchain.tools import BaseTool
  4. from pydantic import Field, BaseModel
  5. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  6. from core.embedding.cached_embedding import CacheEmbedding
  7. from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
  8. from core.index.vector_index.vector_index import VectorIndex
  9. from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
  10. from core.model_providers.model_factory import ModelFactory
  11. from extensions.ext_database import db
  12. from models.dataset import Dataset, DocumentSegment
  13. class DatasetRetrieverToolInput(BaseModel):
  14. query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
  15. class DatasetRetrieverTool(BaseTool):
  16. """Tool for querying a Dataset."""
  17. name: str = "dataset"
  18. args_schema: Type[BaseModel] = DatasetRetrieverToolInput
  19. description: str = "use this to retrieve a dataset. "
  20. tenant_id: str
  21. dataset_id: str
  22. k: int = 3
  23. @classmethod
  24. def from_dataset(cls, dataset: Dataset, **kwargs):
  25. description = dataset.description
  26. if not description:
  27. description = 'useful for when you want to answer queries about the ' + dataset.name
  28. description = description.replace('\n', '').replace('\r', '')
  29. return cls(
  30. name=f'dataset-{dataset.id}',
  31. tenant_id=dataset.tenant_id,
  32. dataset_id=dataset.id,
  33. description=description,
  34. **kwargs
  35. )
  36. def _run(self, query: str) -> str:
  37. dataset = db.session.query(Dataset).filter(
  38. Dataset.tenant_id == self.tenant_id,
  39. Dataset.id == self.dataset_id
  40. ).first()
  41. if not dataset:
  42. return f'[{self.name} failed to find dataset with id {self.dataset_id}.]'
  43. if dataset.indexing_technique == "economy":
  44. # use keyword table query
  45. kw_table_index = KeywordTableIndex(
  46. dataset=dataset,
  47. config=KeywordTableConfig(
  48. max_keywords_per_chunk=5
  49. )
  50. )
  51. documents = kw_table_index.search(query, search_kwargs={'k': self.k})
  52. return str("\n".join([document.page_content for document in documents]))
  53. else:
  54. try:
  55. embedding_model = ModelFactory.get_embedding_model(
  56. tenant_id=dataset.tenant_id,
  57. model_provider_name=dataset.embedding_model_provider,
  58. model_name=dataset.embedding_model
  59. )
  60. except LLMBadRequestError:
  61. return ''
  62. except ProviderTokenNotInitError:
  63. return ''
  64. embeddings = CacheEmbedding(embedding_model)
  65. vector_index = VectorIndex(
  66. dataset=dataset,
  67. config=current_app.config,
  68. embeddings=embeddings
  69. )
  70. if self.k > 0:
  71. documents = vector_index.search(
  72. query,
  73. search_type='similarity',
  74. search_kwargs={
  75. 'k': self.k
  76. }
  77. )
  78. else:
  79. documents = []
  80. hit_callback = DatasetIndexToolCallbackHandler(dataset.id)
  81. hit_callback.on_tool_end(documents)
  82. document_context_list = []
  83. index_node_ids = [document.metadata['doc_id'] for document in documents]
  84. segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
  85. DocumentSegment.status == 'completed',
  86. DocumentSegment.enabled == True,
  87. DocumentSegment.index_node_id.in_(index_node_ids)
  88. ).all()
  89. if segments:
  90. index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
  91. sorted_segments = sorted(segments,
  92. key=lambda segment: index_node_id_to_position.get(segment.index_node_id,
  93. float('inf')))
  94. for segment in sorted_segments:
  95. if segment.answer:
  96. document_context_list.append(f'question:{segment.content} \nanswer:{segment.answer}')
  97. else:
  98. document_context_list.append(segment.content)
  99. return str("\n".join(document_context_list))
  100. async def _arun(self, tool_input: str) -> str:
  101. raise NotImplementedError()