base.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import base64
  2. from abc import ABC, abstractmethod
  3. from typing import Optional, Union
  4. from core.constant import llm_constant
  5. from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError
  6. from extensions.ext_database import db
  7. from libs import rsa
  8. from models.account import Tenant
  9. from models.provider import Provider, ProviderType, ProviderName
  10. class BaseProvider(ABC):
  11. def __init__(self, tenant_id: str):
  12. self.tenant_id = tenant_id
  13. def get_provider_api_key(self, model_id: Optional[str] = None, only_custom: bool = False) -> Union[str | dict]:
  14. """
  15. Returns the decrypted API key for the given tenant_id and provider_name.
  16. If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
  17. If the provider is not found or not valid, raises a ProviderTokenNotInitError.
  18. """
  19. provider = self.get_provider(only_custom)
  20. if not provider:
  21. raise ProviderTokenNotInitError(
  22. f"No valid {llm_constant.models[model_id]} model provider credentials found. "
  23. f"Please go to Settings -> Model Provider to complete your provider credentials."
  24. )
  25. if provider.provider_type == ProviderType.SYSTEM.value:
  26. quota_used = provider.quota_used if provider.quota_used is not None else 0
  27. quota_limit = provider.quota_limit if provider.quota_limit is not None else 0
  28. if model_id and model_id == 'gpt-4':
  29. raise ModelCurrentlyNotSupportError()
  30. if quota_used >= quota_limit:
  31. raise QuotaExceededError()
  32. return self.get_hosted_credentials()
  33. else:
  34. return self.get_decrypted_token(provider.encrypted_config)
  35. def get_provider(self, only_custom: bool = False) -> Optional[Provider]:
  36. """
  37. Returns the Provider instance for the given tenant_id and provider_name.
  38. If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
  39. """
  40. return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, only_custom)
  41. @classmethod
  42. def get_valid_provider(cls, tenant_id: str, provider_name: str = None, only_custom: bool = False) -> Optional[
  43. Provider]:
  44. """
  45. Returns the Provider instance for the given tenant_id and provider_name.
  46. If both CUSTOM and System providers exist.
  47. """
  48. query = db.session.query(Provider).filter(
  49. Provider.tenant_id == tenant_id
  50. )
  51. if provider_name:
  52. query = query.filter(Provider.provider_name == provider_name)
  53. if only_custom:
  54. query = query.filter(Provider.provider_type == ProviderType.CUSTOM.value)
  55. providers = query.order_by(Provider.provider_type.asc()).all()
  56. for provider in providers:
  57. if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config:
  58. return provider
  59. elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid:
  60. return provider
  61. return None
  62. def get_hosted_credentials(self) -> Union[str | dict]:
  63. raise ProviderTokenNotInitError(
  64. f"No valid {self.get_provider_name().value} model provider credentials found. "
  65. f"Please go to Settings -> Model Provider to complete your provider credentials."
  66. )
  67. def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
  68. """
  69. Returns the provider configs.
  70. """
  71. try:
  72. config = self.get_provider_api_key(only_custom=only_custom)
  73. except:
  74. config = ''
  75. if obfuscated:
  76. return self.obfuscated_token(config)
  77. return config
  78. def obfuscated_token(self, token: str):
  79. return token[:6] + '*' * (len(token) - 8) + token[-2:]
  80. def get_token_type(self):
  81. return str
  82. def get_encrypted_token(self, config: Union[dict | str]):
  83. return self.encrypt_token(config)
  84. def get_decrypted_token(self, token: str):
  85. return self.decrypt_token(token)
  86. def encrypt_token(self, token):
  87. tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
  88. encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
  89. return base64.b64encode(encrypted_token).decode()
  90. def decrypt_token(self, token):
  91. return rsa.decrypt(base64.b64decode(token), self.tenant_id)
  92. @abstractmethod
  93. def get_provider_name(self):
  94. raise NotImplementedError
  95. @abstractmethod
  96. def get_credentials(self, model_id: Optional[str] = None) -> dict:
  97. raise NotImplementedError
  98. @abstractmethod
  99. def get_models(self, model_id: Optional[str] = None) -> list[dict]:
  100. raise NotImplementedError
  101. @abstractmethod
  102. def config_validate(self, config: str):
  103. raise NotImplementedError