Browse Source

improve the code readability of http_executor node (#4360)

非法操作 11 months ago
parent
commit
3271e3e803
1 changed files with 63 additions and 123 deletions
  1. 63 123
      api/core/workflow/nodes/http_request/http_executor.py

+ 63 - 123
api/core/workflow/nodes/http_request/http_executor.py

@@ -14,28 +14,18 @@ from core.workflow.entities.variable_pool import ValueType, VariablePool
 from core.workflow.nodes.http_request.entities import HttpRequestNodeData
 from core.workflow.utils.variable_template_parser import VariableTemplateParser
 
-MAX_BINARY_SIZE = int(os.environ.get('HTTP_REQUEST_NODE_MAX_BINARY_SIZE', str(1024 * 1024 * 10))) # 10MB
+MAX_BINARY_SIZE = int(os.environ.get('HTTP_REQUEST_NODE_MAX_BINARY_SIZE', 1024 * 1024 * 10))  # 10MB
 READABLE_MAX_BINARY_SIZE = f'{MAX_BINARY_SIZE / 1024 / 1024:.2f}MB'
-MAX_TEXT_SIZE = int(os.environ.get('HTTP_REQUEST_NODE_MAX_TEXT_SIZE', str(1024 * 1024))) # 10MB # 1MB
+MAX_TEXT_SIZE = int(os.environ.get('HTTP_REQUEST_NODE_MAX_TEXT_SIZE', 1024 * 1024))  # 1MB
 READABLE_MAX_TEXT_SIZE = f'{MAX_TEXT_SIZE / 1024 / 1024:.2f}MB'
 
+
 class HttpExecutorResponse:
     headers: dict[str, str]
     response: Union[httpx.Response, requests.Response]
 
     def __init__(self, response: Union[httpx.Response, requests.Response] = None):
-        """
-        init
-        """
-        headers = {}
-        if isinstance(response, httpx.Response):
-            for k, v in response.headers.items():
-                headers[k] = v
-        elif isinstance(response, requests.Response):
-            for k, v in response.headers.items():
-                headers[k] = v
-
-        self.headers = headers
+        self.headers = response.headers
         self.response = response
 
     @property
@@ -45,21 +35,11 @@ class HttpExecutorResponse:
         """
         content_type = self.get_content_type()
         file_content_types = ['image', 'audio', 'video']
-        for v in file_content_types:
-            if v in content_type:
-                return True
-        
-        return False
+
+        return any(v in content_type for v in file_content_types)
 
     def get_content_type(self) -> str:
-        """
-        get content type
-        """
-        for key, val in self.headers.items():
-            if key.lower() == 'content-type':
-                return val
-        
-        return ''
+        return self.headers.get('content-type')
 
     def extract_file(self) -> tuple[str, bytes]:
         """
@@ -67,29 +47,25 @@ class HttpExecutorResponse:
         """
         if self.is_file:
             return self.get_content_type(), self.body
-            
+
         return '', b''
-    
+
     @property
     def content(self) -> str:
         """
         get content
         """
-        if isinstance(self.response, httpx.Response):
-            return self.response.text
-        elif isinstance(self.response, requests.Response):
+        if isinstance(self.response, httpx.Response | requests.Response):
             return self.response.text
         else:
             raise ValueError(f'Invalid response type {type(self.response)}')
-    
+
     @property
     def body(self) -> bytes:
         """
         get body
         """
-        if isinstance(self.response, httpx.Response):
-            return self.response.content
-        elif isinstance(self.response, requests.Response):
+        if isinstance(self.response, httpx.Response | requests.Response):
             return self.response.content
         else:
             raise ValueError(f'Invalid response type {type(self.response)}')
@@ -99,20 +75,18 @@ class HttpExecutorResponse:
         """
         get status code
         """
-        if isinstance(self.response, httpx.Response):
-            return self.response.status_code
-        elif isinstance(self.response, requests.Response):
+        if isinstance(self.response, httpx.Response | requests.Response):
             return self.response.status_code
         else:
             raise ValueError(f'Invalid response type {type(self.response)}')
-        
+
     @property
     def size(self) -> int:
         """
         get size
         """
         return len(self.body)
-    
+
     @property
     def readable_size(self) -> str:
         """
@@ -138,10 +112,8 @@ class HttpExecutor:
     variable_selectors: list[VariableSelector]
     timeout: HttpRequestNodeData.Timeout
 
-    def __init__(self, node_data: HttpRequestNodeData, timeout: HttpRequestNodeData.Timeout, variable_pool: Optional[VariablePool] = None):
-        """
-        init
-        """
+    def __init__(self, node_data: HttpRequestNodeData, timeout: HttpRequestNodeData.Timeout,
+                 variable_pool: Optional[VariablePool] = None):
         self.server_url = node_data.url
         self.method = node_data.method
         self.authorization = node_data.authorization
@@ -155,7 +127,8 @@ class HttpExecutor:
         self.variable_selectors = []
         self._init_template(node_data, variable_pool)
 
-    def _is_json_body(self, body: HttpRequestNodeData.Body):
+    @staticmethod
+    def _is_json_body(body: HttpRequestNodeData.Body):
         """
         check if body is json
         """
@@ -165,55 +138,46 @@ class HttpExecutor:
                 return True
             except:
                 return False
-        
+
         return False
 
-    def _init_template(self, node_data: HttpRequestNodeData, variable_pool: Optional[VariablePool] = None):
+    @staticmethod
+    def _to_dict(convert_item: str, convert_text: str, maxsplit: int = -1):
         """
-        init template
+        Convert the string like `aa:bb\n cc:dd` to dict `{aa:bb, cc:dd}`
+        :param convert_item: A label for what item to be converted, params, headers or body.
+        :param convert_text: The string containing key-value pairs separated by '\n'.
+        :param maxsplit: The maximum number of splits allowed for the ':' character in each key-value pair. Default is -1 (no limit).
+        :return: A dictionary containing the key-value pairs from the input string.
         """
-        variable_selectors = []
-
-        # extract all template in url
-        self.server_url, server_url_variable_selectors = self._format_template(node_data.url, variable_pool)
-
-        # extract all template in params
-        params, params_variable_selectors = self._format_template(node_data.params, variable_pool)
-
-        # fill in params
-        kv_paris = params.split('\n')
+        kv_paris = convert_text.split('\n')
+        result = {}
         for kv in kv_paris:
             if not kv.strip():
                 continue
 
-            kv = kv.split(':')
+            kv = kv.split(':', maxsplit=maxsplit)
             if len(kv) == 2:
                 k, v = kv
             elif len(kv) == 1:
                 k, v = kv[0], ''
             else:
-                raise ValueError(f'Invalid params {kv}')
-            
-            self.params[k.strip()] = v
+                raise ValueError(f'Invalid {convert_item} {kv}')
+            result[k.strip()] = v
+        return result
 
-        # extract all template in headers
-        headers, headers_variable_selectors = self._format_template(node_data.headers, variable_pool)
+    def _init_template(self, node_data: HttpRequestNodeData, variable_pool: Optional[VariablePool] = None):
 
-        # fill in headers
-        kv_paris = headers.split('\n')
-        for kv in kv_paris:
-            if not kv.strip():
-                continue
+        # extract all template in url
+        self.server_url, server_url_variable_selectors = self._format_template(node_data.url, variable_pool)
 
-            kv = kv.split(':')
-            if len(kv) == 2:
-                k, v = kv
-            elif len(kv) == 1:
-                k, v = kv[0], ''
-            else:
-                raise ValueError(f'Invalid headers {kv}')
-            
-            self.headers[k.strip()] = v.strip()
+        # extract all template in params
+        params, params_variable_selectors = self._format_template(node_data.params, variable_pool)
+        self.params = self._to_dict("params", params)
+
+        # extract all template in headers
+        headers, headers_variable_selectors = self._format_template(node_data.headers, variable_pool)
+        self.headers = self._to_dict("headers", headers)
 
         # extract all template in body
         body_data_variable_selectors = []
@@ -231,18 +195,7 @@ class HttpExecutor:
                 self.headers['Content-Type'] = 'application/x-www-form-urlencoded'
 
             if node_data.body.type in ['form-data', 'x-www-form-urlencoded']:
-                body = {}
-                kv_paris = body_data.split('\n')
-                for kv in kv_paris:
-                    if not kv.strip():
-                        continue
-                    kv = kv.split(':', 1)
-                    if len(kv) == 2:
-                        body[kv[0].strip()] = kv[1]
-                    elif len(kv) == 1:
-                        body[kv[0].strip()] = ''
-                    else:
-                        raise ValueError(f'Invalid body {kv}')
+                body = self._to_dict("body", body_data, 1)
 
                 if node_data.body.type == 'form-data':
                     self.files = {
@@ -261,14 +214,14 @@ class HttpExecutor:
 
         self.variable_selectors = (server_url_variable_selectors + params_variable_selectors
                                    + headers_variable_selectors + body_data_variable_selectors)
-                
+
     def _assembling_headers(self) -> dict[str, Any]:
         authorization = deepcopy(self.authorization)
         headers = deepcopy(self.headers) or {}
         if self.authorization.type == 'api-key':
             if self.authorization.config.api_key is None:
                 raise ValueError('api_key is required')
-            
+
             if not self.authorization.config.header:
                 authorization.config.header = 'Authorization'
 
@@ -278,9 +231,9 @@ class HttpExecutor:
                 headers[authorization.config.header] = f'Basic {authorization.config.api_key}'
             elif self.authorization.config.type == 'custom':
                 headers[authorization.config.header] = authorization.config.api_key
-        
+
         return headers
-    
+
     def _validate_and_parse_response(self, response: Union[httpx.Response, requests.Response]) -> HttpExecutorResponse:
         """
             validate the response
@@ -289,21 +242,22 @@ class HttpExecutor:
             executor_response = HttpExecutorResponse(response)
         else:
             raise ValueError(f'Invalid response type {type(response)}')
-        
+
         if executor_response.is_file:
             if executor_response.size > MAX_BINARY_SIZE:
-                raise ValueError(f'File size is too large, max size is {READABLE_MAX_BINARY_SIZE}, but current size is {executor_response.readable_size}.')
+                raise ValueError(
+                    f'File size is too large, max size is {READABLE_MAX_BINARY_SIZE}, but current size is {executor_response.readable_size}.')
         else:
             if executor_response.size > MAX_TEXT_SIZE:
-                raise ValueError(f'Text size is too large, max size is {READABLE_MAX_TEXT_SIZE}, but current size is {executor_response.readable_size}.')
-        
+                raise ValueError(
+                    f'Text size is too large, max size is {READABLE_MAX_TEXT_SIZE}, but current size is {executor_response.readable_size}.')
+
         return executor_response
-        
+
     def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response:
         """
             do http request depending on api bundle
         """
-        # do http request
         kwargs = {
             'url': self.server_url,
             'headers': headers,
@@ -312,25 +266,14 @@ class HttpExecutor:
             'follow_redirects': True
         }
 
-        if self.method == 'get':
-            response = ssrf_proxy.get(**kwargs)
-        elif self.method == 'post':
-            response = ssrf_proxy.post(data=self.body, files=self.files, **kwargs)
-        elif self.method == 'put':
-            response = ssrf_proxy.put(data=self.body, files=self.files, **kwargs)
-        elif self.method == 'delete':
-            response = ssrf_proxy.delete(data=self.body, files=self.files, **kwargs)
-        elif self.method == 'patch':
-            response = ssrf_proxy.patch(data=self.body, files=self.files, **kwargs)
-        elif self.method == 'head':
-            response = ssrf_proxy.head(**kwargs)
-        elif self.method == 'options':
-            response = ssrf_proxy.options(**kwargs)
+        if self.method in ('get', 'head', 'options'):
+            response = getattr(ssrf_proxy, self.method)(**kwargs)
+        elif self.method in ('post', 'put', 'delete', 'patch'):
+            response = getattr(ssrf_proxy, self.method)(data=self.body, files=self.files, **kwargs)
         else:
             raise ValueError(f'Invalid http method {self.method}')
-        
         return response
-    
+
     def invoke(self) -> HttpExecutorResponse:
         """
         invoke http request
@@ -343,14 +286,11 @@ class HttpExecutor:
 
         # validate response
         return self._validate_and_parse_response(response)
-    
+
     def to_raw_request(self, mask_authorization_header: Optional[bool] = True) -> str:
         """
         convert to raw request
         """
-        if mask_authorization_header == None:
-            mask_authorization_header = True
-            
         server_url = self.server_url
         if self.params:
             server_url += f'?{urlencode(self.params)}'
@@ -365,11 +305,11 @@ class HttpExecutor:
                     authorization_header = 'Authorization'
                     if self.authorization.config and self.authorization.config.header:
                         authorization_header = self.authorization.config.header
-                    
+
                     if k.lower() == authorization_header.lower():
                         raw_request += f'{k}: {"*" * len(v)}\n'
                         continue
-            
+
             raw_request += f'{k}: {v}\n'
 
         raw_request += '\n'