Prechádzať zdrojové kódy

feat: added dataset recall testing API (#9300)

Summer-Gu 6 mesiacov pred
rodič
commit
8501af298f

+ 7 - 71
api/controllers/console/datasets/hit_testing.py

@@ -1,88 +1,24 @@
-import logging
+from flask_restful import Resource
 
-from flask_login import current_user
-from flask_restful import Resource, marshal, reqparse
-from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
-
-import services
 from controllers.console import api
-from controllers.console.app.error import (
-    CompletionRequestError,
-    ProviderModelCurrentlyNotSupportError,
-    ProviderNotInitializeError,
-    ProviderQuotaExceededError,
-)
-from controllers.console.datasets.error import DatasetNotInitializedError
+from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
-from core.errors.error import (
-    LLMBadRequestError,
-    ModelCurrentlyNotSupportError,
-    ProviderTokenNotInitError,
-    QuotaExceededError,
-)
-from core.model_runtime.errors.invoke import InvokeError
-from fields.hit_testing_fields import hit_testing_record_fields
 from libs.login import login_required
-from services.dataset_service import DatasetService
-from services.hit_testing_service import HitTestingService
 
 
-class HitTestingApi(Resource):
+class HitTestingApi(Resource, DatasetsHitTestingBase):
     @setup_required
     @login_required
     @account_initialization_required
     def post(self, dataset_id):
         dataset_id_str = str(dataset_id)
 
-        dataset = DatasetService.get_dataset(dataset_id_str)
-        if dataset is None:
-            raise NotFound("Dataset not found.")
-
-        try:
-            DatasetService.check_dataset_permission(dataset, current_user)
-        except services.errors.account.NoPermissionError as e:
-            raise Forbidden(str(e))
-
-        parser = reqparse.RequestParser()
-        parser.add_argument("query", type=str, location="json")
-        parser.add_argument("retrieval_model", type=dict, required=False, location="json")
-        parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
-        args = parser.parse_args()
-
-        HitTestingService.hit_testing_args_check(args)
-
-        try:
-            response = HitTestingService.retrieve(
-                dataset=dataset,
-                query=args["query"],
-                account=current_user,
-                retrieval_model=args["retrieval_model"],
-                external_retrieval_model=args["external_retrieval_model"],
-                limit=10,
-            )
+        dataset = self.get_and_validate_dataset(dataset_id_str)
+        args = self.parse_args()
+        self.hit_testing_args_check(args)
 
-            return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
-        except services.errors.index.IndexNotInitializedError:
-            raise DatasetNotInitializedError()
-        except ProviderTokenNotInitError as ex:
-            raise ProviderNotInitializeError(ex.description)
-        except QuotaExceededError:
-            raise ProviderQuotaExceededError()
-        except ModelCurrentlyNotSupportError:
-            raise ProviderModelCurrentlyNotSupportError()
-        except LLMBadRequestError:
-            raise ProviderNotInitializeError(
-                "No Embedding Model or Reranking Model available. Please configure a valid provider "
-                "in the Settings -> Model Provider."
-            )
-        except InvokeError as e:
-            raise CompletionRequestError(e.description)
-        except ValueError as e:
-            raise ValueError(str(e))
-        except Exception as e:
-            logging.exception("Hit testing failed.")
-            raise InternalServerError(str(e))
+        return self.perform_hit_testing(dataset, args)
 
 
 api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")

+ 85 - 0
api/controllers/console/datasets/hit_testing_base.py

@@ -0,0 +1,85 @@
+import logging
+
+from flask_login import current_user
+from flask_restful import marshal, reqparse
+from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
+
+import services.dataset_service
+from controllers.console.app.error import (
+    CompletionRequestError,
+    ProviderModelCurrentlyNotSupportError,
+    ProviderNotInitializeError,
+    ProviderQuotaExceededError,
+)
+from controllers.console.datasets.error import DatasetNotInitializedError
+from core.errors.error import (
+    LLMBadRequestError,
+    ModelCurrentlyNotSupportError,
+    ProviderTokenNotInitError,
+    QuotaExceededError,
+)
+from core.model_runtime.errors.invoke import InvokeError
+from fields.hit_testing_fields import hit_testing_record_fields
+from services.dataset_service import DatasetService
+from services.hit_testing_service import HitTestingService
+
+
+class DatasetsHitTestingBase:
+    @staticmethod
+    def get_and_validate_dataset(dataset_id: str):
+        dataset = DatasetService.get_dataset(dataset_id)
+        if dataset is None:
+            raise NotFound("Dataset not found.")
+
+        try:
+            DatasetService.check_dataset_permission(dataset, current_user)
+        except services.errors.account.NoPermissionError as e:
+            raise Forbidden(str(e))
+
+        return dataset
+
+    @staticmethod
+    def hit_testing_args_check(args):
+        HitTestingService.hit_testing_args_check(args)
+
+    @staticmethod
+    def parse_args():
+        parser = reqparse.RequestParser()
+
+        parser.add_argument("query", type=str, location="json")
+        parser.add_argument("retrieval_model", type=dict, required=False, location="json")
+        parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
+        return parser.parse_args()
+
+    @staticmethod
+    def perform_hit_testing(dataset, args):
+        try:
+            response = HitTestingService.retrieve(
+                dataset=dataset,
+                query=args["query"],
+                account=current_user,
+                retrieval_model=args["retrieval_model"],
+                external_retrieval_model=args["external_retrieval_model"],
+                limit=10,
+            )
+            return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
+        except services.errors.index.IndexNotInitializedError:
+            raise DatasetNotInitializedError()
+        except ProviderTokenNotInitError as ex:
+            raise ProviderNotInitializeError(ex.description)
+        except QuotaExceededError:
+            raise ProviderQuotaExceededError()
+        except ModelCurrentlyNotSupportError:
+            raise ProviderModelCurrentlyNotSupportError()
+        except LLMBadRequestError:
+            raise ProviderNotInitializeError(
+                "No Embedding Model or Reranking Model available. Please configure a valid provider "
+                "in the Settings -> Model Provider."
+            )
+        except InvokeError as e:
+            raise CompletionRequestError(e.description)
+        except ValueError as e:
+            raise ValueError(str(e))
+        except Exception as e:
+            logging.exception("Hit testing failed.")
+            raise InternalServerError(str(e))

+ 1 - 2
api/controllers/service_api/__init__.py

@@ -5,7 +5,6 @@ from libs.external_api import ExternalApi
 bp = Blueprint("service_api", __name__, url_prefix="/v1")
 api = ExternalApi(bp)
 
-
 from . import index
 from .app import app, audio, completion, conversation, file, message, workflow
-from .dataset import dataset, document, segment
+from .dataset import dataset, document, hit_testing, segment

+ 17 - 0
api/controllers/service_api/dataset/hit_testing.py

@@ -0,0 +1,17 @@
+from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
+from controllers.service_api import api
+from controllers.service_api.wraps import DatasetApiResource
+
+
+class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
+    def post(self, tenant_id, dataset_id):
+        dataset_id_str = str(dataset_id)
+
+        dataset = self.get_and_validate_dataset(dataset_id_str)
+        args = self.parse_args()
+        self.hit_testing_args_check(args)
+
+        return self.perform_hit_testing(dataset, args)
+
+
+api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")

+ 145 - 0
web/app/(commonLayout)/datasets/template/template.en.mdx

@@ -1050,6 +1050,151 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
 
 ---
 
+<Heading
+  url='/datasets/{dataset_id}/hit_testing'
+  method='POST'
+  title='Dataset hit testing'
+  name='#dataset_hit_testing'
+/>
+<Row>
+  <Col>
+    ### Path
+    <Properties>
+      <Property name='dataset_id' type='string' key='dataset_id'>
+        Dataset ID
+      </Property>
+    </Properties>
+
+    ### Request Body
+    <Properties>
+      <Property name='query' type='string' key='query'>
+        retrieval keywordc
+      </Property>
+      <Property name='retrieval_model' type='object' key='retrieval_model'>
+        retrieval keyword(Optional, if not filled, it will be recalled according to the default method)
+        - <code>search_method</code> (text) Search method: One of the following four keywords is required
+          - <code>keyword_search</code> Keyword search
+          - <code>semantic_search</code> Semantic search
+          - <code>full_text_search</code> Full-text search
+          - <code>hybrid_search</code> Hybrid search
+        - <code>reranking_enable</code> (bool) Whether to enable reranking, optional, required if the search mode is semantic_search or hybrid_search
+        - <code>reranking_mode</code> (object) Rerank model configuration, optional, required if reranking is enabled
+            - <code>reranking_provider_name</code> (string) Rerank model provider
+            - <code>reranking_model_name</code> (string) Rerank model name
+        - <code>weights</code> (double) Semantic search weight setting in hybrid search mode
+        - <code>top_k</code> (integer) Number of results to return, optional
+        - <code>score_threshold_enabled</code> (bool) Whether to enable score threshold
+        - <code>score_threshold</code> (double) Score threshold
+      </Property>
+      <Property name='external_retrieval_model' type='object' key='external_retrieval_model'>
+          Unused field
+      </Property>
+    </Properties>
+  </Col>
+  <Col sticky>
+    <CodeGroup
+      title="Request"
+      tag="POST"
+      label="/datasets/{dataset_id}/hit_testing"
+      targetCode={`curl --location --request GET '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \\\n--header 'Authorization: Bearer {api_key}'\\\n--header 'Content-Type: application/json'\\\n--data-raw '{
+        "query": "test",
+        "retrieval_model": {
+            "search_method": "keyword_search",
+            "reranking_enable": false,
+            "reranking_mode": null,
+            "reranking_model": {
+                "reranking_provider_name": "",
+                "reranking_model_name": ""
+            },
+            "weights": null,
+            "top_k": 1,
+            "score_threshold_enabled": false,
+            "score_threshold": null
+        }
+    }'`}
+    >
+    ```bash {{ title: 'cURL' }}
+    curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \
+    --header 'Authorization: Bearer {api_key}' \
+    --header 'Content-Type: application/json' \
+    --data-raw '{
+        "query": "test",
+        "retrieval_model": {
+            "search_method": "keyword_search",
+            "reranking_enable": false,
+            "reranking_mode": null,
+            "reranking_model": {
+                "reranking_provider_name": "",
+                "reranking_model_name": ""
+            },
+            "weights": null,
+            "top_k": 2,
+            "score_threshold_enabled": false,
+            "score_threshold": null
+        }
+    }'
+    ```
+    </CodeGroup>
+    <CodeGroup title="Response">
+    ```json {{ title: 'Response' }}
+    {
+      "query": {
+        "content": "test"
+      },
+      "records": [
+        {
+          "segment": {
+            "id": "7fa6f24f-8679-48b3-bc9d-bdf28d73f218",
+            "position": 1,
+            "document_id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2",
+            "content": "Operation guide",
+            "answer": null,
+            "word_count": 847,
+            "tokens": 280,
+            "keywords": [
+              "install",
+              "java",
+              "base",
+              "scripts",
+              "jdk",
+              "manual",
+              "internal",
+              "opens",
+              "add",
+              "vmoptions"
+            ],
+            "index_node_id": "39dd8443-d960-45a8-bb46-7275ad7fbc8e",
+            "index_node_hash": "0189157697b3c6a418ccf8264a09699f25858975578f3467c76d6bfc94df1d73",
+            "hit_count": 0,
+            "enabled": true,
+            "disabled_at": null,
+            "disabled_by": null,
+            "status": "completed",
+            "created_by": "dbcb1ab5-90c8-41a7-8b78-73b235eb6f6f",
+            "created_at": 1728734540,
+            "indexing_at": 1728734552,
+            "completed_at": 1728734584,
+            "error": null,
+            "stopped_at": null,
+            "document": {
+              "id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2",
+              "data_source_type": "upload_file",
+              "name": "readme.txt",
+              "doc_type": null
+            }
+          },
+          "score": 3.730463140527718e-05,
+          "tsne_position": null
+        }
+      ]
+    }
+    ```
+    </CodeGroup>
+  </Col>
+</Row>
+
+---
+
 <Row>
   <Col>
     ### Error message

+ 146 - 0
web/app/(commonLayout)/datasets/template/template.zh.mdx

@@ -1049,6 +1049,152 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
   </Col>
 </Row>
 
+---
+
+<Heading
+  url='/datasets/{dataset_id}/hit_testing'
+  method='POST'
+  title='知识库召回测试'
+  name='#dataset_hit_testing'
+/>
+<Row>
+  <Col>
+    ### Path
+    <Properties>
+      <Property name='dataset_id' type='string' key='dataset_id'>
+        知识库 ID
+      </Property>
+    </Properties>
+
+    ### Request Body
+    <Properties>
+      <Property name='query' type='string' key='query'>
+        召回关键词
+      </Property>
+      <Property name='retrieval_model' type='object' key='retrieval_model'>
+        召回参数(选填,如不填,按照默认方式召回)
+        - <code>search_method</code> (text) 检索方法:以下三个关键字之一,必填
+          - <code>keyword_search</code> 关键字检索
+          - <code>semantic_search</code> 语义检索
+          - <code>full_text_search</code> 全文检索
+          - <code>hybrid_search</code> 混合检索
+        - <code>reranking_enable</code> (bool) 是否启用 Reranking,非必填,如果检索模式为semantic_search模式或者hybrid_search则传值
+        - <code>reranking_mode</code> (object) Rerank模型配置,非必填,如果启用了 reranking 则传值
+            - <code>reranking_provider_name</code> (string) Rerank 模型提供商
+            - <code>reranking_model_name</code> (string) Rerank 模型名称
+        - <code>weights</code> (double) 混合检索模式下语意检索的权重设置
+        - <code>top_k</code> (integer) 返回结果数量,非必填
+        - <code>score_threshold_enabled</code> (bool) 是否开启Score阈值
+        - <code>score_threshold</code> (double) Score阈值
+      </Property>
+      <Property name='external_retrieval_model' type='object' key='external_retrieval_model'>
+          未启用字段
+      </Property>
+    </Properties>
+  </Col>
+  <Col sticky>
+    <CodeGroup
+      title="Request"
+      tag="POST"
+      label="/datasets/{dataset_id}/hit_testing"
+      targetCode={`curl --location --request GET '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \\\n--header 'Authorization: Bearer {api_key}'\\\n--header 'Content-Type: application/json'\\\n--data-raw '{
+        "query": "test",
+        "retrieval_model": {
+            "search_method": "keyword_search",
+            "reranking_enable": false,
+            "reranking_mode": null,
+            "reranking_model": {
+                "reranking_provider_name": "",
+                "reranking_model_name": ""
+            },
+            "weights": null,
+            "top_k": 1,
+            "score_threshold_enabled": false,
+            "score_threshold": null
+        }
+    }'`}
+    >
+    ```bash {{ title: 'cURL' }}
+    curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \
+    --header 'Authorization: Bearer {api_key}' \
+    --header 'Content-Type: application/json' \
+    --data-raw '{
+        "query": "test",
+        "retrieval_model": {
+            "search_method": "keyword_search",
+            "reranking_enable": false,
+            "reranking_mode": null,
+            "reranking_model": {
+                "reranking_provider_name": "",
+                "reranking_model_name": ""
+            },
+            "weights": null,
+            "top_k": 2,
+            "score_threshold_enabled": false,
+            "score_threshold": null
+        }
+    }'
+    ```
+    </CodeGroup>
+    <CodeGroup title="Response">
+    ```json {{ title: 'Response' }}
+    {
+      "query": {
+        "content": "test"
+      },
+      "records": [
+        {
+          "segment": {
+            "id": "7fa6f24f-8679-48b3-bc9d-bdf28d73f218",
+            "position": 1,
+            "document_id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2",
+            "content": "Operation guide",
+            "answer": null,
+            "word_count": 847,
+            "tokens": 280,
+            "keywords": [
+              "install",
+              "java",
+              "base",
+              "scripts",
+              "jdk",
+              "manual",
+              "internal",
+              "opens",
+              "add",
+              "vmoptions"
+            ],
+            "index_node_id": "39dd8443-d960-45a8-bb46-7275ad7fbc8e",
+            "index_node_hash": "0189157697b3c6a418ccf8264a09699f25858975578f3467c76d6bfc94df1d73",
+            "hit_count": 0,
+            "enabled": true,
+            "disabled_at": null,
+            "disabled_by": null,
+            "status": "completed",
+            "created_by": "dbcb1ab5-90c8-41a7-8b78-73b235eb6f6f",
+            "created_at": 1728734540,
+            "indexing_at": 1728734552,
+            "completed_at": 1728734584,
+            "error": null,
+            "stopped_at": null,
+            "document": {
+              "id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2",
+              "data_source_type": "upload_file",
+              "name": "readme.txt",
+              "doc_type": null
+            }
+          },
+          "score": 3.730463140527718e-05,
+          "tsne_position": null
+        }
+      ]
+    }
+    ```
+    </CodeGroup>
+  </Col>
+</Row>
+
+
 ---
 
 <Row>