Переглянути джерело

feat: fix dataset retrieve agent llm not support error (#656)

John Wang 1 рік тому
батько
коміт
ba3dc8cae0

+ 0 - 1
api/core/agent/agent/multi_dataset_router_agent.py

@@ -73,7 +73,6 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
             ),
             **kwargs: Any,
     ) -> BaseSingleActionAgent:
-        llm.model_name = 'gpt-3.5-turbo'
         return super().from_llm_and_tools(
             llm=llm,
             tools=tools,

+ 2 - 1
api/core/agent/agent_executor.py

@@ -31,6 +31,7 @@ class AgentConfiguration(BaseModel):
     llm: BaseLanguageModel
     tools: list[BaseTool]
     summary_llm: BaseLanguageModel
+    dataset_llm: BaseLanguageModel
     memory: Optional[BaseChatMemory] = None
     callbacks: Callbacks = None
     max_iterations: int = 6
@@ -84,7 +85,7 @@ class AgentExecutor:
         elif self.configuration.strategy == PlanningStrategy.ROUTER:
             self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
             agent = MultiDatasetRouterAgent.from_llm_and_tools(
-                llm=self.configuration.llm,
+                llm=self.configuration.dataset_llm,
                 tools=self.configuration.tools,
                 extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
                 verbose=True

+ 10 - 0
api/core/orchestrator_rule_parser.py

@@ -32,6 +32,7 @@ class OrchestratorRuleParser:
         self.tenant_id = tenant_id
         self.app_model_config = app_model_config
         self.agent_summary_model_name = "gpt-3.5-turbo-16k"
+        self.dataset_retrieve_model_name = "gpt-3.5-turbo"
 
     def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
                        rest_tokens: int, chain_callback: MainChainGatherCallbackHandler) \
@@ -89,11 +90,20 @@ class OrchestratorRuleParser:
             if len(tools) == 0:
                 return None
 
+            dataset_llm = LLMBuilder.to_llm(
+                tenant_id=self.tenant_id,
+                model_name=self.dataset_retrieve_model_name,
+                temperature=0,
+                max_tokens=500,
+                callbacks=[DifyStdOutCallbackHandler()]
+            )
+
             agent_configuration = AgentConfiguration(
                 strategy=planning_strategy,
                 llm=agent_llm,
                 tools=tools,
                 summary_llm=summary_llm,
+                dataset_llm=dataset_llm,
                 memory=memory,
                 callbacks=[chain_callback, agent_callback],
                 max_iterations=10,