Browse Source

chore: skip unnecessary key checks prior to accessing a dictionary (#4497)

Bowen Liang 11 tháng trước cách đây
mục cha
commit
04ad46dd31
30 tập tin đã thay đổi với 45 bổ sung44 xóa
  1. 1 1
      api/core/app/app_config/common/sensitive_word_avoidance/manager.py
  2. 1 1
      api/core/app/app_config/features/file_upload/manager.py
  3. 1 1
      api/core/app/app_config/features/more_like_this/manager.py
  4. 1 1
      api/core/app/app_config/features/retrieval_resource/manager.py
  5. 1 1
      api/core/app/app_config/features/speech_to_text/manager.py
  6. 1 1
      api/core/app/app_config/features/suggested_questions_after_answer/manager.py
  7. 1 1
      api/core/app/app_config/features/text_to_speech/manager.py
  8. 1 1
      api/core/app/apps/advanced_chat/app_generator.py
  9. 1 1
      api/core/app/apps/agent_chat/app_generator.py
  10. 1 1
      api/core/app/apps/chat/app_generator.py
  11. 1 1
      api/core/app/apps/completion/app_generator.py
  12. 1 1
      api/core/app/apps/workflow/app_generator.py
  13. 2 2
      api/core/docstore/dataset_docstore.py
  14. 1 1
      api/core/indexing_runner.py
  15. 2 2
      api/core/model_runtime/model_providers/anthropic/llm/llm.py
  16. 1 1
      api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py
  17. 2 2
      api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py
  18. 1 1
      api/core/model_runtime/model_providers/openai/_common.py
  19. 1 1
      api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py
  20. 1 1
      api/core/rag/retrieval/dataset_retrieval.py
  21. 1 1
      api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py
  22. 1 1
      api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py
  23. 1 1
      api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py
  24. 1 1
      api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py
  25. 5 5
      api/core/tools/utils/web_reader_tool.py
  26. 1 1
      api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
  27. 1 0
      api/pyproject.toml
  28. 1 1
      api/services/annotation_service.py
  29. 2 2
      api/services/app_service.py
  30. 8 8
      api/services/dataset_service.py

+ 1 - 1
api/core/app/app_config/common/sensitive_word_avoidance/manager.py

@@ -11,7 +11,7 @@ class SensitiveWordAvoidanceConfigManager:
         if not sensitive_word_avoidance_dict:
             return None
 
-        if 'enabled' in sensitive_word_avoidance_dict and sensitive_word_avoidance_dict['enabled']:
+        if sensitive_word_avoidance_dict.get('enabled'):
             return SensitiveWordAvoidanceEntity(
                 type=sensitive_word_avoidance_dict.get('type'),
                 config=sensitive_word_avoidance_dict.get('config'),

+ 1 - 1
api/core/app/app_config/features/file_upload/manager.py

@@ -14,7 +14,7 @@ class FileUploadConfigManager:
         """
         file_upload_dict = config.get('file_upload')
         if file_upload_dict:
-            if 'image' in file_upload_dict and file_upload_dict['image']:
+            if file_upload_dict.get('image'):
                 if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']:
                     image_config = {
                         'number_limits': file_upload_dict['image']['number_limits'],

+ 1 - 1
api/core/app/app_config/features/more_like_this/manager.py

@@ -9,7 +9,7 @@ class MoreLikeThisConfigManager:
         more_like_this = False
         more_like_this_dict = config.get('more_like_this')
         if more_like_this_dict:
-            if 'enabled' in more_like_this_dict and more_like_this_dict['enabled']:
+            if more_like_this_dict.get('enabled'):
                 more_like_this = True
 
         return more_like_this

+ 1 - 1
api/core/app/app_config/features/retrieval_resource/manager.py

@@ -4,7 +4,7 @@ class RetrievalResourceConfigManager:
         show_retrieve_source = False
         retriever_resource_dict = config.get('retriever_resource')
         if retriever_resource_dict:
-            if 'enabled' in retriever_resource_dict and retriever_resource_dict['enabled']:
+            if retriever_resource_dict.get('enabled'):
                 show_retrieve_source = True
 
         return show_retrieve_source

+ 1 - 1
api/core/app/app_config/features/speech_to_text/manager.py

@@ -9,7 +9,7 @@ class SpeechToTextConfigManager:
         speech_to_text = False
         speech_to_text_dict = config.get('speech_to_text')
         if speech_to_text_dict:
-            if 'enabled' in speech_to_text_dict and speech_to_text_dict['enabled']:
+            if speech_to_text_dict.get('enabled'):
                 speech_to_text = True
 
         return speech_to_text

+ 1 - 1
api/core/app/app_config/features/suggested_questions_after_answer/manager.py

@@ -9,7 +9,7 @@ class SuggestedQuestionsAfterAnswerConfigManager:
         suggested_questions_after_answer = False
         suggested_questions_after_answer_dict = config.get('suggested_questions_after_answer')
         if suggested_questions_after_answer_dict:
-            if 'enabled' in suggested_questions_after_answer_dict and suggested_questions_after_answer_dict['enabled']:
+            if suggested_questions_after_answer_dict.get('enabled'):
                 suggested_questions_after_answer = True
 
         return suggested_questions_after_answer

+ 1 - 1
api/core/app/app_config/features/text_to_speech/manager.py

@@ -12,7 +12,7 @@ class TextToSpeechConfigManager:
         text_to_speech = False
         text_to_speech_dict = config.get('text_to_speech')
         if text_to_speech_dict:
-            if 'enabled' in text_to_speech_dict and text_to_speech_dict['enabled']:
+            if text_to_speech_dict.get('enabled'):
                 text_to_speech = TextToSpeechEntity(
                     enabled=text_to_speech_dict.get('enabled'),
                     voice=text_to_speech_dict.get('voice'),

+ 1 - 1
api/core/app/apps/advanced_chat/app_generator.py

@@ -66,7 +66,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
             conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
 
         # parse files
-        files = args['files'] if 'files' in args and args['files'] else []
+        files = args['files'] if args.get('files') else []
         message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
         file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
         if file_extra_config:

+ 1 - 1
api/core/app/apps/agent_chat/app_generator.py

@@ -83,7 +83,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
             )
 
         # parse files
-        files = args['files'] if 'files' in args and args['files'] else []
+        files = args['files'] if args.get('files') else []
         message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
         file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
         if file_extra_config:

+ 1 - 1
api/core/app/apps/chat/app_generator.py

@@ -80,7 +80,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
             )
 
         # parse files
-        files = args['files'] if 'files' in args and args['files'] else []
+        files = args['files'] if args.get('files') else []
         message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
         file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
         if file_extra_config:

+ 1 - 1
api/core/app/apps/completion/app_generator.py

@@ -75,7 +75,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
             )
 
         # parse files
-        files = args['files'] if 'files' in args and args['files'] else []
+        files = args['files'] if args.get('files') else []
         message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
         file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
         if file_extra_config:

+ 1 - 1
api/core/app/apps/workflow/app_generator.py

@@ -49,7 +49,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
         inputs = args['inputs']
 
         # parse files
-        files = args['files'] if 'files' in args and args['files'] else []
+        files = args['files'] if args.get('files') else []
         message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
         file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
         if file_extra_config:

+ 2 - 2
api/core/docstore/dataset_docstore.py

@@ -121,13 +121,13 @@ class DatasetDocumentStore:
                     enabled=False,
                     created_by=self._user_id,
                 )
-                if 'answer' in doc.metadata and doc.metadata['answer']:
+                if doc.metadata.get('answer'):
                     segment_document.answer = doc.metadata.pop('answer', '')
 
                 db.session.add(segment_document)
             else:
                 segment_document.content = doc.page_content
-                if 'answer' in doc.metadata and doc.metadata['answer']:
+                if doc.metadata.get('answer'):
                     segment_document.answer = doc.metadata.pop('answer', '')
                 segment_document.index_node_hash = doc.metadata['doc_hash']
                 segment_document.word_count = len(doc.page_content)

+ 1 - 1
api/core/indexing_runner.py

@@ -418,7 +418,7 @@ class IndexingRunner:
             if separator:
                 separator = separator.replace('\\n', '\n')
 
-            if 'chunk_overlap' in segmentation and segmentation['chunk_overlap']:
+            if segmentation.get('chunk_overlap'):
                 chunk_overlap = segmentation['chunk_overlap']
             else:
                 chunk_overlap = 0

+ 2 - 2
api/core/model_runtime/model_providers/anthropic/llm/llm.py

@@ -146,7 +146,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
         """
         Code block mode wrapper for invoking large language model
         """
-        if 'response_format' in model_parameters and model_parameters['response_format']:
+        if model_parameters.get('response_format'):
             stop = stop or []
             # chat model
             self._transform_chat_json_prompts(
@@ -408,7 +408,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
             "max_retries": 1,
         }
 
-        if 'anthropic_api_url' in credentials and credentials['anthropic_api_url']:
+        if credentials.get('anthropic_api_url'):
             credentials['anthropic_api_url'] = credentials['anthropic_api_url'].rstrip('/')
             credentials_kwargs['base_url'] = credentials['anthropic_api_url']
 

+ 1 - 1
api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py

@@ -89,7 +89,7 @@ class BaichuanModel:
             # save stop reason temporarily
             stop_reason = ''
             for choice in choices:
-                if 'finish_reason' in choice and choice['finish_reason']:
+                if choice.get('finish_reason'):
                     stop_reason = choice['finish_reason']
 
                 if len(choice['delta']['content']) == 0:

+ 2 - 2
api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py

@@ -43,7 +43,7 @@ class MinimaxChatCompletionPro:
         if 'top_p' in model_parameters and type(model_parameters['top_p']) == float:
             extra_kwargs['top_p'] = model_parameters['top_p']
 
-        if 'plugin_web_search' in model_parameters and model_parameters['plugin_web_search']:
+        if model_parameters.get('plugin_web_search'):
             extra_kwargs['plugins'] = [
                 'plugin_web_search'
             ]
@@ -158,7 +158,7 @@ class MinimaxChatCompletionPro:
                 self._handle_error(code, msg)
 
             # final chunk
-            if data['reply'] or 'usage' in data and data['usage']:
+            if data['reply'] or data.get('usage'):
                 total_tokens = data['usage']['total_tokens']
                 minimax_message = MinimaxMessage(
                     role=MinimaxMessage.Role.ASSISTANT.value,

+ 1 - 1
api/core/model_runtime/model_providers/openai/_common.py

@@ -25,7 +25,7 @@ class _CommonOpenAI:
             "max_retries": 1,
         }
 
-        if 'openai_api_base' in credentials and credentials['openai_api_base']:
+        if credentials.get('openai_api_base'):
             credentials['openai_api_base'] = credentials['openai_api_base'].rstrip('/')
             credentials_kwargs['base_url'] = credentials['openai_api_base'] + '/v1'
 

+ 1 - 1
api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py

@@ -180,7 +180,7 @@ class OpenLLMGenerate:
                 completion_usage += len(token_ids)
                 message = OpenLLMGenerateMessage(content=text, role=OpenLLMGenerateMessage.Role.ASSISTANT.value)
 
-                if 'finish_reason' in choice and choice['finish_reason']:
+                if choice.get('finish_reason'):
                     finish_reason = choice['finish_reason']
                     prompt_token_ids = data['prompt_token_ids']
                     message.stop_reason = finish_reason

+ 1 - 1
api/core/rag/retrieval/dataset_retrieval.py

@@ -124,7 +124,7 @@ class DatasetRetrieval:
 
         document_score_list = {}
         for item in all_documents:
-            if 'score' in item.metadata and item.metadata['score']:
+            if item.metadata.get('score'):
                 document_score_list[item.metadata['doc_id']] = item.metadata['score']
 
         document_context_list = []

+ 1 - 1
api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py

@@ -70,7 +70,7 @@ class StableDiffusionTool(BuiltinTool):
         if not base_url:
             return self.create_text_message('Please input base_url')
 
-        if 'model' in tool_parameters and tool_parameters['model']:
+        if tool_parameters.get('model'):
             self.runtime.credentials['model'] = tool_parameters['model']
 
         model = self.runtime.credentials.get('model', None)

+ 1 - 1
api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py

@@ -48,7 +48,7 @@ class WolframAlphaTool(BuiltinTool):
             
             if 'success' not in response_data['queryresult'] or response_data['queryresult']['success'] != True:
                 query_result = response_data.get('queryresult', {})
-                if 'error' in query_result and query_result['error']:
+                if query_result.get('error'):
                     if 'msg' in query_result['error']:
                         if query_result['error']['msg'] == 'Invalid appid':
                             raise ToolProviderCredentialValidationError('Invalid appid')

+ 1 - 1
api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py

@@ -79,7 +79,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
 
         document_score_list = {}
         for item in all_documents:
-            if 'score' in item.metadata and item.metadata['score']:
+            if item.metadata.get('score'):
                 document_score_list[item.metadata['doc_id']] = item.metadata['score']
 
         document_context_list = []

+ 1 - 1
api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py

@@ -87,7 +87,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
             document_score_list = {}
             if dataset.indexing_technique != "economy":
                 for item in documents:
-                    if 'score' in item.metadata and item.metadata['score']:
+                    if item.metadata.get('score'):
                         document_score_list[item.metadata['doc_id']] = item.metadata['score']
             document_context_list = []
             index_node_ids = [document.metadata['doc_id'] for document in documents]

+ 5 - 5
api/core/tools/utils/web_reader_tool.py

@@ -132,17 +132,17 @@ def extract_using_readabilipy(html):
     }
     # Populate article fields from readability fields where present
     if input_json:
-        if "title" in input_json and input_json["title"]:
+        if input_json.get("title"):
             article_json["title"] = input_json["title"]
-        if "byline" in input_json and input_json["byline"]:
+        if input_json.get("byline"):
             article_json["byline"] = input_json["byline"]
-        if "date" in input_json and input_json["date"]:
+        if input_json.get("date"):
             article_json["date"] = input_json["date"]
-        if "content" in input_json and input_json["content"]:
+        if input_json.get("content"):
             article_json["content"] = input_json["content"]
             article_json["plain_content"] = plain_content(article_json["content"], False, False)
             article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"])
-        if "textContent" in input_json and input_json["textContent"]:
+        if input_json.get("textContent"):
             article_json["plain_text"] = input_json["textContent"]
             article_json["plain_text"] = re.sub(r'\n\s*\n', '\n', article_json["plain_text"])
 

+ 1 - 1
api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py

@@ -143,7 +143,7 @@ class KnowledgeRetrievalNode(BaseNode):
         if all_documents:
             document_score_list = {}
             for item in all_documents:
-                if 'score' in item.metadata and item.metadata['score']:
+                if item.metadata.get('score'):
                     document_score_list[item.metadata['doc_id']] = item.metadata['score']
 
             index_node_ids = [document.metadata['doc_id'] for document in all_documents]

+ 1 - 0
api/pyproject.toml

@@ -13,6 +13,7 @@ select = [
     "I001", # unsorted-imports
     "I002", # missing-required-import
     "UP",   # pyupgrade rules
+    "RUF019", # unnecessary-key-check
 ]
 ignore = [
     "F403", # undefined-local-with-import-star

+ 1 - 1
api/services/annotation_service.py

@@ -31,7 +31,7 @@ class AppAnnotationService:
 
         if not app:
             raise NotFound("App not found")
-        if 'message_id' in args and args['message_id']:
+        if args.get('message_id'):
             message_id = str(args['message_id'])
             # get message info
             message = db.session.query(Message).filter(

+ 2 - 2
api/services/app_service.py

@@ -47,10 +47,10 @@ class AppService:
         elif args['mode'] == 'channel':
             filters.append(App.mode == AppMode.CHANNEL.value)
 
-        if 'name' in args and args['name']:
+        if args.get('name'):
             name = args['name'][:30]
             filters.append(App.name.ilike(f'%{name}%'))
-        if 'tag_ids' in args and args['tag_ids']:
+        if args.get('tag_ids'):
             target_ids = TagService.get_target_ids_by_tag_ids('app',
                                                               tenant_id,
                                                               args['tag_ids'])

+ 8 - 8
api/services/dataset_service.py

@@ -569,7 +569,7 @@ class DocumentService:
 
         documents = []
         batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
-        if 'original_document_id' in document_data and document_data["original_document_id"]:
+        if document_data.get("original_document_id"):
             document = DocumentService.update_document_with_dataset_id(dataset, document_data, account)
             documents.append(document)
         else:
@@ -750,10 +750,10 @@ class DocumentService:
         if document.display_status != 'available':
             raise ValueError("Document is not available")
         # update document name
-        if 'name' in document_data and document_data['name']:
+        if document_data.get('name'):
             document.name = document_data['name']
         # save process rule
-        if 'process_rule' in document_data and document_data['process_rule']:
+        if document_data.get('process_rule'):
             process_rule = document_data["process_rule"]
             if process_rule["mode"] == "custom":
                 dataset_process_rule = DatasetProcessRule(
@@ -773,7 +773,7 @@ class DocumentService:
             db.session.commit()
             document.dataset_process_rule_id = dataset_process_rule.id
         # update document data source
-        if 'data_source' in document_data and document_data['data_source']:
+        if document_data.get('data_source'):
             file_name = ''
             data_source_info = {}
             if document_data["data_source"]["type"] == "upload_file":
@@ -871,7 +871,7 @@ class DocumentService:
                 embedding_model.model
             )
             dataset_collection_binding_id = dataset_collection_binding.id
-            if 'retrieval_model' in document_data and document_data['retrieval_model']:
+            if document_data.get('retrieval_model'):
                 retrieval_model = document_data['retrieval_model']
             else:
                 default_retrieval_model = {
@@ -921,9 +921,9 @@ class DocumentService:
                     and ('process_rule' not in args and not args['process_rule']):
                 raise ValueError("Data source or Process rule is required")
             else:
-                if 'data_source' in args and args['data_source']:
+                if args.get('data_source'):
                     DocumentService.data_source_args_validate(args)
-                if 'process_rule' in args and args['process_rule']:
+                if args.get('process_rule'):
                     DocumentService.process_rule_args_validate(args)
 
     @classmethod
@@ -1266,7 +1266,7 @@ class SegmentService:
             if segment.content == content:
                 if document.doc_form == 'qa_model':
                     segment.answer = args['answer']
-                if 'keywords' in args and args['keywords']:
+                if args.get('keywords'):
                     segment.keywords = args['keywords']
                 segment.enabled = True
                 segment.disabled_at = None