diff --git a/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexHNSW_SQ8NR.cpp b/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexHNSW_SQ8NR.cpp index 3f09cd953dcf6862127ce3e04d78e1e39bc4e0e7..5b34701e7d5781c1e2555c436992e0667a34903a 100644 --- a/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexHNSW_SQ8NR.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexHNSW_SQ8NR.cpp @@ -64,6 +64,7 @@ IndexHNSW_SQ8NR::Load(const BinarySet& index_binary) { normalize = (index_->metric_type_ == 1); // 1 == InnerProduct data_ = index_binary.GetByName(SQ8_DATA)->data; + index_->SetSq8((float*)(data_.get() + Dim() * Count())); } catch (std::exception& e) { KNOWHERE_THROW_MSG(e.what()); } @@ -84,7 +85,6 @@ IndexHNSW_SQ8NR::Train(const DatasetPtr& dataset_ptr, const Config& config) { index_ = std::make_shared>( space, rows, config[IndexParams::M].get(), config[IndexParams::efConstruction].get()); auto data_space = new uint8_t[dim * (rows + 2 * sizeof(float))]; - index_->SetSq8(true); index_->sq_train(rows, (const float*)p_data, data_space); data_ = std::shared_ptr(data_space); } catch (std::exception& e) { diff --git a/core/src/index/thirdparty/hnswlib/hnswalg_nm.h b/core/src/index/thirdparty/hnswlib/hnswalg_nm.h index 1037873e5d1bfcdc10747ccf47f4a33189ad5f59..bc385a5491777edf992b1ee7042563941d2ec263 100644 --- a/core/src/index/thirdparty/hnswlib/hnswalg_nm.h +++ b/core/src/index/thirdparty/hnswlib/hnswalg_nm.h @@ -56,7 +56,6 @@ namespace hnswlib_nm { is_sq8_ = false; sq_ = nullptr; - sqdc_ = nullptr; level_generator_.seed(random_seed); @@ -104,7 +103,6 @@ namespace hnswlib_nm { delete visited_list_pool_; if (sq_) delete sq_; - if (sqdc_) delete sqdc_; // linxj: delete delete space; @@ -126,7 +124,6 @@ namespace hnswlib_nm { bool is_sq8_ = false; faiss::ScalarQuantizer *sq_ = nullptr; - faiss::SQDistanceComputer *sqdc_ = nullptr; double mult_, revSize_; int maxlevel_; @@ -160,31 +157,27 @@ namespace hnswlib_nm { return ((char*)pdata + offset * data_size_); } - void SetSq8(bool set_sq8) { - is_sq8_ = set_sq8; - if (is_sq8_) { - sq_ = new faiss::ScalarQuantizer(*(size_t*)dist_func_param_, faiss::QuantizerType::QT_8bit); // hard code - } + void SetSq8(const float *trained) { + if (!trained) + throw std::runtime_error("trained sq8 data cannot be null in SetSq8!"); + if (sq_) delete sq_; + is_sq8_ = true; + sq_ = new faiss::ScalarQuantizer(*(size_t*)dist_func_param_, faiss::QuantizerType::QT_8bit); // hard code + sq_->trained.resize((sq_->d) << 1); + memcpy(sq_->trained.data(), trained, sq_->trained.size() * sizeof(float)); } void sq_train(size_t nb, const float *xb, uint8_t *p_codes) { - if (!is_sq8_) { - throw std::runtime_error("is_sq8 should be set true by interface SetSq8(true) before you invoke sq_train!"); - } - if (!sq_) { - sq_ = new faiss::ScalarQuantizer(*(size_t*)dist_func_param_, faiss::QuantizerType::QT_8bit); // hard code - } + if (!p_codes) + throw std::runtime_error("p_codes cannot be null in sq_train!"); + if (!xb) + throw std::runtime_error("base vector cannot be null in sq_train!"); + if (sq_) delete sq_; + is_sq8_ = true; + sq_ = new faiss::ScalarQuantizer(*(size_t*)dist_func_param_, faiss::QuantizerType::QT_8bit); // hard code sq_->train(nb, xb); sq_->compute_codes(xb, p_codes, nb); memcpy(p_codes + *(size_t*)dist_func_param_ * nb, sq_->trained.data(), *(size_t*)dist_func_param_ * sizeof(float) * 2); - if (metric_type_ == 0) { // L2 - sqdc_ = new DCClassL2(sq_->d, sq_->trained); - } else if (metric_type_ == 1) { // IP - sqdc_ = new DCClassIP(sq_->d, sq_->trained); - } else { - throw std::runtime_error("unsupported metric_type, it must be 0(L2) or 1(IP)!"); - } - sqdc_->code_size = sq_->code_size; } int getRandomLevel(double reverse_size) { @@ -282,9 +275,18 @@ namespace hnswlib_nm { vl_type *visited_array = vl->mass; vl_type visited_array_tag = vl->curV; + faiss::SQDistanceComputer *sqdc = nullptr; if (is_sq8_) { - sqdc_->codes = (uint8_t*)pdata; - sqdc_->set_query((const float*)data_point); + if (metric_type_ == 0) { // L2 + sqdc = new DCClassL2(sq_->d, sq_->trained); + } else if (metric_type_ == 1) { // IP + sqdc = new DCClassIP(sq_->d, sq_->trained); + } else { + throw std::runtime_error("unsupported metric_type, it must be 0(L2) or 1(IP)!"); + } + sqdc->code_size = sq_->code_size; + sqdc->codes = (uint8_t*)pdata; + sqdc->set_query((const float*)data_point); } std::priority_queue, std::vector>, CompareByFirst> top_candidates; @@ -295,7 +297,7 @@ namespace hnswlib_nm { if (!has_deletions || !bitset->test((faiss::ConcurrentBitset::id_type_t)(ep_id))) { dist_t dist; if (is_sq8_) { - dist = (*sqdc_)(ep_id); + dist = (*sqdc)(ep_id); } else { dist = fstdistfunc_(data_point, getDataByInternalId(pdata, ep_id), dist_func_param_); } @@ -345,7 +347,7 @@ namespace hnswlib_nm { dist_t dist; if (is_sq8_) { - dist = (*sqdc_)(candidate_id); + dist = (*sqdc)(candidate_id); } else { char *currObj1 = (getDataByInternalId(pdata, candidate_id)); dist = fstdistfunc_(data_point, currObj1, dist_func_param_); @@ -373,6 +375,7 @@ namespace hnswlib_nm { } visited_list_pool_->releaseVisitedList(vl); + if (is_sq8_) delete sqdc; return top_candidates; } @@ -1117,10 +1120,19 @@ namespace hnswlib_nm { tableint currObj = enterpoint_node_; dist_t curdist; + faiss::SQDistanceComputer *sqdc = nullptr; if (is_sq8_) { - sqdc_->set_query((const float*)query_data); - sqdc_->codes = (uint8_t*)pdata; - curdist = (*sqdc_)(currObj); + if (metric_type_ == 0) { // L2 + sqdc = new DCClassL2(sq_->d, sq_->trained); + } else if (metric_type_ == 1) { // IP + sqdc = new DCClassIP(sq_->d, sq_->trained); + } else { + throw std::runtime_error("unsupported metric_type, it must be 0(L2) or 1(IP)!"); + } + sqdc->code_size = sq_->code_size; + sqdc->set_query((const float*)query_data); + sqdc->codes = (uint8_t*)pdata; + curdist = (*sqdc)(currObj); } else { curdist = fstdistfunc_(query_data, getDataByInternalId(pdata, enterpoint_node_), dist_func_param_); } @@ -1140,7 +1152,7 @@ namespace hnswlib_nm { throw std::runtime_error("cand error"); dist_t d; if (is_sq8_) { - d = (*sqdc_)(cand); + d = (*sqdc)(cand); } else { d = fstdistfunc_(query_data, getDataByInternalId(pdata, cand), dist_func_param_); } @@ -1174,6 +1186,7 @@ namespace hnswlib_nm { result.push(std::pair(rez.first, rez.second)); top_candidates.pop(); } + if (is_sq8_) delete sqdc; return result; }; diff --git a/core/src/index/unittest/CMakeLists.txt b/core/src/index/unittest/CMakeLists.txt index 69859f34f0b62d4256a43e4d83b8ad2e0c5e3d54..419b5516b9fb62cd3c516c81d209fd1f4392e7c9 100644 --- a/core/src/index/unittest/CMakeLists.txt +++ b/core/src/index/unittest/CMakeLists.txt @@ -192,6 +192,17 @@ endif () target_link_libraries(test_hnsw ${depend_libs} ${unittest_libs} ${basic_libs}) install(TARGETS test_hnsw DESTINATION unittest) +################################################################################ +# +set(hnsw_sq8nr_srcs + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_offset_index/IndexHNSW_SQ8NR.cpp + ) +if (NOT TARGET test_hnsw_sq8nr) + add_executable(test_hnsw_sq8nr test_hnsw_sq8nr.cpp ${hnsw_sq8nr_srcs} ${util_srcs}) +endif () +target_link_libraries(test_hnsw_sq8nr ${depend_libs} ${unittest_libs} ${basic_libs}) +install(TARGETS test_hnsw_sq8nr DESTINATION unittest) + ################################################################################ # if (MILVUS_SUPPORT_SPTAG) diff --git a/core/src/index/unittest/test_hnsw_sq8nr.cpp b/core/src/index/unittest/test_hnsw_sq8nr.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1fcd708576bcf9cac7552216cebdaf6be9735527 --- /dev/null +++ b/core/src/index/unittest/test_hnsw_sq8nr.cpp @@ -0,0 +1,304 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include +#include +#include +#include +#include +#include "knowhere/common/Exception.h" +#include "unittest/utils.h" + +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; + +class HNSWSQ8NRTest : public DataGen, public TestWithParam { + protected: + void + SetUp() override { + IndexType = GetParam(); + std::cout << "IndexType from GetParam() is: " << IndexType << std::endl; + Generate(64, 10000, 10); // dim = 64, nb = 10000, nq = 10 + // Generate(2, 10, 2); // dim = 64, nb = 10000, nq = 10 + index_ = std::make_shared(); + conf = milvus::knowhere::Config{ + {milvus::knowhere::meta::DIM, 64}, {milvus::knowhere::meta::TOPK, 10}, + {milvus::knowhere::IndexParams::M, 16}, {milvus::knowhere::IndexParams::efConstruction, 200}, + {milvus::knowhere::IndexParams::ef, 200}, {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + }; + /* + conf = milvus::knowhere::Config{ + {milvus::knowhere::meta::DIM, 2}, {milvus::knowhere::meta::TOPK, 2}, + {milvus::knowhere::IndexParams::M, 2}, {milvus::knowhere::IndexParams::efConstruction, 4}, + {milvus::knowhere::IndexParams::ef, 7}, {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + }; + */ + } + + protected: + milvus::knowhere::Config conf; + std::shared_ptr index_ = nullptr; + std::string IndexType; +}; + +INSTANTIATE_TEST_CASE_P(HNSWParameters, HNSWSQ8NRTest, Values("HNSWSQ8NR")); + +TEST_P(HNSWSQ8NRTest, HNSW_basic) { + assert(!xb.empty()); + + // null faiss index + /* + { + ASSERT_ANY_THROW(index_->Serialize()); + ASSERT_ANY_THROW(index_->Query(query_dataset, conf)); + ASSERT_ANY_THROW(index_->Add(nullptr, conf)); + ASSERT_ANY_THROW(index_->AddWithoutIds(nullptr, conf)); + ASSERT_ANY_THROW(index_->Count()); + ASSERT_ANY_THROW(index_->Dim()); + } + */ + + index_->Train(base_dataset, conf); + index_->Add(base_dataset, conf); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + + // Serialize and Load before Query + milvus::knowhere::BinarySet bs = index_->Serialize(); + + // int64_t dim = base_dataset->Get(milvus::knowhere::meta::DIM); + // int64_t rows = base_dataset->Get(milvus::knowhere::meta::ROWS); + // auto raw_data = base_dataset->Get(milvus::knowhere::meta::TENSOR); + // milvus::knowhere::BinaryPtr bptr = std::make_shared(); + // bptr->data = std::shared_ptr((uint8_t*)raw_data, [&](uint8_t*) {}); + // bptr->size = dim * rows * sizeof(float); + // bs.Append(RAW_DATA, bptr); + + index_->Load(bs); + + auto result = index_->Query(query_dataset, conf); + AssertAnns(result, nq, k); +} + +TEST_P(HNSWSQ8NRTest, HNSW_delete) { + assert(!xb.empty()); + + index_->Train(base_dataset, conf); + index_->Add(base_dataset, conf); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + + faiss::ConcurrentBitsetPtr bitset = std::make_shared(nb); + for (auto i = 0; i < nq; ++i) { + bitset->set(i); + } + + // Serialize and Load before Query + milvus::knowhere::BinarySet bs = index_->Serialize(); + + // int64_t dim = base_dataset->Get(milvus::knowhere::meta::DIM); + // int64_t rows = base_dataset->Get(milvus::knowhere::meta::ROWS); + // auto raw_data = base_dataset->Get(milvus::knowhere::meta::TENSOR); + // milvus::knowhere::BinaryPtr bptr = std::make_shared(); + // bptr->data = std::shared_ptr((uint8_t*)raw_data, [&](uint8_t*) {}); + // bptr->size = dim * rows * sizeof(float); + // bs.Append(RAW_DATA, bptr); + + index_->Load(bs); + + auto result1 = index_->Query(query_dataset, conf); + AssertAnns(result1, nq, k); + + index_->SetBlacklist(bitset); + auto result2 = index_->Query(query_dataset, conf); + AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL); + + /* + * delete result checked by eyes + auto ids1 = result1->Get(milvus::knowhere::meta::IDS); + auto ids2 = result2->Get(milvus::knowhere::meta::IDS); + std::cout << std::endl; + for (int i = 0; i < nq; ++ i) { + std::cout << "ids1: "; + for (int j = 0; j < k; ++ j) { + std::cout << *(ids1 + i * k + j) << " "; + } + std::cout << "ids2: "; + for (int j = 0; j < k; ++ j) { + std::cout << *(ids2 + i * k + j) << " "; + } + std::cout << std::endl; + for (int j = 0; j < std::min(5, k>>1); ++ j) { + ASSERT_EQ(*(ids1 + i * k + j + 1), *(ids2 + i * k + j)); + } + } + */ +} + +TEST_P(HNSWSQ8NRTest, HNSW_serialize) { + auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) { + { + FileIOWriter writer(filename); + writer(static_cast(bin->data.get()), bin->size); + } + + FileIOReader reader(filename); + reader(ret, bin->size); + }; + + { + index_->Train(base_dataset, conf); + index_->Add(base_dataset, conf); + auto binaryset = index_->Serialize(); + auto bin_index = binaryset.GetByName("HNSW_SQ8"); + auto bin_sq8 = binaryset.GetByName(SQ8_DATA); + + std::string filename = "/tmp/HNSW_SQ8NR_test_serialize_index.bin"; + std::string filename2 = "/tmp/HNSW_SQ8NR_test_serialize_sq8.bin"; + auto load_index_data = new uint8_t[bin_index->size]; + serialize(filename, bin_index, load_index_data); + auto load_sq8_data = new uint8_t[bin_sq8->size]; + serialize(filename2, bin_sq8, load_sq8_data); + + binaryset.clear(); + std::shared_ptr data_index(load_index_data); + binaryset.Append("HNSW_SQ8", data_index, bin_index->size); + std::shared_ptr sq8_index(load_sq8_data); + binaryset.Append(SQ8_DATA, sq8_index, bin_sq8->size); + + index_->Load(binaryset); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + auto result = index_->Query(query_dataset, conf); + AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); + } +} + +/* + * faiss style test + * keep it +int +main() { + int64_t d = 64; // dimension + int64_t nb = 10000; // database size + int64_t nq = 10; // 10000; // nb of queries + faiss::ConcurrentBitsetPtr bitset = std::make_shared(nb); + + int64_t* ids = new int64_t[nb]; + float* xb = new float[d * nb]; + float* xq = new float[d * nq]; + // int64_t *ids = (int64_t*)malloc(nb * sizeof(int64_t)); + // float* xb = (float*)malloc(d * nb * sizeof(float)); + // float* xq = (float*)malloc(d * nq * sizeof(float)); + + for (int i = 0; i < nb; i++) { + for (int j = 0; j < d; j++) xb[d * i + j] = drand48(); + xb[d * i] += i / 1000.; + ids[i] = i; + } +// printf("gen xb and ids done! \n"); + + // srand((unsigned)time(nullptr)); + auto random_seed = (unsigned)time(nullptr); +// printf("delete ids: \n"); + for (int i = 0; i < nq; i++) { + auto tmp = rand_r(&random_seed) % nb; +// printf("%ld\n", tmp); + // std::cout << "before delete, test result: " << bitset->test(tmp) << std::endl; + bitset->set(tmp); + // std::cout << "after delete, test result: " << bitset->test(tmp) << std::endl; + for (int j = 0; j < d; j++) xq[d * i + j] = xb[d * tmp + j]; + // xq[d * i] += i / 1000.; + } +// printf("\n"); + + int k = 4; + int m = 16; + int ef = 200; + milvus::knowhere::IndexHNSW_NM index; + milvus::knowhere::DatasetPtr base_dataset = generate_dataset(nb, d, (const void*)xb, ids); +// base_dataset->Set(milvus::knowhere::meta::ROWS, nb); +// base_dataset->Set(milvus::knowhere::meta::DIM, d); +// base_dataset->Set(milvus::knowhere::meta::TENSOR, (const void*)xb); +// base_dataset->Set(milvus::knowhere::meta::IDS, (const int64_t*)ids); + + milvus::knowhere::Config base_conf{ + {milvus::knowhere::meta::DIM, d}, + {milvus::knowhere::meta::TOPK, k}, + {milvus::knowhere::IndexParams::M, m}, + {milvus::knowhere::IndexParams::efConstruction, ef}, + {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + }; + milvus::knowhere::DatasetPtr query_dataset = generate_query_dataset(nq, d, (const void*)xq); + milvus::knowhere::Config query_conf{ + {milvus::knowhere::meta::DIM, d}, + {milvus::knowhere::meta::TOPK, k}, + {milvus::knowhere::IndexParams::M, m}, + {milvus::knowhere::IndexParams::ef, ef}, + {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + }; + + index.Train(base_dataset, base_conf); + index.Add(base_dataset, base_conf); + +// printf("------------sanity check----------------\n"); + { // sanity check + auto res = index.Query(query_dataset, query_conf); +// printf("Query done!\n"); + const int64_t* I = res->Get(milvus::knowhere::meta::IDS); +// float* D = res->Get(milvus::knowhere::meta::DISTANCE); + +// printf("I=\n"); +// for (int i = 0; i < 5; i++) { +// for (int j = 0; j < k; j++) printf("%5ld ", I[i * k + j]); +// printf("\n"); +// } + +// printf("D=\n"); +// for (int i = 0; i < 5; i++) { +// for (int j = 0; j < k; j++) printf("%7g ", D[i * k + j]); +// printf("\n"); +// } + } + +// printf("---------------search xq-------------\n"); + { // search xq + auto res = index.Query(query_dataset, query_conf); + const int64_t* I = res->Get(milvus::knowhere::meta::IDS); + + printf("I=\n"); + for (int i = 0; i < nq; i++) { + for (int j = 0; j < k; j++) printf("%5ld ", I[i * k + j]); + printf("\n"); + } + } + + printf("----------------search xq with delete------------\n"); + { // search xq with delete + index.SetBlacklist(bitset); + auto res = index.Query(query_dataset, query_conf); + auto I = res->Get(milvus::knowhere::meta::IDS); + + printf("I=\n"); + for (int i = 0; i < nq; i++) { + for (int j = 0; j < k; j++) printf("%5ld ", I[i * k + j]); + printf("\n"); + } + } + + delete[] xb; + delete[] xq; + delete[] ids; + + return 0; +} +*/