providers.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. # -*- coding:utf-8 -*-
  2. import base64
  3. import json
  4. import logging
  5. from flask import current_app
  6. from flask_login import login_required, current_user
  7. from flask_restful import Resource, reqparse, abort
  8. from werkzeug.exceptions import Forbidden
  9. from controllers.console import api
  10. from controllers.console.setup import setup_required
  11. from controllers.console.wraps import account_initialization_required
  12. from core.llm.provider.errors import ValidateFailedError
  13. from extensions.ext_database import db
  14. from libs import rsa
  15. from models.provider import Provider, ProviderType, ProviderName
  16. from services.provider_service import ProviderService
  17. class ProviderListApi(Resource):
  18. @setup_required
  19. @login_required
  20. @account_initialization_required
  21. def get(self):
  22. tenant_id = current_user.current_tenant_id
  23. """
  24. If the type is AZURE_OPENAI, decode and return the four fields of azure_api_type, azure_api_version:,
  25. azure_api_base, azure_api_key as an object, where azure_api_key displays the first 6 bits in plaintext, and the
  26. rest is replaced by * and the last two bits are displayed in plaintext
  27. If the type is other, decode and return the Token field directly, the field displays the first 6 bits in
  28. plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
  29. """
  30. ProviderService.init_supported_provider(current_user.current_tenant)
  31. providers = Provider.query.filter_by(tenant_id=tenant_id).all()
  32. provider_list = [
  33. {
  34. 'provider_name': p.provider_name,
  35. 'provider_type': p.provider_type,
  36. 'is_valid': p.is_valid,
  37. 'last_used': p.last_used,
  38. 'is_enabled': p.is_enabled,
  39. **({
  40. 'quota_type': p.quota_type,
  41. 'quota_limit': p.quota_limit,
  42. 'quota_used': p.quota_used
  43. } if p.provider_type == ProviderType.SYSTEM.value else {}),
  44. 'token': ProviderService.get_obfuscated_api_key(current_user.current_tenant,
  45. ProviderName(p.provider_name), only_custom=True)
  46. if p.provider_type == ProviderType.CUSTOM.value else None
  47. }
  48. for p in providers
  49. ]
  50. return provider_list
  51. class ProviderTokenApi(Resource):
  52. @setup_required
  53. @login_required
  54. @account_initialization_required
  55. def post(self, provider):
  56. if provider not in [p.value for p in ProviderName]:
  57. abort(404)
  58. # The role of the current user in the ta table must be admin or owner
  59. if current_user.current_tenant.current_role not in ['admin', 'owner']:
  60. logging.log(logging.ERROR,
  61. f'User {current_user.id} is not authorized to update provider token, current_role is {current_user.current_tenant.current_role}')
  62. raise Forbidden()
  63. parser = reqparse.RequestParser()
  64. parser.add_argument('token', type=ProviderService.get_token_type(
  65. tenant=current_user.current_tenant,
  66. provider_name=ProviderName(provider)
  67. ), required=True, nullable=False, location='json')
  68. args = parser.parse_args()
  69. if args['token']:
  70. try:
  71. ProviderService.validate_provider_configs(
  72. tenant=current_user.current_tenant,
  73. provider_name=ProviderName(provider),
  74. configs=args['token']
  75. )
  76. token_is_valid = True
  77. except ValidateFailedError as ex:
  78. raise ValueError(str(ex))
  79. base64_encrypted_token = ProviderService.get_encrypted_token(
  80. tenant=current_user.current_tenant,
  81. provider_name=ProviderName(provider),
  82. configs=args['token']
  83. )
  84. else:
  85. base64_encrypted_token = None
  86. token_is_valid = False
  87. tenant = current_user.current_tenant
  88. provider_model = db.session.query(Provider).filter(
  89. Provider.tenant_id == tenant.id,
  90. Provider.provider_name == provider,
  91. Provider.provider_type == ProviderType.CUSTOM.value
  92. ).first()
  93. # Only allow updating token for CUSTOM provider type
  94. if provider_model:
  95. provider_model.encrypted_config = base64_encrypted_token
  96. provider_model.is_valid = token_is_valid
  97. else:
  98. provider_model = Provider(tenant_id=tenant.id, provider_name=provider,
  99. provider_type=ProviderType.CUSTOM.value,
  100. encrypted_config=base64_encrypted_token,
  101. is_valid=token_is_valid)
  102. db.session.add(provider_model)
  103. if provider in [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value] and provider_model.is_valid:
  104. other_providers = db.session.query(Provider).filter(
  105. Provider.tenant_id == tenant.id,
  106. Provider.provider_name.in_([ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value]),
  107. Provider.provider_name != provider,
  108. Provider.provider_type == ProviderType.CUSTOM.value
  109. ).all()
  110. for other_provider in other_providers:
  111. other_provider.is_valid = False
  112. db.session.commit()
  113. if provider in [ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,
  114. ProviderName.HUGGINGFACEHUB.value]:
  115. return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}, 201
  116. return {'result': 'success'}, 201
  117. class ProviderTokenValidateApi(Resource):
  118. @setup_required
  119. @login_required
  120. @account_initialization_required
  121. def post(self, provider):
  122. if provider not in [p.value for p in ProviderName]:
  123. abort(404)
  124. parser = reqparse.RequestParser()
  125. parser.add_argument('token', type=ProviderService.get_token_type(
  126. tenant=current_user.current_tenant,
  127. provider_name=ProviderName(provider)
  128. ), required=True, nullable=False, location='json')
  129. args = parser.parse_args()
  130. # todo: remove this when the provider is supported
  131. if provider in [ProviderName.COHERE.value,
  132. ProviderName.HUGGINGFACEHUB.value]:
  133. return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}
  134. result = True
  135. error = None
  136. try:
  137. ProviderService.validate_provider_configs(
  138. tenant=current_user.current_tenant,
  139. provider_name=ProviderName(provider),
  140. configs=args['token']
  141. )
  142. except ValidateFailedError as e:
  143. result = False
  144. error = str(e)
  145. response = {'result': 'success' if result else 'error'}
  146. if not result:
  147. response['error'] = error
  148. return response
  149. class ProviderSystemApi(Resource):
  150. @setup_required
  151. @login_required
  152. @account_initialization_required
  153. def put(self, provider):
  154. if provider not in [p.value for p in ProviderName]:
  155. abort(404)
  156. parser = reqparse.RequestParser()
  157. parser.add_argument('is_enabled', type=bool, required=True, location='json')
  158. args = parser.parse_args()
  159. tenant = current_user.current_tenant_id
  160. provider_model = Provider.query.filter_by(tenant_id=tenant.id, provider_name=provider).first()
  161. if provider_model and provider_model.provider_type == ProviderType.SYSTEM.value:
  162. provider_model.is_valid = args['is_enabled']
  163. db.session.commit()
  164. elif not provider_model:
  165. if provider == ProviderName.OPENAI.value:
  166. quota_limit = current_app.config['OPENAI_HOSTED_QUOTA_LIMIT']
  167. elif provider == ProviderName.ANTHROPIC.value:
  168. quota_limit = current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT']
  169. else:
  170. quota_limit = 0
  171. ProviderService.create_system_provider(
  172. tenant,
  173. provider,
  174. quota_limit,
  175. args['is_enabled']
  176. )
  177. else:
  178. abort(403)
  179. return {'result': 'success'}
  180. @setup_required
  181. @login_required
  182. @account_initialization_required
  183. def get(self, provider):
  184. if provider not in [p.value for p in ProviderName]:
  185. abort(404)
  186. # The role of the current user in the ta table must be admin or owner
  187. if current_user.current_tenant.current_role not in ['admin', 'owner']:
  188. raise Forbidden()
  189. provider_model = db.session.query(Provider).filter(Provider.tenant_id == current_user.current_tenant_id,
  190. Provider.provider_name == provider,
  191. Provider.provider_type == ProviderType.SYSTEM.value).first()
  192. system_model = None
  193. if provider_model:
  194. system_model = {
  195. 'result': 'success',
  196. 'provider': {
  197. 'provider_name': provider_model.provider_name,
  198. 'provider_type': provider_model.provider_type,
  199. 'is_valid': provider_model.is_valid,
  200. 'last_used': provider_model.last_used,
  201. 'is_enabled': provider_model.is_enabled,
  202. 'quota_type': provider_model.quota_type,
  203. 'quota_limit': provider_model.quota_limit,
  204. 'quota_used': provider_model.quota_used
  205. }
  206. }
  207. else:
  208. abort(404)
  209. return system_model
  210. api.add_resource(ProviderTokenApi, '/providers/<provider>/token',
  211. endpoint='current_providers_token') # Deprecated
  212. api.add_resource(ProviderTokenValidateApi, '/providers/<provider>/token-validate',
  213. endpoint='current_providers_token_validate') # Deprecated
  214. api.add_resource(ProviderTokenApi, '/workspaces/current/providers/<provider>/token',
  215. endpoint='workspaces_current_providers_token') # PUT for updating provider token
  216. api.add_resource(ProviderTokenValidateApi, '/workspaces/current/providers/<provider>/token-validate',
  217. endpoint='workspaces_current_providers_token_validate') # POST for validating provider token
  218. api.add_resource(ProviderListApi, '/workspaces/current/providers') # GET for getting providers list
  219. api.add_resource(ProviderSystemApi, '/workspaces/current/providers/<provider>/system',
  220. endpoint='workspaces_current_providers_system') # GET for getting provider quota, PUT for updating provider status