api_tool.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. import json
  2. from json import dumps
  3. from typing import Any, Union
  4. import httpx
  5. import requests
  6. import core.helper.ssrf_proxy as ssrf_proxy
  7. from core.tools.entities.tool_bundle import ApiBasedToolBundle
  8. from core.tools.entities.tool_entities import ToolInvokeMessage
  9. from core.tools.errors import ToolProviderCredentialValidationError
  10. from core.tools.tool.tool import Tool
  11. class ApiTool(Tool):
  12. api_bundle: ApiBasedToolBundle
  13. """
  14. Api tool
  15. """
  16. def fork_tool_runtime(self, meta: dict[str, Any]) -> 'Tool':
  17. """
  18. fork a new tool with meta data
  19. :param meta: the meta data of a tool call processing, tenant_id is required
  20. :return: the new tool
  21. """
  22. return self.__class__(
  23. identity=self.identity.copy() if self.identity else None,
  24. parameters=self.parameters.copy() if self.parameters else None,
  25. description=self.description.copy() if self.description else None,
  26. api_bundle=self.api_bundle.copy() if self.api_bundle else None,
  27. runtime=Tool.Runtime(**meta)
  28. )
  29. def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False) -> str:
  30. """
  31. validate the credentials for Api tool
  32. """
  33. # assemble validate request and request parameters
  34. headers = self.assembling_request(parameters)
  35. if format_only:
  36. return
  37. response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, parameters)
  38. # validate response
  39. return self.validate_and_parse_response(response)
  40. def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
  41. headers = {}
  42. credentials = self.runtime.credentials or {}
  43. if 'auth_type' not in credentials:
  44. raise ToolProviderCredentialValidationError('Missing auth_type')
  45. if credentials['auth_type'] == 'api_key':
  46. api_key_header = 'api_key'
  47. if 'api_key_header' in credentials:
  48. api_key_header = credentials['api_key_header']
  49. if 'api_key_value' not in credentials:
  50. raise ToolProviderCredentialValidationError('Missing api_key_value')
  51. elif not isinstance(credentials['api_key_value'], str):
  52. raise ToolProviderCredentialValidationError('api_key_value must be a string')
  53. if 'api_key_header_prefix' in credentials:
  54. api_key_header_prefix = credentials['api_key_header_prefix']
  55. if api_key_header_prefix == 'basic':
  56. credentials['api_key_value'] = f'Basic {credentials["api_key_value"]}'
  57. elif api_key_header_prefix == 'bearer':
  58. credentials['api_key_value'] = f'Bearer {credentials["api_key_value"]}'
  59. elif api_key_header_prefix == 'custom':
  60. pass
  61. headers[api_key_header] = credentials['api_key_value']
  62. needed_parameters = [parameter for parameter in self.api_bundle.parameters if parameter.required]
  63. for parameter in needed_parameters:
  64. if parameter.required and parameter.name not in parameters:
  65. raise ToolProviderCredentialValidationError(f"Missing required parameter {parameter.name}")
  66. if parameter.default is not None and parameter.name not in parameters:
  67. parameters[parameter.name] = parameter.default
  68. return headers
  69. def validate_and_parse_response(self, response: Union[httpx.Response, requests.Response]) -> str:
  70. """
  71. validate the response
  72. """
  73. if isinstance(response, httpx.Response):
  74. if response.status_code >= 400:
  75. raise ToolProviderCredentialValidationError(f"Request failed with status code {response.status_code}")
  76. if not response.content:
  77. return 'Empty response from the tool, please check your parameters and try again.'
  78. try:
  79. response = response.json()
  80. try:
  81. return json.dumps(response, ensure_ascii=False)
  82. except Exception as e:
  83. return json.dumps(response)
  84. except Exception as e:
  85. return response.text
  86. elif isinstance(response, requests.Response):
  87. if not response.ok:
  88. raise ToolProviderCredentialValidationError(f"Request failed with status code {response.status_code}")
  89. if not response.content:
  90. return 'Empty response from the tool, please check your parameters and try again.'
  91. try:
  92. response = response.json()
  93. try:
  94. return json.dumps(response, ensure_ascii=False)
  95. except Exception as e:
  96. return json.dumps(response)
  97. except Exception as e:
  98. return response.text
  99. else:
  100. raise ValueError(f'Invalid response type {type(response)}')
  101. def do_http_request(self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any]) -> httpx.Response:
  102. """
  103. do http request depending on api bundle
  104. """
  105. method = method.lower()
  106. params = {}
  107. path_params = {}
  108. body = {}
  109. cookies = {}
  110. # check parameters
  111. for parameter in self.api_bundle.openapi.get('parameters', []):
  112. if parameter['in'] == 'path':
  113. value = ''
  114. if parameter['name'] in parameters:
  115. value = parameters[parameter['name']]
  116. elif parameter['required']:
  117. raise ToolProviderCredentialValidationError(f"Missing required parameter {parameter['name']}")
  118. else:
  119. value = (parameter.get('schema', {}) or {}).get('default', '')
  120. path_params[parameter['name']] = value
  121. elif parameter['in'] == 'query':
  122. value = ''
  123. if parameter['name'] in parameters:
  124. value = parameters[parameter['name']]
  125. elif parameter['required']:
  126. raise ToolProviderCredentialValidationError(f"Missing required parameter {parameter['name']}")
  127. else:
  128. value = (parameter.get('schema', {}) or {}).get('default', '')
  129. params[parameter['name']] = value
  130. elif parameter['in'] == 'cookie':
  131. value = ''
  132. if parameter['name'] in parameters:
  133. value = parameters[parameter['name']]
  134. elif parameter['required']:
  135. raise ToolProviderCredentialValidationError(f"Missing required parameter {parameter['name']}")
  136. else:
  137. value = (parameter.get('schema', {}) or {}).get('default', '')
  138. cookies[parameter['name']] = value
  139. elif parameter['in'] == 'header':
  140. value = ''
  141. if parameter['name'] in parameters:
  142. value = parameters[parameter['name']]
  143. elif parameter['required']:
  144. raise ToolProviderCredentialValidationError(f"Missing required parameter {parameter['name']}")
  145. else:
  146. value = (parameter.get('schema', {}) or {}).get('default', '')
  147. headers[parameter['name']] = value
  148. # check if there is a request body and handle it
  149. if 'requestBody' in self.api_bundle.openapi and self.api_bundle.openapi['requestBody'] is not None:
  150. # handle json request body
  151. if 'content' in self.api_bundle.openapi['requestBody']:
  152. for content_type in self.api_bundle.openapi['requestBody']['content']:
  153. headers['Content-Type'] = content_type
  154. body_schema = self.api_bundle.openapi['requestBody']['content'][content_type]['schema']
  155. required = body_schema['required'] if 'required' in body_schema else []
  156. properties = body_schema['properties'] if 'properties' in body_schema else {}
  157. for name, property in properties.items():
  158. if name in parameters:
  159. # convert type
  160. try:
  161. value = parameters[name]
  162. if property['type'] == 'integer':
  163. value = int(value)
  164. elif property['type'] == 'number':
  165. # check if it is a float
  166. if '.' in value:
  167. value = float(value)
  168. else:
  169. value = int(value)
  170. elif property['type'] == 'boolean':
  171. value = bool(value)
  172. body[name] = value
  173. except ValueError as e:
  174. body[name] = parameters[name]
  175. elif name in required:
  176. raise ToolProviderCredentialValidationError(
  177. f"Missing required parameter {name} in operation {self.api_bundle.operation_id}"
  178. )
  179. elif 'default' in property:
  180. body[name] = property['default']
  181. else:
  182. body[name] = None
  183. break
  184. # replace path parameters
  185. for name, value in path_params.items():
  186. url = url.replace(f'{{{name}}}', f'{value}')
  187. # parse http body data if needed, for GET/HEAD/OPTIONS/TRACE, the body is ignored
  188. if 'Content-Type' in headers:
  189. if headers['Content-Type'] == 'application/json':
  190. body = dumps(body)
  191. else:
  192. body = body
  193. # do http request
  194. if method == 'get':
  195. response = ssrf_proxy.get(url, params=params, headers=headers, cookies=cookies, timeout=10, follow_redirects=True)
  196. elif method == 'post':
  197. response = ssrf_proxy.post(url, params=params, headers=headers, cookies=cookies, data=body, timeout=10, follow_redirects=True)
  198. elif method == 'put':
  199. response = ssrf_proxy.put(url, params=params, headers=headers, cookies=cookies, data=body, timeout=10, follow_redirects=True)
  200. elif method == 'delete':
  201. """
  202. request body data is unsupported for DELETE method in standard http protocol
  203. however, OpenAPI 3.0 supports request body data for DELETE method, so we support it here by using requests
  204. """
  205. response = ssrf_proxy.delete(url, params=params, headers=headers, cookies=cookies, data=body, timeout=10, allow_redirects=True)
  206. elif method == 'patch':
  207. response = ssrf_proxy.patch(url, params=params, headers=headers, cookies=cookies, data=body, timeout=10, follow_redirects=True)
  208. elif method == 'head':
  209. response = ssrf_proxy.head(url, params=params, headers=headers, cookies=cookies, timeout=10, follow_redirects=True)
  210. elif method == 'options':
  211. response = ssrf_proxy.options(url, params=params, headers=headers, cookies=cookies, timeout=10, follow_redirects=True)
  212. else:
  213. raise ValueError(f'Invalid http method {method}')
  214. return response
  215. def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
  216. """
  217. invoke http request
  218. """
  219. # assemble request
  220. headers = self.assembling_request(tool_parameters)
  221. # do http request
  222. response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, tool_parameters)
  223. # validate response
  224. response = self.validate_and_parse_response(response)
  225. # assemble invoke message
  226. return self.create_text_message(response)