Przeglądaj źródła

Azure openai init (#1929)

Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Charlie.Wei 1 rok temu
rodzic
commit
5b24d7129e

+ 77 - 9
api/core/entities/provider_configuration.py

@@ -1,7 +1,7 @@
 import datetime
 import json
 import logging
-import time
+
 from json import JSONDecodeError
 from typing import Optional, List, Dict, Tuple, Iterator
 
@@ -11,8 +11,9 @@ from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, S
 from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus
 from core.helper import encrypter
 from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
-from core.model_runtime.entities.model_entities import ModelType
-from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType
+from core.model_runtime.entities.model_entities import ModelType, FetchFrom
+from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType, \
+    ConfigurateMethod
 from core.model_runtime.model_providers import model_provider_factory
 from core.model_runtime.model_providers.__base.ai_model import AIModel
 from core.model_runtime.model_providers.__base.model_provider import ModelProvider
@@ -22,6 +23,8 @@ from models.provider import ProviderType, Provider, ProviderModel, TenantPreferr
 
 logger = logging.getLogger(__name__)
 
+original_provider_configurate_methods = {}
+
 
 class ProviderConfiguration(BaseModel):
     """
@@ -34,6 +37,20 @@ class ProviderConfiguration(BaseModel):
     system_configuration: SystemConfiguration
     custom_configuration: CustomConfiguration
 
+    def __init__(self, **data):
+        super().__init__(**data)
+
+        if self.provider.provider not in original_provider_configurate_methods:
+            original_provider_configurate_methods[self.provider.provider] = []
+            for configurate_method in self.provider.configurate_methods:
+                original_provider_configurate_methods[self.provider.provider].append(configurate_method)
+
+        if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
+            if (any([len(quota_configuration.restrict_models) > 0
+                     for quota_configuration in self.system_configuration.quota_configurations])
+                    and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods):
+                self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
+
     def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
         """
         Get current credentials.
@@ -123,7 +140,8 @@ class ProviderConfiguration(BaseModel):
 
         if provider_record:
             try:
-                original_credentials = json.loads(provider_record.encrypted_config) if provider_record.encrypted_config else {}
+                original_credentials = json.loads(
+                    provider_record.encrypted_config) if provider_record.encrypted_config else {}
             except JSONDecodeError:
                 original_credentials = {}
 
@@ -265,7 +283,8 @@ class ProviderConfiguration(BaseModel):
 
         if provider_model_record:
             try:
-                original_credentials = json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
+                original_credentials = json.loads(
+                    provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
             except JSONDecodeError:
                 original_credentials = {}
 
@@ -534,21 +553,70 @@ class ProviderConfiguration(BaseModel):
                 ]
             )
 
+        if self.provider.provider not in original_provider_configurate_methods:
+            original_provider_configurate_methods[self.provider.provider] = []
+            for configurate_method in provider_instance.get_provider_schema().configurate_methods:
+                original_provider_configurate_methods[self.provider.provider].append(configurate_method)
+
+        should_use_custom_model = False
+        if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
+            should_use_custom_model = True
+
         for quota_configuration in self.system_configuration.quota_configurations:
             if self.system_configuration.current_quota_type != quota_configuration.quota_type:
                 continue
 
-            restrict_llms = quota_configuration.restrict_llms
-            if not restrict_llms:
+            restrict_models = quota_configuration.restrict_models
+            if len(restrict_models) == 0:
                 break
 
+            if should_use_custom_model:
+                if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
+                    # only customizable model
+                    for restrict_model in restrict_models:
+                        copy_credentials = self.system_configuration.credentials.copy()
+                        if restrict_model.base_model_name:
+                            copy_credentials['base_model_name'] = restrict_model.base_model_name
+
+                        try:
+                            custom_model_schema = (
+                                provider_instance.get_model_instance(restrict_model.model_type)
+                                .get_customizable_model_schema_from_credentials(
+                                    restrict_model.model,
+                                    copy_credentials
+                                )
+                            )
+                        except Exception as ex:
+                            logger.warning(f'get custom model schema failed, {ex}')
+                            continue
+
+                        if not custom_model_schema:
+                            continue
+
+                        if custom_model_schema.model_type not in model_types:
+                            continue
+
+                        provider_models.append(
+                            ModelWithProviderEntity(
+                                model=custom_model_schema.model,
+                                label=custom_model_schema.label,
+                                model_type=custom_model_schema.model_type,
+                                features=custom_model_schema.features,
+                                fetch_from=FetchFrom.PREDEFINED_MODEL,
+                                model_properties=custom_model_schema.model_properties,
+                                deprecated=custom_model_schema.deprecated,
+                                provider=SimpleModelProviderEntity(self.provider),
+                                status=ModelStatus.ACTIVE
+                            )
+                        )
+
             # if llm name not in restricted llm list, remove it
+            restrict_model_names = [rm.model for rm in restrict_models]
             for m in provider_models:
-                if m.model_type == ModelType.LLM and m.model not in restrict_llms:
+                if m.model_type == ModelType.LLM and m.model not in restrict_model_names:
                     m.status = ModelStatus.NO_PERMISSION
                 elif not quota_configuration.is_valid:
                     m.status = ModelStatus.QUOTA_EXCEEDED
-
         return provider_models
 
     def _get_custom_provider_models(self,

+ 7 - 1
api/core/entities/provider_entities.py

@@ -21,6 +21,12 @@ class SystemConfigurationStatus(Enum):
     UNSUPPORTED = 'unsupported'
 
 
+class RestrictModel(BaseModel):
+    model: str
+    base_model_name: Optional[str] = None
+    model_type: ModelType
+
+
 class QuotaConfiguration(BaseModel):
     """
     Model class for provider quota configuration.
@@ -30,7 +36,7 @@ class QuotaConfiguration(BaseModel):
     quota_limit: int
     quota_used: int
     is_valid: bool
-    restrict_llms: list[str] = []
+    restrict_models: list[RestrictModel] = []
 
 
 class SystemConfiguration(BaseModel):

+ 52 - 11
api/core/hosting_configuration.py

@@ -4,13 +4,14 @@ from typing import Optional
 from flask import Flask
 from pydantic import BaseModel
 
-from core.entities.provider_entities import QuotaUnit
+from core.entities.provider_entities import QuotaUnit, RestrictModel
+from core.model_runtime.entities.model_entities import ModelType
 from models.provider import ProviderQuotaType
 
 
 class HostingQuota(BaseModel):
     quota_type: ProviderQuotaType
-    restrict_llms: list[str] = []
+    restrict_models: list[RestrictModel] = []
 
 
 class TrialHostingQuota(HostingQuota):
@@ -47,10 +48,9 @@ class HostingConfiguration:
     provider_map: dict[str, HostingProvider] = {}
     moderation_config: HostedModerationConfig = None
 
-    def init_app(self, app: Flask):
-        if app.config.get('EDITION') != 'CLOUD':
-            return
+    def init_app(self, app: Flask) -> None:
 
+        self.provider_map["azure_openai"] = self.init_azure_openai()
         self.provider_map["openai"] = self.init_openai()
         self.provider_map["anthropic"] = self.init_anthropic()
         self.provider_map["minimax"] = self.init_minimax()
@@ -59,6 +59,47 @@ class HostingConfiguration:
 
         self.moderation_config = self.init_moderation_config()
 
+    def init_azure_openai(self) -> HostingProvider:
+        quota_unit = QuotaUnit.TIMES
+        if os.environ.get("HOSTED_AZURE_OPENAI_ENABLED") and os.environ.get("HOSTED_AZURE_OPENAI_ENABLED").lower() == 'true':
+            credentials = {
+                "openai_api_key": os.environ.get("HOSTED_AZURE_OPENAI_API_KEY"),
+                "openai_api_base": os.environ.get("HOSTED_AZURE_OPENAI_API_BASE"),
+                "base_model_name": "gpt-35-turbo"
+            }
+
+            quotas = []
+            hosted_quota_limit = int(os.environ.get("HOSTED_AZURE_OPENAI_QUOTA_LIMIT", "1000"))
+            if hosted_quota_limit != -1 or hosted_quota_limit > 0:
+                trial_quota = TrialHostingQuota(
+                    quota_limit=hosted_quota_limit,
+                    restrict_models=[
+                        RestrictModel(model="gpt-4", base_model_name="gpt-4", model_type=ModelType.LLM),
+                        RestrictModel(model="gpt-4-32k", base_model_name="gpt-4-32k", model_type=ModelType.LLM),
+                        RestrictModel(model="gpt-4-1106-preview", base_model_name="gpt-4-1106-preview", model_type=ModelType.LLM),
+                        RestrictModel(model="gpt-4-vision-preview", base_model_name="gpt-4-vision-preview", model_type=ModelType.LLM),
+                        RestrictModel(model="gpt-35-turbo", base_model_name="gpt-35-turbo", model_type=ModelType.LLM),
+                        RestrictModel(model="gpt-35-turbo-1106", base_model_name="gpt-35-turbo-1106", model_type=ModelType.LLM),
+                        RestrictModel(model="gpt-35-turbo-instruct", base_model_name="gpt-35-turbo-instruct", model_type=ModelType.LLM),
+                        RestrictModel(model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM),
+                        RestrictModel(model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM),
+                        RestrictModel(model="text-embedding-ada-002", base_model_name="text-embedding-ada-002", model_type=ModelType.TEXT_EMBEDDING),
+                    ]
+                )
+                quotas.append(trial_quota)
+
+            return HostingProvider(
+                enabled=True,
+                credentials=credentials,
+                quota_unit=quota_unit,
+                quotas=quotas
+            )
+
+        return HostingProvider(
+            enabled=False,
+            quota_unit=quota_unit,
+        )
+
     def init_openai(self) -> HostingProvider:
         quota_unit = QuotaUnit.TIMES
         if os.environ.get("HOSTED_OPENAI_ENABLED") and os.environ.get("HOSTED_OPENAI_ENABLED").lower() == 'true':
@@ -77,12 +118,12 @@ class HostingConfiguration:
             if hosted_quota_limit != -1 or hosted_quota_limit > 0:
                 trial_quota = TrialHostingQuota(
                     quota_limit=hosted_quota_limit,
-                    restrict_llms=[
-                        "gpt-3.5-turbo",
-                        "gpt-3.5-turbo-1106",
-                        "gpt-3.5-turbo-instruct",
-                        "gpt-3.5-turbo-16k",
-                        "text-davinci-003"
+                    restrict_models=[
+                        RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM),
+                        RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM),
+                        RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM),
+                        RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM),
+                        RestrictModel(model="text-davinci-003", model_type=ModelType.LLM),
                     ]
                 )
                 quotas.append(trial_quota)

+ 3 - 2
api/core/model_manager.py

@@ -144,7 +144,7 @@ class ModelInstance:
             user=user
         )
 
-    def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \
+    def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None, **params) \
             -> str:
         """
         Invoke large language model
@@ -161,7 +161,8 @@ class ModelInstance:
             model=self.model,
             credentials=self.credentials,
             file=file,
-            user=user
+            user=user,
+            **params
         )
 
 

+ 1 - 1
api/core/model_runtime/entities/model_entities.py

@@ -32,7 +32,7 @@ class ModelType(Enum):
             return cls.TEXT_EMBEDDING
         elif origin_model_type == 'reranking' or origin_model_type == cls.RERANK.value:
             return cls.RERANK
-        elif origin_model_type == cls.SPEECH2TEXT.value:
+        elif origin_model_type == 'speech2text' or origin_model_type == cls.SPEECH2TEXT.value:
             return cls.SPEECH2TEXT
         elif origin_model_type == cls.MODERATION.value:
             return cls.MODERATION

+ 3 - 3
api/core/model_runtime/model_providers/azure_openai/_constant.py

@@ -2,7 +2,7 @@ from pydantic import BaseModel
 
 from core.model_runtime.entities.llm_entities import LLMMode
 from core.model_runtime.entities.model_entities import ModelFeature, ModelType, FetchFrom, ParameterRule, \
-    DefaultParameterName, PriceConfig
+    DefaultParameterName, PriceConfig, ModelPropertyKey
 from core.model_runtime.entities.model_entities import AIModelEntity, I18nObject
 from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
 
@@ -502,8 +502,8 @@ EMBEDDING_BASE_MODELS = [
             fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
             model_type=ModelType.TEXT_EMBEDDING,
             model_properties={
-                'context_size': 8097,
-                'max_chunks': 32,
+                ModelPropertyKey.CONTEXT_SIZE: 8097,
+                ModelPropertyKey.MAX_CHUNKS: 32,
             },
             pricing=PriceConfig(
                 input=0.0001,

+ 5 - 5
api/core/model_runtime/model_providers/azure_openai/llm/llm.py

@@ -30,7 +30,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
                 stream: bool = True, user: Optional[str] = None) \
             -> Union[LLMResult, Generator]:
 
-        ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
+        ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
 
         if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
             # chat model
@@ -59,7 +59,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
     def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
                        tools: Optional[list[PromptMessageTool]] = None) -> int:
 
-        model_mode = self._get_ai_model_entity(credentials['base_model_name'], model).entity.model_properties.get(
+        model_mode = self._get_ai_model_entity(credentials.get('base_model_name'), model).entity.model_properties.get(
             ModelPropertyKey.MODE)
 
         if model_mode == LLMMode.CHAT.value:
@@ -79,7 +79,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
         if 'base_model_name' not in credentials:
             raise CredentialsValidateFailedError('Base Model Name is required')
 
-        ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
+        ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
 
         if not ai_model_entity:
             raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
@@ -109,8 +109,8 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
             raise CredentialsValidateFailedError(str(ex))
 
     def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
-        ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
-        return ai_model_entity.entity
+        ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
+        return ai_model_entity.entity if ai_model_entity else None
 
     def _generate(self, model: str, credentials: dict,
                   prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None,

+ 3 - 2
api/core/provider_manager.py

@@ -12,7 +12,8 @@ from core.entities.provider_entities import CustomConfiguration, CustomProviderC
 from core.helper import encrypter
 from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
 from core.model_runtime.entities.model_entities import ModelType
-from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType
+from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType, \
+    ConfigurateMethod
 from core.model_runtime.model_providers import model_provider_factory
 from extensions import ext_hosting_provider
 from extensions.ext_database import db
@@ -607,7 +608,7 @@ class ProviderManager:
                 quota_used=provider_record.quota_used,
                 quota_limit=provider_record.quota_limit,
                 is_valid=provider_record.quota_limit > provider_record.quota_used or provider_record.quota_limit == -1,
-                restrict_llms=provider_quota.restrict_llms
+                restrict_models=provider_quota.restrict_models
             )
 
             quota_configurations.append(quota_configuration)