|
@@ -1,35 +1,61 @@
|
|
|
+import json
|
|
|
import math
|
|
|
+import re
|
|
|
import threading
|
|
|
-from collections import Counter
|
|
|
-from typing import Any, Optional, cast
|
|
|
+from collections import Counter, defaultdict
|
|
|
+from collections.abc import Generator, Mapping
|
|
|
+from typing import Any, Optional, Union, cast
|
|
|
|
|
|
from flask import Flask, current_app
|
|
|
-
|
|
|
-from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
|
|
|
+from sqlalchemy import Integer, and_, or_, text
|
|
|
+from sqlalchemy import cast as sqlalchemy_cast
|
|
|
+
|
|
|
+from core.app.app_config.entities import (
|
|
|
+ DatasetEntity,
|
|
|
+ DatasetRetrieveConfigEntity,
|
|
|
+ MetadataFilteringCondition,
|
|
|
+ ModelConfig,
|
|
|
+)
|
|
|
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
|
|
|
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
|
|
from core.entities.agent_entities import PlanningStrategy
|
|
|
+from core.entities.model_entities import ModelStatus
|
|
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
|
|
from core.model_manager import ModelInstance, ModelManager
|
|
|
-from core.model_runtime.entities.message_entities import PromptMessageTool
|
|
|
+from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
|
|
+from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
|
|
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
|
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
|
|
from core.ops.entities.trace_entity import TraceTaskName
|
|
|
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
|
|
from core.ops.utils import measure_time
|
|
|
+from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
|
|
+from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
|
|
+from core.prompt.simple_prompt_transform import ModelMode
|
|
|
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
|
|
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
|
|
from core.rag.datasource.retrieval_service import RetrievalService
|
|
|
from core.rag.entities.context_entities import DocumentContext
|
|
|
+from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
|
|
from core.rag.index_processor.constant.index_type import IndexType
|
|
|
from core.rag.models.document import Document
|
|
|
from core.rag.rerank.rerank_type import RerankMode
|
|
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
|
|
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
|
|
|
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
|
|
|
+from core.rag.retrieval.template_prompts import (
|
|
|
+ METADATA_FILTER_ASSISTANT_PROMPT_1,
|
|
|
+ METADATA_FILTER_ASSISTANT_PROMPT_2,
|
|
|
+ METADATA_FILTER_COMPLETION_PROMPT,
|
|
|
+ METADATA_FILTER_SYSTEM_PROMPT,
|
|
|
+ METADATA_FILTER_USER_PROMPT_1,
|
|
|
+ METADATA_FILTER_USER_PROMPT_2,
|
|
|
+ METADATA_FILTER_USER_PROMPT_3,
|
|
|
+)
|
|
|
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
|
|
from extensions.ext_database import db
|
|
|
-from models.dataset import ChildChunk, Dataset, DatasetQuery, DocumentSegment
|
|
|
+from libs.json_in_md_parser import parse_and_check_json_markdown
|
|
|
+from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment
|
|
|
from models.dataset import Document as DatasetDocument
|
|
|
from services.external_knowledge_service import ExternalDatasetService
|
|
|
|
|
@@ -59,6 +85,7 @@ class DatasetRetrieval:
|
|
|
hit_callback: DatasetIndexToolCallbackHandler,
|
|
|
message_id: str,
|
|
|
memory: Optional[TokenBufferMemory] = None,
|
|
|
+ inputs: Optional[Mapping[str, Any]] = None,
|
|
|
) -> Optional[str]:
|
|
|
"""
|
|
|
Retrieve dataset.
|
|
@@ -116,6 +143,22 @@ class DatasetRetrieval:
|
|
|
continue
|
|
|
|
|
|
available_datasets.append(dataset)
|
|
|
+ if inputs:
|
|
|
+ inputs = {key: str(value) for key, value in inputs.items()}
|
|
|
+ else:
|
|
|
+ inputs = {}
|
|
|
+ available_datasets_ids = [dataset.id for dataset in available_datasets]
|
|
|
+ metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition(
|
|
|
+ available_datasets_ids,
|
|
|
+ query,
|
|
|
+ tenant_id,
|
|
|
+ user_id,
|
|
|
+ retrieve_config.metadata_filtering_mode, # type: ignore
|
|
|
+ retrieve_config.metadata_model_config, # type: ignore
|
|
|
+ retrieve_config.metadata_filtering_conditions,
|
|
|
+ inputs,
|
|
|
+ )
|
|
|
+
|
|
|
all_documents = []
|
|
|
user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
|
|
|
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
|
@@ -130,6 +173,8 @@ class DatasetRetrieval:
|
|
|
model_config,
|
|
|
planning_strategy,
|
|
|
message_id,
|
|
|
+ metadata_filter_document_ids,
|
|
|
+ metadata_condition,
|
|
|
)
|
|
|
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
|
|
all_documents = self.multiple_retrieve(
|
|
@@ -146,6 +191,8 @@ class DatasetRetrieval:
|
|
|
retrieve_config.weights,
|
|
|
retrieve_config.reranking_enabled or True,
|
|
|
message_id,
|
|
|
+ metadata_filter_document_ids,
|
|
|
+ metadata_condition,
|
|
|
)
|
|
|
|
|
|
dify_documents = [item for item in all_documents if item.provider == "dify"]
|
|
@@ -239,6 +286,8 @@ class DatasetRetrieval:
|
|
|
model_config: ModelConfigWithCredentialsEntity,
|
|
|
planning_strategy: PlanningStrategy,
|
|
|
message_id: Optional[str] = None,
|
|
|
+ metadata_filter_document_ids: Optional[dict[str, list[str]]] = None,
|
|
|
+ metadata_condition: Optional[MetadataCondition] = None,
|
|
|
):
|
|
|
tools = []
|
|
|
for dataset in available_datasets:
|
|
@@ -279,6 +328,7 @@ class DatasetRetrieval:
|
|
|
dataset_id=dataset_id,
|
|
|
query=query,
|
|
|
external_retrieval_parameters=dataset.retrieval_model,
|
|
|
+ metadata_condition=metadata_condition,
|
|
|
)
|
|
|
for external_document in external_documents:
|
|
|
document = Document(
|
|
@@ -293,6 +343,15 @@ class DatasetRetrieval:
|
|
|
document.metadata["dataset_name"] = dataset.name
|
|
|
results.append(document)
|
|
|
else:
|
|
|
+ if metadata_condition and not metadata_filter_document_ids:
|
|
|
+ return []
|
|
|
+ document_ids_filter = None
|
|
|
+ if metadata_filter_document_ids:
|
|
|
+ document_ids = metadata_filter_document_ids.get(dataset.id, [])
|
|
|
+ if document_ids:
|
|
|
+ document_ids_filter = document_ids
|
|
|
+ else:
|
|
|
+ return []
|
|
|
retrieval_model_config = dataset.retrieval_model or default_retrieval_model
|
|
|
|
|
|
# get top k
|
|
@@ -324,6 +383,7 @@ class DatasetRetrieval:
|
|
|
reranking_model=reranking_model,
|
|
|
reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"),
|
|
|
weights=retrieval_model_config.get("weights", None),
|
|
|
+ document_ids_filter=document_ids_filter,
|
|
|
)
|
|
|
self._on_query(query, [dataset_id], app_id, user_from, user_id)
|
|
|
|
|
@@ -348,6 +408,8 @@ class DatasetRetrieval:
|
|
|
weights: Optional[dict[str, Any]] = None,
|
|
|
reranking_enable: bool = True,
|
|
|
message_id: Optional[str] = None,
|
|
|
+ metadata_filter_document_ids: Optional[dict[str, list[str]]] = None,
|
|
|
+ metadata_condition: Optional[MetadataCondition] = None,
|
|
|
):
|
|
|
if not available_datasets:
|
|
|
return []
|
|
@@ -387,6 +449,16 @@ class DatasetRetrieval:
|
|
|
|
|
|
for dataset in available_datasets:
|
|
|
index_type = dataset.indexing_technique
|
|
|
+ document_ids_filter = None
|
|
|
+ if dataset.provider != "external":
|
|
|
+ if metadata_condition and not metadata_filter_document_ids:
|
|
|
+ continue
|
|
|
+ if metadata_filter_document_ids:
|
|
|
+ document_ids = metadata_filter_document_ids.get(dataset.id, [])
|
|
|
+ if document_ids:
|
|
|
+ document_ids_filter = document_ids
|
|
|
+ else:
|
|
|
+ continue
|
|
|
retrieval_thread = threading.Thread(
|
|
|
target=self._retriever,
|
|
|
kwargs={
|
|
@@ -395,6 +467,8 @@ class DatasetRetrieval:
|
|
|
"query": query,
|
|
|
"top_k": top_k,
|
|
|
"all_documents": all_documents,
|
|
|
+ "document_ids_filter": document_ids_filter,
|
|
|
+ "metadata_condition": metadata_condition,
|
|
|
},
|
|
|
)
|
|
|
threads.append(retrieval_thread)
|
|
@@ -493,7 +567,16 @@ class DatasetRetrieval:
|
|
|
db.session.add_all(dataset_queries)
|
|
|
db.session.commit()
|
|
|
|
|
|
- def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
|
|
|
+ def _retriever(
|
|
|
+ self,
|
|
|
+ flask_app: Flask,
|
|
|
+ dataset_id: str,
|
|
|
+ query: str,
|
|
|
+ top_k: int,
|
|
|
+ all_documents: list,
|
|
|
+ document_ids_filter: Optional[list[str]] = None,
|
|
|
+ metadata_condition: Optional[MetadataCondition] = None,
|
|
|
+ ):
|
|
|
with flask_app.app_context():
|
|
|
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
|
|
|
|
@@ -506,6 +589,7 @@ class DatasetRetrieval:
|
|
|
dataset_id=dataset_id,
|
|
|
query=query,
|
|
|
external_retrieval_parameters=dataset.retrieval_model,
|
|
|
+ metadata_condition=metadata_condition,
|
|
|
)
|
|
|
for external_document in external_documents:
|
|
|
document = Document(
|
|
@@ -546,6 +630,7 @@ class DatasetRetrieval:
|
|
|
else None,
|
|
|
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
|
|
weights=retrieval_model.get("weights", None),
|
|
|
+ document_ids_filter=document_ids_filter,
|
|
|
)
|
|
|
|
|
|
all_documents.extend(documents)
|
|
@@ -733,3 +818,340 @@ class DatasetRetrieval:
|
|
|
filter_documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True
|
|
|
)
|
|
|
return filter_documents[:top_k] if top_k else filter_documents
|
|
|
+
|
|
|
+ def _get_metadata_filter_condition(
|
|
|
+ self,
|
|
|
+ dataset_ids: list,
|
|
|
+ query: str,
|
|
|
+ tenant_id: str,
|
|
|
+ user_id: str,
|
|
|
+ metadata_filtering_mode: str,
|
|
|
+ metadata_model_config: ModelConfig,
|
|
|
+ metadata_filtering_conditions: Optional[MetadataFilteringCondition],
|
|
|
+ inputs: dict,
|
|
|
+ ) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
|
|
|
+ document_query = db.session.query(DatasetDocument).filter(
|
|
|
+ DatasetDocument.dataset_id.in_(dataset_ids),
|
|
|
+ DatasetDocument.indexing_status == "completed",
|
|
|
+ DatasetDocument.enabled == True,
|
|
|
+ DatasetDocument.archived == False,
|
|
|
+ )
|
|
|
+ filters = [] # type: ignore
|
|
|
+ metadata_condition = None
|
|
|
+ if metadata_filtering_mode == "disabled":
|
|
|
+ return None, None
|
|
|
+ elif metadata_filtering_mode == "automatic":
|
|
|
+ automatic_metadata_filters = self._automatic_metadata_filter_func(
|
|
|
+ dataset_ids, query, tenant_id, user_id, metadata_model_config
|
|
|
+ )
|
|
|
+ if automatic_metadata_filters:
|
|
|
+ conditions = []
|
|
|
+ for filter in automatic_metadata_filters:
|
|
|
+ self._process_metadata_filter_func(
|
|
|
+ filter.get("condition"), # type: ignore
|
|
|
+ filter.get("metadata_name"), # type: ignore
|
|
|
+ filter.get("value"),
|
|
|
+ filters, # type: ignore
|
|
|
+ )
|
|
|
+ conditions.append(
|
|
|
+ Condition(
|
|
|
+ name=filter.get("metadata_name"), # type: ignore
|
|
|
+ comparison_operator=filter.get("condition"), # type: ignore
|
|
|
+ value=filter.get("value"),
|
|
|
+ )
|
|
|
+ )
|
|
|
+ metadata_condition = MetadataCondition(
|
|
|
+ logical_operator=metadata_filtering_conditions.logical_operator, # type: ignore
|
|
|
+ conditions=conditions,
|
|
|
+ )
|
|
|
+ elif metadata_filtering_mode == "manual":
|
|
|
+ if metadata_filtering_conditions:
|
|
|
+ metadata_condition = MetadataCondition(**metadata_filtering_conditions.model_dump())
|
|
|
+ for condition in metadata_filtering_conditions.conditions: # type: ignore
|
|
|
+ metadata_name = condition.name
|
|
|
+ expected_value = condition.value
|
|
|
+ if expected_value or condition.comparison_operator in ("empty", "not empty"):
|
|
|
+ if isinstance(expected_value, str):
|
|
|
+ expected_value = self._replace_metadata_filter_value(expected_value, inputs)
|
|
|
+ filters = self._process_metadata_filter_func(
|
|
|
+ condition.comparison_operator, metadata_name, expected_value, filters
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ raise ValueError("Invalid metadata filtering mode")
|
|
|
+ if filters:
|
|
|
+ if metadata_filtering_conditions.logical_operator == "or": # type: ignore
|
|
|
+ document_query = document_query.filter(or_(*filters))
|
|
|
+ else:
|
|
|
+ document_query = document_query.filter(and_(*filters))
|
|
|
+ documents = document_query.all()
|
|
|
+ # group by dataset_id
|
|
|
+ metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
|
|
|
+ for document in documents:
|
|
|
+ metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
|
|
|
+ return metadata_filter_document_ids, metadata_condition
|
|
|
+
|
|
|
+ def _replace_metadata_filter_value(self, text: str, inputs: dict) -> str:
|
|
|
+ def replacer(match):
|
|
|
+ key = match.group(1)
|
|
|
+ return str(inputs.get(key, f"{{{{{key}}}}}"))
|
|
|
+
|
|
|
+ pattern = re.compile(r"\{\{(\w+)\}\}")
|
|
|
+ return pattern.sub(replacer, text)
|
|
|
+
|
|
|
+ def _automatic_metadata_filter_func(
|
|
|
+ self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
|
|
|
+ ) -> Optional[list[dict[str, Any]]]:
|
|
|
+ # get all metadata field
|
|
|
+ metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
|
|
|
+ all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
|
|
|
+ # get metadata model config
|
|
|
+ if metadata_model_config is None:
|
|
|
+ raise ValueError("metadata_model_config is required")
|
|
|
+ # get metadata model instance
|
|
|
+ # fetch model config
|
|
|
+ model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config)
|
|
|
+
|
|
|
+ # fetch prompt messages
|
|
|
+ prompt_messages, stop = self._get_prompt_template(
|
|
|
+ model_config=model_config,
|
|
|
+ mode=metadata_model_config.mode,
|
|
|
+ metadata_fields=all_metadata_fields,
|
|
|
+ query=query or "",
|
|
|
+ )
|
|
|
+
|
|
|
+ result_text = ""
|
|
|
+ try:
|
|
|
+ # handle invoke result
|
|
|
+ invoke_result = cast(
|
|
|
+ Generator[LLMResult, None, None],
|
|
|
+ model_instance.invoke_llm(
|
|
|
+ prompt_messages=prompt_messages,
|
|
|
+ model_parameters=model_config.parameters,
|
|
|
+ stop=stop,
|
|
|
+ stream=True,
|
|
|
+ user=user_id,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ # handle invoke result
|
|
|
+ result_text, usage = self._handle_invoke_result(invoke_result=invoke_result)
|
|
|
+
|
|
|
+ result_text_json = parse_and_check_json_markdown(result_text, [])
|
|
|
+ automatic_metadata_filters = []
|
|
|
+ if "metadata_map" in result_text_json:
|
|
|
+ metadata_map = result_text_json["metadata_map"]
|
|
|
+ for item in metadata_map:
|
|
|
+ if item.get("metadata_field_name") in all_metadata_fields:
|
|
|
+ automatic_metadata_filters.append(
|
|
|
+ {
|
|
|
+ "metadata_name": item.get("metadata_field_name"),
|
|
|
+ "value": item.get("metadata_field_value"),
|
|
|
+ "condition": item.get("comparison_operator"),
|
|
|
+ }
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ return None
|
|
|
+ return automatic_metadata_filters
|
|
|
+
|
|
|
+ def _process_metadata_filter_func(self, condition: str, metadata_name: str, value: Optional[Any], filters: list):
|
|
|
+ match condition:
|
|
|
+ case "contains":
|
|
|
+ filters.append(
|
|
|
+ (text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}%")
|
|
|
+ )
|
|
|
+ case "not contains":
|
|
|
+ filters.append(
|
|
|
+ (text("documents.doc_metadata ->> :key NOT LIKE :value")).params(
|
|
|
+ key=metadata_name, value=f"%{value}%"
|
|
|
+ )
|
|
|
+ )
|
|
|
+ case "start with":
|
|
|
+ filters.append(
|
|
|
+ (text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"{value}%")
|
|
|
+ )
|
|
|
+
|
|
|
+ case "end with":
|
|
|
+ filters.append(
|
|
|
+ (text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}")
|
|
|
+ )
|
|
|
+ case "is" | "=":
|
|
|
+ if isinstance(value, str):
|
|
|
+ filters.append(DatasetDocument.doc_metadata[metadata_name] == f'"{value}"')
|
|
|
+ else:
|
|
|
+ filters.append(
|
|
|
+ sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) == value
|
|
|
+ )
|
|
|
+ case "is not" | "≠":
|
|
|
+ if isinstance(value, str):
|
|
|
+ filters.append(DatasetDocument.doc_metadata[metadata_name] != f'"{value}"')
|
|
|
+ else:
|
|
|
+ filters.append(
|
|
|
+ sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) != value
|
|
|
+ )
|
|
|
+ case "empty":
|
|
|
+ filters.append(DatasetDocument.doc_metadata[metadata_name].is_(None))
|
|
|
+ case "not empty":
|
|
|
+ filters.append(DatasetDocument.doc_metadata[metadata_name].isnot(None))
|
|
|
+ case "before" | "<":
|
|
|
+ filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) < value)
|
|
|
+ case "after" | ">":
|
|
|
+ filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) > value)
|
|
|
+ case "≤" | ">=":
|
|
|
+ filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) <= value)
|
|
|
+ case "≥" | ">=":
|
|
|
+ filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) >= value)
|
|
|
+ case _:
|
|
|
+ pass
|
|
|
+ return filters
|
|
|
+
|
|
|
+ def _fetch_model_config(
|
|
|
+ self, tenant_id: str, model: ModelConfig
|
|
|
+ ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
|
|
+ """
|
|
|
+ Fetch model config
|
|
|
+ :param node_data: node data
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ if model is None:
|
|
|
+ raise ValueError("single_retrieval_config is required")
|
|
|
+ model_name = model.name
|
|
|
+ provider_name = model.provider
|
|
|
+
|
|
|
+ model_manager = ModelManager()
|
|
|
+ model_instance = model_manager.get_model_instance(
|
|
|
+ tenant_id=tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
|
|
|
+ )
|
|
|
+
|
|
|
+ provider_model_bundle = model_instance.provider_model_bundle
|
|
|
+ model_type_instance = model_instance.model_type_instance
|
|
|
+ model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
|
|
+
|
|
|
+ model_credentials = model_instance.credentials
|
|
|
+
|
|
|
+ # check model
|
|
|
+ provider_model = provider_model_bundle.configuration.get_provider_model(
|
|
|
+ model=model_name, model_type=ModelType.LLM
|
|
|
+ )
|
|
|
+
|
|
|
+ if provider_model is None:
|
|
|
+ raise ValueError(f"Model {model_name} not exist.")
|
|
|
+
|
|
|
+ if provider_model.status == ModelStatus.NO_CONFIGURE:
|
|
|
+ raise ValueError(f"Model {model_name} credentials is not initialized.")
|
|
|
+ elif provider_model.status == ModelStatus.NO_PERMISSION:
|
|
|
+ raise ValueError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
|
|
+ elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
|
|
+ raise ValueError(f"Model provider {provider_name} quota exceeded.")
|
|
|
+
|
|
|
+ # model config
|
|
|
+ completion_params = model.completion_params
|
|
|
+ stop = []
|
|
|
+ if "stop" in completion_params:
|
|
|
+ stop = completion_params["stop"]
|
|
|
+ del completion_params["stop"]
|
|
|
+
|
|
|
+ # get model mode
|
|
|
+ model_mode = model.mode
|
|
|
+ if not model_mode:
|
|
|
+ raise ValueError("LLM mode is required.")
|
|
|
+
|
|
|
+ model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
|
|
+
|
|
|
+ if not model_schema:
|
|
|
+ raise ValueError(f"Model {model_name} not exist.")
|
|
|
+
|
|
|
+ return model_instance, ModelConfigWithCredentialsEntity(
|
|
|
+ provider=provider_name,
|
|
|
+ model=model_name,
|
|
|
+ model_schema=model_schema,
|
|
|
+ mode=model_mode,
|
|
|
+ provider_model_bundle=provider_model_bundle,
|
|
|
+ credentials=model_credentials,
|
|
|
+ parameters=completion_params,
|
|
|
+ stop=stop,
|
|
|
+ )
|
|
|
+
|
|
|
+ def _get_prompt_template(
|
|
|
+ self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list, query: str
|
|
|
+ ):
|
|
|
+ model_mode = ModelMode.value_of(mode)
|
|
|
+ input_text = query
|
|
|
+
|
|
|
+ prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]]
|
|
|
+ if model_mode == ModelMode.CHAT:
|
|
|
+ prompt_template = []
|
|
|
+ system_prompt_messages = ChatModelMessage(role=PromptMessageRole.SYSTEM, text=METADATA_FILTER_SYSTEM_PROMPT)
|
|
|
+ prompt_template.append(system_prompt_messages)
|
|
|
+ user_prompt_message_1 = ChatModelMessage(role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_1)
|
|
|
+ prompt_template.append(user_prompt_message_1)
|
|
|
+ assistant_prompt_message_1 = ChatModelMessage(
|
|
|
+ role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_1
|
|
|
+ )
|
|
|
+ prompt_template.append(assistant_prompt_message_1)
|
|
|
+ user_prompt_message_2 = ChatModelMessage(role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_2)
|
|
|
+ prompt_template.append(user_prompt_message_2)
|
|
|
+ assistant_prompt_message_2 = ChatModelMessage(
|
|
|
+ role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_2
|
|
|
+ )
|
|
|
+ prompt_template.append(assistant_prompt_message_2)
|
|
|
+ user_prompt_message_3 = ChatModelMessage(
|
|
|
+ role=PromptMessageRole.USER,
|
|
|
+ text=METADATA_FILTER_USER_PROMPT_3.format(
|
|
|
+ input_text=input_text,
|
|
|
+ metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ prompt_template.append(user_prompt_message_3)
|
|
|
+ elif model_mode == ModelMode.COMPLETION:
|
|
|
+ prompt_template = CompletionModelPromptTemplate(
|
|
|
+ text=METADATA_FILTER_COMPLETION_PROMPT.format(
|
|
|
+ input_text=input_text,
|
|
|
+ metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Model mode {model_mode} not support.")
|
|
|
+
|
|
|
+ prompt_transform = AdvancedPromptTransform()
|
|
|
+ prompt_messages = prompt_transform.get_prompt(
|
|
|
+ prompt_template=prompt_template,
|
|
|
+ inputs={},
|
|
|
+ query=query or "",
|
|
|
+ files=[],
|
|
|
+ context=None,
|
|
|
+ memory_config=None,
|
|
|
+ memory=None,
|
|
|
+ model_config=model_config,
|
|
|
+ )
|
|
|
+ stop = model_config.stop
|
|
|
+
|
|
|
+ return prompt_messages, stop
|
|
|
+
|
|
|
+ def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
|
|
|
+ """
|
|
|
+ Handle invoke result
|
|
|
+ :param invoke_result: invoke result
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ model = None
|
|
|
+ prompt_messages: list[PromptMessage] = []
|
|
|
+ full_text = ""
|
|
|
+ usage = None
|
|
|
+ for result in invoke_result:
|
|
|
+ text = result.delta.message.content
|
|
|
+ full_text += text
|
|
|
+
|
|
|
+ if not model:
|
|
|
+ model = result.model
|
|
|
+
|
|
|
+ if not prompt_messages:
|
|
|
+ prompt_messages = result.prompt_messages
|
|
|
+
|
|
|
+ if not usage and result.delta.usage:
|
|
|
+ usage = result.delta.usage
|
|
|
+
|
|
|
+ if not usage:
|
|
|
+ usage = LLMUsage.empty_usage()
|
|
|
+
|
|
|
+ return full_text, usage
|