ソースを参照

fix: count down thread in completion db not commit (#1267)

takatost 1 年間 前
コミット
41d4c5b424

+ 2 - 2
api/core/conversation_message_task.py

@@ -94,7 +94,7 @@ class ConversationMessageTask:
         if not self.conversation:
             self.is_new_conversation = True
             self.conversation = Conversation(
-                app_id=self.app_model_config.app_id,
+                app_id=self.app.id,
                 app_model_config_id=self.app_model_config.id,
                 model_provider=self.provider_name,
                 model_id=self.model_name,
@@ -115,7 +115,7 @@ class ConversationMessageTask:
             db.session.commit()
 
         self.message = Message(
-            app_id=self.app_model_config.app_id,
+            app_id=self.app.id,
             model_provider=self.provider_name,
             model_id=self.model_name,
             override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,

+ 1 - 7
api/core/model_providers/models/llm/openai_model.py

@@ -106,13 +106,7 @@ class OpenAIModel(BaseLLM):
             raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.")
 
         prompts = self._get_prompt_from_messages(messages)
-
-        try:
-            return self._client.generate([prompts], stop, callbacks)
-        finally:
-            thread_context = api_requestor._thread_context
-            if hasattr(thread_context, "session") and thread_context.session:
-                thread_context.session.close()
+        return self._client.generate([prompts], stop, callbacks)
 
     def get_num_tokens(self, messages: List[PromptMessage]) -> int:
         """

+ 18 - 18
api/services/completion_service.py

@@ -155,7 +155,7 @@ class CompletionService:
         generate_worker_thread.start()
 
         # wait for 10 minutes to close the thread
-        cls.countdown_and_close(generate_worker_thread, pubsub, user, generate_task_id)
+        cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id)
 
         return cls.compact_response(pubsub, streaming)
 
@@ -210,25 +210,26 @@ class CompletionService:
                 db.session.commit()
 
     @classmethod
-    def countdown_and_close(cls, worker_thread, pubsub, user, generate_task_id) -> threading.Thread:
+    def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, user, generate_task_id) -> threading.Thread:
         # wait for 10 minutes to close the thread
         timeout = 600
 
         def close_pubsub():
-            sleep_iterations = 0
-            while sleep_iterations < timeout and worker_thread.is_alive():
-                if sleep_iterations > 0 and sleep_iterations % 10 == 0:
-                    PubHandler.ping(user, generate_task_id)
+            with flask_app.app_context():
+                sleep_iterations = 0
+                while sleep_iterations < timeout and worker_thread.is_alive():
+                    if sleep_iterations > 0 and sleep_iterations % 10 == 0:
+                        PubHandler.ping(user, generate_task_id)
 
-                time.sleep(1)
-                sleep_iterations += 1
+                    time.sleep(1)
+                    sleep_iterations += 1
 
-            if worker_thread.is_alive():
-                PubHandler.stop(user, generate_task_id)
-                try:
-                    pubsub.close()
-                except:
-                    pass
+                if worker_thread.is_alive():
+                    PubHandler.stop(user, generate_task_id)
+                    try:
+                        pubsub.close()
+                    except:
+                        pass
 
         countdown_thread = threading.Thread(target=close_pubsub)
         countdown_thread.start()
@@ -288,7 +289,7 @@ class CompletionService:
 
         generate_worker_thread.start()
 
-        cls.countdown_and_close(generate_worker_thread, pubsub, user, generate_task_id)
+        cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id)
 
         return cls.compact_response(pubsub, streaming)
 
@@ -313,15 +314,14 @@ class CompletionService:
             except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
                     LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
                     ModelCurrentlyNotSupportError) as e:
-                db.session.rollback()
                 PubHandler.pub_error(user, generate_task_id, e)
             except LLMAuthorizationError:
-                db.session.rollback()
                 PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided'))
             except Exception as e:
-                db.session.rollback()
                 logging.exception("Unknown Error in completion")
                 PubHandler.pub_error(user, generate_task_id, e)
+            finally:
+                db.session.commit()
 
     @classmethod
     def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):