|
@@ -114,6 +114,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
|
|
self.agent_thought_count = db.session.query(MessageAgentThought).filter(
|
|
|
MessageAgentThought.message_id == self.message.id,
|
|
|
).count()
|
|
|
+ db.session.close()
|
|
|
|
|
|
# check if model supports stream tool call
|
|
|
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
|
@@ -341,13 +342,16 @@ class BaseAssistantApplicationRunner(AppRunner):
|
|
|
created_by=self.user_id,
|
|
|
)
|
|
|
db.session.add(message_file)
|
|
|
+ db.session.commit()
|
|
|
+ db.session.refresh(message_file)
|
|
|
+
|
|
|
result.append((
|
|
|
message_file,
|
|
|
message.save_as
|
|
|
))
|
|
|
-
|
|
|
- db.session.commit()
|
|
|
|
|
|
+ db.session.close()
|
|
|
+
|
|
|
return result
|
|
|
|
|
|
def create_agent_thought(self, message_id: str, message: str,
|
|
@@ -384,6 +388,8 @@ class BaseAssistantApplicationRunner(AppRunner):
|
|
|
|
|
|
db.session.add(thought)
|
|
|
db.session.commit()
|
|
|
+ db.session.refresh(thought)
|
|
|
+ db.session.close()
|
|
|
|
|
|
self.agent_thought_count += 1
|
|
|
|
|
@@ -401,6 +407,10 @@ class BaseAssistantApplicationRunner(AppRunner):
|
|
|
"""
|
|
|
Save agent thought
|
|
|
"""
|
|
|
+ agent_thought = db.session.query(MessageAgentThought).filter(
|
|
|
+ MessageAgentThought.id == agent_thought.id
|
|
|
+ ).first()
|
|
|
+
|
|
|
if thought is not None:
|
|
|
agent_thought.thought = thought
|
|
|
|
|
@@ -451,6 +461,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
|
|
agent_thought.tool_labels_str = json.dumps(labels)
|
|
|
|
|
|
db.session.commit()
|
|
|
+ db.session.close()
|
|
|
|
|
|
def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]:
|
|
|
"""
|
|
@@ -523,9 +534,14 @@ class BaseAssistantApplicationRunner(AppRunner):
|
|
|
"""
|
|
|
convert tool variables to db variables
|
|
|
"""
|
|
|
+ db_variables = db.session.query(ToolConversationVariables).filter(
|
|
|
+ ToolConversationVariables.conversation_id == self.message.conversation_id,
|
|
|
+ ).first()
|
|
|
+
|
|
|
db_variables.updated_at = datetime.utcnow()
|
|
|
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
|
|
|
db.session.commit()
|
|
|
+ db.session.close()
|
|
|
|
|
|
def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
|
|
"""
|
|
@@ -581,4 +597,6 @@ class BaseAssistantApplicationRunner(AppRunner):
|
|
|
if message.answer:
|
|
|
result.append(AssistantPromptMessage(content=message.answer))
|
|
|
|
|
|
+ db.session.close()
|
|
|
+
|
|
|
return result
|