|
@@ -155,7 +155,7 @@ class CompletionService:
|
|
|
generate_worker_thread.start()
|
|
|
|
|
|
# wait for 10 minutes to close the thread
|
|
|
- cls.countdown_and_close(generate_worker_thread, pubsub, user, generate_task_id)
|
|
|
+ cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id)
|
|
|
|
|
|
return cls.compact_response(pubsub, streaming)
|
|
|
|
|
@@ -210,25 +210,26 @@ class CompletionService:
|
|
|
db.session.commit()
|
|
|
|
|
|
@classmethod
|
|
|
- def countdown_and_close(cls, worker_thread, pubsub, user, generate_task_id) -> threading.Thread:
|
|
|
+ def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, user, generate_task_id) -> threading.Thread:
|
|
|
# wait for 10 minutes to close the thread
|
|
|
timeout = 600
|
|
|
|
|
|
def close_pubsub():
|
|
|
- sleep_iterations = 0
|
|
|
- while sleep_iterations < timeout and worker_thread.is_alive():
|
|
|
- if sleep_iterations > 0 and sleep_iterations % 10 == 0:
|
|
|
- PubHandler.ping(user, generate_task_id)
|
|
|
+ with flask_app.app_context():
|
|
|
+ sleep_iterations = 0
|
|
|
+ while sleep_iterations < timeout and worker_thread.is_alive():
|
|
|
+ if sleep_iterations > 0 and sleep_iterations % 10 == 0:
|
|
|
+ PubHandler.ping(user, generate_task_id)
|
|
|
|
|
|
- time.sleep(1)
|
|
|
- sleep_iterations += 1
|
|
|
+ time.sleep(1)
|
|
|
+ sleep_iterations += 1
|
|
|
|
|
|
- if worker_thread.is_alive():
|
|
|
- PubHandler.stop(user, generate_task_id)
|
|
|
- try:
|
|
|
- pubsub.close()
|
|
|
- except:
|
|
|
- pass
|
|
|
+ if worker_thread.is_alive():
|
|
|
+ PubHandler.stop(user, generate_task_id)
|
|
|
+ try:
|
|
|
+ pubsub.close()
|
|
|
+ except:
|
|
|
+ pass
|
|
|
|
|
|
countdown_thread = threading.Thread(target=close_pubsub)
|
|
|
countdown_thread.start()
|
|
@@ -288,7 +289,7 @@ class CompletionService:
|
|
|
|
|
|
generate_worker_thread.start()
|
|
|
|
|
|
- cls.countdown_and_close(generate_worker_thread, pubsub, user, generate_task_id)
|
|
|
+ cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id)
|
|
|
|
|
|
return cls.compact_response(pubsub, streaming)
|
|
|
|
|
@@ -313,15 +314,14 @@ class CompletionService:
|
|
|
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
|
|
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
|
|
|
ModelCurrentlyNotSupportError) as e:
|
|
|
- db.session.rollback()
|
|
|
PubHandler.pub_error(user, generate_task_id, e)
|
|
|
except LLMAuthorizationError:
|
|
|
- db.session.rollback()
|
|
|
PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided'))
|
|
|
except Exception as e:
|
|
|
- db.session.rollback()
|
|
|
logging.exception("Unknown Error in completion")
|
|
|
PubHandler.pub_error(user, generate_task_id, e)
|
|
|
+ finally:
|
|
|
+ db.session.commit()
|
|
|
|
|
|
@classmethod
|
|
|
def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):
|