Browse Source

Feat/firecrawl data source (#5232)

Co-authored-by: Nicolas <nicolascamara29@gmail.com>
Co-authored-by: chenhe <guchenhe@gmail.com>
Co-authored-by: takatost <takatost@gmail.com>
Jyong 10 tháng trước cách đây
mục cha
commit
ba5f8afaa8
36 tập tin đã thay đổi với 1174 bổ sung64 xóa
  1. 2 1
      api/.env.example
  2. 2 2
      api/controllers/console/__init__.py
  3. 67 0
      api/controllers/console/auth/data_source_bearer_auth.py
  4. 7 0
      api/controllers/console/auth/error.py
  5. 11 11
      api/controllers/console/datasets/data_source.py
  6. 17 0
      api/controllers/console/datasets/datasets.py
  7. 43 0
      api/controllers/console/datasets/datasets_document.py
  8. 6 0
      api/controllers/console/datasets/error.py
  9. 49 0
      api/controllers/console/datasets/website.py
  10. 18 1
      api/core/indexing_runner.py
  11. 1 1
      api/core/model_runtime/model_providers/togetherai/llm/llm.py
  12. 1 0
      api/core/rag/extractor/entity/datasource_type.py
  13. 24 3
      api/core/rag/extractor/entity/extract_setting.py
  14. 13 0
      api/core/rag/extractor/extract_processor.py
  15. 132 0
      api/core/rag/extractor/firecrawl/firecrawl_app.py
  16. 60 0
      api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py
  17. 6 6
      api/core/rag/extractor/notion_extractor.py
  18. 64 0
      api/libs/bearer_data_source.py
  19. 16 16
      api/libs/oauth_data_source.py
  20. 67 0
      api/migrations/versions/7b45942e39bb_add_api_key_auth_binding.py
  21. 2 2
      api/models/dataset.py
  22. 34 2
      api/models/source.py
  23. 3 0
      api/pyproject.toml
  24. 0 0
      api/services/auth/__init__.py
  25. 10 0
      api/services/auth/api_key_auth_base.py
  26. 14 0
      api/services/auth/api_key_auth_factory.py
  27. 70 0
      api/services/auth/api_key_auth_service.py
  28. 56 0
      api/services/auth/firecrawl.py
  29. 79 13
      api/services/dataset_service.py
  30. 171 0
      api/services/website_service.py
  31. 6 6
      api/tasks/document_indexing_sync_task.py
  32. 90 0
      api/tasks/sync_website_document_indexing_task.py
  33. 0 0
      api/tests/unit_tests/core/rag/extractor/firecrawl/__init__.py
  34. 33 0
      api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py
  35. 0 0
      api/tests/unit_tests/oss/__init__.py
  36. 0 0
      api/tests/unit_tests/oss/local/__init__.py

+ 2 - 1
api/.env.example

@@ -215,4 +215,5 @@ WORKFLOW_MAX_EXECUTION_TIME=1200
 WORKFLOW_CALL_MAX_DEPTH=5
 
 # App configuration
-APP_MAX_EXECUTION_TIME=1200
+APP_MAX_EXECUTION_TIME=1200
+

+ 2 - 2
api/controllers/console/__init__.py

@@ -29,13 +29,13 @@ from .app import (
 )
 
 # Import auth controllers
-from .auth import activate, data_source_oauth, login, oauth
+from .auth import activate, data_source_bearer_auth, data_source_oauth, login, oauth
 
 # Import billing controllers
 from .billing import billing
 
 # Import datasets controllers
-from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing
+from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing, website
 
 # Import explore controllers
 from .explore import (

+ 67 - 0
api/controllers/console/auth/data_source_bearer_auth.py

@@ -0,0 +1,67 @@
+from flask_login import current_user
+from flask_restful import Resource, reqparse
+from werkzeug.exceptions import Forbidden
+
+from controllers.console import api
+from controllers.console.auth.error import ApiKeyAuthFailedError
+from libs.login import login_required
+from services.auth.api_key_auth_service import ApiKeyAuthService
+
+from ..setup import setup_required
+from ..wraps import account_initialization_required
+
+
+class ApiKeyAuthDataSource(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self):
+        # The role of the current user in the table must be admin or owner
+        if not current_user.is_admin_or_owner:
+            raise Forbidden()
+        data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id)
+        if data_source_api_key_bindings:
+            return {
+                'settings': [data_source_api_key_binding.to_dict() for data_source_api_key_binding in
+                             data_source_api_key_bindings]}
+        return {'settings': []}
+
+
+class ApiKeyAuthDataSourceBinding(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def post(self):
+        # The role of the current user in the table must be admin or owner
+        if not current_user.is_admin_or_owner:
+            raise Forbidden()
+        parser = reqparse.RequestParser()
+        parser.add_argument('category', type=str, required=True, nullable=False, location='json')
+        parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
+        parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
+        args = parser.parse_args()
+        ApiKeyAuthService.validate_api_key_auth_args(args)
+        try:
+            ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args)
+        except Exception as e:
+            raise ApiKeyAuthFailedError(str(e))
+        return {'result': 'success'}, 200
+
+
+class ApiKeyAuthDataSourceBindingDelete(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def delete(self, binding_id):
+        # The role of the current user in the table must be admin or owner
+        if not current_user.is_admin_or_owner:
+            raise Forbidden()
+
+        ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id)
+
+        return {'result': 'success'}, 200
+
+
+api.add_resource(ApiKeyAuthDataSource, '/api-key-auth/data-source')
+api.add_resource(ApiKeyAuthDataSourceBinding, '/api-key-auth/data-source/binding')
+api.add_resource(ApiKeyAuthDataSourceBindingDelete, '/api-key-auth/data-source/<uuid:binding_id>')

+ 7 - 0
api/controllers/console/auth/error.py

@@ -0,0 +1,7 @@
+from libs.exception import BaseHTTPException
+
+
+class ApiKeyAuthFailedError(BaseHTTPException):
+    error_code = 'auth_failed'
+    description = "{message}"
+    code = 500

+ 11 - 11
api/controllers/console/datasets/data_source.py

@@ -16,7 +16,7 @@ from extensions.ext_database import db
 from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
 from libs.login import login_required
 from models.dataset import Document
-from models.source import DataSourceBinding
+from models.source import DataSourceOauthBinding
 from services.dataset_service import DatasetService, DocumentService
 from tasks.document_indexing_sync_task import document_indexing_sync_task
 
@@ -29,9 +29,9 @@ class DataSourceApi(Resource):
     @marshal_with(integrate_list_fields)
     def get(self):
         # get workspace data source integrates
-        data_source_integrates = db.session.query(DataSourceBinding).filter(
-            DataSourceBinding.tenant_id == current_user.current_tenant_id,
-            DataSourceBinding.disabled == False
+        data_source_integrates = db.session.query(DataSourceOauthBinding).filter(
+            DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
+            DataSourceOauthBinding.disabled == False
         ).all()
 
         base_url = request.url_root.rstrip('/')
@@ -71,7 +71,7 @@ class DataSourceApi(Resource):
     def patch(self, binding_id, action):
         binding_id = str(binding_id)
         action = str(action)
-        data_source_binding = DataSourceBinding.query.filter_by(
+        data_source_binding = DataSourceOauthBinding.query.filter_by(
             id=binding_id
         ).first()
         if data_source_binding is None:
@@ -124,7 +124,7 @@ class DataSourceNotionListApi(Resource):
                     data_source_info = json.loads(document.data_source_info)
                     exist_page_ids.append(data_source_info['notion_page_id'])
         # get all authorized pages
-        data_source_bindings = DataSourceBinding.query.filter_by(
+        data_source_bindings = DataSourceOauthBinding.query.filter_by(
             tenant_id=current_user.current_tenant_id,
             provider='notion',
             disabled=False
@@ -163,12 +163,12 @@ class DataSourceNotionApi(Resource):
     def get(self, workspace_id, page_id, page_type):
         workspace_id = str(workspace_id)
         page_id = str(page_id)
-        data_source_binding = DataSourceBinding.query.filter(
+        data_source_binding = DataSourceOauthBinding.query.filter(
             db.and_(
-                DataSourceBinding.tenant_id == current_user.current_tenant_id,
-                DataSourceBinding.provider == 'notion',
-                DataSourceBinding.disabled == False,
-                DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
+                DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
+                DataSourceOauthBinding.provider == 'notion',
+                DataSourceOauthBinding.disabled == False,
+                DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
             )
         ).first()
         if not data_source_binding:

+ 17 - 0
api/controllers/console/datasets/datasets.py

@@ -315,6 +315,22 @@ class DatasetIndexingEstimateApi(Resource):
                         document_model=args['doc_form']
                     )
                     extract_settings.append(extract_setting)
+        elif args['info_list']['data_source_type'] == 'website_crawl':
+            website_info_list = args['info_list']['website_info_list']
+            for url in website_info_list['urls']:
+                extract_setting = ExtractSetting(
+                    datasource_type="website_crawl",
+                    website_info={
+                        "provider": website_info_list['provider'],
+                        "job_id": website_info_list['job_id'],
+                        "url": url,
+                        "tenant_id": current_user.current_tenant_id,
+                        "mode": 'crawl',
+                        "only_main_content": website_info_list['only_main_content']
+                    },
+                    document_model=args['doc_form']
+                )
+                extract_settings.append(extract_setting)
         else:
             raise ValueError('Data source type not support')
         indexing_runner = IndexingRunner()
@@ -519,6 +535,7 @@ class DatasetRetrievalSettingMockApi(Resource):
                 raise ValueError(f"Unsupported vector db type {vector_type}.")
 
 
+
 class DatasetErrorDocs(Resource):
     @setup_required
     @login_required

+ 43 - 0
api/controllers/console/datasets/datasets_document.py

@@ -465,6 +465,20 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
                     document_model=document.doc_form
                 )
                 extract_settings.append(extract_setting)
+            elif document.data_source_type == 'website_crawl':
+                extract_setting = ExtractSetting(
+                    datasource_type="website_crawl",
+                    website_info={
+                        "provider": data_source_info['provider'],
+                        "job_id": data_source_info['job_id'],
+                        "url": data_source_info['url'],
+                        "tenant_id": current_user.current_tenant_id,
+                        "mode": data_source_info['mode'],
+                        "only_main_content": data_source_info['only_main_content']
+                    },
+                    document_model=document.doc_form
+                )
+                extract_settings.append(extract_setting)
 
             else:
                 raise ValueError('Data source type not support')
@@ -952,6 +966,33 @@ class DocumentRenameApi(DocumentResource):
         return document
 
 
+class WebsiteDocumentSyncApi(DocumentResource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self, dataset_id, document_id):
+        """sync website document."""
+        dataset_id = str(dataset_id)
+        dataset = DatasetService.get_dataset(dataset_id)
+        if not dataset:
+            raise NotFound('Dataset not found.')
+        document_id = str(document_id)
+        document = DocumentService.get_document(dataset.id, document_id)
+        if not document:
+            raise NotFound('Document not found.')
+        if document.tenant_id != current_user.current_tenant_id:
+            raise Forbidden('No permission.')
+        if document.data_source_type != 'website_crawl':
+            raise ValueError('Document is not a website document.')
+        # 403 if document is archived
+        if DocumentService.check_archived(document):
+            raise ArchivedDocumentImmutableError()
+        # sync document
+        DocumentService.sync_website_document(dataset_id, document)
+
+        return {'result': 'success'}, 200
+
+
 api.add_resource(GetProcessRuleApi, '/datasets/process-rule')
 api.add_resource(DatasetDocumentListApi,
                  '/datasets/<uuid:dataset_id>/documents')
@@ -980,3 +1021,5 @@ api.add_resource(DocumentRecoverApi, '/datasets/<uuid:dataset_id>/documents/<uui
 api.add_resource(DocumentRetryApi, '/datasets/<uuid:dataset_id>/retry')
 api.add_resource(DocumentRenameApi,
                  '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename')
+
+api.add_resource(WebsiteDocumentSyncApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync')

+ 6 - 0
api/controllers/console/datasets/error.py

@@ -73,6 +73,12 @@ class InvalidMetadataError(BaseHTTPException):
     code = 400
 
 
+class WebsiteCrawlError(BaseHTTPException):
+    error_code = 'crawl_failed'
+    description = "{message}"
+    code = 500
+
+
 class DatasetInUseError(BaseHTTPException):
     error_code = 'dataset_in_use'
     description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it."

+ 49 - 0
api/controllers/console/datasets/website.py

@@ -0,0 +1,49 @@
+from flask_restful import Resource, reqparse
+
+from controllers.console import api
+from controllers.console.datasets.error import WebsiteCrawlError
+from controllers.console.setup import setup_required
+from controllers.console.wraps import account_initialization_required
+from libs.login import login_required
+from services.website_service import WebsiteService
+
+
+class WebsiteCrawlApi(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def post(self):
+        parser = reqparse.RequestParser()
+        parser.add_argument('provider', type=str, choices=['firecrawl'],
+                            required=True, nullable=True, location='json')
+        parser.add_argument('url', type=str, required=True, nullable=True, location='json')
+        parser.add_argument('options', type=dict, required=True, nullable=True, location='json')
+        args = parser.parse_args()
+        WebsiteService.document_create_args_validate(args)
+        # crawl url
+        try:
+            result = WebsiteService.crawl_url(args)
+        except Exception as e:
+            raise WebsiteCrawlError(str(e))
+        return result, 200
+
+
+class WebsiteCrawlStatusApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self, job_id: str):
+        parser = reqparse.RequestParser()
+        parser.add_argument('provider', type=str, choices=['firecrawl'], required=True, location='args')
+        args = parser.parse_args()
+        # get crawl status
+        try:
+            result = WebsiteService.get_crawl_status(job_id, args['provider'])
+        except Exception as e:
+            raise WebsiteCrawlError(str(e))
+        return result, 200
+
+
+api.add_resource(WebsiteCrawlApi, '/website/crawl')
+api.add_resource(WebsiteCrawlStatusApi, '/website/crawl/status/<string:job_id>')

+ 18 - 1
api/core/indexing_runner.py

@@ -339,7 +339,7 @@ class IndexingRunner:
     def _extract(self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict) \
             -> list[Document]:
         # load file
-        if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
+        if dataset_document.data_source_type not in ["upload_file", "notion_import", "website_crawl"]:
             return []
 
         data_source_info = dataset_document.data_source_info_dict
@@ -375,6 +375,23 @@ class IndexingRunner:
                 document_model=dataset_document.doc_form
             )
             text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
+        elif dataset_document.data_source_type == 'website_crawl':
+            if (not data_source_info or 'provider' not in data_source_info
+                    or 'url' not in data_source_info or 'job_id' not in data_source_info):
+                raise ValueError("no website import info found")
+            extract_setting = ExtractSetting(
+                datasource_type="website_crawl",
+                website_info={
+                    "provider": data_source_info['provider'],
+                    "job_id": data_source_info['job_id'],
+                    "tenant_id": dataset_document.tenant_id,
+                    "url": data_source_info['url'],
+                    "mode": data_source_info['mode'],
+                    "only_main_content": data_source_info['only_main_content']
+                },
+                document_model=dataset_document.doc_form
+            )
+            text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
         # update document status to splitting
         self._update_document_index_status(
             document_id=dataset_document.id,

+ 1 - 1
api/core/model_runtime/model_providers/togetherai/llm/llm.py

@@ -124,7 +124,7 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
                     default=float(credentials.get('presence_penalty', 0)),
                     min=-2,
                     max=2
-                )
+                ),
             ],
             pricing=PriceConfig(
                 input=Decimal(cred_with_endpoint.get('input_price', 0)),

+ 1 - 0
api/core/rag/extractor/entity/datasource_type.py

@@ -4,3 +4,4 @@ from enum import Enum
 class DatasourceType(Enum):
     FILE = "upload_file"
     NOTION = "notion_import"
+    WEBSITE = "website_crawl"

+ 24 - 3
api/core/rag/extractor/entity/extract_setting.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 from pydantic import BaseModel, ConfigDict
 
 from models.dataset import Document
@@ -19,14 +21,33 @@ class NotionInfo(BaseModel):
         super().__init__(**data)
 
 
+class WebsiteInfo(BaseModel):
+    """
+    website import info.
+    """
+    provider: str
+    job_id: str
+    url: str
+    mode: str
+    tenant_id: str
+    only_main_content: bool = False
+
+    class Config:
+        arbitrary_types_allowed = True
+
+    def __init__(self, **data) -> None:
+        super().__init__(**data)
+
+
 class ExtractSetting(BaseModel):
     """
     Model class for provider response.
     """
     datasource_type: str
-    upload_file: UploadFile = None
-    notion_info: NotionInfo = None
-    document_model: str = None
+    upload_file: Optional[UploadFile]
+    notion_info: Optional[NotionInfo]
+    website_info: Optional[WebsiteInfo]
+    document_model: Optional[str]
     model_config = ConfigDict(arbitrary_types_allowed=True)
 
     def __init__(self, **data) -> None:

+ 13 - 0
api/core/rag/extractor/extract_processor.py

@@ -11,6 +11,7 @@ from core.rag.extractor.csv_extractor import CSVExtractor
 from core.rag.extractor.entity.datasource_type import DatasourceType
 from core.rag.extractor.entity.extract_setting import ExtractSetting
 from core.rag.extractor.excel_extractor import ExcelExtractor
+from core.rag.extractor.firecrawl.firecrawl_web_extractor import FirecrawlWebExtractor
 from core.rag.extractor.html_extractor import HtmlExtractor
 from core.rag.extractor.markdown_extractor import MarkdownExtractor
 from core.rag.extractor.notion_extractor import NotionExtractor
@@ -154,5 +155,17 @@ class ExtractProcessor:
                 tenant_id=extract_setting.notion_info.tenant_id,
             )
             return extractor.extract()
+        elif extract_setting.datasource_type == DatasourceType.WEBSITE.value:
+            if extract_setting.website_info.provider == 'firecrawl':
+                extractor = FirecrawlWebExtractor(
+                    url=extract_setting.website_info.url,
+                    job_id=extract_setting.website_info.job_id,
+                    tenant_id=extract_setting.website_info.tenant_id,
+                    mode=extract_setting.website_info.mode,
+                    only_main_content=extract_setting.website_info.only_main_content
+                )
+                return extractor.extract()
+            else:
+                raise ValueError(f"Unsupported website provider: {extract_setting.website_info.provider}")
         else:
             raise ValueError(f"Unsupported datasource type: {extract_setting.datasource_type}")

+ 132 - 0
api/core/rag/extractor/firecrawl/firecrawl_app.py

@@ -0,0 +1,132 @@
+import json
+import time
+
+import requests
+
+from extensions.ext_storage import storage
+
+
+class FirecrawlApp:
+    def __init__(self, api_key=None, base_url=None):
+        self.api_key = api_key
+        self.base_url = base_url or 'https://api.firecrawl.dev'
+        if self.api_key is None and self.base_url == 'https://api.firecrawl.dev':
+            raise ValueError('No API key provided')
+
+    def scrape_url(self, url, params=None) -> dict:
+        headers = {
+            'Content-Type': 'application/json',
+            'Authorization': f'Bearer {self.api_key}'
+        }
+        json_data = {'url': url}
+        if params:
+            json_data.update(params)
+        response = requests.post(
+            f'{self.base_url}/v0/scrape',
+            headers=headers,
+            json=json_data
+        )
+        if response.status_code == 200:
+            response = response.json()
+            if response['success'] == True:
+                data = response['data']
+                return {
+                    'title': data.get('metadata').get('title'),
+                    'description': data.get('metadata').get('description'),
+                    'source_url': data.get('metadata').get('sourceURL'),
+                    'markdown': data.get('markdown')
+                }
+            else:
+                raise Exception(f'Failed to scrape URL. Error: {response["error"]}')
+
+        elif response.status_code in [402, 409, 500]:
+            error_message = response.json().get('error', 'Unknown error occurred')
+            raise Exception(f'Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}')
+        else:
+            raise Exception(f'Failed to scrape URL. Status code: {response.status_code}')
+
+    def crawl_url(self, url, params=None) -> str:
+        start_time = time.time()
+        headers = self._prepare_headers()
+        json_data = {'url': url}
+        if params:
+            json_data.update(params)
+        response = self._post_request(f'{self.base_url}/v0/crawl', json_data, headers)
+        if response.status_code == 200:
+            job_id = response.json().get('jobId')
+            return job_id
+        else:
+            self._handle_error(response, 'start crawl job')
+
+    def check_crawl_status(self, job_id) -> dict:
+        headers = self._prepare_headers()
+        response = self._get_request(f'{self.base_url}/v0/crawl/status/{job_id}', headers)
+        if response.status_code == 200:
+            crawl_status_response = response.json()
+            if crawl_status_response.get('status') == 'completed':
+                total = crawl_status_response.get('total', 0)
+                if total == 0:
+                    raise Exception('Failed to check crawl status. Error: No page found')
+                data = crawl_status_response.get('data', [])
+                url_data_list = []
+                for item in data:
+                    if isinstance(item, dict) and 'metadata' in item and 'markdown' in item:
+                        url_data = {
+                            'title': item.get('metadata').get('title'),
+                            'description': item.get('metadata').get('description'),
+                            'source_url': item.get('metadata').get('sourceURL'),
+                            'markdown': item.get('markdown')
+                        }
+                        url_data_list.append(url_data)
+                if url_data_list:
+                    file_key = 'website_files/' + job_id + '.txt'
+                    if storage.exists(file_key):
+                        storage.delete(file_key)
+                    storage.save(file_key, json.dumps(url_data_list).encode('utf-8'))
+                return {
+                    'status': 'completed',
+                    'total': crawl_status_response.get('total'),
+                    'current': crawl_status_response.get('current'),
+                    'data': url_data_list
+                }
+
+            else:
+                return {
+                    'status': crawl_status_response.get('status'),
+                    'total': crawl_status_response.get('total'),
+                    'current': crawl_status_response.get('current'),
+                    'data': []
+                }
+
+        else:
+            self._handle_error(response, 'check crawl status')
+
+    def _prepare_headers(self):
+        return {
+            'Content-Type': 'application/json',
+            'Authorization': f'Bearer {self.api_key}'
+        }
+
+    def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5):
+        for attempt in range(retries):
+            response = requests.post(url, headers=headers, json=data)
+            if response.status_code == 502:
+                time.sleep(backoff_factor * (2 ** attempt))
+            else:
+                return response
+        return response
+
+    def _get_request(self, url, headers, retries=3, backoff_factor=0.5):
+        for attempt in range(retries):
+            response = requests.get(url, headers=headers)
+            if response.status_code == 502:
+                time.sleep(backoff_factor * (2 ** attempt))
+            else:
+                return response
+        return response
+
+    def _handle_error(self, response, action):
+        error_message = response.json().get('error', 'Unknown error occurred')
+        raise Exception(f'Failed to {action}. Status code: {response.status_code}. Error: {error_message}')
+
+

+ 60 - 0
api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py

@@ -0,0 +1,60 @@
+from core.rag.extractor.extractor_base import BaseExtractor
+from core.rag.models.document import Document
+from services.website_service import WebsiteService
+
+
+class FirecrawlWebExtractor(BaseExtractor):
+    """
+    Crawl and scrape websites and return content in clean llm-ready markdown. 
+
+
+    Args:
+        url: The URL to scrape.
+        api_key: The API key for Firecrawl.
+        base_url: The base URL for the Firecrawl API. Defaults to 'https://api.firecrawl.dev'.
+        mode: The mode of operation. Defaults to 'scrape'. Options are 'crawl', 'scrape' and 'crawl_return_urls'.
+    """
+
+    def __init__(
+            self,
+            url: str,
+            job_id: str,
+            tenant_id: str,
+            mode: str = 'crawl',
+            only_main_content: bool = False
+    ):
+        """Initialize with url, api_key, base_url and mode."""
+        self._url = url
+        self.job_id = job_id
+        self.tenant_id = tenant_id
+        self.mode = mode
+        self.only_main_content = only_main_content
+
+    def extract(self) -> list[Document]:
+        """Extract content from the URL."""
+        documents = []
+        if self.mode == 'crawl':
+            crawl_data = WebsiteService.get_crawl_url_data(self.job_id, 'firecrawl', self._url, self.tenant_id)
+            if crawl_data is None:
+                return []
+            document = Document(page_content=crawl_data.get('markdown', ''),
+                                metadata={
+                                    'source_url': crawl_data.get('source_url'),
+                                    'description': crawl_data.get('description'),
+                                    'title': crawl_data.get('title')
+                                }
+                                )
+            documents.append(document)
+        elif self.mode == 'scrape':
+            scrape_data = WebsiteService.get_scrape_url_data('firecrawl', self._url, self.tenant_id,
+                                                             self.only_main_content)
+
+            document = Document(page_content=scrape_data.get('markdown', ''),
+                                metadata={
+                                    'source_url': scrape_data.get('source_url'),
+                                    'description': scrape_data.get('description'),
+                                    'title': scrape_data.get('title')
+                                }
+                                )
+            documents.append(document)
+        return documents

+ 6 - 6
api/core/rag/extractor/notion_extractor.py

@@ -9,7 +9,7 @@ from core.rag.extractor.extractor_base import BaseExtractor
 from core.rag.models.document import Document
 from extensions.ext_database import db
 from models.dataset import Document as DocumentModel
-from models.source import DataSourceBinding
+from models.source import DataSourceOauthBinding
 
 logger = logging.getLogger(__name__)
 
@@ -345,12 +345,12 @@ class NotionExtractor(BaseExtractor):
 
     @classmethod
     def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
-        data_source_binding = DataSourceBinding.query.filter(
+        data_source_binding = DataSourceOauthBinding.query.filter(
             db.and_(
-                DataSourceBinding.tenant_id == tenant_id,
-                DataSourceBinding.provider == 'notion',
-                DataSourceBinding.disabled == False,
-                DataSourceBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"'
+                DataSourceOauthBinding.tenant_id == tenant_id,
+                DataSourceOauthBinding.provider == 'notion',
+                DataSourceOauthBinding.disabled == False,
+                DataSourceOauthBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"'
             )
         ).first()
 

+ 64 - 0
api/libs/bearer_data_source.py

@@ -0,0 +1,64 @@
+# [REVIEW] Implement if Needed? Do we need a new type of data source
+from abc import abstractmethod
+
+import requests
+from api.models.source import DataSourceBearerBinding
+from flask_login import current_user
+
+from extensions.ext_database import db
+
+
+class BearerDataSource:
+    def __init__(self, api_key: str, api_base_url: str):
+        self.api_key = api_key
+        self.api_base_url = api_base_url
+
+    @abstractmethod
+    def validate_bearer_data_source(self):
+        """
+        Validate the data source
+        """
+
+
+class FireCrawlDataSource(BearerDataSource):
+    def validate_bearer_data_source(self):
+        TEST_CRAWL_SITE_URL = "https://www.google.com"
+        FIRECRAWL_API_VERSION = "v0"
+
+        test_api_endpoint = self.api_base_url.rstrip('/') + f"/{FIRECRAWL_API_VERSION}/scrape"
+
+        headers = {
+            "Authorization": f"Bearer {self.api_key}",
+            "Content-Type": "application/json",
+        }
+
+        data = {
+            "url": TEST_CRAWL_SITE_URL,
+        }
+
+        response = requests.get(test_api_endpoint, headers=headers, json=data)
+
+        return response.json().get("status") == "success"
+
+    def save_credentials(self):
+        # save data source binding
+        data_source_binding = DataSourceBearerBinding.query.filter(
+            db.and_(
+                DataSourceBearerBinding.tenant_id == current_user.current_tenant_id,
+                DataSourceBearerBinding.provider == 'firecrawl',
+                DataSourceBearerBinding.endpoint_url == self.api_base_url,
+                DataSourceBearerBinding.bearer_key == self.api_key
+            )
+        ).first()
+        if data_source_binding:
+            data_source_binding.disabled = False
+            db.session.commit()
+        else:
+            new_data_source_binding = DataSourceBearerBinding(
+                tenant_id=current_user.current_tenant_id,
+                provider='firecrawl',
+                endpoint_url=self.api_base_url,
+                bearer_key=self.api_key
+            )
+            db.session.add(new_data_source_binding)
+            db.session.commit()

+ 16 - 16
api/libs/oauth_data_source.py

@@ -4,7 +4,7 @@ import requests
 from flask_login import current_user
 
 from extensions.ext_database import db
-from models.source import DataSourceBinding
+from models.source import DataSourceOauthBinding
 
 
 class OAuthDataSource:
@@ -63,11 +63,11 @@ class NotionOAuth(OAuthDataSource):
             'total': len(pages)
         }
         # save data source binding
-        data_source_binding = DataSourceBinding.query.filter(
+        data_source_binding = DataSourceOauthBinding.query.filter(
             db.and_(
-                DataSourceBinding.tenant_id == current_user.current_tenant_id,
-                DataSourceBinding.provider == 'notion',
-                DataSourceBinding.access_token == access_token
+                DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
+                DataSourceOauthBinding.provider == 'notion',
+                DataSourceOauthBinding.access_token == access_token
             )
         ).first()
         if data_source_binding:
@@ -75,7 +75,7 @@ class NotionOAuth(OAuthDataSource):
             data_source_binding.disabled = False
             db.session.commit()
         else:
-            new_data_source_binding = DataSourceBinding(
+            new_data_source_binding = DataSourceOauthBinding(
                 tenant_id=current_user.current_tenant_id,
                 access_token=access_token,
                 source_info=source_info,
@@ -98,11 +98,11 @@ class NotionOAuth(OAuthDataSource):
             'total': len(pages)
         }
         # save data source binding
-        data_source_binding = DataSourceBinding.query.filter(
+        data_source_binding = DataSourceOauthBinding.query.filter(
             db.and_(
-                DataSourceBinding.tenant_id == current_user.current_tenant_id,
-                DataSourceBinding.provider == 'notion',
-                DataSourceBinding.access_token == access_token
+                DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
+                DataSourceOauthBinding.provider == 'notion',
+                DataSourceOauthBinding.access_token == access_token
             )
         ).first()
         if data_source_binding:
@@ -110,7 +110,7 @@ class NotionOAuth(OAuthDataSource):
             data_source_binding.disabled = False
             db.session.commit()
         else:
-            new_data_source_binding = DataSourceBinding(
+            new_data_source_binding = DataSourceOauthBinding(
                 tenant_id=current_user.current_tenant_id,
                 access_token=access_token,
                 source_info=source_info,
@@ -121,12 +121,12 @@ class NotionOAuth(OAuthDataSource):
 
     def sync_data_source(self, binding_id: str):
         # save data source binding
-        data_source_binding = DataSourceBinding.query.filter(
+        data_source_binding = DataSourceOauthBinding.query.filter(
             db.and_(
-                DataSourceBinding.tenant_id == current_user.current_tenant_id,
-                DataSourceBinding.provider == 'notion',
-                DataSourceBinding.id == binding_id,
-                DataSourceBinding.disabled == False
+                DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
+                DataSourceOauthBinding.provider == 'notion',
+                DataSourceOauthBinding.id == binding_id,
+                DataSourceOauthBinding.disabled == False
             )
         ).first()
         if data_source_binding:

+ 67 - 0
api/migrations/versions/7b45942e39bb_add_api_key_auth_binding.py

@@ -0,0 +1,67 @@
+"""add-api-key-auth-binding
+
+Revision ID: 7b45942e39bb
+Revises: 47cc7df8c4f3
+Create Date: 2024-05-14 07:31:29.702766
+
+"""
+import sqlalchemy as sa
+from alembic import op
+
+import models as models
+
+# revision identifiers, used by Alembic.
+revision = '7b45942e39bb'
+down_revision = '4e99a8df00ff'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    op.create_table('data_source_api_key_auth_bindings',
+    sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+    sa.Column('tenant_id', models.StringUUID(), nullable=False),
+    sa.Column('category', sa.String(length=255), nullable=False),
+    sa.Column('provider', sa.String(length=255), nullable=False),
+    sa.Column('credentials', sa.Text(), nullable=True),
+    sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
+    sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
+    sa.Column('disabled', sa.Boolean(), server_default=sa.text('false'), nullable=True),
+    sa.PrimaryKeyConstraint('id', name='data_source_api_key_auth_binding_pkey')
+    )
+    with op.batch_alter_table('data_source_api_key_auth_bindings', schema=None) as batch_op:
+        batch_op.create_index('data_source_api_key_auth_binding_provider_idx', ['provider'], unique=False)
+        batch_op.create_index('data_source_api_key_auth_binding_tenant_id_idx', ['tenant_id'], unique=False)
+
+    with op.batch_alter_table('data_source_bindings', schema=None) as batch_op:
+        batch_op.drop_index('source_binding_tenant_id_idx')
+        batch_op.drop_index('source_info_idx')
+
+    op.rename_table('data_source_bindings', 'data_source_oauth_bindings')
+
+    with op.batch_alter_table('data_source_oauth_bindings', schema=None) as batch_op:
+        batch_op.create_index('source_binding_tenant_id_idx', ['tenant_id'], unique=False)
+        batch_op.create_index('source_info_idx', ['source_info'], unique=False, postgresql_using='gin')
+    # ### end Alembic commands ###
+
+
+def downgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+
+    with op.batch_alter_table('data_source_oauth_bindings', schema=None) as batch_op:
+        batch_op.drop_index('source_info_idx', postgresql_using='gin')
+        batch_op.drop_index('source_binding_tenant_id_idx')
+
+    op.rename_table('data_source_oauth_bindings', 'data_source_bindings')
+
+    with op.batch_alter_table('data_source_bindings', schema=None) as batch_op:
+        batch_op.create_index('source_info_idx', ['source_info'], unique=False)
+        batch_op.create_index('source_binding_tenant_id_idx', ['tenant_id'], unique=False)
+
+    with op.batch_alter_table('data_source_api_key_auth_bindings', schema=None) as batch_op:
+        batch_op.drop_index('data_source_api_key_auth_binding_tenant_id_idx')
+        batch_op.drop_index('data_source_api_key_auth_binding_provider_idx')
+
+    op.drop_table('data_source_api_key_auth_bindings')
+    # ### end Alembic commands ###

+ 2 - 2
api/models/dataset.py

@@ -270,7 +270,7 @@ class Document(db.Model):
         255), nullable=False, server_default=db.text("'text_model'::character varying"))
     doc_language = db.Column(db.String(255), nullable=True)
 
-    DATA_SOURCES = ['upload_file', 'notion_import']
+    DATA_SOURCES = ['upload_file', 'notion_import', 'website_crawl']
 
     @property
     def display_status(self):
@@ -322,7 +322,7 @@ class Document(db.Model):
                             'created_at': file_detail.created_at.timestamp()
                         }
                     }
-            elif self.data_source_type == 'notion_import':
+            elif self.data_source_type == 'notion_import' or self.data_source_type == 'website_crawl':
                 return json.loads(self.data_source_info)
         return {}
 

+ 34 - 2
api/models/source.py

@@ -1,11 +1,13 @@
+import json
+
 from sqlalchemy.dialects.postgresql import JSONB
 
 from extensions.ext_database import db
 from models import StringUUID
 
 
-class DataSourceBinding(db.Model):
-    __tablename__ = 'data_source_bindings'
+class DataSourceOauthBinding(db.Model):
+    __tablename__ = 'data_source_oauth_bindings'
     __table_args__ = (
         db.PrimaryKeyConstraint('id', name='source_binding_pkey'),
         db.Index('source_binding_tenant_id_idx', 'tenant_id'),
@@ -20,3 +22,33 @@ class DataSourceBinding(db.Model):
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     disabled = db.Column(db.Boolean, nullable=True, server_default=db.text('false'))
+
+
+class DataSourceApiKeyAuthBinding(db.Model):
+    __tablename__ = 'data_source_api_key_auth_bindings'
+    __table_args__ = (
+        db.PrimaryKeyConstraint('id', name='data_source_api_key_auth_binding_pkey'),
+        db.Index('data_source_api_key_auth_binding_tenant_id_idx', 'tenant_id'),
+        db.Index('data_source_api_key_auth_binding_provider_idx', 'provider'),
+    )
+
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    tenant_id = db.Column(StringUUID, nullable=False)
+    category = db.Column(db.String(255), nullable=False)
+    provider = db.Column(db.String(255), nullable=False)
+    credentials = db.Column(db.Text, nullable=True)  # JSON
+    created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
+    updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
+    disabled = db.Column(db.Boolean, nullable=True, server_default=db.text('false'))
+
+    def to_dict(self):
+        return {
+            'id': self.id,
+            'tenant_id': self.tenant_id,
+            'category': self.category,
+            'provider': self.provider,
+            'credentials': json.loads(self.credentials),
+            'created_at': self.created_at.timestamp(),
+            'updated_at': self.updated_at.timestamp(),
+            'disabled': self.disabled
+        }

+ 3 - 0
api/pyproject.toml

@@ -78,6 +78,9 @@ CODE_MAX_STRING_LENGTH = "80000"
 CODE_EXECUTION_ENDPOINT="http://127.0.0.1:8194"
 CODE_EXECUTION_API_KEY="dify-sandbox"
 
+FIRECRAWL_API_KEY = "fc-"
+
+
 
 [tool.poetry]
 name = "dify-api"

+ 0 - 0
api/services/auth/__init__.py


+ 10 - 0
api/services/auth/api_key_auth_base.py

@@ -0,0 +1,10 @@
+from abc import ABC, abstractmethod
+
+
+class ApiKeyAuthBase(ABC):
+    def __init__(self, credentials: dict):
+        self.credentials = credentials
+
+    @abstractmethod
+    def validate_credentials(self):
+        raise NotImplementedError

+ 14 - 0
api/services/auth/api_key_auth_factory.py

@@ -0,0 +1,14 @@
+
+from services.auth.firecrawl import FirecrawlAuth
+
+
+class ApiKeyAuthFactory:
+
+    def __init__(self, provider: str, credentials: dict):
+        if provider == 'firecrawl':
+            self.auth = FirecrawlAuth(credentials)
+        else:
+            raise ValueError('Invalid provider')
+
+    def validate_credentials(self):
+        return self.auth.validate_credentials()

+ 70 - 0
api/services/auth/api_key_auth_service.py

@@ -0,0 +1,70 @@
+import json
+
+from core.helper import encrypter
+from extensions.ext_database import db
+from models.source import DataSourceApiKeyAuthBinding
+from services.auth.api_key_auth_factory import ApiKeyAuthFactory
+
+
+class ApiKeyAuthService:
+
+    @staticmethod
+    def get_provider_auth_list(tenant_id: str) -> list:
+        data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter(
+            DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
+            DataSourceApiKeyAuthBinding.disabled.is_(False)
+        ).all()
+        return data_source_api_key_bindings
+
+    @staticmethod
+    def create_provider_auth(tenant_id: str, args: dict):
+        auth_result = ApiKeyAuthFactory(args['provider'], args['credentials']).validate_credentials()
+        if auth_result:
+            # Encrypt the api key
+            api_key = encrypter.encrypt_token(tenant_id, args['credentials']['config']['api_key'])
+            args['credentials']['config']['api_key'] = api_key
+
+            data_source_api_key_binding = DataSourceApiKeyAuthBinding()
+            data_source_api_key_binding.tenant_id = tenant_id
+            data_source_api_key_binding.category = args['category']
+            data_source_api_key_binding.provider = args['provider']
+            data_source_api_key_binding.credentials = json.dumps(args['credentials'], ensure_ascii=False)
+            db.session.add(data_source_api_key_binding)
+            db.session.commit()
+
+    @staticmethod
+    def get_auth_credentials(tenant_id: str, category: str, provider: str):
+        data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter(
+            DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
+            DataSourceApiKeyAuthBinding.category == category,
+            DataSourceApiKeyAuthBinding.provider == provider,
+            DataSourceApiKeyAuthBinding.disabled.is_(False)
+        ).first()
+        if not data_source_api_key_bindings:
+            return None
+        credentials = json.loads(data_source_api_key_bindings.credentials)
+        return credentials
+
+    @staticmethod
+    def delete_provider_auth(tenant_id: str, binding_id: str):
+        data_source_api_key_binding = db.session.query(DataSourceApiKeyAuthBinding).filter(
+            DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
+            DataSourceApiKeyAuthBinding.id == binding_id
+        ).first()
+        if data_source_api_key_binding:
+            db.session.delete(data_source_api_key_binding)
+            db.session.commit()
+
+    @classmethod
+    def validate_api_key_auth_args(cls, args):
+        if 'category' not in args or not args['category']:
+            raise ValueError('category is required')
+        if 'provider' not in args or not args['provider']:
+            raise ValueError('provider is required')
+        if 'credentials' not in args or not args['credentials']:
+            raise ValueError('credentials is required')
+        if not isinstance(args['credentials'], dict):
+            raise ValueError('credentials must be a dictionary')
+        if 'auth_type' not in args['credentials'] or not args['credentials']['auth_type']:
+            raise ValueError('auth_type is required')
+

+ 56 - 0
api/services/auth/firecrawl.py

@@ -0,0 +1,56 @@
+import json
+
+import requests
+
+from services.auth.api_key_auth_base import ApiKeyAuthBase
+
+
+class FirecrawlAuth(ApiKeyAuthBase):
+    def __init__(self, credentials: dict):
+        super().__init__(credentials)
+        auth_type = credentials.get('auth_type')
+        if auth_type != 'bearer':
+            raise ValueError('Invalid auth type, Firecrawl auth type must be Bearer')
+        self.api_key = credentials.get('config').get('api_key', None)
+        self.base_url = credentials.get('config').get('base_url', 'https://api.firecrawl.dev')
+
+        if not self.api_key:
+            raise ValueError('No API key provided')
+
+    def validate_credentials(self):
+        headers = self._prepare_headers()
+        options = {
+            'url': 'https://example.com',
+            'crawlerOptions': {
+                'excludes': [],
+                'includes': [],
+                'limit': 1
+            },
+            'pageOptions': {
+                'onlyMainContent': True
+            }
+        }
+        response = self._post_request(f'{self.base_url}/v0/crawl', options, headers)
+        if response.status_code == 200:
+            return True
+        else:
+            self._handle_error(response)
+
+    def _prepare_headers(self):
+        return {
+            'Content-Type': 'application/json',
+            'Authorization': f'Bearer {self.api_key}'
+        }
+
+    def _post_request(self, url, data, headers):
+        return requests.post(url, headers=headers, json=data)
+
+    def _handle_error(self, response):
+        if response.status_code in [402, 409, 500]:
+            error_message = response.json().get('error', 'Unknown error occurred')
+            raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}')
+        else:
+            if response.text:
+                error_message = json.loads(response.text).get('error', 'Unknown error occurred')
+                raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}')
+            raise Exception(f'Unexpected error occurred while trying to authorize. Status code: {response.status_code}')

+ 79 - 13
api/services/dataset_service.py

@@ -31,7 +31,7 @@ from models.dataset import (
     DocumentSegment,
 )
 from models.model import UploadFile
-from models.source import DataSourceBinding
+from models.source import DataSourceOauthBinding
 from services.errors.account import NoPermissionError
 from services.errors.dataset import DatasetInUseError, DatasetNameDuplicateError
 from services.errors.document import DocumentIndexingError
@@ -48,6 +48,7 @@ from tasks.document_indexing_update_task import document_indexing_update_task
 from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task
 from tasks.recover_document_indexing_task import recover_document_indexing_task
 from tasks.retry_document_indexing_task import retry_document_indexing_task
+from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task
 
 
 class DatasetService:
@@ -508,18 +509,40 @@ class DocumentService:
     @staticmethod
     def retry_document(dataset_id: str, documents: list[Document]):
         for document in documents:
+            # add retry flag
+            retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id)
+            cache_result = redis_client.get(retry_indexing_cache_key)
+            if cache_result is not None:
+                raise ValueError("Document is being retried, please try again later")
             # retry document indexing
             document.indexing_status = 'waiting'
             db.session.add(document)
             db.session.commit()
-            # add retry flag
-            retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id)
+
             redis_client.setex(retry_indexing_cache_key, 600, 1)
         # trigger async task
         document_ids = [document.id for document in documents]
         retry_document_indexing_task.delay(dataset_id, document_ids)
 
     @staticmethod
+    def sync_website_document(dataset_id: str, document: Document):
+        # add sync flag
+        sync_indexing_cache_key = 'document_{}_is_sync'.format(document.id)
+        cache_result = redis_client.get(sync_indexing_cache_key)
+        if cache_result is not None:
+            raise ValueError("Document is being synced, please try again later")
+        # sync document indexing
+        document.indexing_status = 'waiting'
+        data_source_info = document.data_source_info_dict
+        data_source_info['mode'] = 'scrape'
+        document.data_source_info = json.dumps(data_source_info, ensure_ascii=False)
+        db.session.add(document)
+        db.session.commit()
+
+        redis_client.setex(sync_indexing_cache_key, 600, 1)
+
+        sync_website_document_indexing_task.delay(dataset_id, document.id)
+    @staticmethod
     def get_documents_position(dataset_id):
         document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
         if document:
@@ -545,6 +568,9 @@ class DocumentService:
                     notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
                     for notion_info in notion_info_list:
                         count = count + len(notion_info['pages'])
+                elif document_data["data_source"]["type"] == "website_crawl":
+                    website_info = document_data["data_source"]['info_list']['website_info_list']
+                    count = len(website_info['urls'])
                 batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
                 if count > batch_upload_limit:
                     raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
@@ -683,12 +709,12 @@ class DocumentService:
                         exist_document[data_source_info['notion_page_id']] = document.id
                 for notion_info in notion_info_list:
                     workspace_id = notion_info['workspace_id']
-                    data_source_binding = DataSourceBinding.query.filter(
+                    data_source_binding = DataSourceOauthBinding.query.filter(
                         db.and_(
-                            DataSourceBinding.tenant_id == current_user.current_tenant_id,
-                            DataSourceBinding.provider == 'notion',
-                            DataSourceBinding.disabled == False,
-                            DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
+                            DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
+                            DataSourceOauthBinding.provider == 'notion',
+                            DataSourceOauthBinding.disabled == False,
+                            DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
                         )
                     ).first()
                     if not data_source_binding:
@@ -717,6 +743,28 @@ class DocumentService:
                 # delete not selected documents
                 if len(exist_document) > 0:
                     clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
+            elif document_data["data_source"]["type"] == "website_crawl":
+                website_info = document_data["data_source"]['info_list']['website_info_list']
+                urls = website_info['urls']
+                for url in urls:
+                    data_source_info = {
+                        'url': url,
+                        'provider': website_info['provider'],
+                        'job_id': website_info['job_id'],
+                        'only_main_content': website_info.get('only_main_content', False),
+                        'mode': 'crawl',
+                    }
+                    document = DocumentService.build_document(dataset, dataset_process_rule.id,
+                                                              document_data["data_source"]["type"],
+                                                              document_data["doc_form"],
+                                                              document_data["doc_language"],
+                                                              data_source_info, created_from, position,
+                                                              account, url, batch)
+                    db.session.add(document)
+                    db.session.flush()
+                    document_ids.append(document.id)
+                    documents.append(document)
+                    position += 1
             db.session.commit()
 
             # trigger async task
@@ -818,12 +866,12 @@ class DocumentService:
                 notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
                 for notion_info in notion_info_list:
                     workspace_id = notion_info['workspace_id']
-                    data_source_binding = DataSourceBinding.query.filter(
+                    data_source_binding = DataSourceOauthBinding.query.filter(
                         db.and_(
-                            DataSourceBinding.tenant_id == current_user.current_tenant_id,
-                            DataSourceBinding.provider == 'notion',
-                            DataSourceBinding.disabled == False,
-                            DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
+                            DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
+                            DataSourceOauthBinding.provider == 'notion',
+                            DataSourceOauthBinding.disabled == False,
+                            DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
                         )
                     ).first()
                     if not data_source_binding:
@@ -835,6 +883,17 @@ class DocumentService:
                             "notion_page_icon": page['page_icon'],
                             "type": page['type']
                         }
+            elif document_data["data_source"]["type"] == "website_crawl":
+                website_info = document_data["data_source"]['info_list']['website_info_list']
+                urls = website_info['urls']
+                for url in urls:
+                    data_source_info = {
+                        'url': url,
+                        'provider': website_info['provider'],
+                        'job_id': website_info['job_id'],
+                        'only_main_content': website_info.get('only_main_content', False),
+                        'mode': 'crawl',
+                    }
             document.data_source_type = document_data["data_source"]["type"]
             document.data_source_info = json.dumps(data_source_info)
             document.name = file_name
@@ -873,6 +932,9 @@ class DocumentService:
                 notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
                 for notion_info in notion_info_list:
                     count = count + len(notion_info['pages'])
+            elif document_data["data_source"]["type"] == "website_crawl":
+                website_info = document_data["data_source"]['info_list']['website_info_list']
+                count = len(website_info['urls'])
             batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
             if count > batch_upload_limit:
                 raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
@@ -973,6 +1035,10 @@ class DocumentService:
             if 'notion_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][
                 'notion_info_list']:
                 raise ValueError("Notion source info is required")
+        if args['data_source']['type'] == 'website_crawl':
+            if 'website_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][
+                'website_info_list']:
+                raise ValueError("Website source info is required")
 
     @classmethod
     def process_rule_args_validate(cls, args: dict):

+ 171 - 0
api/services/website_service.py

@@ -0,0 +1,171 @@
+import datetime
+import json
+
+from flask_login import current_user
+
+from core.helper import encrypter
+from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
+from extensions.ext_redis import redis_client
+from extensions.ext_storage import storage
+from services.auth.api_key_auth_service import ApiKeyAuthService
+
+
+class WebsiteService:
+
+    @classmethod
+    def document_create_args_validate(cls, args: dict):
+        if 'url' not in args or not args['url']:
+            raise ValueError('url is required')
+        if 'options' not in args or not args['options']:
+            raise ValueError('options is required')
+        if 'limit' not in args['options'] or not args['options']['limit']:
+            raise ValueError('limit is required')
+
+    @classmethod
+    def crawl_url(cls, args: dict) -> dict:
+        provider = args.get('provider')
+        url = args.get('url')
+        options = args.get('options')
+        credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id,
+                                                             'website',
+                                                             provider)
+        if provider == 'firecrawl':
+            # decrypt api_key
+            api_key = encrypter.decrypt_token(
+                tenant_id=current_user.current_tenant_id,
+                token=credentials.get('config').get('api_key')
+            )
+            firecrawl_app = FirecrawlApp(api_key=api_key,
+                                         base_url=credentials.get('config').get('base_url', None))
+            crawl_sub_pages = options.get('crawl_sub_pages', False)
+            only_main_content = options.get('only_main_content', False)
+            if not crawl_sub_pages:
+                params = {
+                    'crawlerOptions': {
+                        "includes": [],
+                        "excludes": [],
+                        "generateImgAltText": True,
+                        "limit": 1,
+                        'returnOnlyUrls': False,
+                        'pageOptions': {
+                            'onlyMainContent': only_main_content,
+                            "includeHtml": False
+                        }
+                    }
+                }
+            else:
+                includes = options.get('includes').split(',') if options.get('includes') else []
+                excludes = options.get('excludes').split(',') if options.get('excludes') else []
+                params = {
+                    'crawlerOptions': {
+                        "includes": includes if includes else [],
+                        "excludes": excludes if excludes else [],
+                        "generateImgAltText": True,
+                        "limit": options.get('limit', 1),
+                        'returnOnlyUrls': False,
+                        'pageOptions': {
+                            'onlyMainContent': only_main_content,
+                            "includeHtml": False
+                        }
+                    }
+                }
+                if options.get('max_depth'):
+                    params['crawlerOptions']['maxDepth'] = options.get('max_depth')
+            job_id = firecrawl_app.crawl_url(url, params)
+            website_crawl_time_cache_key = f'website_crawl_{job_id}'
+            time = str(datetime.datetime.now().timestamp())
+            redis_client.setex(website_crawl_time_cache_key, 3600, time)
+            return {
+                'status': 'active',
+                'job_id': job_id
+            }
+        else:
+            raise ValueError('Invalid provider')
+
+    @classmethod
+    def get_crawl_status(cls, job_id: str, provider: str) -> dict:
+        credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id,
+                                                             'website',
+                                                             provider)
+        if provider == 'firecrawl':
+            # decrypt api_key
+            api_key = encrypter.decrypt_token(
+                tenant_id=current_user.current_tenant_id,
+                token=credentials.get('config').get('api_key')
+            )
+            firecrawl_app = FirecrawlApp(api_key=api_key,
+                                         base_url=credentials.get('config').get('base_url', None))
+            result = firecrawl_app.check_crawl_status(job_id)
+            crawl_status_data = {
+                'status': result.get('status', 'active'),
+                'job_id': job_id,
+                'total': result.get('total', 0),
+                'current': result.get('current', 0),
+                'data': result.get('data', [])
+            }
+            if crawl_status_data['status'] == 'completed':
+                website_crawl_time_cache_key = f'website_crawl_{job_id}'
+                start_time = redis_client.get(website_crawl_time_cache_key)
+                if start_time:
+                    end_time = datetime.datetime.now().timestamp()
+                    time_consuming = abs(end_time - float(start_time))
+                    crawl_status_data['time_consuming'] = f"{time_consuming:.2f}"
+                    redis_client.delete(website_crawl_time_cache_key)
+        else:
+            raise ValueError('Invalid provider')
+        return crawl_status_data
+
+    @classmethod
+    def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict | None:
+        credentials = ApiKeyAuthService.get_auth_credentials(tenant_id,
+                                                             'website',
+                                                             provider)
+        if provider == 'firecrawl':
+            file_key = 'website_files/' + job_id + '.txt'
+            if storage.exists(file_key):
+                data = storage.load_once(file_key)
+                if data:
+                    data = json.loads(data.decode('utf-8'))
+            else:
+                # decrypt api_key
+                api_key = encrypter.decrypt_token(
+                    tenant_id=tenant_id,
+                    token=credentials.get('config').get('api_key')
+                )
+                firecrawl_app = FirecrawlApp(api_key=api_key,
+                                             base_url=credentials.get('config').get('base_url', None))
+                result = firecrawl_app.check_crawl_status(job_id)
+                if result.get('status') != 'completed':
+                    raise ValueError('Crawl job is not completed')
+                data = result.get('data')
+            if data:
+                for item in data:
+                    if item.get('source_url') == url:
+                        return item
+            return None
+        else:
+            raise ValueError('Invalid provider')
+
+    @classmethod
+    def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict | None:
+        credentials = ApiKeyAuthService.get_auth_credentials(tenant_id,
+                                                             'website',
+                                                             provider)
+        if provider == 'firecrawl':
+            # decrypt api_key
+            api_key = encrypter.decrypt_token(
+                tenant_id=tenant_id,
+                token=credentials.get('config').get('api_key')
+            )
+            firecrawl_app = FirecrawlApp(api_key=api_key,
+                                         base_url=credentials.get('config').get('base_url', None))
+            params = {
+                'pageOptions': {
+                    'onlyMainContent': only_main_content,
+                    "includeHtml": False
+                }
+            }
+            result = firecrawl_app.scrape_url(url, params)
+            return result
+        else:
+            raise ValueError('Invalid provider')

+ 6 - 6
api/tasks/document_indexing_sync_task.py

@@ -11,7 +11,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor
 from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
 from extensions.ext_database import db
 from models.dataset import Dataset, Document, DocumentSegment
-from models.source import DataSourceBinding
+from models.source import DataSourceOauthBinding
 
 
 @shared_task(queue='dataset')
@@ -43,12 +43,12 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
         page_id = data_source_info['notion_page_id']
         page_type = data_source_info['type']
         page_edited_time = data_source_info['last_edited_time']
-        data_source_binding = DataSourceBinding.query.filter(
+        data_source_binding = DataSourceOauthBinding.query.filter(
             db.and_(
-                DataSourceBinding.tenant_id == document.tenant_id,
-                DataSourceBinding.provider == 'notion',
-                DataSourceBinding.disabled == False,
-                DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
+                DataSourceOauthBinding.tenant_id == document.tenant_id,
+                DataSourceOauthBinding.provider == 'notion',
+                DataSourceOauthBinding.disabled == False,
+                DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
             )
         ).first()
         if not data_source_binding:

+ 90 - 0
api/tasks/sync_website_document_indexing_task.py

@@ -0,0 +1,90 @@
+import datetime
+import logging
+import time
+
+import click
+from celery import shared_task
+
+from core.indexing_runner import IndexingRunner
+from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
+from extensions.ext_database import db
+from extensions.ext_redis import redis_client
+from models.dataset import Dataset, Document, DocumentSegment
+from services.feature_service import FeatureService
+
+
+@shared_task(queue='dataset')
+def sync_website_document_indexing_task(dataset_id: str, document_id: str):
+    """
+    Async process document
+    :param dataset_id:
+    :param document_id:
+
+    Usage: sunc_website_document_indexing_task.delay(dataset_id, document_id)
+    """
+    start_at = time.perf_counter()
+
+    dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
+
+    sync_indexing_cache_key = 'document_{}_is_sync'.format(document_id)
+    # check document limit
+    features = FeatureService.get_features(dataset.tenant_id)
+    try:
+        if features.billing.enabled:
+            vector_space = features.vector_space
+            if 0 < vector_space.limit <= vector_space.size:
+                raise ValueError("Your total number of documents plus the number of uploads have over the limit of "
+                                 "your subscription.")
+    except Exception as e:
+        document = db.session.query(Document).filter(
+            Document.id == document_id,
+            Document.dataset_id == dataset_id
+        ).first()
+        if document:
+            document.indexing_status = 'error'
+            document.error = str(e)
+            document.stopped_at = datetime.datetime.utcnow()
+            db.session.add(document)
+            db.session.commit()
+        redis_client.delete(sync_indexing_cache_key)
+        return
+
+    logging.info(click.style('Start sync website document: {}'.format(document_id), fg='green'))
+    document = db.session.query(Document).filter(
+        Document.id == document_id,
+        Document.dataset_id == dataset_id
+    ).first()
+    try:
+        if document:
+            # clean old data
+            index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
+
+            segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
+            if segments:
+                index_node_ids = [segment.index_node_id for segment in segments]
+                # delete from vector index
+                index_processor.clean(dataset, index_node_ids)
+
+                for segment in segments:
+                    db.session.delete(segment)
+                db.session.commit()
+
+            document.indexing_status = 'parsing'
+            document.processing_started_at = datetime.datetime.utcnow()
+            db.session.add(document)
+            db.session.commit()
+
+            indexing_runner = IndexingRunner()
+            indexing_runner.run([document])
+            redis_client.delete(sync_indexing_cache_key)
+    except Exception as ex:
+        document.indexing_status = 'error'
+        document.error = str(ex)
+        document.stopped_at = datetime.datetime.utcnow()
+        db.session.add(document)
+        db.session.commit()
+        logging.info(click.style(str(ex), fg='yellow'))
+        redis_client.delete(sync_indexing_cache_key)
+        pass
+    end_at = time.perf_counter()
+    logging.info(click.style('Sync document: {} latency: {}'.format(document_id, end_at - start_at), fg='green'))

+ 0 - 0
api/tests/unit_tests/core/rag/extractor/firecrawl/__init__.py


+ 33 - 0
api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py

@@ -0,0 +1,33 @@
+import os
+from unittest import mock
+
+from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
+from core.rag.extractor.firecrawl.firecrawl_web_extractor import FirecrawlWebExtractor
+from core.rag.models.document import Document
+from tests.unit_tests.core.rag.extractor.test_notion_extractor import _mock_response
+
+
+def test_firecrawl_web_extractor_crawl_mode(mocker):
+    url = "https://firecrawl.dev"
+    api_key = os.getenv('FIRECRAWL_API_KEY') or 'fc-'
+    base_url = 'https://api.firecrawl.dev'
+    firecrawl_app = FirecrawlApp(api_key=api_key,
+                                 base_url=base_url)
+    params = {
+        'crawlerOptions': {
+            "includes": [],
+            "excludes": [],
+            "generateImgAltText": True,
+            "maxDepth": 1,
+            "limit": 1,
+            'returnOnlyUrls': False,
+
+        }
+    }
+    mocked_firecrawl = {
+        "jobId": "test",
+    }
+    mocker.patch("requests.post", return_value=_mock_response(mocked_firecrawl))
+    job_id = firecrawl_app.crawl_url(url, params)
+    print(job_id)
+    assert isinstance(job_id, str)

+ 0 - 0
api/tests/unit_tests/oss/__init__.py


+ 0 - 0
api/tests/unit_tests/oss/local/__init__.py