csv.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import logging
  2. from typing import Optional, Dict, List
  3. from langchain.document_loaders import CSVLoader as LCCSVLoader
  4. from langchain.document_loaders.helpers import detect_file_encodings
  5. from models.dataset import Document
  6. logger = logging.getLogger(__name__)
  7. class CSVLoader(LCCSVLoader):
  8. def __init__(
  9. self,
  10. file_path: str,
  11. source_column: Optional[str] = None,
  12. csv_args: Optional[Dict] = None,
  13. encoding: Optional[str] = None,
  14. autodetect_encoding: bool = True,
  15. ):
  16. self.file_path = file_path
  17. self.source_column = source_column
  18. self.encoding = encoding
  19. self.csv_args = csv_args or {}
  20. self.autodetect_encoding = autodetect_encoding
  21. def load(self) -> List[Document]:
  22. """Load data into document objects."""
  23. try:
  24. with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
  25. docs = self._read_from_file(csvfile)
  26. except UnicodeDecodeError as e:
  27. if self.autodetect_encoding:
  28. detected_encodings = detect_file_encodings(self.file_path)
  29. for encoding in detected_encodings:
  30. logger.debug("Trying encoding: ", encoding.encoding)
  31. try:
  32. with open(self.file_path, newline="", encoding=encoding.encoding) as csvfile:
  33. docs = self._read_from_file(csvfile)
  34. break
  35. except UnicodeDecodeError:
  36. continue
  37. else:
  38. raise RuntimeError(f"Error loading {self.file_path}") from e
  39. return docs
  40. def _read_from_file(self, csvfile):
  41. docs = []
  42. csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
  43. for i, row in enumerate(csv_reader):
  44. content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items())
  45. try:
  46. source = (
  47. row[self.source_column]
  48. if self.source_column is not None
  49. else ''
  50. )
  51. except KeyError:
  52. raise ValueError(
  53. f"Source column '{self.source_column}' not found in CSV file."
  54. )
  55. metadata = {"source": source, "row": i}
  56. doc = Document(page_content=content, metadata=metadata)
  57. docs.append(doc)
  58. return docs