|
@@ -0,0 +1,526 @@
|
|
|
+import json
|
|
|
+import os
|
|
|
+import uuid
|
|
|
+from collections.abc import Generator, Iterable, Sequence
|
|
|
+from itertools import islice
|
|
|
+from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
|
|
+
|
|
|
+import qdrant_client
|
|
|
+import requests
|
|
|
+from flask import current_app
|
|
|
+from pydantic import BaseModel
|
|
|
+from qdrant_client.http import models as rest
|
|
|
+from qdrant_client.http.models import (
|
|
|
+ FilterSelector,
|
|
|
+ HnswConfigDiff,
|
|
|
+ PayloadSchemaType,
|
|
|
+ TextIndexParams,
|
|
|
+ TextIndexType,
|
|
|
+ TokenizerType,
|
|
|
+)
|
|
|
+from qdrant_client.local.qdrant_local import QdrantLocal
|
|
|
+from requests.auth import HTTPDigestAuth
|
|
|
+
|
|
|
+from configs import dify_config
|
|
|
+from core.rag.datasource.vdb.field import Field
|
|
|
+from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
|
|
|
+from core.rag.datasource.vdb.vector_base import BaseVector
|
|
|
+from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
|
|
+from core.rag.datasource.vdb.vector_type import VectorType
|
|
|
+from core.rag.embedding.embedding_base import Embeddings
|
|
|
+from core.rag.models.document import Document
|
|
|
+from extensions.ext_database import db
|
|
|
+from extensions.ext_redis import redis_client
|
|
|
+from models.dataset import Dataset, TidbAuthBinding
|
|
|
+
|
|
|
+if TYPE_CHECKING:
|
|
|
+ from qdrant_client import grpc # noqa
|
|
|
+ from qdrant_client.conversions import common_types
|
|
|
+ from qdrant_client.http import models as rest
|
|
|
+
|
|
|
+ DictFilter = dict[str, Union[str, int, bool, dict, list]]
|
|
|
+ MetadataFilter = Union[DictFilter, common_types.Filter]
|
|
|
+
|
|
|
+
|
|
|
+class TidbOnQdrantConfig(BaseModel):
|
|
|
+ endpoint: str
|
|
|
+ api_key: Optional[str] = None
|
|
|
+ timeout: float = 20
|
|
|
+ root_path: Optional[str] = None
|
|
|
+ grpc_port: int = 6334
|
|
|
+ prefer_grpc: bool = False
|
|
|
+
|
|
|
+ def to_qdrant_params(self):
|
|
|
+ if self.endpoint and self.endpoint.startswith("path:"):
|
|
|
+ path = self.endpoint.replace("path:", "")
|
|
|
+ if not os.path.isabs(path):
|
|
|
+ path = os.path.join(self.root_path, path)
|
|
|
+
|
|
|
+ return {"path": path}
|
|
|
+ else:
|
|
|
+ return {
|
|
|
+ "url": self.endpoint,
|
|
|
+ "api_key": self.api_key,
|
|
|
+ "timeout": self.timeout,
|
|
|
+ "verify": False,
|
|
|
+ "grpc_port": self.grpc_port,
|
|
|
+ "prefer_grpc": self.prefer_grpc,
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+class TidbConfig(BaseModel):
|
|
|
+ api_url: str
|
|
|
+ public_key: str
|
|
|
+ private_key: str
|
|
|
+
|
|
|
+
|
|
|
+class TidbOnQdrantVector(BaseVector):
|
|
|
+ def __init__(self, collection_name: str, group_id: str, config: TidbOnQdrantConfig, distance_func: str = "Cosine"):
|
|
|
+ super().__init__(collection_name)
|
|
|
+ self._client_config = config
|
|
|
+ self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params())
|
|
|
+ self._distance_func = distance_func.upper()
|
|
|
+ self._group_id = group_id
|
|
|
+
|
|
|
+ def get_type(self) -> str:
|
|
|
+ return VectorType.TIDB_ON_QDRANT
|
|
|
+
|
|
|
+ def to_index_struct(self) -> dict:
|
|
|
+ return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
|
|
|
+
|
|
|
+ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
|
|
+ if texts:
|
|
|
+ # get embedding vector size
|
|
|
+ vector_size = len(embeddings[0])
|
|
|
+ # get collection name
|
|
|
+ collection_name = self._collection_name
|
|
|
+ # create collection
|
|
|
+ self.create_collection(collection_name, vector_size)
|
|
|
+
|
|
|
+ self.add_texts(texts, embeddings, **kwargs)
|
|
|
+
|
|
|
+ def create_collection(self, collection_name: str, vector_size: int):
|
|
|
+ lock_name = "vector_indexing_lock_{}".format(collection_name)
|
|
|
+ with redis_client.lock(lock_name, timeout=20):
|
|
|
+ collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
|
|
+ if redis_client.get(collection_exist_cache_key):
|
|
|
+ return
|
|
|
+ collection_name = collection_name or uuid.uuid4().hex
|
|
|
+ all_collection_name = []
|
|
|
+ collections_response = self._client.get_collections()
|
|
|
+ collection_list = collections_response.collections
|
|
|
+ for collection in collection_list:
|
|
|
+ all_collection_name.append(collection.name)
|
|
|
+ if collection_name not in all_collection_name:
|
|
|
+ from qdrant_client.http import models as rest
|
|
|
+
|
|
|
+ vectors_config = rest.VectorParams(
|
|
|
+ size=vector_size,
|
|
|
+ distance=rest.Distance[self._distance_func],
|
|
|
+ )
|
|
|
+ hnsw_config = HnswConfigDiff(
|
|
|
+ m=0,
|
|
|
+ payload_m=16,
|
|
|
+ ef_construct=100,
|
|
|
+ full_scan_threshold=10000,
|
|
|
+ max_indexing_threads=0,
|
|
|
+ on_disk=False,
|
|
|
+ )
|
|
|
+ self._client.recreate_collection(
|
|
|
+ collection_name=collection_name,
|
|
|
+ vectors_config=vectors_config,
|
|
|
+ hnsw_config=hnsw_config,
|
|
|
+ timeout=int(self._client_config.timeout),
|
|
|
+ )
|
|
|
+
|
|
|
+ # create group_id payload index
|
|
|
+ self._client.create_payload_index(
|
|
|
+ collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD
|
|
|
+ )
|
|
|
+ # create doc_id payload index
|
|
|
+ self._client.create_payload_index(
|
|
|
+ collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD
|
|
|
+ )
|
|
|
+ # create full text index
|
|
|
+ text_index_params = TextIndexParams(
|
|
|
+ type=TextIndexType.TEXT,
|
|
|
+ tokenizer=TokenizerType.MULTILINGUAL,
|
|
|
+ min_token_len=2,
|
|
|
+ max_token_len=20,
|
|
|
+ lowercase=True,
|
|
|
+ )
|
|
|
+ self._client.create_payload_index(
|
|
|
+ collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params
|
|
|
+ )
|
|
|
+ redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
|
|
+
|
|
|
+ def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
|
|
+ uuids = self._get_uuids(documents)
|
|
|
+ texts = [d.page_content for d in documents]
|
|
|
+ metadatas = [d.metadata for d in documents]
|
|
|
+
|
|
|
+ added_ids = []
|
|
|
+ for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id):
|
|
|
+ self._client.upsert(collection_name=self._collection_name, points=points)
|
|
|
+ added_ids.extend(batch_ids)
|
|
|
+
|
|
|
+ return added_ids
|
|
|
+
|
|
|
+ def _generate_rest_batches(
|
|
|
+ self,
|
|
|
+ texts: Iterable[str],
|
|
|
+ embeddings: list[list[float]],
|
|
|
+ metadatas: Optional[list[dict]] = None,
|
|
|
+ ids: Optional[Sequence[str]] = None,
|
|
|
+ batch_size: int = 64,
|
|
|
+ group_id: Optional[str] = None,
|
|
|
+ ) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]:
|
|
|
+ from qdrant_client.http import models as rest
|
|
|
+
|
|
|
+ texts_iterator = iter(texts)
|
|
|
+ embeddings_iterator = iter(embeddings)
|
|
|
+ metadatas_iterator = iter(metadatas or [])
|
|
|
+ ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
|
|
|
+ while batch_texts := list(islice(texts_iterator, batch_size)):
|
|
|
+ # Take the corresponding metadata and id for each text in a batch
|
|
|
+ batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
|
|
|
+ batch_ids = list(islice(ids_iterator, batch_size))
|
|
|
+
|
|
|
+ # Generate the embeddings for all the texts in a batch
|
|
|
+ batch_embeddings = list(islice(embeddings_iterator, batch_size))
|
|
|
+
|
|
|
+ points = [
|
|
|
+ rest.PointStruct(
|
|
|
+ id=point_id,
|
|
|
+ vector=vector,
|
|
|
+ payload=payload,
|
|
|
+ )
|
|
|
+ for point_id, vector, payload in zip(
|
|
|
+ batch_ids,
|
|
|
+ batch_embeddings,
|
|
|
+ self._build_payloads(
|
|
|
+ batch_texts,
|
|
|
+ batch_metadatas,
|
|
|
+ Field.CONTENT_KEY.value,
|
|
|
+ Field.METADATA_KEY.value,
|
|
|
+ group_id,
|
|
|
+ Field.GROUP_KEY.value,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ ]
|
|
|
+
|
|
|
+ yield batch_ids, points
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _build_payloads(
|
|
|
+ cls,
|
|
|
+ texts: Iterable[str],
|
|
|
+ metadatas: Optional[list[dict]],
|
|
|
+ content_payload_key: str,
|
|
|
+ metadata_payload_key: str,
|
|
|
+ group_id: str,
|
|
|
+ group_payload_key: str,
|
|
|
+ ) -> list[dict]:
|
|
|
+ payloads = []
|
|
|
+ for i, text in enumerate(texts):
|
|
|
+ if text is None:
|
|
|
+ raise ValueError(
|
|
|
+ "At least one of the texts is None. Please remove it before "
|
|
|
+ "calling .from_texts or .add_texts on Qdrant instance."
|
|
|
+ )
|
|
|
+ metadata = metadatas[i] if metadatas is not None else None
|
|
|
+ payloads.append({content_payload_key: text, metadata_payload_key: metadata, group_payload_key: group_id})
|
|
|
+
|
|
|
+ return payloads
|
|
|
+
|
|
|
+ def delete_by_metadata_field(self, key: str, value: str):
|
|
|
+ from qdrant_client.http import models
|
|
|
+ from qdrant_client.http.exceptions import UnexpectedResponse
|
|
|
+
|
|
|
+ try:
|
|
|
+ filter = models.Filter(
|
|
|
+ must=[
|
|
|
+ models.FieldCondition(
|
|
|
+ key=f"metadata.{key}",
|
|
|
+ match=models.MatchValue(value=value),
|
|
|
+ ),
|
|
|
+ ],
|
|
|
+ )
|
|
|
+
|
|
|
+ self._reload_if_needed()
|
|
|
+
|
|
|
+ self._client.delete(
|
|
|
+ collection_name=self._collection_name,
|
|
|
+ points_selector=FilterSelector(filter=filter),
|
|
|
+ )
|
|
|
+ except UnexpectedResponse as e:
|
|
|
+ # Collection does not exist, so return
|
|
|
+ if e.status_code == 404:
|
|
|
+ return
|
|
|
+ # Some other error occurred, so re-raise the exception
|
|
|
+ else:
|
|
|
+ raise e
|
|
|
+
|
|
|
+ def delete(self):
|
|
|
+ from qdrant_client.http.exceptions import UnexpectedResponse
|
|
|
+
|
|
|
+ try:
|
|
|
+ self._client.delete_collection(collection_name=self._collection_name)
|
|
|
+ except UnexpectedResponse as e:
|
|
|
+ # Collection does not exist, so return
|
|
|
+ if e.status_code == 404:
|
|
|
+ return
|
|
|
+ # Some other error occurred, so re-raise the exception
|
|
|
+ else:
|
|
|
+ raise e
|
|
|
+
|
|
|
+ def delete_by_ids(self, ids: list[str]) -> None:
|
|
|
+ from qdrant_client.http import models
|
|
|
+ from qdrant_client.http.exceptions import UnexpectedResponse
|
|
|
+
|
|
|
+ for node_id in ids:
|
|
|
+ try:
|
|
|
+ filter = models.Filter(
|
|
|
+ must=[
|
|
|
+ models.FieldCondition(
|
|
|
+ key="metadata.doc_id",
|
|
|
+ match=models.MatchValue(value=node_id),
|
|
|
+ ),
|
|
|
+ ],
|
|
|
+ )
|
|
|
+ self._client.delete(
|
|
|
+ collection_name=self._collection_name,
|
|
|
+ points_selector=FilterSelector(filter=filter),
|
|
|
+ )
|
|
|
+ except UnexpectedResponse as e:
|
|
|
+ # Collection does not exist, so return
|
|
|
+ if e.status_code == 404:
|
|
|
+ return
|
|
|
+ # Some other error occurred, so re-raise the exception
|
|
|
+ else:
|
|
|
+ raise e
|
|
|
+
|
|
|
+ def text_exists(self, id: str) -> bool:
|
|
|
+ all_collection_name = []
|
|
|
+ collections_response = self._client.get_collections()
|
|
|
+ collection_list = collections_response.collections
|
|
|
+ for collection in collection_list:
|
|
|
+ all_collection_name.append(collection.name)
|
|
|
+ if self._collection_name not in all_collection_name:
|
|
|
+ return False
|
|
|
+ response = self._client.retrieve(collection_name=self._collection_name, ids=[id])
|
|
|
+
|
|
|
+ return len(response) > 0
|
|
|
+
|
|
|
+ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
|
|
+ from qdrant_client.http import models
|
|
|
+
|
|
|
+ filter = models.Filter(
|
|
|
+ must=[
|
|
|
+ models.FieldCondition(
|
|
|
+ key="group_id",
|
|
|
+ match=models.MatchValue(value=self._group_id),
|
|
|
+ ),
|
|
|
+ ],
|
|
|
+ )
|
|
|
+ results = self._client.search(
|
|
|
+ collection_name=self._collection_name,
|
|
|
+ query_vector=query_vector,
|
|
|
+ query_filter=filter,
|
|
|
+ limit=kwargs.get("top_k", 4),
|
|
|
+ with_payload=True,
|
|
|
+ with_vectors=True,
|
|
|
+ score_threshold=kwargs.get("score_threshold", 0.0),
|
|
|
+ )
|
|
|
+ docs = []
|
|
|
+ for result in results:
|
|
|
+ metadata = result.payload.get(Field.METADATA_KEY.value) or {}
|
|
|
+ # duplicate check score threshold
|
|
|
+ score_threshold = kwargs.get("score_threshold") or 0.0
|
|
|
+ if result.score > score_threshold:
|
|
|
+ metadata["score"] = result.score
|
|
|
+ doc = Document(
|
|
|
+ page_content=result.payload.get(Field.CONTENT_KEY.value),
|
|
|
+ metadata=metadata,
|
|
|
+ )
|
|
|
+ docs.append(doc)
|
|
|
+ # Sort the documents by score in descending order
|
|
|
+ docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
|
|
|
+ return docs
|
|
|
+
|
|
|
+ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
|
|
+ """Return docs most similar by bm25.
|
|
|
+ Returns:
|
|
|
+ List of documents most similar to the query text and distance for each.
|
|
|
+ """
|
|
|
+ from qdrant_client.http import models
|
|
|
+
|
|
|
+ scroll_filter = models.Filter(
|
|
|
+ must=[
|
|
|
+ models.FieldCondition(
|
|
|
+ key="page_content",
|
|
|
+ match=models.MatchText(text=query),
|
|
|
+ )
|
|
|
+ ]
|
|
|
+ )
|
|
|
+ response = self._client.scroll(
|
|
|
+ collection_name=self._collection_name,
|
|
|
+ scroll_filter=scroll_filter,
|
|
|
+ limit=kwargs.get("top_k", 2),
|
|
|
+ with_payload=True,
|
|
|
+ with_vectors=True,
|
|
|
+ )
|
|
|
+ results = response[0]
|
|
|
+ documents = []
|
|
|
+ for result in results:
|
|
|
+ if result:
|
|
|
+ document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value)
|
|
|
+ document.metadata["vector"] = result.vector
|
|
|
+ documents.append(document)
|
|
|
+
|
|
|
+ return documents
|
|
|
+
|
|
|
+ def _reload_if_needed(self):
|
|
|
+ if isinstance(self._client, QdrantLocal):
|
|
|
+ self._client = cast(QdrantLocal, self._client)
|
|
|
+ self._client._load()
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _document_from_scored_point(
|
|
|
+ cls,
|
|
|
+ scored_point: Any,
|
|
|
+ content_payload_key: str,
|
|
|
+ metadata_payload_key: str,
|
|
|
+ ) -> Document:
|
|
|
+ return Document(
|
|
|
+ page_content=scored_point.payload.get(content_payload_key),
|
|
|
+ metadata=scored_point.payload.get(metadata_payload_key) or {},
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
|
|
+ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector:
|
|
|
+ tidb_auth_binding = (
|
|
|
+ db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
|
|
|
+ )
|
|
|
+ if not tidb_auth_binding:
|
|
|
+ idle_tidb_auth_binding = (
|
|
|
+ db.session.query(TidbAuthBinding)
|
|
|
+ .filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
|
|
|
+ .limit(1)
|
|
|
+ .one_or_none()
|
|
|
+ )
|
|
|
+ if idle_tidb_auth_binding:
|
|
|
+ idle_tidb_auth_binding.active = True
|
|
|
+ idle_tidb_auth_binding.tenant_id = dataset.tenant_id
|
|
|
+ db.session.commit()
|
|
|
+ TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
|
|
|
+ else:
|
|
|
+ with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
|
|
|
+ tidb_auth_binding = (
|
|
|
+ db.session.query(TidbAuthBinding)
|
|
|
+ .filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
|
|
+ .one_or_none()
|
|
|
+ )
|
|
|
+ if tidb_auth_binding:
|
|
|
+ TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
|
|
+
|
|
|
+ else:
|
|
|
+ new_cluster = TidbService.create_tidb_serverless_cluster(
|
|
|
+ dify_config.TIDB_PROJECT_ID,
|
|
|
+ dify_config.TIDB_API_URL,
|
|
|
+ dify_config.TIDB_IAM_API_URL,
|
|
|
+ dify_config.TIDB_PUBLIC_KEY,
|
|
|
+ dify_config.TIDB_PRIVATE_KEY,
|
|
|
+ dify_config.TIDB_REGION,
|
|
|
+ )
|
|
|
+ new_tidb_auth_binding = TidbAuthBinding(
|
|
|
+ cluster_id=new_cluster["cluster_id"],
|
|
|
+ cluster_name=new_cluster["cluster_name"],
|
|
|
+ account=new_cluster["account"],
|
|
|
+ password=new_cluster["password"],
|
|
|
+ tenant_id=dataset.tenant_id,
|
|
|
+ active=True,
|
|
|
+ status="ACTIVE",
|
|
|
+ )
|
|
|
+ db.session.add(new_tidb_auth_binding)
|
|
|
+ db.session.commit()
|
|
|
+ TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}"
|
|
|
+
|
|
|
+ else:
|
|
|
+ TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
|
|
+
|
|
|
+ if dataset.index_struct_dict:
|
|
|
+ class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
|
|
+ collection_name = class_prefix
|
|
|
+ else:
|
|
|
+ dataset_id = dataset.id
|
|
|
+ collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
|
|
+ dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TIDB_ON_QDRANT, collection_name))
|
|
|
+
|
|
|
+ config = current_app.config
|
|
|
+
|
|
|
+ return TidbOnQdrantVector(
|
|
|
+ collection_name=collection_name,
|
|
|
+ group_id=dataset.id,
|
|
|
+ config=TidbOnQdrantConfig(
|
|
|
+ endpoint=dify_config.TIDB_ON_QDRANT_URL,
|
|
|
+ api_key=TIDB_ON_QDRANT_API_KEY,
|
|
|
+ root_path=config.root_path,
|
|
|
+ timeout=dify_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT,
|
|
|
+ grpc_port=dify_config.TIDB_ON_QDRANT_GRPC_PORT,
|
|
|
+ prefer_grpc=dify_config.TIDB_ON_QDRANT_GRPC_ENABLED,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ def create_tidb_serverless_cluster(self, tidb_config: TidbConfig, display_name: str, region: str):
|
|
|
+ """
|
|
|
+ Creates a new TiDB Serverless cluster.
|
|
|
+ :param tidb_config: The configuration for the TiDB Cloud API.
|
|
|
+ :param display_name: The user-friendly display name of the cluster (required).
|
|
|
+ :param region: The region where the cluster will be created (required).
|
|
|
+
|
|
|
+ :return: The response from the API.
|
|
|
+ """
|
|
|
+ region_object = {
|
|
|
+ "name": region,
|
|
|
+ }
|
|
|
+
|
|
|
+ labels = {
|
|
|
+ "tidb.cloud/project": "1372813089454548012",
|
|
|
+ }
|
|
|
+ cluster_data = {"displayName": display_name, "region": region_object, "labels": labels}
|
|
|
+
|
|
|
+ response = requests.post(
|
|
|
+ f"{tidb_config.api_url}/clusters",
|
|
|
+ json=cluster_data,
|
|
|
+ auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key),
|
|
|
+ )
|
|
|
+
|
|
|
+ if response.status_code == 200:
|
|
|
+ return response.json()
|
|
|
+ else:
|
|
|
+ response.raise_for_status()
|
|
|
+
|
|
|
+ def change_tidb_serverless_root_password(self, tidb_config: TidbConfig, cluster_id: str, new_password: str):
|
|
|
+ """
|
|
|
+ Changes the root password of a specific TiDB Serverless cluster.
|
|
|
+
|
|
|
+ :param tidb_config: The configuration for the TiDB Cloud API.
|
|
|
+ :param cluster_id: The ID of the cluster for which the password is to be changed (required).
|
|
|
+ :param new_password: The new password for the root user (required).
|
|
|
+ :return: The response from the API.
|
|
|
+ """
|
|
|
+
|
|
|
+ body = {"password": new_password}
|
|
|
+
|
|
|
+ response = requests.put(
|
|
|
+ f"{tidb_config.api_url}/clusters/{cluster_id}/password",
|
|
|
+ json=body,
|
|
|
+ auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key),
|
|
|
+ )
|
|
|
+
|
|
|
+ if response.status_code == 200:
|
|
|
+ return response.json()
|
|
|
+ else:
|
|
|
+ response.raise_for_status()
|