model_providers.py 10 KB


  1. from flask_login import current_user
  2. from core.login.login import login_required
  3. from flask_restful import Resource, reqparse
  4. from werkzeug.exceptions import Forbidden
  5. from controllers.console import api
  6. from controllers.console.app.error import ProviderNotInitializeError
  7. from controllers.console.setup import setup_required
  8. from controllers.console.wraps import account_initialization_required
  9. from core.model_providers.error import LLMBadRequestError
  10. from core.model_providers.providers.base import CredentialsValidateFailedError
  11. from services.provider_checkout_service import ProviderCheckoutService
  12. from services.provider_service import ProviderService
  13. class ModelProviderListApi(Resource):
  14. @setup_required
  15. @login_required
  16. @account_initialization_required
  17. def get(self):
  18. tenant_id = current_user.current_tenant_id
  19. provider_service = ProviderService()
  20. provider_list = provider_service.get_provider_list(tenant_id)
  21. return provider_list
  22. class ModelProviderValidateApi(Resource):
  23. @setup_required
  24. @login_required
  25. @account_initialization_required
  26. def post(self, provider_name: str):
  27. parser = reqparse.RequestParser()
  28. parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
  29. args = parser.parse_args()
  30. provider_service = ProviderService()
  31. result = True
  32. error = None
  33. try:
  34. provider_service.custom_provider_config_validate(
  35. provider_name=provider_name,
  36. config=args['config']
  37. )
  38. except CredentialsValidateFailedError as ex:
  39. result = False
  40. error = str(ex)
  41. response = {'result': 'success' if result else 'error'}
  42. if not result:
  43. response['error'] = error
  44. return response
  45. class ModelProviderUpdateApi(Resource):
  46. @setup_required
  47. @login_required
  48. @account_initialization_required
  49. def post(self, provider_name: str):
  50. if current_user.current_tenant.current_role not in ['admin', 'owner']:
  51. raise Forbidden()
  52. parser = reqparse.RequestParser()
  53. parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
  54. args = parser.parse_args()
  55. provider_service = ProviderService()
  56. try:
  57. provider_service.save_custom_provider_config(
  58. tenant_id=current_user.current_tenant_id,
  59. provider_name=provider_name,
  60. config=args['config']
  61. )
  62. except CredentialsValidateFailedError as ex:
  63. raise ValueError(str(ex))
  64. return {'result': 'success'}, 201
  65. @setup_required
  66. @login_required
  67. @account_initialization_required
  68. def delete(self, provider_name: str):
  69. if current_user.current_tenant.current_role not in ['admin', 'owner']:
  70. raise Forbidden()
  71. provider_service = ProviderService()
  72. provider_service.delete_custom_provider(
  73. tenant_id=current_user.current_tenant_id,
  74. provider_name=provider_name
  75. )
  76. return {'result': 'success'}, 204
  77. class ModelProviderModelValidateApi(Resource):
  78. @setup_required
  79. @login_required
  80. @account_initialization_required
  81. def post(self, provider_name: str):
  82. parser = reqparse.RequestParser()
  83. parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
  84. parser.add_argument('model_type', type=str, required=True, nullable=False,
  85. choices=['text-generation', 'embeddings', 'speech2text'], location='json')
  86. parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
  87. args = parser.parse_args()
  88. provider_service = ProviderService()
  89. result = True
  90. error = None
  91. try:
  92. provider_service.custom_provider_model_config_validate(
  93. provider_name=provider_name,
  94. model_name=args['model_name'],
  95. model_type=args['model_type'],
  96. config=args['config']
  97. )
  98. except CredentialsValidateFailedError as ex:
  99. result = False
  100. error = str(ex)
  101. response = {'result': 'success' if result else 'error'}
  102. if not result:
  103. response['error'] = error
  104. return response
  105. class ModelProviderModelUpdateApi(Resource):
  106. @setup_required
  107. @login_required
  108. @account_initialization_required
  109. def post(self, provider_name: str):
  110. if current_user.current_tenant.current_role not in ['admin', 'owner']:
  111. raise Forbidden()
  112. parser = reqparse.RequestParser()
  113. parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
  114. parser.add_argument('model_type', type=str, required=True, nullable=False,
  115. choices=['text-generation', 'embeddings', 'speech2text'], location='json')
  116. parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
  117. args = parser.parse_args()
  118. provider_service = ProviderService()
  119. try:
  120. provider_service.add_or_save_custom_provider_model_config(
  121. tenant_id=current_user.current_tenant_id,
  122. provider_name=provider_name,
  123. model_name=args['model_name'],
  124. model_type=args['model_type'],
  125. config=args['config']
  126. )
  127. except CredentialsValidateFailedError as ex:
  128. raise ValueError(str(ex))
  129. return {'result': 'success'}, 200
  130. @setup_required
  131. @login_required
  132. @account_initialization_required
  133. def delete(self, provider_name: str):
  134. if current_user.current_tenant.current_role not in ['admin', 'owner']:
  135. raise Forbidden()
  136. parser = reqparse.RequestParser()
  137. parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
  138. parser.add_argument('model_type', type=str, required=True, nullable=False,
  139. choices=['text-generation', 'embeddings', 'speech2text'], location='args')
  140. args = parser.parse_args()
  141. provider_service = ProviderService()
  142. provider_service.delete_custom_provider_model(
  143. tenant_id=current_user.current_tenant_id,
  144. provider_name=provider_name,
  145. model_name=args['model_name'],
  146. model_type=args['model_type']
  147. )
  148. return {'result': 'success'}, 204
  149. class PreferredProviderTypeUpdateApi(Resource):
  150. @setup_required
  151. @login_required
  152. @account_initialization_required
  153. def post(self, provider_name: str):
  154. if current_user.current_tenant.current_role not in ['admin', 'owner']:
  155. raise Forbidden()
  156. parser = reqparse.RequestParser()
  157. parser.add_argument('preferred_provider_type', type=str, required=True, nullable=False,
  158. choices=['system', 'custom'], location='json')
  159. args = parser.parse_args()
  160. provider_service = ProviderService()
  161. provider_service.switch_preferred_provider(
  162. tenant_id=current_user.current_tenant_id,
  163. provider_name=provider_name,
  164. preferred_provider_type=args['preferred_provider_type']
  165. )
  166. return {'result': 'success'}
  167. class ModelProviderModelParameterRuleApi(Resource):
  168. @setup_required
  169. @login_required
  170. @account_initialization_required
  171. def get(self, provider_name: str):
  172. parser = reqparse.RequestParser()
  173. parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
  174. args = parser.parse_args()
  175. provider_service = ProviderService()
  176. try:
  177. parameter_rules = provider_service.get_model_parameter_rules(
  178. tenant_id=current_user.current_tenant_id,
  179. model_provider_name=provider_name,
  180. model_name=args['model_name'],
  181. model_type='text-generation'
  182. )
  183. except LLMBadRequestError:
  184. raise ProviderNotInitializeError(
  185. f"Current Text Generation Model is invalid. Please switch to the available model.")
  186. rules = {
  187. k: {
  188. 'enabled': v.enabled,
  189. 'min': v.min,
  190. 'max': v.max,
  191. 'default': v.default
  192. }
  193. for k, v in vars(parameter_rules).items()
  194. }
  195. return rules
  196. class ModelProviderPaymentCheckoutUrlApi(Resource):
  197. @setup_required
  198. @login_required
  199. @account_initialization_required
  200. def get(self, provider_name: str):
  201. provider_service = ProviderCheckoutService()
  202. provider_checkout = provider_service.create_checkout(
  203. tenant_id=current_user.current_tenant_id,
  204. provider_name=provider_name,
  205. account=current_user
  206. )
  207. return {
  208. 'url': provider_checkout.get_checkout_url()
  209. }
  210. class ModelProviderFreeQuotaSubmitApi(Resource):
  211. @setup_required
  212. @login_required
  213. @account_initialization_required
  214. def post(self, provider_name: str):
  215. provider_service = ProviderService()
  216. result = provider_service.free_quota_submit(
  217. tenant_id=current_user.current_tenant_id,
  218. provider_name=provider_name
  219. )
  220. return result
  221. api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers')
  222. api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider_name>/validate')
  223. api.add_resource(ModelProviderUpdateApi, '/workspaces/current/model-providers/<string:provider_name>')
  224. api.add_resource(ModelProviderModelValidateApi,
  225. '/workspaces/current/model-providers/<string:provider_name>/models/validate')
  226. api.add_resource(ModelProviderModelUpdateApi,
  227. '/workspaces/current/model-providers/<string:provider_name>/models')
  228. api.add_resource(PreferredProviderTypeUpdateApi,
  229. '/workspaces/current/model-providers/<string:provider_name>/preferred-provider-type')
  230. api.add_resource(ModelProviderModelParameterRuleApi,
  231. '/workspaces/current/model-providers/<string:provider_name>/models/parameter-rules')
  232. api.add_resource(ModelProviderPaymentCheckoutUrlApi,
  233. '/workspaces/current/model-providers/<string:provider_name>/checkout-url')
  234. api.add_resource(ModelProviderFreeQuotaSubmitApi,
  235. '/workspaces/current/model-providers/<string:provider_name>/free-quota-submit')