Sfoglia il codice sorgente

feat: support multi datasets router chain mode (#231)

John Wang 1 anno fa
parent
commit
88545184be

+ 132 - 0
api/core/chain/llm_router_chain.py

@@ -0,0 +1,132 @@
+"""Base classes for LLM-powered router chains."""
+from __future__ import annotations
+
+import json
+from typing import Any, Dict, List, Optional, Type, cast, NamedTuple
+
+from langchain.chains.base import Chain
+from pydantic import root_validator
+
+from langchain.chains import LLMChain
+from langchain.prompts import BasePromptTemplate
+from langchain.schema import BaseOutputParser, OutputParserException, BaseLanguageModel
+
+
+class Route(NamedTuple):
+    destination: Optional[str]
+    next_inputs: Dict[str, Any]
+
+
+class LLMRouterChain(Chain):
+    """A router chain that uses an LLM chain to perform routing."""
+
+    llm_chain: LLMChain
+    """LLM chain used to perform routing"""
+
+    @root_validator()
+    def validate_prompt(cls, values: dict) -> dict:
+        prompt = values["llm_chain"].prompt
+        if prompt.output_parser is None:
+            raise ValueError(
+                "LLMRouterChain requires base llm_chain prompt to have an output"
+                " parser that converts LLM text output to a dictionary with keys"
+                " 'destination' and 'next_inputs'. Received a prompt with no output"
+                " parser."
+            )
+        return values
+
+    @property
+    def input_keys(self) -> List[str]:
+        """Will be whatever keys the LLM chain prompt expects.
+
+        :meta private:
+        """
+        return self.llm_chain.input_keys
+
+    def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
+        super()._validate_outputs(outputs)
+        if not isinstance(outputs["next_inputs"], dict):
+            raise ValueError
+
+    def _call(
+        self,
+        inputs: Dict[str, Any]
+    ) -> Dict[str, Any]:
+        output = cast(
+            Dict[str, Any],
+            self.llm_chain.predict_and_parse(**inputs),
+        )
+        return output
+
+    @classmethod
+    def from_llm(
+        cls, llm: BaseLanguageModel, prompt: BasePromptTemplate, **kwargs: Any
+    ) -> LLMRouterChain:
+        """Convenience constructor."""
+        llm_chain = LLMChain(llm=llm, prompt=prompt)
+        return cls(llm_chain=llm_chain, **kwargs)
+
+    @property
+    def output_keys(self) -> List[str]:
+        return ["destination", "next_inputs"]
+
+    def route(self, inputs: Dict[str, Any]) -> Route:
+        result = self(inputs)
+        return Route(result["destination"], result["next_inputs"])
+
+
+class RouterOutputParser(BaseOutputParser[Dict[str, str]]):
+    """Parser for output of router chain int he multi-prompt chain."""
+
+    default_destination: str = "DEFAULT"
+    next_inputs_type: Type = str
+    next_inputs_inner_key: str = "input"
+
+    def parse_json_markdown(self, json_string: str) -> dict:
+        # Remove the triple backticks if present
+        json_string = json_string.replace("```json", "").replace("```", "")
+
+        # Strip whitespace and newlines from the start and end
+        json_string = json_string.strip()
+
+        # Parse the JSON string into a Python dictionary
+        parsed = json.loads(json_string)
+
+        return parsed
+
+    def parse_and_check_json_markdown(self, text: str, expected_keys: List[str]) -> dict:
+        try:
+            json_obj = self.parse_json_markdown(text)
+        except json.JSONDecodeError as e:
+            raise OutputParserException(f"Got invalid JSON object. Error: {e}")
+        for key in expected_keys:
+            if key not in json_obj:
+                raise OutputParserException(
+                    f"Got invalid return object. Expected key `{key}` "
+                    f"to be present, but got {json_obj}"
+                )
+        return json_obj
+
+    def parse(self, text: str) -> Dict[str, Any]:
+        try:
+            expected_keys = ["destination", "next_inputs"]
+            parsed = self.parse_and_check_json_markdown(text, expected_keys)
+            if not isinstance(parsed["destination"], str):
+                raise ValueError("Expected 'destination' to be a string.")
+            if not isinstance(parsed["next_inputs"], self.next_inputs_type):
+                raise ValueError(
+                    f"Expected 'next_inputs' to be {self.next_inputs_type}."
+                )
+            parsed["next_inputs"] = {self.next_inputs_inner_key: parsed["next_inputs"]}
+            if (
+                parsed["destination"].strip().lower()
+                == self.default_destination.lower()
+            ):
+                parsed["destination"] = None
+            else:
+                parsed["destination"] = parsed["destination"].strip()
+            return parsed
+        except Exception as e:
+            raise OutputParserException(
+                f"Parsing text\n{text}\n raised following error:\n{e}"
+            )

+ 22 - 30
api/core/chain/main_chain_builder.py

@@ -1,18 +1,18 @@
 from typing import Optional, List
 
-from langchain.callbacks import SharedCallbackManager
+from langchain.callbacks import SharedCallbackManager, CallbackManager
 from langchain.chains import SequentialChain
 from langchain.chains.base import Chain
 from langchain.memory.chat_memory import BaseChatMemory
 
-from core.agent.agent_builder import AgentBuilder
 from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
-from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
 from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
+from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
 from core.chain.chain_builder import ChainBuilder
-from core.constant import llm_constant
+from core.chain.multi_dataset_router_chain import MultiDatasetRouterChain
 from core.conversation_message_task import ConversationMessageTask
-from core.tool.dataset_tool_builder import DatasetToolBuilder
+from extensions.ext_database import db
+from models.dataset import Dataset
 
 
 class MainChainBuilder:
@@ -31,8 +31,7 @@ class MainChainBuilder:
             tenant_id=tenant_id,
             agent_mode=agent_mode,
             memory=memory,
-            dataset_tool_callback_handler=DatasetToolCallbackHandler(conversation_message_task),
-            agent_loop_gather_callback_handler=chain_callback_handler.agent_loop_gather_callback_handler
+            conversation_message_task=conversation_message_task
         )
         chains += tool_chains
 
@@ -59,15 +58,15 @@ class MainChainBuilder:
 
     @classmethod
     def get_agent_chains(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
-                         dataset_tool_callback_handler: DatasetToolCallbackHandler,
-                         agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler):
+                         conversation_message_task: ConversationMessageTask):
         # agent mode
         chains = []
         if agent_mode and agent_mode.get('enabled'):
             tools = agent_mode.get('tools', [])
 
             pre_fixed_chains = []
-            agent_tools = []
+            # agent_tools = []
+            datasets = []
             for tool in tools:
                 tool_type = list(tool.keys())[0]
                 tool_config = list(tool.values())[0]
@@ -76,34 +75,27 @@ class MainChainBuilder:
                     if chain:
                         pre_fixed_chains.append(chain)
                 elif tool_type == "dataset":
-                    dataset_tool = DatasetToolBuilder.build_dataset_tool(
-                        tenant_id=tenant_id,
-                        dataset_id=tool_config.get("id"),
-                        response_mode='no_synthesizer',  # "compact"
-                        callback_handler=dataset_tool_callback_handler
-                    )
+                    # get dataset from dataset id
+                    dataset = db.session.query(Dataset).filter(
+                        Dataset.tenant_id == tenant_id,
+                        Dataset.id == tool_config.get("id")
+                    ).first()
 
-                    if dataset_tool:
-                        agent_tools.append(dataset_tool)
+                    if dataset:
+                        datasets.append(dataset)
 
             # add pre-fixed chains
             chains += pre_fixed_chains
 
-            if len(agent_tools) == 1:
+            if len(datasets) > 0:
                 # tool to chain
-                tool_chain = ChainBuilder.to_tool_chain(tool=agent_tools[0], output_key='tool_output')
-                chains.append(tool_chain)
-            elif len(agent_tools) > 1:
-                # build agent config
-                agent_chain = AgentBuilder.to_agent_chain(
+                multi_dataset_router_chain = MultiDatasetRouterChain.from_datasets(
                     tenant_id=tenant_id,
-                    tools=agent_tools,
-                    memory=memory,
-                    dataset_tool_callback_handler=dataset_tool_callback_handler,
-                    agent_loop_gather_callback_handler=agent_loop_gather_callback_handler
+                    datasets=datasets,
+                    conversation_message_task=conversation_message_task,
+                    callback_manager=CallbackManager([DifyStdOutCallbackHandler()])
                 )
-
-                chains.append(agent_chain)
+                chains.append(multi_dataset_router_chain)
 
         final_output_key = cls.get_chains_output_key(chains)
 

+ 138 - 0
api/core/chain/multi_dataset_router_chain.py

@@ -0,0 +1,138 @@
+from typing import Mapping, List, Dict, Any, Optional
+
+from langchain import LLMChain, PromptTemplate, ConversationChain
+from langchain.callbacks import CallbackManager
+from langchain.chains.base import Chain
+from langchain.schema import BaseLanguageModel
+from pydantic import Extra
+
+from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
+from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
+from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser
+from core.conversation_message_task import ConversationMessageTask
+from core.llm.llm_builder import LLMBuilder
+from core.tool.dataset_tool_builder import DatasetToolBuilder
+from core.tool.llama_index_tool import EnhanceLlamaIndexTool
+from models.dataset import Dataset
+
+MULTI_PROMPT_ROUTER_TEMPLATE = """
+Given a raw text input to a language model select the model prompt best suited for \
+the input. You will be given the names of the available prompts and a description of \
+what the prompt is best suited for. You may also revise the original input if you \
+think that revising it will ultimately lead to a better response from the language \
+model.
+
+<< FORMATTING >>
+Return a markdown code snippet with a JSON object formatted to look like:
+```json
+{{{{
+    "destination": string \\ name of the prompt to use or "DEFAULT"
+    "next_inputs": string \\ a potentially modified version of the original input
+}}}}
+```
+
+REMEMBER: "destination" MUST be one of the candidate prompt names specified below OR \
+it can be "DEFAULT" if the input is not well suited for any of the candidate prompts.
+REMEMBER: "next_inputs" can just be the original input if you don't think any \
+modifications are needed.
+
+<< CANDIDATE PROMPTS >>
+{destinations}
+
+<< INPUT >>
+{{input}}
+
+<< OUTPUT >>
+"""
+
+
+class MultiDatasetRouterChain(Chain):
+    """Use a single chain to route an input to one of multiple candidate chains."""
+
+    router_chain: LLMRouterChain
+    """Chain for deciding a destination chain and the input to it."""
+    dataset_tools: Mapping[str, EnhanceLlamaIndexTool]
+    """Map of name to candidate chains that inputs can be routed to."""
+
+    class Config:
+        """Configuration for this pydantic object."""
+
+        extra = Extra.forbid
+        arbitrary_types_allowed = True
+
+    @property
+    def input_keys(self) -> List[str]:
+        """Will be whatever keys the router chain prompt expects.
+
+        :meta private:
+        """
+        return self.router_chain.input_keys
+
+    @property
+    def output_keys(self) -> List[str]:
+        return ["text"]
+
+    @classmethod
+    def from_datasets(
+            cls,
+            tenant_id: str,
+            datasets: List[Dataset],
+            conversation_message_task: ConversationMessageTask,
+            **kwargs: Any,
+    ):
+        """Convenience constructor for instantiating from destination prompts."""
+        llm_callback_manager = CallbackManager([DifyStdOutCallbackHandler()])
+        llm = LLMBuilder.to_llm(
+            tenant_id=tenant_id,
+            model_name='gpt-3.5-turbo',
+            temperature=0,
+            max_tokens=1024,
+            callback_manager=llm_callback_manager
+        )
+
+        destinations = [f"{d.id}: {d.description}" for d in datasets]
+        destinations_str = "\n".join(destinations)
+        router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(
+            destinations=destinations_str
+        )
+        router_prompt = PromptTemplate(
+            template=router_template,
+            input_variables=["input"],
+            output_parser=RouterOutputParser(),
+        )
+        router_chain = LLMRouterChain.from_llm(llm, router_prompt)
+        dataset_tools = {}
+        for dataset in datasets:
+            dataset_tool = DatasetToolBuilder.build_dataset_tool(
+                dataset=dataset,
+                response_mode='no_synthesizer',  # "compact"
+                callback_handler=DatasetToolCallbackHandler(conversation_message_task)
+            )
+            dataset_tools[dataset.id] = dataset_tool
+        return cls(
+            router_chain=router_chain,
+            dataset_tools=dataset_tools,
+            **kwargs,
+        )
+
+    def _call(
+        self,
+        inputs: Dict[str, Any]
+    ) -> Dict[str, Any]:
+        if len(self.dataset_tools) == 0:
+            return {"text": ''}
+        elif len(self.dataset_tools) == 1:
+            return {"text": next(iter(self.dataset_tools.values())).run(inputs['input'])}
+
+        route = self.router_chain.route(inputs)
+
+        if not route.destination:
+            return {"text": ''}
+        elif route.destination in self.dataset_tools:
+            return {"text": self.dataset_tools[route.destination].run(
+                route.next_inputs['input']
+            )}
+        else:
+            raise ValueError(
+                f"Received invalid destination chain name '{route.destination}'"
+            )

+ 3 - 13
api/core/tool/dataset_tool_builder.py

@@ -10,24 +10,14 @@ from core.index.keyword_table_index import KeywordTableIndex
 from core.index.vector_index import VectorIndex
 from core.prompt.prompts import QUERY_KEYWORD_EXTRACT_TEMPLATE
 from core.tool.llama_index_tool import EnhanceLlamaIndexTool
-from extensions.ext_database import db
 from models.dataset import Dataset
 
 
 class DatasetToolBuilder:
     @classmethod
-    def build_dataset_tool(cls, tenant_id: str, dataset_id: str,
+    def build_dataset_tool(cls, dataset: Dataset,
                            response_mode: str = "no_synthesizer",
                            callback_handler: Optional[DatasetToolCallbackHandler] = None):
-        # get dataset from dataset id
-        dataset = db.session.query(Dataset).filter(
-            Dataset.tenant_id == tenant_id,
-            Dataset.id == dataset_id
-        ).first()
-
-        if not dataset:
-            return None
-
         if dataset.indexing_technique == "economy":
             # use keyword table query
             index = KeywordTableIndex(dataset=dataset).query_index
@@ -65,7 +55,7 @@ class DatasetToolBuilder:
 
         index_tool_config = IndexToolConfig(
             index=index,
-            name=f"dataset-{dataset_id}",
+            name=f"dataset-{dataset.id}",
             description=description,
             index_query_kwargs=query_kwargs,
             tool_kwargs={
@@ -75,7 +65,7 @@ class DatasetToolBuilder:
             # return_direct: Whether to return LLM results directly or process the output data with an Output Parser
         )
 
-        index_callback_handler = DatasetIndexToolCallbackHandler(dataset_id=dataset_id)
+        index_callback_handler = DatasetIndexToolCallbackHandler(dataset_id=dataset.id)
 
         return EnhanceLlamaIndexTool.from_tool_config(
             tool_config=index_tool_config,