diff --git a/CHANGELOG.md b/CHANGELOG.md index 4dd98ae94121b556dd26dc9b2bbb77c237f69c40..1d16621fad44fe0585390e91163eef79f14805e8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,7 @@ Please mark all change in change log and use the ticket from JIRA. - \#533 - NSG build failed with MetricType Inner Product - \#543 - client raise exception in shards when search results is empty - \#545 - Avoid dead circle of build index thread when error occurs +- \#547 - NSG build failed using GPU-edition if set gpu_enable false - \#552 - Server down during building index_type: IVF_PQ using GPU-edition - \#561 - Milvus server should report exception/error message or terminate on mysql metadata backend error - \#596 - Frequently insert operation cost too much disk space @@ -73,6 +74,7 @@ Please mark all change in change log and use the ticket from JIRA. - \#449 - Add ShowPartitions example for C++ SDK - \#470 - Small raw files should not be build index - \#584 - Intergrate internal FAISS +- \#611 - Remove MILVUS_CPU_VERSION ## Task diff --git a/core/CMakeLists.txt b/core/CMakeLists.txt index cf022f94a2506d8a160cdbd4a7b39beb7d7079a2..dd482f6464be8edf1db234bcf2c91803691e241a 100644 --- a/core/CMakeLists.txt +++ b/core/CMakeLists.txt @@ -146,7 +146,6 @@ if (CUSTOMIZATION) add_compile_definitions(CUSTOMIZATION) endif () -set(MILVUS_CPU_VERSION false) if (MILVUS_GPU_VERSION) message(STATUS "Building Milvus GPU version") add_compile_definitions("MILVUS_GPU_VERSION") @@ -155,8 +154,6 @@ if (MILVUS_GPU_VERSION) set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -Xcompiler -fPIC -std=c++11 -D_FORCE_INLINES --expt-extended-lambda") else () message(STATUS "Building Milvus CPU version") - set(MILVUS_CPU_VERSION true) - add_compile_definitions("MILVUS_CPU_VERSION") endif () if (MILVUS_WITH_PROMETHEUS) diff --git a/core/src/db/DBImpl.cpp b/core/src/db/DBImpl.cpp index b51ce571fdffd6927e429540c70abfa8f3a0de26..a5571db314873f4f3658e9c9b5836464723fae0c 100644 --- a/core/src/db/DBImpl.cpp +++ b/core/src/db/DBImpl.cpp @@ -1044,10 +1044,10 @@ DBImpl::BuildTableIndexRecursively(const std::string& table_id, const TableIndex if (!failed_files.empty()) { std::string msg = "Failed to build index for " + std::to_string(failed_files.size()) + ((failed_files.size() == 1) ? " file" : " files"); -#ifdef MILVUS_CPU_VERSION - msg += ", please double check index parameters."; -#else +#ifdef MILVUS_GPU_VERSION msg += ", file size is too large or gpu memory is not enough."; +#else + msg += ", please double check index parameters."; #endif return Status(DB_ERROR, msg); } diff --git a/core/src/db/engine/ExecutionEngineImpl.cpp b/core/src/db/engine/ExecutionEngineImpl.cpp index 6bb0b978173e3eabb71c30ab2cd540f115a6a8d8..1b9aaf3bf87bc4cbd362e616bd5322d4dbd8f8b0 100644 --- a/core/src/db/engine/ExecutionEngineImpl.cpp +++ b/core/src/db/engine/ExecutionEngineImpl.cpp @@ -93,18 +93,18 @@ ExecutionEngineImpl::CreatetVecIndex(EngineType type) { break; } case EngineType::FAISS_IVFFLAT: { -#ifdef MILVUS_CPU_VERSION - index = GetVecIndexFactory(IndexType::FAISS_IVFFLAT_CPU); -#else +#ifdef MILVUS_GPU_VERSION index = GetVecIndexFactory(IndexType::FAISS_IVFFLAT_MIX); +#else + index = GetVecIndexFactory(IndexType::FAISS_IVFFLAT_CPU); #endif break; } case EngineType::FAISS_IVFSQ8: { -#ifdef MILVUS_CPU_VERSION - index = GetVecIndexFactory(IndexType::FAISS_IVFSQ8_CPU); -#else +#ifdef MILVUS_GPU_VERSION index = GetVecIndexFactory(IndexType::FAISS_IVFSQ8_MIX); +#else + index = GetVecIndexFactory(IndexType::FAISS_IVFSQ8_CPU); #endif break; } @@ -119,10 +119,10 @@ ExecutionEngineImpl::CreatetVecIndex(EngineType type) { } #endif case EngineType::FAISS_PQ: { -#ifdef MILVUS_CPU_VERSION - index = GetVecIndexFactory(IndexType::FAISS_IVFPQ_CPU); -#else +#ifdef MILVUS_GPU_VERSION index = GetVecIndexFactory(IndexType::FAISS_IVFPQ_MIX); +#else + index = GetVecIndexFactory(IndexType::FAISS_IVFPQ_CPU); #endif break; } @@ -618,6 +618,9 @@ ExecutionEngineImpl::Init() { server::Config& config = server::Config::GetInstance(); std::vector gpu_ids; Status s = config.GetGpuResourceConfigBuildIndexResources(gpu_ids); + if (!s.ok()) { + gpu_num_ = knowhere::INVALID_VALUE; + } for (auto id : gpu_ids) { if (gpu_num_ == id) { return Status::OK(); diff --git a/core/src/grpc/README.md b/core/src/grpc/README.md deleted file mode 100644 index 6a3fe1157caff720570eea422d169ccd7473de1f..0000000000000000000000000000000000000000 --- a/core/src/grpc/README.md +++ /dev/null @@ -1,6 +0,0 @@ -We manually change two APIs in "milvus.pb.h": - add_vector_data() - add_row_id_array() - add_ids() - add_distances() -If proto files need be generated again, remember to re-change above APIs. \ No newline at end of file diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexNSG.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexNSG.cpp index 16c0b9172f6a85b6d88479e961ebcf6d6083d8f9..370df76b9bf3743335997e910ea568f0f76e51cf 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/IndexNSG.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexNSG.cpp @@ -116,17 +116,28 @@ NSG::Train(const DatasetPtr& dataset, const Config& config) { } // TODO(linxj): dev IndexFactory, support more IndexType + Graph knng; #ifdef MILVUS_GPU_VERSION - auto preprocess_index = std::make_shared(build_cfg->gpu_id); + if (build_cfg->gpu_id == knowhere::INVALID_VALUE) { + auto preprocess_index = std::make_shared(); + auto model = preprocess_index->Train(dataset, config); + preprocess_index->set_index_model(model); + preprocess_index->AddWithoutIds(dataset, config); + preprocess_index->GenGraph(build_cfg->knng, knng, dataset, config); + } else { + auto preprocess_index = std::make_shared(build_cfg->gpu_id); + auto model = preprocess_index->Train(dataset, config); + preprocess_index->set_index_model(model); + preprocess_index->AddWithoutIds(dataset, config); + preprocess_index->GenGraph(build_cfg->knng, knng, dataset, config); + } #else auto preprocess_index = std::make_shared(); -#endif auto model = preprocess_index->Train(dataset, config); preprocess_index->set_index_model(model); preprocess_index->AddWithoutIds(dataset, config); - - Graph knng; preprocess_index->GenGraph(build_cfg->knng, knng, dataset, config); +#endif algo::BuildParams b_params; b_params.candidate_pool_size = build_cfg->candidate_pool_size; diff --git a/core/src/main.cpp b/core/src/main.cpp index 5c97a061d20ace9a30d9e43c663c7a675a5eb0c6..670a992d298ac3f600ab7345d9d7b062dcc95fcb 100644 --- a/core/src/main.cpp +++ b/core/src/main.cpp @@ -58,10 +58,10 @@ print_banner() { << "OpenBLAS" #endif << " library." << std::endl; -#ifdef MILVUS_CPU_VERSION - std::cout << "You are using Milvus CPU edition" << std::endl; -#else +#ifdef MILVUS_GPU_VERSION std::cout << "You are using Milvus GPU edition" << std::endl; +#else + std::cout << "You are using Milvus CPU edition" << std::endl; #endif std::cout << std::endl; } diff --git a/core/src/scheduler/SchedInst.h b/core/src/scheduler/SchedInst.h index 1e8a7acf2e385ba79c4a6f5b9ebb7c6f12897fef..6cca3770336c64ffcb8140c9ebf84bc7e428ab61 100644 --- a/core/src/scheduler/SchedInst.h +++ b/core/src/scheduler/SchedInst.h @@ -25,6 +25,7 @@ #include "optimizer/BuildIndexPass.h" #include "optimizer/FaissFlatPass.h" #include "optimizer/FaissIVFFlatPass.h" +#include "optimizer/FaissIVFPQPass.h" #include "optimizer/FaissIVFSQ8HPass.h" #include "optimizer/FaissIVFSQ8Pass.h" #include "optimizer/FallbackPass.h" @@ -129,7 +130,10 @@ class OptimizerInst { pass_list.push_back(std::make_shared()); pass_list.push_back(std::make_shared()); pass_list.push_back(std::make_shared()); +#ifdef CUSTOMIZATION pass_list.push_back(std::make_shared()); +#endif + pass_list.push_back(std::make_shared()); } #endif pass_list.push_back(std::make_shared()); diff --git a/core/src/scheduler/optimizer/FaissIVFPQPass.cpp b/core/src/scheduler/optimizer/FaissIVFPQPass.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f97fec63b4ced91125d2b72afa2847e3f737ffdd --- /dev/null +++ b/core/src/scheduler/optimizer/FaissIVFPQPass.cpp @@ -0,0 +1,74 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#ifdef MILVUS_GPU_VERSION +#include "scheduler/optimizer/FaissIVFPQPass.h" +#include "cache/GpuCacheMgr.h" +#include "scheduler/SchedInst.h" +#include "scheduler/Utils.h" +#include "scheduler/task/SearchTask.h" +#include "scheduler/tasklabel/SpecResLabel.h" +#include "server/Config.h" +#include "utils/Log.h" + +namespace milvus { +namespace scheduler { + +void +FaissIVFPQPass::Init() { +#ifdef MILVUS_GPU_VERSION + server::Config& config = server::Config::GetInstance(); + Status s = config.GetEngineConfigGpuSearchThreshold(threshold_); + if (!s.ok()) { + threshold_ = std::numeric_limits::max(); + } + s = config.GetGpuResourceConfigSearchResources(gpus); + if (!s.ok()) { + throw; + } +#endif +} + +bool +FaissIVFPQPass::Run(const TaskPtr& task) { + if (task->Type() != TaskType::SearchTask) { + return false; + } + + auto search_task = std::static_pointer_cast(task); + if (search_task->file_->engine_type_ != (int)engine::EngineType::FAISS_PQ) { + return false; + } + + auto search_job = std::static_pointer_cast(search_task->job_.lock()); + ResourcePtr res_ptr; + if (search_job->nq() < threshold_) { + SERVER_LOG_DEBUG << "FaissIVFPQPass: nq < gpu_search_threshold, specify cpu to search!"; + res_ptr = ResMgrInst::GetInstance()->GetResource("cpu"); + } else { + auto best_device_id = count_ % gpus.size(); + SERVER_LOG_DEBUG << "FaissIVFPQPass: nq > gpu_search_threshold, specify gpu" << best_device_id << " to search!"; + count_++; + res_ptr = ResMgrInst::GetInstance()->GetResource(ResourceType::GPU, gpus[best_device_id]); + } + auto label = std::make_shared(res_ptr); + task->label() = label; + return true; +} + +} // namespace scheduler +} // namespace milvus +#endif diff --git a/core/src/scheduler/optimizer/FaissIVFPQPass.h b/core/src/scheduler/optimizer/FaissIVFPQPass.h new file mode 100644 index 0000000000000000000000000000000000000000..9225f84b7c4d6839fa5687e428eb7b17f44ea38f --- /dev/null +++ b/core/src/scheduler/optimizer/FaissIVFPQPass.h @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#ifdef MILVUS_GPU_VERSION +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "Pass.h" + +namespace milvus { +namespace scheduler { + +class FaissIVFPQPass : public Pass { + public: + FaissIVFPQPass() = default; + + public: + void + Init() override; + + bool + Run(const TaskPtr& task) override; + + private: + int64_t threshold_ = std::numeric_limits::max(); + int64_t count_ = 0; + std::vector gpus; +}; + +using FaissIVFPQPassPtr = std::shared_ptr; + +} // namespace scheduler +} // namespace milvus +#endif diff --git a/core/src/server/Server.cpp b/core/src/server/Server.cpp index 169463080e4db1fd8147317fdce78bb627323242..100db5d2ece2f9346f86c9425965312b1400faf6 100644 --- a/core/src/server/Server.cpp +++ b/core/src/server/Server.cpp @@ -183,10 +183,10 @@ Server::Start() { // print version information SERVER_LOG_INFO << "Milvus " << BUILD_TYPE << " version: v" << MILVUS_VERSION << ", built at " << BUILD_TIME; -#ifdef MILVUS_CPU_VERSION - SERVER_LOG_INFO << "CPU edition"; -#else +#ifdef MILVUS_GPU_VERSION SERVER_LOG_INFO << "GPU edition"; +#else + SERVER_LOG_INFO << "CPU edition"; #endif server::Metrics::GetInstance().Init(); server::SystemInfo::GetInstance().Init(); diff --git a/core/src/wrapper/ConfAdapter.cpp b/core/src/wrapper/ConfAdapter.cpp index 7644e77ef5f4a00b380afeda1c7204b833b1d239..9ee2f060b107b49a5e815e8901056eb1dcdafa99 100644 --- a/core/src/wrapper/ConfAdapter.cpp +++ b/core/src/wrapper/ConfAdapter.cpp @@ -39,8 +39,6 @@ void ConfAdapter::MatchBase(knowhere::Config conf) { if (conf->metric_type == knowhere::DEFAULT_TYPE) conf->metric_type = knowhere::METRICTYPE::L2; - if (conf->gpu_id == knowhere::INVALID_VALUE) - conf->gpu_id = 0; } knowhere::Config diff --git a/core/unittest/db/utils.cpp b/core/unittest/db/utils.cpp index 293eeccc691879d519905f34b19d950aab9876ac..a57bae79b5f7c7ab1b27f416f0feed9cd1ee63aa 100644 --- a/core/unittest/db/utils.cpp +++ b/core/unittest/db/utils.cpp @@ -68,17 +68,16 @@ static const char* CONFIG_STR = "engine_config:\n" " use_blas_threshold: 20\n" "\n" - "resource_config:\n" -#ifdef MILVUS_CPU_VERSION - " search_resources:\n" - " - cpu\n" - " index_build_device: cpu # CPU used for building index"; -#else - " search_resources:\n" - " - cpu\n" +#ifdef MILVUS_GPU_VERSION + "gpu_resource_config:\n" + " enable: true # whether to enable GPU resources\n" + " cache_capacity: 4 # GB, size of GPU memory per card used for cache, must be a positive integer\n" + " search_resources: # define the GPU devices used for search computation, must be in format gpux\n" + " - gpu0\n" + " build_index_resources: # define the GPU devices used for index building, must be in format gpux\n" " - gpu0\n" - " index_build_device: gpu0 # GPU used for building index"; #endif + "\n"; void WriteToFile(const std::string& file_path, const char* content) { diff --git a/core/unittest/server/utils.cpp b/core/unittest/server/utils.cpp index 6fb424356df379eb7e228eaaff5540042be0d5d7..c232b1185aa2fbaef315502c2e033218cc03bf85 100644 --- a/core/unittest/server/utils.cpp +++ b/core/unittest/server/utils.cpp @@ -54,24 +54,21 @@ static const char* VALID_CONFIG_STR = "cache_config:\n" " cpu_cache_capacity: 16 # GB, CPU memory used for cache\n" " cpu_cache_threshold: 0.85 \n" - " gpu_cache_capacity: 4 # GB, GPU memory used for cache\n" - " gpu_cache_threshold: 0.85 \n" " cache_insert_data: false # whether to load inserted data into cache\n" "\n" "engine_config:\n" " use_blas_threshold: 20 \n" "\n" - "resource_config:\n" -#ifdef MILVUS_CPU_VERSION - " search_resources:\n" - " - cpu\n" - " index_build_device: cpu # CPU used for building index"; -#else - " search_resources:\n" - " - cpu\n" +#ifdef MILVUS_GPU_VERSION + "gpu_resource_config:\n" + " enable: true # whether to enable GPU resources\n" + " cache_capacity: 4 # GB, size of GPU memory per card used for cache, must be a positive integer\n" + " search_resources: # define the GPU devices used for search computation, must be in format gpux\n" + " - gpu0\n" + " build_index_resources: # define the GPU devices used for index building, must be in format gpux\n" " - gpu0\n" - " index_build_device: gpu0 # GPU used for building index"; #endif + "\n"; static const char* INVALID_CONFIG_STR = "*INVALID*"; diff --git a/core/unittest/wrapper/utils.cpp b/core/unittest/wrapper/utils.cpp index 96b9e643f5d07251362f9ed72b392ad09e297461..a5f8e1b6b22f53d13805d4864ecf8a1084867e21 100644 --- a/core/unittest/wrapper/utils.cpp +++ b/core/unittest/wrapper/utils.cpp @@ -56,17 +56,16 @@ static const char* CONFIG_STR = "engine_config:\n" " blas_threshold: 20\n" "\n" - "resource_config:\n" -#ifdef MILVUS_CPU_VERSION - " search_resources:\n" - " - cpu\n" - " index_build_device: cpu # CPU used for building index"; -#else - " search_resources:\n" - " - cpu\n" +#ifdef MILVUS_GPU_VERSION + "gpu_resource_config:\n" + " enable: true # whether to enable GPU resources\n" + " cache_capacity: 4 # GB, size of GPU memory per card used for cache, must be a positive integer\n" + " search_resources: # define the GPU devices used for search computation, must be in format gpux\n" + " - gpu0\n" + " build_index_resources: # define the GPU devices used for index building, must be in format gpux\n" " - gpu0\n" - " index_build_device: gpu0 # GPU used for building index"; #endif + "\n"; void WriteToFile(const std::string& file_path, const char* content) { diff --git a/tests/milvus_python_test/test_index.py b/tests/milvus_python_test/test_index.py index b253cf02a36b34f990ff9db74c3f2512efdcf752..8ce03b6a61e9e1f0a714a83fefaf0f2160248e0d 100644 --- a/tests/milvus_python_test/test_index.py +++ b/tests/milvus_python_test/test_index.py @@ -497,6 +497,7 @@ class TestIndexBase: status, ids = connect.add_vectors(table, vectors) for i in range(2): status = connect.create_index(table, index_params) + assert status.OK() status, result = connect.describe_index(table) logging.getLogger().info(result) @@ -569,7 +570,10 @@ class TestIndexIP: logging.getLogger().info(index_params) status, ids = connect.add_vectors(ip_table, vectors) status = connect.create_index(ip_table, index_params) - assert status.OK() + if index_params["index_type"] == IndexType.IVF_PQ: + assert not status.OK() + else: + assert status.OK() @pytest.mark.timeout(BUILD_TIMEOUT) def test_create_index_partition(self, connect, ip_table, get_index_params): @@ -584,7 +588,10 @@ class TestIndexIP: status = connect.create_partition(ip_table, partition_name, tag) status, ids = connect.add_vectors(ip_table, vectors, partition_tag=tag) status = connect.create_index(partition_name, index_params) - assert status.OK() + if index_params["index_type"] == IndexType.IVF_PQ: + assert not status.OK() + else: + assert status.OK() @pytest.mark.level(2) def test_create_index_without_connect(self, dis_connect, ip_table): @@ -609,14 +616,17 @@ class TestIndexIP: logging.getLogger().info(index_params) status, ids = connect.add_vectors(ip_table, vectors) status = connect.create_index(ip_table, index_params) - assert status.OK() - logging.getLogger().info(connect.describe_index(ip_table)) - query_vecs = [vectors[0], vectors[1], vectors[2]] - top_k = 5 - status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs) - logging.getLogger().info(result) - assert status.OK() - assert len(result) == len(query_vecs) + if index_params["index_type"] == IndexType.IVF_PQ: + assert not status.OK() + else: + assert status.OK() + logging.getLogger().info(connect.describe_index(ip_table)) + query_vecs = [vectors[0], vectors[1], vectors[2]] + top_k = 5 + status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs) + logging.getLogger().info(result) + assert status.OK() + assert len(result) == len(query_vecs) # TODO: enable @pytest.mark.timeout(BUILD_TIMEOUT) @@ -943,16 +953,19 @@ class TestIndexIP: index_params = get_index_params status, ids = connect.add_vectors(ip_table, vectors) status = connect.create_index(ip_table, index_params) - assert status.OK() - status, result = connect.describe_index(ip_table) - logging.getLogger().info(result) - status = connect.drop_index(ip_table) - assert status.OK() - status, result = connect.describe_index(ip_table) - logging.getLogger().info(result) - assert result._nlist == 16384 - assert result._table_name == ip_table - assert result._index_type == IndexType.FLAT + if index_params["index_type"] == IndexType.IVF_PQ: + assert not status.OK() + else: + assert status.OK() + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + status = connect.drop_index(ip_table) + assert status.OK() + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == ip_table + assert result._index_type == IndexType.FLAT def test_drop_index_partition(self, connect, ip_table, get_simple_index_params): ''' @@ -965,16 +978,19 @@ class TestIndexIP: status = connect.create_partition(ip_table, partition_name, tag) status, ids = connect.add_vectors(ip_table, vectors, partition_tag=tag) status = connect.create_index(ip_table, index_params) - assert status.OK() - status, result = connect.describe_index(ip_table) - logging.getLogger().info(result) - status = connect.drop_index(ip_table) - assert status.OK() - status, result = connect.describe_index(ip_table) - logging.getLogger().info(result) - assert result._nlist == 16384 - assert result._table_name == ip_table - assert result._index_type == IndexType.FLAT + if index_params["index_type"] == IndexType.IVF_PQ: + assert not status.OK() + else: + assert status.OK() + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + status = connect.drop_index(ip_table) + assert status.OK() + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == ip_table + assert result._index_type == IndexType.FLAT def test_drop_index_partition_A(self, connect, ip_table, get_simple_index_params): ''' @@ -987,19 +1003,22 @@ class TestIndexIP: status = connect.create_partition(ip_table, partition_name, tag) status, ids = connect.add_vectors(ip_table, vectors, partition_tag=tag) status = connect.create_index(partition_name, index_params) - assert status.OK() - status = connect.drop_index(ip_table) - assert status.OK() - status, result = connect.describe_index(ip_table) - logging.getLogger().info(result) - assert result._nlist == 16384 - assert result._table_name == ip_table - assert result._index_type == IndexType.FLAT - status, result = connect.describe_index(partition_name) - logging.getLogger().info(result) - assert result._nlist == 16384 - assert result._table_name == partition_name - assert result._index_type == IndexType.FLAT + if index_params["index_type"] == IndexType.IVF_PQ: + assert not status.OK() + else: + assert status.OK() + status = connect.drop_index(ip_table) + assert status.OK() + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == ip_table + assert result._index_type == IndexType.FLAT + status, result = connect.describe_index(partition_name) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == partition_name + assert result._index_type == IndexType.FLAT def test_drop_index_partition_B(self, connect, ip_table, get_simple_index_params): ''' @@ -1012,19 +1031,22 @@ class TestIndexIP: status = connect.create_partition(ip_table, partition_name, tag) status, ids = connect.add_vectors(ip_table, vectors, partition_tag=tag) status = connect.create_index(partition_name, index_params) - assert status.OK() - status = connect.drop_index(partition_name) - assert status.OK() - status, result = connect.describe_index(ip_table) - logging.getLogger().info(result) - assert result._nlist == 16384 - assert result._table_name == ip_table - assert result._index_type == IndexType.FLAT - status, result = connect.describe_index(partition_name) - logging.getLogger().info(result) - assert result._nlist == 16384 - assert result._table_name == partition_name - assert result._index_type == IndexType.FLAT + if index_params["index_type"] == IndexType.IVF_PQ: + assert not status.OK() + else: + assert status.OK() + status = connect.drop_index(partition_name) + assert status.OK() + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == ip_table + assert result._index_type == IndexType.FLAT + status, result = connect.describe_index(partition_name) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == partition_name + assert result._index_type == IndexType.FLAT def test_drop_index_partition_C(self, connect, ip_table, get_simple_index_params): ''' @@ -1040,24 +1062,27 @@ class TestIndexIP: status = connect.create_partition(ip_table, new_partition_name, new_tag) status, ids = connect.add_vectors(ip_table, vectors) status = connect.create_index(ip_table, index_params) - assert status.OK() - status = connect.drop_index(new_partition_name) - assert status.OK() - status, result = connect.describe_index(new_partition_name) - logging.getLogger().info(result) - assert result._nlist == 16384 - assert result._table_name == new_partition_name - assert result._index_type == IndexType.FLAT - status, result = connect.describe_index(partition_name) - logging.getLogger().info(result) - assert result._nlist == index_params["nlist"] - assert result._table_name == partition_name - assert result._index_type == index_params["index_type"] - status, result = connect.describe_index(ip_table) - logging.getLogger().info(result) - assert result._nlist == index_params["nlist"] - assert result._table_name == ip_table - assert result._index_type == index_params["index_type"] + if index_params["index_type"] == IndexType.IVF_PQ: + assert not status.OK() + else: + assert status.OK() + status = connect.drop_index(new_partition_name) + assert status.OK() + status, result = connect.describe_index(new_partition_name) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == new_partition_name + assert result._index_type == IndexType.FLAT + status, result = connect.describe_index(partition_name) + logging.getLogger().info(result) + assert result._nlist == index_params["nlist"] + assert result._table_name == partition_name + assert result._index_type == index_params["index_type"] + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + assert result._nlist == index_params["nlist"] + assert result._table_name == ip_table + assert result._index_type == index_params["index_type"] def test_drop_index_repeatly(self, connect, ip_table, get_simple_index_params): ''' @@ -1068,18 +1093,21 @@ class TestIndexIP: index_params = get_simple_index_params status, ids = connect.add_vectors(ip_table, vectors) status = connect.create_index(ip_table, index_params) - assert status.OK() - status, result = connect.describe_index(ip_table) - logging.getLogger().info(result) - status = connect.drop_index(ip_table) - assert status.OK() - status = connect.drop_index(ip_table) - assert status.OK() - status, result = connect.describe_index(ip_table) - logging.getLogger().info(result) - assert result._nlist == 16384 - assert result._table_name == ip_table - assert result._index_type == IndexType.FLAT + if index_params["index_type"] == IndexType.IVF_PQ: + assert not status.OK() + else: + assert status.OK() + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + status = connect.drop_index(ip_table) + assert status.OK() + status = connect.drop_index(ip_table) + assert status.OK() + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == ip_table + assert result._index_type == IndexType.FLAT @pytest.mark.level(2) def test_drop_index_without_connect(self, dis_connect, ip_table): @@ -1120,16 +1148,19 @@ class TestIndexIP: status, ids = connect.add_vectors(ip_table, vectors) for i in range(2): status = connect.create_index(ip_table, index_params) - assert status.OK() - status, result = connect.describe_index(ip_table) - logging.getLogger().info(result) - status = connect.drop_index(ip_table) - assert status.OK() - status, result = connect.describe_index(ip_table) - logging.getLogger().info(result) - assert result._nlist == 16384 - assert result._table_name == ip_table - assert result._index_type == IndexType.FLAT + if index_params["index_type"] == IndexType.IVF_PQ: + assert not status.OK() + else: + assert status.OK() + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + status = connect.drop_index(ip_table) + assert status.OK() + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == ip_table + assert result._index_type == IndexType.FLAT def test_create_drop_index_repeatly_different_index_params(self, connect, ip_table): ''' diff --git a/tests/milvus_python_test/utils.py b/tests/milvus_python_test/utils.py index 6f7c81d13502090576b7d839ca3806d54c4e2bcc..e591521815c917c7bc09857db39db95fe5d029ca 100644 --- a/tests/milvus_python_test/utils.py +++ b/tests/milvus_python_test/utils.py @@ -437,7 +437,7 @@ def gen_invalid_index_params(): def gen_index_params(): index_params = [] - index_types = [IndexType.FLAT, IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H] + index_types = [IndexType.FLAT, IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H, IndexType.IVF_PQ] nlists = [1, 16384, 50000] def gen_params(index_types, nlists): @@ -450,7 +450,7 @@ def gen_index_params(): def gen_simple_index_params(): index_params = [] - index_types = [IndexType.FLAT, IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H] + index_types = [IndexType.FLAT, IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H, IndexType.IVF_PQ] nlists = [1024] def gen_params(index_types, nlists):