|
@@ -7,6 +7,8 @@ from pydantic import BaseModel, Field
|
|
|
from core.tools.entities.tool_entities import ToolInvokeMessage
|
|
|
from core.tools.tool.builtin_tool import BuiltinTool
|
|
|
|
|
|
+BRAVE_BASE_URL = "https://api.search.brave.com/res/v1/web/search"
|
|
|
+
|
|
|
|
|
|
class BraveSearchWrapper(BaseModel):
|
|
|
"""Wrapper around the Brave search engine."""
|
|
@@ -15,8 +17,10 @@ class BraveSearchWrapper(BaseModel):
|
|
|
"""The API key to use for the Brave search engine."""
|
|
|
search_kwargs: dict = Field(default_factory=dict)
|
|
|
"""Additional keyword arguments to pass to the search request."""
|
|
|
- base_url: str = "https://api.search.brave.com/res/v1/web/search"
|
|
|
+ base_url: str = BRAVE_BASE_URL
|
|
|
"""The base URL for the Brave search engine."""
|
|
|
+ ensure_ascii: bool = True
|
|
|
+ """Ensure the JSON output is ASCII encoded."""
|
|
|
|
|
|
def run(self, query: str) -> str:
|
|
|
"""Query the Brave search engine and return the results as a JSON string.
|
|
@@ -36,7 +40,7 @@ class BraveSearchWrapper(BaseModel):
|
|
|
}
|
|
|
for item in web_search_results
|
|
|
]
|
|
|
- return json.dumps(final_results)
|
|
|
+ return json.dumps(final_results, ensure_ascii=self.ensure_ascii)
|
|
|
|
|
|
def _search_request(self, query: str) -> list[dict]:
|
|
|
headers = {
|
|
@@ -68,7 +72,9 @@ class BraveSearch(BaseModel):
|
|
|
search_wrapper: BraveSearchWrapper
|
|
|
|
|
|
@classmethod
|
|
|
- def from_api_key(cls, api_key: str, search_kwargs: Optional[dict] = None, **kwargs: Any) -> "BraveSearch":
|
|
|
+ def from_api_key(
|
|
|
+ cls, api_key: str, base_url: str, search_kwargs: Optional[dict] = None, ensure_ascii: bool = True, **kwargs: Any
|
|
|
+ ) -> "BraveSearch":
|
|
|
"""Create a tool from an api key.
|
|
|
|
|
|
Args:
|
|
@@ -79,7 +85,9 @@ class BraveSearch(BaseModel):
|
|
|
Returns:
|
|
|
A tool.
|
|
|
"""
|
|
|
- wrapper = BraveSearchWrapper(api_key=api_key, search_kwargs=search_kwargs or {})
|
|
|
+ wrapper = BraveSearchWrapper(
|
|
|
+ api_key=api_key, base_url=base_url, search_kwargs=search_kwargs or {}, ensure_ascii=ensure_ascii
|
|
|
+ )
|
|
|
return cls(search_wrapper=wrapper, **kwargs)
|
|
|
|
|
|
def _run(
|
|
@@ -109,11 +117,18 @@ class BraveSearchTool(BuiltinTool):
|
|
|
query = tool_parameters.get("query", "")
|
|
|
count = tool_parameters.get("count", 3)
|
|
|
api_key = self.runtime.credentials["brave_search_api_key"]
|
|
|
+ base_url = self.runtime.credentials.get("base_url", BRAVE_BASE_URL)
|
|
|
+ ensure_ascii = tool_parameters.get("ensure_ascii", True)
|
|
|
+
|
|
|
+ if len(base_url) == 0:
|
|
|
+ base_url = BRAVE_BASE_URL
|
|
|
|
|
|
if not query:
|
|
|
return self.create_text_message("Please input query")
|
|
|
|
|
|
- tool = BraveSearch.from_api_key(api_key=api_key, search_kwargs={"count": count})
|
|
|
+ tool = BraveSearch.from_api_key(
|
|
|
+ api_key=api_key, base_url=base_url, search_kwargs={"count": count}, ensure_ascii=ensure_ascii
|
|
|
+ )
|
|
|
|
|
|
results = tool._run(query)
|
|
|
|