extensible.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import enum
  2. import importlib.util
  3. import json
  4. import logging
  5. import os
  6. from typing import Any, Optional
  7. from pydantic import BaseModel
  8. from core.helper.position_helper import sort_to_dict_by_position_map
  9. class ExtensionModule(enum.Enum):
  10. MODERATION = "moderation"
  11. EXTERNAL_DATA_TOOL = "external_data_tool"
  12. class ModuleExtension(BaseModel):
  13. extension_class: Any = None
  14. name: str
  15. label: Optional[dict] = None
  16. form_schema: Optional[list] = None
  17. builtin: bool = True
  18. position: Optional[int] = None
  19. class Extensible:
  20. module: ExtensionModule
  21. name: str
  22. tenant_id: str
  23. config: Optional[dict] = None
  24. def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None:
  25. self.tenant_id = tenant_id
  26. self.config = config
  27. @classmethod
  28. def scan_extensions(cls):
  29. extensions: list[ModuleExtension] = []
  30. position_map = {}
  31. # get the path of the current class
  32. current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py")
  33. current_dir_path = os.path.dirname(current_path)
  34. # traverse subdirectories
  35. for subdir_name in os.listdir(current_dir_path):
  36. if subdir_name.startswith("__"):
  37. continue
  38. subdir_path = os.path.join(current_dir_path, subdir_name)
  39. extension_name = subdir_name
  40. if os.path.isdir(subdir_path):
  41. file_names = os.listdir(subdir_path)
  42. # is builtin extension, builtin extension
  43. # in the front-end page and business logic, there are special treatments.
  44. builtin = False
  45. position = None
  46. if "__builtin__" in file_names:
  47. builtin = True
  48. builtin_file_path = os.path.join(subdir_path, "__builtin__")
  49. if os.path.exists(builtin_file_path):
  50. with open(builtin_file_path, encoding="utf-8") as f:
  51. position = int(f.read().strip())
  52. position_map[extension_name] = position
  53. if (extension_name + ".py") not in file_names:
  54. logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
  55. continue
  56. # Dynamic loading {subdir_name}.py file and find the subclass of Extensible
  57. py_path = os.path.join(subdir_path, extension_name + ".py")
  58. spec = importlib.util.spec_from_file_location(extension_name, py_path)
  59. if not spec or not spec.loader:
  60. raise Exception(f"Failed to load module {extension_name} from {py_path}")
  61. mod = importlib.util.module_from_spec(spec)
  62. spec.loader.exec_module(mod)
  63. extension_class = None
  64. for name, obj in vars(mod).items():
  65. if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
  66. extension_class = obj
  67. break
  68. if not extension_class:
  69. logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.")
  70. continue
  71. json_data = {}
  72. if not builtin:
  73. if "schema.json" not in file_names:
  74. logging.warning(f"Missing schema.json file in {subdir_path}, Skip.")
  75. continue
  76. json_path = os.path.join(subdir_path, "schema.json")
  77. json_data = {}
  78. if os.path.exists(json_path):
  79. with open(json_path, encoding="utf-8") as f:
  80. json_data = json.load(f)
  81. extensions.append(
  82. ModuleExtension(
  83. extension_class=extension_class,
  84. name=extension_name,
  85. label=json_data.get("label"),
  86. form_schema=json_data.get("form_schema"),
  87. builtin=builtin,
  88. position=position,
  89. )
  90. )
  91. sorted_extensions = sort_to_dict_by_position_map(
  92. position_map=position_map, data=extensions, name_func=lambda x: x.name
  93. )
  94. return sorted_extensions