Sfoglia il codice sorgente

fix(core): Fix incorrect type hints. (#5427)

-LAN- 10 mesi fa
parent
commit
23fa3dedc4

+ 4 - 2
api/core/extension/extensible.py

@@ -1,5 +1,5 @@
 import enum
-import importlib
+import importlib.util
 import json
 import logging
 import os
@@ -74,6 +74,8 @@ class Extensible:
                 # Dynamic loading {subdir_name}.py file and find the subclass of Extensible
                 py_path = os.path.join(subdir_path, extension_name + '.py')
                 spec = importlib.util.spec_from_file_location(extension_name, py_path)
+                if not spec or not spec.loader:
+                    raise Exception(f"Failed to load module {extension_name} from {py_path}")
                 mod = importlib.util.module_from_spec(spec)
                 spec.loader.exec_module(mod)
 
@@ -108,6 +110,6 @@ class Extensible:
                     position=position
                 ))
 
-        sorted_extensions = sort_to_dict_by_position_map(position_map, extensions, lambda x: x.name)
+        sorted_extensions = sort_to_dict_by_position_map(position_map=position_map, data=extensions, name_func=lambda x: x.name)
 
         return sorted_extensions

+ 10 - 11
api/core/helper/module_import_helper.py

@@ -5,11 +5,7 @@ from types import ModuleType
 from typing import AnyStr
 
 
-def import_module_from_source(
-        module_name: str,
-        py_file_path: AnyStr,
-        use_lazy_loader: bool = False
-) -> ModuleType:
+def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_lazy_loader: bool = False) -> ModuleType:
     """
     Importing a module from the source file directly
     """
@@ -17,9 +13,13 @@ def import_module_from_source(
         existed_spec = importlib.util.find_spec(module_name)
         if existed_spec:
             spec = existed_spec
+            if not spec.loader:
+                raise Exception(f"Failed to load module {module_name} from {py_file_path}")
         else:
             # Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
             spec = importlib.util.spec_from_file_location(module_name, py_file_path)
+            if not spec or not spec.loader:
+                raise Exception(f"Failed to load module {module_name} from {py_file_path}")
             if use_lazy_loader:
                 # Refer to: https://docs.python.org/3/library/importlib.html#implementing-lazy-imports
                 spec.loader = importlib.util.LazyLoader(spec.loader)
@@ -29,7 +29,7 @@ def import_module_from_source(
         spec.loader.exec_module(module)
         return module
     except Exception as e:
-        logging.exception(f'Failed to load module {module_name} from {py_file_path}: {str(e)}')
+        logging.exception(f"Failed to load module {module_name} from {py_file_path}: {str(e)}")
         raise e
 
 
@@ -43,15 +43,14 @@ def get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type]
 
 
 def load_single_subclass_from_source(
-        module_name: str,
-        script_path: AnyStr,
-        parent_type: type,
-        use_lazy_loader: bool = False,
+    *, module_name: str, script_path: AnyStr, parent_type: type, use_lazy_loader: bool = False
 ) -> type:
     """
     Load a single subclass from the source
     """
-    module = import_module_from_source(module_name, script_path, use_lazy_loader)
+    module = import_module_from_source(
+        module_name=module_name, py_file_path=script_path, use_lazy_loader=use_lazy_loader
+    )
     subclasses = get_subclasses_from_module(module, parent_type)
     match len(subclasses):
         case 1:

+ 2 - 5
api/core/helper/position_helper.py

@@ -1,15 +1,12 @@
 import os
 from collections import OrderedDict
 from collections.abc import Callable
-from typing import Any, AnyStr
+from typing import Any
 
 from core.tools.utils.yaml_utils import load_yaml_file
 
 
-def get_position_map(
-        folder_path: AnyStr,
-        file_name: str = '_position.yaml',
-) -> dict[str, int]:
+def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> dict[str, int]:
     """
     Get the mapping from name to index from a YAML file
     :param folder_path:

+ 9 - 4
api/core/model_manager.py

@@ -1,6 +1,6 @@
 import logging
 import os
-from collections.abc import Generator
+from collections.abc import Callable, Generator
 from typing import IO, Optional, Union, cast
 
 from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
@@ -102,7 +102,7 @@ class ModelInstance:
 
     def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
                    tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
-                   stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
+                   stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \
             -> Union[LLMResult, Generator]:
         """
         Invoke large language model
@@ -291,7 +291,7 @@ class ModelInstance:
             streaming=streaming
         )
 
-    def _round_robin_invoke(self, function: callable, *args, **kwargs):
+    def _round_robin_invoke(self, function: Callable, *args, **kwargs):
         """
         Round-robin invoke
         :param function: function to invoke
@@ -437,6 +437,7 @@ class LBModelManager:
 
         while True:
             current_index = redis_client.incr(cache_key)
+            current_index = cast(int, current_index)
             if current_index >= 10000000:
                 current_index = 1
                 redis_client.set(cache_key, current_index)
@@ -499,7 +500,10 @@ class LBModelManager:
             config.id
         )
 
-        return redis_client.exists(cooldown_cache_key)
+
+        res = redis_client.exists(cooldown_cache_key)
+        res = cast(bool, res)
+        return res
 
     @classmethod
     def get_config_in_cooldown_and_ttl(cls, tenant_id: str,
@@ -528,4 +532,5 @@ class LBModelManager:
         if ttl == -2:
             return False, 0
 
+        ttl = cast(int, ttl)
         return True, ttl

+ 5 - 4
api/core/model_runtime/entities/provider_entities.py

@@ -1,10 +1,11 @@
+from collections.abc import Sequence
 from enum import Enum
 from typing import Optional
 
 from pydantic import BaseModel, ConfigDict
 
 from core.model_runtime.entities.common_entities import I18nObject
-from core.model_runtime.entities.model_entities import AIModelEntity, ModelType, ProviderModel
+from core.model_runtime.entities.model_entities import ModelType, ProviderModel
 
 
 class ConfigurateMethod(Enum):
@@ -93,8 +94,8 @@ class SimpleProviderEntity(BaseModel):
     label: I18nObject
     icon_small: Optional[I18nObject] = None
     icon_large: Optional[I18nObject] = None
-    supported_model_types: list[ModelType]
-    models: list[AIModelEntity] = []
+    supported_model_types: Sequence[ModelType]
+    models: list[ProviderModel] = []
 
 
 class ProviderHelpEntity(BaseModel):
@@ -116,7 +117,7 @@ class ProviderEntity(BaseModel):
     icon_large: Optional[I18nObject] = None
     background: Optional[str] = None
     help: Optional[ProviderHelpEntity] = None
-    supported_model_types: list[ModelType]
+    supported_model_types: Sequence[ModelType]
     configurate_methods: list[ConfigurateMethod]
     models: list[ProviderModel] = []
     provider_credential_schema: Optional[ProviderCredentialSchema] = None

+ 33 - 19
api/core/model_runtime/model_providers/__base/ai_model.py

@@ -1,6 +1,7 @@
 import decimal
 import os
 from abc import ABC, abstractmethod
+from collections.abc import Mapping
 from typing import Optional
 
 from pydantic import ConfigDict
@@ -26,15 +27,16 @@ class AIModel(ABC):
     """
     Base class for all models.
     """
+
     model_type: ModelType
-    model_schemas: list[AIModelEntity] = None
+    model_schemas: Optional[list[AIModelEntity]] = None
     started_at: float = 0
 
     # pydantic configs
     model_config = ConfigDict(protected_namespaces=())
 
     @abstractmethod
-    def validate_credentials(self, model: str, credentials: dict) -> None:
+    def validate_credentials(self, model: str, credentials: Mapping) -> None:
         """
         Validate model credentials
 
@@ -90,8 +92,8 @@ class AIModel(ABC):
 
         # get price info from predefined model schema
         price_config: Optional[PriceConfig] = None
-        if model_schema:
-            price_config: PriceConfig = model_schema.pricing
+        if model_schema and model_schema.pricing:
+            price_config = model_schema.pricing
 
         # get unit price
         unit_price = None
@@ -103,13 +105,15 @@ class AIModel(ABC):
 
         if unit_price is None:
             return PriceInfo(
-                unit_price=decimal.Decimal('0.0'),
-                unit=decimal.Decimal('0.0'),
-                total_amount=decimal.Decimal('0.0'),
+                unit_price=decimal.Decimal("0.0"),
+                unit=decimal.Decimal("0.0"),
+                total_amount=decimal.Decimal("0.0"),
                 currency="USD",
             )
 
         # calculate total amount
+        if not price_config:
+            raise ValueError(f"Price config not found for model {model}")
         total_amount = tokens * unit_price * price_config.unit
         total_amount = total_amount.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
 
@@ -209,7 +213,7 @@ class AIModel(ABC):
 
         return model_schemas
 
-    def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]:
+    def get_model_schema(self, model: str, credentials: Optional[Mapping] = None) -> Optional[AIModelEntity]:
         """
         Get model schema by model name and credentials
 
@@ -231,7 +235,7 @@ class AIModel(ABC):
 
         return None
 
-    def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
+    def get_customizable_model_schema_from_credentials(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
         """
         Get customizable model schema from credentials
 
@@ -240,8 +244,8 @@ class AIModel(ABC):
         :return: model schema
         """
         return self._get_customizable_model_schema(model, credentials)
-    
-    def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
+
+    def _get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
         """
         Get customizable model schema and fill in the template
         """
@@ -249,7 +253,7 @@ class AIModel(ABC):
 
         if not schema:
             return None
-        
+
         # fill in the template
         new_parameter_rules = []
         for parameter_rule in schema.parameter_rules:
@@ -271,10 +275,20 @@ class AIModel(ABC):
                         parameter_rule.help = I18nObject(
                             en_US=default_parameter_rule['help']['en_US'],
                         )
-                    if not parameter_rule.help.en_US and ('help' in default_parameter_rule and 'en_US' in default_parameter_rule['help']):
-                        parameter_rule.help.en_US = default_parameter_rule['help']['en_US']
-                    if not parameter_rule.help.zh_Hans and ('help' in default_parameter_rule and 'zh_Hans' in default_parameter_rule['help']):
-                        parameter_rule.help.zh_Hans = default_parameter_rule['help'].get('zh_Hans', default_parameter_rule['help']['en_US'])
+                    if (
+                        parameter_rule.help
+                        and not parameter_rule.help.en_US
+                        and ("help" in default_parameter_rule and "en_US" in default_parameter_rule["help"])
+                    ):
+                        parameter_rule.help.en_US = default_parameter_rule["help"]["en_US"]
+                    if (
+                        parameter_rule.help
+                        and not parameter_rule.help.zh_Hans
+                        and ("help" in default_parameter_rule and "zh_Hans" in default_parameter_rule["help"])
+                    ):
+                        parameter_rule.help.zh_Hans = default_parameter_rule["help"].get(
+                            "zh_Hans", default_parameter_rule["help"]["en_US"]
+                        )
                 except ValueError:
                     pass
 
@@ -284,7 +298,7 @@ class AIModel(ABC):
 
         return schema
 
-    def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
+    def get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
         """
         Get customizable model schema
 
@@ -304,7 +318,7 @@ class AIModel(ABC):
         default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name)
 
         if not default_parameter_rule:
-            raise Exception(f'Invalid model parameter rule name {name}')
+            raise Exception(f"Invalid model parameter rule name {name}")
 
         return default_parameter_rule
 
@@ -318,4 +332,4 @@ class AIModel(ABC):
         :param text: plain text of prompt. You need to convert the original message to plain text
         :return: number of tokens
         """
-        return GPT2Tokenizer.get_num_tokens(text)
+        return GPT2Tokenizer.get_num_tokens(text)

+ 17 - 14
api/core/model_runtime/model_providers/__base/large_language_model.py

@@ -3,7 +3,7 @@ import os
 import re
 import time
 from abc import abstractmethod
-from collections.abc import Generator
+from collections.abc import Generator, Mapping
 from typing import Optional, Union
 
 from pydantic import ConfigDict
@@ -43,7 +43,7 @@ class LargeLanguageModel(AIModel):
     def invoke(self, model: str, credentials: dict,
                prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
                tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
-               stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
+               stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \
             -> Union[LLMResult, Generator]:
         """
         Invoke large language model
@@ -129,7 +129,7 @@ class LargeLanguageModel(AIModel):
                 user=user,
                 callbacks=callbacks
             )
-        else:
+        elif isinstance(result, LLMResult):
             self._trigger_after_invoke_callbacks(
                 model=model,
                 result=result,
@@ -148,7 +148,7 @@ class LargeLanguageModel(AIModel):
     def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
                            model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
                            stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
-                           callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
+                           callbacks: Optional[list[Callback]] = None) -> Union[LLMResult, Generator]:
         """
         Code block mode wrapper, ensure the response is a code block with output markdown quote
 
@@ -196,7 +196,7 @@ if you are not sure about the structure.
             # override the system message
             prompt_messages[0] = SystemPromptMessage(
                 content=block_prompts
-                    .replace("{{instructions}}", prompt_messages[0].content)
+                    .replace("{{instructions}}", str(prompt_messages[0].content))
             )
         else:
             # insert the system message
@@ -274,8 +274,9 @@ if you are not sure about the structure.
             else:
                 yield piece
                 continue
-            new_piece = ""
+            new_piece: str = ""
             for char in piece:
+                char = str(char)
                 if state == "normal":
                     if char == "`":
                         state = "in_backticks"
@@ -340,7 +341,7 @@ if you are not sure about the structure.
             if state == "done":
                 continue
 
-            new_piece = ""
+            new_piece: str = ""
             for char in piece:
                 if state == "search_start":
                     if char == "`":
@@ -365,7 +366,7 @@ if you are not sure about the structure.
                             # If backticks were counted but we're still collecting content, it was a false start
                             new_piece += "`" * backtick_count
                             backtick_count = 0
-                        new_piece += char
+                        new_piece += str(char)
 
                 elif state == "done":
                     break
@@ -388,13 +389,14 @@ if you are not sure about the structure.
                                  prompt_messages: list[PromptMessage], model_parameters: dict,
                                  tools: Optional[list[PromptMessageTool]] = None,
                                  stop: Optional[list[str]] = None, stream: bool = True,
-                                 user: Optional[str] = None, callbacks: list[Callback] = None) -> Generator:
+                                 user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> Generator:
         """
         Invoke result generator
 
         :param result: result generator
         :return: result generator
         """
+        callbacks = callbacks or []
         prompt_message = AssistantPromptMessage(
             content=""
         )
@@ -489,6 +491,7 @@ if you are not sure about the structure.
 
     def _llm_result_to_stream(self, result: LLMResult) -> Generator:
         """
+from typing_extensions import deprecated
         Transform llm result to stream
 
         :param result: llm result
@@ -531,7 +534,7 @@ if you are not sure about the structure.
 
         return []
 
-    def get_model_mode(self, model: str, credentials: Optional[dict] = None) -> LLMMode:
+    def get_model_mode(self, model: str, credentials: Optional[Mapping] = None) -> LLMMode:
         """
         Get model mode
 
@@ -595,7 +598,7 @@ if you are not sure about the structure.
                                          prompt_messages: list[PromptMessage], model_parameters: dict,
                                          tools: Optional[list[PromptMessageTool]] = None,
                                          stop: Optional[list[str]] = None, stream: bool = True,
-                                         user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
+                                         user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
         """
         Trigger before invoke callbacks
 
@@ -633,7 +636,7 @@ if you are not sure about the structure.
                                      prompt_messages: list[PromptMessage], model_parameters: dict,
                                      tools: Optional[list[PromptMessageTool]] = None,
                                      stop: Optional[list[str]] = None, stream: bool = True,
-                                     user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
+                                     user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
         """
         Trigger new chunk callbacks
 
@@ -672,7 +675,7 @@ if you are not sure about the structure.
                                         prompt_messages: list[PromptMessage], model_parameters: dict,
                                         tools: Optional[list[PromptMessageTool]] = None,
                                         stop: Optional[list[str]] = None, stream: bool = True,
-                                        user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
+                                        user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
         """
         Trigger after invoke callbacks
 
@@ -712,7 +715,7 @@ if you are not sure about the structure.
                                         prompt_messages: list[PromptMessage], model_parameters: dict,
                                         tools: Optional[list[PromptMessageTool]] = None,
                                         stop: Optional[list[str]] = None, stream: bool = True,
-                                        user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
+                                        user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
         """
         Trigger invoke error callbacks
 

+ 18 - 11
api/core/model_runtime/model_providers/__base/model_provider.py

@@ -1,5 +1,6 @@
 import os
 from abc import ABC, abstractmethod
+from typing import Optional
 
 from core.helper.module_import_helper import get_subclasses_from_module, import_module_from_source
 from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
@@ -9,7 +10,7 @@ from core.tools.utils.yaml_utils import load_yaml_file
 
 
 class ModelProvider(ABC):
-    provider_schema: ProviderEntity = None
+    provider_schema: Optional[ProviderEntity] = None
     model_instance_map: dict[str, AIModel] = {}
 
     @abstractmethod
@@ -28,23 +29,23 @@ class ModelProvider(ABC):
     def get_provider_schema(self) -> ProviderEntity:
         """
         Get provider schema
-
+    
         :return: provider schema
         """
         if self.provider_schema:
             return self.provider_schema
-
+    
         # get dirname of the current path
         provider_name = self.__class__.__module__.split('.')[-1]
 
         # get the path of the model_provider classes
         base_path = os.path.abspath(__file__)
         current_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name)
-
+    
         # read provider schema from yaml file
         yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
         yaml_data = load_yaml_file(yaml_path, ignore_error=True)
-
+    
         try:
             # yaml_data to entity
             provider_schema = ProviderEntity(**yaml_data)
@@ -53,7 +54,7 @@ class ModelProvider(ABC):
 
         # cache schema
         self.provider_schema = provider_schema
-
+    
         return provider_schema
 
     def models(self, model_type: ModelType) -> list[AIModelEntity]:
@@ -84,7 +85,7 @@ class ModelProvider(ABC):
         :return:
         """
         # get dirname of the current path
-        provider_name = self.__class__.__module__.split('.')[-1]
+        provider_name = self.__class__.__module__.split(".")[-1]
 
         if f"{provider_name}.{model_type.value}" in self.model_instance_map:
             return self.model_instance_map[f"{provider_name}.{model_type.value}"]
@@ -101,11 +102,17 @@ class ModelProvider(ABC):
         # Dynamic loading {model_type_name}.py file and find the subclass of AIModel
         parent_module = '.'.join(self.__class__.__module__.split('.')[:-1])
         mod = import_module_from_source(
-            f'{parent_module}.{model_type_name}.{model_type_name}', model_type_py_path)
-        model_class = next(filter(lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__,
-                                  get_subclasses_from_module(mod, AIModel)), None)
+            module_name=f"{parent_module}.{model_type_name}.{model_type_name}", py_file_path=model_type_py_path
+        )
+        model_class = next(
+            filter(
+                lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__,
+                get_subclasses_from_module(mod, AIModel),
+            ),
+            None,
+        )
         if not model_class:
-            raise Exception(f'Missing AIModel Class for model type {model_type} in {model_type_py_path}')
+            raise Exception(f"Missing AIModel Class for model type {model_type} in {model_type_py_path}")
 
         model_instance_map = model_class()
         self.model_instance_map[f"{provider_name}.{model_type.value}"] = model_instance_map

+ 54 - 26
api/core/model_runtime/model_providers/model_provider_factory.py

@@ -1,5 +1,6 @@
 import logging
 import os
+from collections.abc import Sequence
 from typing import Optional
 
 from pydantic import BaseModel, ConfigDict
@@ -16,20 +17,21 @@ logger = logging.getLogger(__name__)
 
 
 class ModelProviderExtension(BaseModel):
+    model_config = ConfigDict(arbitrary_types_allowed=True)
+
     provider_instance: ModelProvider
     name: str
     position: Optional[int] = None
-    model_config = ConfigDict(arbitrary_types_allowed=True)
 
 
 class ModelProviderFactory:
-    model_provider_extensions: dict[str, ModelProviderExtension] = None
+    model_provider_extensions: Optional[dict[str, ModelProviderExtension]] = None
 
     def __init__(self) -> None:
         # for cache in memory
         self.get_providers()
 
-    def get_providers(self) -> list[ProviderEntity]:
+    def get_providers(self) -> Sequence[ProviderEntity]:
         """
         Get all providers
         :return: list of providers
@@ -39,7 +41,7 @@ class ModelProviderFactory:
 
         # traverse all model_provider_extensions
         providers = []
-        for name, model_provider_extension in model_provider_extensions.items():
+        for model_provider_extension in model_provider_extensions.values():
             # get model_provider instance
             model_provider_instance = model_provider_extension.provider_instance
 
@@ -57,7 +59,7 @@ class ModelProviderFactory:
         # return providers
         return providers
 
-    def provider_credentials_validate(self, provider: str, credentials: dict) -> dict:
+    def provider_credentials_validate(self, *, provider: str, credentials: dict) -> dict:
         """
         Validate provider credentials
 
@@ -74,6 +76,9 @@ class ModelProviderFactory:
         # get provider_credential_schema and validate credentials according to the rules
         provider_credential_schema = provider_schema.provider_credential_schema
 
+        if not provider_credential_schema:
+            raise ValueError(f"Provider {provider} does not have provider_credential_schema")
+
         # validate provider credential schema
         validator = ProviderCredentialSchemaValidator(provider_credential_schema)
         filtered_credentials = validator.validate_and_filter(credentials)
@@ -83,8 +88,9 @@ class ModelProviderFactory:
 
         return filtered_credentials
 
-    def model_credentials_validate(self, provider: str, model_type: ModelType,
-                                   model: str, credentials: dict) -> dict:
+    def model_credentials_validate(
+        self, *, provider: str, model_type: ModelType, model: str, credentials: dict
+    ) -> dict:
         """
         Validate model credentials
 
@@ -103,6 +109,9 @@ class ModelProviderFactory:
         # get model_credential_schema and validate credentials according to the rules
         model_credential_schema = provider_schema.model_credential_schema
 
+        if not model_credential_schema:
+            raise ValueError(f"Provider {provider} does not have model_credential_schema")
+
         # validate model credential schema
         validator = ModelCredentialSchemaValidator(model_type, model_credential_schema)
         filtered_credentials = validator.validate_and_filter(credentials)
@@ -115,11 +124,13 @@ class ModelProviderFactory:
 
         return filtered_credentials
 
-    def get_models(self,
-                   provider: Optional[str] = None,
-                   model_type: Optional[ModelType] = None,
-                   provider_configs: Optional[list[ProviderConfig]] = None) \
-            -> list[SimpleProviderEntity]:
+    def get_models(
+        self,
+        *,
+        provider: Optional[str] = None,
+        model_type: Optional[ModelType] = None,
+        provider_configs: Optional[list[ProviderConfig]] = None,
+    ) -> list[SimpleProviderEntity]:
         """
         Get all models for given model type
 
@@ -128,6 +139,8 @@ class ModelProviderFactory:
         :param provider_configs: list of provider configs
         :return: list of models
         """
+        provider_configs = provider_configs or []
+
         # scan all providers
         model_provider_extensions = self._get_model_provider_map()
 
@@ -184,7 +197,7 @@ class ModelProviderFactory:
         # get the provider extension
         model_provider_extension = model_provider_extensions.get(provider)
         if not model_provider_extension:
-            raise Exception(f'Invalid provider: {provider}')
+            raise Exception(f"Invalid provider: {provider}")
 
         # get the provider instance
         model_provider_instance = model_provider_extension.provider_instance
@@ -192,10 +205,22 @@ class ModelProviderFactory:
         return model_provider_instance
 
     def _get_model_provider_map(self) -> dict[str, ModelProviderExtension]:
+        """
+        Retrieves the model provider map.
+
+        This method retrieves the model provider map, which is a dictionary containing the model provider names as keys
+        and instances of `ModelProviderExtension` as values. The model provider map is used to store information about
+        available model providers.
+
+        Returns:
+            A dictionary containing the model provider map.
+
+        Raises:
+            None.
+        """
         if self.model_provider_extensions:
             return self.model_provider_extensions
 
-
         # get the path of current classes
         current_path = os.path.abspath(__file__)
         model_providers_path = os.path.dirname(current_path)
@@ -204,8 +229,8 @@ class ModelProviderFactory:
         model_provider_dir_paths = [
             os.path.join(model_providers_path, model_provider_dir)
             for model_provider_dir in os.listdir(model_providers_path)
-            if not model_provider_dir.startswith('__')
-               and os.path.isdir(os.path.join(model_providers_path, model_provider_dir))
+            if not model_provider_dir.startswith("__")
+            and os.path.isdir(os.path.join(model_providers_path, model_provider_dir))
         ]
 
         # get _position.yaml file path
@@ -219,30 +244,33 @@ class ModelProviderFactory:
 
             file_names = os.listdir(model_provider_dir_path)
 
-            if (model_provider_name + '.py') not in file_names:
+            if (model_provider_name + ".py") not in file_names:
                 logger.warning(f"Missing {model_provider_name}.py file in {model_provider_dir_path}, Skip.")
                 continue
 
             # Dynamic loading {model_provider_name}.py file and find the subclass of ModelProvider
-            py_path = os.path.join(model_provider_dir_path, model_provider_name + '.py')
+            py_path = os.path.join(model_provider_dir_path, model_provider_name + ".py")
             model_provider_class = load_single_subclass_from_source(
-                module_name=f'core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}',
+                module_name=f"core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}",
                 script_path=py_path,
-                parent_type=ModelProvider)
+                parent_type=ModelProvider,
+            )
 
             if not model_provider_class:
                 logger.warning(f"Missing Model Provider Class that extends ModelProvider in {py_path}, Skip.")
                 continue
 
-            if f'{model_provider_name}.yaml' not in file_names:
+            if f"{model_provider_name}.yaml" not in file_names:
                 logger.warning(f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip.")
                 continue
 
-            model_providers.append(ModelProviderExtension(
-                name=model_provider_name,
-                provider_instance=model_provider_class(),
-                position=position_map.get(model_provider_name)
-            ))
+            model_providers.append(
+                ModelProviderExtension(
+                    name=model_provider_name,
+                    provider_instance=model_provider_class(),
+                    position=position_map.get(model_provider_name),
+                )
+            )
 
         sorted_extensions = sort_to_dict_by_position_map(position_map, model_providers, lambda x: x.name)
 

+ 12 - 20
api/core/model_runtime/model_providers/openai/_common.py

@@ -1,3 +1,5 @@
+from collections.abc import Mapping
+
 import openai
 from httpx import Timeout
 
@@ -12,7 +14,7 @@ from core.model_runtime.errors.invoke import (
 
 
 class _CommonOpenAI:
-    def _to_credential_kwargs(self, credentials: dict) -> dict:
+    def _to_credential_kwargs(self, credentials: Mapping) -> dict:
         """
         Transform credentials to kwargs for model instance
 
@@ -25,9 +27,9 @@ class _CommonOpenAI:
             "max_retries": 1,
         }
 
-        if credentials.get('openai_api_base'):
-            credentials['openai_api_base'] = credentials['openai_api_base'].rstrip('/')
-            credentials_kwargs['base_url'] = credentials['openai_api_base'] + '/v1'
+        if credentials.get("openai_api_base"):
+            openai_api_base = credentials["openai_api_base"].rstrip("/")
+            credentials_kwargs["base_url"] = openai_api_base + "/v1"
 
         if 'openai_organization' in credentials:
             credentials_kwargs['organization'] = credentials['openai_organization']
@@ -45,24 +47,14 @@ class _CommonOpenAI:
         :return: Invoke error mapping
         """
         return {
-            InvokeConnectionError: [
-                openai.APIConnectionError,
-                openai.APITimeoutError
-            ],
-            InvokeServerUnavailableError: [
-                openai.InternalServerError
-            ],
-            InvokeRateLimitError: [
-                openai.RateLimitError
-            ],
-            InvokeAuthorizationError: [
-                openai.AuthenticationError,
-                openai.PermissionDeniedError
-            ],
+            InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError],
+            InvokeServerUnavailableError: [openai.InternalServerError],
+            InvokeRateLimitError: [openai.RateLimitError],
+            InvokeAuthorizationError: [openai.AuthenticationError, openai.PermissionDeniedError],
             InvokeBadRequestError: [
                 openai.BadRequestError,
                 openai.NotFoundError,
                 openai.UnprocessableEntityError,
-                openai.APIError
-            ]
+                openai.APIError,
+            ],
         }

+ 2 - 1
api/core/model_runtime/model_providers/openai/openai.py

@@ -1,4 +1,5 @@
 import logging
+from collections.abc import Mapping
 
 from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
@@ -9,7 +10,7 @@ logger = logging.getLogger(__name__)
 
 class OpenAIProvider(ModelProvider):
 
-    def validate_provider_credentials(self, credentials: dict) -> None:
+    def validate_provider_credentials(self, credentials: Mapping) -> None:
         """
         Validate provider credentials
         if validate failed, raise exception