Browse Source

chore: remove Langchain tools import (#3407)

Jyong 1 year ago
parent
commit
0737e930cb

+ 1 - 1
api/core/rag/extractor/blod/blod.py

@@ -159,7 +159,7 @@ class BlobLoader(ABC):
     def yield_blobs(
     def yield_blobs(
         self,
         self,
     ) -> Iterable[Blob]:
     ) -> Iterable[Blob]:
-        """A lazy loader for raw data represented by LangChain's Blob object.
+        """A lazy loader for raw data represented by Blob object.
 
 
         Returns:
         Returns:
             A generator over blobs
             A generator over blobs

+ 2 - 2
api/core/rag/retrieval/dataset_retrieval.py

@@ -2,7 +2,6 @@ import threading
 from typing import Optional, cast
 from typing import Optional, cast
 
 
 from flask import Flask, current_app
 from flask import Flask, current_app
-from langchain.tools import BaseTool
 
 
 from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
 from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
 from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
 from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
@@ -19,6 +18,7 @@ from core.rag.retrieval.router.multi_dataset_function_call_router import Functio
 from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
 from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
 from core.rerank.rerank import RerankRunner
 from core.rerank.rerank import RerankRunner
 from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
 from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
+from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
 from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
 from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
 from extensions.ext_database import db
 from extensions.ext_database import db
 from models.dataset import Dataset, DatasetQuery, DocumentSegment
 from models.dataset import Dataset, DatasetQuery, DocumentSegment
@@ -383,7 +383,7 @@ class DatasetRetrieval:
                                   return_resource: bool,
                                   return_resource: bool,
                                   invoke_from: InvokeFrom,
                                   invoke_from: InvokeFrom,
                                   hit_callback: DatasetIndexToolCallbackHandler) \
                                   hit_callback: DatasetIndexToolCallbackHandler) \
-            -> Optional[list[BaseTool]]:
+            -> Optional[list[DatasetRetrieverBaseTool]]:
         """
         """
         A dataset tool is a tool that can be used to retrieve information from a dataset
         A dataset tool is a tool that can be used to retrieve information from a dataset
         :param tenant_id: tenant id
         :param tenant_id: tenant id

+ 25 - 0
api/core/rag/retrieval/output_parser/react_output.py

@@ -0,0 +1,25 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import NamedTuple, Union
+
+
+@dataclass
+class ReactAction:
+    """A full description of an action for an ReactAction to execute."""
+
+    tool: str
+    """The name of the Tool to execute."""
+    tool_input: Union[str, dict]
+    """The input to pass in to the Tool."""
+    log: str
+    """Additional information to log about the action."""
+
+
+class ReactFinish(NamedTuple):
+    """The final return value of an ReactFinish."""
+
+    return_values: dict
+    """Dictionary of return values."""
+    log: str
+    """Additional information to log about the return value"""

+ 7 - 11
api/core/rag/retrieval/output_parser/structured_chat.py

@@ -2,28 +2,24 @@ import json
 import re
 import re
 from typing import Union
 from typing import Union
 
 
-from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser as LCStructuredChatOutputParser
-from langchain.agents.structured_chat.output_parser import logger
-from langchain.schema import AgentAction, AgentFinish, OutputParserException
+from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish
 
 
 
 
-class StructuredChatOutputParser(LCStructuredChatOutputParser):
-    def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
+class StructuredChatOutputParser:
+    def parse(self, text: str) -> Union[ReactAction, ReactFinish]:
         try:
         try:
             action_match = re.search(r"```(\w*)\n?({.*?)```", text, re.DOTALL)
             action_match = re.search(r"```(\w*)\n?({.*?)```", text, re.DOTALL)
             if action_match is not None:
             if action_match is not None:
                 response = json.loads(action_match.group(2).strip(), strict=False)
                 response = json.loads(action_match.group(2).strip(), strict=False)
                 if isinstance(response, list):
                 if isinstance(response, list):
-                    # gpt turbo frequently ignores the directive to emit a single action
-                    logger.warning("Got multiple action responses: %s", response)
                     response = response[0]
                     response = response[0]
                 if response["action"] == "Final Answer":
                 if response["action"] == "Final Answer":
-                    return AgentFinish({"output": response["action_input"]}, text)
+                    return ReactFinish({"output": response["action_input"]}, text)
                 else:
                 else:
-                    return AgentAction(
+                    return ReactAction(
                         response["action"], response.get("action_input", {}), text
                         response["action"], response.get("action_input", {}), text
                     )
                     )
             else:
             else:
-                return AgentFinish({"output": text}, text)
+                return ReactFinish({"output": text}, text)
         except Exception as e:
         except Exception as e:
-            raise OutputParserException(f"Could not parse LLM output: {text}")
+            raise ValueError(f"Could not parse LLM output: {text}")

+ 13 - 22
api/core/rag/retrieval/router/multi_dataset_react_route.py

@@ -1,20 +1,21 @@
 from collections.abc import Generator, Sequence
 from collections.abc import Generator, Sequence
-from typing import Optional, Union
-
-from langchain import PromptTemplate
-from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
-from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
-from langchain.schema import AgentAction
+from typing import Union
 
 
 from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
 from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
 from core.model_manager import ModelInstance
 from core.model_manager import ModelInstance
 from core.model_runtime.entities.llm_entities import LLMUsage
 from core.model_runtime.entities.llm_entities import LLMUsage
 from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
 from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
 from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
 from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
-from core.prompt.entities.advanced_prompt_entities import ChatModelMessage
+from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
+from core.rag.retrieval.output_parser.react_output import ReactAction
 from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
 from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
 from core.workflow.nodes.llm.llm_node import LLMNode
 from core.workflow.nodes.llm.llm_node import LLMNode
 
 
+PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
+
+SUFFIX = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
+Thought:"""
+
 FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
 FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
 The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
 The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
 Valid "action" values: "Final Answer" or {tool_names}
 Valid "action" values: "Final Answer" or {tool_names}
@@ -86,7 +87,6 @@ class ReactMultiDatasetRouter:
             tenant_id: str,
             tenant_id: str,
             prefix: str = PREFIX,
             prefix: str = PREFIX,
             suffix: str = SUFFIX,
             suffix: str = SUFFIX,
-            human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
             format_instructions: str = FORMAT_INSTRUCTIONS,
             format_instructions: str = FORMAT_INSTRUCTIONS,
     ) -> Union[str, None]:
     ) -> Union[str, None]:
         if model_config.mode == "chat":
         if model_config.mode == "chat":
@@ -95,7 +95,6 @@ class ReactMultiDatasetRouter:
                 tools=tools,
                 tools=tools,
                 prefix=prefix,
                 prefix=prefix,
                 suffix=suffix,
                 suffix=suffix,
-                human_message_template=human_message_template,
                 format_instructions=format_instructions,
                 format_instructions=format_instructions,
             )
             )
         else:
         else:
@@ -103,7 +102,6 @@ class ReactMultiDatasetRouter:
                 tools=tools,
                 tools=tools,
                 prefix=prefix,
                 prefix=prefix,
                 format_instructions=format_instructions,
                 format_instructions=format_instructions,
-                input_variables=None
             )
             )
         stop = ['Observation:']
         stop = ['Observation:']
         # handle invoke result
         # handle invoke result
@@ -127,9 +125,9 @@ class ReactMultiDatasetRouter:
             tenant_id=tenant_id
             tenant_id=tenant_id
         )
         )
         output_parser = StructuredChatOutputParser()
         output_parser = StructuredChatOutputParser()
-        agent_decision = output_parser.parse(result_text)
-        if isinstance(agent_decision, AgentAction):
-            return agent_decision.tool
+        react_decision = output_parser.parse(result_text)
+        if isinstance(react_decision, ReactAction):
+            return react_decision.tool
         return None
         return None
 
 
     def _invoke_llm(self, completion_param: dict,
     def _invoke_llm(self, completion_param: dict,
@@ -139,7 +137,6 @@ class ReactMultiDatasetRouter:
                     ) -> tuple[str, LLMUsage]:
                     ) -> tuple[str, LLMUsage]:
         """
         """
             Invoke large language model
             Invoke large language model
-            :param node_data: node data
             :param model_instance: model instance
             :param model_instance: model instance
             :param prompt_messages: prompt messages
             :param prompt_messages: prompt messages
             :param stop: stop
             :param stop: stop
@@ -197,7 +194,6 @@ class ReactMultiDatasetRouter:
             tools: Sequence[PromptMessageTool],
             tools: Sequence[PromptMessageTool],
             prefix: str = PREFIX,
             prefix: str = PREFIX,
             suffix: str = SUFFIX,
             suffix: str = SUFFIX,
-            human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
             format_instructions: str = FORMAT_INSTRUCTIONS,
             format_instructions: str = FORMAT_INSTRUCTIONS,
     ) -> list[ChatModelMessage]:
     ) -> list[ChatModelMessage]:
         tool_strings = []
         tool_strings = []
@@ -227,16 +223,13 @@ class ReactMultiDatasetRouter:
             tools: Sequence[PromptMessageTool],
             tools: Sequence[PromptMessageTool],
             prefix: str = PREFIX,
             prefix: str = PREFIX,
             format_instructions: str = FORMAT_INSTRUCTIONS,
             format_instructions: str = FORMAT_INSTRUCTIONS,
-            input_variables: Optional[list[str]] = None,
-    ) -> PromptTemplate:
+    ) -> CompletionModelPromptTemplate:
         """Create prompt in the style of the zero shot agent.
         """Create prompt in the style of the zero shot agent.
 
 
         Args:
         Args:
             tools: List of tools the agent will have access to, used to format the
             tools: List of tools the agent will have access to, used to format the
                 prompt.
                 prompt.
             prefix: String to put before the list of tools.
             prefix: String to put before the list of tools.
-            input_variables: List of input variables the final prompt will expect.
-
         Returns:
         Returns:
             A PromptTemplate with the template assembled from the pieces here.
             A PromptTemplate with the template assembled from the pieces here.
         """
         """
@@ -249,6 +242,4 @@ Thought: {agent_scratchpad}
         tool_names = ", ".join([tool.name for tool in tools])
         tool_names = ", ".join([tool.name for tool in tools])
         format_instructions = format_instructions.format(tool_names=tool_names)
         format_instructions = format_instructions.format(tool_names=tool_names)
         template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
         template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
-        if input_variables is None:
-            input_variables = ["input", "agent_scratchpad"]
-        return PromptTemplate(template=template, input_variables=input_variables)
+        return CompletionModelPromptTemplate(text=template)

+ 3 - 12
api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py

@@ -1,8 +1,6 @@
 import threading
 import threading
-from typing import Optional
 
 
 from flask import Flask, current_app
 from flask import Flask, current_app
-from langchain.tools import BaseTool
 from pydantic import BaseModel, Field
 from pydantic import BaseModel, Field
 
 
 from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
 from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
@@ -10,6 +8,7 @@ from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.entities.model_entities import ModelType
 from core.rag.datasource.retrieval_service import RetrievalService
 from core.rag.datasource.retrieval_service import RetrievalService
 from core.rerank.rerank import RerankRunner
 from core.rerank.rerank import RerankRunner
+from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
 from extensions.ext_database import db
 from extensions.ext_database import db
 from models.dataset import Dataset, Document, DocumentSegment
 from models.dataset import Dataset, Document, DocumentSegment
 
 
@@ -29,20 +28,15 @@ class DatasetMultiRetrieverToolInput(BaseModel):
     query: str = Field(..., description="dataset multi retriever and rerank")
     query: str = Field(..., description="dataset multi retriever and rerank")
 
 
 
 
-class DatasetMultiRetrieverTool(BaseTool):
+class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
     """Tool for querying multi dataset."""
     """Tool for querying multi dataset."""
     name: str = "dataset_"
     name: str = "dataset_"
     args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput
     args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput
     description: str = "dataset multi retriever and rerank. "
     description: str = "dataset multi retriever and rerank. "
-    tenant_id: str
     dataset_ids: list[str]
     dataset_ids: list[str]
-    top_k: int = 2
-    score_threshold: Optional[float] = None
     reranking_provider_name: str
     reranking_provider_name: str
     reranking_model_name: str
     reranking_model_name: str
-    return_resource: bool
-    retriever_from: str
-    hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
+
 
 
     @classmethod
     @classmethod
     def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs):
     def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs):
@@ -149,9 +143,6 @@ class DatasetMultiRetrieverTool(BaseTool):
 
 
             return str("\n".join(document_context_list))
             return str("\n".join(document_context_list))
 
 
-    async def _arun(self, tool_input: str) -> str:
-        raise NotImplementedError()
-
     def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_documents: list,
     def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_documents: list,
                    hit_callbacks: list[DatasetIndexToolCallbackHandler]):
                    hit_callbacks: list[DatasetIndexToolCallbackHandler]):
         with flask_app.app_context():
         with flask_app.app_context():

+ 34 - 0
api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py

@@ -0,0 +1,34 @@
+from abc import abstractmethod
+from typing import Any, Optional
+
+from msal_extensions.persistence import ABC
+from pydantic import BaseModel
+
+from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
+
+
+class DatasetRetrieverBaseTool(BaseModel, ABC):
+    """Tool for querying a Dataset."""
+    name: str = "dataset"
+    description: str = "use this to retrieve a dataset. "
+    tenant_id: str
+    top_k: int = 2
+    score_threshold: Optional[float] = None
+    hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
+    return_resource: bool
+    retriever_from: str
+
+    class Config:
+        arbitrary_types_allowed = True
+
+    @abstractmethod
+    def _run(
+        self,
+        *args: Any,
+        **kwargs: Any,
+    ) -> Any:
+        """Use the tool.
+
+        Add run_manager: Optional[CallbackManagerForToolRun] = None
+        to child implementations to enable tracing,
+        """

+ 4 - 15
api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py

@@ -1,10 +1,8 @@
-from typing import Optional
 
 
-from langchain.tools import BaseTool
 from pydantic import BaseModel, Field
 from pydantic import BaseModel, Field
 
 
-from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
 from core.rag.datasource.retrieval_service import RetrievalService
 from core.rag.datasource.retrieval_service import RetrievalService
+from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
 from extensions.ext_database import db
 from extensions.ext_database import db
 from models.dataset import Dataset, Document, DocumentSegment
 from models.dataset import Dataset, Document, DocumentSegment
 
 
@@ -24,19 +22,13 @@ class DatasetRetrieverToolInput(BaseModel):
     query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
     query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
 
 
 
 
-class DatasetRetrieverTool(BaseTool):
+class DatasetRetrieverTool(DatasetRetrieverBaseTool):
     """Tool for querying a Dataset."""
     """Tool for querying a Dataset."""
     name: str = "dataset"
     name: str = "dataset"
     args_schema: type[BaseModel] = DatasetRetrieverToolInput
     args_schema: type[BaseModel] = DatasetRetrieverToolInput
     description: str = "use this to retrieve a dataset. "
     description: str = "use this to retrieve a dataset. "
-
-    tenant_id: str
     dataset_id: str
     dataset_id: str
-    top_k: int = 2
-    score_threshold: Optional[float] = None
-    hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
-    return_resource: bool
-    retriever_from: str
+
 
 
     @classmethod
     @classmethod
     def from_dataset(cls, dataset: Dataset, **kwargs):
     def from_dataset(cls, dataset: Dataset, **kwargs):
@@ -153,7 +145,4 @@ class DatasetRetrieverTool(BaseTool):
                     for hit_callback in self.hit_callbacks:
                     for hit_callback in self.hit_callbacks:
                         hit_callback.return_retriever_resource_info(context_list)
                         hit_callback.return_retriever_resource_info(context_list)
 
 
-            return str("\n".join(document_context_list))
-
-    async def _arun(self, tool_input: str) -> str:
-        raise NotImplementedError()
+            return str("\n".join(document_context_list))

+ 9 - 10
api/core/tools/tool/dataset_retriever_tool.py

@@ -1,7 +1,5 @@
 from typing import Any
 from typing import Any
 
 
-from langchain.tools import BaseTool
-
 from core.app.app_config.entities import DatasetRetrieveConfigEntity
 from core.app.app_config.entities import DatasetRetrieveConfigEntity
 from core.app.entities.app_invoke_entities import InvokeFrom
 from core.app.entities.app_invoke_entities import InvokeFrom
 from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
 from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
@@ -14,11 +12,12 @@ from core.tools.entities.tool_entities import (
     ToolParameter,
     ToolParameter,
     ToolProviderType,
     ToolProviderType,
 )
 )
+from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
 from core.tools.tool.tool import Tool
 from core.tools.tool.tool import Tool
 
 
 
 
 class DatasetRetrieverTool(Tool):
 class DatasetRetrieverTool(Tool):
-    langchain_tool: BaseTool
+    retrival_tool: DatasetRetrieverBaseTool
 
 
     @staticmethod
     @staticmethod
     def get_dataset_tools(tenant_id: str,
     def get_dataset_tools(tenant_id: str,
@@ -43,7 +42,7 @@ class DatasetRetrieverTool(Tool):
         # Agent only support SINGLE mode
         # Agent only support SINGLE mode
         original_retriever_mode = retrieve_config.retrieve_strategy
         original_retriever_mode = retrieve_config.retrieve_strategy
         retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
         retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
-        langchain_tools = feature.to_dataset_retriever_tool(
+        retrival_tools = feature.to_dataset_retriever_tool(
             tenant_id=tenant_id,
             tenant_id=tenant_id,
             dataset_ids=dataset_ids,
             dataset_ids=dataset_ids,
             retrieve_config=retrieve_config,
             retrieve_config=retrieve_config,
@@ -54,17 +53,17 @@ class DatasetRetrieverTool(Tool):
         # restore retrieve strategy
         # restore retrieve strategy
         retrieve_config.retrieve_strategy = original_retriever_mode
         retrieve_config.retrieve_strategy = original_retriever_mode
 
 
-        # convert langchain tools to Tools
+        # convert retrival tools to Tools
         tools = []
         tools = []
-        for langchain_tool in langchain_tools:
+        for retrival_tool in retrival_tools:
             tool = DatasetRetrieverTool(
             tool = DatasetRetrieverTool(
-                langchain_tool=langchain_tool,
-                identity=ToolIdentity(provider='', author='', name=langchain_tool.name, label=I18nObject(en_US='', zh_Hans='')),
+                retrival_tool=retrival_tool,
+                identity=ToolIdentity(provider='', author='', name=retrival_tool.name, label=I18nObject(en_US='', zh_Hans='')),
                 parameters=[],
                 parameters=[],
                 is_team_authorization=True,
                 is_team_authorization=True,
                 description=ToolDescription(
                 description=ToolDescription(
                     human=I18nObject(en_US='', zh_Hans=''),
                     human=I18nObject(en_US='', zh_Hans=''),
-                    llm=langchain_tool.description),
+                    llm=retrival_tool.description),
                 runtime=DatasetRetrieverTool.Runtime()
                 runtime=DatasetRetrieverTool.Runtime()
             )
             )
 
 
@@ -96,7 +95,7 @@ class DatasetRetrieverTool(Tool):
             return self.create_text_message(text='please input query')
             return self.create_text_message(text='please input query')
 
 
         # invoke dataset retriever tool
         # invoke dataset retriever tool
-        result = self.langchain_tool._run(query=query)
+        result = self.retrival_tool._run(query=query)
 
 
         return self.create_text_message(text=result)
         return self.create_text_message(text=result)