Jelajahi Sumber

fix: detached model in completion thread (#1269)

takatost 1 tahun lalu
induk
melakukan
373e90ee6d

+ 0 - 2
api/core/model_providers/models/llm/base.py

@@ -132,8 +132,6 @@ class BaseLLM(BaseProviderModel):
         if self.deduct_quota:
             self.model_provider.check_quota_over_limit()
 
-        db.session.commit()
-
         if not callbacks:
             callbacks = self.callbacks
         else:

+ 30 - 19
api/services/completion_service.py

@@ -3,7 +3,7 @@ import logging
 import threading
 import time
 import uuid
-from typing import Generator, Union, Any
+from typing import Generator, Union, Any, Optional
 
 from flask import current_app, Flask
 from redis.client import PubSub
@@ -141,12 +141,12 @@ class CompletionService:
         generate_worker_thread = threading.Thread(target=cls.generate_worker, kwargs={
             'flask_app': current_app._get_current_object(),
             'generate_task_id': generate_task_id,
-            'app_model': app_model,
+            'detached_app_model': app_model,
             'app_model_config': app_model_config,
             'query': query,
             'inputs': inputs,
-            'user': user,
-            'conversation': conversation,
+            'detached_user': user,
+            'detached_conversation': conversation,
             'streaming': streaming,
             'is_model_config_override': is_model_config_override,
             'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev'
@@ -171,18 +171,22 @@ class CompletionService:
         return user
 
     @classmethod
-    def generate_worker(cls, flask_app: Flask, generate_task_id: str, app_model: App, app_model_config: AppModelConfig,
-                        query: str, inputs: dict, user: Union[Account, EndUser],
-                        conversation: Conversation, streaming: bool, is_model_config_override: bool,
+    def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App, app_model_config: AppModelConfig,
+                        query: str, inputs: dict, detached_user: Union[Account, EndUser],
+                        detached_conversation: Optional[Conversation], streaming: bool, is_model_config_override: bool,
                         retriever_from: str = 'dev'):
         with flask_app.app_context():
-            try:
-                if conversation:
-                    # fixed the state of the conversation object when it detached from the original session
-                    conversation = db.session.query(Conversation).filter_by(id=conversation.id).first()
+            # fixed the state of the model object when it detached from the original session
+            user = db.session.merge(detached_user)
+            app_model = db.session.merge(detached_app_model)
 
-                # run
+            if detached_conversation:
+                conversation = db.session.merge(detached_conversation)
+            else:
+                conversation = None
 
+            try:
+                # run
                 Completion.generate(
                     task_id=generate_task_id,
                     app=app_model,
@@ -210,12 +214,14 @@ class CompletionService:
                 db.session.commit()
 
     @classmethod
-    def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, user, generate_task_id) -> threading.Thread:
+    def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, detached_user, generate_task_id) -> threading.Thread:
         # wait for 10 minutes to close the thread
         timeout = 600
 
         def close_pubsub():
             with flask_app.app_context():
+                user = db.session.merge(detached_user)
+
                 sleep_iterations = 0
                 while sleep_iterations < timeout and worker_thread.is_alive():
                     if sleep_iterations > 0 and sleep_iterations % 10 == 0:
@@ -279,11 +285,11 @@ class CompletionService:
         generate_worker_thread = threading.Thread(target=cls.generate_more_like_this_worker, kwargs={
             'flask_app': current_app._get_current_object(),
             'generate_task_id': generate_task_id,
-            'app_model': app_model,
+            'detached_app_model': app_model,
             'app_model_config': app_model_config,
-            'message': message,
+            'detached_message': message,
             'pre_prompt': pre_prompt,
-            'user': user,
+            'detached_user': user,
             'streaming': streaming
         })
 
@@ -294,10 +300,15 @@ class CompletionService:
         return cls.compact_response(pubsub, streaming)
 
     @classmethod
-    def generate_more_like_this_worker(cls, flask_app: Flask, generate_task_id: str, app_model: App,
-                                       app_model_config: AppModelConfig, message: Message, pre_prompt: str,
-                                       user: Union[Account, EndUser], streaming: bool):
+    def generate_more_like_this_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App,
+                                       app_model_config: AppModelConfig, detached_message: Message, pre_prompt: str,
+                                       detached_user: Union[Account, EndUser], streaming: bool):
         with flask_app.app_context():
+            # fixed the state of the model object when it detached from the original session
+            user = db.session.merge(detached_user)
+            app_model = db.session.merge(detached_app_model)
+            message = db.session.merge(detached_message)
+
             try:
                 # run
                 Completion.generate_more_like_this(