浏览代码

Support knowledge metadata filter (#15982)

Jyong 1 月之前
父节点
当前提交
abeaea4f79
共有 48 个文件被更改,包括 2399 次插入483 次删除
  1. 1 0
      api/controllers/console/__init__.py
  2. 2 2
      api/controllers/console/datasets/datasets_document.py
  3. 155 0
      api/controllers/console/datasets/metadata.py
  4. 24 1
      api/core/app/app_config/easy_ui_based_app/dataset/manager.py
  5. 54 1
      api/core/app/app_config/entities.py
  6. 1 0
      api/core/app/apps/chat/app_runner.py
  7. 1 0
      api/core/app/apps/completion/app_runner.py
  8. 6 5
      api/core/rag/datasource/keyword/jieba/jieba.py
  9. 17 2
      api/core/rag/datasource/retrieval_service.py
  10. 1 1
      api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py
  11. 12 2
      api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py
  12. 15 5
      api/core/rag/datasource/vdb/baidu/baidu_vector.py
  13. 9 1
      api/core/rag/datasource/vdb/chroma/chroma_vector.py
  14. 6 0
      api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
  15. 10 2
      api/core/rag/datasource/vdb/lindorm/lindorm_vector.py
  16. 12 0
      api/core/rag/datasource/vdb/milvus/milvus_vector.py
  17. 4 0
      api/core/rag/datasource/vdb/myscale/myscale_vector.py
  18. 6 0
      api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py
  19. 6 0
      api/core/rag/datasource/vdb/opensearch/opensearch_vector.py
  20. 13 2
      api/core/rag/datasource/vdb/oracle/oraclevector.py
  21. 3 0
      api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
  22. 13 0
      api/core/rag/datasource/vdb/pgvector/pgvector.py
  23. 38 21
      api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
  24. 7 3
      api/core/rag/datasource/vdb/relyt/relyt_vector.py
  25. 6 1
      api/core/rag/datasource/vdb/tencent/tencent_vector.py
  26. 24 0
      api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py
  27. 6 0
      api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py
  28. 14 1
      api/core/rag/datasource/vdb/upstash/upstash_vector.py
  29. 5 1
      api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py
  30. 8 4
      api/core/rag/datasource/vdb/weaviate/weaviate_vector.py
  31. 45 0
      api/core/rag/entities/metadata_entities.py
  32. 15 0
      api/core/rag/index_processor/constant/built_in_field.py
  33. 429 7
      api/core/rag/retrieval/dataset_retrieval.py
  34. 66 0
      api/core/rag/retrieval/template_prompts.py
  35. 49 1
      api/core/workflow/nodes/knowledge_retrieval/entities.py
  36. 4 0
      api/core/workflow/nodes/knowledge_retrieval/exc.py
  37. 268 22
      api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
  38. 66 0
      api/core/workflow/nodes/knowledge_retrieval/template_prompts.py
  39. 10 0
      api/fields/dataset_fields.py
  40. 9 0
      api/fields/document_fields.py
  41. 90 0
      api/migrations/versions/2025_02_27_0917-d20049ed0af6_add_metadata_function.py
  42. 175 1
      api/models/dataset.py
  43. 355 392
      api/poetry.lock
  44. 57 3
      api/services/dataset_service.py
  45. 33 0
      api/services/entities/knowledge_entities/knowledge_entities.py
  46. 7 1
      api/services/external_knowledge_service.py
  47. 241 0
      api/services/metadata_service.py
  48. 1 1
      api/services/tag_service.py

+ 1 - 0
api/controllers/console/__init__.py

@@ -81,6 +81,7 @@ from .datasets import (
     datasets_segments,
     external,
     hit_testing,
+    metadata,
     website,
 )
 

+ 2 - 2
api/controllers/console/datasets/datasets_document.py

@@ -621,7 +621,7 @@ class DocumentDetailApi(DocumentResource):
             raise InvalidMetadataError(f"Invalid metadata value: {metadata}")
 
         if metadata == "only":
-            response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata}
+            response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
         elif metadata == "without":
             dataset_process_rules = DatasetService.get_process_rules(dataset_id)
             document_process_rules = document.dataset_process_rule.to_dict()
@@ -682,7 +682,7 @@ class DocumentDetailApi(DocumentResource):
                 "disabled_by": document.disabled_by,
                 "archived": document.archived,
                 "doc_type": document.doc_type,
-                "doc_metadata": document.doc_metadata,
+                "doc_metadata": document.doc_metadata_details,
                 "segment_count": document.segment_count,
                 "average_segment_length": document.average_segment_length,
                 "hit_count": document.hit_count,

+ 155 - 0
api/controllers/console/datasets/metadata.py

@@ -0,0 +1,155 @@
+from flask_login import current_user  # type: ignore  # type: ignore
+from flask_restful import Resource, marshal_with, reqparse  # type: ignore
+from werkzeug.exceptions import NotFound
+
+from controllers.console import api
+from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
+from fields.dataset_fields import dataset_metadata_fields
+from libs.login import login_required
+from services.dataset_service import DatasetService
+from services.entities.knowledge_entities.knowledge_entities import (
+    MetadataArgs,
+    MetadataOperationData,
+)
+from services.metadata_service import MetadataService
+
+
+def _validate_name(name):
+    if not name or len(name) < 1 or len(name) > 40:
+        raise ValueError("Name must be between 1 to 40 characters.")
+    return name
+
+
+def _validate_description_length(description):
+    if len(description) > 400:
+        raise ValueError("Description cannot exceed 400 characters.")
+    return description
+
+
+class DatasetMetadataCreateApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    @enterprise_license_required
+    @marshal_with(dataset_metadata_fields)
+    def post(self, dataset_id):
+        parser = reqparse.RequestParser()
+        parser.add_argument("type", type=str, required=True, nullable=True, location="json")
+        parser.add_argument("name", type=str, required=True, nullable=True, location="json")
+        args = parser.parse_args()
+        metadata_args = MetadataArgs(**args)
+
+        dataset_id_str = str(dataset_id)
+        dataset = DatasetService.get_dataset(dataset_id_str)
+        if dataset is None:
+            raise NotFound("Dataset not found.")
+        DatasetService.check_dataset_permission(dataset, current_user)
+
+        metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
+        return metadata, 201
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    @enterprise_license_required
+    def get(self, dataset_id):
+        dataset_id_str = str(dataset_id)
+        dataset = DatasetService.get_dataset(dataset_id_str)
+        if dataset is None:
+            raise NotFound("Dataset not found.")
+        return MetadataService.get_dataset_metadatas(dataset), 200
+
+
+class DatasetMetadataApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    @enterprise_license_required
+    @marshal_with(dataset_metadata_fields)
+    def patch(self, dataset_id, metadata_id):
+        parser = reqparse.RequestParser()
+        parser.add_argument("name", type=str, required=True, nullable=True, location="json")
+        args = parser.parse_args()
+
+        dataset_id_str = str(dataset_id)
+        metadata_id_str = str(metadata_id)
+        dataset = DatasetService.get_dataset(dataset_id_str)
+        if dataset is None:
+            raise NotFound("Dataset not found.")
+        DatasetService.check_dataset_permission(dataset, current_user)
+
+        metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name"))
+        return metadata, 200
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    @enterprise_license_required
+    def delete(self, dataset_id, metadata_id):
+        dataset_id_str = str(dataset_id)
+        metadata_id_str = str(metadata_id)
+        dataset = DatasetService.get_dataset(dataset_id_str)
+        if dataset is None:
+            raise NotFound("Dataset not found.")
+        DatasetService.check_dataset_permission(dataset, current_user)
+
+        MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
+        return 200
+
+
+class DatasetMetadataBuiltInFieldApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    @enterprise_license_required
+    def get(self):
+        built_in_fields = MetadataService.get_built_in_fields()
+        return {"fields": built_in_fields}, 200
+
+
+class DatasetMetadataBuiltInFieldActionApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    @enterprise_license_required
+    def post(self, dataset_id, action):
+        dataset_id_str = str(dataset_id)
+        dataset = DatasetService.get_dataset(dataset_id_str)
+        if dataset is None:
+            raise NotFound("Dataset not found.")
+        DatasetService.check_dataset_permission(dataset, current_user)
+
+        if action == "enable":
+            MetadataService.enable_built_in_field(dataset)
+        elif action == "disable":
+            MetadataService.disable_built_in_field(dataset)
+        return 200
+
+
+class DocumentMetadataEditApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    @enterprise_license_required
+    def post(self, dataset_id):
+        dataset_id_str = str(dataset_id)
+        dataset = DatasetService.get_dataset(dataset_id_str)
+        if dataset is None:
+            raise NotFound("Dataset not found.")
+        DatasetService.check_dataset_permission(dataset, current_user)
+
+        parser = reqparse.RequestParser()
+        parser.add_argument("operation_data", type=list, required=True, nullable=True, location="json")
+        args = parser.parse_args()
+        metadata_args = MetadataOperationData(**args)
+
+        MetadataService.update_documents_metadata(dataset, metadata_args)
+
+        return 200
+
+
+api.add_resource(DatasetMetadataCreateApi, "/datasets/<uuid:dataset_id>/metadata")
+api.add_resource(DatasetMetadataApi, "/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
+api.add_resource(DatasetMetadataBuiltInFieldApi, "/datasets/metadata/built-in")
+api.add_resource(DatasetMetadataBuiltInFieldActionApi, "/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>")
+api.add_resource(DocumentMetadataEditApi, "/datasets/<uuid:dataset_id>/documents/metadata")

+ 24 - 1
api/core/app/app_config/easy_ui_based_app/dataset/manager.py

@@ -1,7 +1,12 @@
 import uuid
 from typing import Optional
 
-from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
+from core.app.app_config.entities import (
+    DatasetEntity,
+    DatasetRetrieveConfigEntity,
+    MetadataFilteringCondition,
+    ModelConfig,
+)
 from core.entities.agent_entities import PlanningStrategy
 from models.model import AppMode
 from services.dataset_service import DatasetService
@@ -78,6 +83,15 @@ class DatasetConfigManager:
                     retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
                         dataset_configs["retrieval_model"]
                     ),
+                    metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"),
+                    metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config"))
+                    if dataset_configs.get("metadata_model_config")
+                    else None,
+                    metadata_filtering_conditions=MetadataFilteringCondition(
+                        **dataset_configs.get("metadata_filtering_conditions", {})
+                    )
+                    if dataset_configs.get("metadata_filtering_conditions")
+                    else None,
                 ),
             )
         else:
@@ -96,6 +110,15 @@ class DatasetConfigManager:
                     weights=dataset_configs.get("weights"),
                     reranking_enabled=dataset_configs.get("reranking_enabled", True),
                     rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
+                    metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"),
+                    metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config"))
+                    if dataset_configs.get("metadata_model_config")
+                    else None,
+                    metadata_filtering_conditions=MetadataFilteringCondition(
+                        **dataset_configs.get("metadata_filtering_conditions", {})
+                    )
+                    if dataset_configs.get("metadata_filtering_conditions")
+                    else None,
                 ),
             )
 

+ 54 - 1
api/core/app/app_config/entities.py

@@ -1,10 +1,11 @@
 from collections.abc import Sequence
 from enum import Enum, StrEnum
-from typing import Any, Optional
+from typing import Any, Literal, Optional
 
 from pydantic import BaseModel, Field, field_validator
 
 from core.file import FileTransferMethod, FileType, FileUploadConfig
+from core.model_runtime.entities.llm_entities import LLMMode
 from core.model_runtime.entities.message_entities import PromptMessageRole
 from models.model import AppMode
 
@@ -135,6 +136,55 @@ class ExternalDataVariableEntity(BaseModel):
     config: dict[str, Any] = Field(default_factory=dict)
 
 
+SupportedComparisonOperator = Literal[
+    # for string or array
+    "contains",
+    "not contains",
+    "start with",
+    "end with",
+    "is",
+    "is not",
+    "empty",
+    "not empty",
+    # for number
+    "=",
+    "≠",
+    ">",
+    "<",
+    "≥",
+    "≤",
+    # for time
+    "before",
+    "after",
+]
+
+
+class ModelConfig(BaseModel):
+    provider: str
+    name: str
+    mode: LLMMode
+    completion_params: dict[str, Any] = {}
+
+
+class Condition(BaseModel):
+    """
+    Conditon detail
+    """
+
+    name: str
+    comparison_operator: SupportedComparisonOperator
+    value: str | Sequence[str] | None | int | float = None
+
+
+class MetadataFilteringCondition(BaseModel):
+    """
+    Metadata Filtering Condition.
+    """
+
+    logical_operator: Optional[Literal["and", "or"]] = "and"
+    conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
+
+
 class DatasetRetrieveConfigEntity(BaseModel):
     """
     Dataset Retrieve Config Entity.
@@ -171,6 +221,9 @@ class DatasetRetrieveConfigEntity(BaseModel):
     reranking_model: Optional[dict] = None
     weights: Optional[dict] = None
     reranking_enabled: Optional[bool] = True
+    metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled"
+    metadata_model_config: Optional[ModelConfig] = None
+    metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None
 
 
 class DatasetEntity(BaseModel):

+ 1 - 0
api/core/app/apps/chat/app_runner.py

@@ -180,6 +180,7 @@ class ChatAppRunner(AppRunner):
                 hit_callback=hit_callback,
                 memory=memory,
                 message_id=message.id,
+                inputs=inputs,
             )
 
         # reorganize all inputs and template to prompt messages

+ 1 - 0
api/core/app/apps/completion/app_runner.py

@@ -139,6 +139,7 @@ class CompletionAppRunner(AppRunner):
                 show_retrieve_source=app_config.additional_features.show_retrieve_source,
                 hit_callback=hit_callback,
                 message_id=message.id,
+                inputs=inputs,
             )
 
         # reorganize all inputs and template to prompt messages

+ 6 - 5
api/core/rag/datasource/keyword/jieba/jieba.py

@@ -88,16 +88,17 @@ class Jieba(BaseKeyword):
         keyword_table = self._get_dataset_keyword_table()
 
         k = kwargs.get("top_k", 4)
-
+        document_ids_filter = kwargs.get("document_ids_filter")
         sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k)
 
         documents = []
         for chunk_index in sorted_chunk_indices:
-            segment = (
-                db.session.query(DocumentSegment)
-                .filter(DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index)
-                .first()
+            segment_query = db.session.query(DocumentSegment).filter(
+                DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index
             )
+            if document_ids_filter:
+                segment_query = segment_query.filter(DocumentSegment.document_id.in_(document_ids_filter))
+            segment = segment_query.first()
 
             if segment:
                 documents.append(

+ 17 - 2
api/core/rag/datasource/retrieval_service.py

@@ -41,6 +41,7 @@ class RetrievalService:
         reranking_model: Optional[dict] = None,
         reranking_mode: str = "reranking_model",
         weights: Optional[dict] = None,
+        document_ids_filter: Optional[list[str]] = None,
     ):
         if not query:
             return []
@@ -64,6 +65,7 @@ class RetrievalService:
                         top_k=top_k,
                         all_documents=all_documents,
                         exceptions=exceptions,
+                        document_ids_filter=document_ids_filter,
                     )
                 )
             if RetrievalMethod.is_support_semantic_search(retrieval_method):
@@ -79,6 +81,7 @@ class RetrievalService:
                         all_documents=all_documents,
                         retrieval_method=retrieval_method,
                         exceptions=exceptions,
+                        document_ids_filter=document_ids_filter,
                     )
                 )
             if RetrievalMethod.is_support_fulltext_search(retrieval_method):
@@ -130,7 +133,14 @@ class RetrievalService:
 
     @classmethod
     def keyword_search(
-        cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list
+        cls,
+        flask_app: Flask,
+        dataset_id: str,
+        query: str,
+        top_k: int,
+        all_documents: list,
+        exceptions: list,
+        document_ids_filter: Optional[list[str]] = None,
     ):
         with flask_app.app_context():
             try:
@@ -139,7 +149,10 @@ class RetrievalService:
                     raise ValueError("dataset not found")
 
                 keyword = Keyword(dataset=dataset)
-                documents = keyword.search(cls.escape_query_for_search(query), top_k=top_k)
+
+                documents = keyword.search(
+                    cls.escape_query_for_search(query), top_k=top_k, document_ids_filter=document_ids_filter
+                )
                 all_documents.extend(documents)
             except Exception as e:
                 exceptions.append(str(e))
@@ -156,6 +169,7 @@ class RetrievalService:
         all_documents: list,
         retrieval_method: str,
         exceptions: list,
+        document_ids_filter: Optional[list[str]] = None,
     ):
         with flask_app.app_context():
             try:
@@ -170,6 +184,7 @@ class RetrievalService:
                     top_k=top_k,
                     score_threshold=score_threshold,
                     filter={"group_id": [dataset.id]},
+                    document_ids_filter=document_ids_filter,
                 )
 
                 if documents:

+ 1 - 1
api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py

@@ -53,7 +53,7 @@ class AnalyticdbVector(BaseVector):
         self.analyticdb_vector.delete_by_metadata_field(key, value)
 
     def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
-        return self.analyticdb_vector.search_by_vector(query_vector)
+        return self.analyticdb_vector.search_by_vector(query_vector, **kwargs)
 
     def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
         return self.analyticdb_vector.search_by_full_text(query, **kwargs)

+ 12 - 2
api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py

@@ -196,6 +196,11 @@ class AnalyticdbVectorBySql:
         top_k = kwargs.get("top_k", 4)
         if not isinstance(top_k, int) or top_k <= 0:
             raise ValueError("top_k must be a positive integer")
+        document_ids_filter = kwargs.get("document_ids_filter")
+        where_clause = "WHERE 1=1"
+        if document_ids_filter:
+            document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
+            where_clause += f"AND metadata_->>'document_id' IN ({document_ids})"
         score_threshold = float(kwargs.get("score_threshold") or 0.0)
         with self._get_cursor() as cur:
             query_vector_str = json.dumps(query_vector)
@@ -204,7 +209,7 @@ class AnalyticdbVectorBySql:
                 f"SELECT t.id AS id, t.vector AS vector, (1.0 - t.score) AS score, "
                 f"t.page_content as page_content, t.metadata_ AS metadata_ "
                 f"FROM (SELECT id, vector, page_content, metadata_, vector <=> %s AS score "
-                f"FROM {self.table_name} ORDER BY score LIMIT {top_k} ) t",
+                f"FROM {self.table_name} {where_clause} ORDER BY score LIMIT {top_k} ) t",
                 (query_vector_str,),
             )
             documents = []
@@ -224,12 +229,17 @@ class AnalyticdbVectorBySql:
         top_k = kwargs.get("top_k", 4)
         if not isinstance(top_k, int) or top_k <= 0:
             raise ValueError("top_k must be a positive integer")
+        document_ids_filter = kwargs.get("document_ids_filter")
+        where_clause = ""
+        if document_ids_filter:
+            document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
+            where_clause += f"AND metadata_->>'document_id' IN ({document_ids})"
         with self._get_cursor() as cur:
             cur.execute(
                 f"""SELECT id, vector, page_content, metadata_, 
                 ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
                 FROM {self.table_name}
-                WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn')
+                WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') {where_clause}
                 ORDER BY score DESC
                 LIMIT {top_k}""",
                 (f"'{query}'", f"'{query}'"),

+ 15 - 5
api/core/rag/datasource/vdb/baidu/baidu_vector.py

@@ -123,11 +123,21 @@ class BaiduVector(BaseVector):
 
     def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
         query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector]
-        anns = AnnSearch(
-            vector_field=self.field_vector,
-            vector_floats=query_vector,
-            params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
-        )
+        document_ids_filter = kwargs.get("document_ids_filter")
+        if document_ids_filter:
+            document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
+            anns = AnnSearch(
+                vector_field=self.field_vector,
+                vector_floats=query_vector,
+                params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
+                filter=f"document_id IN ({document_ids})",
+            )
+        else:
+            anns = AnnSearch(
+                vector_field=self.field_vector,
+                vector_floats=query_vector,
+                params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
+            )
         res = self._db.table(self._collection_name).search(
             anns=anns,
             projections=[self.field_id, self.field_text, self.field_metadata],

+ 9 - 1
api/core/rag/datasource/vdb/chroma/chroma_vector.py

@@ -95,7 +95,15 @@ class ChromaVector(BaseVector):
 
     def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
         collection = self._client.get_or_create_collection(self._collection_name)
-        results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
+        document_ids_filter = kwargs.get("document_ids_filter")
+        if document_ids_filter:
+            results: QueryResult = collection.query(
+                query_embeddings=query_vector,
+                n_results=kwargs.get("top_k", 4),
+                where={"document_id": {"$in": document_ids_filter}},  # type: ignore
+            )
+        else:
+            results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))  # type: ignore
         score_threshold = float(kwargs.get("score_threshold") or 0.0)
 
         # Check if results contain data

+ 6 - 0
api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py

@@ -117,6 +117,9 @@ class ElasticSearchVector(BaseVector):
         top_k = kwargs.get("top_k", 4)
         num_candidates = math.ceil(top_k * 1.5)
         knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates}
+        document_ids_filter = kwargs.get("document_ids_filter")
+        if document_ids_filter:
+            knn["filter"] = {"terms": {"metadata.document_id": document_ids_filter}}
 
         results = self._client.search(index=self._collection_name, knn=knn, size=top_k)
 
@@ -145,6 +148,9 @@ class ElasticSearchVector(BaseVector):
 
     def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
         query_str = {"match": {Field.CONTENT_KEY.value: query}}
+        document_ids_filter = kwargs.get("document_ids_filter")
+        if document_ids_filter:
+            query_str["filter"] = {"terms": {"metadata.document_id": document_ids_filter}}  # type: ignore
         results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4))
         docs = []
         for hit in results["hits"]["hits"]:

+ 10 - 2
api/core/rag/datasource/vdb/lindorm/lindorm_vector.py

@@ -168,7 +168,12 @@ class LindormVectorStore(BaseVector):
             raise ValueError("All elements in query_vector should be floats")
 
         top_k = kwargs.get("top_k", 10)
-        query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs)
+        document_ids_filter = kwargs.get("document_ids_filter")
+        filters = []
+        if document_ids_filter:
+            filters.append({"terms": {"metadata.document_id": document_ids_filter}})
+        query = default_vector_search_query(query_vector=query_vector, k=top_k, filters=filters, **kwargs)
+
         try:
             params = {}
             if self._using_ugc:
@@ -206,7 +211,10 @@ class LindormVectorStore(BaseVector):
         should = kwargs.get("should")
         minimum_should_match = kwargs.get("minimum_should_match", 0)
         top_k = kwargs.get("top_k", 10)
-        filters = kwargs.get("filter")
+        filters = kwargs.get("filter", [])
+        document_ids_filter = kwargs.get("document_ids_filter")
+        if document_ids_filter:
+            filters.append({"terms": {"metadata.document_id": document_ids_filter}})
         routing = self._routing
         full_text_query = default_text_search_query(
             query_text=query,

+ 12 - 0
api/core/rag/datasource/vdb/milvus/milvus_vector.py

@@ -228,12 +228,18 @@ class MilvusVector(BaseVector):
         """
         Search for documents by vector similarity.
         """
+        document_ids_filter = kwargs.get("document_ids_filter")
+        filter = ""
+        if document_ids_filter:
+            document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
+            filter = f'metadata["document_id"] in ({document_ids})'
         results = self._client.search(
             collection_name=self._collection_name,
             data=[query_vector],
             anns_field=Field.VECTOR.value,
             limit=kwargs.get("top_k", 4),
             output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
+            filter=filter,
         )
 
         return self._process_search_results(
@@ -249,6 +255,11 @@ class MilvusVector(BaseVector):
         if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value):
             logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)")
             return []
+        document_ids_filter = kwargs.get("document_ids_filter")
+        filter = ""
+        if document_ids_filter:
+            document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
+            filter = f'metadata["document_id"] in ({document_ids})'
 
         results = self._client.search(
             collection_name=self._collection_name,
@@ -256,6 +267,7 @@ class MilvusVector(BaseVector):
             anns_field=Field.SPARSE_VECTOR.value,
             limit=kwargs.get("top_k", 4),
             output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
+            filter=filter,
         )
 
         return self._process_search_results(

+ 4 - 0
api/core/rag/datasource/vdb/myscale/myscale_vector.py

@@ -133,6 +133,10 @@ class MyScaleVector(BaseVector):
             if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0
             else ""
         )
+        document_ids_filter = kwargs.get("document_ids_filter")
+        if document_ids_filter:
+            document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
+            where_str = f"{where_str} AND metadata['document_id'] in ({document_ids})"
         sql = f"""
             SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
             {where_str} ORDER BY dist {order.value} LIMIT {top_k}

+ 6 - 0
api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py

@@ -154,6 +154,11 @@ class OceanBaseVector(BaseVector):
         return []
 
     def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
+        document_ids_filter = kwargs.get("document_ids_filter")
+        where_clause = None
+        if document_ids_filter:
+            document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
+            where_clause = f"metadata->>'$.document_id' in ({document_ids})"
         ef_search = kwargs.get("ef_search", self._hnsw_ef_search)
         if ef_search != self._hnsw_ef_search:
             self._client.set_ob_hnsw_ef_search(ef_search)
@@ -167,6 +172,7 @@ class OceanBaseVector(BaseVector):
             distance_func=func.l2_distance,
             output_column_names=["text", "metadata"],
             with_dist=True,
+            where_clause=where_clause,
         )
         docs = []
         for text, metadata, distance in cur:

+ 6 - 0
api/core/rag/datasource/vdb/opensearch/opensearch_vector.py

@@ -154,6 +154,9 @@ class OpenSearchVector(BaseVector):
             "size": kwargs.get("top_k", 4),
             "query": {"knn": {Field.VECTOR.value: {Field.VECTOR.value: query_vector, "k": kwargs.get("top_k", 4)}}},
         }
+        document_ids_filter = kwargs.get("document_ids_filter")
+        if document_ids_filter:
+            query["query"] = {"terms": {"metadata.document_id": document_ids_filter}}
 
         try:
             response = self._client.search(index=self._collection_name.lower(), body=query)
@@ -179,6 +182,9 @@ class OpenSearchVector(BaseVector):
 
     def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
         full_text_query = {"query": {"match": {Field.CONTENT_KEY.value: query}}}
+        document_ids_filter = kwargs.get("document_ids_filter")
+        if document_ids_filter:
+            full_text_query["query"]["terms"] = {"metadata.document_id": document_ids_filter}
 
         response = self._client.search(index=self._collection_name.lower(), body=full_text_query)
 

+ 13 - 2
api/core/rag/datasource/vdb/oracle/oraclevector.py

@@ -201,10 +201,15 @@ class OracleVector(BaseVector):
         :return: List of Documents that are nearest to the query vector.
         """
         top_k = kwargs.get("top_k", 4)
+        document_ids_filter = kwargs.get("document_ids_filter")
+        where_clause = ""
+        if document_ids_filter:
+            document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
+            where_clause = f"WHERE metadata->>'document_id' in ({document_ids})"
         with self._get_cursor() as cur:
             cur.execute(
                 f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}"
-                f" ORDER BY distance fetch first {top_k} rows only",
+                f" {where_clause} ORDER BY distance fetch first {top_k} rows only",
                 [numpy.array(query_vector)],
             )
             docs = []
@@ -257,9 +262,15 @@ class OracleVector(BaseVector):
                     if token not in stop_words:
                         entities.append(token)
             with self._get_cursor() as cur:
+                document_ids_filter = kwargs.get("document_ids_filter")
+                where_clause = ""
+                if document_ids_filter:
+                    document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
+                    where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
                 cur.execute(
                     f"select meta, text, embedding FROM {self.table_name}"
-                    f" WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only",
+                    f"WHERE CONTAINS(text, :1, 1) > 0 {where_clause} "
+                    f"order by score(1) desc fetch first {top_k} rows only",
                     [" ACCUM ".join(entities)],
                 )
                 docs = []

+ 3 - 0
api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py

@@ -189,6 +189,9 @@ class PGVectoRS(BaseVector):
                 .limit(kwargs.get("top_k", 4))
                 .order_by("distance")
             )
+            document_ids_filter = kwargs.get("document_ids_filter")
+            if document_ids_filter:
+                stmt = stmt.where(self._table.meta["document_id"].in_(document_ids_filter))
             res = session.execute(stmt)
             results = [(row[0], row[1]) for row in res]
 

+ 13 - 0
api/core/rag/datasource/vdb/pgvector/pgvector.py

@@ -173,10 +173,16 @@ class PGVector(BaseVector):
         top_k = kwargs.get("top_k", 4)
         if not isinstance(top_k, int) or top_k <= 0:
             raise ValueError("top_k must be a positive integer")
+        document_ids_filter = kwargs.get("document_ids_filter")
+        where_clause = ""
+        if document_ids_filter:
+            document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
+            where_clause = f" WHERE metadata->>'document_id' in ({document_ids}) "
 
         with self._get_cursor() as cur:
             cur.execute(
                 f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}"
+                f" {where_clause}"
                 f" ORDER BY distance LIMIT {top_k}",
                 (json.dumps(query_vector),),
             )
@@ -195,12 +201,18 @@ class PGVector(BaseVector):
         if not isinstance(top_k, int) or top_k <= 0:
             raise ValueError("top_k must be a positive integer")
         with self._get_cursor() as cur:
+            document_ids_filter = kwargs.get("document_ids_filter")
+            where_clause = ""
+            if document_ids_filter:
+                document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
+                where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
             if self.pg_bigm:
                 cur.execute("SET pg_bigm.similarity_limit TO 0.000001")
                 cur.execute(
                     f"""SELECT meta, text, bigm_similarity(unistr(%s), coalesce(text, '')) AS score
                     FROM {self.table_name}
                     WHERE text =%% unistr(%s)
+                    {where_clause}
                     ORDER BY score DESC
                     LIMIT {top_k}""",
                     # f"'{query}'" is required in order to account for whitespace in query
@@ -211,6 +223,7 @@ class PGVector(BaseVector):
                     f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score
                     FROM {self.table_name}
                     WHERE to_tsvector(text) @@ plainto_tsquery(%s)
+                    {where_clause}
                     ORDER BY score DESC
                     LIMIT {top_k}""",
                     # f"'{query}'" is required in order to account for whitespace in query

+ 38 - 21
api/core/rag/datasource/vdb/qdrant/qdrant_vector.py

@@ -286,27 +286,26 @@ class QdrantVector(BaseVector):
         from qdrant_client.http import models
         from qdrant_client.http.exceptions import UnexpectedResponse
 
-        for node_id in ids:
-            try:
-                filter = models.Filter(
-                    must=[
-                        models.FieldCondition(
-                            key="metadata.doc_id",
-                            match=models.MatchValue(value=node_id),
-                        ),
-                    ],
-                )
-                self._client.delete(
-                    collection_name=self._collection_name,
-                    points_selector=FilterSelector(filter=filter),
-                )
-            except UnexpectedResponse as e:
-                # Collection does not exist, so return
-                if e.status_code == 404:
-                    return
-                # Some other error occurred, so re-raise the exception
-                else:
-                    raise e
+        try:
+            filter = models.Filter(
+                must=[
+                    models.FieldCondition(
+                        key="metadata.doc_id",
+                        match=models.MatchAny(any=ids),
+                    ),
+                ],
+            )
+            self._client.delete(
+                collection_name=self._collection_name,
+                points_selector=FilterSelector(filter=filter),
+            )
+        except UnexpectedResponse as e:
+            # Collection does not exist, so return
+            if e.status_code == 404:
+                return
+            # Some other error occurred, so re-raise the exception
+            else:
+                raise e
 
     def text_exists(self, id: str) -> bool:
         all_collection_name = []
@@ -331,6 +330,15 @@ class QdrantVector(BaseVector):
                 ),
             ],
         )
+        document_ids_filter = kwargs.get("document_ids_filter")
+        if document_ids_filter:
+            if filter.must:
+                filter.must.append(
+                    models.FieldCondition(
+                        key="metadata.document_id",
+                        match=models.MatchAny(any=document_ids_filter),
+                    )
+                )
         results = self._client.search(
             collection_name=self._collection_name,
             query_vector=query_vector,
@@ -377,6 +385,15 @@ class QdrantVector(BaseVector):
                 ),
             ]
         )
+        document_ids_filter = kwargs.get("document_ids_filter")
+        if document_ids_filter:
+            if scroll_filter.must:
+                scroll_filter.must.append(
+                    models.FieldCondition(
+                        key="metadata.document_id",
+                        match=models.MatchAny(any=document_ids_filter),
+                    )
+                )
         response = self._client.scroll(
             collection_name=self._collection_name,
             scroll_filter=scroll_filter,

+ 7 - 3
api/core/rag/datasource/vdb/relyt/relyt_vector.py

@@ -223,8 +223,12 @@ class RelytVector(BaseVector):
         return len(result) > 0
 
     def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
+        document_ids_filter = kwargs.get("document_ids_filter")
+        filter = kwargs.get("filter", {})
+        if document_ids_filter:
+            filter["document_id"] = document_ids_filter
         results = self.similarity_search_with_score_by_vector(
-            k=int(kwargs.get("top_k", 4)), embedding=query_vector, filter=kwargs.get("filter")
+            k=int(kwargs.get("top_k", 4)), embedding=query_vector, filter=filter
         )
 
         # Organize results.
@@ -246,9 +250,9 @@ class RelytVector(BaseVector):
         filter_condition = ""
         if filter is not None:
             conditions = [
-                f"metadata->>{key!r} in ({', '.join(map(repr, value))})"
+                f"metadata->>'{key!r}' in ({', '.join(map(repr, value))})"
                 if len(value) > 1
-                else f"metadata->>{key!r} = {value[0]!r}"
+                else f"metadata->>'{key!r}' = {value[0]!r}"
                 for key, value in filter.items()
             ]
             filter_condition = f"WHERE {' AND '.join(conditions)}"

+ 6 - 1
api/core/rag/datasource/vdb/tencent/tencent_vector.py

@@ -145,11 +145,16 @@ class TencentVector(BaseVector):
         self._db.collection(self._collection_name).delete(document_ids=ids)
 
     def delete_by_metadata_field(self, key: str, value: str) -> None:
-        self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(key, [value])))
+        self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(f"metadata.{key}", [value])))
 
     def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
+        document_ids_filter = kwargs.get("document_ids_filter")
+        filter = None
+        if document_ids_filter:
+            filter = Filter(Filter.In("metadata.document_id", document_ids_filter))
         res = self._db.collection(self._collection_name).search(
             vectors=[query_vector],
+            filter=filter,
             params=document.HNSWSearchParams(ef=kwargs.get("ef", 10)),
             retrieve_vector=False,
             limit=kwargs.get("top_k", 4),

+ 24 - 0
api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py

@@ -326,6 +326,18 @@ class TidbOnQdrantVector(BaseVector):
                 ),
             ],
         )
+        document_ids_filter = kwargs.get("document_ids_filter")
+        if document_ids_filter:
+            should_conditions = []
+            for document_id_filter in document_ids_filter:
+                should_conditions.append(
+                    models.FieldCondition(
+                        key="metadata.document_id",
+                        match=models.MatchValue(value=document_id_filter),
+                    )
+                )
+            if should_conditions:
+                filter.should = should_conditions  # type: ignore
         results = self._client.search(
             collection_name=self._collection_name,
             query_vector=query_vector,
@@ -368,6 +380,18 @@ class TidbOnQdrantVector(BaseVector):
                 )
             ]
         )
+        document_ids_filter = kwargs.get("document_ids_filter")
+        if document_ids_filter:
+            should_conditions = []
+            for document_id_filter in document_ids_filter:
+                should_conditions.append(
+                    models.FieldCondition(
+                        key="metadata.document_id",
+                        match=models.MatchValue(value=document_id_filter),
+                    )
+                )
+            if should_conditions:
+                scroll_filter.should = should_conditions  # type: ignore
         response = self._client.scroll(
             collection_name=self._collection_name,
             scroll_filter=scroll_filter,

+ 6 - 0
api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py

@@ -196,6 +196,11 @@ class TiDBVector(BaseVector):
 
         docs = []
         tidb_dist_func = self._get_distance_func()
+        document_ids_filter = kwargs.get("document_ids_filter")
+        where_clause = ""
+        if document_ids_filter:
+            document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
+            where_clause = f" WHERE meta->>'$.document_id' in ({document_ids}) "
 
         with Session(self._engine) as session:
             select_statement = sql_text(f"""
@@ -206,6 +211,7 @@ class TiDBVector(BaseVector):
                     text,
                     {tidb_dist_func}(vector, :query_vector_str) AS distance
                   FROM {self._collection_name}
+                  {where_clause}
                   ORDER BY distance ASC
                   LIMIT :top_k
                 ) t

+ 14 - 1
api/core/rag/datasource/vdb/upstash/upstash_vector.py

@@ -88,7 +88,20 @@ class UpstashVector(BaseVector):
 
     def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
         top_k = kwargs.get("top_k", 4)
-        result = self.index.query(vector=query_vector, top_k=top_k, include_metadata=True, include_data=True)
+        document_ids_filter = kwargs.get("document_ids_filter")
+        if document_ids_filter:
+            document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
+            filter = f"document_id in ({document_ids})"
+        else:
+            filter = ""
+        result = self.index.query(
+            vector=query_vector,
+            top_k=top_k,
+            include_metadata=True,
+            include_data=True,
+            include_vectors=False,
+            filter=filter,
+        )
         docs = []
         score_threshold = float(kwargs.get("score_threshold") or 0.0)
         for record in result:

+ 5 - 1
api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py

@@ -177,7 +177,11 @@ class VikingDBVector(BaseVector):
             query_vector, limit=kwargs.get("top_k", 4)
         )
         score_threshold = float(kwargs.get("score_threshold") or 0.0)
-        return self._get_search_res(results, score_threshold)
+        docs = self._get_search_res(results, score_threshold)
+        document_ids_filter = kwargs.get("document_ids_filter")
+        if document_ids_filter:
+            docs = [doc for doc in docs if doc.metadata.get("document_id") in document_ids_filter]
+        return docs
 
     def _get_search_res(self, results, score_threshold) -> list[Document]:
         if len(results) == 0:

+ 8 - 4
api/core/rag/datasource/vdb/weaviate/weaviate_vector.py

@@ -187,8 +187,10 @@ class WeaviateVector(BaseVector):
         query_obj = self._client.query.get(collection_name, properties)
 
         vector = {"vector": query_vector}
-        if kwargs.get("where_filter"):
-            query_obj = query_obj.with_where(kwargs.get("where_filter"))
+        document_ids_filter = kwargs.get("document_ids_filter")
+        if document_ids_filter:
+            where_filter = {"operator": "ContainsAny", "path": ["document_id"], "valueTextArray": document_ids_filter}
+            query_obj = query_obj.with_where(where_filter)
         result = (
             query_obj.with_near_vector(vector)
             .with_limit(kwargs.get("top_k", 4))
@@ -233,8 +235,10 @@ class WeaviateVector(BaseVector):
         if kwargs.get("search_distance"):
             content["certainty"] = kwargs.get("search_distance")
         query_obj = self._client.query.get(collection_name, properties)
-        if kwargs.get("where_filter"):
-            query_obj = query_obj.with_where(kwargs.get("where_filter"))
+        document_ids_filter = kwargs.get("document_ids_filter")
+        if document_ids_filter:
+            where_filter = {"operator": "ContainsAny", "path": ["document_id"], "valueTextArray": document_ids_filter}
+            query_obj = query_obj.with_where(where_filter)
         query_obj = query_obj.with_additional(["vector"])
         properties = ["text"]
         result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 4)).do()

+ 45 - 0
api/core/rag/entities/metadata_entities.py

@@ -0,0 +1,45 @@
+from collections.abc import Sequence
+from typing import Literal, Optional
+
+from pydantic import BaseModel, Field
+
+SupportedComparisonOperator = Literal[
+    # for string or array
+    "contains",
+    "not contains",
+    "start with",
+    "end with",
+    "is",
+    "is not",
+    "empty",
+    "not empty",
+    # for number
+    "=",
+    "≠",
+    ">",
+    "<",
+    "≥",
+    "≤",
+    # for time
+    "before",
+    "after",
+]
+
+
+class Condition(BaseModel):
+    """
+    Conditon detail
+    """
+
+    name: str
+    comparison_operator: SupportedComparisonOperator
+    value: str | Sequence[str] | None | int | float = None
+
+
+class MetadataCondition(BaseModel):
+    """
+    Metadata Condition.
+    """
+
+    logical_operator: Optional[Literal["and", "or"]] = "and"
+    conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)

+ 15 - 0
api/core/rag/index_processor/constant/built_in_field.py

@@ -0,0 +1,15 @@
+from enum import Enum
+
+
+class BuiltInField(str, Enum):
+    document_name = "document_name"
+    uploader = "uploader"
+    upload_date = "upload_date"
+    last_update_date = "last_update_date"
+    source = "source"
+
+
+class MetadataDataSource(Enum):
+    upload_file = "file_upload"
+    website_crawl = "website"
+    notion_import = "notion"

+ 429 - 7
api/core/rag/retrieval/dataset_retrieval.py

@@ -1,35 +1,61 @@
+import json
 import math
+import re
 import threading
-from collections import Counter
-from typing import Any, Optional, cast
+from collections import Counter, defaultdict
+from collections.abc import Generator, Mapping
+from typing import Any, Optional, Union, cast
 
 from flask import Flask, current_app
-
-from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
+from sqlalchemy import Integer, and_, or_, text
+from sqlalchemy import cast as sqlalchemy_cast
+
+from core.app.app_config.entities import (
+    DatasetEntity,
+    DatasetRetrieveConfigEntity,
+    MetadataFilteringCondition,
+    ModelConfig,
+)
 from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
 from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
 from core.entities.agent_entities import PlanningStrategy
+from core.entities.model_entities import ModelStatus
 from core.memory.token_buffer_memory import TokenBufferMemory
 from core.model_manager import ModelInstance, ModelManager
-from core.model_runtime.entities.message_entities import PromptMessageTool
+from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
+from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
 from core.model_runtime.entities.model_entities import ModelFeature, ModelType
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.ops.entities.trace_entity import TraceTaskName
 from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
 from core.ops.utils import measure_time
+from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
+from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
+from core.prompt.simple_prompt_transform import ModelMode
 from core.rag.data_post_processor.data_post_processor import DataPostProcessor
 from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
 from core.rag.datasource.retrieval_service import RetrievalService
 from core.rag.entities.context_entities import DocumentContext
+from core.rag.entities.metadata_entities import Condition, MetadataCondition
 from core.rag.index_processor.constant.index_type import IndexType
 from core.rag.models.document import Document
 from core.rag.rerank.rerank_type import RerankMode
 from core.rag.retrieval.retrieval_methods import RetrievalMethod
 from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
 from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
+from core.rag.retrieval.template_prompts import (
+    METADATA_FILTER_ASSISTANT_PROMPT_1,
+    METADATA_FILTER_ASSISTANT_PROMPT_2,
+    METADATA_FILTER_COMPLETION_PROMPT,
+    METADATA_FILTER_SYSTEM_PROMPT,
+    METADATA_FILTER_USER_PROMPT_1,
+    METADATA_FILTER_USER_PROMPT_2,
+    METADATA_FILTER_USER_PROMPT_3,
+)
 from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
 from extensions.ext_database import db
-from models.dataset import ChildChunk, Dataset, DatasetQuery, DocumentSegment
+from libs.json_in_md_parser import parse_and_check_json_markdown
+from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment
 from models.dataset import Document as DatasetDocument
 from services.external_knowledge_service import ExternalDatasetService
 
@@ -59,6 +85,7 @@ class DatasetRetrieval:
         hit_callback: DatasetIndexToolCallbackHandler,
         message_id: str,
         memory: Optional[TokenBufferMemory] = None,
+        inputs: Optional[Mapping[str, Any]] = None,
     ) -> Optional[str]:
         """
         Retrieve dataset.
@@ -116,6 +143,22 @@ class DatasetRetrieval:
                 continue
 
             available_datasets.append(dataset)
+        if inputs:
+            inputs = {key: str(value) for key, value in inputs.items()}
+        else:
+            inputs = {}
+        available_datasets_ids = [dataset.id for dataset in available_datasets]
+        metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition(
+            available_datasets_ids,
+            query,
+            tenant_id,
+            user_id,
+            retrieve_config.metadata_filtering_mode,  # type: ignore
+            retrieve_config.metadata_model_config,  # type: ignore
+            retrieve_config.metadata_filtering_conditions,
+            inputs,
+        )
+
         all_documents = []
         user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
         if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
@@ -130,6 +173,8 @@ class DatasetRetrieval:
                 model_config,
                 planning_strategy,
                 message_id,
+                metadata_filter_document_ids,
+                metadata_condition,
             )
         elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
             all_documents = self.multiple_retrieve(
@@ -146,6 +191,8 @@ class DatasetRetrieval:
                 retrieve_config.weights,
                 retrieve_config.reranking_enabled or True,
                 message_id,
+                metadata_filter_document_ids,
+                metadata_condition,
             )
 
         dify_documents = [item for item in all_documents if item.provider == "dify"]
@@ -239,6 +286,8 @@ class DatasetRetrieval:
         model_config: ModelConfigWithCredentialsEntity,
         planning_strategy: PlanningStrategy,
         message_id: Optional[str] = None,
+        metadata_filter_document_ids: Optional[dict[str, list[str]]] = None,
+        metadata_condition: Optional[MetadataCondition] = None,
     ):
         tools = []
         for dataset in available_datasets:
@@ -279,6 +328,7 @@ class DatasetRetrieval:
                         dataset_id=dataset_id,
                         query=query,
                         external_retrieval_parameters=dataset.retrieval_model,
+                        metadata_condition=metadata_condition,
                     )
                     for external_document in external_documents:
                         document = Document(
@@ -293,6 +343,15 @@ class DatasetRetrieval:
                             document.metadata["dataset_name"] = dataset.name
                         results.append(document)
                 else:
+                    if metadata_condition and not metadata_filter_document_ids:
+                        return []
+                    document_ids_filter = None
+                    if metadata_filter_document_ids:
+                        document_ids = metadata_filter_document_ids.get(dataset.id, [])
+                        if document_ids:
+                            document_ids_filter = document_ids
+                        else:
+                            return []
                     retrieval_model_config = dataset.retrieval_model or default_retrieval_model
 
                     # get top k
@@ -324,6 +383,7 @@ class DatasetRetrieval:
                             reranking_model=reranking_model,
                             reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"),
                             weights=retrieval_model_config.get("weights", None),
+                            document_ids_filter=document_ids_filter,
                         )
                 self._on_query(query, [dataset_id], app_id, user_from, user_id)
 
@@ -348,6 +408,8 @@ class DatasetRetrieval:
         weights: Optional[dict[str, Any]] = None,
         reranking_enable: bool = True,
         message_id: Optional[str] = None,
+        metadata_filter_document_ids: Optional[dict[str, list[str]]] = None,
+        metadata_condition: Optional[MetadataCondition] = None,
     ):
         if not available_datasets:
             return []
@@ -387,6 +449,16 @@ class DatasetRetrieval:
 
         for dataset in available_datasets:
             index_type = dataset.indexing_technique
+            document_ids_filter = None
+            if dataset.provider != "external":
+                if metadata_condition and not metadata_filter_document_ids:
+                    continue
+                if metadata_filter_document_ids:
+                    document_ids = metadata_filter_document_ids.get(dataset.id, [])
+                    if document_ids:
+                        document_ids_filter = document_ids
+                    else:
+                        continue
             retrieval_thread = threading.Thread(
                 target=self._retriever,
                 kwargs={
@@ -395,6 +467,8 @@ class DatasetRetrieval:
                     "query": query,
                     "top_k": top_k,
                     "all_documents": all_documents,
+                    "document_ids_filter": document_ids_filter,
+                    "metadata_condition": metadata_condition,
                 },
             )
             threads.append(retrieval_thread)
@@ -493,7 +567,16 @@ class DatasetRetrieval:
             db.session.add_all(dataset_queries)
         db.session.commit()
 
-    def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
+    def _retriever(
+        self,
+        flask_app: Flask,
+        dataset_id: str,
+        query: str,
+        top_k: int,
+        all_documents: list,
+        document_ids_filter: Optional[list[str]] = None,
+        metadata_condition: Optional[MetadataCondition] = None,
+    ):
         with flask_app.app_context():
             dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
 
@@ -506,6 +589,7 @@ class DatasetRetrieval:
                     dataset_id=dataset_id,
                     query=query,
                     external_retrieval_parameters=dataset.retrieval_model,
+                    metadata_condition=metadata_condition,
                 )
                 for external_document in external_documents:
                     document = Document(
@@ -546,6 +630,7 @@ class DatasetRetrieval:
                             else None,
                             reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
                             weights=retrieval_model.get("weights", None),
+                            document_ids_filter=document_ids_filter,
                         )
 
                         all_documents.extend(documents)
@@ -733,3 +818,340 @@ class DatasetRetrieval:
             filter_documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True
         )
         return filter_documents[:top_k] if top_k else filter_documents
+
+    def _get_metadata_filter_condition(
+        self,
+        dataset_ids: list,
+        query: str,
+        tenant_id: str,
+        user_id: str,
+        metadata_filtering_mode: str,
+        metadata_model_config: ModelConfig,
+        metadata_filtering_conditions: Optional[MetadataFilteringCondition],
+        inputs: dict,
+    ) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
+        document_query = db.session.query(DatasetDocument).filter(
+            DatasetDocument.dataset_id.in_(dataset_ids),
+            DatasetDocument.indexing_status == "completed",
+            DatasetDocument.enabled == True,
+            DatasetDocument.archived == False,
+        )
+        filters = []  # type: ignore
+        metadata_condition = None
+        if metadata_filtering_mode == "disabled":
+            return None, None
+        elif metadata_filtering_mode == "automatic":
+            automatic_metadata_filters = self._automatic_metadata_filter_func(
+                dataset_ids, query, tenant_id, user_id, metadata_model_config
+            )
+            if automatic_metadata_filters:
+                conditions = []
+                for filter in automatic_metadata_filters:
+                    self._process_metadata_filter_func(
+                        filter.get("condition"),  # type: ignore
+                        filter.get("metadata_name"),  # type: ignore
+                        filter.get("value"),
+                        filters,  # type: ignore
+                    )
+                    conditions.append(
+                        Condition(
+                            name=filter.get("metadata_name"),  # type: ignore
+                            comparison_operator=filter.get("condition"),  # type: ignore
+                            value=filter.get("value"),
+                        )
+                    )
+                metadata_condition = MetadataCondition(
+                    logical_operator=metadata_filtering_conditions.logical_operator,  # type: ignore
+                    conditions=conditions,
+                )
+        elif metadata_filtering_mode == "manual":
+            if metadata_filtering_conditions:
+                metadata_condition = MetadataCondition(**metadata_filtering_conditions.model_dump())
+                for condition in metadata_filtering_conditions.conditions:  # type: ignore
+                    metadata_name = condition.name
+                    expected_value = condition.value
+                    if expected_value or condition.comparison_operator in ("empty", "not empty"):
+                        if isinstance(expected_value, str):
+                            expected_value = self._replace_metadata_filter_value(expected_value, inputs)
+                        filters = self._process_metadata_filter_func(
+                            condition.comparison_operator, metadata_name, expected_value, filters
+                        )
+        else:
+            raise ValueError("Invalid metadata filtering mode")
+        if filters:
+            if metadata_filtering_conditions.logical_operator == "or":  # type: ignore
+                document_query = document_query.filter(or_(*filters))
+            else:
+                document_query = document_query.filter(and_(*filters))
+        documents = document_query.all()
+        # group by dataset_id
+        metadata_filter_document_ids = defaultdict(list) if documents else None  # type: ignore
+        for document in documents:
+            metadata_filter_document_ids[document.dataset_id].append(document.id)  # type: ignore
+        return metadata_filter_document_ids, metadata_condition
+
+    def _replace_metadata_filter_value(self, text: str, inputs: dict) -> str:
+        def replacer(match):
+            key = match.group(1)
+            return str(inputs.get(key, f"{{{{{key}}}}}"))
+
+        pattern = re.compile(r"\{\{(\w+)\}\}")
+        return pattern.sub(replacer, text)
+
+    def _automatic_metadata_filter_func(
+        self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
+    ) -> Optional[list[dict[str, Any]]]:
+        # get all metadata field
+        metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
+        all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
+        # get metadata model config
+        if metadata_model_config is None:
+            raise ValueError("metadata_model_config is required")
+        # get metadata model instance
+        # fetch model config
+        model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config)
+
+        # fetch prompt messages
+        prompt_messages, stop = self._get_prompt_template(
+            model_config=model_config,
+            mode=metadata_model_config.mode,
+            metadata_fields=all_metadata_fields,
+            query=query or "",
+        )
+
+        result_text = ""
+        try:
+            # handle invoke result
+            invoke_result = cast(
+                Generator[LLMResult, None, None],
+                model_instance.invoke_llm(
+                    prompt_messages=prompt_messages,
+                    model_parameters=model_config.parameters,
+                    stop=stop,
+                    stream=True,
+                    user=user_id,
+                ),
+            )
+
+            # handle invoke result
+            result_text, usage = self._handle_invoke_result(invoke_result=invoke_result)
+
+            result_text_json = parse_and_check_json_markdown(result_text, [])
+            automatic_metadata_filters = []
+            if "metadata_map" in result_text_json:
+                metadata_map = result_text_json["metadata_map"]
+                for item in metadata_map:
+                    if item.get("metadata_field_name") in all_metadata_fields:
+                        automatic_metadata_filters.append(
+                            {
+                                "metadata_name": item.get("metadata_field_name"),
+                                "value": item.get("metadata_field_value"),
+                                "condition": item.get("comparison_operator"),
+                            }
+                        )
+        except Exception as e:
+            return None
+        return automatic_metadata_filters
+
+    def _process_metadata_filter_func(self, condition: str, metadata_name: str, value: Optional[Any], filters: list):
+        match condition:
+            case "contains":
+                filters.append(
+                    (text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}%")
+                )
+            case "not contains":
+                filters.append(
+                    (text("documents.doc_metadata ->> :key NOT LIKE :value")).params(
+                        key=metadata_name, value=f"%{value}%"
+                    )
+                )
+            case "start with":
+                filters.append(
+                    (text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"{value}%")
+                )
+
+            case "end with":
+                filters.append(
+                    (text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}")
+                )
+            case "is" | "=":
+                if isinstance(value, str):
+                    filters.append(DatasetDocument.doc_metadata[metadata_name] == f'"{value}"')
+                else:
+                    filters.append(
+                        sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) == value
+                    )
+            case "is not" | "≠":
+                if isinstance(value, str):
+                    filters.append(DatasetDocument.doc_metadata[metadata_name] != f'"{value}"')
+                else:
+                    filters.append(
+                        sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) != value
+                    )
+            case "empty":
+                filters.append(DatasetDocument.doc_metadata[metadata_name].is_(None))
+            case "not empty":
+                filters.append(DatasetDocument.doc_metadata[metadata_name].isnot(None))
+            case "before" | "<":
+                filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) < value)
+            case "after" | ">":
+                filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) > value)
+            case "≤" | ">=":
+                filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) <= value)
+            case "≥" | ">=":
+                filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) >= value)
+            case _:
+                pass
+        return filters
+
+    def _fetch_model_config(
+        self, tenant_id: str, model: ModelConfig
+    ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
+        """
+        Fetch model config
+        :param node_data: node data
+        :return:
+        """
+        if model is None:
+            raise ValueError("single_retrieval_config is required")
+        model_name = model.name
+        provider_name = model.provider
+
+        model_manager = ModelManager()
+        model_instance = model_manager.get_model_instance(
+            tenant_id=tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
+        )
+
+        provider_model_bundle = model_instance.provider_model_bundle
+        model_type_instance = model_instance.model_type_instance
+        model_type_instance = cast(LargeLanguageModel, model_type_instance)
+
+        model_credentials = model_instance.credentials
+
+        # check model
+        provider_model = provider_model_bundle.configuration.get_provider_model(
+            model=model_name, model_type=ModelType.LLM
+        )
+
+        if provider_model is None:
+            raise ValueError(f"Model {model_name} not exist.")
+
+        if provider_model.status == ModelStatus.NO_CONFIGURE:
+            raise ValueError(f"Model {model_name} credentials is not initialized.")
+        elif provider_model.status == ModelStatus.NO_PERMISSION:
+            raise ValueError(f"Dify Hosted OpenAI {model_name} currently not support.")
+        elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
+            raise ValueError(f"Model provider {provider_name} quota exceeded.")
+
+        # model config
+        completion_params = model.completion_params
+        stop = []
+        if "stop" in completion_params:
+            stop = completion_params["stop"]
+            del completion_params["stop"]
+
+        # get model mode
+        model_mode = model.mode
+        if not model_mode:
+            raise ValueError("LLM mode is required.")
+
+        model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
+
+        if not model_schema:
+            raise ValueError(f"Model {model_name} not exist.")
+
+        return model_instance, ModelConfigWithCredentialsEntity(
+            provider=provider_name,
+            model=model_name,
+            model_schema=model_schema,
+            mode=model_mode,
+            provider_model_bundle=provider_model_bundle,
+            credentials=model_credentials,
+            parameters=completion_params,
+            stop=stop,
+        )
+
+    def _get_prompt_template(
+        self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list, query: str
+    ):
+        model_mode = ModelMode.value_of(mode)
+        input_text = query
+
+        prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]]
+        if model_mode == ModelMode.CHAT:
+            prompt_template = []
+            system_prompt_messages = ChatModelMessage(role=PromptMessageRole.SYSTEM, text=METADATA_FILTER_SYSTEM_PROMPT)
+            prompt_template.append(system_prompt_messages)
+            user_prompt_message_1 = ChatModelMessage(role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_1)
+            prompt_template.append(user_prompt_message_1)
+            assistant_prompt_message_1 = ChatModelMessage(
+                role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_1
+            )
+            prompt_template.append(assistant_prompt_message_1)
+            user_prompt_message_2 = ChatModelMessage(role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_2)
+            prompt_template.append(user_prompt_message_2)
+            assistant_prompt_message_2 = ChatModelMessage(
+                role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_2
+            )
+            prompt_template.append(assistant_prompt_message_2)
+            user_prompt_message_3 = ChatModelMessage(
+                role=PromptMessageRole.USER,
+                text=METADATA_FILTER_USER_PROMPT_3.format(
+                    input_text=input_text,
+                    metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
+                ),
+            )
+            prompt_template.append(user_prompt_message_3)
+        elif model_mode == ModelMode.COMPLETION:
+            prompt_template = CompletionModelPromptTemplate(
+                text=METADATA_FILTER_COMPLETION_PROMPT.format(
+                    input_text=input_text,
+                    metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
+                )
+            )
+
+        else:
+            raise ValueError(f"Model mode {model_mode} not support.")
+
+        prompt_transform = AdvancedPromptTransform()
+        prompt_messages = prompt_transform.get_prompt(
+            prompt_template=prompt_template,
+            inputs={},
+            query=query or "",
+            files=[],
+            context=None,
+            memory_config=None,
+            memory=None,
+            model_config=model_config,
+        )
+        stop = model_config.stop
+
+        return prompt_messages, stop
+
+    def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
+        """
+        Handle invoke result
+        :param invoke_result: invoke result
+        :return:
+        """
+        model = None
+        prompt_messages: list[PromptMessage] = []
+        full_text = ""
+        usage = None
+        for result in invoke_result:
+            text = result.delta.message.content
+            full_text += text
+
+            if not model:
+                model = result.model
+
+            if not prompt_messages:
+                prompt_messages = result.prompt_messages
+
+            if not usage and result.delta.usage:
+                usage = result.delta.usage
+
+        if not usage:
+            usage = LLMUsage.empty_usage()
+
+        return full_text, usage

+ 66 - 0
api/core/rag/retrieval/template_prompts.py

@@ -0,0 +1,66 @@
+METADATA_FILTER_SYSTEM_PROMPT = """
+    ### Job Description',
+    You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
+    ### Task
+    Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
+    ### Format
+    The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
+    ### Constraint
+    DO NOT include anything other than the JSON array in your response.
+"""  # noqa: E501
+
+METADATA_FILTER_USER_PROMPT_1 = """
+    { "input_text": "I want to know which company’s email address test@example.com is?",
+    "metadata_fields": ["filename", "email", "phone", "address"]
+    }
+"""
+
+METADATA_FILTER_ASSISTANT_PROMPT_1 = """
+```json
+    {"metadata_map": [
+        {"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}
+    ]
+    }
+```
+"""
+
+METADATA_FILTER_USER_PROMPT_2 = """
+    {"input_text": "What are the movies with a score of more than 9 in 2024?",
+    "metadata_fields": ["name", "year", "rating", "country"]}
+"""
+
+METADATA_FILTER_ASSISTANT_PROMPT_2 = """
+```json
+    {"metadata_map": [
+        {"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="},
+        {"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"},
+    ]}
+```
+"""
+
+METADATA_FILTER_USER_PROMPT_3 = """
+    '{{"input_text": "{input_text}",',
+    '"metadata_fields": {metadata_fields}}}'
+"""
+
+METADATA_FILTER_COMPLETION_PROMPT = """
+### Job Description
+You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
+### Task
+# Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
+### Format
+The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
+### Constraint 
+DO NOT include anything other than the JSON array in your response.
+### Example
+Here is the chat example between human and assistant, inside <example></example> XML tags.
+<example>
+User:{{"input_text": ["I want to know which company’s email address test@example.com is?"], "metadata_fields": ["filename", "email", "phone", "address"]}}
+Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}}
+User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}}
+Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}}
+</example> 
+### User Input
+{{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}}
+### Assistant Output
+"""  # noqa: E501

+ 49 - 1
api/core/workflow/nodes/knowledge_retrieval/entities.py

@@ -1,8 +1,10 @@
+from collections.abc import Sequence
 from typing import Any, Literal, Optional
 
-from pydantic import BaseModel
+from pydantic import BaseModel, Field
 
 from core.workflow.nodes.base import BaseNodeData
+from core.workflow.nodes.llm.entities import VisionConfig
 
 
 class RerankingModelConfig(BaseModel):
@@ -73,6 +75,48 @@ class SingleRetrievalConfig(BaseModel):
     model: ModelConfig
 
 
+SupportedComparisonOperator = Literal[
+    # for string or array
+    "contains",
+    "not contains",
+    "start with",
+    "end with",
+    "is",
+    "is not",
+    "empty",
+    "not empty",
+    # for number
+    "=",
+    "≠",
+    ">",
+    "<",
+    "≥",
+    "≤",
+    # for time
+    "before",
+    "after",
+]
+
+
+class Condition(BaseModel):
+    """
+    Conditon detail
+    """
+
+    name: str
+    comparison_operator: SupportedComparisonOperator
+    value: str | Sequence[str] | None | int | float = None
+
+
+class MetadataFilteringCondition(BaseModel):
+    """
+    Metadata Filtering Condition.
+    """
+
+    logical_operator: Optional[Literal["and", "or"]] = "and"
+    conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
+
+
 class KnowledgeRetrievalNodeData(BaseNodeData):
     """
     Knowledge retrieval Node Data.
@@ -84,3 +128,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
     retrieval_mode: Literal["single", "multiple"]
     multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None
     single_retrieval_config: Optional[SingleRetrievalConfig] = None
+    metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled"
+    metadata_model_config: Optional[ModelConfig] = None
+    metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None
+    vision: VisionConfig = Field(default_factory=VisionConfig)

+ 4 - 0
api/core/workflow/nodes/knowledge_retrieval/exc.py

@@ -16,3 +16,7 @@ class ModelNotSupportedError(KnowledgeRetrievalNodeError):
 
 class ModelQuotaExceededError(KnowledgeRetrievalNodeError):
     """Raised when the model provider quota is exceeded."""
+
+
+class InvalidModelTypeError(KnowledgeRetrievalNodeError):
+    """Raised when the model is not a Large Language Model."""

+ 268 - 22
api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py

@@ -1,32 +1,51 @@
+import json
 import logging
 import time
+from collections import defaultdict
 from collections.abc import Mapping, Sequence
-from typing import Any, cast
+from typing import Any, Optional, cast
 
-from sqlalchemy import func
+from sqlalchemy import Integer, and_, func, or_, text
+from sqlalchemy import cast as sqlalchemy_cast
 
 from core.app.app_config.entities import DatasetRetrieveConfigEntity
 from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
 from core.entities.agent_entities import PlanningStrategy
 from core.entities.model_entities import ModelStatus
 from core.model_manager import ModelInstance, ModelManager
+from core.model_runtime.entities.message_entities import PromptMessageRole
 from core.model_runtime.entities.model_entities import ModelFeature, ModelType
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.prompt.simple_prompt_transform import ModelMode
 from core.rag.datasource.retrieval_service import RetrievalService
+from core.rag.entities.metadata_entities import Condition, MetadataCondition
 from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
 from core.rag.retrieval.retrieval_methods import RetrievalMethod
 from core.variables import StringSegment
 from core.workflow.entities.node_entities import NodeRunResult
-from core.workflow.nodes.base import BaseNode
 from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.event.event import ModelInvokeCompletedEvent
+from core.workflow.nodes.knowledge_retrieval.template_prompts import (
+    METADATA_FILTER_ASSISTANT_PROMPT_1,
+    METADATA_FILTER_ASSISTANT_PROMPT_2,
+    METADATA_FILTER_COMPLETION_PROMPT,
+    METADATA_FILTER_SYSTEM_PROMPT,
+    METADATA_FILTER_USER_PROMPT_1,
+    METADATA_FILTER_USER_PROMPT_3,
+)
+from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate
+from core.workflow.nodes.llm.node import LLMNode
+from core.workflow.nodes.question_classifier.template_prompts import QUESTION_CLASSIFIER_USER_PROMPT_2
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
-from models.dataset import Dataset, Document, RateLimitLog
+from libs.json_in_md_parser import parse_and_check_json_markdown
+from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog
 from models.workflow import WorkflowNodeExecutionStatus
 from services.feature_service import FeatureService
 
-from .entities import KnowledgeRetrievalNodeData
+from .entities import KnowledgeRetrievalNodeData, ModelConfig
 from .exc import (
+    InvalidModelTypeError,
     KnowledgeRetrievalNodeError,
     ModelCredentialsNotInitializedError,
     ModelNotExistError,
@@ -45,13 +64,14 @@ default_retrieval_model = {
 }
 
 
-class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
-    _node_data_cls = KnowledgeRetrievalNodeData
+class KnowledgeRetrievalNode(LLMNode):
+    _node_data_cls = KnowledgeRetrievalNodeData  # type: ignore
     _node_type = NodeType.KNOWLEDGE_RETRIEVAL
 
-    def _run(self) -> NodeRunResult:
+    def _run(self) -> NodeRunResult:  # type: ignore
+        node_data = cast(KnowledgeRetrievalNodeData, self.node_data)
         # extract variables
-        variable = self.graph_runtime_state.variable_pool.get(self.node_data.query_variable_selector)
+        variable = self.graph_runtime_state.variable_pool.get(node_data.query_variable_selector)
         if not isinstance(variable, StringSegment):
             return NodeRunResult(
                 status=WorkflowNodeExecutionStatus.FAILED,
@@ -91,7 +111,7 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
 
         # retrieve knowledge
         try:
-            results = self._fetch_dataset_retriever(node_data=self.node_data, query=query)
+            results = self._fetch_dataset_retriever(node_data=node_data, query=query)
             outputs = {"result": results}
             return NodeRunResult(
                 status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
@@ -145,11 +165,14 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
             if not dataset:
                 continue
             available_datasets.append(dataset)
+        metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition(
+            [dataset.id for dataset in available_datasets], query, node_data
+        )
         all_documents = []
         dataset_retrieval = DatasetRetrieval()
         if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value:
             # fetch model config
-            model_instance, model_config = self._fetch_model_config(node_data)
+            model_instance, model_config = self._fetch_model_config(node_data.single_retrieval_config.model)  # type: ignore
             # check model is support tool calling
             model_type_instance = model_config.provider_model_bundle.model_type_instance
             model_type_instance = cast(LargeLanguageModel, model_type_instance)
@@ -174,6 +197,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
                     model_config=model_config,
                     model_instance=model_instance,
                     planning_strategy=planning_strategy,
+                    metadata_filter_document_ids=metadata_filter_document_ids,
+                    metadata_condition=metadata_condition,
                 )
         elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
             if node_data.multiple_retrieval_config is None:
@@ -220,6 +245,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
                 reranking_model=reranking_model,
                 weights=weights,
                 reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
+                metadata_filter_document_ids=metadata_filter_document_ids,
+                metadata_condition=metadata_condition,
             )
         dify_documents = [item for item in all_documents if item.provider == "dify"]
         external_documents = [item for item in all_documents if item.provider == "external"]
@@ -287,13 +314,187 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
                 item["metadata"]["position"] = position
         return retrieval_resource_list
 
+    def _get_metadata_filter_condition(
+        self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
+    ) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
+        document_query = db.session.query(Document).filter(
+            Document.dataset_id.in_(dataset_ids),
+            Document.indexing_status == "completed",
+            Document.enabled == True,
+            Document.archived == False,
+        )
+        filters = []  # type: ignore
+        metadata_condition = None
+        if node_data.metadata_filtering_mode == "disabled":
+            return None, None
+        elif node_data.metadata_filtering_mode == "automatic":
+            automatic_metadata_filters = self._automatic_metadata_filter_func(dataset_ids, query, node_data)
+            if automatic_metadata_filters:
+                conditions = []
+                for filter in automatic_metadata_filters:
+                    self._process_metadata_filter_func(
+                        filter.get("condition", ""),
+                        filter.get("metadata_name", ""),
+                        filter.get("value"),
+                        filters,  # type: ignore
+                    )
+                    conditions.append(
+                        Condition(
+                            name=filter.get("metadata_name"),  # type: ignore
+                            comparison_operator=filter.get("condition"),  # type: ignore
+                            value=filter.get("value"),
+                        )
+                    )
+                metadata_condition = MetadataCondition(
+                    logical_operator=node_data.metadata_filtering_conditions.logical_operator,  # type: ignore
+                    conditions=conditions,
+                )
+        elif node_data.metadata_filtering_mode == "manual":
+            if node_data.metadata_filtering_conditions:
+                metadata_condition = MetadataCondition(**node_data.metadata_filtering_conditions.model_dump())
+                if node_data.metadata_filtering_conditions:
+                    for condition in node_data.metadata_filtering_conditions.conditions:  # type: ignore
+                        metadata_name = condition.name
+                        expected_value = condition.value
+                        if expected_value or condition.comparison_operator in ("empty", "not empty"):
+                            if isinstance(expected_value, str):
+                                expected_value = self.graph_runtime_state.variable_pool.convert_template(
+                                    expected_value
+                                ).text
+
+                            filters = self._process_metadata_filter_func(
+                                condition.comparison_operator, metadata_name, expected_value, filters
+                            )
+        else:
+            raise ValueError("Invalid metadata filtering mode")
+        if filters:
+            if node_data.metadata_filtering_conditions.logical_operator == "and":  # type: ignore
+                document_query = document_query.filter(and_(*filters))
+            else:
+                document_query = document_query.filter(or_(*filters))
+        documents = document_query.all()
+        # group by dataset_id
+        metadata_filter_document_ids = defaultdict(list) if documents else None  # type: ignore
+        for document in documents:
+            metadata_filter_document_ids[document.dataset_id].append(document.id)  # type: ignore
+        return metadata_filter_document_ids, metadata_condition
+
+    def _automatic_metadata_filter_func(
+        self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
+    ) -> list[dict[str, Any]]:
+        # get all metadata field
+        metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
+        all_metadata_fields = [metadata_field.field_name for metadata_field in metadata_fields]
+        # get metadata model config
+        metadata_model_config = node_data.metadata_model_config
+        if metadata_model_config is None:
+            raise ValueError("metadata_model_config is required")
+        # get metadata model instance
+        # fetch model config
+        model_instance, model_config = self._fetch_model_config(node_data.metadata_model_config)  # type: ignore
+        # fetch prompt messages
+        prompt_template = self._get_prompt_template(
+            node_data=node_data,
+            metadata_fields=all_metadata_fields,
+            query=query or "",
+        )
+        prompt_messages, stop = self._fetch_prompt_messages(
+            prompt_template=prompt_template,
+            sys_query=query,
+            memory=None,
+            model_config=model_config,
+            sys_files=[],
+            vision_enabled=node_data.vision.enabled,
+            vision_detail=node_data.vision.configs.detail,
+            variable_pool=self.graph_runtime_state.variable_pool,
+            jinja2_variables=[],
+        )
+
+        result_text = ""
+        try:
+            # handle invoke result
+            generator = self._invoke_llm(
+                node_data_model=node_data.metadata_model_config,  # type: ignore
+                model_instance=model_instance,
+                prompt_messages=prompt_messages,
+                stop=stop,
+            )
+
+            for event in generator:
+                if isinstance(event, ModelInvokeCompletedEvent):
+                    result_text = event.text
+                    break
+
+            result_text_json = parse_and_check_json_markdown(result_text, [])
+            automatic_metadata_filters = []
+            if "metadata_map" in result_text_json:
+                metadata_map = result_text_json["metadata_map"]
+                for item in metadata_map:
+                    if item.get("metadata_field_name") in all_metadata_fields:
+                        automatic_metadata_filters.append(
+                            {
+                                "metadata_name": item.get("metadata_field_name"),
+                                "value": item.get("metadata_field_value"),
+                                "condition": item.get("comparison_operator"),
+                            }
+                        )
+        except Exception as e:
+            return []
+        return automatic_metadata_filters
+
+    def _process_metadata_filter_func(self, condition: str, metadata_name: str, value: Optional[str], filters: list):
+        match condition:
+            case "contains":
+                filters.append(
+                    (text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}%")
+                )
+            case "not contains":
+                filters.append(
+                    (text("documents.doc_metadata ->> :key NOT LIKE :value")).params(
+                        key=metadata_name, value=f"%{value}%"
+                    )
+                )
+            case "start with":
+                filters.append(
+                    (text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"{value}%")
+                )
+            case "end with":
+                filters.append(
+                    (text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}")
+                )
+            case "=" | "is":
+                if isinstance(value, str):
+                    filters.append(Document.doc_metadata[metadata_name] == f'"{value}"')
+                else:
+                    filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) == value)
+            case "is not" | "≠":
+                if isinstance(value, str):
+                    filters.append(Document.doc_metadata[metadata_name] != f'"{value}"')
+                else:
+                    filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) != value)
+            case "empty":
+                filters.append(Document.doc_metadata[metadata_name].is_(None))
+            case "not empty":
+                filters.append(Document.doc_metadata[metadata_name].isnot(None))
+            case "before" | "<":
+                filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) < value)
+            case "after" | ">":
+                filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) > value)
+            case "≤" | ">=":
+                filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) <= value)
+            case "≥" | ">=":
+                filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) >= value)
+            case _:
+                pass
+        return filters
+
     @classmethod
     def _extract_variable_selector_to_variable_mapping(
         cls,
         *,
         graph_config: Mapping[str, Any],
         node_id: str,
-        node_data: KnowledgeRetrievalNodeData,
+        node_data: KnowledgeRetrievalNodeData,  # type: ignore
     ) -> Mapping[str, Sequence[str]]:
         """
         Extract variable selector to variable mapping
@@ -306,18 +507,16 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
         variable_mapping[node_id + ".query"] = node_data.query_variable_selector
         return variable_mapping
 
-    def _fetch_model_config(
-        self, node_data: KnowledgeRetrievalNodeData
-    ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
+    def _fetch_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:  # type: ignore
         """
         Fetch model config
-        :param node_data: node data
+        :param model: model
         :return:
         """
-        if node_data.single_retrieval_config is None:
-            raise ValueError("single_retrieval_config is required")
-        model_name = node_data.single_retrieval_config.model.name
-        provider_name = node_data.single_retrieval_config.model.provider
+        if model is None:
+            raise ValueError("model is required")
+        model_name = model.name
+        provider_name = model.provider
 
         model_manager = ModelManager()
         model_instance = model_manager.get_model_instance(
@@ -346,14 +545,14 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
             raise ModelQuotaExceededError(f"Model provider {provider_name} quota exceeded.")
 
         # model config
-        completion_params = node_data.single_retrieval_config.model.completion_params
+        completion_params = model.completion_params
         stop = []
         if "stop" in completion_params:
             stop = completion_params["stop"]
             del completion_params["stop"]
 
         # get model mode
-        model_mode = node_data.single_retrieval_config.model.mode
+        model_mode = model.mode
         if not model_mode:
             raise ModelNotExistError("LLM mode is required.")
 
@@ -372,3 +571,50 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
             parameters=completion_params,
             stop=stop,
         )
+
+    def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str):
+        model_mode = ModelMode.value_of(node_data.metadata_model_config.mode)  # type: ignore
+        input_text = query
+        memory_str = ""
+
+        prompt_messages: list[LLMNodeChatModelMessage] = []
+        if model_mode == ModelMode.CHAT:
+            system_prompt_messages = LLMNodeChatModelMessage(
+                role=PromptMessageRole.SYSTEM, text=METADATA_FILTER_SYSTEM_PROMPT
+            )
+            prompt_messages.append(system_prompt_messages)
+            user_prompt_message_1 = LLMNodeChatModelMessage(
+                role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_1
+            )
+            prompt_messages.append(user_prompt_message_1)
+            assistant_prompt_message_1 = LLMNodeChatModelMessage(
+                role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_1
+            )
+            prompt_messages.append(assistant_prompt_message_1)
+            user_prompt_message_2 = LLMNodeChatModelMessage(
+                role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2
+            )
+            prompt_messages.append(user_prompt_message_2)
+            assistant_prompt_message_2 = LLMNodeChatModelMessage(
+                role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_2
+            )
+            prompt_messages.append(assistant_prompt_message_2)
+            user_prompt_message_3 = LLMNodeChatModelMessage(
+                role=PromptMessageRole.USER,
+                text=METADATA_FILTER_USER_PROMPT_3.format(
+                    input_text=input_text,
+                    metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
+                ),
+            )
+            prompt_messages.append(user_prompt_message_3)
+            return prompt_messages
+        elif model_mode == ModelMode.COMPLETION:
+            return LLMNodeCompletionModelPromptTemplate(
+                text=METADATA_FILTER_COMPLETION_PROMPT.format(
+                    input_text=input_text,
+                    metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
+                )
+            )
+
+        else:
+            raise InvalidModelTypeError(f"Model mode {model_mode} not support.")

+ 66 - 0
api/core/workflow/nodes/knowledge_retrieval/template_prompts.py

@@ -0,0 +1,66 @@
+METADATA_FILTER_SYSTEM_PROMPT = """
+    ### Job Description',
+    You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
+    ### Task
+    Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
+    ### Format
+    The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
+    ### Constraint
+    DO NOT include anything other than the JSON array in your response.
+"""  # noqa: E501
+
+METADATA_FILTER_USER_PROMPT_1 = """
+    { "input_text": "I want to know which company’s email address test@example.com is?",
+    "metadata_fields": ["filename", "email", "phone", "address"]
+    }
+"""
+
+METADATA_FILTER_ASSISTANT_PROMPT_1 = """
+```json
+    {"metadata_map": [
+        {"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}
+    ]
+    }
+```
+"""
+
+METADATA_FILTER_USER_PROMPT_2 = """
+    {"input_text": "What are the movies with a score of more than 9 in 2024?",
+    "metadata_fields": ["name", "year", "rating", "country"]}
+"""
+
+METADATA_FILTER_ASSISTANT_PROMPT_2 = """
+```json
+    {"metadata_map": [
+        {"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="},
+        {"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"},
+    ]}
+```
+"""
+
+METADATA_FILTER_USER_PROMPT_3 = """
+    '{{"input_text": "{input_text}",',
+    '"metadata_fields": {metadata_fields}}}'
+"""
+
+METADATA_FILTER_COMPLETION_PROMPT = """
+### Job Description
+You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
+### Task
+# Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
+### Format
+The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
+### Constraint 
+DO NOT include anything other than the JSON array in your response.
+### Example
+Here is the chat example between human and assistant, inside <example></example> XML tags.
+<example>
+User:{{"input_text": ["I want to know which company’s email address test@example.com is?"], "metadata_fields": ["filename", "email", "phone", "address"]}}
+Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}}
+User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}}
+Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}}
+</example> 
+### User Input
+{{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}}
+### Assistant Output
+"""  # noqa: E501

+ 10 - 0
api/fields/dataset_fields.py

@@ -53,6 +53,8 @@ external_knowledge_info_fields = {
     "external_knowledge_api_endpoint": fields.String,
 }
 
+doc_metadata_fields = {"id": fields.String, "name": fields.String, "type": fields.String}
+
 dataset_detail_fields = {
     "id": fields.String,
     "name": fields.String,
@@ -76,6 +78,8 @@ dataset_detail_fields = {
     "doc_form": fields.String,
     "external_knowledge_info": fields.Nested(external_knowledge_info_fields),
     "external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True),
+    "doc_metadata": fields.List(fields.Nested(doc_metadata_fields)),
+    "built_in_field_enabled": fields.Boolean,
 }
 
 dataset_query_detail_fields = {
@@ -87,3 +91,9 @@ dataset_query_detail_fields = {
     "created_by": fields.String,
     "created_at": TimestampField,
 }
+
+dataset_metadata_fields = {
+    "id": fields.String,
+    "type": fields.String,
+    "name": fields.String,
+}

+ 9 - 0
api/fields/document_fields.py

@@ -3,6 +3,13 @@ from flask_restful import fields  # type: ignore
 from fields.dataset_fields import dataset_fields
 from libs.helper import TimestampField
 
+document_metadata_fields = {
+    "id": fields.String,
+    "name": fields.String,
+    "type": fields.String,
+    "value": fields.String,
+}
+
 document_fields = {
     "id": fields.String,
     "position": fields.Integer,
@@ -25,6 +32,7 @@ document_fields = {
     "word_count": fields.Integer,
     "hit_count": fields.Integer,
     "doc_form": fields.String,
+    "doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"),
 }
 
 document_with_segments_fields = {
@@ -51,6 +59,7 @@ document_with_segments_fields = {
     "hit_count": fields.Integer,
     "completed_segments": fields.Integer,
     "total_segments": fields.Integer,
+    "doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"),
 }
 
 dataset_and_document_fields = {

+ 90 - 0
api/migrations/versions/2025_02_27_0917-d20049ed0af6_add_metadata_function.py

@@ -0,0 +1,90 @@
+"""add_metadata_function
+
+Revision ID: d20049ed0af6
+Revises: 08ec4f75af5e
+Create Date: 2025-02-27 09:17:48.903213
+
+"""
+from alembic import op
+import models as models
+import sqlalchemy as sa
+from sqlalchemy.dialects import postgresql
+
+# revision identifiers, used by Alembic.
+revision = 'd20049ed0af6'
+down_revision = 'f051706725cc'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    op.create_table('dataset_metadata_bindings',
+    sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+    sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
+    sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
+    sa.Column('metadata_id', models.types.StringUUID(), nullable=False),
+    sa.Column('document_id', models.types.StringUUID(), nullable=False),
+    sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
+    sa.Column('created_by', models.types.StringUUID(), nullable=False),
+    sa.PrimaryKeyConstraint('id', name='dataset_metadata_binding_pkey')
+    )
+    with op.batch_alter_table('dataset_metadata_bindings', schema=None) as batch_op:
+        batch_op.create_index('dataset_metadata_binding_dataset_idx', ['dataset_id'], unique=False)
+        batch_op.create_index('dataset_metadata_binding_document_idx', ['document_id'], unique=False)
+        batch_op.create_index('dataset_metadata_binding_metadata_idx', ['metadata_id'], unique=False)
+        batch_op.create_index('dataset_metadata_binding_tenant_idx', ['tenant_id'], unique=False)
+
+    op.create_table('dataset_metadatas',
+    sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+    sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
+    sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
+    sa.Column('type', sa.String(length=255), nullable=False),
+    sa.Column('name', sa.String(length=255), nullable=False),
+    sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
+    sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
+    sa.Column('created_by', models.types.StringUUID(), nullable=False),
+    sa.Column('updated_by', models.types.StringUUID(), nullable=True),
+    sa.PrimaryKeyConstraint('id', name='dataset_metadata_pkey')
+    )
+    with op.batch_alter_table('dataset_metadatas', schema=None) as batch_op:
+        batch_op.create_index('dataset_metadata_dataset_idx', ['dataset_id'], unique=False)
+        batch_op.create_index('dataset_metadata_tenant_idx', ['tenant_id'], unique=False)
+
+    with op.batch_alter_table('datasets', schema=None) as batch_op:
+        batch_op.add_column(sa.Column('built_in_field_enabled', sa.Boolean(), server_default=sa.text('false'), nullable=False))
+
+    with op.batch_alter_table('documents', schema=None) as batch_op:
+        batch_op.alter_column('doc_metadata',
+               existing_type=postgresql.JSON(astext_type=sa.Text()),
+               type_=postgresql.JSONB(astext_type=sa.Text()),
+               existing_nullable=True)
+        batch_op.create_index('document_metadata_idx', ['doc_metadata'], unique=False, postgresql_using='gin')
+    # ### end Alembic commands ###
+
+
+def downgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('documents', schema=None) as batch_op:
+        batch_op.drop_index('document_metadata_idx', postgresql_using='gin')
+        batch_op.alter_column('doc_metadata',
+               existing_type=postgresql.JSONB(astext_type=sa.Text()),
+               type_=postgresql.JSON(astext_type=sa.Text()),
+               existing_nullable=True)
+
+    with op.batch_alter_table('datasets', schema=None) as batch_op:
+        batch_op.drop_column('built_in_field_enabled')
+
+    with op.batch_alter_table('dataset_metadatas', schema=None) as batch_op:
+        batch_op.drop_index('dataset_metadata_tenant_idx')
+        batch_op.drop_index('dataset_metadata_dataset_idx')
+
+    op.drop_table('dataset_metadatas')
+    with op.batch_alter_table('dataset_metadata_bindings', schema=None) as batch_op:
+        batch_op.drop_index('dataset_metadata_binding_tenant_idx')
+        batch_op.drop_index('dataset_metadata_binding_metadata_idx')
+        batch_op.drop_index('dataset_metadata_binding_document_idx')
+        batch_op.drop_index('dataset_metadata_binding_dataset_idx')
+
+    op.drop_table('dataset_metadata_bindings')
+    # ### end Alembic commands ###

+ 175 - 1
api/models/dataset.py

@@ -16,6 +16,7 @@ from sqlalchemy.dialects.postgresql import JSONB
 from sqlalchemy.orm import Mapped
 
 from configs import dify_config
+from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
 from core.rag.retrieval.retrieval_methods import RetrievalMethod
 from extensions.ext_storage import storage
 from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
@@ -60,6 +61,7 @@ class Dataset(db.Model):  # type: ignore[name-defined]
     embedding_model_provider = db.Column(db.String(255), nullable=True)
     collection_binding_id = db.Column(StringUUID, nullable=True)
     retrieval_model = db.Column(JSONB, nullable=True)
+    built_in_field_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
 
     @property
     def dataset_keyword_table(self):
@@ -197,6 +199,56 @@ class Dataset(db.Model):  # type: ignore[name-defined]
             "external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""),
         }
 
+    @property
+    def doc_metadata(self):
+        dataset_metadatas = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id == self.id).all()
+
+        doc_metadata = [
+            {
+                "id": dataset_metadata.id,
+                "name": dataset_metadata.name,
+                "type": dataset_metadata.type,
+            }
+            for dataset_metadata in dataset_metadatas
+        ]
+        if self.built_in_field_enabled:
+            doc_metadata.append(
+                {
+                    "id": "built-in",
+                    "name": BuiltInField.document_name.value,
+                    "type": "string",
+                }
+            )
+            doc_metadata.append(
+                {
+                    "id": "built-in",
+                    "name": BuiltInField.uploader.value,
+                    "type": "string",
+                }
+            )
+            doc_metadata.append(
+                {
+                    "id": "built-in",
+                    "name": BuiltInField.upload_date.value,
+                    "type": "time",
+                }
+            )
+            doc_metadata.append(
+                {
+                    "id": "built-in",
+                    "name": BuiltInField.last_update_date.value,
+                    "type": "time",
+                }
+            )
+            doc_metadata.append(
+                {
+                    "id": "built-in",
+                    "name": BuiltInField.source.value,
+                    "type": "string",
+                }
+            )
+        return doc_metadata
+
     @staticmethod
     def gen_collection_name_by_id(dataset_id: str) -> str:
         normalized_dataset_id = dataset_id.replace("-", "_")
@@ -250,6 +302,7 @@ class Document(db.Model):  # type: ignore[name-defined]
         db.Index("document_dataset_id_idx", "dataset_id"),
         db.Index("document_is_paused_idx", "is_paused"),
         db.Index("document_tenant_idx", "tenant_id"),
+        db.Index("document_metadata_idx", "doc_metadata", postgresql_using="gin"),
     )
 
     # initial fields
@@ -306,7 +359,7 @@ class Document(db.Model):  # type: ignore[name-defined]
     archived_at = db.Column(db.DateTime, nullable=True)
     updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
     doc_type = db.Column(db.String(40), nullable=True)
-    doc_metadata = db.Column(db.JSON, nullable=True)
+    doc_metadata = db.Column(JSONB, nullable=True)
     doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying"))
     doc_language = db.Column(db.String(255), nullable=True)
 
@@ -396,12 +449,95 @@ class Document(db.Model):  # type: ignore[name-defined]
             .scalar()
         )
 
+    @property
+    def uploader(self):
+        user = db.session.query(Account).filter(Account.id == self.created_by).first()
+        return user.name if user else None
+
+    @property
+    def upload_date(self):
+        return self.created_at
+
+    @property
+    def last_update_date(self):
+        return self.updated_at
+
+    @property
+    def doc_metadata_details(self):
+        if self.doc_metadata:
+            document_metadatas = (
+                db.session.query(DatasetMetadata)
+                .join(DatasetMetadataBinding, DatasetMetadataBinding.metadata_id == DatasetMetadata.id)
+                .filter(
+                    DatasetMetadataBinding.dataset_id == self.dataset_id, DatasetMetadataBinding.document_id == self.id
+                )
+                .all()
+            )
+            metadata_list = []
+            for metadata in document_metadatas:
+                metadata_dict = {
+                    "id": metadata.id,
+                    "name": metadata.name,
+                    "type": metadata.type,
+                    "value": self.doc_metadata.get(metadata.name),
+                }
+                metadata_list.append(metadata_dict)
+            # deal built-in fields
+            metadata_list.extend(self.get_built_in_fields())
+
+            return metadata_list
+        return None
+
     @property
     def process_rule_dict(self):
         if self.dataset_process_rule_id:
             return self.dataset_process_rule.to_dict()
         return None
 
+    def get_built_in_fields(self):
+        built_in_fields = []
+        built_in_fields.append(
+            {
+                "id": "built-in",
+                "name": BuiltInField.document_name,
+                "type": "string",
+                "value": self.name,
+            }
+        )
+        built_in_fields.append(
+            {
+                "id": "built-in",
+                "name": BuiltInField.uploader,
+                "type": "string",
+                "value": self.uploader,
+            }
+        )
+        built_in_fields.append(
+            {
+                "id": "built-in",
+                "name": BuiltInField.upload_date,
+                "type": "time",
+                "value": self.created_at.timestamp(),
+            }
+        )
+        built_in_fields.append(
+            {
+                "id": "built-in",
+                "name": BuiltInField.last_update_date,
+                "type": "time",
+                "value": self.updated_at.timestamp(),
+            }
+        )
+        built_in_fields.append(
+            {
+                "id": "built-in",
+                "name": BuiltInField.source,
+                "type": "string",
+                "value": MetadataDataSource[self.data_source_type].value,
+            }
+        )
+        return built_in_fields
+
     def to_dict(self):
         return {
             "id": self.id,
@@ -945,3 +1081,41 @@ class RateLimitLog(db.Model):  # type: ignore[name-defined]
     subscription_plan = db.Column(db.String(255), nullable=False)
     operation = db.Column(db.String(255), nullable=False)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
+
+
+class DatasetMetadata(db.Model):  # type: ignore[name-defined]
+    __tablename__ = "dataset_metadatas"
+    __table_args__ = (
+        db.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"),
+        db.Index("dataset_metadata_tenant_idx", "tenant_id"),
+        db.Index("dataset_metadata_dataset_idx", "dataset_id"),
+    )
+
+    id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+    tenant_id = db.Column(StringUUID, nullable=False)
+    dataset_id = db.Column(StringUUID, nullable=False)
+    type = db.Column(db.String(255), nullable=False)
+    name = db.Column(db.String(255), nullable=False)
+    created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
+    updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
+    created_by = db.Column(StringUUID, nullable=False)
+    updated_by = db.Column(StringUUID, nullable=True)
+
+
+class DatasetMetadataBinding(db.Model):  # type: ignore[name-defined]
+    __tablename__ = "dataset_metadata_bindings"
+    __table_args__ = (
+        db.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"),
+        db.Index("dataset_metadata_binding_tenant_idx", "tenant_id"),
+        db.Index("dataset_metadata_binding_dataset_idx", "dataset_id"),
+        db.Index("dataset_metadata_binding_metadata_idx", "metadata_id"),
+        db.Index("dataset_metadata_binding_document_idx", "document_id"),
+    )
+
+    id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+    tenant_id = db.Column(StringUUID, nullable=False)
+    dataset_id = db.Column(StringUUID, nullable=False)
+    metadata_id = db.Column(StringUUID, nullable=False)
+    document_id = db.Column(StringUUID, nullable=False)
+    created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+    created_by = db.Column(StringUUID, nullable=False)

文件差异内容过多而无法显示
+ 355 - 392
api/poetry.lock


+ 57 - 3
api/services/dataset_service.py

@@ -1,3 +1,4 @@
+import copy
 import datetime
 import json
 import logging
@@ -17,6 +18,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
 from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelType
 from core.plugin.entities.plugin import ModelProviderID
+from core.rag.index_processor.constant.built_in_field import BuiltInField
 from core.rag.index_processor.constant.index_type import IndexType
 from core.rag.retrieval.retrieval_methods import RetrievalMethod
 from events.dataset_event import dataset_was_deleted
@@ -643,9 +645,45 @@ class DocumentService:
 
         return document
 
+    @staticmethod
+    def get_document_by_ids(document_ids: list[str]) -> list[Document]:
+        documents = (
+            db.session.query(Document)
+            .filter(
+                Document.id.in_(document_ids),
+                Document.enabled == True,
+                Document.indexing_status == "completed",
+                Document.archived == False,
+            )
+            .all()
+        )
+        return documents
+
     @staticmethod
     def get_document_by_dataset_id(dataset_id: str) -> list[Document]:
-        documents = db.session.query(Document).filter(Document.dataset_id == dataset_id, Document.enabled == True).all()
+        documents = (
+            db.session.query(Document)
+            .filter(
+                Document.dataset_id == dataset_id,
+                Document.enabled == True,
+            )
+            .all()
+        )
+
+        return documents
+
+    @staticmethod
+    def get_working_documents_by_dataset_id(dataset_id: str) -> list[Document]:
+        documents = (
+            db.session.query(Document)
+            .filter(
+                Document.dataset_id == dataset_id,
+                Document.enabled == True,
+                Document.indexing_status == "completed",
+                Document.archived == False,
+            )
+            .all()
+        )
 
         return documents
 
@@ -728,8 +766,13 @@ class DocumentService:
         if document.tenant_id != current_user.current_tenant_id:
             raise ValueError("No permission.")
 
-        document.name = name
+        if dataset.built_in_field_enabled:
+            if document.doc_metadata:
+                doc_metadata = copy.deepcopy(document.doc_metadata)
+                doc_metadata[BuiltInField.document_name.value] = name
+                document.doc_metadata = doc_metadata
 
+        document.name = name
         db.session.add(document)
         db.session.commit()
 
@@ -1128,9 +1171,20 @@ class DocumentService:
             doc_form=document_form,
             doc_language=document_language,
         )
+        doc_metadata = {}
+        if dataset.built_in_field_enabled:
+            doc_metadata = {
+                BuiltInField.document_name: name,
+                BuiltInField.uploader: account.name,
+                BuiltInField.upload_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
+                BuiltInField.last_update_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
+                BuiltInField.source: data_source_type,
+            }
         if metadata is not None:
-            document.doc_metadata = metadata.doc_metadata
+            doc_metadata.update(metadata.doc_metadata)
             document.doc_type = metadata.doc_type
+        if doc_metadata:
+            document.doc_metadata = doc_metadata
         return document
 
     @staticmethod

+ 33 - 0
api/services/entities/knowledge_entities/knowledge_entities.py

@@ -125,3 +125,36 @@ class SegmentUpdateArgs(BaseModel):
 class ChildChunkUpdateArgs(BaseModel):
     id: Optional[str] = None
     content: str
+
+
+class MetadataArgs(BaseModel):
+    type: Literal["string", "number", "time"]
+    name: str
+
+
+class MetadataUpdateArgs(BaseModel):
+    name: str
+    value: Optional[str | int | float] = None
+
+
+class MetadataValueUpdateArgs(BaseModel):
+    fields: list[MetadataUpdateArgs]
+
+
+class MetadataDetail(BaseModel):
+    id: str
+    name: str
+    value: Optional[str | int | float] = None
+
+
+class DocumentMetadataOperation(BaseModel):
+    document_id: str
+    metadata_list: list[MetadataDetail]
+
+
+class MetadataOperationData(BaseModel):
+    """
+    Metadata operation data
+    """
+
+    operation_data: list[DocumentMetadataOperation]

+ 7 - 1
api/services/external_knowledge_service.py

@@ -8,6 +8,7 @@ import validators
 
 from constants import HIDDEN_VALUE
 from core.helper import ssrf_proxy
+from core.rag.entities.metadata_entities import MetadataCondition
 from extensions.ext_database import db
 from models.dataset import (
     Dataset,
@@ -245,7 +246,11 @@ class ExternalDatasetService:
 
     @staticmethod
     def fetch_external_knowledge_retrieval(
-        tenant_id: str, dataset_id: str, query: str, external_retrieval_parameters: dict
+        tenant_id: str,
+        dataset_id: str,
+        query: str,
+        external_retrieval_parameters: dict,
+        metadata_condition: Optional[MetadataCondition] = None,
     ) -> list:
         external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by(
             dataset_id=dataset_id, tenant_id=tenant_id
@@ -272,6 +277,7 @@ class ExternalDatasetService:
             },
             "query": query,
             "knowledge_id": external_knowledge_binding.external_knowledge_id,
+            "metadata_condition": metadata_condition.model_dump() if metadata_condition else None,
         }
 
         response = ExternalDatasetService.process_external_api(

+ 241 - 0
api/services/metadata_service.py

@@ -0,0 +1,241 @@
+import copy
+import datetime
+import logging
+from typing import Optional
+
+from flask_login import current_user  # type: ignore
+
+from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
+from extensions.ext_database import db
+from extensions.ext_redis import redis_client
+from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding
+from services.dataset_service import DocumentService
+from services.entities.knowledge_entities.knowledge_entities import (
+    MetadataArgs,
+    MetadataOperationData,
+)
+
+
+class MetadataService:
+    @staticmethod
+    def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata:
+        # check if metadata name already exists
+        if DatasetMetadata.query.filter_by(
+            tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=metadata_args.name
+        ).first():
+            raise ValueError("Metadata name already exists.")
+        for field in BuiltInField:
+            if field.value == metadata_args.name:
+                raise ValueError("Metadata name already exists in Built-in fields.")
+        metadata = DatasetMetadata(
+            tenant_id=current_user.current_tenant_id,
+            dataset_id=dataset_id,
+            type=metadata_args.type,
+            name=metadata_args.name,
+            created_by=current_user.id,
+        )
+        db.session.add(metadata)
+        db.session.commit()
+        return metadata
+
+    @staticmethod
+    def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata:  # type: ignore
+        lock_key = f"dataset_metadata_lock_{dataset_id}"
+        # check if metadata name already exists
+        if DatasetMetadata.query.filter_by(
+            tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=name
+        ).first():
+            raise ValueError("Metadata name already exists.")
+        for field in BuiltInField:
+            if field.value == name:
+                raise ValueError("Metadata name already exists in Built-in fields.")
+        try:
+            MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
+            metadata = DatasetMetadata.query.filter_by(id=metadata_id).first()
+            if metadata is None:
+                raise ValueError("Metadata not found.")
+            old_name = metadata.name
+            metadata.name = name
+            metadata.updated_by = current_user.id
+            metadata.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+
+            # update related documents
+            dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all()
+            if dataset_metadata_bindings:
+                document_ids = [binding.document_id for binding in dataset_metadata_bindings]
+                documents = DocumentService.get_document_by_ids(document_ids)
+                for document in documents:
+                    doc_metadata = copy.deepcopy(document.doc_metadata)
+                    value = doc_metadata.pop(old_name, None)
+                    doc_metadata[name] = value
+                    document.doc_metadata = doc_metadata
+                    db.session.add(document)
+            db.session.commit()
+            return metadata  # type: ignore
+        except Exception:
+            logging.exception("Update metadata name failed")
+        finally:
+            redis_client.delete(lock_key)
+
+    @staticmethod
+    def delete_metadata(dataset_id: str, metadata_id: str):
+        lock_key = f"dataset_metadata_lock_{dataset_id}"
+        try:
+            MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
+            metadata = DatasetMetadata.query.filter_by(id=metadata_id).first()
+            if metadata is None:
+                raise ValueError("Metadata not found.")
+            db.session.delete(metadata)
+
+            # deal related documents
+            dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all()
+            if dataset_metadata_bindings:
+                document_ids = [binding.document_id for binding in dataset_metadata_bindings]
+                documents = DocumentService.get_document_by_ids(document_ids)
+                for document in documents:
+                    doc_metadata = copy.deepcopy(document.doc_metadata)
+                    doc_metadata.pop(metadata.name, None)
+                    document.doc_metadata = doc_metadata
+                    db.session.add(document)
+            db.session.commit()
+            return metadata
+        except Exception:
+            logging.exception("Delete metadata failed")
+        finally:
+            redis_client.delete(lock_key)
+
+    @staticmethod
+    def get_built_in_fields():
+        return [
+            {"name": BuiltInField.document_name.value, "type": "string"},
+            {"name": BuiltInField.uploader.value, "type": "string"},
+            {"name": BuiltInField.upload_date.value, "type": "time"},
+            {"name": BuiltInField.last_update_date.value, "type": "time"},
+            {"name": BuiltInField.source.value, "type": "string"},
+        ]
+
+    @staticmethod
+    def enable_built_in_field(dataset: Dataset):
+        if dataset.built_in_field_enabled:
+            return
+        lock_key = f"dataset_metadata_lock_{dataset.id}"
+        try:
+            MetadataService.knowledge_base_metadata_lock_check(dataset.id, None)
+            dataset.built_in_field_enabled = True
+            db.session.add(dataset)
+            documents = DocumentService.get_working_documents_by_dataset_id(dataset.id)
+            if documents:
+                for document in documents:
+                    if not document.doc_metadata:
+                        doc_metadata = {}
+                    else:
+                        doc_metadata = copy.deepcopy(document.doc_metadata)
+                    doc_metadata[BuiltInField.document_name.value] = document.name
+                    doc_metadata[BuiltInField.uploader.value] = document.uploader
+                    doc_metadata[BuiltInField.upload_date.value] = document.upload_date.timestamp()
+                    doc_metadata[BuiltInField.last_update_date.value] = document.last_update_date.timestamp()
+                    doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value
+                    document.doc_metadata = doc_metadata
+                    db.session.add(document)
+                db.session.commit()
+        except Exception:
+            logging.exception("Enable built-in field failed")
+        finally:
+            redis_client.delete(lock_key)
+
+    @staticmethod
+    def disable_built_in_field(dataset: Dataset):
+        if not dataset.built_in_field_enabled:
+            return
+        lock_key = f"dataset_metadata_lock_{dataset.id}"
+        try:
+            MetadataService.knowledge_base_metadata_lock_check(dataset.id, None)
+            dataset.built_in_field_enabled = False
+            db.session.add(dataset)
+            documents = DocumentService.get_working_documents_by_dataset_id(dataset.id)
+            document_ids = []
+            if documents:
+                for document in documents:
+                    doc_metadata = copy.deepcopy(document.doc_metadata)
+                    doc_metadata.pop(BuiltInField.document_name.value, None)
+                    doc_metadata.pop(BuiltInField.uploader.value, None)
+                    doc_metadata.pop(BuiltInField.upload_date.value, None)
+                    doc_metadata.pop(BuiltInField.last_update_date.value, None)
+                    doc_metadata.pop(BuiltInField.source.value, None)
+                    document.doc_metadata = doc_metadata
+                    db.session.add(document)
+                    document_ids.append(document.id)
+            db.session.commit()
+        except Exception:
+            logging.exception("Disable built-in field failed")
+        finally:
+            redis_client.delete(lock_key)
+
+    @staticmethod
+    def update_documents_metadata(dataset: Dataset, metadata_args: MetadataOperationData):
+        for operation in metadata_args.operation_data:
+            lock_key = f"document_metadata_lock_{operation.document_id}"
+            try:
+                MetadataService.knowledge_base_metadata_lock_check(None, operation.document_id)
+                document = DocumentService.get_document(dataset.id, operation.document_id)
+                if document is None:
+                    raise ValueError("Document not found.")
+                doc_metadata = {}
+                for metadata_value in operation.metadata_list:
+                    doc_metadata[metadata_value.name] = metadata_value.value
+                if dataset.built_in_field_enabled:
+                    doc_metadata[BuiltInField.document_name.value] = document.name
+                    doc_metadata[BuiltInField.uploader.value] = document.uploader
+                    doc_metadata[BuiltInField.upload_date.value] = document.upload_date.timestamp()
+                    doc_metadata[BuiltInField.last_update_date.value] = document.last_update_date.timestamp()
+                    doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value
+                document.doc_metadata = doc_metadata
+                db.session.add(document)
+                db.session.commit()
+                # deal metadata binding
+                DatasetMetadataBinding.query.filter_by(document_id=operation.document_id).delete()
+                for metadata_value in operation.metadata_list:
+                    dataset_metadata_binding = DatasetMetadataBinding(
+                        tenant_id=current_user.current_tenant_id,
+                        dataset_id=dataset.id,
+                        document_id=operation.document_id,
+                        metadata_id=metadata_value.id,
+                        created_by=current_user.id,
+                    )
+                    db.session.add(dataset_metadata_binding)
+                db.session.commit()
+            except Exception:
+                logging.exception("Update documents metadata failed")
+            finally:
+                redis_client.delete(lock_key)
+
+    @staticmethod
+    def knowledge_base_metadata_lock_check(dataset_id: Optional[str], document_id: Optional[str]):
+        if dataset_id:
+            lock_key = f"dataset_metadata_lock_{dataset_id}"
+            if redis_client.get(lock_key):
+                raise ValueError("Another knowledge base metadata operation is running, please wait a moment.")
+            redis_client.set(lock_key, 1, ex=3600)
+        if document_id:
+            lock_key = f"document_metadata_lock_{document_id}"
+            if redis_client.get(lock_key):
+                raise ValueError("Another document metadata operation is running, please wait a moment.")
+            redis_client.set(lock_key, 1, ex=3600)
+
+    @staticmethod
+    def get_dataset_metadatas(dataset: Dataset):
+        return {
+            "doc_metadata": [
+                {
+                    "id": item.get("id"),
+                    "name": item.get("name"),
+                    "type": item.get("type"),
+                    "count": DatasetMetadataBinding.query.filter_by(
+                        metadata_id=item.get("id"), dataset_id=dataset.id
+                    ).count(),
+                }
+                for item in dataset.doc_metadata or []
+                if item.get("id") != "built-in"
+            ],
+            "built_in_field_enabled": dataset.built_in_field_enabled,
+        }

+ 1 - 1
api/services/tag_service.py

@@ -20,7 +20,7 @@ class TagService:
         )
         if keyword:
             query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%")))
-        query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at)
+        query = query.group_by(Tag.id, Tag.type, Tag.name)
         results: list = query.order_by(Tag.created_at.desc()).all()
         return results
 

部分文件因为文件数量过多而无法显示