|
@@ -5,6 +5,9 @@ from collections.abc import Generator, Mapping
|
|
|
from threading import Thread
|
|
|
from typing import Any, Optional, Union
|
|
|
|
|
|
+from sqlalchemy import select
|
|
|
+from sqlalchemy.orm import Session
|
|
|
+
|
|
|
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
|
|
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
|
|
|
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
|
@@ -79,8 +82,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|
|
|
|
|
_task_state: WorkflowTaskState
|
|
|
_application_generate_entity: AdvancedChatAppGenerateEntity
|
|
|
- _workflow: Workflow
|
|
|
- _user: Union[Account, EndUser]
|
|
|
_workflow_system_variables: dict[SystemVariableKey, Any]
|
|
|
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
|
|
|
_conversation_name_generate_thread: Optional[Thread] = None
|
|
@@ -96,32 +97,35 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|
|
stream: bool,
|
|
|
dialogue_count: int,
|
|
|
) -> None:
|
|
|
- """
|
|
|
- Initialize AdvancedChatAppGenerateTaskPipeline.
|
|
|
- :param application_generate_entity: application generate entity
|
|
|
- :param workflow: workflow
|
|
|
- :param queue_manager: queue manager
|
|
|
- :param conversation: conversation
|
|
|
- :param message: message
|
|
|
- :param user: user
|
|
|
- :param stream: stream
|
|
|
- :param dialogue_count: dialogue count
|
|
|
- """
|
|
|
- super().__init__(application_generate_entity, queue_manager, user, stream)
|
|
|
+ super().__init__(
|
|
|
+ application_generate_entity=application_generate_entity,
|
|
|
+ queue_manager=queue_manager,
|
|
|
+ stream=stream,
|
|
|
+ )
|
|
|
|
|
|
- if isinstance(self._user, EndUser):
|
|
|
- user_id = self._user.session_id
|
|
|
+ if isinstance(user, EndUser):
|
|
|
+ self._user_id = user.session_id
|
|
|
+ self._created_by_role = CreatedByRole.END_USER
|
|
|
+ elif isinstance(user, Account):
|
|
|
+ self._user_id = user.id
|
|
|
+ self._created_by_role = CreatedByRole.ACCOUNT
|
|
|
else:
|
|
|
- user_id = self._user.id
|
|
|
+ raise NotImplementedError(f"User type not supported: {type(user)}")
|
|
|
+
|
|
|
+ self._workflow_id = workflow.id
|
|
|
+ self._workflow_features_dict = workflow.features_dict
|
|
|
+
|
|
|
+ self._conversation_id = conversation.id
|
|
|
+ self._conversation_mode = conversation.mode
|
|
|
+
|
|
|
+ self._message_id = message.id
|
|
|
+ self._message_created_at = int(message.created_at.timestamp())
|
|
|
|
|
|
- self._workflow = workflow
|
|
|
- self._conversation = conversation
|
|
|
- self._message = message
|
|
|
self._workflow_system_variables = {
|
|
|
SystemVariableKey.QUERY: message.query,
|
|
|
SystemVariableKey.FILES: application_generate_entity.files,
|
|
|
SystemVariableKey.CONVERSATION_ID: conversation.id,
|
|
|
- SystemVariableKey.USER_ID: user_id,
|
|
|
+ SystemVariableKey.USER_ID: self._user_id,
|
|
|
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
|
|
|
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
|
|
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
|
@@ -139,13 +143,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|
|
Process generate task pipeline.
|
|
|
:return:
|
|
|
"""
|
|
|
- db.session.refresh(self._workflow)
|
|
|
- db.session.refresh(self._user)
|
|
|
- db.session.close()
|
|
|
-
|
|
|
# start generate conversation name thread
|
|
|
self._conversation_name_generate_thread = self._generate_conversation_name(
|
|
|
- self._conversation, self._application_generate_entity.query
|
|
|
+ conversation_id=self._conversation_id, query=self._application_generate_entity.query
|
|
|
)
|
|
|
|
|
|
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
|
@@ -171,12 +171,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|
|
return ChatbotAppBlockingResponse(
|
|
|
task_id=stream_response.task_id,
|
|
|
data=ChatbotAppBlockingResponse.Data(
|
|
|
- id=self._message.id,
|
|
|
- mode=self._conversation.mode,
|
|
|
- conversation_id=self._conversation.id,
|
|
|
- message_id=self._message.id,
|
|
|
+ id=self._message_id,
|
|
|
+ mode=self._conversation_mode,
|
|
|
+ conversation_id=self._conversation_id,
|
|
|
+ message_id=self._message_id,
|
|
|
answer=self._task_state.answer,
|
|
|
- created_at=int(self._message.created_at.timestamp()),
|
|
|
+ created_at=self._message_created_at,
|
|
|
**extras,
|
|
|
),
|
|
|
)
|
|
@@ -194,9 +194,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|
|
"""
|
|
|
for stream_response in generator:
|
|
|
yield ChatbotAppStreamResponse(
|
|
|
- conversation_id=self._conversation.id,
|
|
|
- message_id=self._message.id,
|
|
|
- created_at=int(self._message.created_at.timestamp()),
|
|
|
+ conversation_id=self._conversation_id,
|
|
|
+ message_id=self._message_id,
|
|
|
+ created_at=self._message_created_at,
|
|
|
stream_response=stream_response,
|
|
|
)
|
|
|
|
|
@@ -214,7 +214,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|
|
tts_publisher = None
|
|
|
task_id = self._application_generate_entity.task_id
|
|
|
tenant_id = self._application_generate_entity.app_config.tenant_id
|
|
|
- features_dict = self._workflow.features_dict
|
|
|
+ features_dict = self._workflow_features_dict
|
|
|
|
|
|
if (
|
|
|
features_dict.get("text_to_speech")
|
|
@@ -274,26 +274,33 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|
|
if isinstance(event, QueuePingEvent):
|
|
|
yield self._ping_stream_response()
|
|
|
elif isinstance(event, QueueErrorEvent):
|
|
|
- err = self._handle_error(event, self._message)
|
|
|
+ with Session(db.engine) as session:
|
|
|
+ err = self._handle_error(event=event, session=session, message_id=self._message_id)
|
|
|
+ session.commit()
|
|
|
yield self._error_to_stream_response(err)
|
|
|
break
|
|
|
elif isinstance(event, QueueWorkflowStartedEvent):
|
|
|
# override graph runtime state
|
|
|
graph_runtime_state = event.graph_runtime_state
|
|
|
|
|
|
- # init workflow run
|
|
|
- workflow_run = self._handle_workflow_run_start()
|
|
|
-
|
|
|
- self._refetch_message()
|
|
|
- self._message.workflow_run_id = workflow_run.id
|
|
|
-
|
|
|
- db.session.commit()
|
|
|
- db.session.refresh(self._message)
|
|
|
- db.session.close()
|
|
|
-
|
|
|
- yield self._workflow_start_to_stream_response(
|
|
|
- task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
|
|
- )
|
|
|
+ with Session(db.engine) as session:
|
|
|
+ # init workflow run
|
|
|
+ workflow_run = self._handle_workflow_run_start(
|
|
|
+ session=session,
|
|
|
+ workflow_id=self._workflow_id,
|
|
|
+ user_id=self._user_id,
|
|
|
+ created_by_role=self._created_by_role,
|
|
|
+ )
|
|
|
+ message = self._get_message(session=session)
|
|
|
+ if not message:
|
|
|
+ raise ValueError(f"Message not found: {self._message_id}")
|
|
|
+ message.workflow_run_id = workflow_run.id
|
|
|
+ session.commit()
|
|
|
+
|
|
|
+ workflow_start_resp = self._workflow_start_to_stream_response(
|
|
|
+ session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
|
|
+ )
|
|
|
+ yield workflow_start_resp
|
|
|
elif isinstance(
|
|
|
event,
|
|
|
QueueNodeRetryEvent,
|
|
@@ -304,28 +311,28 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|
|
workflow_run=workflow_run, event=event
|
|
|
)
|
|
|
|
|
|
- response = self._workflow_node_retry_to_stream_response(
|
|
|
+ node_retry_resp = self._workflow_node_retry_to_stream_response(
|
|
|
event=event,
|
|
|
task_id=self._application_generate_entity.task_id,
|
|
|
workflow_node_execution=workflow_node_execution,
|
|
|
)
|
|
|
|
|
|
- if response:
|
|
|
- yield response
|
|
|
+ if node_retry_resp:
|
|
|
+ yield node_retry_resp
|
|
|
elif isinstance(event, QueueNodeStartedEvent):
|
|
|
if not workflow_run:
|
|
|
raise ValueError("workflow run not initialized.")
|
|
|
|
|
|
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
|
|
|
|
|
|
- response_start = self._workflow_node_start_to_stream_response(
|
|
|
+ node_start_resp = self._workflow_node_start_to_stream_response(
|
|
|
event=event,
|
|
|
task_id=self._application_generate_entity.task_id,
|
|
|
workflow_node_execution=workflow_node_execution,
|
|
|
)
|
|
|
|
|
|
- if response_start:
|
|
|
- yield response_start
|
|
|
+ if node_start_resp:
|
|
|
+ yield node_start_resp
|
|
|
elif isinstance(event, QueueNodeSucceededEvent):
|
|
|
workflow_node_execution = self._handle_workflow_node_execution_success(event)
|
|
|
|
|
@@ -333,25 +340,24 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|
|
if event.node_type in [NodeType.ANSWER, NodeType.END]:
|
|
|
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
|
|
|
|
|
|
- response_finish = self._workflow_node_finish_to_stream_response(
|
|
|
+ node_finish_resp = self._workflow_node_finish_to_stream_response(
|
|
|
event=event,
|
|
|
task_id=self._application_generate_entity.task_id,
|
|
|
workflow_node_execution=workflow_node_execution,
|
|
|
)
|
|
|
|
|
|
- if response_finish:
|
|
|
- yield response_finish
|
|
|
+ if node_finish_resp:
|
|
|
+ yield node_finish_resp
|
|
|
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
|
|
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
|
|
|
|
|
|
- response_finish = self._workflow_node_finish_to_stream_response(
|
|
|
+ node_finish_resp = self._workflow_node_finish_to_stream_response(
|
|
|
event=event,
|
|
|
task_id=self._application_generate_entity.task_id,
|
|
|
workflow_node_execution=workflow_node_execution,
|
|
|
)
|
|
|
-
|
|
|
- if response:
|
|
|
- yield response
|
|
|
+ if node_finish_resp:
|
|
|
+ yield node_finish_resp
|
|
|
|
|
|
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
|
|
if not workflow_run:
|
|
@@ -395,20 +401,24 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|
|
if not graph_runtime_state:
|
|
|
raise ValueError("workflow run not initialized.")
|
|
|
|
|
|
- workflow_run = self._handle_workflow_run_success(
|
|
|
- workflow_run=workflow_run,
|
|
|
- start_at=graph_runtime_state.start_at,
|
|
|
- total_tokens=graph_runtime_state.total_tokens,
|
|
|
- total_steps=graph_runtime_state.node_run_steps,
|
|
|
- outputs=event.outputs,
|
|
|
- conversation_id=self._conversation.id,
|
|
|
- trace_manager=trace_manager,
|
|
|
- )
|
|
|
+ with Session(db.engine) as session:
|
|
|
+ workflow_run = self._handle_workflow_run_success(
|
|
|
+ session=session,
|
|
|
+ workflow_run=workflow_run,
|
|
|
+ start_at=graph_runtime_state.start_at,
|
|
|
+ total_tokens=graph_runtime_state.total_tokens,
|
|
|
+ total_steps=graph_runtime_state.node_run_steps,
|
|
|
+ outputs=event.outputs,
|
|
|
+ conversation_id=self._conversation_id,
|
|
|
+ trace_manager=trace_manager,
|
|
|
+ )
|
|
|
|
|
|
- yield self._workflow_finish_to_stream_response(
|
|
|
- task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
|
|
- )
|
|
|
+ workflow_finish_resp = self._workflow_finish_to_stream_response(
|
|
|
+ session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
|
|
+ )
|
|
|
+ session.commit()
|
|
|
|
|
|
+ yield workflow_finish_resp
|
|
|
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
|
|
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
|
|
|
if not workflow_run:
|
|
@@ -417,21 +427,25 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|
|
if not graph_runtime_state:
|
|
|
raise ValueError("graph runtime state not initialized.")
|
|
|
|
|
|
- workflow_run = self._handle_workflow_run_partial_success(
|
|
|
- workflow_run=workflow_run,
|
|
|
- start_at=graph_runtime_state.start_at,
|
|
|
- total_tokens=graph_runtime_state.total_tokens,
|
|
|
- total_steps=graph_runtime_state.node_run_steps,
|
|
|
- outputs=event.outputs,
|
|
|
- exceptions_count=event.exceptions_count,
|
|
|
- conversation_id=None,
|
|
|
- trace_manager=trace_manager,
|
|
|
- )
|
|
|
+ with Session(db.engine) as session:
|
|
|
+ workflow_run = self._handle_workflow_run_partial_success(
|
|
|
+ session=session,
|
|
|
+ workflow_run=workflow_run,
|
|
|
+ start_at=graph_runtime_state.start_at,
|
|
|
+ total_tokens=graph_runtime_state.total_tokens,
|
|
|
+ total_steps=graph_runtime_state.node_run_steps,
|
|
|
+ outputs=event.outputs,
|
|
|
+ exceptions_count=event.exceptions_count,
|
|
|
+ conversation_id=None,
|
|
|
+ trace_manager=trace_manager,
|
|
|
+ )
|
|
|
|
|
|
- yield self._workflow_finish_to_stream_response(
|
|
|
- task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
|
|
- )
|
|
|
+ workflow_finish_resp = self._workflow_finish_to_stream_response(
|
|
|
+ session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
|
|
+ )
|
|
|
+ session.commit()
|
|
|
|
|
|
+ yield workflow_finish_resp
|
|
|
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
|
|
elif isinstance(event, QueueWorkflowFailedEvent):
|
|
|
if not workflow_run:
|
|
@@ -440,71 +454,73 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|
|
if not graph_runtime_state:
|
|
|
raise ValueError("graph runtime state not initialized.")
|
|
|
|
|
|
- workflow_run = self._handle_workflow_run_failed(
|
|
|
- workflow_run=workflow_run,
|
|
|
- start_at=graph_runtime_state.start_at,
|
|
|
- total_tokens=graph_runtime_state.total_tokens,
|
|
|
- total_steps=graph_runtime_state.node_run_steps,
|
|
|
- status=WorkflowRunStatus.FAILED,
|
|
|
- error=event.error,
|
|
|
- conversation_id=self._conversation.id,
|
|
|
- trace_manager=trace_manager,
|
|
|
- exceptions_count=event.exceptions_count,
|
|
|
- )
|
|
|
-
|
|
|
- yield self._workflow_finish_to_stream_response(
|
|
|
- task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
|
|
- )
|
|
|
-
|
|
|
- err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
|
|
|
- yield self._error_to_stream_response(self._handle_error(err_event, self._message))
|
|
|
- break
|
|
|
- elif isinstance(event, QueueStopEvent):
|
|
|
- if workflow_run and graph_runtime_state:
|
|
|
+ with Session(db.engine) as session:
|
|
|
workflow_run = self._handle_workflow_run_failed(
|
|
|
+ session=session,
|
|
|
workflow_run=workflow_run,
|
|
|
start_at=graph_runtime_state.start_at,
|
|
|
total_tokens=graph_runtime_state.total_tokens,
|
|
|
total_steps=graph_runtime_state.node_run_steps,
|
|
|
- status=WorkflowRunStatus.STOPPED,
|
|
|
- error=event.get_stop_reason(),
|
|
|
- conversation_id=self._conversation.id,
|
|
|
+ status=WorkflowRunStatus.FAILED,
|
|
|
+ error=event.error,
|
|
|
+ conversation_id=self._conversation_id,
|
|
|
trace_manager=trace_manager,
|
|
|
+ exceptions_count=event.exceptions_count,
|
|
|
)
|
|
|
-
|
|
|
- yield self._workflow_finish_to_stream_response(
|
|
|
- task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
|
|
+ workflow_finish_resp = self._workflow_finish_to_stream_response(
|
|
|
+ session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
|
|
)
|
|
|
-
|
|
|
- # Save message
|
|
|
- self._save_message(graph_runtime_state=graph_runtime_state)
|
|
|
+ err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
|
|
|
+ err = self._handle_error(event=err_event, session=session, message_id=self._message_id)
|
|
|
+ session.commit()
|
|
|
+ yield workflow_finish_resp
|
|
|
+ yield self._error_to_stream_response(err)
|
|
|
+ break
|
|
|
+ elif isinstance(event, QueueStopEvent):
|
|
|
+ if workflow_run and graph_runtime_state:
|
|
|
+ with Session(db.engine) as session:
|
|
|
+ workflow_run = self._handle_workflow_run_failed(
|
|
|
+ session=session,
|
|
|
+ workflow_run=workflow_run,
|
|
|
+ start_at=graph_runtime_state.start_at,
|
|
|
+ total_tokens=graph_runtime_state.total_tokens,
|
|
|
+ total_steps=graph_runtime_state.node_run_steps,
|
|
|
+ status=WorkflowRunStatus.STOPPED,
|
|
|
+ error=event.get_stop_reason(),
|
|
|
+ conversation_id=self._conversation_id,
|
|
|
+ trace_manager=trace_manager,
|
|
|
+ )
|
|
|
+
|
|
|
+ workflow_finish_resp = self._workflow_finish_to_stream_response(
|
|
|
+ session=session,
|
|
|
+ task_id=self._application_generate_entity.task_id,
|
|
|
+ workflow_run=workflow_run,
|
|
|
+ )
|
|
|
+ # Save message
|
|
|
+ self._save_message(session=session, graph_runtime_state=graph_runtime_state)
|
|
|
+ session.commit()
|
|
|
+ yield workflow_finish_resp
|
|
|
|
|
|
yield self._message_end_to_stream_response()
|
|
|
break
|
|
|
elif isinstance(event, QueueRetrieverResourcesEvent):
|
|
|
self._handle_retriever_resources(event)
|
|
|
|
|
|
- self._refetch_message()
|
|
|
-
|
|
|
- self._message.message_metadata = (
|
|
|
- json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
|
|
- )
|
|
|
-
|
|
|
- db.session.commit()
|
|
|
- db.session.refresh(self._message)
|
|
|
- db.session.close()
|
|
|
+ with Session(db.engine) as session:
|
|
|
+ message = self._get_message(session=session)
|
|
|
+ message.message_metadata = (
|
|
|
+ json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
|
|
+ )
|
|
|
+ session.commit()
|
|
|
elif isinstance(event, QueueAnnotationReplyEvent):
|
|
|
self._handle_annotation_reply(event)
|
|
|
|
|
|
- self._refetch_message()
|
|
|
-
|
|
|
- self._message.message_metadata = (
|
|
|
- json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
|
|
- )
|
|
|
-
|
|
|
- db.session.commit()
|
|
|
- db.session.refresh(self._message)
|
|
|
- db.session.close()
|
|
|
+ with Session(db.engine) as session:
|
|
|
+ message = self._get_message(session=session)
|
|
|
+ message.message_metadata = (
|
|
|
+ json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
|
|
+ )
|
|
|
+ session.commit()
|
|
|
elif isinstance(event, QueueTextChunkEvent):
|
|
|
delta_text = event.text
|
|
|
if delta_text is None:
|
|
@@ -521,7 +537,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|
|
|
|
|
self._task_state.answer += delta_text
|
|
|
yield self._message_to_stream_response(
|
|
|
- answer=delta_text, message_id=self._message.id, from_variable_selector=event.from_variable_selector
|
|
|
+ answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
|
|
|
)
|
|
|
elif isinstance(event, QueueMessageReplaceEvent):
|
|
|
# published by moderation
|
|
@@ -536,7 +552,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|
|
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
|
|
|
|
|
|
# Save message
|
|
|
- self._save_message(graph_runtime_state=graph_runtime_state)
|
|
|
+ with Session(db.engine) as session:
|
|
|
+ self._save_message(session=session, graph_runtime_state=graph_runtime_state)
|
|
|
+ session.commit()
|
|
|
|
|
|
yield self._message_end_to_stream_response()
|
|
|
else:
|
|
@@ -549,54 +567,46 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|
|
if self._conversation_name_generate_thread:
|
|
|
self._conversation_name_generate_thread.join()
|
|
|
|
|
|
- def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
|
|
|
- self._refetch_message()
|
|
|
-
|
|
|
- self._message.answer = self._task_state.answer
|
|
|
- self._message.provider_response_latency = time.perf_counter() - self._start_at
|
|
|
- self._message.message_metadata = (
|
|
|
+ def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
|
|
|
+ message = self._get_message(session=session)
|
|
|
+ message.answer = self._task_state.answer
|
|
|
+ message.provider_response_latency = time.perf_counter() - self._start_at
|
|
|
+ message.message_metadata = (
|
|
|
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
|
|
)
|
|
|
message_files = [
|
|
|
MessageFile(
|
|
|
- message_id=self._message.id,
|
|
|
+ message_id=message.id,
|
|
|
type=file["type"],
|
|
|
transfer_method=file["transfer_method"],
|
|
|
url=file["remote_url"],
|
|
|
belongs_to="assistant",
|
|
|
upload_file_id=file["related_id"],
|
|
|
created_by_role=CreatedByRole.ACCOUNT
|
|
|
- if self._message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
|
|
+ if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
|
|
else CreatedByRole.END_USER,
|
|
|
- created_by=self._message.from_account_id or self._message.from_end_user_id or "",
|
|
|
+ created_by=message.from_account_id or message.from_end_user_id or "",
|
|
|
)
|
|
|
for file in self._recorded_files
|
|
|
]
|
|
|
- db.session.add_all(message_files)
|
|
|
+ session.add_all(message_files)
|
|
|
|
|
|
if graph_runtime_state and graph_runtime_state.llm_usage:
|
|
|
usage = graph_runtime_state.llm_usage
|
|
|
- self._message.message_tokens = usage.prompt_tokens
|
|
|
- self._message.message_unit_price = usage.prompt_unit_price
|
|
|
- self._message.message_price_unit = usage.prompt_price_unit
|
|
|
- self._message.answer_tokens = usage.completion_tokens
|
|
|
- self._message.answer_unit_price = usage.completion_unit_price
|
|
|
- self._message.answer_price_unit = usage.completion_price_unit
|
|
|
- self._message.total_price = usage.total_price
|
|
|
- self._message.currency = usage.currency
|
|
|
-
|
|
|
+ message.message_tokens = usage.prompt_tokens
|
|
|
+ message.message_unit_price = usage.prompt_unit_price
|
|
|
+ message.message_price_unit = usage.prompt_price_unit
|
|
|
+ message.answer_tokens = usage.completion_tokens
|
|
|
+ message.answer_unit_price = usage.completion_unit_price
|
|
|
+ message.answer_price_unit = usage.completion_price_unit
|
|
|
+ message.total_price = usage.total_price
|
|
|
+ message.currency = usage.currency
|
|
|
self._task_state.metadata["usage"] = jsonable_encoder(usage)
|
|
|
else:
|
|
|
self._task_state.metadata["usage"] = jsonable_encoder(LLMUsage.empty_usage())
|
|
|
-
|
|
|
- db.session.commit()
|
|
|
-
|
|
|
message_was_created.send(
|
|
|
- self._message,
|
|
|
+ message,
|
|
|
application_generate_entity=self._application_generate_entity,
|
|
|
- conversation=self._conversation,
|
|
|
- is_first_message=self._application_generate_entity.conversation_id is None,
|
|
|
- extras=self._application_generate_entity.extras,
|
|
|
)
|
|
|
|
|
|
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
|
|
@@ -613,7 +623,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|
|
|
|
|
return MessageEndStreamResponse(
|
|
|
task_id=self._application_generate_entity.task_id,
|
|
|
- id=self._message.id,
|
|
|
+ id=self._message_id,
|
|
|
files=self._recorded_files,
|
|
|
metadata=extras.get("metadata", {}),
|
|
|
)
|
|
@@ -641,11 +651,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|
|
|
|
|
return False
|
|
|
|
|
|
- def _refetch_message(self) -> None:
|
|
|
- """
|
|
|
- Refetch message.
|
|
|
- :return:
|
|
|
- """
|
|
|
- message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
|
|
- if message:
|
|
|
- self._message = message
|
|
|
+ def _get_message(self, *, session: Session):
|
|
|
+ stmt = select(Message).where(Message.id == self._message_id)
|
|
|
+ message = session.scalar(stmt)
|
|
|
+ if not message:
|
|
|
+ raise ValueError(f"Message not found: {self._message_id}")
|
|
|
+ return message
|