Explorar el Código

chore: Extract common functions of the base model in Azure OpenAI Provider (#9907)

ice yao hace 5 meses
padre
commit
22776f24ab

+ 3 - 0
api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml

@@ -53,6 +53,9 @@ model_credential_schema:
       type: select
       required: true
       options:
+        - label:
+            en_US: 2024-10-01-preview
+          value: 2024-10-01-preview
         - label:
             en_US: 2024-09-01-preview
           value: 2024-09-01-preview

+ 10 - 17
api/core/model_runtime/model_providers/azure_openai/llm/llm.py

@@ -45,9 +45,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
         stream: bool = True,
         user: Optional[str] = None,
     ) -> Union[LLMResult, Generator]:
-        base_model_name = credentials.get("base_model_name")
-        if not base_model_name:
-            raise ValueError("Base Model Name is required")
+        base_model_name = self._get_base_model_name(credentials)
         ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
 
         if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
@@ -81,9 +79,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
         prompt_messages: list[PromptMessage],
         tools: Optional[list[PromptMessageTool]] = None,
     ) -> int:
-        base_model_name = credentials.get("base_model_name")
-        if not base_model_name:
-            raise ValueError("Base Model Name is required")
+        base_model_name = self._get_base_model_name(credentials)
         model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
         if not model_entity:
             raise ValueError(f"Base Model Name {base_model_name} is invalid")
@@ -108,9 +104,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
         if "base_model_name" not in credentials:
             raise CredentialsValidateFailedError("Base Model Name is required")
 
-        base_model_name = credentials.get("base_model_name")
-        if not base_model_name:
-            raise CredentialsValidateFailedError("Base Model Name is required")
+        base_model_name = self._get_base_model_name(credentials)
         ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
 
         if not ai_model_entity:
@@ -149,9 +143,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
             raise CredentialsValidateFailedError(str(ex))
 
     def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
-        base_model_name = credentials.get("base_model_name")
-        if not base_model_name:
-            raise ValueError("Base Model Name is required")
+        base_model_name = self._get_base_model_name(credentials)
         ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
         return ai_model_entity.entity if ai_model_entity else None
 
@@ -308,11 +300,6 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
 
         if tools:
             extra_model_kwargs["tools"] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
-            # extra_model_kwargs['functions'] = [{
-            #     "name": tool.name,
-            #     "description": tool.description,
-            #     "parameters": tool.parameters
-            # } for tool in tools]
 
         if stop:
             extra_model_kwargs["stop"] = stop
@@ -769,3 +756,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
                 ai_model_entity_copy.entity.label.en_US = model
                 ai_model_entity_copy.entity.label.zh_Hans = model
                 return ai_model_entity_copy
+
+    def _get_base_model_name(self, credentials: dict) -> str:
+        base_model_name = credentials.get("base_model_name")
+        if not base_model_name:
+            raise ValueError("Base Model Name is required")
+        return base_model_name