azure_provider.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import json
  2. from typing import Optional, Union
  3. import requests
  4. from core.llm.provider.base import BaseProvider
  5. from models.provider import ProviderName
  6. class AzureProvider(BaseProvider):
  7. def get_models(self, model_id: Optional[str] = None) -> list[dict]:
  8. credentials = self.get_credentials(model_id)
  9. url = "{}/openai/deployments?api-version={}".format(
  10. credentials.get('openai_api_base'),
  11. credentials.get('openai_api_version')
  12. )
  13. headers = {
  14. "api-key": credentials.get('openai_api_key'),
  15. "content-type": "application/json; charset=utf-8"
  16. }
  17. response = requests.get(url, headers=headers)
  18. if response.status_code == 200:
  19. result = response.json()
  20. return [{
  21. 'id': deployment['id'],
  22. 'name': '{} ({})'.format(deployment['id'], deployment['model'])
  23. } for deployment in result['data'] if deployment['status'] == 'succeeded']
  24. else:
  25. # TODO: optimize in future
  26. raise Exception('Failed to get deployments from Azure OpenAI. Status code: {}'.format(response.status_code))
  27. def get_credentials(self, model_id: Optional[str] = None) -> dict:
  28. """
  29. Returns the API credentials for Azure OpenAI as a dictionary.
  30. """
  31. config = self.get_provider_api_key(model_id=model_id)
  32. config['openai_api_type'] = 'azure'
  33. config['deployment_name'] = model_id.replace('.', '')
  34. return config
  35. def get_provider_name(self):
  36. return ProviderName.AZURE_OPENAI
  37. def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
  38. """
  39. Returns the provider configs.
  40. """
  41. try:
  42. config = self.get_provider_api_key()
  43. except:
  44. config = {
  45. 'openai_api_type': 'azure',
  46. 'openai_api_version': '2023-03-15-preview',
  47. 'openai_api_base': 'https://<your-domain-prefix>.openai.azure.com/',
  48. 'openai_api_key': ''
  49. }
  50. if obfuscated:
  51. if not config.get('openai_api_key'):
  52. config = {
  53. 'openai_api_type': 'azure',
  54. 'openai_api_version': '2023-03-15-preview',
  55. 'openai_api_base': 'https://<your-domain-prefix>.openai.azure.com/',
  56. 'openai_api_key': ''
  57. }
  58. config['openai_api_key'] = self.obfuscated_token(config.get('openai_api_key'))
  59. return config
  60. return config
  61. def get_token_type(self):
  62. # TODO: change to dict when implemented
  63. return lambda value: value
  64. def config_validate(self, config: Union[dict | str]):
  65. """
  66. Validates the given config.
  67. """
  68. # TODO: implement
  69. pass
  70. def get_encrypted_token(self, config: Union[dict | str]):
  71. """
  72. Returns the encrypted token.
  73. """
  74. return json.dumps({
  75. 'openai_api_type': 'azure',
  76. 'openai_api_version': '2023-03-15-preview',
  77. 'openai_api_base': config['openai_api_base'],
  78. 'openai_api_key': self.encrypt_token(config['openai_api_key'])
  79. })
  80. def get_decrypted_token(self, token: str):
  81. """
  82. Returns the decrypted token.
  83. """
  84. config = json.loads(token)
  85. config['openai_api_key'] = self.decrypt_token(config['openai_api_key'])
  86. return config