api_tool.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. import json
  2. from os import getenv
  3. from typing import Any
  4. from urllib.parse import urlencode
  5. import httpx
  6. from core.file.file_manager import download
  7. from core.helper import ssrf_proxy
  8. from core.tools.entities.tool_bundle import ApiToolBundle
  9. from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
  10. from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
  11. from core.tools.tool.tool import Tool
  12. API_TOOL_DEFAULT_TIMEOUT = (
  13. int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")),
  14. int(getenv("API_TOOL_DEFAULT_READ_TIMEOUT", "60")),
  15. )
  16. class ApiTool(Tool):
  17. api_bundle: ApiToolBundle
  18. """
  19. Api tool
  20. """
  21. def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool":
  22. """
  23. fork a new tool with meta data
  24. :param meta: the meta data of a tool call processing, tenant_id is required
  25. :return: the new tool
  26. """
  27. return self.__class__(
  28. identity=self.identity.model_copy() if self.identity else None,
  29. parameters=self.parameters.copy() if self.parameters else None,
  30. description=self.description.model_copy() if self.description else None,
  31. api_bundle=self.api_bundle.model_copy() if self.api_bundle else None,
  32. runtime=Tool.Runtime(**runtime),
  33. )
  34. def validate_credentials(
  35. self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False
  36. ) -> str:
  37. """
  38. validate the credentials for Api tool
  39. """
  40. # assemble validate request and request parameters
  41. headers = self.assembling_request(parameters)
  42. if format_only:
  43. return ""
  44. response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, parameters)
  45. # validate response
  46. return self.validate_and_parse_response(response)
  47. def tool_provider_type(self) -> ToolProviderType:
  48. return ToolProviderType.API
  49. def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
  50. headers = {}
  51. credentials = self.runtime.credentials or {}
  52. if "auth_type" not in credentials:
  53. raise ToolProviderCredentialValidationError("Missing auth_type")
  54. if credentials["auth_type"] == "api_key":
  55. api_key_header = "api_key"
  56. if "api_key_header" in credentials:
  57. api_key_header = credentials["api_key_header"]
  58. if "api_key_value" not in credentials:
  59. raise ToolProviderCredentialValidationError("Missing api_key_value")
  60. elif not isinstance(credentials["api_key_value"], str):
  61. raise ToolProviderCredentialValidationError("api_key_value must be a string")
  62. if "api_key_header_prefix" in credentials:
  63. api_key_header_prefix = credentials["api_key_header_prefix"]
  64. if api_key_header_prefix == "basic" and credentials["api_key_value"]:
  65. credentials["api_key_value"] = f'Basic {credentials["api_key_value"]}'
  66. elif api_key_header_prefix == "bearer" and credentials["api_key_value"]:
  67. credentials["api_key_value"] = f'Bearer {credentials["api_key_value"]}'
  68. elif api_key_header_prefix == "custom":
  69. pass
  70. headers[api_key_header] = credentials["api_key_value"]
  71. needed_parameters = [parameter for parameter in self.api_bundle.parameters if parameter.required]
  72. for parameter in needed_parameters:
  73. if parameter.required and parameter.name not in parameters:
  74. raise ToolParameterValidationError(f"Missing required parameter {parameter.name}")
  75. if parameter.default is not None and parameter.name not in parameters:
  76. parameters[parameter.name] = parameter.default
  77. return headers
  78. def validate_and_parse_response(self, response: httpx.Response) -> str:
  79. """
  80. validate the response
  81. """
  82. if isinstance(response, httpx.Response):
  83. if response.status_code >= 400:
  84. raise ToolInvokeError(f"Request failed with status code {response.status_code} and {response.text}")
  85. if not response.content:
  86. return "Empty response from the tool, please check your parameters and try again."
  87. try:
  88. response = response.json()
  89. try:
  90. return json.dumps(response, ensure_ascii=False)
  91. except Exception as e:
  92. return json.dumps(response)
  93. except Exception as e:
  94. return response.text
  95. else:
  96. raise ValueError(f"Invalid response type {type(response)}")
  97. @staticmethod
  98. def get_parameter_value(parameter, parameters):
  99. if parameter["name"] in parameters:
  100. return parameters[parameter["name"]]
  101. elif parameter.get("required", False):
  102. raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}")
  103. else:
  104. return (parameter.get("schema", {}) or {}).get("default", "")
  105. def do_http_request(
  106. self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any]
  107. ) -> httpx.Response:
  108. """
  109. do http request depending on api bundle
  110. """
  111. method = method.lower()
  112. params = {}
  113. path_params = {}
  114. body = {}
  115. cookies = {}
  116. files = []
  117. # check parameters
  118. for parameter in self.api_bundle.openapi.get("parameters", []):
  119. value = self.get_parameter_value(parameter, parameters)
  120. if parameter["in"] == "path":
  121. path_params[parameter["name"]] = value
  122. elif parameter["in"] == "query":
  123. if value != "":
  124. params[parameter["name"]] = value
  125. elif parameter["in"] == "cookie":
  126. cookies[parameter["name"]] = value
  127. elif parameter["in"] == "header":
  128. headers[parameter["name"]] = value
  129. # check if there is a request body and handle it
  130. if "requestBody" in self.api_bundle.openapi and self.api_bundle.openapi["requestBody"] is not None:
  131. # handle json request body
  132. if "content" in self.api_bundle.openapi["requestBody"]:
  133. for content_type in self.api_bundle.openapi["requestBody"]["content"]:
  134. headers["Content-Type"] = content_type
  135. body_schema = self.api_bundle.openapi["requestBody"]["content"][content_type]["schema"]
  136. required = body_schema.get("required", [])
  137. properties = body_schema.get("properties", {})
  138. for name, property in properties.items():
  139. if name in parameters:
  140. if property.get("format") == "binary":
  141. f = parameters[name]
  142. files.append((name, (f.filename, download(f), f.mime_type)))
  143. else:
  144. # convert type
  145. body[name] = self._convert_body_property_type(property, parameters[name])
  146. elif name in required:
  147. raise ToolParameterValidationError(
  148. f"Missing required parameter {name} in operation {self.api_bundle.operation_id}"
  149. )
  150. elif "default" in property:
  151. body[name] = property["default"]
  152. else:
  153. body[name] = None
  154. break
  155. # replace path parameters
  156. for name, value in path_params.items():
  157. url = url.replace(f"{{{name}}}", f"{value}")
  158. # parse http body data if needed
  159. if "Content-Type" in headers:
  160. if headers["Content-Type"] == "application/json":
  161. body = json.dumps(body)
  162. elif headers["Content-Type"] == "application/x-www-form-urlencoded":
  163. body = urlencode(body)
  164. else:
  165. body = body
  166. if method in {"get", "head", "post", "put", "delete", "patch"}:
  167. response = getattr(ssrf_proxy, method)(
  168. url,
  169. params=params,
  170. headers=headers,
  171. cookies=cookies,
  172. data=body,
  173. files=files,
  174. timeout=API_TOOL_DEFAULT_TIMEOUT,
  175. follow_redirects=True,
  176. )
  177. return response
  178. else:
  179. raise ValueError(f"Invalid http method {self.method}")
  180. def _convert_body_property_any_of(
  181. self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10
  182. ) -> Any:
  183. if max_recursive <= 0:
  184. raise Exception("Max recursion depth reached")
  185. for option in any_of or []:
  186. try:
  187. if "type" in option:
  188. # Attempt to convert the value based on the type.
  189. if option["type"] == "integer" or option["type"] == "int":
  190. return int(value)
  191. elif option["type"] == "number":
  192. if "." in str(value):
  193. return float(value)
  194. else:
  195. return int(value)
  196. elif option["type"] == "string":
  197. return str(value)
  198. elif option["type"] == "boolean":
  199. if str(value).lower() in {"true", "1"}:
  200. return True
  201. elif str(value).lower() in {"false", "0"}:
  202. return False
  203. else:
  204. continue # Not a boolean, try next option
  205. elif option["type"] == "null" and not value:
  206. return None
  207. else:
  208. continue # Unsupported type, try next option
  209. elif "anyOf" in option and isinstance(option["anyOf"], list):
  210. # Recursive call to handle nested anyOf
  211. return self._convert_body_property_any_of(property, value, option["anyOf"], max_recursive - 1)
  212. except ValueError:
  213. continue # Conversion failed, try next option
  214. # If no option succeeded, you might want to return the value as is or raise an error
  215. return value # or raise ValueError(f"Cannot convert value '{value}' to any specified type in anyOf")
  216. def _convert_body_property_type(self, property: dict[str, Any], value: Any) -> Any:
  217. try:
  218. if "type" in property:
  219. if property["type"] == "integer" or property["type"] == "int":
  220. return int(value)
  221. elif property["type"] == "number":
  222. # check if it is a float
  223. if "." in str(value):
  224. return float(value)
  225. else:
  226. return int(value)
  227. elif property["type"] == "string":
  228. return str(value)
  229. elif property["type"] == "boolean":
  230. return bool(value)
  231. elif property["type"] == "null":
  232. if value is None:
  233. return None
  234. elif property["type"] == "object" or property["type"] == "array":
  235. if isinstance(value, str):
  236. try:
  237. # an array str like '[1,2]' also can convert to list [1,2] through json.loads
  238. # json not support single quote, but we can support it
  239. value = value.replace("'", '"')
  240. return json.loads(value)
  241. except ValueError:
  242. return value
  243. elif isinstance(value, dict):
  244. return value
  245. else:
  246. return value
  247. else:
  248. raise ValueError(f"Invalid type {property['type']} for property {property}")
  249. elif "anyOf" in property and isinstance(property["anyOf"], list):
  250. return self._convert_body_property_any_of(property, value, property["anyOf"])
  251. except ValueError as e:
  252. return value
  253. def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
  254. """
  255. invoke http request
  256. """
  257. # assemble request
  258. headers = self.assembling_request(tool_parameters)
  259. # do http request
  260. response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, tool_parameters)
  261. # validate response
  262. response = self.validate_and_parse_response(response)
  263. # assemble invoke message
  264. return self.create_text_message(response)