|
@@ -1,7 +1,7 @@
|
|
|
import re
|
|
|
from typing import List, Tuple, Any, Union, Sequence, Optional, cast
|
|
|
|
|
|
-from langchain import BasePromptTemplate
|
|
|
+from langchain import BasePromptTemplate, PromptTemplate
|
|
|
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
|
|
|
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
|
|
|
from langchain.callbacks.base import BaseCallbackManager
|
|
@@ -12,6 +12,7 @@ from langchain.tools import BaseTool
|
|
|
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
|
|
|
|
|
from core.chain.llm_chain import LLMChain
|
|
|
+from core.model_providers.models.entity.model_params import ModelMode
|
|
|
from core.model_providers.models.llm.base import BaseLLM
|
|
|
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
|
|
|
|
@@ -92,6 +93,10 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
|
|
rst = tool.run(tool_input={'query': kwargs['input']})
|
|
|
return AgentFinish(return_values={"output": rst}, log=rst)
|
|
|
|
|
|
+ if intermediate_steps:
|
|
|
+ _, observation = intermediate_steps[-1]
|
|
|
+ return AgentFinish(return_values={"output": observation}, log=observation)
|
|
|
+
|
|
|
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
|
|
|
|
|
try:
|
|
@@ -107,6 +112,8 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
|
|
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
|
|
|
tool_inputs['query'] = kwargs['input']
|
|
|
agent_decision.tool_input = tool_inputs
|
|
|
+ elif isinstance(tool_inputs, str):
|
|
|
+ agent_decision.tool_input = kwargs['input']
|
|
|
else:
|
|
|
agent_decision.return_values['output'] = ''
|
|
|
return agent_decision
|
|
@@ -143,6 +150,61 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
|
|
]
|
|
|
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
|
|
|
|
|
+ @classmethod
|
|
|
+ def create_completion_prompt(
|
|
|
+ cls,
|
|
|
+ tools: Sequence[BaseTool],
|
|
|
+ prefix: str = PREFIX,
|
|
|
+ format_instructions: str = FORMAT_INSTRUCTIONS,
|
|
|
+ input_variables: Optional[List[str]] = None,
|
|
|
+ ) -> PromptTemplate:
|
|
|
+ """Create prompt in the style of the zero shot agent.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ tools: List of tools the agent will have access to, used to format the
|
|
|
+ prompt.
|
|
|
+ prefix: String to put before the list of tools.
|
|
|
+ input_variables: List of input variables the final prompt will expect.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ A PromptTemplate with the template assembled from the pieces here.
|
|
|
+ """
|
|
|
+ suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
|
|
+Question: {input}
|
|
|
+Thought: {agent_scratchpad}
|
|
|
+"""
|
|
|
+
|
|
|
+ tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
|
|
|
+ tool_names = ", ".join([tool.name for tool in tools])
|
|
|
+ format_instructions = format_instructions.format(tool_names=tool_names)
|
|
|
+ template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
|
|
|
+ if input_variables is None:
|
|
|
+ input_variables = ["input", "agent_scratchpad"]
|
|
|
+ return PromptTemplate(template=template, input_variables=input_variables)
|
|
|
+
|
|
|
+ def _construct_scratchpad(
|
|
|
+ self, intermediate_steps: List[Tuple[AgentAction, str]]
|
|
|
+ ) -> str:
|
|
|
+ agent_scratchpad = ""
|
|
|
+ for action, observation in intermediate_steps:
|
|
|
+ agent_scratchpad += action.log
|
|
|
+ agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
|
|
|
+
|
|
|
+ if not isinstance(agent_scratchpad, str):
|
|
|
+ raise ValueError("agent_scratchpad should be of type string.")
|
|
|
+ if agent_scratchpad:
|
|
|
+ llm_chain = cast(LLMChain, self.llm_chain)
|
|
|
+ if llm_chain.model_instance.model_mode == ModelMode.CHAT:
|
|
|
+ return (
|
|
|
+ f"This was your previous work "
|
|
|
+ f"(but I haven't seen any of it! I only see what "
|
|
|
+ f"you return as final answer):\n{agent_scratchpad}"
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ return agent_scratchpad
|
|
|
+ else:
|
|
|
+ return agent_scratchpad
|
|
|
+
|
|
|
@classmethod
|
|
|
def from_llm_and_tools(
|
|
|
cls,
|
|
@@ -160,15 +222,23 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
|
|
) -> Agent:
|
|
|
"""Construct an agent from an LLM and tools."""
|
|
|
cls._validate_tools(tools)
|
|
|
- prompt = cls.create_prompt(
|
|
|
- tools,
|
|
|
- prefix=prefix,
|
|
|
- suffix=suffix,
|
|
|
- human_message_template=human_message_template,
|
|
|
- format_instructions=format_instructions,
|
|
|
- input_variables=input_variables,
|
|
|
- memory_prompts=memory_prompts,
|
|
|
- )
|
|
|
+ if model_instance.model_mode == ModelMode.CHAT:
|
|
|
+ prompt = cls.create_prompt(
|
|
|
+ tools,
|
|
|
+ prefix=prefix,
|
|
|
+ suffix=suffix,
|
|
|
+ human_message_template=human_message_template,
|
|
|
+ format_instructions=format_instructions,
|
|
|
+ input_variables=input_variables,
|
|
|
+ memory_prompts=memory_prompts,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ prompt = cls.create_completion_prompt(
|
|
|
+ tools,
|
|
|
+ prefix=prefix,
|
|
|
+ format_instructions=format_instructions,
|
|
|
+ input_variables=input_variables
|
|
|
+ )
|
|
|
llm_chain = LLMChain(
|
|
|
model_instance=model_instance,
|
|
|
prompt=prompt,
|