Преглед на файлове

feat: optimize error record in agent (#869)

takatost преди 1 година
родител
ревизия
2dfb3e95f6

+ 5 - 1
api/core/agent/agent/multi_dataset_router_agent.py

@@ -59,7 +59,11 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
             _, observation = intermediate_steps[-1]
             return AgentFinish(return_values={"output": observation}, log=observation)
 
-        return super().plan(intermediate_steps, callbacks, **kwargs)
+        try:
+            return super().plan(intermediate_steps, callbacks, **kwargs)
+        except Exception as e:
+            new_exception = self.model_instance.handle_exceptions(e)
+            raise new_exception
 
     async def aplan(
             self,

+ 7 - 3
api/core/agent/agent/openai_function_call.py

@@ -50,9 +50,13 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
         prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
         messages = prompt.to_messages()
 
-        predicted_message = self.llm.predict_messages(
-            messages, functions=self.functions, callbacks=None
-        )
+        try:
+            predicted_message = self.llm.predict_messages(
+                messages, functions=self.functions, callbacks=None
+            )
+        except Exception as e:
+            new_exception = self.model_instance.handle_exceptions(e)
+            raise new_exception
 
         function_call = predicted_message.additional_kwargs.get("function_call", {})
 

+ 7 - 3
api/core/agent/agent/openai_multi_function_call.py

@@ -50,9 +50,13 @@ class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, Ope
         prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
         messages = prompt.to_messages()
 
-        predicted_message = self.llm.predict_messages(
-            messages, functions=self.functions, callbacks=None
-        )
+        try:
+            predicted_message = self.llm.predict_messages(
+                messages, functions=self.functions, callbacks=None
+            )
+        except Exception as e:
+            new_exception = self.model_instance.handle_exceptions(e)
+            raise new_exception
 
         function_call = predicted_message.additional_kwargs.get("function_call", {})
 

+ 6 - 1
api/core/agent/agent/structed_multi_dataset_router_agent.py

@@ -94,7 +94,12 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
             return AgentFinish(return_values={"output": rst}, log=rst)
 
         full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
-        full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
+
+        try:
+            full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
+        except Exception as e:
+            new_exception = self.model_instance.handle_exceptions(e)
+            raise new_exception
 
         try:
             return self.output_parser.parse(full_output)

+ 6 - 2
api/core/agent/agent/structured_chat.py

@@ -89,8 +89,8 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
             Action specifying what tool to use.
         """
         full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
-
         prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)])
+
         messages = []
         if prompts:
             messages = prompts[0].to_messages()
@@ -99,7 +99,11 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
         if rest_tokens < 0:
             full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
 
-        full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
+        try:
+            full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
+        except Exception as e:
+            new_exception = self.model_instance.handle_exceptions(e)
+            raise new_exception
 
         try:
             return self.output_parser.parse(full_output)

+ 2 - 2
api/core/callback_handler/agent_loop_gather_callback_handler.py

@@ -85,7 +85,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
     def on_llm_error(
         self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
     ) -> None:
-        logging.exception(error)
+        logging.debug("Agent on_llm_error: %s", error)
         self._agent_loops = []
         self._current_loop = None
         self._message_agent_thought = None
@@ -164,7 +164,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
         self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
     ) -> None:
         """Do nothing."""
-        logging.exception(error)
+        logging.debug("Agent on_tool_error: %s", error)
         self._agent_loops = []
         self._current_loop = None
         self._message_agent_thought = None

+ 1 - 1
api/core/callback_handler/dataset_tool_callback_handler.py

@@ -68,4 +68,4 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
         self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
     ) -> None:
         """Do nothing."""
-        logging.exception(error)
+        logging.debug("Dataset tool on_llm_error: %s", error)

+ 1 - 1
api/core/callback_handler/main_chain_gather_callback_handler.py

@@ -72,5 +72,5 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
     def on_chain_error(
         self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
     ) -> None:
-        logging.exception(error)
+        logging.debug("Dataset tool on_chain_error: %s", error)
         self.clear_chain_results()