encoders.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. import dataclasses
  2. import datetime
  3. from collections import defaultdict, deque
  4. from collections.abc import Callable
  5. from decimal import Decimal
  6. from enum import Enum
  7. from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
  8. from pathlib import Path, PurePath
  9. from re import Pattern
  10. from types import GeneratorType
  11. from typing import Any, Literal, Optional, Union
  12. from uuid import UUID
  13. from pydantic import BaseModel
  14. from pydantic.networks import AnyUrl, NameEmail
  15. from pydantic.types import SecretBytes, SecretStr
  16. from pydantic_core import Url
  17. from pydantic_extra_types.color import Color
  18. def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any:
  19. return model.model_dump(mode=mode, **kwargs)
  20. # Taken from Pydantic v1 as is
  21. def isoformat(o: Union[datetime.date, datetime.time]) -> str:
  22. return o.isoformat()
  23. # Taken from Pydantic v1 as is
  24. # TODO: pv2 should this return strings instead?
  25. def decimal_encoder(dec_value: Decimal) -> Union[int, float]:
  26. """
  27. Encodes a Decimal as int of there's no exponent, otherwise float
  28. This is useful when we use ConstrainedDecimal to represent Numeric(x,0)
  29. where a integer (but not int typed) is used. Encoding this as a float
  30. results in failed round-tripping between encode and parse.
  31. Our Id type is a prime example of this.
  32. >>> decimal_encoder(Decimal("1.0"))
  33. 1.0
  34. >>> decimal_encoder(Decimal("1"))
  35. 1
  36. """
  37. if dec_value.as_tuple().exponent >= 0: # type: ignore[operator]
  38. return int(dec_value)
  39. else:
  40. return float(dec_value)
  41. ENCODERS_BY_TYPE: dict[type[Any], Callable[[Any], Any]] = {
  42. bytes: lambda o: o.decode(),
  43. Color: str,
  44. datetime.date: isoformat,
  45. datetime.datetime: isoformat,
  46. datetime.time: isoformat,
  47. datetime.timedelta: lambda td: td.total_seconds(),
  48. Decimal: decimal_encoder,
  49. Enum: lambda o: o.value,
  50. frozenset: list,
  51. deque: list,
  52. GeneratorType: list,
  53. IPv4Address: str,
  54. IPv4Interface: str,
  55. IPv4Network: str,
  56. IPv6Address: str,
  57. IPv6Interface: str,
  58. IPv6Network: str,
  59. NameEmail: str,
  60. Path: str,
  61. Pattern: lambda o: o.pattern,
  62. SecretBytes: str,
  63. SecretStr: str,
  64. set: list,
  65. UUID: str,
  66. Url: str,
  67. AnyUrl: str,
  68. }
  69. def generate_encoders_by_class_tuples(
  70. type_encoder_map: dict[Any, Callable[[Any], Any]],
  71. ) -> dict[Callable[[Any], Any], tuple[Any, ...]]:
  72. encoders_by_class_tuples: dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict(tuple)
  73. for type_, encoder in type_encoder_map.items():
  74. encoders_by_class_tuples[encoder] += (type_,)
  75. return encoders_by_class_tuples
  76. encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE)
  77. def jsonable_encoder(
  78. obj: Any,
  79. by_alias: bool = True,
  80. exclude_unset: bool = False,
  81. exclude_defaults: bool = False,
  82. exclude_none: bool = False,
  83. custom_encoder: Optional[dict[Any, Callable[[Any], Any]]] = None,
  84. sqlalchemy_safe: bool = True,
  85. ) -> Any:
  86. custom_encoder = custom_encoder or {}
  87. if custom_encoder:
  88. if type(obj) in custom_encoder:
  89. return custom_encoder[type(obj)](obj)
  90. else:
  91. for encoder_type, encoder_instance in custom_encoder.items():
  92. if isinstance(obj, encoder_type):
  93. return encoder_instance(obj)
  94. if isinstance(obj, BaseModel):
  95. obj_dict = _model_dump(
  96. obj,
  97. mode="json",
  98. include=None,
  99. exclude=None,
  100. by_alias=by_alias,
  101. exclude_unset=exclude_unset,
  102. exclude_none=exclude_none,
  103. exclude_defaults=exclude_defaults,
  104. )
  105. if "__root__" in obj_dict:
  106. obj_dict = obj_dict["__root__"]
  107. return jsonable_encoder(
  108. obj_dict,
  109. exclude_none=exclude_none,
  110. exclude_defaults=exclude_defaults,
  111. sqlalchemy_safe=sqlalchemy_safe,
  112. )
  113. if dataclasses.is_dataclass(obj):
  114. # FIXME: mypy error, try to fix it instead of using type: ignore
  115. obj_dict = dataclasses.asdict(obj) # type: ignore
  116. return jsonable_encoder(
  117. obj_dict,
  118. by_alias=by_alias,
  119. exclude_unset=exclude_unset,
  120. exclude_defaults=exclude_defaults,
  121. exclude_none=exclude_none,
  122. custom_encoder=custom_encoder,
  123. sqlalchemy_safe=sqlalchemy_safe,
  124. )
  125. if isinstance(obj, Enum):
  126. return obj.value
  127. if isinstance(obj, PurePath):
  128. return str(obj)
  129. if isinstance(obj, str | int | float | type(None)):
  130. return obj
  131. if isinstance(obj, Decimal):
  132. return format(obj, "f")
  133. if isinstance(obj, dict):
  134. encoded_dict = {}
  135. allowed_keys = set(obj.keys())
  136. for key, value in obj.items():
  137. if (
  138. (not sqlalchemy_safe or (not isinstance(key, str)) or (not key.startswith("_sa")))
  139. and (value is not None or not exclude_none)
  140. and key in allowed_keys
  141. ):
  142. encoded_key = jsonable_encoder(
  143. key,
  144. by_alias=by_alias,
  145. exclude_unset=exclude_unset,
  146. exclude_none=exclude_none,
  147. custom_encoder=custom_encoder,
  148. sqlalchemy_safe=sqlalchemy_safe,
  149. )
  150. encoded_value = jsonable_encoder(
  151. value,
  152. by_alias=by_alias,
  153. exclude_unset=exclude_unset,
  154. exclude_none=exclude_none,
  155. custom_encoder=custom_encoder,
  156. sqlalchemy_safe=sqlalchemy_safe,
  157. )
  158. encoded_dict[encoded_key] = encoded_value
  159. return encoded_dict
  160. if isinstance(obj, list | set | frozenset | GeneratorType | tuple | deque):
  161. encoded_list = []
  162. for item in obj:
  163. encoded_list.append(
  164. jsonable_encoder(
  165. item,
  166. by_alias=by_alias,
  167. exclude_unset=exclude_unset,
  168. exclude_defaults=exclude_defaults,
  169. exclude_none=exclude_none,
  170. custom_encoder=custom_encoder,
  171. sqlalchemy_safe=sqlalchemy_safe,
  172. )
  173. )
  174. return encoded_list
  175. if type(obj) in ENCODERS_BY_TYPE:
  176. return ENCODERS_BY_TYPE[type(obj)](obj)
  177. for encoder, classes_tuple in encoders_by_class_tuples.items():
  178. if isinstance(obj, classes_tuple):
  179. return encoder(obj)
  180. try:
  181. data = dict(obj)
  182. except Exception as e:
  183. errors: list[Exception] = []
  184. errors.append(e)
  185. try:
  186. data = vars(obj)
  187. except Exception as e:
  188. errors.append(e)
  189. raise ValueError(errors) from e
  190. return jsonable_encoder(
  191. data,
  192. by_alias=by_alias,
  193. exclude_unset=exclude_unset,
  194. exclude_defaults=exclude_defaults,
  195. exclude_none=exclude_none,
  196. custom_encoder=custom_encoder,
  197. sqlalchemy_safe=sqlalchemy_safe,
  198. )