瀏覽代碼

feat(tools/podcast_generator): add support for setting openai base url with the podcast_generationor tool (#10517)

Xiao Ley 5 月之前
父節點
當前提交
451ccb778d

+ 12 - 0
api/core/tools/provider/builtin/podcast_generator/podcast_generator.yaml

@@ -32,3 +32,15 @@ credentials_for_provider:
     placeholder:
       en_US: Enter your TTS service API key
       zh_Hans: 输入您的 TTS 服务 API 密钥
+  openai_base_url:
+    type: text-input
+    required: false
+    label:
+      en_US: OpenAI base URL
+      zh_Hans: OpenAI base URL
+    help:
+      en_US: Please input your OpenAI base URL
+      zh_Hans: 请输入你的 OpenAI base URL
+    placeholder:
+      en_US: Please input your OpenAI base URL
+      zh_Hans: 请输入你的 OpenAI base URL

+ 12 - 2
api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py

@@ -5,6 +5,7 @@ import warnings
 from typing import Any, Literal, Optional, Union
 
 import openai
+from yarl import URL
 
 from core.tools.entities.tool_entities import ToolInvokeMessage
 from core.tools.errors import ToolParameterValidationError, ToolProviderCredentialValidationError
@@ -53,15 +54,24 @@ class PodcastAudioGeneratorTool(BuiltinTool):
         if not host1_voice or not host2_voice:
             raise ToolParameterValidationError("Host voices are required")
 
-        # Get OpenAI API key from credentials
+        # Ensure runtime and credentials
         if not self.runtime or not self.runtime.credentials:
             raise ToolProviderCredentialValidationError("Tool runtime or credentials are missing")
+
+        # Get OpenAI API key from credentials
         api_key = self.runtime.credentials.get("api_key")
         if not api_key:
             raise ToolProviderCredentialValidationError("OpenAI API key is missing")
 
+        # Get OpenAI base URL
+        openai_base_url = self.runtime.credentials.get("openai_base_url", None)
+        openai_base_url = str(URL(openai_base_url) / "v1") if openai_base_url else None
+
         # Initialize OpenAI client
-        client = openai.OpenAI(api_key=api_key)
+        client = openai.OpenAI(
+            api_key=api_key,
+            base_url=openai_base_url,
+        )
 
         # Create a thread pool
         max_workers = 5