فهرست منبع

feat: support setting database used in Milvus (#3003)

Leo Q 1 سال پیش
والد
کامیت
9c01bcb3e5

+ 2 - 0
api/config.py

@@ -67,6 +67,7 @@ DEFAULTS = {
     'CODE_EXECUTION_ENDPOINT': '',
     'CODE_EXECUTION_API_KEY': '',
     'TOOL_ICON_CACHE_MAX_AGE': 3600,
+    'MILVUS_DATABASE': 'default',
     'KEYWORD_DATA_SOURCE_TYPE': 'database',
 }
 
@@ -212,6 +213,7 @@ class Config:
         self.MILVUS_USER = get_env('MILVUS_USER')
         self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD')
         self.MILVUS_SECURE = get_env('MILVUS_SECURE')
+        self.MILVUS_DATABASE = get_env('MILVUS_DATABASE')
 
         # weaviate settings
         self.WEAVIATE_ENDPOINT = get_env('WEAVIATE_ENDPOINT')

+ 8 - 6
api/core/rag/datasource/vdb/milvus/milvus_vector.py

@@ -20,16 +20,17 @@ class MilvusConfig(BaseModel):
     password: str
     secure: bool = False
     batch_size: int = 100
+    database: str = "default"
 
     @root_validator()
     def validate_config(cls, values: dict) -> dict:
-        if not values['host']:
+        if not values.get('host'):
             raise ValueError("config MILVUS_HOST is required")
-        if not values['port']:
+        if not values.get('port'):
             raise ValueError("config MILVUS_PORT is required")
-        if not values['user']:
+        if not values.get('user'):
             raise ValueError("config MILVUS_USER is required")
-        if not values['password']:
+        if not values.get('password'):
             raise ValueError("config MILVUS_PASSWORD is required")
         return values
 
@@ -39,7 +40,8 @@ class MilvusConfig(BaseModel):
             'port': self.port,
             'user': self.user,
             'password': self.password,
-            'secure': self.secure
+            'secure': self.secure,
+            'db_name': self.database,
         }
 
 
@@ -192,7 +194,7 @@ class MilvusVector(BaseVector):
             else:
                 uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
             connections.connect(alias=alias, uri=uri, user=self._client_config.user,
-                                password=self._client_config.password)
+                                password=self._client_config.password, db_name=self._client_config.database)
             if not utility.has_collection(self._collection_name, using=alias):
                 from pymilvus import CollectionSchema, DataType, FieldSchema
                 from pymilvus.orm.types import infer_dtype_bydata

+ 1 - 0
api/core/rag/datasource/vdb/vector_factory.py

@@ -110,6 +110,7 @@ class Vector:
                     user=config.get('MILVUS_USER'),
                     password=config.get('MILVUS_PASSWORD'),
                     secure=config.get('MILVUS_SECURE'),
+                    database=config.get('MILVUS_DATABASE'),
                 )
             )
         else:

+ 24 - 0
api/tests/unittests/test_model.py

@@ -0,0 +1,24 @@
+import pytest
+from pydantic.error_wrappers import ValidationError
+
+from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig
+
+
+def test_default_value():
+    valid_config = {
+        'host': 'localhost',
+        'port': 19530,
+        'user': 'root',
+        'password': 'Milvus'
+    }
+
+    for key in valid_config:
+        config = valid_config.copy()
+        del config[key]
+        with pytest.raises(ValidationError) as e:
+            MilvusConfig(**config)
+        assert e.value.errors()[1]['msg'] == f'config MILVUS_{key.upper()} is required'
+    
+    config = MilvusConfig(**valid_config)
+    assert config.secure is False
+    assert config.database == 'default'