Pārlūkot izejas kodu

Feat/add blocking mode resource return (#1171)

Co-authored-by: jyong <jyong@dify.ai>
Jyong 1 gadu atpakaļ
vecāks
revīzija
5d9070bc60
2 mainītis faili ar 27 papildinājumiem un 3 dzēšanām
  1. 1 1
      api/models/model.py
  2. 26 2
      api/services/completion_service.py

+ 1 - 1
api/models/model.py

@@ -147,7 +147,7 @@ class AppModelConfig(db.Model):
             "suggested_questions": self.suggested_questions_list,
             "suggested_questions_after_answer": self.suggested_questions_after_answer_dict,
             "speech_to_text": self.speech_to_text_dict,
-            "retriever_resource": self.retriever_resource,
+            "retriever_resource": self.retriever_resource_dict,
             "more_like_this": self.more_like_this_dict,
             "sensitive_word_avoidance": self.sensitive_word_avoidance_dict,
             "model": self.model_dict,

+ 26 - 2
api/services/completion_service.py

@@ -366,6 +366,7 @@ class CompletionService:
         generate_channel = list(pubsub.channels.keys())[0].decode('utf-8')
         if not streaming:
             try:
+                message_result = {}
                 for message in pubsub.listen():
                     if message["type"] == "message":
                         result = message["data"].decode('utf-8')
@@ -373,7 +374,10 @@ class CompletionService:
                         if result.get('error'):
                             cls.handle_error(result)
                         if result['event'] == 'message' and 'data' in result:
-                            return cls.get_message_response_data(result.get('data'))
+                            message_result['message'] = result.get('data')
+                        if result['event'] == 'message_end' and 'data' in result:
+                            message_result['message_end'] = result.get('data')
+                            return cls.get_blocking_message_response_data(message_result)
             except ValueError as e:
                 if e.args[0] != "I/O operation on closed file.":  # ignore this error
                     raise CompletionStoppedError()
@@ -399,7 +403,6 @@ class CompletionService:
                             if event == "end":
                                 logging.debug("{} finished".format(generate_channel))
                                 break
-
                             if event == 'message':
                                 yield "data: " + json.dumps(cls.get_message_response_data(result.get('data'))) + "\n\n"
                             elif event == 'chain':
@@ -441,6 +444,27 @@ class CompletionService:
 
         return response_data
 
+    @classmethod
+    def get_blocking_message_response_data(cls, data: dict):
+        message = data.get('message')
+        response_data = {
+            'event': 'message',
+            'task_id': message.get('task_id'),
+            'id': message.get('message_id'),
+            'answer': message.get('text'),
+            'metadata': {},
+            'created_at': int(time.time())
+        }
+
+        if message.get('mode') == 'chat':
+            response_data['conversation_id'] = message.get('conversation_id')
+        if 'message_end' in data:
+            message_end = data.get('message_end')
+            if 'retriever_resources' in message_end:
+                response_data['metadata']['retriever_resources'] = message_end.get('retriever_resources')
+
+        return response_data
+
     @classmethod
     def get_message_end_data(cls, data: dict):
         response_data = {