瀏覽代碼

chore: remove langchain in tools (#3247)

Yeuoly 1 年之前
父節點
當前提交
e635f3dc1d

+ 24 - 4
api/core/callback_handler/agent_tool_callback_handler.py

@@ -1,12 +1,32 @@
 import os
-from typing import Any, Optional, Union
+from typing import Any, Optional, TextIO, Union
 
-from langchain.callbacks.base import BaseCallbackHandler
-from langchain.input import print_text
 from pydantic import BaseModel
 
+_TEXT_COLOR_MAPPING = {
+    "blue": "36;1",
+    "yellow": "33;1",
+    "pink": "38;5;200",
+    "green": "32;1",
+    "red": "31;1",
+}
 
-class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel):
+def get_colored_text(text: str, color: str) -> str:
+    """Get colored text."""
+    color_str = _TEXT_COLOR_MAPPING[color]
+    return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
+
+
+def print_text(
+    text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None
+) -> None:
+    """Print text with highlighting and no end characters."""
+    text_to_print = get_colored_text(text, color) if color else text
+    print(text_to_print, end=end, file=file)
+    if file:
+        file.flush()  # ensure all printed content are written to file
+
+class DifyAgentCallbackHandler(BaseModel):
     """Callback Handler that prints to std out."""
     color: Optional[str] = ''
     current_loop = 1

+ 83 - 2
api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py

@@ -1,11 +1,92 @@
-from typing import Any
+import logging
+from typing import Any, Optional
 
-from langchain.utilities import ArxivAPIWrapper
+import arxiv
 from pydantic import BaseModel, Field
 
 from core.tools.entities.tool_entities import ToolInvokeMessage
 from core.tools.tool.builtin_tool import BuiltinTool
 
+logger = logging.getLogger(__name__)
+class ArxivAPIWrapper(BaseModel):
+    """Wrapper around ArxivAPI.
+
+    To use, you should have the ``arxiv`` python package installed.
+    https://lukasschwab.me/arxiv.py/index.html
+    This wrapper will use the Arxiv API to conduct searches and
+    fetch document summaries. By default, it will return the document summaries
+    of the top-k results.
+    It limits the Document content by doc_content_chars_max.
+    Set doc_content_chars_max=None if you don't want to limit the content size.
+
+    Args:
+        top_k_results: number of the top-scored document used for the arxiv tool
+        ARXIV_MAX_QUERY_LENGTH: the cut limit on the query used for the arxiv tool.
+        load_max_docs: a limit to the number of loaded documents
+        load_all_available_meta:
+            if True: the `metadata` of the loaded Documents contains all available
+            meta info (see https://lukasschwab.me/arxiv.py/index.html#Result),
+            if False: the `metadata` contains only the published date, title,
+            authors and summary.
+        doc_content_chars_max: an optional cut limit for the length of a document's
+            content
+
+    Example:
+        .. code-block:: python
+
+            arxiv = ArxivAPIWrapper(
+                top_k_results = 3,
+                ARXIV_MAX_QUERY_LENGTH = 300,
+                load_max_docs = 3,
+                load_all_available_meta = False,
+                doc_content_chars_max = 40000
+            )
+            arxiv.run("tree of thought llm)
+    """
+
+    arxiv_search = arxiv.Search  #: :meta private:
+    arxiv_exceptions = (
+        arxiv.ArxivError,
+        arxiv.UnexpectedEmptyPageError,
+        arxiv.HTTPError,
+    )  # :meta private:
+    top_k_results: int = 3
+    ARXIV_MAX_QUERY_LENGTH = 300
+    load_max_docs: int = 100
+    load_all_available_meta: bool = False
+    doc_content_chars_max: Optional[int] = 4000
+
+    def run(self, query: str) -> str:
+        """
+        Performs an arxiv search and A single string
+        with the publish date, title, authors, and summary
+        for each article separated by two newlines.
+
+        If an error occurs or no documents found, error text
+        is returned instead. Wrapper for
+        https://lukasschwab.me/arxiv.py/index.html#Search
+
+        Args:
+            query: a plaintext search query
+        """  # noqa: E501
+        try:
+            results = self.arxiv_search(  # type: ignore
+                query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.top_k_results
+            ).results()
+        except self.arxiv_exceptions as ex:
+            return f"Arxiv exception: {ex}"
+        docs = [
+            f"Published: {result.updated.date()}\n"
+            f"Title: {result.title}\n"
+            f"Authors: {', '.join(a.name for a in result.authors)}\n"
+            f"Summary: {result.summary}"
+            for result in results
+        ]
+        if docs:
+            return "\n\n".join(docs)[: self.doc_content_chars_max]
+        else:
+            return "No good Arxiv Result was found"
+
 
 class ArxivSearchInput(BaseModel):
     query: str = Field(..., description="Search query.")

+ 87 - 3
api/core/tools/provider/builtin/brave/tools/brave_search.py

@@ -1,11 +1,95 @@
-from typing import Any
+import json
+from typing import Any, Optional
 
-from langchain.tools import BraveSearch
+import requests
+from pydantic import BaseModel, Field
 
 from core.tools.entities.tool_entities import ToolInvokeMessage
 from core.tools.tool.builtin_tool import BuiltinTool
 
 
+class BraveSearchWrapper(BaseModel):
+    """Wrapper around the Brave search engine."""
+
+    api_key: str
+    """The API key to use for the Brave search engine."""
+    search_kwargs: dict = Field(default_factory=dict)
+    """Additional keyword arguments to pass to the search request."""
+    base_url = "https://api.search.brave.com/res/v1/web/search"
+    """The base URL for the Brave search engine."""
+
+    def run(self, query: str) -> str:
+        """Query the Brave search engine and return the results as a JSON string.
+
+        Args:
+            query: The query to search for.
+
+        Returns: The results as a JSON string.
+
+        """
+        web_search_results = self._search_request(query=query)
+        final_results = [
+            {
+                "title": item.get("title"),
+                "link": item.get("url"),
+                "snippet": item.get("description"),
+            }
+            for item in web_search_results
+        ]
+        return json.dumps(final_results)
+    
+    def _search_request(self, query: str) -> list[dict]:
+        headers = {
+            "X-Subscription-Token": self.api_key,
+            "Accept": "application/json",
+        }
+        req = requests.PreparedRequest()
+        params = {**self.search_kwargs, **{"q": query}}
+        req.prepare_url(self.base_url, params)
+        if req.url is None:
+            raise ValueError("prepared url is None, this should not happen")
+
+        response = requests.get(req.url, headers=headers)
+        if not response.ok:
+            raise Exception(f"HTTP error {response.status_code}")
+
+        return response.json().get("web", {}).get("results", [])
+
+class BraveSearch(BaseModel):
+    """Tool that queries the BraveSearch."""
+
+    name = "brave_search"
+    description = (
+        "a search engine. "
+        "useful for when you need to answer questions about current events."
+        " input should be a search query."
+    )
+    search_wrapper: BraveSearchWrapper
+
+    @classmethod
+    def from_api_key(
+        cls, api_key: str, search_kwargs: Optional[dict] = None, **kwargs: Any
+    ) -> "BraveSearch":
+        """Create a tool from an api key.
+
+        Args:
+            api_key: The api key to use.
+            search_kwargs: Any additional kwargs to pass to the search wrapper.
+            **kwargs: Any additional kwargs to pass to the tool.
+
+        Returns:
+            A tool.
+        """
+        wrapper = BraveSearchWrapper(api_key=api_key, search_kwargs=search_kwargs or {})
+        return cls(search_wrapper=wrapper, **kwargs)
+
+    def _run(
+        self,
+        query: str,
+    ) -> str:
+        """Use the tool."""
+        return self.search_wrapper.run(query)
+
 class BraveSearchTool(BuiltinTool):
     """
     Tool for performing a search using Brave search engine.
@@ -31,7 +115,7 @@ class BraveSearchTool(BuiltinTool):
 
         tool = BraveSearch.from_api_key(api_key=api_key, search_kwargs={"count": count})
 
-        results = tool.run(query)
+        results = tool._run(query)
 
         if not results:
             return self.create_text_message(f"No results found for '{query}' in Tavily")

+ 135 - 4
api/core/tools/provider/builtin/duckduckgo/tools/duckduckgo_search.py

@@ -1,16 +1,147 @@
-from typing import Any
+from typing import Any, Optional
 
-from langchain.tools import DuckDuckGoSearchRun
 from pydantic import BaseModel, Field
 
 from core.tools.entities.tool_entities import ToolInvokeMessage
 from core.tools.tool.builtin_tool import BuiltinTool
 
 
+class DuckDuckGoSearchAPIWrapper(BaseModel):
+    """Wrapper for DuckDuckGo Search API.
+
+    Free and does not require any setup.
+    """
+
+    region: Optional[str] = "wt-wt"
+    safesearch: str = "moderate"
+    time: Optional[str] = "y"
+    max_results: int = 5
+
+    def get_snippets(self, query: str) -> list[str]:
+        """Run query through DuckDuckGo and return concatenated results."""
+        from duckduckgo_search import DDGS
+
+        with DDGS() as ddgs:
+            results = ddgs.text(
+                query,
+                region=self.region,
+                safesearch=self.safesearch,
+                timelimit=self.time,
+            )
+            if results is None:
+                return ["No good DuckDuckGo Search Result was found"]
+            snippets = []
+            for i, res in enumerate(results, 1):
+                if res is not None:
+                    snippets.append(res["body"])
+                if len(snippets) == self.max_results:
+                    break
+        return snippets
+
+    def run(self, query: str) -> str:
+        snippets = self.get_snippets(query)
+        return " ".join(snippets)
+
+    def results(
+        self, query: str, num_results: int, backend: str = "api"
+    ) -> list[dict[str, str]]:
+        """Run query through DuckDuckGo and return metadata.
+
+        Args:
+            query: The query to search for.
+            num_results: The number of results to return.
+
+        Returns:
+            A list of dictionaries with the following keys:
+                snippet - The description of the result.
+                title - The title of the result.
+                link - The link to the result.
+        """
+        from duckduckgo_search import DDGS
+
+        with DDGS() as ddgs:
+            results = ddgs.text(
+                query,
+                region=self.region,
+                safesearch=self.safesearch,
+                timelimit=self.time,
+                backend=backend,
+            )
+            if results is None:
+                return [{"Result": "No good DuckDuckGo Search Result was found"}]
+
+            def to_metadata(result: dict) -> dict[str, str]:
+                if backend == "news":
+                    return {
+                        "date": result["date"],
+                        "title": result["title"],
+                        "snippet": result["body"],
+                        "source": result["source"],
+                        "link": result["url"],
+                    }
+                return {
+                    "snippet": result["body"],
+                    "title": result["title"],
+                    "link": result["href"],
+                }
+
+            formatted_results = []
+            for i, res in enumerate(results, 1):
+                if res is not None:
+                    formatted_results.append(to_metadata(res))
+                if len(formatted_results) == num_results:
+                    break
+        return formatted_results
+
+
+class DuckDuckGoSearchRun(BaseModel):
+    """Tool that queries the DuckDuckGo search API."""
+
+    name = "duckduckgo_search"
+    description = (
+        "A wrapper around DuckDuckGo Search. "
+        "Useful for when you need to answer questions about current events. "
+        "Input should be a search query."
+    )
+    api_wrapper: DuckDuckGoSearchAPIWrapper = Field(
+        default_factory=DuckDuckGoSearchAPIWrapper
+    )
+
+    def _run(
+        self,
+        query: str,
+    ) -> str:
+        """Use the tool."""
+        return self.api_wrapper.run(query)
+
+
+class DuckDuckGoSearchResults(BaseModel):
+    """Tool that queries the DuckDuckGo search API and gets back json."""
+
+    name = "DuckDuckGo Results JSON"
+    description = (
+        "A wrapper around Duck Duck Go Search. "
+        "Useful for when you need to answer questions about current events. "
+        "Input should be a search query. Output is a JSON array of the query results"
+    )
+    num_results: int = 4
+    api_wrapper: DuckDuckGoSearchAPIWrapper = Field(
+        default_factory=DuckDuckGoSearchAPIWrapper
+    )
+    backend: str = "api"
+
+    def _run(
+        self,
+        query: str,
+    ) -> str:
+        """Use the tool."""
+        res = self.api_wrapper.results(query, self.num_results, backend=self.backend)
+        res_strs = [", ".join([f"{k}: {v}" for k, v in d.items()]) for d in res]
+        return ", ".join([f"[{rs}]" for rs in res_strs])
+
 class DuckDuckGoInput(BaseModel):
     query: str = Field(..., description="Search query.")
 
-
 class DuckDuckGoSearchTool(BuiltinTool):
     """
     Tool for performing a search using DuckDuckGo search engine.
@@ -34,7 +165,7 @@ class DuckDuckGoSearchTool(BuiltinTool):
 
         tool = DuckDuckGoSearchRun(args_schema=DuckDuckGoInput)
 
-        result = tool.run(query)
+        result = tool._run(query)
 
         return self.create_text_message(self.summary(user_id=user_id, content=result))
     

+ 174 - 3
api/core/tools/provider/builtin/pubmed/tools/pubmed_search.py

@@ -1,16 +1,187 @@
+import json
+import time
+import urllib.error
+import urllib.parse
+import urllib.request
 from typing import Any
 
-from langchain.tools import PubmedQueryRun
 from pydantic import BaseModel, Field
 
 from core.tools.entities.tool_entities import ToolInvokeMessage
 from core.tools.tool.builtin_tool import BuiltinTool
 
 
+class PubMedAPIWrapper(BaseModel):
+    """
+    Wrapper around PubMed API.
+
+    This wrapper will use the PubMed API to conduct searches and fetch
+    document summaries. By default, it will return the document summaries
+    of the top-k results of an input search.
+
+    Parameters:
+        top_k_results: number of the top-scored document used for the PubMed tool
+        load_max_docs: a limit to the number of loaded documents
+        load_all_available_meta:
+          if True: the `metadata` of the loaded Documents gets all available meta info
+            (see https://www.ncbi.nlm.nih.gov/books/NBK25499/#chapter4.ESearch)
+          if False: the `metadata` gets only the most informative fields.
+    """
+
+    base_url_esearch = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?"
+    base_url_efetch = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?"
+    max_retry = 5
+    sleep_time = 0.2
+
+    # Default values for the parameters
+    top_k_results: int = 3
+    load_max_docs: int = 25
+    ARXIV_MAX_QUERY_LENGTH = 300
+    doc_content_chars_max: int = 2000
+    load_all_available_meta: bool = False
+    email: str = "your_email@example.com"
+
+    def run(self, query: str) -> str:
+        """
+        Run PubMed search and get the article meta information.
+        See https://www.ncbi.nlm.nih.gov/books/NBK25499/#chapter4.ESearch
+        It uses only the most informative fields of article meta information.
+        """
+
+        try:
+            # Retrieve the top-k results for the query
+            docs = [
+                f"Published: {result['pub_date']}\nTitle: {result['title']}\n"
+                f"Summary: {result['summary']}"
+                for result in self.load(query[: self.ARXIV_MAX_QUERY_LENGTH])
+            ]
+
+            # Join the results and limit the character count
+            return (
+                "\n\n".join(docs)[:self.doc_content_chars_max]
+                if docs
+                else "No good PubMed Result was found"
+            )
+        except Exception as ex:
+            return f"PubMed exception: {ex}"
+
+    def load(self, query: str) -> list[dict]:
+        """
+        Search PubMed for documents matching the query.
+        Return a list of dictionaries containing the document metadata.
+        """
+
+        url = (
+            self.base_url_esearch
+            + "db=pubmed&term="
+            + str({urllib.parse.quote(query)})
+            + f"&retmode=json&retmax={self.top_k_results}&usehistory=y"
+        )
+        result = urllib.request.urlopen(url)
+        text = result.read().decode("utf-8")
+        json_text = json.loads(text)
+
+        articles = []
+        webenv = json_text["esearchresult"]["webenv"]
+        for uid in json_text["esearchresult"]["idlist"]:
+            article = self.retrieve_article(uid, webenv)
+            articles.append(article)
+
+        # Convert the list of articles to a JSON string
+        return articles
+
+    def retrieve_article(self, uid: str, webenv: str) -> dict:
+        url = (
+            self.base_url_efetch
+            + "db=pubmed&retmode=xml&id="
+            + uid
+            + "&webenv="
+            + webenv
+        )
+
+        retry = 0
+        while True:
+            try:
+                result = urllib.request.urlopen(url)
+                break
+            except urllib.error.HTTPError as e:
+                if e.code == 429 and retry < self.max_retry:
+                    # Too Many Requests error
+                    # wait for an exponentially increasing amount of time
+                    print(
+                        f"Too Many Requests, "
+                        f"waiting for {self.sleep_time:.2f} seconds..."
+                    )
+                    time.sleep(self.sleep_time)
+                    self.sleep_time *= 2
+                    retry += 1
+                else:
+                    raise e
+
+        xml_text = result.read().decode("utf-8")
+
+        # Get title
+        title = ""
+        if "<ArticleTitle>" in xml_text and "</ArticleTitle>" in xml_text:
+            start_tag = "<ArticleTitle>"
+            end_tag = "</ArticleTitle>"
+            title = xml_text[
+                xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag)
+            ]
+
+        # Get abstract
+        abstract = ""
+        if "<AbstractText>" in xml_text and "</AbstractText>" in xml_text:
+            start_tag = "<AbstractText>"
+            end_tag = "</AbstractText>"
+            abstract = xml_text[
+                xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag)
+            ]
+
+        # Get publication date
+        pub_date = ""
+        if "<PubDate>" in xml_text and "</PubDate>" in xml_text:
+            start_tag = "<PubDate>"
+            end_tag = "</PubDate>"
+            pub_date = xml_text[
+                xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag)
+            ]
+
+        # Return article as dictionary
+        article = {
+            "uid": uid,
+            "title": title,
+            "summary": abstract,
+            "pub_date": pub_date,
+        }
+        return article
+
+
+class PubmedQueryRun(BaseModel):
+    """Tool that searches the PubMed API."""
+
+    name = "PubMed"
+    description = (
+        "A wrapper around PubMed.org "
+        "Useful for when you need to answer questions about Physics, Mathematics, "
+        "Computer Science, Quantitative Biology, Quantitative Finance, Statistics, "
+        "Electrical Engineering, and Economics "
+        "from scientific articles on PubMed.org. "
+        "Input should be a search query."
+    )
+    api_wrapper: PubMedAPIWrapper = Field(default_factory=PubMedAPIWrapper)
+
+    def _run(
+        self,
+        query: str,
+    ) -> str:
+        """Use the Arxiv tool."""
+        return self.api_wrapper.run(query)
+
+
 class PubMedInput(BaseModel):
     query: str = Field(..., description="Search query.")
 
-
 class PubMedSearchTool(BuiltinTool):
     """
     Tool for performing a search using PubMed search engine.
@@ -34,7 +205,7 @@ class PubMedSearchTool(BuiltinTool):
 
         tool = PubmedQueryRun(args_schema=PubMedInput)
 
-        result = tool.run(query)
+        result = tool._run(query)
 
         return self.create_text_message(self.summary(user_id=user_id, content=result))
     

+ 72 - 2
api/core/tools/provider/builtin/twilio/tools/send_message.py

@@ -1,11 +1,81 @@
-from typing import Any, Union
+from typing import Any, Optional, Union
 
-from langchain.utilities import TwilioAPIWrapper
+from pydantic import BaseModel, validator
 
 from core.tools.entities.tool_entities import ToolInvokeMessage
 from core.tools.tool.builtin_tool import BuiltinTool
 
 
+class TwilioAPIWrapper(BaseModel):
+    """Messaging Client using Twilio.
+
+    To use, you should have the ``twilio`` python package installed,
+    and the environment variables ``TWILIO_ACCOUNT_SID``, ``TWILIO_AUTH_TOKEN``, and
+    ``TWILIO_FROM_NUMBER``, or pass `account_sid`, `auth_token`, and `from_number` as
+    named parameters to the constructor.
+
+    Example:
+        .. code-block:: python
+
+            from langchain.utilities.twilio import TwilioAPIWrapper
+            twilio = TwilioAPIWrapper(
+                account_sid="ACxxx",
+                auth_token="xxx",
+                from_number="+10123456789"
+            )
+            twilio.run('test', '+12484345508')
+    """
+
+    client: Any  #: :meta private:
+    account_sid: Optional[str] = None
+    """Twilio account string identifier."""
+    auth_token: Optional[str] = None
+    """Twilio auth token."""
+    from_number: Optional[str] = None
+    """A Twilio phone number in [E.164](https://www.twilio.com/docs/glossary/what-e164) 
+        format, an 
+        [alphanumeric sender ID](https://www.twilio.com/docs/sms/send-messages#use-an-alphanumeric-sender-id), 
+        or a [Channel Endpoint address](https://www.twilio.com/docs/sms/channels#channel-addresses) 
+        that is enabled for the type of message you want to send. Phone numbers or 
+        [short codes](https://www.twilio.com/docs/sms/api/short-code) purchased from 
+        Twilio also work here. You cannot, for example, spoof messages from a private 
+        cell phone number. If you are using `messaging_service_sid`, this parameter 
+        must be empty.
+    """  # noqa: E501
+
+    @validator("client", pre=True, always=True)
+    def set_validator(cls, values: dict) -> dict:
+        """Validate that api key and python package exists in environment."""
+        try:
+            from twilio.rest import Client
+        except ImportError:
+            raise ImportError(
+                "Could not import twilio python package. "
+                "Please install it with `pip install twilio`."
+            )
+        account_sid = values.get("account_sid")
+        auth_token = values.get("auth_token")
+        values["from_number"] = values.get("from_number")
+        values["client"] = Client(account_sid, auth_token)
+
+        return values
+
+    def run(self, body: str, to: str) -> str:
+        """Run body through Twilio and respond with message sid.
+
+        Args:
+            body: The text of the message you want to send. Can be up to 1,600
+                characters in length.
+            to: The destination phone number in
+                [E.164](https://www.twilio.com/docs/glossary/what-e164) format for
+                SMS/MMS or
+                [Channel user address](https://www.twilio.com/docs/sms/channels#channel-addresses)
+                for other 3rd-party channels.
+        """  # noqa: E501
+        message = self.client.messages.create(to, from_=self.from_number, body=body)
+        return message.sid
+
+
 class SendMessageTool(BuiltinTool):
     """
     A tool for sending messages using Twilio API.

+ 70 - 11
api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py

@@ -1,16 +1,79 @@
-from typing import Any, Union
+from typing import Any, Optional, Union
 
-from langchain import WikipediaAPIWrapper
-from langchain.tools import WikipediaQueryRun
-from pydantic import BaseModel, Field
+import wikipedia
 
 from core.tools.entities.tool_entities import ToolInvokeMessage
 from core.tools.tool.builtin_tool import BuiltinTool
 
+WIKIPEDIA_MAX_QUERY_LENGTH = 300
 
-class WikipediaInput(BaseModel):
-    query: str = Field(..., description="search query.")
+class WikipediaAPIWrapper:
+    """Wrapper around WikipediaAPI.
 
+    To use, you should have the ``wikipedia`` python package installed.
+    This wrapper will use the Wikipedia API to conduct searches and
+    fetch page summaries. By default, it will return the page summaries
+    of the top-k results.
+    It limits the Document content by doc_content_chars_max.
+    """
+
+    top_k_results: int = 3
+    lang: str = "en"
+    load_all_available_meta: bool = False
+    doc_content_chars_max: int = 4000
+
+    def __init__(self, doc_content_chars_max: int = 4000):
+        self.doc_content_chars_max = doc_content_chars_max
+
+    def run(self, query: str) -> str:
+        wikipedia.set_lang(self.lang)
+        wiki_client = wikipedia
+
+        """Run Wikipedia search and get page summaries."""
+        page_titles = wiki_client.search(query[:WIKIPEDIA_MAX_QUERY_LENGTH])
+        summaries = []
+        for page_title in page_titles[: self.top_k_results]:
+            if wiki_page := self._fetch_page(page_title):
+                if summary := self._formatted_page_summary(page_title, wiki_page):
+                    summaries.append(summary)
+        if not summaries:
+            return "No good Wikipedia Search Result was found"
+        return "\n\n".join(summaries)[: self.doc_content_chars_max]
+
+    @staticmethod
+    def _formatted_page_summary(page_title: str, wiki_page: Any) -> Optional[str]:
+        return f"Page: {page_title}\nSummary: {wiki_page.summary}"
+
+    def _fetch_page(self, page: str) -> Optional[str]:
+        try:
+            return wikipedia.page(title=page, auto_suggest=False)
+        except (
+            wikipedia.exceptions.PageError,
+            wikipedia.exceptions.DisambiguationError,
+        ):
+            return None
+
+class WikipediaQueryRun:
+    """Tool that searches the Wikipedia API."""
+
+    name = "Wikipedia"
+    description = (
+        "A wrapper around Wikipedia. "
+        "Useful for when you need to answer general questions about "
+        "people, places, companies, facts, historical events, or other subjects. "
+        "Input should be a search query."
+    )
+    api_wrapper: WikipediaAPIWrapper
+
+    def __init__(self, api_wrapper: WikipediaAPIWrapper):
+        self.api_wrapper = api_wrapper
+
+    def _run(
+        self,
+        query: str,
+    ) -> str:
+        """Use the Wikipedia tool."""
+        return self.api_wrapper.run(query)
 class WikiPediaSearchTool(BuiltinTool):
     def _invoke(self, 
                 user_id: str, 
@@ -24,14 +87,10 @@ class WikiPediaSearchTool(BuiltinTool):
             return self.create_text_message('Please input query')
         
         tool = WikipediaQueryRun(
-            name="wikipedia",
             api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
-            args_schema=WikipediaInput
         )
 
-        result = tool.run(tool_input={
-            'query': query
-        })
+        result = tool._run(query)
 
         return self.create_text_message(self.summary(user_id=user_id,content=result))