provider.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. from pydantic import Field
  2. from core.entities.provider_entities import ProviderConfig
  3. from core.tools.__base.tool_provider import ToolProviderController
  4. from core.tools.__base.tool_runtime import ToolRuntime
  5. from core.tools.custom_tool.tool import ApiTool
  6. from core.tools.entities.common_entities import I18nObject
  7. from core.tools.entities.tool_bundle import ApiToolBundle
  8. from core.tools.entities.tool_entities import (
  9. ApiProviderAuthType,
  10. ToolDescription,
  11. ToolEntity,
  12. ToolIdentity,
  13. ToolProviderEntity,
  14. ToolProviderIdentity,
  15. ToolProviderType,
  16. )
  17. from extensions.ext_database import db
  18. from models.tools import ApiToolProvider
  19. class ApiToolProviderController(ToolProviderController):
  20. provider_id: str
  21. tenant_id: str
  22. tools: list[ApiTool] = Field(default_factory=list)
  23. def __init__(self, entity: ToolProviderEntity, provider_id: str, tenant_id: str) -> None:
  24. super().__init__(entity)
  25. self.provider_id = provider_id
  26. self.tenant_id = tenant_id
  27. self.tools = []
  28. @classmethod
  29. def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController":
  30. credentials_schema = [
  31. ProviderConfig(
  32. name="auth_type",
  33. required=True,
  34. type=ProviderConfig.Type.SELECT,
  35. options=[
  36. ProviderConfig.Option(value="none", label=I18nObject(en_US="None", zh_Hans="无")),
  37. ProviderConfig.Option(value="api_key", label=I18nObject(en_US="api_key", zh_Hans="api_key")),
  38. ],
  39. default="none",
  40. help=I18nObject(en_US="The auth type of the api provider", zh_Hans="api provider 的认证类型"),
  41. )
  42. ]
  43. if auth_type == ApiProviderAuthType.API_KEY:
  44. credentials_schema = [
  45. *credentials_schema,
  46. ProviderConfig(
  47. name="api_key_header",
  48. required=False,
  49. default="api_key",
  50. type=ProviderConfig.Type.TEXT_INPUT,
  51. help=I18nObject(en_US="The header name of the api key", zh_Hans="携带 api key 的 header 名称"),
  52. ),
  53. ProviderConfig(
  54. name="api_key_value",
  55. required=True,
  56. type=ProviderConfig.Type.SECRET_INPUT,
  57. help=I18nObject(en_US="The api key", zh_Hans="api key的值"),
  58. ),
  59. ProviderConfig(
  60. name="api_key_header_prefix",
  61. required=False,
  62. default="basic",
  63. type=ProviderConfig.Type.SELECT,
  64. help=I18nObject(en_US="The prefix of the api key header", zh_Hans="api key header 的前缀"),
  65. options=[
  66. ProviderConfig.Option(value="basic", label=I18nObject(en_US="Basic", zh_Hans="Basic")),
  67. ProviderConfig.Option(value="bearer", label=I18nObject(en_US="Bearer", zh_Hans="Bearer")),
  68. ProviderConfig.Option(value="custom", label=I18nObject(en_US="Custom", zh_Hans="Custom")),
  69. ],
  70. ),
  71. ]
  72. elif auth_type == ApiProviderAuthType.NONE:
  73. pass
  74. user = db_provider.user
  75. user_name = user.name if user else ""
  76. return ApiToolProviderController(
  77. entity=ToolProviderEntity(
  78. identity=ToolProviderIdentity(
  79. author=user_name,
  80. name=db_provider.name,
  81. label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
  82. description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
  83. icon=db_provider.icon,
  84. ),
  85. credentials_schema=credentials_schema,
  86. plugin_id=None,
  87. ),
  88. provider_id=db_provider.id or "",
  89. tenant_id=db_provider.tenant_id or "",
  90. )
  91. @property
  92. def provider_type(self) -> ToolProviderType:
  93. return ToolProviderType.API
  94. def _parse_tool_bundle(self, tool_bundle: ApiToolBundle) -> ApiTool:
  95. """
  96. parse tool bundle to tool
  97. :param tool_bundle: the tool bundle
  98. :return: the tool
  99. """
  100. return ApiTool(
  101. api_bundle=tool_bundle,
  102. provider_id=self.provider_id,
  103. entity=ToolEntity(
  104. identity=ToolIdentity(
  105. author=tool_bundle.author,
  106. name=tool_bundle.operation_id or "default_tool",
  107. label=I18nObject(
  108. en_US=tool_bundle.operation_id or "default_tool",
  109. zh_Hans=tool_bundle.operation_id or "default_tool",
  110. ),
  111. icon=self.entity.identity.icon,
  112. provider=self.provider_id,
  113. ),
  114. description=ToolDescription(
  115. human=I18nObject(en_US=tool_bundle.summary or "", zh_Hans=tool_bundle.summary or ""),
  116. llm=tool_bundle.summary or "",
  117. ),
  118. parameters=tool_bundle.parameters or [],
  119. ),
  120. runtime=ToolRuntime(tenant_id=self.tenant_id),
  121. )
  122. def load_bundled_tools(self, tools: list[ApiToolBundle]):
  123. """
  124. load bundled tools
  125. :param tools: the bundled tools
  126. :return: the tools
  127. """
  128. self.tools = [self._parse_tool_bundle(tool) for tool in tools]
  129. return self.tools
  130. def get_tools(self, tenant_id: str) -> list[ApiTool]:
  131. """
  132. fetch tools from database
  133. :param user_id: the user id
  134. :param tenant_id: the tenant id
  135. :return: the tools
  136. """
  137. if len(self.tools) > 0:
  138. return self.tools
  139. tools: list[ApiTool] = []
  140. # get tenant api providers
  141. db_providers: list[ApiToolProvider] = (
  142. db.session.query(ApiToolProvider)
  143. .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name)
  144. .all()
  145. )
  146. if db_providers and len(db_providers) != 0:
  147. for db_provider in db_providers:
  148. for tool in db_provider.tools:
  149. assistant_tool = self._parse_tool_bundle(tool)
  150. tools.append(assistant_tool)
  151. self.tools = tools
  152. return tools
  153. def get_tool(self, tool_name: str):
  154. """
  155. get tool by name
  156. :param tool_name: the name of the tool
  157. :return: the tool
  158. """
  159. if self.tools is None:
  160. self.get_tools(self.tenant_id)
  161. for tool in self.tools:
  162. if tool.entity.identity.name == tool_name:
  163. return tool
  164. raise ValueError(f"tool {tool_name} not found")