|
@@ -366,6 +366,7 @@ class CompletionService:
|
|
|
generate_channel = list(pubsub.channels.keys())[0].decode('utf-8')
|
|
|
if not streaming:
|
|
|
try:
|
|
|
+ message_result = {}
|
|
|
for message in pubsub.listen():
|
|
|
if message["type"] == "message":
|
|
|
result = message["data"].decode('utf-8')
|
|
@@ -373,7 +374,10 @@ class CompletionService:
|
|
|
if result.get('error'):
|
|
|
cls.handle_error(result)
|
|
|
if result['event'] == 'message' and 'data' in result:
|
|
|
- return cls.get_message_response_data(result.get('data'))
|
|
|
+ message_result['message'] = result.get('data')
|
|
|
+ if result['event'] == 'message_end' and 'data' in result:
|
|
|
+ message_result['message_end'] = result.get('data')
|
|
|
+ return cls.get_blocking_message_response_data(message_result)
|
|
|
except ValueError as e:
|
|
|
if e.args[0] != "I/O operation on closed file.": # ignore this error
|
|
|
raise CompletionStoppedError()
|
|
@@ -399,7 +403,6 @@ class CompletionService:
|
|
|
if event == "end":
|
|
|
logging.debug("{} finished".format(generate_channel))
|
|
|
break
|
|
|
-
|
|
|
if event == 'message':
|
|
|
yield "data: " + json.dumps(cls.get_message_response_data(result.get('data'))) + "\n\n"
|
|
|
elif event == 'chain':
|
|
@@ -441,6 +444,27 @@ class CompletionService:
|
|
|
|
|
|
return response_data
|
|
|
|
|
|
+ @classmethod
|
|
|
+ def get_blocking_message_response_data(cls, data: dict):
|
|
|
+ message = data.get('message')
|
|
|
+ response_data = {
|
|
|
+ 'event': 'message',
|
|
|
+ 'task_id': message.get('task_id'),
|
|
|
+ 'id': message.get('message_id'),
|
|
|
+ 'answer': message.get('text'),
|
|
|
+ 'metadata': {},
|
|
|
+ 'created_at': int(time.time())
|
|
|
+ }
|
|
|
+
|
|
|
+ if message.get('mode') == 'chat':
|
|
|
+ response_data['conversation_id'] = message.get('conversation_id')
|
|
|
+ if 'message_end' in data:
|
|
|
+ message_end = data.get('message_end')
|
|
|
+ if 'retriever_resources' in message_end:
|
|
|
+ response_data['metadata']['retriever_resources'] = message_end.get('retriever_resources')
|
|
|
+
|
|
|
+ return response_data
|
|
|
+
|
|
|
@classmethod
|
|
|
def get_message_end_data(cls, data: dict):
|
|
|
response_data = {
|