generator.py 3.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. import os
  2. from flask_login import current_user # type: ignore
  3. from flask_restful import Resource, reqparse # type: ignore
  4. from controllers.console import api
  5. from controllers.console.app.error import (
  6. CompletionRequestError,
  7. ProviderModelCurrentlyNotSupportError,
  8. ProviderNotInitializeError,
  9. ProviderQuotaExceededError,
  10. )
  11. from controllers.console.wraps import account_initialization_required, setup_required
  12. from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
  13. from core.llm_generator.llm_generator import LLMGenerator
  14. from core.model_runtime.errors.invoke import InvokeError
  15. from libs.login import login_required
  16. class RuleGenerateApi(Resource):
  17. @setup_required
  18. @login_required
  19. @account_initialization_required
  20. def post(self):
  21. parser = reqparse.RequestParser()
  22. parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
  23. parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
  24. parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
  25. args = parser.parse_args()
  26. account = current_user
  27. PROMPT_GENERATION_MAX_TOKENS = int(os.getenv("PROMPT_GENERATION_MAX_TOKENS", "512"))
  28. try:
  29. rules = LLMGenerator.generate_rule_config(
  30. tenant_id=account.current_tenant_id,
  31. instruction=args["instruction"],
  32. model_config=args["model_config"],
  33. no_variable=args["no_variable"],
  34. rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS,
  35. )
  36. except ProviderTokenNotInitError as ex:
  37. raise ProviderNotInitializeError(ex.description)
  38. except QuotaExceededError:
  39. raise ProviderQuotaExceededError()
  40. except ModelCurrentlyNotSupportError:
  41. raise ProviderModelCurrentlyNotSupportError()
  42. except InvokeError as e:
  43. raise CompletionRequestError(e.description)
  44. return rules
  45. class RuleCodeGenerateApi(Resource):
  46. @setup_required
  47. @login_required
  48. @account_initialization_required
  49. def post(self):
  50. parser = reqparse.RequestParser()
  51. parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
  52. parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
  53. parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
  54. parser.add_argument("code_language", type=str, required=False, default="javascript", location="json")
  55. args = parser.parse_args()
  56. account = current_user
  57. CODE_GENERATION_MAX_TOKENS = int(os.getenv("CODE_GENERATION_MAX_TOKENS", "1024"))
  58. try:
  59. code_result = LLMGenerator.generate_code(
  60. tenant_id=account.current_tenant_id,
  61. instruction=args["instruction"],
  62. model_config=args["model_config"],
  63. code_language=args["code_language"],
  64. max_tokens=CODE_GENERATION_MAX_TOKENS,
  65. )
  66. except ProviderTokenNotInitError as ex:
  67. raise ProviderNotInitializeError(ex.description)
  68. except QuotaExceededError:
  69. raise ProviderQuotaExceededError()
  70. except ModelCurrentlyNotSupportError:
  71. raise ProviderModelCurrentlyNotSupportError()
  72. except InvokeError as e:
  73. raise CompletionRequestError(e.description)
  74. return code_result
  75. api.add_resource(RuleGenerateApi, "/rule-generate")
  76. api.add_resource(RuleCodeGenerateApi, "/rule-code-generate")