extensible.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import enum
  2. import importlib.util
  3. import json
  4. import logging
  5. import os
  6. from pathlib import Path
  7. from typing import Any, Optional
  8. from pydantic import BaseModel
  9. from core.helper.position_helper import sort_to_dict_by_position_map
  10. class ExtensionModule(enum.Enum):
  11. MODERATION = "moderation"
  12. EXTERNAL_DATA_TOOL = "external_data_tool"
  13. class ModuleExtension(BaseModel):
  14. extension_class: Any = None
  15. name: str
  16. label: Optional[dict] = None
  17. form_schema: Optional[list] = None
  18. builtin: bool = True
  19. position: Optional[int] = None
  20. class Extensible:
  21. module: ExtensionModule
  22. name: str
  23. tenant_id: str
  24. config: Optional[dict] = None
  25. def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None:
  26. self.tenant_id = tenant_id
  27. self.config = config
  28. @classmethod
  29. def scan_extensions(cls):
  30. extensions = []
  31. position_map: dict[str, int] = {}
  32. # get the path of the current class
  33. current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py")
  34. current_dir_path = os.path.dirname(current_path)
  35. # traverse subdirectories
  36. for subdir_name in os.listdir(current_dir_path):
  37. if subdir_name.startswith("__"):
  38. continue
  39. subdir_path = os.path.join(current_dir_path, subdir_name)
  40. extension_name = subdir_name
  41. if os.path.isdir(subdir_path):
  42. file_names = os.listdir(subdir_path)
  43. # is builtin extension, builtin extension
  44. # in the front-end page and business logic, there are special treatments.
  45. builtin = False
  46. # default position is 0 can not be None for sort_to_dict_by_position_map
  47. position = 0
  48. if "__builtin__" in file_names:
  49. builtin = True
  50. builtin_file_path = os.path.join(subdir_path, "__builtin__")
  51. if os.path.exists(builtin_file_path):
  52. position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip())
  53. position_map[extension_name] = position
  54. if (extension_name + ".py") not in file_names:
  55. logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
  56. continue
  57. # Dynamic loading {subdir_name}.py file and find the subclass of Extensible
  58. py_path = os.path.join(subdir_path, extension_name + ".py")
  59. spec = importlib.util.spec_from_file_location(extension_name, py_path)
  60. if not spec or not spec.loader:
  61. raise Exception(f"Failed to load module {extension_name} from {py_path}")
  62. mod = importlib.util.module_from_spec(spec)
  63. spec.loader.exec_module(mod)
  64. extension_class = None
  65. for name, obj in vars(mod).items():
  66. if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
  67. extension_class = obj
  68. break
  69. if not extension_class:
  70. logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.")
  71. continue
  72. json_data: dict[str, Any] = {}
  73. if not builtin:
  74. if "schema.json" not in file_names:
  75. logging.warning(f"Missing schema.json file in {subdir_path}, Skip.")
  76. continue
  77. json_path = os.path.join(subdir_path, "schema.json")
  78. json_data = {}
  79. if os.path.exists(json_path):
  80. with open(json_path, encoding="utf-8") as f:
  81. json_data = json.load(f)
  82. extensions.append(
  83. ModuleExtension(
  84. extension_class=extension_class,
  85. name=extension_name,
  86. label=json_data.get("label"),
  87. form_schema=json_data.get("form_schema"),
  88. builtin=builtin,
  89. position=position,
  90. )
  91. )
  92. sorted_extensions = sort_to_dict_by_position_map(
  93. position_map=position_map, data=extensions, name_func=lambda x: x.name
  94. )
  95. return sorted_extensions