瀏覽代碼

chore: cleanup ruff flake8-simplify linter rules (#8286)

Co-authored-by: -LAN- <laipz8200@outlook.com>
Bowen Liang 7 月之前
父節點
當前提交
0f14873255
共有 34 個文件被更改,包括 110 次插入138 次删除
  1. 1 1
      api/core/app/task_pipeline/based_generate_task_pipeline.py
  2. 2 2
      api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py
  3. 1 3
      api/core/model_runtime/model_providers/google/llm/llm.py
  4. 1 3
      api/core/model_runtime/model_providers/oci/llm/llm.py
  5. 1 3
      api/core/model_runtime/model_providers/tongyi/llm/llm.py
  6. 1 3
      api/core/model_runtime/model_providers/vertex_ai/llm/llm.py
  7. 1 5
      api/core/model_runtime/model_providers/xinference/llm/llm.py
  8. 5 6
      api/core/model_runtime/model_providers/zhipuai/llm/llm.py
  9. 3 10
      api/core/moderation/keywords/keywords.py
  10. 1 1
      api/core/ops/ops_trace_manager.py
  11. 23 25
      api/core/rag/datasource/vdb/relyt/relyt_vector.py
  12. 1 4
      api/core/rag/datasource/vdb/tencent/tencent_vector.py
  13. 16 18
      api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py
  14. 2 1
      api/core/rag/extractor/word_extractor.py
  15. 2 2
      api/core/rag/rerank/weight_rerank.py
  16. 2 2
      api/core/rag/retrieval/dataset_retrieval.py
  17. 1 3
      api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py
  18. 1 1
      api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py
  19. 1 1
      api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py
  20. 10 10
      api/core/tools/provider/builtin/searchapi/tools/google.py
  21. 3 3
      api/core/tools/provider/builtin/searchapi/tools/google_jobs.py
  22. 5 5
      api/core/tools/provider/builtin/searchapi/tools/google_news.py
  23. 2 2
      api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py
  24. 1 1
      api/core/tools/provider/builtin/stability/tools/text2image.py
  25. 3 4
      api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py
  26. 1 1
      api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py
  27. 1 3
      api/core/tools/utils/message_transformer.py
  28. 1 3
      api/core/workflow/graph_engine/entities/graph.py
  29. 1 1
      api/core/workflow/nodes/answer/answer_stream_generate_router.py
  30. 1 1
      api/core/workflow/nodes/end/end_stream_generate_router.py
  31. 1 1
      api/core/workflow/nodes/tool/entities.py
  32. 11 4
      api/pyproject.toml
  33. 1 3
      api/services/file_service.py
  34. 2 2
      api/services/ops_service.py

+ 1 - 1
api/core/app/task_pipeline/based_generate_task_pipeline.py

@@ -65,7 +65,7 @@ class BasedGenerateTaskPipeline:
 
         if isinstance(e, InvokeAuthorizationError):
             err = InvokeAuthorizationError("Incorrect API key provided")
-        elif isinstance(e, InvokeError) or isinstance(e, ValueError):
+        elif isinstance(e, InvokeError | ValueError):
             err = e
         else:
             err = Exception(e.description if getattr(e, "description", None) is not None else str(e))

+ 2 - 2
api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py

@@ -45,7 +45,7 @@ class BaichuanModel:
         parameters: dict[str, Any],
         tools: Optional[list[PromptMessageTool]] = None,
     ) -> dict[str, Any]:
-        if model in self._model_mapping.keys():
+        if model in self._model_mapping:
             # the LargeLanguageModel._code_block_mode_wrapper() method will remove the response_format of parameters.
             # we need to rename it to res_format to get its value
             if parameters.get("res_format") == "json_object":
@@ -94,7 +94,7 @@ class BaichuanModel:
         timeout: int,
         tools: Optional[list[PromptMessageTool]] = None,
     ) -> Union[Iterator, dict]:
-        if model in self._model_mapping.keys():
+        if model in self._model_mapping:
             api_base = "https://api.baichuan-ai.com/v1/chat/completions"
         else:
             raise BadRequestError(f"Unknown model: {model}")

+ 1 - 3
api/core/model_runtime/model_providers/google/llm/llm.py

@@ -337,9 +337,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
             message_text = f"{human_prompt} {content}"
         elif isinstance(message, AssistantPromptMessage):
             message_text = f"{ai_prompt} {content}"
-        elif isinstance(message, SystemPromptMessage):
-            message_text = f"{human_prompt} {content}"
-        elif isinstance(message, ToolPromptMessage):
+        elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
             message_text = f"{human_prompt} {content}"
         else:
             raise ValueError(f"Got unknown type {message}")

+ 1 - 3
api/core/model_runtime/model_providers/oci/llm/llm.py

@@ -442,9 +442,7 @@ class OCILargeLanguageModel(LargeLanguageModel):
             message_text = f"{human_prompt} {content}"
         elif isinstance(message, AssistantPromptMessage):
             message_text = f"{ai_prompt} {content}"
-        elif isinstance(message, SystemPromptMessage):
-            message_text = f"{human_prompt} {content}"
-        elif isinstance(message, ToolPromptMessage):
+        elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
             message_text = f"{human_prompt} {content}"
         else:
             raise ValueError(f"Got unknown type {message}")

+ 1 - 3
api/core/model_runtime/model_providers/tongyi/llm/llm.py

@@ -350,9 +350,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
                         break
         elif isinstance(message, AssistantPromptMessage):
             message_text = f"{ai_prompt} {content}"
-        elif isinstance(message, SystemPromptMessage):
-            message_text = content
-        elif isinstance(message, ToolPromptMessage):
+        elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
             message_text = content
         else:
             raise ValueError(f"Got unknown type {message}")

+ 1 - 3
api/core/model_runtime/model_providers/vertex_ai/llm/llm.py

@@ -633,9 +633,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
             message_text = f"{human_prompt} {content}"
         elif isinstance(message, AssistantPromptMessage):
             message_text = f"{ai_prompt} {content}"
-        elif isinstance(message, SystemPromptMessage):
-            message_text = f"{human_prompt} {content}"
-        elif isinstance(message, ToolPromptMessage):
+        elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
             message_text = f"{human_prompt} {content}"
         else:
             raise ValueError(f"Got unknown type {message}")

+ 1 - 5
api/core/model_runtime/model_providers/xinference/llm/llm.py

@@ -272,11 +272,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
         """
         text = ""
         for item in message:
-            if isinstance(item, UserPromptMessage):
-                text += item.content
-            elif isinstance(item, SystemPromptMessage):
-                text += item.content
-            elif isinstance(item, AssistantPromptMessage):
+            if isinstance(item, UserPromptMessage | SystemPromptMessage | AssistantPromptMessage):
                 text += item.content
             else:
                 raise NotImplementedError(f"PromptMessage type {type(item)} is not supported")

+ 5 - 6
api/core/model_runtime/model_providers/zhipuai/llm/llm.py

@@ -209,9 +209,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
                 ):
                     new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content
                 else:
-                    if copy_prompt_message.role == PromptMessageRole.USER:
-                        new_prompt_messages.append(copy_prompt_message)
-                    elif copy_prompt_message.role == PromptMessageRole.TOOL:
+                    if (
+                        copy_prompt_message.role == PromptMessageRole.USER
+                        or copy_prompt_message.role == PromptMessageRole.TOOL
+                    ):
                         new_prompt_messages.append(copy_prompt_message)
                     elif copy_prompt_message.role == PromptMessageRole.SYSTEM:
                         new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content)
@@ -461,9 +462,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
             message_text = f"{human_prompt} {content}"
         elif isinstance(message, AssistantPromptMessage):
             message_text = f"{ai_prompt} {content}"
-        elif isinstance(message, SystemPromptMessage):
-            message_text = content
-        elif isinstance(message, ToolPromptMessage):
+        elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
             message_text = content
         else:
             raise ValueError(f"Got unknown type {message}")

+ 3 - 10
api/core/moderation/keywords/keywords.py

@@ -56,14 +56,7 @@ class KeywordsModeration(Moderation):
         )
 
     def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
-        for value in inputs.values():
-            if self._check_keywords_in_value(keywords_list, value):
-                return True
+        return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values())
 
-        return False
-
-    def _check_keywords_in_value(self, keywords_list, value):
-        for keyword in keywords_list:
-            if keyword.lower() in value.lower():
-                return True
-        return False
+    def _check_keywords_in_value(self, keywords_list, value) -> bool:
+        return any(keyword.lower() in value.lower() for keyword in keywords_list)

+ 1 - 1
api/core/ops/ops_trace_manager.py

@@ -223,7 +223,7 @@ class OpsTraceManager:
         :return:
         """
         # auth check
-        if tracing_provider not in provider_config_map.keys() and tracing_provider is not None:
+        if tracing_provider not in provider_config_map and tracing_provider is not None:
             raise ValueError(f"Invalid tracing provider: {tracing_provider}")
 
         app_config: App = db.session.query(App).filter(App.id == app_id).first()

+ 23 - 25
api/core/rag/datasource/vdb/relyt/relyt_vector.py

@@ -127,27 +127,26 @@ class RelytVector(BaseVector):
         )
 
         chunks_table_data = []
-        with self.client.connect() as conn:
-            with conn.begin():
-                for document, metadata, chunk_id, embedding in zip(texts, metadatas, ids, embeddings):
-                    chunks_table_data.append(
-                        {
-                            "id": chunk_id,
-                            "embedding": embedding,
-                            "document": document,
-                            "metadata": metadata,
-                        }
-                    )
-
-                    # Execute the batch insert when the batch size is reached
-                    if len(chunks_table_data) == 500:
-                        conn.execute(insert(chunks_table).values(chunks_table_data))
-                        # Clear the chunks_table_data list for the next batch
-                        chunks_table_data.clear()
-
-                # Insert any remaining records that didn't make up a full batch
-                if chunks_table_data:
+        with self.client.connect() as conn, conn.begin():
+            for document, metadata, chunk_id, embedding in zip(texts, metadatas, ids, embeddings):
+                chunks_table_data.append(
+                    {
+                        "id": chunk_id,
+                        "embedding": embedding,
+                        "document": document,
+                        "metadata": metadata,
+                    }
+                )
+
+                # Execute the batch insert when the batch size is reached
+                if len(chunks_table_data) == 500:
                     conn.execute(insert(chunks_table).values(chunks_table_data))
+                    # Clear the chunks_table_data list for the next batch
+                    chunks_table_data.clear()
+
+            # Insert any remaining records that didn't make up a full batch
+            if chunks_table_data:
+                conn.execute(insert(chunks_table).values(chunks_table_data))
 
         return ids
 
@@ -186,11 +185,10 @@ class RelytVector(BaseVector):
         )
 
         try:
-            with self.client.connect() as conn:
-                with conn.begin():
-                    delete_condition = chunks_table.c.id.in_(ids)
-                    conn.execute(chunks_table.delete().where(delete_condition))
-                    return True
+            with self.client.connect() as conn, conn.begin():
+                delete_condition = chunks_table.c.id.in_(ids)
+                conn.execute(chunks_table.delete().where(delete_condition))
+                return True
         except Exception as e:
             print("Delete operation failed:", str(e))
             return False

+ 1 - 4
api/core/rag/datasource/vdb/tencent/tencent_vector.py

@@ -63,10 +63,7 @@ class TencentVector(BaseVector):
 
     def _has_collection(self) -> bool:
         collections = self._db.list_collections()
-        for collection in collections:
-            if collection.collection_name == self._collection_name:
-                return True
-        return False
+        return any(collection.collection_name == self._collection_name for collection in collections)
 
     def _create_collection(self, dimension: int) -> None:
         lock_name = "vector_indexing_lock_{}".format(self._collection_name)

+ 16 - 18
api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py

@@ -124,20 +124,19 @@ class TiDBVector(BaseVector):
         texts = [d.page_content for d in documents]
 
         chunks_table_data = []
-        with self._engine.connect() as conn:
-            with conn.begin():
-                for id, text, meta, embedding in zip(ids, texts, metas, embeddings):
-                    chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta})
-
-                    # Execute the batch insert when the batch size is reached
-                    if len(chunks_table_data) == 500:
-                        conn.execute(insert(table).values(chunks_table_data))
-                        # Clear the chunks_table_data list for the next batch
-                        chunks_table_data.clear()
-
-                # Insert any remaining records that didn't make up a full batch
-                if chunks_table_data:
+        with self._engine.connect() as conn, conn.begin():
+            for id, text, meta, embedding in zip(ids, texts, metas, embeddings):
+                chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta})
+
+                # Execute the batch insert when the batch size is reached
+                if len(chunks_table_data) == 500:
                     conn.execute(insert(table).values(chunks_table_data))
+                    # Clear the chunks_table_data list for the next batch
+                    chunks_table_data.clear()
+
+            # Insert any remaining records that didn't make up a full batch
+            if chunks_table_data:
+                conn.execute(insert(table).values(chunks_table_data))
         return ids
 
     def text_exists(self, id: str) -> bool:
@@ -160,11 +159,10 @@ class TiDBVector(BaseVector):
             raise ValueError("No ids provided to delete.")
         table = self._table(self._dimension)
         try:
-            with self._engine.connect() as conn:
-                with conn.begin():
-                    delete_condition = table.c.id.in_(ids)
-                    conn.execute(table.delete().where(delete_condition))
-                    return True
+            with self._engine.connect() as conn, conn.begin():
+                delete_condition = table.c.id.in_(ids)
+                conn.execute(table.delete().where(delete_condition))
+                return True
         except Exception as e:
             print("Delete operation failed:", str(e))
             return False

+ 2 - 1
api/core/rag/extractor/word_extractor.py

@@ -48,7 +48,8 @@ class WordExtractor(BaseExtractor):
                 raise ValueError(f"Check the url of your file; returned status code {r.status_code}")
 
             self.web_path = self.file_path
-            self.temp_file = tempfile.NamedTemporaryFile()
+            # TODO: use a better way to handle the file
+            self.temp_file = tempfile.NamedTemporaryFile()  # noqa: SIM115
             self.temp_file.write(r.content)
             self.file_path = self.temp_file.name
         elif not os.path.isfile(self.file_path):

+ 2 - 2
api/core/rag/rerank/weight_rerank.py

@@ -120,8 +120,8 @@ class WeightRerankRunner:
             intersection = set(vec1.keys()) & set(vec2.keys())
             numerator = sum(vec1[x] * vec2[x] for x in intersection)
 
-            sum1 = sum(vec1[x] ** 2 for x in vec1.keys())
-            sum2 = sum(vec2[x] ** 2 for x in vec2.keys())
+            sum1 = sum(vec1[x] ** 2 for x in vec1)
+            sum2 = sum(vec2[x] ** 2 for x in vec2)
             denominator = math.sqrt(sum1) * math.sqrt(sum2)
 
             if not denominator:

+ 2 - 2
api/core/rag/retrieval/dataset_retrieval.py

@@ -581,8 +581,8 @@ class DatasetRetrieval:
             intersection = set(vec1.keys()) & set(vec2.keys())
             numerator = sum(vec1[x] * vec2[x] for x in intersection)
 
-            sum1 = sum(vec1[x] ** 2 for x in vec1.keys())
-            sum2 = sum(vec2[x] ** 2 for x in vec2.keys())
+            sum1 = sum(vec1[x] ** 2 for x in vec1)
+            sum2 = sum(vec2[x] ** 2 for x in vec2)
             denominator = math.sqrt(sum1) * math.sqrt(sum2)
 
             if not denominator:

+ 1 - 3
api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py

@@ -201,9 +201,7 @@ class ListWorksheetRecordsTool(BuiltinTool):
             elif value.startswith('[{"organizeId"'):
                 value = json.loads(value)
                 value = "、".join([item["organizeName"] for item in value])
-            elif value.startswith('[{"file_id"'):
-                value = ""
-            elif value == "[]":
+            elif value.startswith('[{"file_id"') or value == "[]":
                 value = ""
         elif hasattr(value, "accountId"):
             value = value["fullname"]

+ 1 - 1
api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py

@@ -35,7 +35,7 @@ class NovitaAiModelQueryTool(BuiltinTool):
             models_data=[],
             headers=headers,
             params=params,
-            recursive=False if result_type == "first sd_name" or result_type == "first name sd_name pair" else True,
+            recursive=not (result_type == "first sd_name" or result_type == "first name sd_name pair"),
         )
 
         result_str = ""

+ 1 - 1
api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py

@@ -39,7 +39,7 @@ class QRCodeGeneratorTool(BuiltinTool):
 
         # get error_correction
         error_correction = tool_parameters.get("error_correction", "")
-        if error_correction not in self.error_correction_levels.keys():
+        if error_correction not in self.error_correction_levels:
             return self.create_text_message("Invalid parameter error_correction")
 
         try:

+ 10 - 10
api/core/tools/provider/builtin/searchapi/tools/google.py

@@ -44,36 +44,36 @@ class SearchAPI:
     @staticmethod
     def _process_response(res: dict, type: str) -> str:
         """Process response from SearchAPI."""
-        if "error" in res.keys():
+        if "error" in res:
             raise ValueError(f"Got error from SearchApi: {res['error']}")
 
         toret = ""
         if type == "text":
-            if "answer_box" in res.keys() and "answer" in res["answer_box"].keys():
+            if "answer_box" in res and "answer" in res["answer_box"]:
                 toret += res["answer_box"]["answer"] + "\n"
-            if "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
+            if "answer_box" in res and "snippet" in res["answer_box"]:
                 toret += res["answer_box"]["snippet"] + "\n"
-            if "knowledge_graph" in res.keys() and "description" in res["knowledge_graph"].keys():
+            if "knowledge_graph" in res and "description" in res["knowledge_graph"]:
                 toret += res["knowledge_graph"]["description"] + "\n"
-            if "organic_results" in res.keys() and "snippet" in res["organic_results"][0].keys():
+            if "organic_results" in res and "snippet" in res["organic_results"][0]:
                 for item in res["organic_results"]:
                     toret += "content: " + item["snippet"] + "\n" + "link: " + item["link"] + "\n"
             if toret == "":
                 toret = "No good search result found"
 
         elif type == "link":
-            if "answer_box" in res.keys() and "organic_result" in res["answer_box"].keys():
-                if "title" in res["answer_box"]["organic_result"].keys():
+            if "answer_box" in res and "organic_result" in res["answer_box"]:
+                if "title" in res["answer_box"]["organic_result"]:
                     toret = f"[{res['answer_box']['organic_result']['title']}]({res['answer_box']['organic_result']['link']})\n"
-            elif "organic_results" in res.keys() and "link" in res["organic_results"][0].keys():
+            elif "organic_results" in res and "link" in res["organic_results"][0]:
                 toret = ""
                 for item in res["organic_results"]:
                     toret += f"[{item['title']}]({item['link']})\n"
-            elif "related_questions" in res.keys() and "link" in res["related_questions"][0].keys():
+            elif "related_questions" in res and "link" in res["related_questions"][0]:
                 toret = ""
                 for item in res["related_questions"]:
                     toret += f"[{item['title']}]({item['link']})\n"
-            elif "related_searches" in res.keys() and "link" in res["related_searches"][0].keys():
+            elif "related_searches" in res and "link" in res["related_searches"][0]:
                 toret = ""
                 for item in res["related_searches"]:
                     toret += f"[{item['title']}]({item['link']})\n"

+ 3 - 3
api/core/tools/provider/builtin/searchapi/tools/google_jobs.py

@@ -44,12 +44,12 @@ class SearchAPI:
     @staticmethod
     def _process_response(res: dict, type: str) -> str:
         """Process response from SearchAPI."""
-        if "error" in res.keys():
+        if "error" in res:
             raise ValueError(f"Got error from SearchApi: {res['error']}")
 
         toret = ""
         if type == "text":
-            if "jobs" in res.keys() and "title" in res["jobs"][0].keys():
+            if "jobs" in res and "title" in res["jobs"][0]:
                 for item in res["jobs"]:
                     toret += (
                         "title: "
@@ -65,7 +65,7 @@ class SearchAPI:
                 toret = "No good search result found"
 
         elif type == "link":
-            if "jobs" in res.keys() and "apply_link" in res["jobs"][0].keys():
+            if "jobs" in res and "apply_link" in res["jobs"][0]:
                 for item in res["jobs"]:
                     toret += f"[{item['title']} - {item['company_name']}]({item['apply_link']})\n"
             else:

+ 5 - 5
api/core/tools/provider/builtin/searchapi/tools/google_news.py

@@ -44,25 +44,25 @@ class SearchAPI:
     @staticmethod
     def _process_response(res: dict, type: str) -> str:
         """Process response from SearchAPI."""
-        if "error" in res.keys():
+        if "error" in res:
             raise ValueError(f"Got error from SearchApi: {res['error']}")
 
         toret = ""
         if type == "text":
-            if "organic_results" in res.keys() and "snippet" in res["organic_results"][0].keys():
+            if "organic_results" in res and "snippet" in res["organic_results"][0]:
                 for item in res["organic_results"]:
                     toret += "content: " + item["snippet"] + "\n" + "link: " + item["link"] + "\n"
-            if "top_stories" in res.keys() and "title" in res["top_stories"][0].keys():
+            if "top_stories" in res and "title" in res["top_stories"][0]:
                 for item in res["top_stories"]:
                     toret += "title: " + item["title"] + "\n" + "link: " + item["link"] + "\n"
             if toret == "":
                 toret = "No good search result found"
 
         elif type == "link":
-            if "organic_results" in res.keys() and "title" in res["organic_results"][0].keys():
+            if "organic_results" in res and "title" in res["organic_results"][0]:
                 for item in res["organic_results"]:
                     toret += f"[{item['title']}]({item['link']})\n"
-            elif "top_stories" in res.keys() and "title" in res["top_stories"][0].keys():
+            elif "top_stories" in res and "title" in res["top_stories"][0]:
                 for item in res["top_stories"]:
                     toret += f"[{item['title']}]({item['link']})\n"
             else:

+ 2 - 2
api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py

@@ -44,11 +44,11 @@ class SearchAPI:
     @staticmethod
     def _process_response(res: dict) -> str:
         """Process response from SearchAPI."""
-        if "error" in res.keys():
+        if "error" in res:
             raise ValueError(f"Got error from SearchApi: {res['error']}")
 
         toret = ""
-        if "transcripts" in res.keys() and "text" in res["transcripts"][0].keys():
+        if "transcripts" in res and "text" in res["transcripts"][0]:
             for item in res["transcripts"]:
                 toret += item["text"] + " "
         if toret == "":

+ 1 - 1
api/core/tools/provider/builtin/stability/tools/text2image.py

@@ -35,7 +35,7 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization):
         if model in ["sd3", "sd3-turbo"]:
             payload["model"] = tool_parameters.get("model")
 
-        if not model == "sd3-turbo":
+        if model != "sd3-turbo":
             payload["negative_prompt"] = tool_parameters.get("negative_prompt", "")
 
         response = post(

+ 3 - 4
api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py

@@ -206,10 +206,9 @@ class StableDiffusionTool(BuiltinTool):
 
         # Convert image to RGB and save as PNG
         try:
-            with Image.open(io.BytesIO(image_binary)) as image:
-                with io.BytesIO() as buffer:
-                    image.convert("RGB").save(buffer, format="PNG")
-                    image_binary = buffer.getvalue()
+            with Image.open(io.BytesIO(image_binary)) as image, io.BytesIO() as buffer:
+                image.convert("RGB").save(buffer, format="PNG")
+                image_binary = buffer.getvalue()
         except Exception as e:
             return self.create_text_message(f"Failed to process the image: {str(e)}")
 

+ 1 - 1
api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py

@@ -27,7 +27,7 @@ class WikipediaAPIWrapper:
         self.doc_content_chars_max = doc_content_chars_max
 
     def run(self, query: str, lang: str = "") -> str:
-        if lang in wikipedia.languages().keys():
+        if lang in wikipedia.languages():
             self.lang = lang
 
         wikipedia.set_lang(self.lang)

+ 1 - 3
api/core/tools/utils/message_transformer.py

@@ -19,9 +19,7 @@ class ToolFileMessageTransformer:
         result = []
 
         for message in messages:
-            if message.type == ToolInvokeMessage.MessageType.TEXT:
-                result.append(message)
-            elif message.type == ToolInvokeMessage.MessageType.LINK:
+            if message.type == ToolInvokeMessage.MessageType.TEXT or message.type == ToolInvokeMessage.MessageType.LINK:
                 result.append(message)
             elif message.type == ToolInvokeMessage.MessageType.IMAGE:
                 # try to download image

+ 1 - 3
api/core/workflow/graph_engine/entities/graph.py

@@ -224,9 +224,7 @@ class Graph(BaseModel):
         """
         leaf_node_ids = []
         for node_id in self.node_ids:
-            if node_id not in self.edge_mapping:
-                leaf_node_ids.append(node_id)
-            elif (
+            if node_id not in self.edge_mapping or (
                 len(self.edge_mapping[node_id]) == 1
                 and self.edge_mapping[node_id][0].target_node_id == self.root_node_id
             ):

+ 1 - 1
api/core/workflow/nodes/answer/answer_stream_generate_router.py

@@ -24,7 +24,7 @@ class AnswerStreamGeneratorRouter:
         # parse stream output node value selectors of answer nodes
         answer_generate_route: dict[str, list[GenerateRouteChunk]] = {}
         for answer_node_id, node_config in node_id_config_mapping.items():
-            if not node_config.get("data", {}).get("type") == NodeType.ANSWER.value:
+            if node_config.get("data", {}).get("type") != NodeType.ANSWER.value:
                 continue
 
             # get generate route for stream output

+ 1 - 1
api/core/workflow/nodes/end/end_stream_generate_router.py

@@ -17,7 +17,7 @@ class EndStreamGeneratorRouter:
         # parse stream output node value selector of end nodes
         end_stream_variable_selectors_mapping: dict[str, list[list[str]]] = {}
         for end_node_id, node_config in node_id_config_mapping.items():
-            if not node_config.get("data", {}).get("type") == NodeType.END.value:
+            if node_config.get("data", {}).get("type") != NodeType.END.value:
                 continue
 
             # skip end node in parallel

+ 1 - 1
api/core/workflow/nodes/tool/entities.py

@@ -20,7 +20,7 @@ class ToolEntity(BaseModel):
         if not isinstance(value, dict):
             raise ValueError("tool_configurations must be a dictionary")
 
-        for key in values.data.get("tool_configurations", {}).keys():
+        for key in values.data.get("tool_configurations", {}):
             value = values.data.get("tool_configurations", {}).get(key)
             if not isinstance(value, str | int | float | bool):
                 raise ValueError(f"{key} must be a string")

+ 11 - 4
api/pyproject.toml

@@ -17,14 +17,12 @@ select = [
     "F", # pyflakes rules
     "I", # isort rules
     "N", # pep8-naming
-    "UP", # pyupgrade rules
     "RUF019", # unnecessary-key-check
     "RUF100", # unused-noqa
     "RUF101", # redirected-noqa
     "S506", # unsafe-yaml-load
-    "SIM116", # if-else-block-instead-of-dict-lookup
-    "SIM401", # if-else-block-instead-of-dict-get
-    "SIM910", # dict-get-with-none-default
+    "SIM", # flake8-simplify rules
+    "UP", # pyupgrade rules
     "W191", # tab-indentation
     "W605", # invalid-escape-sequence
 ]
@@ -50,6 +48,15 @@ ignore = [
     "B905", # zip-without-explicit-strict
     "N806", # non-lowercase-variable-in-function
     "N815", # mixed-case-variable-in-class-scope
+    "SIM102", # collapsible-if
+    "SIM103", # needless-bool
+    "SIM105", # suppressible-exception
+    "SIM107", # return-in-try-except-finally
+    "SIM108", # if-else-block-instead-of-if-exp
+    "SIM113", # eumerate-for-loop
+    "SIM117", # multiple-with-statements
+    "SIM210", # if-expr-with-true-false
+    "SIM300", # yoda-conditions
 ]
 
 [tool.ruff.lint.per-file-ignores]

+ 1 - 3
api/services/file_service.py

@@ -56,9 +56,7 @@ class FileService:
             if etl_type == "Unstructured"
             else ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS
         )
-        if extension.lower() not in allowed_extensions:
-            raise UnsupportedFileTypeError()
-        elif only_image and extension.lower() not in IMAGE_EXTENSIONS:
+        if extension.lower() not in allowed_extensions or only_image and extension.lower() not in IMAGE_EXTENSIONS:
             raise UnsupportedFileTypeError()
 
         # read file content

+ 2 - 2
api/services/ops_service.py

@@ -54,7 +54,7 @@ class OpsService:
         :param tracing_config: tracing config
         :return:
         """
-        if tracing_provider not in provider_config_map.keys() and tracing_provider:
+        if tracing_provider not in provider_config_map and tracing_provider:
             return {"error": f"Invalid tracing provider: {tracing_provider}"}
 
         config_class, other_keys = (
@@ -113,7 +113,7 @@ class OpsService:
         :param tracing_config: tracing config
         :return:
         """
-        if tracing_provider not in provider_config_map.keys():
+        if tracing_provider not in provider_config_map:
             raise ValueError(f"Invalid tracing provider: {tracing_provider}")
 
         # check if trace config already exists