Ver Fonte

Fix/shared lock (#210)

John Wang há 1 ano atrás
pai
commit
1a5acf43aa

+ 5 - 1
api/core/callback_handler/index_tool_callback_handler.py

@@ -34,5 +34,9 @@ class DatasetIndexToolCallbackHandler(IndexToolCallbackHandler):
             db.session.query(DocumentSegment).filter(
                 DocumentSegment.dataset_id == self.dataset_id,
                 DocumentSegment.index_node_id == index_node_id
-            ).update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
+            ).update(
+                {DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
+                synchronize_session=False
+            )
 
+            db.session.commit()

+ 9 - 1
api/core/completion.py

@@ -1,14 +1,17 @@
+import logging
 from typing import Optional, List, Union, Tuple
 
 from langchain.callbacks import CallbackManager
 from langchain.chat_models.base import BaseChatModel
 from langchain.llms import BaseLLM
 from langchain.schema import BaseMessage, BaseLanguageModel, HumanMessage
+from requests.exceptions import ChunkedEncodingError
+
 from core.constant import llm_constant
 from core.callback_handler.llm_callback_handler import LLMCallbackHandler
 from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
     DifyStdOutCallbackHandler
-from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
+from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, PubHandler
 from core.llm.error import LLMBadRequestError
 from core.llm.llm_builder import LLMBuilder
 from core.chain.main_chain_builder import MainChainBuilder
@@ -84,6 +87,11 @@ class Completion:
             )
         except ConversationTaskStoppedException:
             return
+        except ChunkedEncodingError as e:
+            # Interrupt by LLM (like OpenAI), handle it.
+            logging.warning(f'ChunkedEncodingError: {e}')
+            conversation_message_task.end()
+            return
 
     @classmethod
     def run_final_llm(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,

+ 4 - 1
api/core/conversation_message_task.py

@@ -171,7 +171,7 @@ class ConversationMessageTask:
         )
 
         if not by_stopped:
-            self._pub_handler.pub_end()
+            self.end()
 
     def update_provider_quota(self):
         llm_provider_service = LLMProviderService(
@@ -268,6 +268,9 @@ class ConversationMessageTask:
         total_price = message_tokens_per_1k * message_unit_price + answer_tokens_per_1k * answer_unit_price
         return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
 
+    def end(self):
+        self._pub_handler.pub_end()
+
 
 class PubHandler:
     def __init__(self, user: Union[Account | EndUser], task_id: str,