Prechádzať zdrojové kódy

refactor(api/core): Improve type hints and apply ruff formatter in agent runner and model manager. (#8166)

-LAN- 7 mesiacov pred
rodič
commit
ed37439ef7
2 zmenil súbory, kde vykonal 199 pridanie a 197 odobranie
  1. 129 110
      api/core/agent/base_agent_runner.py
  2. 70 87
      api/core/model_manager.py

+ 129 - 110
api/core/agent/base_agent_runner.py

@@ -1,6 +1,7 @@
 import json
 import logging
 import uuid
+from collections.abc import Mapping, Sequence
 from datetime import datetime, timezone
 from typing import Optional, Union, cast
 
@@ -45,22 +46,25 @@ from models.tools import ToolConversationVariables
 
 logger = logging.getLogger(__name__)
 
+
 class BaseAgentRunner(AppRunner):
-    def __init__(self, tenant_id: str,
-                 application_generate_entity: AgentChatAppGenerateEntity,
-                 conversation: Conversation,
-                 app_config: AgentChatAppConfig,
-                 model_config: ModelConfigWithCredentialsEntity,
-                 config: AgentEntity,
-                 queue_manager: AppQueueManager,
-                 message: Message,
-                 user_id: str,
-                 memory: Optional[TokenBufferMemory] = None,
-                 prompt_messages: Optional[list[PromptMessage]] = None,
-                 variables_pool: Optional[ToolRuntimeVariablePool] = None,
-                 db_variables: Optional[ToolConversationVariables] = None,
-                 model_instance: ModelInstance = None
-                 ) -> None:
+    def __init__(
+        self,
+        tenant_id: str,
+        application_generate_entity: AgentChatAppGenerateEntity,
+        conversation: Conversation,
+        app_config: AgentChatAppConfig,
+        model_config: ModelConfigWithCredentialsEntity,
+        config: AgentEntity,
+        queue_manager: AppQueueManager,
+        message: Message,
+        user_id: str,
+        memory: Optional[TokenBufferMemory] = None,
+        prompt_messages: Optional[list[PromptMessage]] = None,
+        variables_pool: Optional[ToolRuntimeVariablePool] = None,
+        db_variables: Optional[ToolConversationVariables] = None,
+        model_instance: ModelInstance = None,
+    ) -> None:
         """
         Agent runner
         :param tenant_id: tenant id
@@ -88,9 +92,7 @@ class BaseAgentRunner(AppRunner):
         self.message = message
         self.user_id = user_id
         self.memory = memory
-        self.history_prompt_messages = self.organize_agent_history(
-            prompt_messages=prompt_messages or []
-        )
+        self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
         self.variables_pool = variables_pool
         self.db_variables_pool = db_variables
         self.model_instance = model_instance
@@ -111,12 +113,16 @@ class BaseAgentRunner(AppRunner):
             retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
             return_resource=app_config.additional_features.show_retrieve_source,
             invoke_from=application_generate_entity.invoke_from,
-            hit_callback=hit_callback
+            hit_callback=hit_callback,
         )
         # get how many agent thoughts have been created
-        self.agent_thought_count = db.session.query(MessageAgentThought).filter(
-            MessageAgentThought.message_id == self.message.id,
-        ).count()
+        self.agent_thought_count = (
+            db.session.query(MessageAgentThought)
+            .filter(
+                MessageAgentThought.message_id == self.message.id,
+            )
+            .count()
+        )
         db.session.close()
 
         # check if model supports stream tool call
@@ -135,25 +141,26 @@ class BaseAgentRunner(AppRunner):
         self.query = None
         self._current_thoughts: list[PromptMessage] = []
 
-    def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \
-            -> AgentChatAppGenerateEntity:
+    def _repack_app_generate_entity(
+        self, app_generate_entity: AgentChatAppGenerateEntity
+    ) -> AgentChatAppGenerateEntity:
         """
         Repack app generate entity
         """
         if app_generate_entity.app_config.prompt_template.simple_prompt_template is None:
-            app_generate_entity.app_config.prompt_template.simple_prompt_template = ''
+            app_generate_entity.app_config.prompt_template.simple_prompt_template = ""
 
         return app_generate_entity
-    
+
     def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
         """
-            convert tool to prompt message tool
+        convert tool to prompt message tool
         """
         tool_entity = ToolManager.get_agent_tool_runtime(
             tenant_id=self.tenant_id,
             app_id=self.app_config.app_id,
             agent_tool=tool,
-            invoke_from=self.application_generate_entity.invoke_from
+            invoke_from=self.application_generate_entity.invoke_from,
         )
         tool_entity.load_variables(self.variables_pool)
 
@@ -164,7 +171,7 @@ class BaseAgentRunner(AppRunner):
                 "type": "object",
                 "properties": {},
                 "required": [],
-            }
+            },
         )
 
         parameters = tool_entity.get_all_runtime_parameters()
@@ -177,19 +184,19 @@ class BaseAgentRunner(AppRunner):
             if parameter.type == ToolParameter.ToolParameterType.SELECT:
                 enum = [option.value for option in parameter.options]
 
-            message_tool.parameters['properties'][parameter.name] = {
+            message_tool.parameters["properties"][parameter.name] = {
                 "type": parameter_type,
-                "description": parameter.llm_description or '',
+                "description": parameter.llm_description or "",
             }
 
             if len(enum) > 0:
-                message_tool.parameters['properties'][parameter.name]['enum'] = enum
+                message_tool.parameters["properties"][parameter.name]["enum"] = enum
 
             if parameter.required:
-                message_tool.parameters['required'].append(parameter.name)
+                message_tool.parameters["required"].append(parameter.name)
 
         return message_tool, tool_entity
-    
+
     def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
         """
         convert dataset retriever tool to prompt message tool
@@ -201,24 +208,24 @@ class BaseAgentRunner(AppRunner):
                 "type": "object",
                 "properties": {},
                 "required": [],
-            }
+            },
         )
 
         for parameter in tool.get_runtime_parameters():
-            parameter_type = 'string'
-        
-            prompt_tool.parameters['properties'][parameter.name] = {
+            parameter_type = "string"
+
+            prompt_tool.parameters["properties"][parameter.name] = {
                 "type": parameter_type,
-                "description": parameter.llm_description or '',
+                "description": parameter.llm_description or "",
             }
 
             if parameter.required:
-                if parameter.name not in prompt_tool.parameters['required']:
-                    prompt_tool.parameters['required'].append(parameter.name)
+                if parameter.name not in prompt_tool.parameters["required"]:
+                    prompt_tool.parameters["required"].append(parameter.name)
 
         return prompt_tool
-    
-    def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]:
+
+    def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]:
         """
         Init tools
         """
@@ -261,51 +268,51 @@ class BaseAgentRunner(AppRunner):
             enum = []
             if parameter.type == ToolParameter.ToolParameterType.SELECT:
                 enum = [option.value for option in parameter.options]
-        
-            prompt_tool.parameters['properties'][parameter.name] = {
+
+            prompt_tool.parameters["properties"][parameter.name] = {
                 "type": parameter_type,
-                "description": parameter.llm_description or '',
+                "description": parameter.llm_description or "",
             }
 
             if len(enum) > 0:
-                prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
+                prompt_tool.parameters["properties"][parameter.name]["enum"] = enum
 
             if parameter.required:
-                if parameter.name not in prompt_tool.parameters['required']:
-                    prompt_tool.parameters['required'].append(parameter.name)
+                if parameter.name not in prompt_tool.parameters["required"]:
+                    prompt_tool.parameters["required"].append(parameter.name)
 
         return prompt_tool
-        
-    def create_agent_thought(self, message_id: str, message: str, 
-                             tool_name: str, tool_input: str, messages_ids: list[str]
-                             ) -> MessageAgentThought:
+
+    def create_agent_thought(
+        self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str]
+    ) -> MessageAgentThought:
         """
         Create agent thought
         """
         thought = MessageAgentThought(
             message_id=message_id,
             message_chain_id=None,
-            thought='',
+            thought="",
             tool=tool_name,
-            tool_labels_str='{}',
-            tool_meta_str='{}',
+            tool_labels_str="{}",
+            tool_meta_str="{}",
             tool_input=tool_input,
             message=message,
             message_token=0,
             message_unit_price=0,
             message_price_unit=0,
-            message_files=json.dumps(messages_ids) if messages_ids else '',
-            answer='',
-            observation='',
+            message_files=json.dumps(messages_ids) if messages_ids else "",
+            answer="",
+            observation="",
             answer_token=0,
             answer_unit_price=0,
             answer_price_unit=0,
             tokens=0,
             total_price=0,
             position=self.agent_thought_count + 1,
-            currency='USD',
+            currency="USD",
             latency=0,
-            created_by_role='account',
+            created_by_role="account",
             created_by=self.user_id,
         )
 
@@ -318,22 +325,22 @@ class BaseAgentRunner(AppRunner):
 
         return thought
 
-    def save_agent_thought(self, 
-                           agent_thought: MessageAgentThought, 
-                           tool_name: str,
-                           tool_input: Union[str, dict],
-                           thought: str, 
-                           observation: Union[str, dict], 
-                           tool_invoke_meta: Union[str, dict],
-                           answer: str,
-                           messages_ids: list[str],
-                           llm_usage: LLMUsage = None) -> MessageAgentThought:
+    def save_agent_thought(
+        self,
+        agent_thought: MessageAgentThought,
+        tool_name: str,
+        tool_input: Union[str, dict],
+        thought: str,
+        observation: Union[str, dict],
+        tool_invoke_meta: Union[str, dict],
+        answer: str,
+        messages_ids: list[str],
+        llm_usage: LLMUsage = None,
+    ) -> MessageAgentThought:
         """
         Save agent thought
         """
-        agent_thought = db.session.query(MessageAgentThought).filter(
-            MessageAgentThought.id == agent_thought.id
-        ).first()
+        agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
 
         if thought is not None:
             agent_thought.thought = thought
@@ -356,7 +363,7 @@ class BaseAgentRunner(AppRunner):
                     observation = json.dumps(observation, ensure_ascii=False)
                 except Exception as e:
                     observation = json.dumps(observation)
-                    
+
             agent_thought.observation = observation
 
         if answer is not None:
@@ -364,7 +371,7 @@ class BaseAgentRunner(AppRunner):
 
         if messages_ids is not None and len(messages_ids) > 0:
             agent_thought.message_files = json.dumps(messages_ids)
-        
+
         if llm_usage:
             agent_thought.message_token = llm_usage.prompt_tokens
             agent_thought.message_price_unit = llm_usage.prompt_price_unit
@@ -377,7 +384,7 @@ class BaseAgentRunner(AppRunner):
 
         # check if tool labels is not empty
         labels = agent_thought.tool_labels or {}
-        tools = agent_thought.tool.split(';') if agent_thought.tool else []
+        tools = agent_thought.tool.split(";") if agent_thought.tool else []
         for tool in tools:
             if not tool:
                 continue
@@ -386,7 +393,7 @@ class BaseAgentRunner(AppRunner):
                 if tool_label:
                     labels[tool] = tool_label.to_dict()
                 else:
-                    labels[tool] = {'en_US': tool, 'zh_Hans': tool}
+                    labels[tool] = {"en_US": tool, "zh_Hans": tool}
 
         agent_thought.tool_labels_str = json.dumps(labels)
 
@@ -401,14 +408,18 @@ class BaseAgentRunner(AppRunner):
 
         db.session.commit()
         db.session.close()
-    
+
     def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
         """
         convert tool variables to db variables
         """
-        db_variables = db.session.query(ToolConversationVariables).filter(
-            ToolConversationVariables.conversation_id == self.message.conversation_id,
-        ).first()
+        db_variables = (
+            db.session.query(ToolConversationVariables)
+            .filter(
+                ToolConversationVariables.conversation_id == self.message.conversation_id,
+            )
+            .first()
+        )
 
         db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
         db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
@@ -425,9 +436,14 @@ class BaseAgentRunner(AppRunner):
             if isinstance(prompt_message, SystemPromptMessage):
                 result.append(prompt_message)
 
-        messages: list[Message] = db.session.query(Message).filter(
-            Message.conversation_id == self.message.conversation_id,
-        ).order_by(Message.created_at.asc()).all()
+        messages: list[Message] = (
+            db.session.query(Message)
+            .filter(
+                Message.conversation_id == self.message.conversation_id,
+            )
+            .order_by(Message.created_at.asc())
+            .all()
+        )
 
         for message in messages:
             if message.id == self.message.id:
@@ -439,13 +455,13 @@ class BaseAgentRunner(AppRunner):
                 for agent_thought in agent_thoughts:
                     tools = agent_thought.tool
                     if tools:
-                        tools = tools.split(';')
+                        tools = tools.split(";")
                         tool_calls: list[AssistantPromptMessage.ToolCall] = []
                         tool_call_response: list[ToolPromptMessage] = []
                         try:
                             tool_inputs = json.loads(agent_thought.tool_input)
                         except Exception as e:
-                            tool_inputs = { tool: {} for tool in tools }
+                            tool_inputs = {tool: {} for tool in tools}
                         try:
                             tool_responses = json.loads(agent_thought.observation)
                         except Exception as e:
@@ -454,27 +470,33 @@ class BaseAgentRunner(AppRunner):
                         for tool in tools:
                             # generate a uuid for tool call
                             tool_call_id = str(uuid.uuid4())
-                            tool_calls.append(AssistantPromptMessage.ToolCall(
-                                id=tool_call_id,
-                                type='function',
-                                function=AssistantPromptMessage.ToolCall.ToolCallFunction(
+                            tool_calls.append(
+                                AssistantPromptMessage.ToolCall(
+                                    id=tool_call_id,
+                                    type="function",
+                                    function=AssistantPromptMessage.ToolCall.ToolCallFunction(
+                                        name=tool,
+                                        arguments=json.dumps(tool_inputs.get(tool, {})),
+                                    ),
+                                )
+                            )
+                            tool_call_response.append(
+                                ToolPromptMessage(
+                                    content=tool_responses.get(tool, agent_thought.observation),
                                     name=tool,
-                                    arguments=json.dumps(tool_inputs.get(tool, {})),
+                                    tool_call_id=tool_call_id,
                                 )
-                            ))
-                            tool_call_response.append(ToolPromptMessage(
-                                content=tool_responses.get(tool, agent_thought.observation),
-                                name=tool,
-                                tool_call_id=tool_call_id,
-                            ))
-
-                        result.extend([
-                            AssistantPromptMessage(
-                                content=agent_thought.thought,
-                                tool_calls=tool_calls,
-                            ),
-                            *tool_call_response
-                        ])
+                            )
+
+                        result.extend(
+                            [
+                                AssistantPromptMessage(
+                                    content=agent_thought.thought,
+                                    tool_calls=tool_calls,
+                                ),
+                                *tool_call_response,
+                            ]
+                        )
                     if not tools:
                         result.append(AssistantPromptMessage(content=agent_thought.thought))
             else:
@@ -496,10 +518,7 @@ class BaseAgentRunner(AppRunner):
             file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
 
             if file_extra_config:
-                file_objs = message_file_parser.transform_message_files(
-                    files,
-                    file_extra_config
-                )
+                file_objs = message_file_parser.transform_message_files(files, file_extra_config)
             else:
                 file_objs = []
 

+ 70 - 87
api/core/model_manager.py

@@ -1,6 +1,6 @@
 import logging
 import os
-from collections.abc import Callable, Generator
+from collections.abc import Callable, Generator, Sequence
 from typing import IO, Optional, Union, cast
 
 from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
@@ -41,7 +41,7 @@ class ModelInstance:
             configuration=provider_model_bundle.configuration,
             model_type=provider_model_bundle.model_type_instance.model_type,
             model=model,
-            credentials=self.credentials
+            credentials=self.credentials,
         )
 
     @staticmethod
@@ -54,10 +54,7 @@ class ModelInstance:
         """
         configuration = provider_model_bundle.configuration
         model_type = provider_model_bundle.model_type_instance.model_type
-        credentials = configuration.get_current_credentials(
-            model_type=model_type,
-            model=model
-        )
+        credentials = configuration.get_current_credentials(model_type=model_type, model=model)
 
         if credentials is None:
             raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.")
@@ -65,10 +62,9 @@ class ModelInstance:
         return credentials
 
     @staticmethod
-    def _get_load_balancing_manager(configuration: ProviderConfiguration,
-                                    model_type: ModelType,
-                                    model: str,
-                                    credentials: dict) -> Optional["LBModelManager"]:
+    def _get_load_balancing_manager(
+        configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict
+    ) -> Optional["LBModelManager"]:
         """
         Get load balancing model credentials
         :param configuration: provider configuration
@@ -81,8 +77,7 @@ class ModelInstance:
             current_model_setting = None
             # check if model is disabled by admin
             for model_setting in configuration.model_settings:
-                if (model_setting.model_type == model_type
-                        and model_setting.model == model):
+                if model_setting.model_type == model_type and model_setting.model == model:
                     current_model_setting = model_setting
                     break
 
@@ -95,17 +90,23 @@ class ModelInstance:
                     model_type=model_type,
                     model=model,
                     load_balancing_configs=current_model_setting.load_balancing_configs,
-                    managed_credentials=credentials if configuration.custom_configuration.provider else None
+                    managed_credentials=credentials if configuration.custom_configuration.provider else None,
                 )
 
                 return lb_model_manager
 
         return None
 
-    def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
-                   tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
-                   stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \
-            -> Union[LLMResult, Generator]:
+    def invoke_llm(
+        self,
+        prompt_messages: list[PromptMessage],
+        model_parameters: Optional[dict] = None,
+        tools: Sequence[PromptMessageTool] | None = None,
+        stop: Optional[list[str]] = None,
+        stream: bool = True,
+        user: Optional[str] = None,
+        callbacks: Optional[list[Callback]] = None,
+    ) -> Union[LLMResult, Generator]:
         """
         Invoke large language model
 
@@ -132,11 +133,12 @@ class ModelInstance:
             stop=stop,
             stream=stream,
             user=user,
-            callbacks=callbacks
+            callbacks=callbacks,
         )
 
-    def get_llm_num_tokens(self, prompt_messages: list[PromptMessage],
-                           tools: Optional[list[PromptMessageTool]] = None) -> int:
+    def get_llm_num_tokens(
+        self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
+    ) -> int:
         """
         Get number of tokens for llm
 
@@ -153,11 +155,10 @@ class ModelInstance:
             model=self.model,
             credentials=self.credentials,
             prompt_messages=prompt_messages,
-            tools=tools
+            tools=tools,
         )
 
-    def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \
-            -> TextEmbeddingResult:
+    def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) -> TextEmbeddingResult:
         """
         Invoke large language model
 
@@ -174,7 +175,7 @@ class ModelInstance:
             model=self.model,
             credentials=self.credentials,
             texts=texts,
-            user=user
+            user=user,
         )
 
     def get_text_embedding_num_tokens(self, texts: list[str]) -> int:
@@ -192,13 +193,17 @@ class ModelInstance:
             function=self.model_type_instance.get_num_tokens,
             model=self.model,
             credentials=self.credentials,
-            texts=texts
+            texts=texts,
         )
 
-    def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None,
-                      top_n: Optional[int] = None,
-                      user: Optional[str] = None) \
-            -> RerankResult:
+    def invoke_rerank(
+        self,
+        query: str,
+        docs: list[str],
+        score_threshold: Optional[float] = None,
+        top_n: Optional[int] = None,
+        user: Optional[str] = None,
+    ) -> RerankResult:
         """
         Invoke rerank model
 
@@ -221,11 +226,10 @@ class ModelInstance:
             docs=docs,
             score_threshold=score_threshold,
             top_n=top_n,
-            user=user
+            user=user,
         )
 
-    def invoke_moderation(self, text: str, user: Optional[str] = None) \
-            -> bool:
+    def invoke_moderation(self, text: str, user: Optional[str] = None) -> bool:
         """
         Invoke moderation model
 
@@ -242,11 +246,10 @@ class ModelInstance:
             model=self.model,
             credentials=self.credentials,
             text=text,
-            user=user
+            user=user,
         )
 
-    def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \
-            -> str:
+    def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) -> str:
         """
         Invoke large language model
 
@@ -263,11 +266,10 @@ class ModelInstance:
             model=self.model,
             credentials=self.credentials,
             file=file,
-            user=user
+            user=user,
         )
 
-    def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) \
-            -> str:
+    def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> str:
         """
         Invoke large language tts model
 
@@ -288,7 +290,7 @@ class ModelInstance:
             content_text=content_text,
             user=user,
             tenant_id=tenant_id,
-            voice=voice
+            voice=voice,
         )
 
     def _round_robin_invoke(self, function: Callable, *args, **kwargs):
@@ -312,8 +314,8 @@ class ModelInstance:
                     raise last_exception
 
             try:
-                if 'credentials' in kwargs:
-                    del kwargs['credentials']
+                if "credentials" in kwargs:
+                    del kwargs["credentials"]
                 return function(*args, **kwargs, credentials=lb_config.credentials)
             except InvokeRateLimitError as e:
                 # expire in 60 seconds
@@ -340,9 +342,7 @@ class ModelInstance:
 
         self.model_type_instance = cast(TTSModel, self.model_type_instance)
         return self.model_type_instance.get_tts_model_voices(
-            model=self.model,
-            credentials=self.credentials,
-            language=language
+            model=self.model, credentials=self.credentials, language=language
         )
 
 
@@ -363,9 +363,7 @@ class ModelManager:
             return self.get_default_model_instance(tenant_id, model_type)
 
         provider_model_bundle = self._provider_manager.get_provider_model_bundle(
-            tenant_id=tenant_id,
-            provider=provider,
-            model_type=model_type
+            tenant_id=tenant_id, provider=provider, model_type=model_type
         )
 
         return ModelInstance(provider_model_bundle, model)
@@ -386,10 +384,7 @@ class ModelManager:
         :param model_type: model type
         :return:
         """
-        default_model_entity = self._provider_manager.get_default_model(
-            tenant_id=tenant_id,
-            model_type=model_type
-        )
+        default_model_entity = self._provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type)
 
         if not default_model_entity:
             raise ProviderTokenNotInitError(f"Default model not found for {model_type}")
@@ -398,17 +393,20 @@ class ModelManager:
             tenant_id=tenant_id,
             provider=default_model_entity.provider.provider,
             model_type=model_type,
-            model=default_model_entity.model
+            model=default_model_entity.model,
         )
 
 
 class LBModelManager:
-    def __init__(self, tenant_id: str,
-                 provider: str,
-                 model_type: ModelType,
-                 model: str,
-                 load_balancing_configs: list[ModelLoadBalancingConfiguration],
-                 managed_credentials: Optional[dict] = None) -> None:
+    def __init__(
+        self,
+        tenant_id: str,
+        provider: str,
+        model_type: ModelType,
+        model: str,
+        load_balancing_configs: list[ModelLoadBalancingConfiguration],
+        managed_credentials: Optional[dict] = None,
+    ) -> None:
         """
         Load balancing model manager
         :param tenant_id: tenant_id
@@ -439,10 +437,7 @@ class LBModelManager:
         :return:
         """
         cache_key = "model_lb_index:{}:{}:{}:{}".format(
-            self._tenant_id,
-            self._provider,
-            self._model_type.value,
-            self._model
+            self._tenant_id, self._provider, self._model_type.value, self._model
         )
 
         cooldown_load_balancing_configs = []
@@ -473,10 +468,12 @@ class LBModelManager:
 
                 continue
 
-            if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
-                logger.info(f"Model LB\nid: {config.id}\nname:{config.name}\n"
-                            f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n"
-                            f"model_type: {self._model_type.value}\nmodel: {self._model}")
+            if bool(os.environ.get("DEBUG", "False").lower() == "true"):
+                logger.info(
+                    f"Model LB\nid: {config.id}\nname:{config.name}\n"
+                    f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n"
+                    f"model_type: {self._model_type.value}\nmodel: {self._model}"
+                )
 
             return config
 
@@ -490,14 +487,10 @@ class LBModelManager:
         :return:
         """
         cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
-            self._tenant_id,
-            self._provider,
-            self._model_type.value,
-            self._model,
-            config.id
+            self._tenant_id, self._provider, self._model_type.value, self._model, config.id
         )
 
-        redis_client.setex(cooldown_cache_key, expire, 'true')
+        redis_client.setex(cooldown_cache_key, expire, "true")
 
     def in_cooldown(self, config: ModelLoadBalancingConfiguration) -> bool:
         """
@@ -506,11 +499,7 @@ class LBModelManager:
         :return:
         """
         cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
-            self._tenant_id,
-            self._provider,
-            self._model_type.value,
-            self._model,
-            config.id
+            self._tenant_id, self._provider, self._model_type.value, self._model, config.id
         )
 
         res = redis_client.exists(cooldown_cache_key)
@@ -518,11 +507,9 @@ class LBModelManager:
         return res
 
     @staticmethod
-    def get_config_in_cooldown_and_ttl(tenant_id: str,
-                                       provider: str,
-                                       model_type: ModelType,
-                                       model: str,
-                                       config_id: str) -> tuple[bool, int]:
+    def get_config_in_cooldown_and_ttl(
+        tenant_id: str, provider: str, model_type: ModelType, model: str, config_id: str
+    ) -> tuple[bool, int]:
         """
         Get model load balancing config is in cooldown and ttl
         :param tenant_id: workspace id
@@ -533,11 +520,7 @@ class LBModelManager:
         :return:
         """
         cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
-            tenant_id,
-            provider,
-            model_type.value,
-            model,
-            config_id
+            tenant_id, provider, model_type.value, model, config_id
         )
 
         ttl = redis_client.ttl(cooldown_cache_key)