소스 검색

chore: remove Langchain tools import (#3407)

Jyong 1 년 전
부모
커밋
0737e930cb

+ 1 - 1
api/core/rag/extractor/blod/blod.py

@@ -159,7 +159,7 @@ class BlobLoader(ABC):
     def yield_blobs(
         self,
     ) -> Iterable[Blob]:
-        """A lazy loader for raw data represented by LangChain's Blob object.
+        """A lazy loader for raw data represented by Blob object.
 
         Returns:
             A generator over blobs

+ 2 - 2
api/core/rag/retrieval/dataset_retrieval.py

@@ -2,7 +2,6 @@ import threading
 from typing import Optional, cast
 
 from flask import Flask, current_app
-from langchain.tools import BaseTool
 
 from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
 from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
@@ -19,6 +18,7 @@ from core.rag.retrieval.router.multi_dataset_function_call_router import Functio
 from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
 from core.rerank.rerank import RerankRunner
 from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
+from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
 from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
 from extensions.ext_database import db
 from models.dataset import Dataset, DatasetQuery, DocumentSegment
@@ -383,7 +383,7 @@ class DatasetRetrieval:
                                   return_resource: bool,
                                   invoke_from: InvokeFrom,
                                   hit_callback: DatasetIndexToolCallbackHandler) \
-            -> Optional[list[BaseTool]]:
+            -> Optional[list[DatasetRetrieverBaseTool]]:
         """
         A dataset tool is a tool that can be used to retrieve information from a dataset
         :param tenant_id: tenant id

+ 25 - 0
api/core/rag/retrieval/output_parser/react_output.py

@@ -0,0 +1,25 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import NamedTuple, Union
+
+
+@dataclass
+class ReactAction:
+    """A full description of an action for an ReactAction to execute."""
+
+    tool: str
+    """The name of the Tool to execute."""
+    tool_input: Union[str, dict]
+    """The input to pass in to the Tool."""
+    log: str
+    """Additional information to log about the action."""
+
+
+class ReactFinish(NamedTuple):
+    """The final return value of an ReactFinish."""
+
+    return_values: dict
+    """Dictionary of return values."""
+    log: str
+    """Additional information to log about the return value"""

+ 7 - 11
api/core/rag/retrieval/output_parser/structured_chat.py

@@ -2,28 +2,24 @@ import json
 import re
 from typing import Union
 
-from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser as LCStructuredChatOutputParser
-from langchain.agents.structured_chat.output_parser import logger
-from langchain.schema import AgentAction, AgentFinish, OutputParserException
+from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish
 
 
-class StructuredChatOutputParser(LCStructuredChatOutputParser):
-    def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
+class StructuredChatOutputParser:
+    def parse(self, text: str) -> Union[ReactAction, ReactFinish]:
         try:
             action_match = re.search(r"```(\w*)\n?({.*?)```", text, re.DOTALL)
             if action_match is not None:
                 response = json.loads(action_match.group(2).strip(), strict=False)
                 if isinstance(response, list):
-                    # gpt turbo frequently ignores the directive to emit a single action
-                    logger.warning("Got multiple action responses: %s", response)
                     response = response[0]
                 if response["action"] == "Final Answer":
-                    return AgentFinish({"output": response["action_input"]}, text)
+                    return ReactFinish({"output": response["action_input"]}, text)
                 else:
-                    return AgentAction(
+                    return ReactAction(
                         response["action"], response.get("action_input", {}), text
                     )
             else:
-                return AgentFinish({"output": text}, text)
+                return ReactFinish({"output": text}, text)
         except Exception as e:
-            raise OutputParserException(f"Could not parse LLM output: {text}")
+            raise ValueError(f"Could not parse LLM output: {text}")

+ 13 - 22
api/core/rag/retrieval/router/multi_dataset_react_route.py

@@ -1,20 +1,21 @@
 from collections.abc import Generator, Sequence
-from typing import Optional, Union
-
-from langchain import PromptTemplate
-from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
-from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
-from langchain.schema import AgentAction
+from typing import Union
 
 from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
 from core.model_manager import ModelInstance
 from core.model_runtime.entities.llm_entities import LLMUsage
 from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
 from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
-from core.prompt.entities.advanced_prompt_entities import ChatModelMessage
+from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
+from core.rag.retrieval.output_parser.react_output import ReactAction
 from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
 from core.workflow.nodes.llm.llm_node import LLMNode
 
+PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
+
+SUFFIX = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
+Thought:"""
+
 FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
 The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
 Valid "action" values: "Final Answer" or {tool_names}
@@ -86,7 +87,6 @@ class ReactMultiDatasetRouter:
             tenant_id: str,
             prefix: str = PREFIX,
             suffix: str = SUFFIX,
-            human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
             format_instructions: str = FORMAT_INSTRUCTIONS,
     ) -> Union[str, None]:
         if model_config.mode == "chat":
@@ -95,7 +95,6 @@ class ReactMultiDatasetRouter:
                 tools=tools,
                 prefix=prefix,
                 suffix=suffix,
-                human_message_template=human_message_template,
                 format_instructions=format_instructions,
             )
         else:
@@ -103,7 +102,6 @@ class ReactMultiDatasetRouter:
                 tools=tools,
                 prefix=prefix,
                 format_instructions=format_instructions,
-                input_variables=None
             )
         stop = ['Observation:']
         # handle invoke result
@@ -127,9 +125,9 @@ class ReactMultiDatasetRouter:
             tenant_id=tenant_id
         )
         output_parser = StructuredChatOutputParser()
-        agent_decision = output_parser.parse(result_text)
-        if isinstance(agent_decision, AgentAction):
-            return agent_decision.tool
+        react_decision = output_parser.parse(result_text)
+        if isinstance(react_decision, ReactAction):
+            return react_decision.tool
         return None
 
     def _invoke_llm(self, completion_param: dict,
@@ -139,7 +137,6 @@ class ReactMultiDatasetRouter:
                     ) -> tuple[str, LLMUsage]:
         """
             Invoke large language model
-            :param node_data: node data
             :param model_instance: model instance
             :param prompt_messages: prompt messages
             :param stop: stop
@@ -197,7 +194,6 @@ class ReactMultiDatasetRouter:
             tools: Sequence[PromptMessageTool],
             prefix: str = PREFIX,
             suffix: str = SUFFIX,
-            human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
             format_instructions: str = FORMAT_INSTRUCTIONS,
     ) -> list[ChatModelMessage]:
         tool_strings = []
@@ -227,16 +223,13 @@ class ReactMultiDatasetRouter:
             tools: Sequence[PromptMessageTool],
             prefix: str = PREFIX,
             format_instructions: str = FORMAT_INSTRUCTIONS,
-            input_variables: Optional[list[str]] = None,
-    ) -> PromptTemplate:
+    ) -> CompletionModelPromptTemplate:
         """Create prompt in the style of the zero shot agent.
 
         Args:
             tools: List of tools the agent will have access to, used to format the
                 prompt.
             prefix: String to put before the list of tools.
-            input_variables: List of input variables the final prompt will expect.
-
         Returns:
             A PromptTemplate with the template assembled from the pieces here.
         """
@@ -249,6 +242,4 @@ Thought: {agent_scratchpad}
         tool_names = ", ".join([tool.name for tool in tools])
         format_instructions = format_instructions.format(tool_names=tool_names)
         template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
-        if input_variables is None:
-            input_variables = ["input", "agent_scratchpad"]
-        return PromptTemplate(template=template, input_variables=input_variables)
+        return CompletionModelPromptTemplate(text=template)

+ 3 - 12
api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py

@@ -1,8 +1,6 @@
 import threading
-from typing import Optional
 
 from flask import Flask, current_app
-from langchain.tools import BaseTool
 from pydantic import BaseModel, Field
 
 from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
@@ -10,6 +8,7 @@ from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelType
 from core.rag.datasource.retrieval_service import RetrievalService
 from core.rerank.rerank import RerankRunner
+from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
 from extensions.ext_database import db
 from models.dataset import Dataset, Document, DocumentSegment
 
@@ -29,20 +28,15 @@ class DatasetMultiRetrieverToolInput(BaseModel):
     query: str = Field(..., description="dataset multi retriever and rerank")
 
 
-class DatasetMultiRetrieverTool(BaseTool):
+class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
     """Tool for querying multi dataset."""
     name: str = "dataset_"
     args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput
     description: str = "dataset multi retriever and rerank. "
-    tenant_id: str
     dataset_ids: list[str]
-    top_k: int = 2
-    score_threshold: Optional[float] = None
     reranking_provider_name: str
     reranking_model_name: str
-    return_resource: bool
-    retriever_from: str
-    hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
+
 
     @classmethod
     def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs):
@@ -149,9 +143,6 @@ class DatasetMultiRetrieverTool(BaseTool):
 
             return str("\n".join(document_context_list))
 
-    async def _arun(self, tool_input: str) -> str:
-        raise NotImplementedError()
-
     def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_documents: list,
                    hit_callbacks: list[DatasetIndexToolCallbackHandler]):
         with flask_app.app_context():

+ 34 - 0
api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py

@@ -0,0 +1,34 @@
+from abc import abstractmethod
+from typing import Any, Optional
+
+from msal_extensions.persistence import ABC
+from pydantic import BaseModel
+
+from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
+
+
+class DatasetRetrieverBaseTool(BaseModel, ABC):
+    """Tool for querying a Dataset."""
+    name: str = "dataset"
+    description: str = "use this to retrieve a dataset. "
+    tenant_id: str
+    top_k: int = 2
+    score_threshold: Optional[float] = None
+    hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
+    return_resource: bool
+    retriever_from: str
+
+    class Config:
+        arbitrary_types_allowed = True
+
+    @abstractmethod
+    def _run(
+        self,
+        *args: Any,
+        **kwargs: Any,
+    ) -> Any:
+        """Use the tool.
+
+        Add run_manager: Optional[CallbackManagerForToolRun] = None
+        to child implementations to enable tracing,
+        """

+ 4 - 15
api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py

@@ -1,10 +1,8 @@
-from typing import Optional
 
-from langchain.tools import BaseTool
 from pydantic import BaseModel, Field
 
-from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
 from core.rag.datasource.retrieval_service import RetrievalService
+from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
 from extensions.ext_database import db
 from models.dataset import Dataset, Document, DocumentSegment
 
@@ -24,19 +22,13 @@ class DatasetRetrieverToolInput(BaseModel):
     query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
 
 
-class DatasetRetrieverTool(BaseTool):
+class DatasetRetrieverTool(DatasetRetrieverBaseTool):
     """Tool for querying a Dataset."""
     name: str = "dataset"
     args_schema: type[BaseModel] = DatasetRetrieverToolInput
     description: str = "use this to retrieve a dataset. "
-
-    tenant_id: str
     dataset_id: str
-    top_k: int = 2
-    score_threshold: Optional[float] = None
-    hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
-    return_resource: bool
-    retriever_from: str
+
 
     @classmethod
     def from_dataset(cls, dataset: Dataset, **kwargs):
@@ -153,7 +145,4 @@ class DatasetRetrieverTool(BaseTool):
                     for hit_callback in self.hit_callbacks:
                         hit_callback.return_retriever_resource_info(context_list)
 
-            return str("\n".join(document_context_list))
-
-    async def _arun(self, tool_input: str) -> str:
-        raise NotImplementedError()
+            return str("\n".join(document_context_list))

+ 9 - 10
api/core/tools/tool/dataset_retriever_tool.py

@@ -1,7 +1,5 @@
 from typing import Any
 
-from langchain.tools import BaseTool
-
 from core.app.app_config.entities import DatasetRetrieveConfigEntity
 from core.app.entities.app_invoke_entities import InvokeFrom
 from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
@@ -14,11 +12,12 @@ from core.tools.entities.tool_entities import (
     ToolParameter,
     ToolProviderType,
 )
+from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
 from core.tools.tool.tool import Tool
 
 
 class DatasetRetrieverTool(Tool):
-    langchain_tool: BaseTool
+    retrival_tool: DatasetRetrieverBaseTool
 
     @staticmethod
     def get_dataset_tools(tenant_id: str,
@@ -43,7 +42,7 @@ class DatasetRetrieverTool(Tool):
         # Agent only support SINGLE mode
         original_retriever_mode = retrieve_config.retrieve_strategy
         retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
-        langchain_tools = feature.to_dataset_retriever_tool(
+        retrival_tools = feature.to_dataset_retriever_tool(
             tenant_id=tenant_id,
             dataset_ids=dataset_ids,
             retrieve_config=retrieve_config,
@@ -54,17 +53,17 @@ class DatasetRetrieverTool(Tool):
         # restore retrieve strategy
         retrieve_config.retrieve_strategy = original_retriever_mode
 
-        # convert langchain tools to Tools
+        # convert retrival tools to Tools
         tools = []
-        for langchain_tool in langchain_tools:
+        for retrival_tool in retrival_tools:
             tool = DatasetRetrieverTool(
-                langchain_tool=langchain_tool,
-                identity=ToolIdentity(provider='', author='', name=langchain_tool.name, label=I18nObject(en_US='', zh_Hans='')),
+                retrival_tool=retrival_tool,
+                identity=ToolIdentity(provider='', author='', name=retrival_tool.name, label=I18nObject(en_US='', zh_Hans='')),
                 parameters=[],
                 is_team_authorization=True,
                 description=ToolDescription(
                     human=I18nObject(en_US='', zh_Hans=''),
-                    llm=langchain_tool.description),
+                    llm=retrival_tool.description),
                 runtime=DatasetRetrieverTool.Runtime()
             )
 
@@ -96,7 +95,7 @@ class DatasetRetrieverTool(Tool):
             return self.create_text_message(text='please input query')
 
         # invoke dataset retriever tool
-        result = self.langchain_tool._run(query=query)
+        result = self.retrival_tool._run(query=query)
 
         return self.create_text_message(text=result)