|
@@ -0,0 +1,132 @@
|
|
|
+"""Base classes for LLM-powered router chains."""
|
|
|
+from __future__ import annotations
|
|
|
+
|
|
|
+import json
|
|
|
+from typing import Any, Dict, List, Optional, Type, cast, NamedTuple
|
|
|
+
|
|
|
+from langchain.chains.base import Chain
|
|
|
+from pydantic import root_validator
|
|
|
+
|
|
|
+from langchain.chains import LLMChain
|
|
|
+from langchain.prompts import BasePromptTemplate
|
|
|
+from langchain.schema import BaseOutputParser, OutputParserException, BaseLanguageModel
|
|
|
+
|
|
|
+
|
|
|
+class Route(NamedTuple):
|
|
|
+ destination: Optional[str]
|
|
|
+ next_inputs: Dict[str, Any]
|
|
|
+
|
|
|
+
|
|
|
+class LLMRouterChain(Chain):
|
|
|
+ """A router chain that uses an LLM chain to perform routing."""
|
|
|
+
|
|
|
+ llm_chain: LLMChain
|
|
|
+ """LLM chain used to perform routing"""
|
|
|
+
|
|
|
+ @root_validator()
|
|
|
+ def validate_prompt(cls, values: dict) -> dict:
|
|
|
+ prompt = values["llm_chain"].prompt
|
|
|
+ if prompt.output_parser is None:
|
|
|
+ raise ValueError(
|
|
|
+ "LLMRouterChain requires base llm_chain prompt to have an output"
|
|
|
+ " parser that converts LLM text output to a dictionary with keys"
|
|
|
+ " 'destination' and 'next_inputs'. Received a prompt with no output"
|
|
|
+ " parser."
|
|
|
+ )
|
|
|
+ return values
|
|
|
+
|
|
|
+ @property
|
|
|
+ def input_keys(self) -> List[str]:
|
|
|
+ """Will be whatever keys the LLM chain prompt expects.
|
|
|
+
|
|
|
+ :meta private:
|
|
|
+ """
|
|
|
+ return self.llm_chain.input_keys
|
|
|
+
|
|
|
+ def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
|
|
|
+ super()._validate_outputs(outputs)
|
|
|
+ if not isinstance(outputs["next_inputs"], dict):
|
|
|
+ raise ValueError
|
|
|
+
|
|
|
+ def _call(
|
|
|
+ self,
|
|
|
+ inputs: Dict[str, Any]
|
|
|
+ ) -> Dict[str, Any]:
|
|
|
+ output = cast(
|
|
|
+ Dict[str, Any],
|
|
|
+ self.llm_chain.predict_and_parse(**inputs),
|
|
|
+ )
|
|
|
+ return output
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def from_llm(
|
|
|
+ cls, llm: BaseLanguageModel, prompt: BasePromptTemplate, **kwargs: Any
|
|
|
+ ) -> LLMRouterChain:
|
|
|
+ """Convenience constructor."""
|
|
|
+ llm_chain = LLMChain(llm=llm, prompt=prompt)
|
|
|
+ return cls(llm_chain=llm_chain, **kwargs)
|
|
|
+
|
|
|
+ @property
|
|
|
+ def output_keys(self) -> List[str]:
|
|
|
+ return ["destination", "next_inputs"]
|
|
|
+
|
|
|
+ def route(self, inputs: Dict[str, Any]) -> Route:
|
|
|
+ result = self(inputs)
|
|
|
+ return Route(result["destination"], result["next_inputs"])
|
|
|
+
|
|
|
+
|
|
|
+class RouterOutputParser(BaseOutputParser[Dict[str, str]]):
|
|
|
+ """Parser for output of router chain int he multi-prompt chain."""
|
|
|
+
|
|
|
+ default_destination: str = "DEFAULT"
|
|
|
+ next_inputs_type: Type = str
|
|
|
+ next_inputs_inner_key: str = "input"
|
|
|
+
|
|
|
+ def parse_json_markdown(self, json_string: str) -> dict:
|
|
|
+ # Remove the triple backticks if present
|
|
|
+ json_string = json_string.replace("```json", "").replace("```", "")
|
|
|
+
|
|
|
+ # Strip whitespace and newlines from the start and end
|
|
|
+ json_string = json_string.strip()
|
|
|
+
|
|
|
+ # Parse the JSON string into a Python dictionary
|
|
|
+ parsed = json.loads(json_string)
|
|
|
+
|
|
|
+ return parsed
|
|
|
+
|
|
|
+ def parse_and_check_json_markdown(self, text: str, expected_keys: List[str]) -> dict:
|
|
|
+ try:
|
|
|
+ json_obj = self.parse_json_markdown(text)
|
|
|
+ except json.JSONDecodeError as e:
|
|
|
+ raise OutputParserException(f"Got invalid JSON object. Error: {e}")
|
|
|
+ for key in expected_keys:
|
|
|
+ if key not in json_obj:
|
|
|
+ raise OutputParserException(
|
|
|
+ f"Got invalid return object. Expected key `{key}` "
|
|
|
+ f"to be present, but got {json_obj}"
|
|
|
+ )
|
|
|
+ return json_obj
|
|
|
+
|
|
|
+ def parse(self, text: str) -> Dict[str, Any]:
|
|
|
+ try:
|
|
|
+ expected_keys = ["destination", "next_inputs"]
|
|
|
+ parsed = self.parse_and_check_json_markdown(text, expected_keys)
|
|
|
+ if not isinstance(parsed["destination"], str):
|
|
|
+ raise ValueError("Expected 'destination' to be a string.")
|
|
|
+ if not isinstance(parsed["next_inputs"], self.next_inputs_type):
|
|
|
+ raise ValueError(
|
|
|
+ f"Expected 'next_inputs' to be {self.next_inputs_type}."
|
|
|
+ )
|
|
|
+ parsed["next_inputs"] = {self.next_inputs_inner_key: parsed["next_inputs"]}
|
|
|
+ if (
|
|
|
+ parsed["destination"].strip().lower()
|
|
|
+ == self.default_destination.lower()
|
|
|
+ ):
|
|
|
+ parsed["destination"] = None
|
|
|
+ else:
|
|
|
+ parsed["destination"] = parsed["destination"].strip()
|
|
|
+ return parsed
|
|
|
+ except Exception as e:
|
|
|
+ raise OutputParserException(
|
|
|
+ f"Parsing text\n{text}\n raised following error:\n{e}"
|
|
|
+ )
|