|
@@ -0,0 +1,163 @@
|
|
|
+import asyncio
|
|
|
+from dataclasses import asdict, dataclass
|
|
|
+from enum import Enum
|
|
|
+from typing import Any, Optional, Union
|
|
|
+
|
|
|
+import aiohttp
|
|
|
+from pydantic import ConfigDict
|
|
|
+
|
|
|
+from core.tools.entities.tool_entities import ToolInvokeMessage
|
|
|
+from core.tools.errors import ToolProviderCredentialValidationError
|
|
|
+from core.tools.tool.builtin_tool import BuiltinTool
|
|
|
+
|
|
|
+
|
|
|
+class SlidesGeneratorTool(BuiltinTool):
|
|
|
+ """
|
|
|
+ Tool for generating presentations using the SlideSpeak API.
|
|
|
+ """
|
|
|
+
|
|
|
+ model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
|
+
|
|
|
+ headers: Optional[dict[str, str]] = None
|
|
|
+ base_url: Optional[str] = None
|
|
|
+ timeout: Optional[aiohttp.ClientTimeout] = None
|
|
|
+ poll_interval: Optional[int] = None
|
|
|
+
|
|
|
+ class TaskState(Enum):
|
|
|
+ FAILURE = "FAILURE"
|
|
|
+ REVOKED = "REVOKED"
|
|
|
+ SUCCESS = "SUCCESS"
|
|
|
+ PENDING = "PENDING"
|
|
|
+ RECEIVED = "RECEIVED"
|
|
|
+ STARTED = "STARTED"
|
|
|
+
|
|
|
+ @dataclass
|
|
|
+ class PresentationRequest:
|
|
|
+ plain_text: str
|
|
|
+ length: Optional[int] = None
|
|
|
+ theme: Optional[str] = None
|
|
|
+
|
|
|
+ async def _generate_presentation(
|
|
|
+ self,
|
|
|
+ session: aiohttp.ClientSession,
|
|
|
+ request: PresentationRequest,
|
|
|
+ ) -> dict[str, Any]:
|
|
|
+ """Generate a new presentation asynchronously"""
|
|
|
+ async with session.post(
|
|
|
+ f"{self.base_url}/presentation/generate",
|
|
|
+ headers=self.headers,
|
|
|
+ json=asdict(request),
|
|
|
+ timeout=self.timeout,
|
|
|
+ ) as response:
|
|
|
+ response.raise_for_status()
|
|
|
+ return await response.json()
|
|
|
+
|
|
|
+ async def _get_task_status(
|
|
|
+ self,
|
|
|
+ session: aiohttp.ClientSession,
|
|
|
+ task_id: str,
|
|
|
+ ) -> dict[str, Any]:
|
|
|
+ """Get the status of a task asynchronously"""
|
|
|
+ async with session.get(
|
|
|
+ f"{self.base_url}/task_status/{task_id}",
|
|
|
+ headers=self.headers,
|
|
|
+ timeout=self.timeout,
|
|
|
+ ) as response:
|
|
|
+ response.raise_for_status()
|
|
|
+ return await response.json()
|
|
|
+
|
|
|
+ async def _wait_for_completion(
|
|
|
+ self,
|
|
|
+ session: aiohttp.ClientSession,
|
|
|
+ task_id: str,
|
|
|
+ ) -> str:
|
|
|
+ """Wait for task completion and return download URL"""
|
|
|
+ while True:
|
|
|
+ status = await self._get_task_status(session, task_id)
|
|
|
+ task_status = self.TaskState(status["task_status"])
|
|
|
+ if task_status == self.TaskState.SUCCESS:
|
|
|
+ return status["task_result"]["url"]
|
|
|
+ if task_status in [self.TaskState.FAILURE, self.TaskState.REVOKED]:
|
|
|
+ raise Exception(f"Task failed with status: {task_status.value}")
|
|
|
+ await asyncio.sleep(self.poll_interval)
|
|
|
+
|
|
|
+ async def _generate_slides(
|
|
|
+ self,
|
|
|
+ plain_text: str,
|
|
|
+ length: Optional[int],
|
|
|
+ theme: Optional[str],
|
|
|
+ ) -> str:
|
|
|
+ """Generate slides and return the download URL"""
|
|
|
+ async with aiohttp.ClientSession() as session:
|
|
|
+ request = self.PresentationRequest(
|
|
|
+ plain_text=plain_text,
|
|
|
+ length=length,
|
|
|
+ theme=theme,
|
|
|
+ )
|
|
|
+ result = await self._generate_presentation(session, request)
|
|
|
+ task_id = result["task_id"]
|
|
|
+ download_url = await self._wait_for_completion(session, task_id)
|
|
|
+ return download_url
|
|
|
+
|
|
|
+ async def _fetch_presentation(
|
|
|
+ self,
|
|
|
+ session: aiohttp.ClientSession,
|
|
|
+ download_url: str,
|
|
|
+ ) -> bytes:
|
|
|
+ """Fetch the presentation file from the download URL"""
|
|
|
+ async with session.get(download_url, timeout=self.timeout) as response:
|
|
|
+ response.raise_for_status()
|
|
|
+ return await response.read()
|
|
|
+
|
|
|
+ def _invoke(
|
|
|
+ self,
|
|
|
+ user_id: str,
|
|
|
+ tool_parameters: dict[str, Any],
|
|
|
+ ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
|
|
+ """Synchronous invoke method that runs asynchronous code"""
|
|
|
+
|
|
|
+ async def async_invoke():
|
|
|
+ # Extract parameters
|
|
|
+ plain_text = tool_parameters.get("plain_text", "")
|
|
|
+ length = tool_parameters.get("length")
|
|
|
+ theme = tool_parameters.get("theme")
|
|
|
+
|
|
|
+ # Ensure runtime and credentials
|
|
|
+ if not self.runtime or not self.runtime.credentials:
|
|
|
+ raise ToolProviderCredentialValidationError("Tool runtime or credentials are missing")
|
|
|
+
|
|
|
+ # Get API key from credentials
|
|
|
+ api_key = self.runtime.credentials.get("slidespeak_api_key")
|
|
|
+ if not api_key:
|
|
|
+ raise ToolProviderCredentialValidationError("SlideSpeak API key is missing")
|
|
|
+
|
|
|
+ # Set configuration
|
|
|
+ self.headers = {
|
|
|
+ "Content-Type": "application/json",
|
|
|
+ "X-API-Key": api_key,
|
|
|
+ }
|
|
|
+ self.base_url = "https://api.slidespeak.co/api/v1"
|
|
|
+ self.timeout = aiohttp.ClientTimeout(total=30)
|
|
|
+ self.poll_interval = 2
|
|
|
+
|
|
|
+ # Run the asynchronous slide generation
|
|
|
+ try:
|
|
|
+ download_url = await self._generate_slides(plain_text, length, theme)
|
|
|
+
|
|
|
+ # Fetch the presentation file
|
|
|
+ async with aiohttp.ClientSession() as session:
|
|
|
+ presentation_bytes = await self._fetch_presentation(session, download_url)
|
|
|
+
|
|
|
+ return [
|
|
|
+ self.create_text_message("Presentation generated successfully"),
|
|
|
+ self.create_blob_message(
|
|
|
+ blob=presentation_bytes,
|
|
|
+ meta={"mime_type": "application/vnd.openxmlformats-officedocument.presentationml.presentation"},
|
|
|
+ ),
|
|
|
+ ]
|
|
|
+ except Exception as e:
|
|
|
+ return [self.create_text_message(f"An error occurred: {str(e)}")]
|
|
|
+
|
|
|
+ # Run the asynchronous code synchronously
|
|
|
+ result = asyncio.run(async_invoke())
|
|
|
+ return result
|