瀏覽代碼

security/SSRF vulns (#6682)

Yeuoly 9 月之前
父節點
當前提交
79cb23e8ac
共有 3 個文件被更改,包括 13 次插入28 次删除
  1. 5 2
      api/core/helper/ssrf_proxy.py
  2. 2 3
      api/core/rag/extractor/extract_processor.py
  3. 6 23
      api/core/tools/utils/web_reader_tool.py

+ 5 - 2
api/core/helper/ssrf_proxy.py

@@ -17,12 +17,15 @@ proxies = {
     'https://': SSRF_PROXY_HTTPS_URL
 } if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL else None
 
-
 BACKOFF_FACTOR = 0.5
 STATUS_FORCELIST = [429, 500, 502, 503, 504]
 
-
 def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
+    if "allow_redirects" in kwargs:
+        allow_redirects = kwargs.pop("allow_redirects")
+        if "follow_redirects" not in kwargs:
+            kwargs["follow_redirects"] = allow_redirects
+    
     retries = 0
     while retries <= max_retries:
         try:

+ 2 - 3
api/core/rag/extractor/extract_processor.py

@@ -4,9 +4,8 @@ from pathlib import Path
 from typing import Union
 from urllib.parse import unquote
 
-import requests
-
 from configs import dify_config
+from core.helper import ssrf_proxy
 from core.rag.extractor.csv_extractor import CSVExtractor
 from core.rag.extractor.entity.datasource_type import DatasourceType
 from core.rag.extractor.entity.extract_setting import ExtractSetting
@@ -51,7 +50,7 @@ class ExtractProcessor:
 
     @classmethod
     def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]:
-        response = requests.get(url, headers={
+        response = ssrf_proxy.get(url, headers={
             "User-Agent": USER_AGENT
         })
 

+ 6 - 23
api/core/tools/utils/web_reader_tool.py

@@ -11,11 +11,10 @@ from contextlib import contextmanager
 from urllib.parse import unquote
 
 import cloudscraper
-import requests
 from bs4 import BeautifulSoup, CData, Comment, NavigableString
-from newspaper import Article
 from regex import regex
 
+from core.helper import ssrf_proxy
 from core.rag.extractor import extract_processor
 from core.rag.extractor.extract_processor import ExtractProcessor
 
@@ -45,7 +44,7 @@ def get_url(url: str, user_agent: str = None) -> str:
 
     main_content_type = None
     supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
-    response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10))
+    response = ssrf_proxy.head(url, headers=headers, follow_redirects=True, timeout=(5, 10))
 
     if response.status_code == 200:
         # check content-type
@@ -67,10 +66,11 @@ def get_url(url: str, user_agent: str = None) -> str:
         if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
             return ExtractProcessor.load_from_url(url, return_text=True)
 
-        response = requests.get(url, headers=headers, allow_redirects=True, timeout=(120, 300))
+        response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
     elif response.status_code == 403:
         scraper = cloudscraper.create_scraper()
-        response = scraper.get(url, headers=headers, allow_redirects=True, timeout=(120, 300))
+        scraper.perform_request = ssrf_proxy.make_request
+        response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
 
     if response.status_code != 200:
         return "URL returned status code {}.".format(response.status_code)
@@ -78,7 +78,7 @@ def get_url(url: str, user_agent: str = None) -> str:
     a = extract_using_readabilipy(response.text)
 
     if not a['plain_text'] or not a['plain_text'].strip():
-        return get_url_from_newspaper3k(url)
+        return ''
 
     res = FULL_TEMPLATE.format(
         title=a['title'],
@@ -91,23 +91,6 @@ def get_url(url: str, user_agent: str = None) -> str:
     return res
 
 
-def get_url_from_newspaper3k(url: str) -> str:
-
-    a = Article(url)
-    a.download()
-    a.parse()
-
-    res = FULL_TEMPLATE.format(
-        title=a.title,
-        authors=a.authors,
-        publish_date=a.publish_date,
-        top_image=a.top_image,
-        text=a.text,
-    )
-
-    return res
-
-
 def extract_using_readabilipy(html):
     with tempfile.NamedTemporaryFile(delete=False, mode='w+') as f_html:
         f_html.write(html)