providers.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. # -*- coding:utf-8 -*-
  2. from flask_login import login_required, current_user
  3. from flask_restful import Resource, reqparse
  4. from werkzeug.exceptions import Forbidden
  5. from controllers.console import api
  6. from controllers.console.setup import setup_required
  7. from controllers.console.wraps import account_initialization_required
  8. from core.model_providers.providers.base import CredentialsValidateFailedError
  9. from models.provider import ProviderType
  10. from services.provider_service import ProviderService
  11. class ProviderListApi(Resource):
  12. @setup_required
  13. @login_required
  14. @account_initialization_required
  15. def get(self):
  16. tenant_id = current_user.current_tenant_id
  17. """
  18. If the type is AZURE_OPENAI, decode and return the four fields of azure_api_type, azure_api_version:,
  19. azure_api_base, azure_api_key as an object, where azure_api_key displays the first 6 bits in plaintext, and the
  20. rest is replaced by * and the last two bits are displayed in plaintext
  21. If the type is other, decode and return the Token field directly, the field displays the first 6 bits in
  22. plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
  23. """
  24. provider_service = ProviderService()
  25. provider_info_list = provider_service.get_provider_list(tenant_id)
  26. provider_list = [
  27. {
  28. 'provider_name': p['provider_name'],
  29. 'provider_type': p['provider_type'],
  30. 'is_valid': p['is_valid'],
  31. 'last_used': p['last_used'],
  32. 'is_enabled': p['is_valid'],
  33. **({
  34. 'quota_type': p['quota_type'],
  35. 'quota_limit': p['quota_limit'],
  36. 'quota_used': p['quota_used']
  37. } if p['provider_type'] == ProviderType.SYSTEM.value else {}),
  38. 'token': (p['config'] if p['provider_name'] != 'openai' else p['config']['openai_api_key'])
  39. if p['config'] else None
  40. }
  41. for name, provider_info in provider_info_list.items()
  42. for p in provider_info['providers']
  43. ]
  44. return provider_list
  45. class ProviderTokenApi(Resource):
  46. @setup_required
  47. @login_required
  48. @account_initialization_required
  49. def post(self, provider):
  50. # The role of the current user in the ta table must be admin or owner
  51. if current_user.current_tenant.current_role not in ['admin', 'owner']:
  52. raise Forbidden()
  53. parser = reqparse.RequestParser()
  54. parser.add_argument('token', required=True, nullable=False, location='json')
  55. args = parser.parse_args()
  56. if provider == 'openai':
  57. args['token'] = {
  58. 'openai_api_key': args['token']
  59. }
  60. provider_service = ProviderService()
  61. try:
  62. provider_service.save_custom_provider_config(
  63. tenant_id=current_user.current_tenant_id,
  64. provider_name=provider,
  65. config=args['token']
  66. )
  67. except CredentialsValidateFailedError as ex:
  68. raise ValueError(str(ex))
  69. return {'result': 'success'}, 201
  70. class ProviderTokenValidateApi(Resource):
  71. @setup_required
  72. @login_required
  73. @account_initialization_required
  74. def post(self, provider):
  75. parser = reqparse.RequestParser()
  76. parser.add_argument('token', required=True, nullable=False, location='json')
  77. args = parser.parse_args()
  78. provider_service = ProviderService()
  79. if provider == 'openai':
  80. args['token'] = {
  81. 'openai_api_key': args['token']
  82. }
  83. result = True
  84. error = None
  85. try:
  86. provider_service.custom_provider_config_validate(
  87. provider_name=provider,
  88. config=args['token']
  89. )
  90. except CredentialsValidateFailedError as ex:
  91. result = False
  92. error = str(ex)
  93. response = {'result': 'success' if result else 'error'}
  94. if not result:
  95. response['error'] = error
  96. return response
  97. api.add_resource(ProviderTokenApi, '/workspaces/current/providers/<provider>/token',
  98. endpoint='workspaces_current_providers_token') # PUT for updating provider token
  99. api.add_resource(ProviderTokenValidateApi, '/workspaces/current/providers/<provider>/token-validate',
  100. endpoint='workspaces_current_providers_token_validate') # POST for validating provider token
  101. api.add_resource(ProviderListApi, '/workspaces/current/providers') # GET for getting providers list