Explorar el Código

fix: agent parse result error (#676)

John Wang hace 1 año
padre
commit
626c78a690
Se han modificado 1 ficheros con 7 adiciones y 2 borrados
  1. 7 2
      api/core/agent/agent/structured_chat.py

+ 7 - 2
api/core/agent/agent/structured_chat.py

@@ -9,7 +9,7 @@ from langchain.callbacks.base import BaseCallbackManager
 from langchain.callbacks.manager import Callbacks
 from langchain.callbacks.manager import Callbacks
 from langchain.memory.summary import SummarizerMixin
 from langchain.memory.summary import SummarizerMixin
 from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
 from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
-from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage
+from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException
 from langchain.tools import BaseTool
 from langchain.tools import BaseTool
 from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
 from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
 
 
@@ -94,7 +94,12 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
             full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
             full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
 
 
         full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
         full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
-        return self.output_parser.parse(full_output)
+
+        try:
+            return self.output_parser.parse(full_output)
+        except OutputParserException:
+            return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
+                                          "I don't know how to respond to that."}, "")
 
 
     def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
     def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
         if len(intermediate_steps) >= 2:
         if len(intermediate_steps) >= 2: