Ver Fonte

feat: optimize error raise (#820)

takatost há 1 ano atrás
pai
commit
1bd0a76a20

+ 1 - 1
api/core/callback_handler/llm_callback_handler.py

@@ -96,4 +96,4 @@ class LLMCallbackHandler(BaseCallbackHandler):
                 )
                 self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
         else:
-            logging.exception(error)
+            logging.debug("on_llm_error: %s", error)

+ 9 - 4
api/core/generator/llm_generator.py

@@ -2,6 +2,7 @@ import logging
 
 from langchain.schema import OutputParserException
 
+from core.model_providers.error import LLMError
 from core.model_providers.model_factory import ModelFactory
 from core.model_providers.models.entity.message import PromptMessage, MessageType
 from core.model_providers.models.entity.model_params import ModelKwargs
@@ -120,8 +121,10 @@ class LLMGenerator:
         try:
             output = model_instance.run(prompts)
             questions = output_parser.parse(output.content)
-        except Exception:
-            logging.exception("Error generating suggested questions after answer")
+        except LLMError:
+            questions = []
+        except Exception as e:
+            logging.exception(e)
             questions = []
 
         return questions
@@ -157,10 +160,12 @@ class LLMGenerator:
         try:
             output = model_instance.run(prompts)
             rule_config = output_parser.parse(output.content)
+        except LLMError as e:
+            raise e
         except OutputParserException:
             raise ValueError('Please give a valid input for intended audience or hoping to solve problems.')
-        except Exception:
-            logging.exception("Error generating prompt")
+        except Exception as e:
+            logging.exception(e)
             rule_config = {
                 "prompt": "",
                 "variables": [],

+ 5 - 1
api/core/model_providers/providers/azure_openai_provider.py

@@ -283,7 +283,11 @@ class AzureOpenAIProvider(BaseModelProvider):
                 if obfuscated:
                     credentials['openai_api_key'] = encrypter.obfuscated_token(credentials['openai_api_key'])
 
-            return credentials
+            return {
+                'openai_api_base': credentials['openai_api_base'],
+                'openai_api_key': credentials['openai_api_key'],
+                'base_model_name': credentials['base_model_name']
+            }
         else:
             if hosted_model_providers.azure_openai:
                 return {

+ 5 - 2
api/tasks/generate_conversation_summary_task.py

@@ -6,6 +6,7 @@ from celery import shared_task
 from werkzeug.exceptions import NotFound
 
 from core.generator.llm_generator import LLMGenerator
+from core.model_providers.error import LLMError
 from extensions.ext_database import db
 from models.model import Conversation, Message
 
@@ -42,5 +43,7 @@ def generate_conversation_summary_task(conversation_id: str):
 
         end_at = time.perf_counter()
         logging.info(click.style('Conversation summary generated: {} latency: {}'.format(conversation_id, end_at - start_at), fg='green'))
-    except Exception:
-        logging.exception("generate conversation summary failed")
+    except LLMError:
+        pass
+    except Exception as e:
+        logging.exception(e)