浏览代码

Feat: support azure openai for temporary (#101)

John Wang 1 年之前
父节点
当前提交
f68b05d5ec

+ 5 - 0
api/config.py

@@ -47,6 +47,7 @@ DEFAULTS = {
     'PDF_PREVIEW': 'True',
     'LOG_LEVEL': 'INFO',
     'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False',
+    'DEFAULT_LLM_PROVIDER': 'openai'
 }
 
 
@@ -181,6 +182,10 @@ class Config:
         # You could disable it for compatibility with certain OpenAPI providers
         self.DISABLE_PROVIDER_CONFIG_VALIDATION = get_bool_env('DISABLE_PROVIDER_CONFIG_VALIDATION')
 
+        # For temp use only
+        # set default LLM provider, default is 'openai', support `azure_openai`
+        self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER')
+
 class CloudEditionConfig(Config):
 
     def __init__(self):

+ 29 - 15
api/controllers/console/workspace/providers.py

@@ -82,29 +82,33 @@ class ProviderTokenApi(Resource):
 
         args = parser.parse_args()
 
-        if not args['token']:
-            raise ValueError('Token is empty')
-
-        try:
-            ProviderService.validate_provider_configs(
+        if args['token']:
+            try:
+                ProviderService.validate_provider_configs(
+                    tenant=current_user.current_tenant,
+                    provider_name=ProviderName(provider),
+                    configs=args['token']
+                )
+                token_is_valid = True
+            except ValidateFailedError:
+                token_is_valid = False
+
+            base64_encrypted_token = ProviderService.get_encrypted_token(
                 tenant=current_user.current_tenant,
                 provider_name=ProviderName(provider),
                 configs=args['token']
             )
-            token_is_valid = True
-        except ValidateFailedError:
+        else:
+            base64_encrypted_token = None
             token_is_valid = False
 
         tenant = current_user.current_tenant
 
-        base64_encrypted_token = ProviderService.get_encrypted_token(
-            tenant=current_user.current_tenant,
-            provider_name=ProviderName(provider),
-            configs=args['token']
-        )
-
-        provider_model = Provider.query.filter_by(tenant_id=tenant.id, provider_name=provider,
-                                                  provider_type=ProviderType.CUSTOM.value).first()
+        provider_model = db.session.query(Provider).filter(
+                Provider.tenant_id == tenant.id,
+                Provider.provider_name == provider,
+                Provider.provider_type == ProviderType.CUSTOM.value
+            ).first()
 
         # Only allow updating token for CUSTOM provider type
         if provider_model:
@@ -117,6 +121,16 @@ class ProviderTokenApi(Resource):
                                       is_valid=token_is_valid)
             db.session.add(provider_model)
 
+        if provider_model.is_valid:
+            other_providers = db.session.query(Provider).filter(
+                Provider.tenant_id == tenant.id,
+                Provider.provider_name != provider,
+                Provider.provider_type == ProviderType.CUSTOM.value
+            ).all()
+
+            for other_provider in other_providers:
+                other_provider.is_valid = False
+
         db.session.commit()
 
         if provider in [ProviderName.ANTHROPIC.value, ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,

+ 48 - 24
api/core/embedding/openai_embedding.py

@@ -11,9 +11,10 @@ from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_except
 
 @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
 def get_embedding(
-    text: str,
-    engine: Optional[str] = None,
-    openai_api_key: Optional[str] = None,
+        text: str,
+        engine: Optional[str] = None,
+        api_key: Optional[str] = None,
+        **kwargs
 ) -> List[float]:
     """Get embedding.
 
@@ -25,11 +26,12 @@ def get_embedding(
 
     """
     text = text.replace("\n", " ")
-    return openai.Embedding.create(input=[text], engine=engine, api_key=openai_api_key)["data"][0]["embedding"]
+    return openai.Embedding.create(input=[text], engine=engine, api_key=api_key, **kwargs)["data"][0]["embedding"]
 
 
 @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
-async def aget_embedding(text: str, engine: Optional[str] = None, openai_api_key: Optional[str] = None) -> List[float]:
+async def aget_embedding(text: str, engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs) -> List[
+    float]:
     """Asynchronously get embedding.
 
     NOTE: Copied from OpenAI's embedding utils:
@@ -42,16 +44,17 @@ async def aget_embedding(text: str, engine: Optional[str] = None, openai_api_key
     # replace newlines, which can negatively affect performance.
     text = text.replace("\n", " ")
 
-    return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=openai_api_key))["data"][0][
+    return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=api_key, **kwargs))["data"][0][
         "embedding"
     ]
 
 
 @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
 def get_embeddings(
-    list_of_text: List[str],
-    engine: Optional[str] = None,
-    openai_api_key: Optional[str] = None
+        list_of_text: List[str],
+        engine: Optional[str] = None,
+        api_key: Optional[str] = None,
+        **kwargs
 ) -> List[List[float]]:
     """Get embeddings.
 
@@ -67,14 +70,14 @@ def get_embeddings(
     # replace newlines, which can negatively affect performance.
     list_of_text = [text.replace("\n", " ") for text in list_of_text]
 
-    data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=openai_api_key).data
+    data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=api_key, **kwargs).data
     data = sorted(data, key=lambda x: x["index"])  # maintain the same order as input.
     return [d["embedding"] for d in data]
 
 
 @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
 async def aget_embeddings(
-    list_of_text: List[str], engine: Optional[str] = None, openai_api_key: Optional[str] = None
+        list_of_text: List[str], engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs
 ) -> List[List[float]]:
     """Asynchronously get embeddings.
 
@@ -90,7 +93,7 @@ async def aget_embeddings(
     # replace newlines, which can negatively affect performance.
     list_of_text = [text.replace("\n", " ") for text in list_of_text]
 
-    data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=openai_api_key)).data
+    data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=api_key, **kwargs)).data
     data = sorted(data, key=lambda x: x["index"])  # maintain the same order as input.
     return [d["embedding"] for d in data]
 
@@ -98,19 +101,30 @@ async def aget_embeddings(
 class OpenAIEmbedding(BaseEmbedding):
 
     def __init__(
-        self,
-        mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE,
-        model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002,
-        deployment_name: Optional[str] = None,
-        openai_api_key: Optional[str] = None,
-        **kwargs: Any,
+            self,
+            mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE,
+            model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002,
+            deployment_name: Optional[str] = None,
+            openai_api_key: Optional[str] = None,
+            **kwargs: Any,
     ) -> None:
         """Init params."""
-        super().__init__(**kwargs)
+        new_kwargs = {}
+
+        if 'embed_batch_size' in kwargs:
+            new_kwargs['embed_batch_size'] = kwargs['embed_batch_size']
+
+        if 'tokenizer' in kwargs:
+            new_kwargs['tokenizer'] = kwargs['tokenizer']
+
+        super().__init__(**new_kwargs)
         self.mode = OpenAIEmbeddingMode(mode)
         self.model = OpenAIEmbeddingModelType(model)
         self.deployment_name = deployment_name
         self.openai_api_key = openai_api_key
+        self.openai_api_type = kwargs.get('openai_api_type')
+        self.openai_api_version = kwargs.get('openai_api_version')
+        self.openai_api_base = kwargs.get('openai_api_base')
 
     @handle_llm_exceptions
     def _get_query_embedding(self, query: str) -> List[float]:
@@ -122,7 +136,9 @@ class OpenAIEmbedding(BaseEmbedding):
             if key not in _QUERY_MODE_MODEL_DICT:
                 raise ValueError(f"Invalid mode, model combination: {key}")
             engine = _QUERY_MODE_MODEL_DICT[key]
-        return get_embedding(query, engine=engine, openai_api_key=self.openai_api_key)
+        return get_embedding(query, engine=engine, api_key=self.openai_api_key,
+                             api_type=self.openai_api_type, api_version=self.openai_api_version,
+                             api_base=self.openai_api_base)
 
     def _get_text_embedding(self, text: str) -> List[float]:
         """Get text embedding."""
@@ -133,7 +149,9 @@ class OpenAIEmbedding(BaseEmbedding):
             if key not in _TEXT_MODE_MODEL_DICT:
                 raise ValueError(f"Invalid mode, model combination: {key}")
             engine = _TEXT_MODE_MODEL_DICT[key]
-        return get_embedding(text, engine=engine, openai_api_key=self.openai_api_key)
+        return get_embedding(text, engine=engine, api_key=self.openai_api_key,
+                             api_type=self.openai_api_type, api_version=self.openai_api_version,
+                             api_base=self.openai_api_base)
 
     async def _aget_text_embedding(self, text: str) -> List[float]:
         """Asynchronously get text embedding."""
@@ -144,7 +162,9 @@ class OpenAIEmbedding(BaseEmbedding):
             if key not in _TEXT_MODE_MODEL_DICT:
                 raise ValueError(f"Invalid mode, model combination: {key}")
             engine = _TEXT_MODE_MODEL_DICT[key]
-        return await aget_embedding(text, engine=engine, openai_api_key=self.openai_api_key)
+        return await aget_embedding(text, engine=engine, api_key=self.openai_api_key,
+                                    api_type=self.openai_api_type, api_version=self.openai_api_version,
+                                    api_base=self.openai_api_base)
 
     def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
         """Get text embeddings.
@@ -160,7 +180,9 @@ class OpenAIEmbedding(BaseEmbedding):
             if key not in _TEXT_MODE_MODEL_DICT:
                 raise ValueError(f"Invalid mode, model combination: {key}")
             engine = _TEXT_MODE_MODEL_DICT[key]
-        embeddings = get_embeddings(texts, engine=engine, openai_api_key=self.openai_api_key)
+        embeddings = get_embeddings(texts, engine=engine, api_key=self.openai_api_key,
+                                    api_type=self.openai_api_type, api_version=self.openai_api_version,
+                                    api_base=self.openai_api_base)
         return embeddings
 
     async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
@@ -172,5 +194,7 @@ class OpenAIEmbedding(BaseEmbedding):
             if key not in _TEXT_MODE_MODEL_DICT:
                 raise ValueError(f"Invalid mode, model combination: {key}")
             engine = _TEXT_MODE_MODEL_DICT[key]
-        embeddings = await aget_embeddings(texts, engine=engine, openai_api_key=self.openai_api_key)
+        embeddings = await aget_embeddings(texts, engine=engine, api_key=self.openai_api_key,
+                                           api_type=self.openai_api_type, api_version=self.openai_api_version,
+                                           api_base=self.openai_api_base)
         return embeddings

+ 3 - 0
api/core/index/index_builder.py

@@ -33,8 +33,11 @@ class IndexBuilder:
             max_chunk_overlap=20
         )
 
+        provider = LLMBuilder.get_default_provider(tenant_id)
+
         model_credentials = LLMBuilder.get_model_credentials(
             tenant_id=tenant_id,
+            model_provider=provider,
             model_name='text-embedding-ada-002'
         )
 

+ 34 - 9
api/core/llm/llm_builder.py

@@ -4,9 +4,14 @@ from langchain.callbacks import CallbackManager
 from langchain.llms.fake import FakeListLLM
 
 from core.constant import llm_constant
+from core.llm.error import ProviderTokenNotInitError
+from core.llm.provider.base import BaseProvider
 from core.llm.provider.llm_provider_service import LLMProviderService
+from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI
+from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI
 from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
 from core.llm.streamable_open_ai import StreamableOpenAI
+from models.provider import ProviderType
 
 
 class LLMBuilder:
@@ -31,16 +36,23 @@ class LLMBuilder:
         if model_name == 'fake':
             return FakeListLLM(responses=[])
 
+        provider = cls.get_default_provider(tenant_id)
+
         mode = cls.get_mode_by_model(model_name)
         if mode == 'chat':
-            # llm_cls = StreamableAzureChatOpenAI
-            llm_cls = StreamableChatOpenAI
+            if provider == 'openai':
+                llm_cls = StreamableChatOpenAI
+            else:
+                llm_cls = StreamableAzureChatOpenAI
         elif mode == 'completion':
-            llm_cls = StreamableOpenAI
+            if provider == 'openai':
+                llm_cls = StreamableOpenAI
+            else:
+                llm_cls = StreamableAzureOpenAI
         else:
             raise ValueError(f"model name {model_name} is not supported.")
 
-        model_credentials = cls.get_model_credentials(tenant_id, model_name)
+        model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
 
         return llm_cls(
             model_name=model_name,
@@ -86,18 +98,31 @@ class LLMBuilder:
             raise ValueError(f"model name {model_name} is not supported.")
 
     @classmethod
-    def get_model_credentials(cls, tenant_id: str, model_name: str) -> dict:
+    def get_model_credentials(cls, tenant_id: str, model_provider: str, model_name: str) -> dict:
         """
         Returns the API credentials for the given tenant_id and model_name, based on the model's provider.
         Raises an exception if the model_name is not found or if the provider is not found.
         """
         if not model_name:
             raise Exception('model name not found')
+        #
+        # if model_name not in llm_constant.models:
+        #     raise Exception('model {} not found'.format(model_name))
 
-        if model_name not in llm_constant.models:
-            raise Exception('model {} not found'.format(model_name))
-
-        model_provider = llm_constant.models[model_name]
+        # model_provider = llm_constant.models[model_name]
 
         provider_service = LLMProviderService(tenant_id=tenant_id, provider_name=model_provider)
         return provider_service.get_credentials(model_name)
+
+    @classmethod
+    def get_default_provider(cls, tenant_id: str) -> str:
+        provider = BaseProvider.get_valid_provider(tenant_id)
+        if not provider:
+            raise ProviderTokenNotInitError()
+
+        if provider.provider_type == ProviderType.SYSTEM.value:
+            provider_name = 'openai'
+        else:
+            provider_name = provider.provider_name
+
+        return provider_name

+ 4 - 6
api/core/llm/provider/azure_provider.py

@@ -36,10 +36,9 @@ class AzureProvider(BaseProvider):
         """
         Returns the API credentials for Azure OpenAI as a dictionary.
         """
-        encrypted_config = self.get_provider_api_key(model_id=model_id)
-        config = json.loads(encrypted_config)
+        config = self.get_provider_api_key(model_id=model_id)
         config['openai_api_type'] = 'azure'
-        config['deployment_name'] = model_id
+        config['deployment_name'] = model_id.replace('.', '')
         return config
 
     def get_provider_name(self):
@@ -51,12 +50,11 @@ class AzureProvider(BaseProvider):
         """
         try:
             config = self.get_provider_api_key()
-            config = json.loads(config)
         except:
             config = {
                 'openai_api_type': 'azure',
                 'openai_api_version': '2023-03-15-preview',
-                'openai_api_base': 'https://foo.microsoft.com/bar',
+                'openai_api_base': 'https://<your-domain-prefix>.openai.azure.com/',
                 'openai_api_key': ''
             }
 
@@ -65,7 +63,7 @@ class AzureProvider(BaseProvider):
                 config = {
                     'openai_api_type': 'azure',
                     'openai_api_version': '2023-03-15-preview',
-                    'openai_api_base': 'https://foo.microsoft.com/bar',
+                    'openai_api_base': 'https://<your-domain-prefix>.openai.azure.com/',
                     'openai_api_key': ''
                 }
 

+ 22 - 10
api/core/llm/provider/base.py

@@ -14,7 +14,7 @@ class BaseProvider(ABC):
     def __init__(self, tenant_id: str):
         self.tenant_id = tenant_id
 
-    def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> str:
+    def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> Union[str | dict]:
         """
         Returns the decrypted API key for the given tenant_id and provider_name.
         If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
@@ -43,23 +43,35 @@ class BaseProvider(ABC):
         Returns the Provider instance for the given tenant_id and provider_name.
         If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
         """
-        providers = db.session.query(Provider).filter(
-            Provider.tenant_id == self.tenant_id,
-            Provider.provider_name == self.get_provider_name().value
-        ).order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all()
+        return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, prefer_custom)
+
+    @classmethod
+    def get_valid_provider(cls, tenant_id: str, provider_name: str = None, prefer_custom: bool = False) -> Optional[Provider]:
+        """
+        Returns the Provider instance for the given tenant_id and provider_name.
+        If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
+        """
+        query = db.session.query(Provider).filter(
+            Provider.tenant_id == tenant_id
+        )
+
+        if provider_name:
+            query = query.filter(Provider.provider_name == provider_name)
+
+        providers = query.order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all()
 
         custom_provider = None
         system_provider = None
 
         for provider in providers:
-            if provider.provider_type == ProviderType.CUSTOM.value:
+            if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config:
                 custom_provider = provider
-            elif provider.provider_type == ProviderType.SYSTEM.value:
+            elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid:
                 system_provider = provider
 
-        if custom_provider and custom_provider.is_valid and custom_provider.encrypted_config:
+        if custom_provider:
             return custom_provider
-        elif system_provider and system_provider.is_valid:
+        elif system_provider:
             return system_provider
         else:
             return None
@@ -80,7 +92,7 @@ class BaseProvider(ABC):
         try:
             config = self.get_provider_api_key()
         except:
-            config = 'THIS-IS-A-MOCK-TOKEN'
+            config = ''
 
         if obfuscated:
             return self.obfuscated_token(config)

+ 40 - 2
api/core/llm/streamable_azure_chat_open_ai.py

@@ -1,12 +1,50 @@
-import requests
 from langchain.schema import BaseMessage, ChatResult, LLMResult
 from langchain.chat_models import AzureChatOpenAI
-from typing import Optional, List
+from typing import Optional, List, Dict, Any
+
+from pydantic import root_validator
 
 from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
 
 
 class StreamableAzureChatOpenAI(AzureChatOpenAI):
+    @root_validator()
+    def validate_environment(cls, values: Dict) -> Dict:
+        """Validate that api key and python package exists in environment."""
+        try:
+            import openai
+        except ImportError:
+            raise ValueError(
+                "Could not import openai python package. "
+                "Please install it with `pip install openai`."
+            )
+        try:
+            values["client"] = openai.ChatCompletion
+        except AttributeError:
+            raise ValueError(
+                "`openai` has no `ChatCompletion` attribute, this is likely "
+                "due to an old version of the openai package. Try upgrading it "
+                "with `pip install --upgrade openai`."
+            )
+        if values["n"] < 1:
+            raise ValueError("n must be at least 1.")
+        if values["n"] > 1 and values["streaming"]:
+            raise ValueError("n must be 1 when streaming.")
+        return values
+
+    @property
+    def _default_params(self) -> Dict[str, Any]:
+        """Get the default parameters for calling OpenAI API."""
+        return {
+            **super()._default_params,
+            "engine": self.deployment_name,
+            "api_type": self.openai_api_type,
+            "api_base": self.openai_api_base,
+            "api_version": self.openai_api_version,
+            "api_key": self.openai_api_key,
+            "organization": self.openai_organization if self.openai_organization else None,
+        }
+
     def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
         """Get the number of tokens in a list of messages.
 

+ 64 - 0
api/core/llm/streamable_azure_open_ai.py

@@ -0,0 +1,64 @@
+import os
+
+from langchain.llms import AzureOpenAI
+from langchain.schema import LLMResult
+from typing import Optional, List, Dict, Mapping, Any
+
+from pydantic import root_validator
+
+from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
+
+
+class StreamableAzureOpenAI(AzureOpenAI):
+    openai_api_type: str = "azure"
+    openai_api_version: str = ""
+
+    @root_validator()
+    def validate_environment(cls, values: Dict) -> Dict:
+        """Validate that api key and python package exists in environment."""
+        try:
+            import openai
+
+            values["client"] = openai.Completion
+        except ImportError:
+            raise ValueError(
+                "Could not import openai python package. "
+                "Please install it with `pip install openai`."
+            )
+        if values["streaming"] and values["n"] > 1:
+            raise ValueError("Cannot stream results when n > 1.")
+        if values["streaming"] and values["best_of"] > 1:
+            raise ValueError("Cannot stream results when best_of > 1.")
+        return values
+
+    @property
+    def _invocation_params(self) -> Dict[str, Any]:
+        return {**super()._invocation_params, **{
+            "api_type": self.openai_api_type,
+            "api_base": self.openai_api_base,
+            "api_version": self.openai_api_version,
+            "api_key": self.openai_api_key,
+            "organization": self.openai_organization if self.openai_organization else None,
+        }}
+
+    @property
+    def _identifying_params(self) -> Mapping[str, Any]:
+        return {**super()._identifying_params, **{
+            "api_type": self.openai_api_type,
+            "api_base": self.openai_api_base,
+            "api_version": self.openai_api_version,
+            "api_key": self.openai_api_key,
+            "organization": self.openai_organization if self.openai_organization else None,
+        }}
+
+    @handle_llm_exceptions
+    def generate(
+            self, prompts: List[str], stop: Optional[List[str]] = None
+    ) -> LLMResult:
+        return super().generate(prompts, stop)
+
+    @handle_llm_exceptions_async
+    async def agenerate(
+            self, prompts: List[str], stop: Optional[List[str]] = None
+    ) -> LLMResult:
+        return await super().agenerate(prompts, stop)

+ 41 - 1
api/core/llm/streamable_chat_open_ai.py

@@ -1,12 +1,52 @@
+import os
+
 from langchain.schema import BaseMessage, ChatResult, LLMResult
 from langchain.chat_models import ChatOpenAI
-from typing import Optional, List
+from typing import Optional, List, Dict, Any
+
+from pydantic import root_validator
 
 from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
 
 
 class StreamableChatOpenAI(ChatOpenAI):
 
+    @root_validator()
+    def validate_environment(cls, values: Dict) -> Dict:
+        """Validate that api key and python package exists in environment."""
+        try:
+            import openai
+        except ImportError:
+            raise ValueError(
+                "Could not import openai python package. "
+                "Please install it with `pip install openai`."
+            )
+        try:
+            values["client"] = openai.ChatCompletion
+        except AttributeError:
+            raise ValueError(
+                "`openai` has no `ChatCompletion` attribute, this is likely "
+                "due to an old version of the openai package. Try upgrading it "
+                "with `pip install --upgrade openai`."
+            )
+        if values["n"] < 1:
+            raise ValueError("n must be at least 1.")
+        if values["n"] > 1 and values["streaming"]:
+            raise ValueError("n must be 1 when streaming.")
+        return values
+
+    @property
+    def _default_params(self) -> Dict[str, Any]:
+        """Get the default parameters for calling OpenAI API."""
+        return {
+            **super()._default_params,
+            "api_type": 'openai',
+            "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
+            "api_version": None,
+            "api_key": self.openai_api_key,
+            "organization": self.openai_organization if self.openai_organization else None,
+        }
+
     def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
         """Get the number of tokens in a list of messages.
 

+ 43 - 1
api/core/llm/streamable_open_ai.py

@@ -1,12 +1,54 @@
+import os
+
 from langchain.schema import LLMResult
-from typing import Optional, List
+from typing import Optional, List, Dict, Any, Mapping
 from langchain import OpenAI
+from pydantic import root_validator
 
 from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
 
 
 class StreamableOpenAI(OpenAI):
 
+    @root_validator()
+    def validate_environment(cls, values: Dict) -> Dict:
+        """Validate that api key and python package exists in environment."""
+        try:
+            import openai
+
+            values["client"] = openai.Completion
+        except ImportError:
+            raise ValueError(
+                "Could not import openai python package. "
+                "Please install it with `pip install openai`."
+            )
+        if values["streaming"] and values["n"] > 1:
+            raise ValueError("Cannot stream results when n > 1.")
+        if values["streaming"] and values["best_of"] > 1:
+            raise ValueError("Cannot stream results when best_of > 1.")
+        return values
+
+    @property
+    def _invocation_params(self) -> Dict[str, Any]:
+        return {**super()._invocation_params, **{
+            "api_type": 'openai',
+            "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
+            "api_version": None,
+            "api_key": self.openai_api_key,
+            "organization": self.openai_organization if self.openai_organization else None,
+        }}
+
+    @property
+    def _identifying_params(self) -> Mapping[str, Any]:
+        return {**super()._identifying_params, **{
+            "api_type": 'openai',
+            "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
+            "api_version": None,
+            "api_key": self.openai_api_key,
+            "organization": self.openai_organization if self.openai_organization else None,
+        }}
+
+
     @handle_llm_exceptions
     def generate(
             self, prompts: List[str], stop: Optional[List[str]] = None

+ 8 - 22
web/app/components/header/account-setting/provider-page/azure-provider/index.tsx

@@ -20,7 +20,7 @@ const AzureProvider = ({
   const [token, setToken] = useState(provider.token as ProviderAzureToken || {})
   const handleFocus = () => {
     if (token === provider.token) {
-      token.azure_api_key = ''
+      token.openai_api_key = ''
       setToken({...token})
       onTokenChange({...token})
     }
@@ -35,31 +35,17 @@ const AzureProvider = ({
     <div className='px-4 py-3'>
       <ProviderInput 
         className='mb-4'
-        name={t('common.provider.azure.resourceName')}
-        placeholder={t('common.provider.azure.resourceNamePlaceholder')}
-        value={token.azure_api_base}
-        onChange={(v) => handleChange('azure_api_base', v)}
-      />
-      <ProviderInput 
-        className='mb-4'
-        name={t('common.provider.azure.deploymentId')}
-        placeholder={t('common.provider.azure.deploymentIdPlaceholder')}
-        value={token.azure_api_type}
-        onChange={v => handleChange('azure_api_type', v)}
-      />
-      <ProviderInput 
-        className='mb-4'
-        name={t('common.provider.azure.apiVersion')}
-        placeholder={t('common.provider.azure.apiVersionPlaceholder')}
-        value={token.azure_api_version}
-        onChange={v => handleChange('azure_api_version', v)}
+        name={t('common.provider.azure.apiBase')}
+        placeholder={t('common.provider.azure.apiBasePlaceholder')}
+        value={token.openai_api_base}
+        onChange={(v) => handleChange('openai_api_base', v)}
       />
       <ProviderValidateTokenInput 
         className='mb-4'
         name={t('common.provider.azure.apiKey')}
         placeholder={t('common.provider.azure.apiKeyPlaceholder')}
-        value={token.azure_api_key}
-        onChange={v => handleChange('azure_api_key', v)}
+        value={token.openai_api_key}
+        onChange={v => handleChange('openai_api_key', v)}
         onFocus={handleFocus}
         onValidatedStatus={onValidatedStatus}
         providerName={provider.provider_name}
@@ -72,4 +58,4 @@ const AzureProvider = ({
   )
 }
 
-export default AzureProvider
+export default AzureProvider

+ 3 - 3
web/app/components/header/account-setting/provider-page/provider-item/index.tsx

@@ -33,12 +33,12 @@ const ProviderItem = ({
   const { notify } = useContext(ToastContext)
   const [token, setToken] = useState<ProviderAzureToken | string>(
     provider.provider_name === 'azure_openai' 
-      ? { azure_api_base: '', azure_api_type: '', azure_api_version: '', azure_api_key: '' } 
+      ? { openai_api_base: '', openai_api_key: '' }
       : ''
     )
   const id = `${provider.provider_name}-${provider.provider_type}`
   const isOpen = id === activeId
-  const providerKey = provider.provider_name === 'azure_openai' ? (provider.token as ProviderAzureToken)?.azure_api_key  : provider.token
+  const providerKey = provider.provider_name === 'azure_openai' ? (provider.token as ProviderAzureToken)?.openai_api_key  : provider.token
   const comingSoon = false
   const isValid = provider.is_valid
 
@@ -135,4 +135,4 @@ const ProviderItem = ({
   )
 }
 
-export default ProviderItem
+export default ProviderItem

+ 2 - 6
web/i18n/lang/common.en.ts

@@ -148,12 +148,8 @@ const translation = {
     editKey: 'Edit',
     invalidApiKey: 'Invalid API key',
     azure: {
-      resourceName: 'Resource Name',
-      resourceNamePlaceholder: 'The name of your Azure OpenAI Resource.',
-      deploymentId: 'Deployment ID',
-      deploymentIdPlaceholder: 'The deployment name you chose when you deployed the model.',
-      apiVersion: 'API Version',
-      apiVersionPlaceholder: 'The API version to use for this operation.',
+      apiBase: 'API Base',
+      apiBasePlaceholder: 'The API Base URL of your Azure OpenAI Resource.',
       apiKey: 'API Key',
       apiKeyPlaceholder: 'Enter your API key here',
       helpTip: 'Learn Azure OpenAI Service',

+ 3 - 7
web/i18n/lang/common.zh.ts

@@ -149,14 +149,10 @@ const translation = {
     editKey: '编辑',
     invalidApiKey: '无效的 API 密钥',
     azure: {
-      resourceName: 'Resource Name',
-      resourceNamePlaceholder: 'The name of your Azure OpenAI Resource.',
-      deploymentId: 'Deployment ID',
-      deploymentIdPlaceholder: 'The deployment name you chose when you deployed the model.',
-      apiVersion: 'API Version',
-      apiVersionPlaceholder: 'The API version to use for this operation.',
+      apiBase: 'API Base',
+      apiBasePlaceholder: '输入您的 Azure OpenAI API Base 地址',
       apiKey: 'API Key',
-      apiKeyPlaceholder: 'Enter your API key here',
+      apiKeyPlaceholder: '输入你的 API 密钥',
       helpTip: '了解 Azure OpenAI Service',
     },
     openaiHosted: {

+ 2 - 4
web/models/common.ts

@@ -55,10 +55,8 @@ export type Member = Pick<UserProfileResponse, 'id' | 'name' | 'email' | 'last_l
 }
 
 export type ProviderAzureToken = {
-  azure_api_base: string
-  azure_api_key: string
-  azure_api_type: string
-  azure_api_version: string
+  openai_api_base: string
+  openai_api_key: string
 }
 export type Provider = {
   provider_name: string