|
@@ -1,10 +1,9 @@
|
|
|
import json
|
|
|
import logging
|
|
|
from typing import Any, Optional
|
|
|
-from uuid import uuid4
|
|
|
|
|
|
from pydantic import BaseModel, model_validator
|
|
|
-from pymilvus import MilvusClient, MilvusException, connections
|
|
|
+from pymilvus import MilvusClient, MilvusException
|
|
|
from pymilvus.milvus_client import IndexParams
|
|
|
|
|
|
from configs import dify_config
|
|
@@ -21,20 +20,17 @@ logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
class MilvusConfig(BaseModel):
|
|
|
- host: str
|
|
|
- port: int
|
|
|
+ uri: str
|
|
|
+ token: Optional[str] = None
|
|
|
user: str
|
|
|
password: str
|
|
|
- secure: bool = False
|
|
|
batch_size: int = 100
|
|
|
database: str = "default"
|
|
|
|
|
|
@model_validator(mode='before')
|
|
|
def validate_config(cls, values: dict) -> dict:
|
|
|
- if not values.get('host'):
|
|
|
- raise ValueError("config MILVUS_HOST is required")
|
|
|
- if not values.get('port'):
|
|
|
- raise ValueError("config MILVUS_PORT is required")
|
|
|
+ if not values.get('uri'):
|
|
|
+ raise ValueError("config MILVUS_URI is required")
|
|
|
if not values.get('user'):
|
|
|
raise ValueError("config MILVUS_USER is required")
|
|
|
if not values.get('password'):
|
|
@@ -43,11 +39,10 @@ class MilvusConfig(BaseModel):
|
|
|
|
|
|
def to_milvus_params(self):
|
|
|
return {
|
|
|
- 'host': self.host,
|
|
|
- 'port': self.port,
|
|
|
+ 'uri': self.uri,
|
|
|
+ 'token': self.token,
|
|
|
'user': self.user,
|
|
|
'password': self.password,
|
|
|
- 'secure': self.secure,
|
|
|
'db_name': self.database,
|
|
|
}
|
|
|
|
|
@@ -111,32 +106,14 @@ class MilvusVector(BaseVector):
|
|
|
return None
|
|
|
|
|
|
def delete_by_metadata_field(self, key: str, value: str):
|
|
|
- alias = uuid4().hex
|
|
|
- if self._client_config.secure:
|
|
|
- uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
|
|
- else:
|
|
|
- uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
|
|
- connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password,
|
|
|
- db_name=self._client_config.database)
|
|
|
-
|
|
|
- from pymilvus import utility
|
|
|
- if utility.has_collection(self._collection_name, using=alias):
|
|
|
+ if self._client.has_collection(self._collection_name):
|
|
|
|
|
|
ids = self.get_ids_by_metadata_field(key, value)
|
|
|
if ids:
|
|
|
self._client.delete(collection_name=self._collection_name, pks=ids)
|
|
|
|
|
|
def delete_by_ids(self, ids: list[str]) -> None:
|
|
|
- alias = uuid4().hex
|
|
|
- if self._client_config.secure:
|
|
|
- uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
|
|
- else:
|
|
|
- uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
|
|
- connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password,
|
|
|
- db_name=self._client_config.database)
|
|
|
-
|
|
|
- from pymilvus import utility
|
|
|
- if utility.has_collection(self._collection_name, using=alias):
|
|
|
+ if self._client.has_collection(self._collection_name):
|
|
|
|
|
|
result = self._client.query(collection_name=self._collection_name,
|
|
|
filter=f'metadata["doc_id"] in {ids}',
|
|
@@ -146,29 +123,11 @@ class MilvusVector(BaseVector):
|
|
|
self._client.delete(collection_name=self._collection_name, pks=ids)
|
|
|
|
|
|
def delete(self) -> None:
|
|
|
- alias = uuid4().hex
|
|
|
- if self._client_config.secure:
|
|
|
- uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
|
|
- else:
|
|
|
- uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
|
|
- connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password,
|
|
|
- db_name=self._client_config.database)
|
|
|
-
|
|
|
- from pymilvus import utility
|
|
|
- if utility.has_collection(self._collection_name, using=alias):
|
|
|
- utility.drop_collection(self._collection_name, None, using=alias)
|
|
|
+ if self._client.has_collection(self._collection_name):
|
|
|
+ self._client.drop_collection(self._collection_name, None)
|
|
|
|
|
|
def text_exists(self, id: str) -> bool:
|
|
|
- alias = uuid4().hex
|
|
|
- if self._client_config.secure:
|
|
|
- uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
|
|
- else:
|
|
|
- uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
|
|
- connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password,
|
|
|
- db_name=self._client_config.database)
|
|
|
-
|
|
|
- from pymilvus import utility
|
|
|
- if not utility.has_collection(self._collection_name, using=alias):
|
|
|
+ if not self._client.has_collection(self._collection_name):
|
|
|
return False
|
|
|
|
|
|
result = self._client.query(collection_name=self._collection_name,
|
|
@@ -210,15 +169,7 @@ class MilvusVector(BaseVector):
|
|
|
if redis_client.get(collection_exist_cache_key):
|
|
|
return
|
|
|
# Grab the existing collection if it exists
|
|
|
- from pymilvus import utility
|
|
|
- alias = uuid4().hex
|
|
|
- if self._client_config.secure:
|
|
|
- uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
|
|
- else:
|
|
|
- uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
|
|
- connections.connect(alias=alias, uri=uri, user=self._client_config.user,
|
|
|
- password=self._client_config.password, db_name=self._client_config.database)
|
|
|
- if not utility.has_collection(self._collection_name, using=alias):
|
|
|
+ if not self._client.has_collection(self._collection_name):
|
|
|
from pymilvus import CollectionSchema, DataType, FieldSchema
|
|
|
from pymilvus.orm.types import infer_dtype_bydata
|
|
|
|
|
@@ -263,11 +214,7 @@ class MilvusVector(BaseVector):
|
|
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
|
|
|
|
|
def _init_client(self, config) -> MilvusClient:
|
|
|
- if config.secure:
|
|
|
- uri = "https://" + str(config.host) + ":" + str(config.port)
|
|
|
- else:
|
|
|
- uri = "http://" + str(config.host) + ":" + str(config.port)
|
|
|
- client = MilvusClient(uri=uri, user=config.user, password=config.password, db_name=config.database)
|
|
|
+ client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
|
|
|
return client
|
|
|
|
|
|
|
|
@@ -285,11 +232,10 @@ class MilvusVectorFactory(AbstractVectorFactory):
|
|
|
return MilvusVector(
|
|
|
collection_name=collection_name,
|
|
|
config=MilvusConfig(
|
|
|
- host=dify_config.MILVUS_HOST,
|
|
|
- port=dify_config.MILVUS_PORT,
|
|
|
+ uri=dify_config.MILVUS_URI,
|
|
|
+ token=dify_config.MILVUS_TOKEN,
|
|
|
user=dify_config.MILVUS_USER,
|
|
|
password=dify_config.MILVUS_PASSWORD,
|
|
|
- secure=dify_config.MILVUS_SECURE,
|
|
|
database=dify_config.MILVUS_DATABASE,
|
|
|
)
|
|
|
)
|