|
@@ -20,16 +20,17 @@ class MilvusConfig(BaseModel):
|
|
|
password: str
|
|
|
secure: bool = False
|
|
|
batch_size: int = 100
|
|
|
+ database: str = "default"
|
|
|
|
|
|
@root_validator()
|
|
|
def validate_config(cls, values: dict) -> dict:
|
|
|
- if not values['host']:
|
|
|
+ if not values.get('host'):
|
|
|
raise ValueError("config MILVUS_HOST is required")
|
|
|
- if not values['port']:
|
|
|
+ if not values.get('port'):
|
|
|
raise ValueError("config MILVUS_PORT is required")
|
|
|
- if not values['user']:
|
|
|
+ if not values.get('user'):
|
|
|
raise ValueError("config MILVUS_USER is required")
|
|
|
- if not values['password']:
|
|
|
+ if not values.get('password'):
|
|
|
raise ValueError("config MILVUS_PASSWORD is required")
|
|
|
return values
|
|
|
|
|
@@ -39,7 +40,8 @@ class MilvusConfig(BaseModel):
|
|
|
'port': self.port,
|
|
|
'user': self.user,
|
|
|
'password': self.password,
|
|
|
- 'secure': self.secure
|
|
|
+ 'secure': self.secure,
|
|
|
+ 'db_name': self.database,
|
|
|
}
|
|
|
|
|
|
|
|
@@ -192,7 +194,7 @@ class MilvusVector(BaseVector):
|
|
|
else:
|
|
|
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
|
|
connections.connect(alias=alias, uri=uri, user=self._client_config.user,
|
|
|
- password=self._client_config.password)
|
|
|
+ password=self._client_config.password, db_name=self._client_config.database)
|
|
|
if not utility.has_collection(self._collection_name, using=alias):
|
|
|
from pymilvus import CollectionSchema, DataType, FieldSchema
|
|
|
from pymilvus.orm.types import infer_dtype_bydata
|