Procházet zdrojové kódy

feat: optimize db connection when llm invoking (#2774)

takatost před 1 rokem
rodič
revize
f073dca22a

+ 4 - 0
api/core/app_runner/assistant_app_runner.py

@@ -195,6 +195,10 @@ class AssistantApplicationRunner(AppRunner):
         if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []):
             agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
 
+        db.session.refresh(conversation)
+        db.session.refresh(message)
+        db.session.close()
+
         # start agent runner
         if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
             assistant_cot_runner = AssistantCotApplicationRunner(

+ 2 - 0
api/core/app_runner/basic_app_runner.py

@@ -192,6 +192,8 @@ class BasicApplicationRunner(AppRunner):
             model=app_orchestration_config.model_config.model
         )
 
+        db.session.close()
+
         invoke_result = model_instance.invoke_llm(
             prompt_messages=prompt_messages,
             model_parameters=app_orchestration_config.model_config.parameters,

+ 8 - 0
api/core/app_runner/generate_task_pipeline.py

@@ -89,6 +89,10 @@ class GenerateTaskPipeline:
         Process generate task pipeline.
         :return:
         """
+        db.session.refresh(self._conversation)
+        db.session.refresh(self._message)
+        db.session.close()
+
         if stream:
             return self._process_stream_response()
         else:
@@ -303,6 +307,7 @@ class GenerateTaskPipeline:
                     .first()
                 )
                 db.session.refresh(agent_thought)
+                db.session.close()
 
                 if agent_thought:
                     response = {
@@ -330,6 +335,8 @@ class GenerateTaskPipeline:
                     .filter(MessageFile.id == event.message_file_id)
                     .first()
                 )
+                db.session.close()
+
                 # get extension
                 if '.' in message_file.url:
                     extension = f'.{message_file.url.split(".")[-1]}'
@@ -413,6 +420,7 @@ class GenerateTaskPipeline:
         usage = llm_result.usage
 
         self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
+        self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
 
         self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages)
         self._message.message_tokens = usage.prompt_tokens

+ 3 - 3
api/core/application_manager.py

@@ -201,7 +201,7 @@ class ApplicationManager:
                 logger.exception("Unknown Error when generating")
                 queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
             finally:
-                db.session.remove()
+                db.session.close()
 
     def _handle_response(self, application_generate_entity: ApplicationGenerateEntity,
                          queue_manager: ApplicationQueueManager,
@@ -233,8 +233,6 @@ class ApplicationManager:
             else:
                 logger.exception(e)
                 raise e
-        finally:
-            db.session.remove()
 
     def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \
             -> AppOrchestrationConfigEntity:
@@ -651,6 +649,7 @@ class ApplicationManager:
 
             db.session.add(conversation)
             db.session.commit()
+            db.session.refresh(conversation)
         else:
             conversation = (
                 db.session.query(Conversation)
@@ -689,6 +688,7 @@ class ApplicationManager:
 
         db.session.add(message)
         db.session.commit()
+        db.session.refresh(message)
 
         for file in application_generate_entity.files:
             message_file = MessageFile(

+ 20 - 2
api/core/features/assistant_base_runner.py

@@ -114,6 +114,7 @@ class BaseAssistantApplicationRunner(AppRunner):
         self.agent_thought_count = db.session.query(MessageAgentThought).filter(
             MessageAgentThought.message_id == self.message.id,
         ).count()
+        db.session.close()
 
         # check if model supports stream tool call
         llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
@@ -341,13 +342,16 @@ class BaseAssistantApplicationRunner(AppRunner):
                 created_by=self.user_id,
             )
             db.session.add(message_file)
+            db.session.commit()
+            db.session.refresh(message_file)
+
             result.append((
                 message_file,
                 message.save_as
             ))
-            
-        db.session.commit()
 
+        db.session.close()
+            
         return result
         
     def create_agent_thought(self, message_id: str, message: str, 
@@ -384,6 +388,8 @@ class BaseAssistantApplicationRunner(AppRunner):
 
         db.session.add(thought)
         db.session.commit()
+        db.session.refresh(thought)
+        db.session.close()
 
         self.agent_thought_count += 1
 
@@ -401,6 +407,10 @@ class BaseAssistantApplicationRunner(AppRunner):
         """
         Save agent thought
         """
+        agent_thought = db.session.query(MessageAgentThought).filter(
+            MessageAgentThought.id == agent_thought.id
+        ).first()
+
         if thought is not None:
             agent_thought.thought = thought
 
@@ -451,6 +461,7 @@ class BaseAssistantApplicationRunner(AppRunner):
         agent_thought.tool_labels_str = json.dumps(labels)
 
         db.session.commit()
+        db.session.close()
     
     def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]:
         """
@@ -523,9 +534,14 @@ class BaseAssistantApplicationRunner(AppRunner):
         """
         convert tool variables to db variables
         """
+        db_variables = db.session.query(ToolConversationVariables).filter(
+            ToolConversationVariables.conversation_id == self.message.conversation_id,
+        ).first()
+
         db_variables.updated_at = datetime.utcnow()
         db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
         db.session.commit()
+        db.session.close()
 
     def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
         """
@@ -581,4 +597,6 @@ class BaseAssistantApplicationRunner(AppRunner):
                 if message.answer:
                     result.append(AssistantPromptMessage(content=message.answer))
 
+        db.session.close()
+
         return result