plugin_migration.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501
  1. import datetime
  2. import json
  3. import logging
  4. import time
  5. from collections.abc import Mapping, Sequence
  6. from concurrent.futures import ThreadPoolExecutor
  7. from pathlib import Path
  8. from typing import Any, Optional
  9. from uuid import uuid4
  10. import click
  11. import tqdm
  12. from flask import Flask, current_app
  13. from sqlalchemy.orm import Session
  14. from core.agent.entities import AgentToolEntity
  15. from core.helper import marketplace
  16. from core.plugin.entities.plugin import ModelProviderID, PluginInstallationSource, ToolProviderID
  17. from core.plugin.entities.plugin_daemon import PluginInstallTaskStatus
  18. from core.plugin.manager.plugin import PluginInstallationManager
  19. from core.tools.entities.tool_entities import ToolProviderType
  20. from models.account import Tenant
  21. from models.engine import db
  22. from models.model import App, AppMode, AppModelConfig
  23. from models.tools import BuiltinToolProvider
  24. from models.workflow import Workflow
  25. logger = logging.getLogger(__name__)
  26. excluded_providers = ["time", "audio", "code", "webscraper"]
  27. class PluginMigration:
  28. @classmethod
  29. def extract_plugins(cls, filepath: str, workers: int) -> None:
  30. """
  31. Migrate plugin.
  32. """
  33. from threading import Lock
  34. click.echo(click.style("Migrating models/tools to new plugin Mechanism", fg="white"))
  35. ended_at = datetime.datetime.now()
  36. started_at = datetime.datetime(2023, 4, 3, 8, 59, 24)
  37. current_time = started_at
  38. with Session(db.engine) as session:
  39. total_tenant_count = session.query(Tenant.id).count()
  40. click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white"))
  41. handled_tenant_count = 0
  42. file_lock = Lock()
  43. counter_lock = Lock()
  44. thread_pool = ThreadPoolExecutor(max_workers=workers)
  45. def process_tenant(flask_app: Flask, tenant_id: str) -> None:
  46. with flask_app.app_context():
  47. nonlocal handled_tenant_count
  48. try:
  49. plugins = cls.extract_installed_plugin_ids(tenant_id)
  50. # Use lock when writing to file
  51. with file_lock:
  52. with open(filepath, "a") as f:
  53. f.write(json.dumps({"tenant_id": tenant_id, "plugins": plugins}) + "\n")
  54. # Use lock when updating counter
  55. with counter_lock:
  56. nonlocal handled_tenant_count
  57. handled_tenant_count += 1
  58. click.echo(
  59. click.style(
  60. f"[{datetime.datetime.now()}] "
  61. f"Processed {handled_tenant_count} tenants "
  62. f"({(handled_tenant_count / total_tenant_count) * 100:.1f}%), "
  63. f"{handled_tenant_count}/{total_tenant_count}",
  64. fg="green",
  65. )
  66. )
  67. except Exception:
  68. logger.exception(f"Failed to process tenant {tenant_id}")
  69. futures = []
  70. while current_time < ended_at:
  71. click.echo(click.style(f"Current time: {current_time}, Started at: {datetime.datetime.now()}", fg="white"))
  72. # Initial interval of 1 day, will be dynamically adjusted based on tenant count
  73. interval = datetime.timedelta(days=1)
  74. # Process tenants in this batch
  75. with Session(db.engine) as session:
  76. # Calculate tenant count in next batch with current interval
  77. # Try different intervals until we find one with a reasonable tenant count
  78. test_intervals = [
  79. datetime.timedelta(days=1),
  80. datetime.timedelta(hours=12),
  81. datetime.timedelta(hours=6),
  82. datetime.timedelta(hours=3),
  83. datetime.timedelta(hours=1),
  84. ]
  85. for test_interval in test_intervals:
  86. tenant_count = (
  87. session.query(Tenant.id)
  88. .filter(Tenant.created_at.between(current_time, current_time + test_interval))
  89. .count()
  90. )
  91. if tenant_count <= 100:
  92. interval = test_interval
  93. break
  94. else:
  95. # If all intervals have too many tenants, use minimum interval
  96. interval = datetime.timedelta(hours=1)
  97. # Adjust interval to target ~100 tenants per batch
  98. if tenant_count > 0:
  99. # Scale interval based on ratio to target count
  100. interval = min(
  101. datetime.timedelta(days=1), # Max 1 day
  102. max(
  103. datetime.timedelta(hours=1), # Min 1 hour
  104. interval * (100 / tenant_count), # Scale to target 100
  105. ),
  106. )
  107. batch_end = min(current_time + interval, ended_at)
  108. rs = (
  109. session.query(Tenant.id)
  110. .filter(Tenant.created_at.between(current_time, batch_end))
  111. .order_by(Tenant.created_at)
  112. )
  113. tenants = []
  114. for row in rs:
  115. tenant_id = str(row.id)
  116. try:
  117. tenants.append(tenant_id)
  118. except Exception:
  119. logger.exception(f"Failed to process tenant {tenant_id}")
  120. continue
  121. futures.append(
  122. thread_pool.submit(
  123. process_tenant,
  124. current_app._get_current_object(), # type: ignore[attr-defined]
  125. tenant_id,
  126. )
  127. )
  128. current_time = batch_end
  129. # wait for all threads to finish
  130. for future in futures:
  131. future.result()
  132. @classmethod
  133. def extract_installed_plugin_ids(cls, tenant_id: str) -> Sequence[str]:
  134. """
  135. Extract installed plugin ids.
  136. """
  137. tools = cls.extract_tool_tables(tenant_id)
  138. models = cls.extract_model_tables(tenant_id)
  139. workflows = cls.extract_workflow_tables(tenant_id)
  140. apps = cls.extract_app_tables(tenant_id)
  141. return list({*tools, *models, *workflows, *apps})
  142. @classmethod
  143. def extract_model_tables(cls, tenant_id: str) -> Sequence[str]:
  144. """
  145. Extract model tables.
  146. """
  147. models: list[str] = []
  148. table_pairs = [
  149. ("providers", "provider_name"),
  150. ("provider_models", "provider_name"),
  151. ("provider_orders", "provider_name"),
  152. ("tenant_default_models", "provider_name"),
  153. ("tenant_preferred_model_providers", "provider_name"),
  154. ("provider_model_settings", "provider_name"),
  155. ("load_balancing_model_configs", "provider_name"),
  156. ]
  157. for table, column in table_pairs:
  158. models.extend(cls.extract_model_table(tenant_id, table, column))
  159. # duplicate models
  160. models = list(set(models))
  161. return models
  162. @classmethod
  163. def extract_model_table(cls, tenant_id: str, table: str, column: str) -> Sequence[str]:
  164. """
  165. Extract model table.
  166. """
  167. with Session(db.engine) as session:
  168. rs = session.execute(
  169. db.text(f"SELECT DISTINCT {column} FROM {table} WHERE tenant_id = :tenant_id"), {"tenant_id": tenant_id}
  170. )
  171. result = []
  172. for row in rs:
  173. provider_name = str(row[0])
  174. result.append(ModelProviderID(provider_name).plugin_id)
  175. return result
  176. @classmethod
  177. def extract_tool_tables(cls, tenant_id: str) -> Sequence[str]:
  178. """
  179. Extract tool tables.
  180. """
  181. with Session(db.engine) as session:
  182. rs = session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
  183. result = []
  184. for row in rs:
  185. result.append(ToolProviderID(row.provider).plugin_id)
  186. return result
  187. @classmethod
  188. def extract_workflow_tables(cls, tenant_id: str) -> Sequence[str]:
  189. """
  190. Extract workflow tables, only ToolNode is required.
  191. """
  192. with Session(db.engine) as session:
  193. rs = session.query(Workflow).filter(Workflow.tenant_id == tenant_id).all()
  194. result = []
  195. for row in rs:
  196. graph = row.graph_dict
  197. # get nodes
  198. nodes = graph.get("nodes", [])
  199. for node in nodes:
  200. data = node.get("data", {})
  201. if data.get("type") == "tool":
  202. provider_name = data.get("provider_name")
  203. provider_type = data.get("provider_type")
  204. if provider_name not in excluded_providers and provider_type == ToolProviderType.BUILT_IN.value:
  205. result.append(ToolProviderID(provider_name).plugin_id)
  206. return result
  207. @classmethod
  208. def extract_app_tables(cls, tenant_id: str) -> Sequence[str]:
  209. """
  210. Extract app tables.
  211. """
  212. with Session(db.engine) as session:
  213. apps = session.query(App).filter(App.tenant_id == tenant_id).all()
  214. if not apps:
  215. return []
  216. agent_app_model_config_ids = [
  217. app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT.value
  218. ]
  219. rs = session.query(AppModelConfig).filter(AppModelConfig.id.in_(agent_app_model_config_ids)).all()
  220. result = []
  221. for row in rs:
  222. agent_config = row.agent_mode_dict
  223. if "tools" in agent_config and isinstance(agent_config["tools"], list):
  224. for tool in agent_config["tools"]:
  225. if isinstance(tool, dict):
  226. try:
  227. tool_entity = AgentToolEntity(**tool)
  228. if (
  229. tool_entity.provider_type == ToolProviderType.BUILT_IN.value
  230. and tool_entity.provider_id not in excluded_providers
  231. ):
  232. result.append(ToolProviderID(tool_entity.provider_id).plugin_id)
  233. except Exception:
  234. logger.exception(f"Failed to process tool {tool}")
  235. continue
  236. return result
  237. @classmethod
  238. def _fetch_plugin_unique_identifier(cls, plugin_id: str) -> Optional[str]:
  239. """
  240. Fetch plugin unique identifier using plugin id.
  241. """
  242. plugin_manifest = marketplace.batch_fetch_plugin_manifests([plugin_id])
  243. if not plugin_manifest:
  244. return None
  245. return plugin_manifest[0].latest_package_identifier
  246. @classmethod
  247. def extract_unique_plugins_to_file(cls, extracted_plugins: str, output_file: str) -> None:
  248. """
  249. Extract unique plugins.
  250. """
  251. Path(output_file).write_text(json.dumps(cls.extract_unique_plugins(extracted_plugins)))
  252. @classmethod
  253. def extract_unique_plugins(cls, extracted_plugins: str) -> Mapping[str, Any]:
  254. plugins: dict[str, str] = {}
  255. plugin_ids = []
  256. plugin_not_exist = []
  257. logger.info(f"Extracting unique plugins from {extracted_plugins}")
  258. with open(extracted_plugins) as f:
  259. for line in f:
  260. data = json.loads(line)
  261. new_plugin_ids = data.get("plugins", [])
  262. for plugin_id in new_plugin_ids:
  263. if plugin_id not in plugin_ids:
  264. plugin_ids.append(plugin_id)
  265. def fetch_plugin(plugin_id):
  266. try:
  267. unique_identifier = cls._fetch_plugin_unique_identifier(plugin_id)
  268. if unique_identifier:
  269. plugins[plugin_id] = unique_identifier
  270. else:
  271. plugin_not_exist.append(plugin_id)
  272. except Exception:
  273. logger.exception(f"Failed to fetch plugin unique identifier for {plugin_id}")
  274. plugin_not_exist.append(plugin_id)
  275. with ThreadPoolExecutor(max_workers=10) as executor:
  276. list(tqdm.tqdm(executor.map(fetch_plugin, plugin_ids), total=len(plugin_ids)))
  277. return {"plugins": plugins, "plugin_not_exist": plugin_not_exist}
  278. @classmethod
  279. def install_plugins(cls, extracted_plugins: str, output_file: str, workers: int = 100) -> None:
  280. """
  281. Install plugins.
  282. """
  283. manager = PluginInstallationManager()
  284. plugins = cls.extract_unique_plugins(extracted_plugins)
  285. not_installed = []
  286. plugin_install_failed = []
  287. # use a fake tenant id to install all the plugins
  288. fake_tenant_id = uuid4().hex
  289. logger.info(f"Installing {len(plugins['plugins'])} plugin instances for fake tenant {fake_tenant_id}")
  290. thread_pool = ThreadPoolExecutor(max_workers=workers)
  291. response = cls.handle_plugin_instance_install(fake_tenant_id, plugins["plugins"])
  292. if response.get("failed"):
  293. plugin_install_failed.extend(response.get("failed", []))
  294. def install(tenant_id: str, plugin_ids: list[str]) -> None:
  295. logger.info(f"Installing {len(plugin_ids)} plugins for tenant {tenant_id}")
  296. # fetch plugin already installed
  297. installed_plugins = manager.list_plugins(tenant_id)
  298. installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
  299. # at most 64 plugins one batch
  300. for i in range(0, len(plugin_ids), 64):
  301. batch_plugin_ids = plugin_ids[i : i + 64]
  302. batch_plugin_identifiers = [
  303. plugins["plugins"][plugin_id]
  304. for plugin_id in batch_plugin_ids
  305. if plugin_id not in installed_plugins_ids and plugin_id in plugins["plugins"]
  306. ]
  307. manager.install_from_identifiers(
  308. tenant_id,
  309. batch_plugin_identifiers,
  310. PluginInstallationSource.Marketplace,
  311. metas=[
  312. {
  313. "plugin_unique_identifier": identifier,
  314. }
  315. for identifier in batch_plugin_identifiers
  316. ],
  317. )
  318. with open(extracted_plugins) as f:
  319. """
  320. Read line by line, and install plugins for each tenant.
  321. """
  322. for line in f:
  323. data = json.loads(line)
  324. tenant_id = data.get("tenant_id")
  325. plugin_ids = data.get("plugins", [])
  326. current_not_installed = {
  327. "tenant_id": tenant_id,
  328. "plugin_not_exist": [],
  329. }
  330. # get plugin unique identifier
  331. for plugin_id in plugin_ids:
  332. unique_identifier = plugins.get(plugin_id)
  333. if unique_identifier:
  334. current_not_installed["plugin_not_exist"].append(plugin_id)
  335. if current_not_installed["plugin_not_exist"]:
  336. not_installed.append(current_not_installed)
  337. thread_pool.submit(install, tenant_id, plugin_ids)
  338. thread_pool.shutdown(wait=True)
  339. logger.info("Uninstall plugins")
  340. # get installation
  341. try:
  342. installation = manager.list_plugins(fake_tenant_id)
  343. while installation:
  344. for plugin in installation:
  345. manager.uninstall(fake_tenant_id, plugin.installation_id)
  346. installation = manager.list_plugins(fake_tenant_id)
  347. except Exception:
  348. logger.exception(f"Failed to get installation for tenant {fake_tenant_id}")
  349. Path(output_file).write_text(
  350. json.dumps(
  351. {
  352. "not_installed": not_installed,
  353. "plugin_install_failed": plugin_install_failed,
  354. }
  355. )
  356. )
  357. @classmethod
  358. def handle_plugin_instance_install(
  359. cls, tenant_id: str, plugin_identifiers_map: Mapping[str, str]
  360. ) -> Mapping[str, Any]:
  361. """
  362. Install plugins for a tenant.
  363. """
  364. manager = PluginInstallationManager()
  365. # download all the plugins and upload
  366. thread_pool = ThreadPoolExecutor(max_workers=10)
  367. futures = []
  368. for plugin_id, plugin_identifier in plugin_identifiers_map.items():
  369. def download_and_upload(tenant_id, plugin_id, plugin_identifier):
  370. plugin_package = marketplace.download_plugin_pkg(plugin_identifier)
  371. if not plugin_package:
  372. raise Exception(f"Failed to download plugin {plugin_identifier}")
  373. # upload
  374. manager.upload_pkg(tenant_id, plugin_package, verify_signature=True)
  375. futures.append(thread_pool.submit(download_and_upload, tenant_id, plugin_id, plugin_identifier))
  376. # Wait for all downloads to complete
  377. for future in futures:
  378. future.result() # This will raise any exceptions that occurred
  379. thread_pool.shutdown(wait=True)
  380. success = []
  381. failed = []
  382. reverse_map = {v: k for k, v in plugin_identifiers_map.items()}
  383. # at most 8 plugins one batch
  384. for i in range(0, len(plugin_identifiers_map), 8):
  385. batch_plugin_ids = list(plugin_identifiers_map.keys())[i : i + 8]
  386. batch_plugin_identifiers = [plugin_identifiers_map[plugin_id] for plugin_id in batch_plugin_ids]
  387. try:
  388. response = manager.install_from_identifiers(
  389. tenant_id=tenant_id,
  390. identifiers=batch_plugin_identifiers,
  391. source=PluginInstallationSource.Marketplace,
  392. metas=[
  393. {
  394. "plugin_unique_identifier": identifier,
  395. }
  396. for identifier in batch_plugin_identifiers
  397. ],
  398. )
  399. except Exception:
  400. # add to failed
  401. failed.extend(batch_plugin_identifiers)
  402. continue
  403. if response.all_installed:
  404. success.extend(batch_plugin_identifiers)
  405. continue
  406. task_id = response.task_id
  407. done = False
  408. while not done:
  409. status = manager.fetch_plugin_installation_task(tenant_id, task_id)
  410. if status.status in [PluginInstallTaskStatus.Failed, PluginInstallTaskStatus.Success]:
  411. for plugin in status.plugins:
  412. if plugin.status == PluginInstallTaskStatus.Success:
  413. success.append(reverse_map[plugin.plugin_unique_identifier])
  414. else:
  415. failed.append(reverse_map[plugin.plugin_unique_identifier])
  416. logger.error(
  417. f"Failed to install plugin {plugin.plugin_unique_identifier}, error: {plugin.message}"
  418. )
  419. done = True
  420. else:
  421. time.sleep(1)
  422. return {"success": success, "failed": failed}