123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501 |
- import datetime
- import json
- import logging
- import time
- from collections.abc import Mapping, Sequence
- from concurrent.futures import ThreadPoolExecutor
- from pathlib import Path
- from typing import Any, Optional
- from uuid import uuid4
- import click
- import tqdm
- from flask import Flask, current_app
- from sqlalchemy.orm import Session
- from core.agent.entities import AgentToolEntity
- from core.helper import marketplace
- from core.plugin.entities.plugin import ModelProviderID, PluginInstallationSource, ToolProviderID
- from core.plugin.entities.plugin_daemon import PluginInstallTaskStatus
- from core.plugin.manager.plugin import PluginInstallationManager
- from core.tools.entities.tool_entities import ToolProviderType
- from models.account import Tenant
- from models.engine import db
- from models.model import App, AppMode, AppModelConfig
- from models.tools import BuiltinToolProvider
- from models.workflow import Workflow
- logger = logging.getLogger(__name__)
- excluded_providers = ["time", "audio", "code", "webscraper"]
- class PluginMigration:
- @classmethod
- def extract_plugins(cls, filepath: str, workers: int) -> None:
- """
- Migrate plugin.
- """
- from threading import Lock
- click.echo(click.style("Migrating models/tools to new plugin Mechanism", fg="white"))
- ended_at = datetime.datetime.now()
- started_at = datetime.datetime(2023, 4, 3, 8, 59, 24)
- current_time = started_at
- with Session(db.engine) as session:
- total_tenant_count = session.query(Tenant.id).count()
- click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white"))
- handled_tenant_count = 0
- file_lock = Lock()
- counter_lock = Lock()
- thread_pool = ThreadPoolExecutor(max_workers=workers)
- def process_tenant(flask_app: Flask, tenant_id: str) -> None:
- with flask_app.app_context():
- nonlocal handled_tenant_count
- try:
- plugins = cls.extract_installed_plugin_ids(tenant_id)
- # Use lock when writing to file
- with file_lock:
- with open(filepath, "a") as f:
- f.write(json.dumps({"tenant_id": tenant_id, "plugins": plugins}) + "\n")
- # Use lock when updating counter
- with counter_lock:
- nonlocal handled_tenant_count
- handled_tenant_count += 1
- click.echo(
- click.style(
- f"[{datetime.datetime.now()}] "
- f"Processed {handled_tenant_count} tenants "
- f"({(handled_tenant_count / total_tenant_count) * 100:.1f}%), "
- f"{handled_tenant_count}/{total_tenant_count}",
- fg="green",
- )
- )
- except Exception:
- logger.exception(f"Failed to process tenant {tenant_id}")
- futures = []
- while current_time < ended_at:
- click.echo(click.style(f"Current time: {current_time}, Started at: {datetime.datetime.now()}", fg="white"))
- # Initial interval of 1 day, will be dynamically adjusted based on tenant count
- interval = datetime.timedelta(days=1)
- # Process tenants in this batch
- with Session(db.engine) as session:
- # Calculate tenant count in next batch with current interval
- # Try different intervals until we find one with a reasonable tenant count
- test_intervals = [
- datetime.timedelta(days=1),
- datetime.timedelta(hours=12),
- datetime.timedelta(hours=6),
- datetime.timedelta(hours=3),
- datetime.timedelta(hours=1),
- ]
- for test_interval in test_intervals:
- tenant_count = (
- session.query(Tenant.id)
- .filter(Tenant.created_at.between(current_time, current_time + test_interval))
- .count()
- )
- if tenant_count <= 100:
- interval = test_interval
- break
- else:
- # If all intervals have too many tenants, use minimum interval
- interval = datetime.timedelta(hours=1)
- # Adjust interval to target ~100 tenants per batch
- if tenant_count > 0:
- # Scale interval based on ratio to target count
- interval = min(
- datetime.timedelta(days=1), # Max 1 day
- max(
- datetime.timedelta(hours=1), # Min 1 hour
- interval * (100 / tenant_count), # Scale to target 100
- ),
- )
- batch_end = min(current_time + interval, ended_at)
- rs = (
- session.query(Tenant.id)
- .filter(Tenant.created_at.between(current_time, batch_end))
- .order_by(Tenant.created_at)
- )
- tenants = []
- for row in rs:
- tenant_id = str(row.id)
- try:
- tenants.append(tenant_id)
- except Exception:
- logger.exception(f"Failed to process tenant {tenant_id}")
- continue
- futures.append(
- thread_pool.submit(
- process_tenant,
- current_app._get_current_object(), # type: ignore[attr-defined]
- tenant_id,
- )
- )
- current_time = batch_end
- # wait for all threads to finish
- for future in futures:
- future.result()
- @classmethod
- def extract_installed_plugin_ids(cls, tenant_id: str) -> Sequence[str]:
- """
- Extract installed plugin ids.
- """
- tools = cls.extract_tool_tables(tenant_id)
- models = cls.extract_model_tables(tenant_id)
- workflows = cls.extract_workflow_tables(tenant_id)
- apps = cls.extract_app_tables(tenant_id)
- return list({*tools, *models, *workflows, *apps})
- @classmethod
- def extract_model_tables(cls, tenant_id: str) -> Sequence[str]:
- """
- Extract model tables.
- """
- models: list[str] = []
- table_pairs = [
- ("providers", "provider_name"),
- ("provider_models", "provider_name"),
- ("provider_orders", "provider_name"),
- ("tenant_default_models", "provider_name"),
- ("tenant_preferred_model_providers", "provider_name"),
- ("provider_model_settings", "provider_name"),
- ("load_balancing_model_configs", "provider_name"),
- ]
- for table, column in table_pairs:
- models.extend(cls.extract_model_table(tenant_id, table, column))
- # duplicate models
- models = list(set(models))
- return models
- @classmethod
- def extract_model_table(cls, tenant_id: str, table: str, column: str) -> Sequence[str]:
- """
- Extract model table.
- """
- with Session(db.engine) as session:
- rs = session.execute(
- db.text(f"SELECT DISTINCT {column} FROM {table} WHERE tenant_id = :tenant_id"), {"tenant_id": tenant_id}
- )
- result = []
- for row in rs:
- provider_name = str(row[0])
- result.append(ModelProviderID(provider_name).plugin_id)
- return result
- @classmethod
- def extract_tool_tables(cls, tenant_id: str) -> Sequence[str]:
- """
- Extract tool tables.
- """
- with Session(db.engine) as session:
- rs = session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
- result = []
- for row in rs:
- result.append(ToolProviderID(row.provider).plugin_id)
- return result
- @classmethod
- def extract_workflow_tables(cls, tenant_id: str) -> Sequence[str]:
- """
- Extract workflow tables, only ToolNode is required.
- """
- with Session(db.engine) as session:
- rs = session.query(Workflow).filter(Workflow.tenant_id == tenant_id).all()
- result = []
- for row in rs:
- graph = row.graph_dict
- # get nodes
- nodes = graph.get("nodes", [])
- for node in nodes:
- data = node.get("data", {})
- if data.get("type") == "tool":
- provider_name = data.get("provider_name")
- provider_type = data.get("provider_type")
- if provider_name not in excluded_providers and provider_type == ToolProviderType.BUILT_IN.value:
- result.append(ToolProviderID(provider_name).plugin_id)
- return result
- @classmethod
- def extract_app_tables(cls, tenant_id: str) -> Sequence[str]:
- """
- Extract app tables.
- """
- with Session(db.engine) as session:
- apps = session.query(App).filter(App.tenant_id == tenant_id).all()
- if not apps:
- return []
- agent_app_model_config_ids = [
- app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT.value
- ]
- rs = session.query(AppModelConfig).filter(AppModelConfig.id.in_(agent_app_model_config_ids)).all()
- result = []
- for row in rs:
- agent_config = row.agent_mode_dict
- if "tools" in agent_config and isinstance(agent_config["tools"], list):
- for tool in agent_config["tools"]:
- if isinstance(tool, dict):
- try:
- tool_entity = AgentToolEntity(**tool)
- if (
- tool_entity.provider_type == ToolProviderType.BUILT_IN.value
- and tool_entity.provider_id not in excluded_providers
- ):
- result.append(ToolProviderID(tool_entity.provider_id).plugin_id)
- except Exception:
- logger.exception(f"Failed to process tool {tool}")
- continue
- return result
- @classmethod
- def _fetch_plugin_unique_identifier(cls, plugin_id: str) -> Optional[str]:
- """
- Fetch plugin unique identifier using plugin id.
- """
- plugin_manifest = marketplace.batch_fetch_plugin_manifests([plugin_id])
- if not plugin_manifest:
- return None
- return plugin_manifest[0].latest_package_identifier
- @classmethod
- def extract_unique_plugins_to_file(cls, extracted_plugins: str, output_file: str) -> None:
- """
- Extract unique plugins.
- """
- Path(output_file).write_text(json.dumps(cls.extract_unique_plugins(extracted_plugins)))
- @classmethod
- def extract_unique_plugins(cls, extracted_plugins: str) -> Mapping[str, Any]:
- plugins: dict[str, str] = {}
- plugin_ids = []
- plugin_not_exist = []
- logger.info(f"Extracting unique plugins from {extracted_plugins}")
- with open(extracted_plugins) as f:
- for line in f:
- data = json.loads(line)
- new_plugin_ids = data.get("plugins", [])
- for plugin_id in new_plugin_ids:
- if plugin_id not in plugin_ids:
- plugin_ids.append(plugin_id)
- def fetch_plugin(plugin_id):
- try:
- unique_identifier = cls._fetch_plugin_unique_identifier(plugin_id)
- if unique_identifier:
- plugins[plugin_id] = unique_identifier
- else:
- plugin_not_exist.append(plugin_id)
- except Exception:
- logger.exception(f"Failed to fetch plugin unique identifier for {plugin_id}")
- plugin_not_exist.append(plugin_id)
- with ThreadPoolExecutor(max_workers=10) as executor:
- list(tqdm.tqdm(executor.map(fetch_plugin, plugin_ids), total=len(plugin_ids)))
- return {"plugins": plugins, "plugin_not_exist": plugin_not_exist}
- @classmethod
- def install_plugins(cls, extracted_plugins: str, output_file: str, workers: int = 100) -> None:
- """
- Install plugins.
- """
- manager = PluginInstallationManager()
- plugins = cls.extract_unique_plugins(extracted_plugins)
- not_installed = []
- plugin_install_failed = []
- # use a fake tenant id to install all the plugins
- fake_tenant_id = uuid4().hex
- logger.info(f"Installing {len(plugins['plugins'])} plugin instances for fake tenant {fake_tenant_id}")
- thread_pool = ThreadPoolExecutor(max_workers=workers)
- response = cls.handle_plugin_instance_install(fake_tenant_id, plugins["plugins"])
- if response.get("failed"):
- plugin_install_failed.extend(response.get("failed", []))
- def install(tenant_id: str, plugin_ids: list[str]) -> None:
- logger.info(f"Installing {len(plugin_ids)} plugins for tenant {tenant_id}")
- # fetch plugin already installed
- installed_plugins = manager.list_plugins(tenant_id)
- installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
- # at most 64 plugins one batch
- for i in range(0, len(plugin_ids), 64):
- batch_plugin_ids = plugin_ids[i : i + 64]
- batch_plugin_identifiers = [
- plugins["plugins"][plugin_id]
- for plugin_id in batch_plugin_ids
- if plugin_id not in installed_plugins_ids and plugin_id in plugins["plugins"]
- ]
- manager.install_from_identifiers(
- tenant_id,
- batch_plugin_identifiers,
- PluginInstallationSource.Marketplace,
- metas=[
- {
- "plugin_unique_identifier": identifier,
- }
- for identifier in batch_plugin_identifiers
- ],
- )
- with open(extracted_plugins) as f:
- """
- Read line by line, and install plugins for each tenant.
- """
- for line in f:
- data = json.loads(line)
- tenant_id = data.get("tenant_id")
- plugin_ids = data.get("plugins", [])
- current_not_installed = {
- "tenant_id": tenant_id,
- "plugin_not_exist": [],
- }
- # get plugin unique identifier
- for plugin_id in plugin_ids:
- unique_identifier = plugins.get(plugin_id)
- if unique_identifier:
- current_not_installed["plugin_not_exist"].append(plugin_id)
- if current_not_installed["plugin_not_exist"]:
- not_installed.append(current_not_installed)
- thread_pool.submit(install, tenant_id, plugin_ids)
- thread_pool.shutdown(wait=True)
- logger.info("Uninstall plugins")
- # get installation
- try:
- installation = manager.list_plugins(fake_tenant_id)
- while installation:
- for plugin in installation:
- manager.uninstall(fake_tenant_id, plugin.installation_id)
- installation = manager.list_plugins(fake_tenant_id)
- except Exception:
- logger.exception(f"Failed to get installation for tenant {fake_tenant_id}")
- Path(output_file).write_text(
- json.dumps(
- {
- "not_installed": not_installed,
- "plugin_install_failed": plugin_install_failed,
- }
- )
- )
- @classmethod
- def handle_plugin_instance_install(
- cls, tenant_id: str, plugin_identifiers_map: Mapping[str, str]
- ) -> Mapping[str, Any]:
- """
- Install plugins for a tenant.
- """
- manager = PluginInstallationManager()
- # download all the plugins and upload
- thread_pool = ThreadPoolExecutor(max_workers=10)
- futures = []
- for plugin_id, plugin_identifier in plugin_identifiers_map.items():
- def download_and_upload(tenant_id, plugin_id, plugin_identifier):
- plugin_package = marketplace.download_plugin_pkg(plugin_identifier)
- if not plugin_package:
- raise Exception(f"Failed to download plugin {plugin_identifier}")
- # upload
- manager.upload_pkg(tenant_id, plugin_package, verify_signature=True)
- futures.append(thread_pool.submit(download_and_upload, tenant_id, plugin_id, plugin_identifier))
- # Wait for all downloads to complete
- for future in futures:
- future.result() # This will raise any exceptions that occurred
- thread_pool.shutdown(wait=True)
- success = []
- failed = []
- reverse_map = {v: k for k, v in plugin_identifiers_map.items()}
- # at most 8 plugins one batch
- for i in range(0, len(plugin_identifiers_map), 8):
- batch_plugin_ids = list(plugin_identifiers_map.keys())[i : i + 8]
- batch_plugin_identifiers = [plugin_identifiers_map[plugin_id] for plugin_id in batch_plugin_ids]
- try:
- response = manager.install_from_identifiers(
- tenant_id=tenant_id,
- identifiers=batch_plugin_identifiers,
- source=PluginInstallationSource.Marketplace,
- metas=[
- {
- "plugin_unique_identifier": identifier,
- }
- for identifier in batch_plugin_identifiers
- ],
- )
- except Exception:
- # add to failed
- failed.extend(batch_plugin_identifiers)
- continue
- if response.all_installed:
- success.extend(batch_plugin_identifiers)
- continue
- task_id = response.task_id
- done = False
- while not done:
- status = manager.fetch_plugin_installation_task(tenant_id, task_id)
- if status.status in [PluginInstallTaskStatus.Failed, PluginInstallTaskStatus.Success]:
- for plugin in status.plugins:
- if plugin.status == PluginInstallTaskStatus.Success:
- success.append(reverse_map[plugin.plugin_unique_identifier])
- else:
- failed.append(reverse_map[plugin.plugin_unique_identifier])
- logger.error(
- f"Failed to install plugin {plugin.plugin_unique_identifier}, error: {plugin.message}"
- )
- done = True
- else:
- time.sleep(1)
- return {"success": success, "failed": failed}
|