瀏覽代碼

chore: apply flake8-comprehensions Ruff rules to improve collection comprehensions (#5652)

Co-authored-by: -LAN- <laipz8200@outlook.com>
Bowen Liang 10 月之前
父節點
當前提交
dcb72e0067
共有 58 個文件被更改,包括 123 次插入136 次删除
  1. 1 1
      api/core/app/app_config/easy_ui_based_app/agent/manager.py
  2. 1 1
      api/core/app/apps/advanced_chat/app_generator.py
  3. 1 1
      api/core/app/apps/agent_chat/app_generator.py
  4. 1 1
      api/core/app/apps/agent_chat/app_runner.py
  5. 1 1
      api/core/app/apps/chat/app_generator.py
  6. 2 2
      api/core/entities/provider_configuration.py
  7. 1 1
      api/core/indexing_runner.py
  8. 1 1
      api/core/model_runtime/model_providers/azure_openai/tts/tts.py
  9. 5 5
      api/core/model_runtime/model_providers/bedrock/llm/llm.py
  10. 3 3
      api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py
  11. 3 5
      api/core/model_runtime/model_providers/cohere/llm/llm.py
  12. 1 4
      api/core/model_runtime/model_providers/google/llm/llm.py
  13. 2 2
      api/core/model_runtime/model_providers/moonshot/llm/llm.py
  14. 4 4
      api/core/model_runtime/model_providers/nvidia/llm/llm.py
  15. 3 5
      api/core/model_runtime/model_providers/openai/llm/llm.py
  16. 1 1
      api/core/model_runtime/model_providers/openai/tts/tts.py
  17. 7 8
      api/core/model_runtime/model_providers/replicate/llm/llm.py
  18. 1 1
      api/core/model_runtime/model_providers/tongyi/tts/tts.py
  19. 1 4
      api/core/model_runtime/model_providers/vertex_ai/llm/llm.py
  20. 1 1
      api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/auth.py
  21. 1 1
      api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/service.py
  22. 1 1
      api/core/model_runtime/model_providers/xinference/llm/llm.py
  23. 1 1
      api/core/model_runtime/model_providers/zhipuai/_common.py
  24. 3 3
      api/core/prompt/simple_prompt_transform.py
  25. 3 3
      api/core/provider_manager.py
  26. 1 1
      api/core/rag/datasource/keyword/jieba/jieba.py
  27. 1 1
      api/core/rag/retrieval/router/multi_dataset_react_route.py
  28. 4 4
      api/core/tools/provider/builtin/bing/tools/bing_web_search.py
  29. 2 2
      api/core/tools/provider/builtin/chart/tools/bar.py
  30. 2 2
      api/core/tools/provider/builtin/chart/tools/line.py
  31. 2 2
      api/core/tools/provider/builtin/chart/tools/pie.py
  32. 2 2
      api/core/tools/provider/builtin/gaode/tools/gaode_weather.py
  33. 2 2
      api/core/tools/provider/builtin/github/tools/github_repositories.py
  34. 3 3
      api/core/tools/provider/builtin/jina/tools/jina_reader.py
  35. 1 1
      api/core/tools/provider/builtin/jina/tools/jina_search.py
  36. 1 1
      api/core/tools/provider/builtin/searchapi/tools/google.py
  37. 2 2
      api/core/tools/provider/builtin/searchapi/tools/google_jobs.py
  38. 1 1
      api/core/tools/provider/builtin/searchapi/tools/google_news.py
  39. 1 1
      api/core/tools/provider/builtin/searxng/tools/searxng_search.py
  40. 1 1
      api/core/tools/provider/builtin/websearch/tools/get_markdown.py
  41. 2 2
      api/core/tools/provider/builtin/websearch/tools/job_search.py
  42. 1 1
      api/core/tools/provider/builtin/websearch/tools/news_search.py
  43. 1 1
      api/core/tools/provider/builtin/websearch/tools/scholar_search.py
  44. 1 1
      api/core/tools/provider/builtin_tool_provider.py
  45. 2 2
      api/core/tools/tool/api_tool.py
  46. 1 1
      api/core/tools/tool/dataset_retriever_tool.py
  47. 1 1
      api/core/tools/tool_manager.py
  48. 8 11
      api/core/tools/utils/parser.py
  49. 1 1
      api/core/tools/utils/web_reader_tool.py
  50. 2 2
      api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py
  51. 2 8
      api/libs/oauth_data_source.py
  52. 16 9
      api/pyproject.toml
  53. 1 1
      api/services/dataset_service.py
  54. 1 1
      api/services/recommended_app_service.py
  55. 2 2
      api/services/workflow/workflow_converter.py
  56. 2 2
      api/tests/integration_tests/workflow/nodes/test_llm.py
  57. 1 1
      api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
  58. 2 2
      api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py

+ 1 - 1
api/core/app/app_config/easy_ui_based_app/agent/manager.py

@@ -40,7 +40,7 @@ class AgentConfigManager:
                         'provider_type': tool['provider_type'],
                         'provider_id': tool['provider_id'],
                         'tool_name': tool['tool_name'],
-                        'tool_parameters': tool['tool_parameters'] if 'tool_parameters' in tool else {}
+                        'tool_parameters': tool.get('tool_parameters', {})
                     }
 
                     agent_tools.append(AgentToolEntity(**agent_tool_properties))

+ 1 - 1
api/core/app/apps/advanced_chat/app_generator.py

@@ -59,7 +59,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
         inputs = args['inputs']
 
         extras = {
-            "auto_generate_conversation_name": args['auto_generate_name'] if 'auto_generate_name' in args else False
+            "auto_generate_conversation_name": args.get('auto_generate_name', False)
         }
 
         # get conversation

+ 1 - 1
api/core/app/apps/agent_chat/app_generator.py

@@ -57,7 +57,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
         inputs = args['inputs']
 
         extras = {
-            "auto_generate_conversation_name": args['auto_generate_name'] if 'auto_generate_name' in args else True
+            "auto_generate_conversation_name": args.get('auto_generate_name', True)
         }
 
         # get conversation

+ 1 - 1
api/core/app/apps/agent_chat/app_runner.py

@@ -203,7 +203,7 @@ class AgentChatAppRunner(AppRunner):
         llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
         model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
 
-        if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []):
+        if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
             agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
 
         conversation = db.session.query(Conversation).filter(Conversation.id == conversation.id).first()

+ 1 - 1
api/core/app/apps/chat/app_generator.py

@@ -55,7 +55,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
         inputs = args['inputs']
 
         extras = {
-            "auto_generate_conversation_name": args['auto_generate_name'] if 'auto_generate_name' in args else True
+            "auto_generate_conversation_name": args.get('auto_generate_name', True)
         }
 
         # get conversation

+ 2 - 2
api/core/entities/provider_configuration.py

@@ -66,8 +66,8 @@ class ProviderConfiguration(BaseModel):
                 original_provider_configurate_methods[self.provider.provider].append(configurate_method)
 
         if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
-            if (any([len(quota_configuration.restrict_models) > 0
-                     for quota_configuration in self.system_configuration.quota_configurations])
+            if (any(len(quota_configuration.restrict_models) > 0
+                     for quota_configuration in self.system_configuration.quota_configurations)
                     and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods):
                 self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
 

+ 1 - 1
api/core/indexing_runner.py

@@ -397,7 +397,7 @@ class IndexingRunner:
             document_id=dataset_document.id,
             after_indexing_status="splitting",
             extra_update_params={
-                DatasetDocument.word_count: sum([len(text_doc.page_content) for text_doc in text_docs]),
+                DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs),
                 DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
             }
         )

+ 1 - 1
api/core/model_runtime/model_providers/azure_openai/tts/tts.py

@@ -83,7 +83,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
         max_workers = self._get_model_workers_limit(model, credentials)
         try:
             sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
-            audio_bytes_list = list()
+            audio_bytes_list = []
 
             # Create a thread pool and map the function to the list of sentences
             with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:

+ 5 - 5
api/core/model_runtime/model_providers/bedrock/llm/llm.py

@@ -175,8 +175,8 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
         # - https://docs.anthropic.com/claude/reference/claude-on-amazon-bedrock
         # - https://github.com/anthropics/anthropic-sdk-python
         client = AnthropicBedrock(
-            aws_access_key=credentials.get("aws_access_key_id", None),
-            aws_secret_key=credentials.get("aws_secret_access_key", None),
+            aws_access_key=credentials.get("aws_access_key_id"),
+            aws_secret_key=credentials.get("aws_secret_access_key"),
             aws_region=credentials["aws_region"],
         )
 
@@ -576,7 +576,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
         """
         Create payload for bedrock api call depending on model provider
         """
-        payload = dict()
+        payload = {}
         model_prefix = model.split('.')[0]
         model_name = model.split('.')[1]
 
@@ -648,8 +648,8 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
         runtime_client = boto3.client(
             service_name='bedrock-runtime',
             config=client_config,
-            aws_access_key_id=credentials.get("aws_access_key_id", None),
-            aws_secret_access_key=credentials.get("aws_secret_access_key", None)
+            aws_access_key_id=credentials.get("aws_access_key_id"),
+            aws_secret_access_key=credentials.get("aws_secret_access_key")
         )
 
         model_prefix = model.split('.')[0]

+ 3 - 3
api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py

@@ -49,8 +49,8 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
         bedrock_runtime = boto3.client(
             service_name='bedrock-runtime',
             config=client_config,
-            aws_access_key_id=credentials.get("aws_access_key_id", None),
-            aws_secret_access_key=credentials.get("aws_secret_access_key", None)
+            aws_access_key_id=credentials.get("aws_access_key_id"),
+            aws_secret_access_key=credentials.get("aws_secret_access_key")
         )
 
         embeddings = []
@@ -148,7 +148,7 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
         """
         Create payload for bedrock api call depending on model provider
         """
-        payload = dict()
+        payload = {}
 
         if model_prefix == "amazon":
             payload['inputText'] = texts

+ 3 - 5
api/core/model_runtime/model_providers/cohere/llm/llm.py

@@ -696,12 +696,10 @@ class CohereLargeLanguageModel(LargeLanguageModel):
                 en_US=model
             ),
             model_type=ModelType.LLM,
-            features=[feature for feature in base_model_schema_features],
+            features=list(base_model_schema_features),
             fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
-            model_properties={
-                key: property for key, property in base_model_schema_model_properties.items()
-            },
-            parameter_rules=[rule for rule in base_model_schema_parameters_rules],
+            model_properties=dict(base_model_schema_model_properties.items()),
+            parameter_rules=list(base_model_schema_parameters_rules),
             pricing=base_model_schema.pricing
         )
 

+ 1 - 4
api/core/model_runtime/model_providers/google/llm/llm.py

@@ -277,10 +277,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
                             type='function',
                             function=AssistantPromptMessage.ToolCall.ToolCallFunction(
                                 name=part.function_call.name,
-                                arguments=json.dumps({
-                                    key: value 
-                                    for key, value in part.function_call.args.items()
-                                })
+                                arguments=json.dumps(dict(part.function_call.args.items()))
                             )
                         )
                     ]

+ 2 - 2
api/core/model_runtime/model_providers/moonshot/llm/llm.py

@@ -88,9 +88,9 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
 
     def _add_function_call(self, model: str, credentials: dict) -> None:
         model_schema = self.get_model_schema(model, credentials)
-        if model_schema and set([
+        if model_schema and {
             ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL
-        ]).intersection(model_schema.features or []):
+        }.intersection(model_schema.features or []):
             credentials['function_calling_type'] = 'tool_call'
 
     def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:

+ 4 - 4
api/core/model_runtime/model_providers/nvidia/llm/llm.py

@@ -100,10 +100,10 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel):
             if api_key:
                 headers["Authorization"] = f"Bearer {api_key}"
 
-            endpoint_url = credentials['endpoint_url'] if 'endpoint_url' in credentials else None
+            endpoint_url = credentials.get('endpoint_url')
             if endpoint_url and not endpoint_url.endswith('/'):
                 endpoint_url += '/'
-            server_url = credentials['server_url'] if 'server_url' in credentials else None
+            server_url = credentials.get('server_url')
 
             # prepare the payload for a simple ping to the model
             data = {
@@ -182,10 +182,10 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel):
         if stream:
             headers['Accept'] = 'text/event-stream'
 
-        endpoint_url = credentials['endpoint_url'] if 'endpoint_url' in credentials else None
+        endpoint_url = credentials.get('endpoint_url')
         if endpoint_url and not endpoint_url.endswith('/'):
             endpoint_url += '/'
-        server_url = credentials['server_url'] if 'server_url' in credentials else None
+        server_url = credentials.get('server_url')
 
         data = {
             "model": model,

+ 3 - 5
api/core/model_runtime/model_providers/openai/llm/llm.py

@@ -1073,12 +1073,10 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
                 en_US=model
             ),
             model_type=ModelType.LLM,
-            features=[feature for feature in base_model_schema_features],
+            features=list(base_model_schema_features),
             fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
-            model_properties={
-                key: property for key, property in base_model_schema_model_properties.items()
-            },
-            parameter_rules=[rule for rule in base_model_schema_parameters_rules],
+            model_properties=dict(base_model_schema_model_properties.items()),
+            parameter_rules=list(base_model_schema_parameters_rules),
             pricing=base_model_schema.pricing    
         )
 

+ 1 - 1
api/core/model_runtime/model_providers/openai/tts/tts.py

@@ -80,7 +80,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
         max_workers = self._get_model_workers_limit(model, credentials)
         try:
             sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
-            audio_bytes_list = list()
+            audio_bytes_list = []
 
             # Create a thread pool and map the function to the list of sentences
             with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:

+ 7 - 8
api/core/model_runtime/model_providers/replicate/llm/llm.py

@@ -275,14 +275,13 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
 
     @classmethod
     def _get_parameter_type(cls, param_type: str) -> str:
-        if param_type == 'integer':
-            return 'int'
-        elif param_type == 'number':
-            return 'float'
-        elif param_type == 'boolean':
-            return 'boolean'
-        elif param_type == 'string':
-            return 'string'
+        type_mapping = {
+            'integer': 'int',
+            'number': 'float',
+            'boolean': 'boolean',
+            'string': 'string'
+        }
+        return type_mapping.get(param_type)
 
     def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
         messages = messages.copy()  # don't mutate the original list

+ 1 - 1
api/core/model_runtime/model_providers/tongyi/tts/tts.py

@@ -80,7 +80,7 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel):
         max_workers = self._get_model_workers_limit(model, credentials)
         try:
             sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
-            audio_bytes_list = list()
+            audio_bytes_list = []
 
             # Create a thread pool and map the function to the list of sentences
             with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:

+ 1 - 4
api/core/model_runtime/model_providers/vertex_ai/llm/llm.py

@@ -579,10 +579,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
                             type='function',
                             function=AssistantPromptMessage.ToolCall.ToolCallFunction(
                                 name=part.function_call.name,
-                                arguments=json.dumps({
-                                    key: value 
-                                    for key, value in part.function_call.args.items()
-                                })
+                                arguments=json.dumps(dict(part.function_call.args.items()))
                             )
                         )
                     ]

+ 1 - 1
api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/auth.py

@@ -102,7 +102,7 @@ class Signer:
         body_hash = Util.sha256(request.body)
         request.headers['X-Content-Sha256'] = body_hash
 
-        signed_headers = dict()
+        signed_headers = {}
         for key in request.headers:
             if key in ['Content-Type', 'Content-Md5', 'Host'] or key.startswith('X-'):
                 signed_headers[key.lower()] = request.headers[key]

+ 1 - 1
api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/service.py

@@ -150,7 +150,7 @@ class Request:
         self.headers = OrderedDict()
         self.query = OrderedDict()
         self.body = ''
-        self.form = dict()
+        self.form = {}
         self.connection_timeout = 0
         self.socket_timeout = 0
 

+ 1 - 1
api/core/model_runtime/model_providers/xinference/llm/llm.py

@@ -147,7 +147,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
             return self._get_num_tokens_by_gpt2(text)
 
         if is_completion_model:
-            return sum([tokens(str(message.content)) for message in messages])
+            return sum(tokens(str(message.content)) for message in messages)
 
         tokens_per_message = 3
         tokens_per_name = 1

+ 1 - 1
api/core/model_runtime/model_providers/zhipuai/_common.py

@@ -18,7 +18,7 @@ class _CommonZhipuaiAI:
         """
         credentials_kwargs = {
             "api_key": credentials['api_key'] if 'api_key' in credentials else 
-                        credentials['zhipuai_api_key'] if 'zhipuai_api_key' in credentials else None,
+                        credentials.get("zhipuai_api_key"),
         }
 
         return credentials_kwargs

+ 3 - 3
api/core/prompt/simple_prompt_transform.py

@@ -148,7 +148,7 @@ class SimplePromptTransform(PromptTransform):
                 special_variable_keys.append('#histories#')
 
         if query_in_prompt:
-            prompt += prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{#query#}}'
+            prompt += prompt_rules.get('query_prompt', '{{#query#}}')
             special_variable_keys.append('#query#')
 
         return {
@@ -234,8 +234,8 @@ class SimplePromptTransform(PromptTransform):
                     )
                 ),
                 max_token_limit=rest_tokens,
-                human_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human',
-                ai_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
+                human_prefix=prompt_rules.get('human_prefix', 'Human'),
+                ai_prefix=prompt_rules.get('assistant_prefix', 'Assistant')
             )
 
             # get prompt

+ 3 - 3
api/core/provider_manager.py

@@ -417,7 +417,7 @@ class ProviderManager:
             model_load_balancing_enabled = cache_result == 'True'
 
         if not model_load_balancing_enabled:
-            return dict()
+            return {}
 
         provider_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \
             .filter(
@@ -451,7 +451,7 @@ class ProviderManager:
             if not provider_records:
                 provider_records = []
 
-            provider_quota_to_provider_record_dict = dict()
+            provider_quota_to_provider_record_dict = {}
             for provider_record in provider_records:
                 if provider_record.provider_type != ProviderType.SYSTEM.value:
                     continue
@@ -661,7 +661,7 @@ class ProviderManager:
         provider_hosting_configuration = hosting_configuration.provider_map.get(provider_entity.provider)
 
         # Convert provider_records to dict
-        quota_type_to_provider_records_dict = dict()
+        quota_type_to_provider_records_dict = {}
         for provider_record in provider_records:
             if provider_record.provider_type != ProviderType.SYSTEM.value:
                 continue

+ 1 - 1
api/core/rag/datasource/keyword/jieba/jieba.py

@@ -197,7 +197,7 @@ class Jieba(BaseKeyword):
                 chunk_indices_count[node_id] += 1
 
         sorted_chunk_indices = sorted(
-            list(chunk_indices_count.keys()),
+            chunk_indices_count.keys(),
             key=lambda x: chunk_indices_count[x],
             reverse=True,
         )

+ 1 - 1
api/core/rag/retrieval/router/multi_dataset_react_route.py

@@ -201,7 +201,7 @@ class ReactMultiDatasetRouter:
             tool_strings.append(
                 f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}")
         formatted_tools = "\n".join(tool_strings)
-        unique_tool_names = set(tool.name for tool in tools)
+        unique_tool_names = {tool.name for tool in tools}
         tool_names = ", ".join('"' + name + '"' for name in unique_tool_names)
         format_instructions = format_instructions.format(tool_names=tool_names)
         template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])

+ 4 - 4
api/core/tools/provider/builtin/bing/tools/bing_web_search.py

@@ -105,15 +105,15 @@ class BingSearchTool(BuiltinTool):
         
 
     def validate_credentials(self, credentials: dict[str, Any], tool_parameters: dict[str, Any]) -> None:
-        key = credentials.get('subscription_key', None)
+        key = credentials.get('subscription_key')
         if not key:
             raise Exception('subscription_key is required')
         
-        server_url = credentials.get('server_url', None)
+        server_url = credentials.get('server_url')
         if not server_url:
             server_url = self.url
 
-        query = tool_parameters.get('query', None)
+        query = tool_parameters.get('query')
         if not query:
             raise Exception('query is required')
         
@@ -170,7 +170,7 @@ class BingSearchTool(BuiltinTool):
         if not server_url:
             server_url = self.url
         
-        query = tool_parameters.get('query', None)
+        query = tool_parameters.get('query')
         if not query:
             raise Exception('query is required')
         

+ 2 - 2
api/core/tools/provider/builtin/chart/tools/bar.py

@@ -16,12 +16,12 @@ class BarChartTool(BuiltinTool):
         data = data.split(';')
 
         # if all data is int, convert to int
-        if all([i.isdigit() for i in data]):
+        if all(i.isdigit() for i in data):
             data = [int(i) for i in data]
         else:
             data = [float(i) for i in data]
 
-        axis = tool_parameters.get('x_axis', None) or None
+        axis = tool_parameters.get('x_axis') or None
         if axis:
             axis = axis.split(';')
             if len(axis) != len(data):

+ 2 - 2
api/core/tools/provider/builtin/chart/tools/line.py

@@ -17,14 +17,14 @@ class LinearChartTool(BuiltinTool):
             return self.create_text_message('Please input data')
         data = data.split(';')
 
-        axis = tool_parameters.get('x_axis', None) or None
+        axis = tool_parameters.get('x_axis') or None
         if axis:
             axis = axis.split(';')
             if len(axis) != len(data):
                 axis = None
 
         # if all data is int, convert to int
-        if all([i.isdigit() for i in data]):
+        if all(i.isdigit() for i in data):
             data = [int(i) for i in data]
         else:
             data = [float(i) for i in data]

+ 2 - 2
api/core/tools/provider/builtin/chart/tools/pie.py

@@ -16,10 +16,10 @@ class PieChartTool(BuiltinTool):
         if not data:
             return self.create_text_message('Please input data')
         data = data.split(';')
-        categories = tool_parameters.get('categories', None) or None
+        categories = tool_parameters.get('categories') or None
 
         # if all data is int, convert to int
-        if all([i.isdigit() for i in data]):
+        if all(i.isdigit() for i in data):
             data = [int(i) for i in data]
         else:
             data = [float(i) for i in data]

+ 2 - 2
api/core/tools/provider/builtin/gaode/tools/gaode_weather.py

@@ -37,10 +37,10 @@ class GaodeRepositoriesTool(BuiltinTool):
                                                                    apikey=self.runtime.credentials.get('api_key')))
                     weatherInfo_data = weatherInfo_response.json()
                     if weatherInfo_response.status_code == 200 and weatherInfo_data.get('info') == 'OK':
-                        contents = list()
+                        contents = []
                         if len(weatherInfo_data.get('forecasts')) > 0:
                             for item in weatherInfo_data['forecasts'][0]['casts']:
-                                content = dict()
+                                content = {}
                                 content['date'] = item.get('date')
                                 content['week'] = item.get('week')
                                 content['dayweather'] = item.get('dayweather')

+ 2 - 2
api/core/tools/provider/builtin/github/tools/github_repositories.py

@@ -39,10 +39,10 @@ class GihubRepositoriesTool(BuiltinTool):
                                      f"q={quote(query)}&sort=stars&per_page={top_n}&order=desc")
             response_data = response.json()
             if response.status_code == 200 and isinstance(response_data.get('items'), list):
-                contents = list()
+                contents = []
                 if len(response_data.get('items')) > 0:
                     for item in response_data.get('items'):
-                        content = dict()
+                        content = {}
                         updated_at_object = datetime.strptime(item['updated_at'], "%Y-%m-%dT%H:%M:%SZ")
                         content['owner'] = item['owner']['login']
                         content['name'] = item['name']

+ 3 - 3
api/core/tools/provider/builtin/jina/tools/jina_reader.py

@@ -26,11 +26,11 @@ class JinaReaderTool(BuiltinTool):
         if 'api_key' in self.runtime.credentials and self.runtime.credentials.get('api_key'):
             headers['Authorization'] = "Bearer " + self.runtime.credentials.get('api_key')
 
-        target_selector = tool_parameters.get('target_selector', None)
+        target_selector = tool_parameters.get('target_selector')
         if target_selector is not None and target_selector != '':
             headers['X-Target-Selector'] = target_selector
 
-        wait_for_selector = tool_parameters.get('wait_for_selector', None)
+        wait_for_selector = tool_parameters.get('wait_for_selector')
         if wait_for_selector is not None and wait_for_selector != '':
             headers['X-Wait-For-Selector'] = wait_for_selector
 
@@ -43,7 +43,7 @@ class JinaReaderTool(BuiltinTool):
         if tool_parameters.get('gather_all_images_at_the_end', False):
             headers['X-With-Images-Summary'] = 'true'
 
-        proxy_server = tool_parameters.get('proxy_server', None)
+        proxy_server = tool_parameters.get('proxy_server')
         if proxy_server is not None and proxy_server != '':
             headers['X-Proxy-Url'] = proxy_server
 

+ 1 - 1
api/core/tools/provider/builtin/jina/tools/jina_search.py

@@ -33,7 +33,7 @@ class JinaSearchTool(BuiltinTool):
         if tool_parameters.get('gather_all_images_at_the_end', False):
             headers['X-With-Images-Summary'] = 'true'
 
-        proxy_server = tool_parameters.get('proxy_server', None)
+        proxy_server = tool_parameters.get('proxy_server')
         if proxy_server is not None and proxy_server != '':
             headers['X-Proxy-Url'] = proxy_server
 

+ 1 - 1
api/core/tools/provider/builtin/searchapi/tools/google.py

@@ -94,7 +94,7 @@ class GoogleTool(BuiltinTool):
         google_domain = tool_parameters.get("google_domain", "google.com")
         gl = tool_parameters.get("gl", "us")
         hl = tool_parameters.get("hl", "en")
-        location = tool_parameters.get("location", None)
+        location = tool_parameters.get("location")
 
         api_key = self.runtime.credentials['searchapi_api_key']
         result = SearchAPI(api_key).run(query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location)

+ 2 - 2
api/core/tools/provider/builtin/searchapi/tools/google_jobs.py

@@ -72,11 +72,11 @@ class GoogleJobsTool(BuiltinTool):
         """
         query = tool_parameters['query']
         result_type = tool_parameters['result_type']
-        is_remote = tool_parameters.get("is_remote", None)
+        is_remote = tool_parameters.get("is_remote")
         google_domain = tool_parameters.get("google_domain", "google.com")
         gl = tool_parameters.get("gl", "us")
         hl = tool_parameters.get("hl", "en")
-        location = tool_parameters.get("location", None)
+        location = tool_parameters.get("location")
 
         ltype = 1 if is_remote else None
 

+ 1 - 1
api/core/tools/provider/builtin/searchapi/tools/google_news.py

@@ -82,7 +82,7 @@ class GoogleNewsTool(BuiltinTool):
         google_domain = tool_parameters.get("google_domain", "google.com")
         gl = tool_parameters.get("gl", "us")
         hl = tool_parameters.get("hl", "en")
-        location = tool_parameters.get("location", None)
+        location = tool_parameters.get("location")
 
         api_key = self.runtime.credentials['searchapi_api_key']
         result = SearchAPI(api_key).run(query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location)

+ 1 - 1
api/core/tools/provider/builtin/searxng/tools/searxng_search.py

@@ -107,7 +107,7 @@ class SearXNGSearchTool(BuiltinTool):
         if not host:
             raise Exception('SearXNG api is required')
                 
-        query = tool_parameters.get('query', None)
+        query = tool_parameters.get('query')
         if not query:
             return self.create_text_message('Please input query')
                 

+ 1 - 1
api/core/tools/provider/builtin/websearch/tools/get_markdown.py

@@ -43,7 +43,7 @@ class GetMarkdownTool(BuiltinTool):
         Invoke the SerplyApi tool.
         """
         url = tool_parameters["url"]
-        location = tool_parameters.get("location", None)
+        location = tool_parameters.get("location")
 
         api_key = self.runtime.credentials["serply_api_key"]
         result = SerplyApi(api_key).run(url, location=location)

+ 2 - 2
api/core/tools/provider/builtin/websearch/tools/job_search.py

@@ -55,7 +55,7 @@ class SerplyApi:
                         f"Employer: {job['employer']}",
                         f"Location: {job['location']}",
                         f"Link: {job['link']}",
-                        f"""Highest: {", ".join([h for h in job["highlights"]])}""",
+                        f"""Highest: {", ".join(list(job["highlights"]))}""",
                         "---",
                     ])
                 )
@@ -78,7 +78,7 @@ class JobSearchTool(BuiltinTool):
         query = tool_parameters["query"]
         gl = tool_parameters.get("gl", "us")
         hl = tool_parameters.get("hl", "en")
-        location = tool_parameters.get("location", None)
+        location = tool_parameters.get("location")
 
         api_key = self.runtime.credentials["serply_api_key"]
         result = SerplyApi(api_key).run(query, gl=gl, hl=hl, location=location)

+ 1 - 1
api/core/tools/provider/builtin/websearch/tools/news_search.py

@@ -80,7 +80,7 @@ class NewsSearchTool(BuiltinTool):
         query = tool_parameters["query"]
         gl = tool_parameters.get("gl", "us")
         hl = tool_parameters.get("hl", "en")
-        location = tool_parameters.get("location", None)
+        location = tool_parameters.get("location")
 
         api_key = self.runtime.credentials["serply_api_key"]
         result = SerplyApi(api_key).run(query, gl=gl, hl=hl, location=location)

+ 1 - 1
api/core/tools/provider/builtin/websearch/tools/scholar_search.py

@@ -83,7 +83,7 @@ class ScholarSearchTool(BuiltinTool):
         query = tool_parameters["query"]
         gl = tool_parameters.get("gl", "us")
         hl = tool_parameters.get("hl", "en")
-        location = tool_parameters.get("location", None)
+        location = tool_parameters.get("location")
 
         api_key = self.runtime.credentials["serply_api_key"]
         result = SerplyApi(api_key).run(query, gl=gl, hl=hl, location=location)

+ 1 - 1
api/core/tools/provider/builtin_tool_provider.py

@@ -38,7 +38,7 @@ class BuiltinToolProviderController(ToolProviderController):
 
         super().__init__(**{
             'identity': provider_yaml['identity'],
-            'credentials_schema': provider_yaml['credentials_for_provider'] if 'credentials_for_provider' in provider_yaml else None,
+            'credentials_schema': provider_yaml.get('credentials_for_provider', None),
         })
 
     def _get_builtin_tools(self) -> list[Tool]:

+ 2 - 2
api/core/tools/tool/api_tool.py

@@ -159,8 +159,8 @@ class ApiTool(Tool):
                 for content_type in self.api_bundle.openapi['requestBody']['content']:
                     headers['Content-Type'] = content_type
                     body_schema = self.api_bundle.openapi['requestBody']['content'][content_type]['schema']
-                    required = body_schema['required'] if 'required' in body_schema else []
-                    properties = body_schema['properties'] if 'properties' in body_schema else {}
+                    required = body_schema.get('required', [])
+                    properties = body_schema.get('properties', {})
                     for name, property in properties.items():
                         if name in parameters:
                             # convert type

+ 1 - 1
api/core/tools/tool/dataset_retriever_tool.py

@@ -90,7 +90,7 @@ class DatasetRetrieverTool(Tool):
         """
         invoke dataset retriever tool
         """
-        query = tool_parameters.get('query', None)
+        query = tool_parameters.get('query')
         if not query:
             return self.create_text_message(text='please input query')
 

+ 1 - 1
api/core/tools/tool_manager.py

@@ -209,7 +209,7 @@ class ToolManager:
 
         if parameter_rule.type == ToolParameter.ToolParameterType.SELECT:
             # check if tool_parameter_config in options
-            options = list(map(lambda x: x.value, parameter_rule.options))
+            options = [x.value for x in parameter_rule.options]
             if parameter_value is not None and parameter_value not in options:
                 raise ValueError(
                     f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}")

+ 8 - 11
api/core/tools/utils/parser.py

@@ -21,10 +21,7 @@ class ApiBasedToolSchemaParser:
         extra_info = extra_info if extra_info is not None else {}
 
         # set description to extra_info
-        if 'description' in openapi['info']:
-            extra_info['description'] = openapi['info']['description']
-        else:
-            extra_info['description'] = ''
+        extra_info['description'] = openapi['info'].get('description', '')
 
         if len(openapi['servers']) == 0:
             raise ToolProviderNotFoundError('No server found in the openapi yaml.')
@@ -95,8 +92,8 @@ class ApiBasedToolSchemaParser:
                     # parse body parameters
                     if 'schema' in interface['operation']['requestBody']['content'][content_type]:
                         body_schema = interface['operation']['requestBody']['content'][content_type]['schema']
-                        required = body_schema['required'] if 'required' in body_schema else []
-                        properties = body_schema['properties'] if 'properties' in body_schema else {}
+                        required = body_schema.get('required', [])
+                        properties = body_schema.get('properties', {})
                         for name, property in properties.items():
                             tool = ToolParameter(
                                 name=name,
@@ -105,14 +102,14 @@ class ApiBasedToolSchemaParser:
                                     zh_Hans=name
                                 ),
                                 human_description=I18nObject(
-                                    en_US=property['description'] if 'description' in property else '',
-                                    zh_Hans=property['description'] if 'description' in property else ''
+                                    en_US=property.get('description', ''),
+                                    zh_Hans=property.get('description', '')
                                 ),
                                 type=ToolParameter.ToolParameterType.STRING,
                                 required=name in required,
                                 form=ToolParameter.ToolParameterForm.LLM,
-                                llm_description=property['description'] if 'description' in property else '',
-                                default=property['default'] if 'default' in property else None,
+                                llm_description=property.get('description', ''),
+                                default=property.get('default', None),
                             )
 
                             # check if there is a type
@@ -149,7 +146,7 @@ class ApiBasedToolSchemaParser:
                 server_url=server_url + interface['path'],
                 method=interface['method'],
                 summary=interface['operation']['description'] if 'description' in interface['operation'] else 
-                        interface['operation']['summary'] if 'summary' in interface['operation'] else None,
+                        interface['operation'].get('summary', None),
                 operation_id=interface['operation']['operationId'],
                 parameters=parameters,
                 author='',

+ 1 - 1
api/core/tools/utils/web_reader_tool.py

@@ -283,7 +283,7 @@ def strip_control_characters(text):
     #   [Cn]: Other, Not Assigned
     #   [Co]: Other, Private Use
     #   [Cs]: Other, Surrogate
-    control_chars = set(['Cc', 'Cf', 'Cn', 'Co', 'Cs'])
+    control_chars = {'Cc', 'Cf', 'Cn', 'Co', 'Cs'}
     retained_chars = ['\t', '\n', '\r', '\f']
 
     # Remove non-printing control characters

+ 2 - 2
api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py

@@ -93,7 +93,7 @@ class ParameterExtractorNode(LLMNode):
         # fetch memory
         memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
 
-        if set(model_schema.features or []) & set([ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL]) \
+        if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} \
                 and node_data.reasoning_mode == 'function_call':
             # use function call 
             prompt_messages, prompt_message_tools = self._generate_function_call_prompt(
@@ -644,7 +644,7 @@ class ParameterExtractorNode(LLMNode):
         if not model_schema:
             raise ValueError("Model schema not found")
 
-        if set(model_schema.features or []) & set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL]):
+        if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}:
             prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000)
         else:
             prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000)

+ 2 - 8
api/libs/oauth_data_source.py

@@ -246,10 +246,7 @@ class NotionOAuth(OAuthDataSource):
         }
         response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
         response_json = response.json()
-        if 'results' in response_json:
-            results = response_json['results']
-        else:
-            results = []
+        results = response_json.get('results', [])
         return results
 
     def notion_block_parent_page_id(self, access_token: str, block_id: str):
@@ -293,8 +290,5 @@ class NotionOAuth(OAuthDataSource):
         }
         response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
         response_json = response.json()
-        if 'results' in response_json:
-            results = response_json['results']
-        else:
-            results = []
+        results = response_json.get('results', [])
         return results

+ 16 - 9
api/pyproject.toml

@@ -14,9 +14,11 @@ line-length = 120
 preview = true
 select = [
     "B", # flake8-bugbear rules
+    "C4", # flake8-comprehensions
     "F", # pyflakes rules
     "I", # isort rules
-    "UP",   # pyupgrade rules
+    "UP", # pyupgrade rules
+    "B035", # static-key-dict-comprehension
     "E101", # mixed-spaces-and-tabs
     "E111", # indentation-with-invalid-multiple
     "E112", # no-indented-block
@@ -28,8 +30,13 @@ select = [
     "RUF100", # unused-noqa
     "RUF101", # redirected-noqa
     "S506", # unsafe-yaml-load
+    "SIM116", # if-else-block-instead-of-dict-lookup
+    "SIM401", # if-else-block-instead-of-dict-get
+    "SIM910", # dict-get-with-none-default
     "W191", # tab-indentation
     "W605", # invalid-escape-sequence
+    "F601", # multi-value-repeated-key-literal
+    "F602", # multi-value-repeated-key-variable
 ]
 ignore = [
     "F403", # undefined-local-with-import-star
@@ -82,8 +89,8 @@ HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL = "b"
 HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL = "c"
 MOCK_SWITCH = "true"
 CODE_MAX_STRING_LENGTH = "80000"
-CODE_EXECUTION_ENDPOINT="http://127.0.0.1:8194"
-CODE_EXECUTION_API_KEY="dify-sandbox"
+CODE_EXECUTION_ENDPOINT = "http://127.0.0.1:8194"
+CODE_EXECUTION_API_KEY = "dify-sandbox"
 FIRECRAWL_API_KEY = "fc-"
 
 [tool.poetry]
@@ -114,11 +121,11 @@ cachetools = "~5.3.0"
 weaviate-client = "~3.21.0"
 mailchimp-transactional = "~1.0.50"
 scikit-learn = "1.2.2"
-sentry-sdk = {version = "~1.39.2", extras = ["flask"]}
+sentry-sdk = { version = "~1.39.2", extras = ["flask"] }
 sympy = "1.12"
 jieba = "0.42.1"
 celery = "~5.3.6"
-redis = {version = "~5.0.3", extras = ["hiredis"]}
+redis = { version = "~5.0.3", extras = ["hiredis"] }
 chardet = "~5.1.0"
 python-docx = "~1.1.0"
 pypdfium2 = "~4.17.0"
@@ -138,7 +145,7 @@ googleapis-common-protos = "1.63.0"
 google-cloud-storage = "2.16.0"
 replicate = "~0.22.0"
 websocket-client = "~1.7.0"
-dashscope = {version = "~1.17.0", extras = ["tokenizer"]}
+dashscope = { version = "~1.17.0", extras = ["tokenizer"] }
 huggingface-hub = "~0.16.4"
 transformers = "~4.35.0"
 tokenizers = "~0.15.0"
@@ -152,10 +159,10 @@ qdrant-client = "1.7.3"
 cohere = "~5.2.4"
 pyyaml = "~6.0.1"
 numpy = "~1.26.4"
-unstructured = {version = "~0.10.27", extras = ["docx", "epub", "md", "msg", "ppt", "pptx"]}
+unstructured = { version = "~0.10.27", extras = ["docx", "epub", "md", "msg", "ppt", "pptx"] }
 bs4 = "~0.0.1"
 markdown = "~3.5.1"
-httpx = {version = "~0.27.0", extras = ["socks"]}
+httpx = { version = "~0.27.0", extras = ["socks"] }
 matplotlib = "~3.8.2"
 yfinance = "~0.2.40"
 pydub = "~0.25.1"
@@ -180,7 +187,7 @@ pgvector = "0.2.5"
 pymysql = "1.1.1"
 tidb-vector = "0.0.9"
 google-cloud-aiplatform = "1.49.0"
-vanna = {version = "0.5.5", extras = ["postgres", "mysql", "clickhouse", "duckdb"]}
+vanna = { version = "0.5.5", extras = ["postgres", "mysql", "clickhouse", "duckdb"] }
 kaleido = "0.2.1"
 tencentcloud-sdk-python-hunyuan = "~3.0.1158"
 tcvectordb = "1.3.2"

+ 1 - 1
api/services/dataset_service.py

@@ -696,7 +696,7 @@ class DocumentService:
             elif document_data["data_source"]["type"] == "notion_import":
                 notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
                 exist_page_ids = []
-                exist_document = dict()
+                exist_document = {}
                 documents = Document.query.filter_by(
                     dataset_id=dataset.id,
                     tenant_id=current_user.current_tenant_id,

+ 1 - 1
api/services/recommended_app_service.py

@@ -95,7 +95,7 @@ class RecommendedAppService:
 
             categories.add(recommended_app.category)  # add category to categories
 
-        return {'recommended_apps': recommended_apps_result, 'categories': sorted(list(categories))}
+        return {'recommended_apps': recommended_apps_result, 'categories': sorted(categories)}
 
     @classmethod
     def _fetch_recommended_apps_from_dify_official(cls, language: str) -> dict:

+ 2 - 2
api/services/workflow/workflow_converter.py

@@ -514,8 +514,8 @@ class WorkflowConverter:
 
                 prompt_rules = prompt_template_config['prompt_rules']
                 role_prefix = {
-                    "user": prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human',
-                    "assistant": prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
+                    "user": prompt_rules.get('human_prefix', 'Human'),
+                    "assistant": prompt_rules.get('assistant_prefix', 'Assistant')
                 }
             else:
                 advanced_completion_prompt_template = prompt_template.advanced_completion_prompt_template

+ 2 - 2
api/tests/integration_tests/workflow/nodes/test_llm.py

@@ -112,7 +112,7 @@ def test_execute_llm(setup_openai_mock):
     # Mock db.session.close()
     db.session.close = MagicMock()
 
-    node._fetch_model_config = MagicMock(return_value=tuple([model_instance, model_config]))
+    node._fetch_model_config = MagicMock(return_value=(model_instance, model_config))
 
     # execute node
     result = node.run(pool)
@@ -229,7 +229,7 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
     # Mock db.session.close()
     db.session.close = MagicMock()
 
-    node._fetch_model_config = MagicMock(return_value=tuple([model_instance, model_config]))
+    node._fetch_model_config = MagicMock(return_value=(model_instance, model_config))
 
     # execute node
     result = node.run(pool)

+ 1 - 1
api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py

@@ -59,7 +59,7 @@ def get_mocked_fetch_model_config(
         provider_model_bundle=provider_model_bundle
     )
 
-    return MagicMock(return_value=tuple([model_instance, model_config]))
+    return MagicMock(return_value=(model_instance, model_config))
 
 @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
 def test_function_calling_parameter_extractor(setup_openai_mock):

+ 2 - 2
api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py

@@ -238,8 +238,8 @@ def test__get_completion_model_prompt_messages():
     prompt_rules = prompt_template['prompt_rules']
     full_inputs = {**inputs, '#context#': context, '#query#': query, '#histories#': memory.get_history_prompt_text(
         max_token_limit=2000,
-        human_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human',
-        ai_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
+        human_prefix=prompt_rules.get("human_prefix", "Human"),
+        ai_prefix=prompt_rules.get("assistant_prefix", "Assistant")
     )}
     real_prompt = prompt_template['prompt_template'].format(full_inputs)