未验证 提交 4fa45dc7 编写于 作者: O op-hunter 提交者: GitHub

#1661 support HNSW deletion on nmslib (#1729)

* support HNSW deletion on nmslib
Signed-off-by: Nlichengming <chengming.li@zilliz.com>

* update changelog
Signed-off-by: Nlichengming <chengming.li@zilliz.com>

* fix lint error on test_hnsw.cpp
Signed-off-by: Nlichengming <chengming.li@zilliz.com>
Co-authored-by: Nlichengming <chengming.li@zilliz.com>
上级 e3786a24
......@@ -17,6 +17,7 @@ Please mark all change in change log and use the issue from GitHub
## Feature
- \#1603 BinaryFlat add 2 Metric: Substructure and Superstructure
- \#1660 IVF PQ CPU support deleted vectors searching
- \#1661 HNSW support deleted vectors searching
## Improvement
- \#267 Improve search performance: reduce delay
......
......@@ -132,8 +132,9 @@ IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) {
}
GETTENSOR(dataset_ptr)
size_t id_size = sizeof(int64_t) * config[meta::TOPK].get<int64_t>();
size_t dist_size = sizeof(float) * config[meta::TOPK].get<int64_t>();
size_t k = config[meta::TOPK].get<int64_t>();
size_t id_size = sizeof(int64_t) * k;
size_t dist_size = sizeof(float) * k;
auto p_id = (int64_t*)malloc(id_size * rows);
auto p_dist = (float*)malloc(dist_size * rows);
......@@ -141,6 +142,9 @@ IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) {
using P = std::pair<float, int64_t>;
auto compare = [](const P& v1, const P& v2) { return v1.first < v2.first; };
faiss::ConcurrentBitsetPtr blacklist = nullptr;
GetBlacklist(blacklist);
#pragma omp parallel for
for (unsigned int i = 0; i < rows; ++i) {
std::vector<P> ret;
......@@ -153,9 +157,9 @@ IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) {
// } else {
// ret = index_->searchKnn((float*)single_query, config[meta::TOPK].get<int64_t>(), compare);
// }
ret = index_->searchKnn((float*)single_query, config[meta::TOPK].get<int64_t>(), compare);
ret = index_->searchKnn((float*)single_query, k, compare, blacklist);
while (ret.size() < config[meta::TOPK]) {
while (ret.size() < k) {
ret.push_back(std::make_pair(-1, -1));
}
std::vector<float> dist;
......@@ -171,8 +175,8 @@ IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) {
std::transform(ret.begin(), ret.end(), std::back_inserter(ids),
[](const std::pair<float, int64_t>& e) { return e.second; });
memcpy(p_dist + i * config[meta::TOPK].get<int64_t>(), dist.data(), dist_size);
memcpy(p_id + i * config[meta::TOPK].get<int64_t>(), ids.data(), id_size);
memcpy(p_dist + i * k, dist.data(), dist_size);
memcpy(p_id + i * k, ids.data(), id_size);
}
auto ret_ds = std::make_shared<Dataset>();
......
......@@ -253,7 +253,7 @@ public:
template <bool has_deletions>
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef) const {
searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, faiss::ConcurrentBitsetPtr bitset) const {
VisitedList *vl = visited_list_pool_->getFreeVisitedList();
vl_type *visited_array = vl->mass;
vl_type visited_array_tag = vl->curV;
......@@ -262,7 +262,8 @@ public:
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;
dist_t lowerBound;
if (!has_deletions || !isMarkedDeleted(ep_id)) {
// if (!has_deletions || !isMarkedDeleted(ep_id)) {
if (!has_deletions || !bitset->test((faiss::ConcurrentBitset::id_type_t)getExternalLabel(ep_id))) {
dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
lowerBound = dist;
top_candidates.emplace(dist, ep_id);
......@@ -318,7 +319,8 @@ public:
_MM_HINT_T0);////////////////////////
#endif
if (!has_deletions || !isMarkedDeleted(candidate_id))
// if (!has_deletions || !isMarkedDeleted(candidate_id))
if (!has_deletions || (!bitset->test((faiss::ConcurrentBitset::id_type_t)getExternalLabel(candidate_id))))
top_candidates.emplace(dist, candidate_id);
if (top_candidates.size() > ef)
......@@ -1061,7 +1063,7 @@ public:
};
std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(const void *query_data, size_t k) const {
searchKnn(const void *query_data, size_t k, faiss::ConcurrentBitsetPtr bitset) const {
std::priority_queue<std::pair<dist_t, labeltype >> result;
if (cur_element_count == 0) return result;
......@@ -1093,14 +1095,14 @@ public:
}
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
if (has_deletions_) {
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates1=searchBaseLayerST<true>(
currObj, query_data, std::max(ef_, k));
if (bitset != nullptr) {
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
top_candidates1 = searchBaseLayerST<true>(currObj, query_data, std::max(ef_, k), bitset);
top_candidates.swap(top_candidates1);
}
else{
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates1=searchBaseLayerST<false>(
currObj, query_data, std::max(ef_, k));
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
top_candidates1 = searchBaseLayerST<false>(currObj, query_data, std::max(ef_, k), bitset);
top_candidates.swap(top_candidates1);
}
while (top_candidates.size() > k) {
......@@ -1116,11 +1118,11 @@ public:
template <typename Comp>
std::vector<std::pair<dist_t, labeltype>>
searchKnn(const void* query_data, size_t k, Comp comp) {
searchKnn(const void* query_data, size_t k, Comp comp, faiss::ConcurrentBitsetPtr bitset) {
std::vector<std::pair<dist_t, labeltype>> result;
if (cur_element_count == 0) return result;
auto ret = searchKnn(query_data, k);
auto ret = searchKnn(query_data, k, bitset);
while (!ret.empty()) {
result.push_back(ret.top());
......
......@@ -27,6 +27,7 @@
#include <vector>
#include <string.h>
#include <faiss/utils/ConcurrentBitset.h>
namespace hnswlib {
typedef int64_t labeltype;
......@@ -80,9 +81,9 @@ namespace hnswlib {
class AlgorithmInterface {
public:
virtual void addPoint(const void *datapoint, labeltype label)=0;
virtual std::priority_queue<std::pair<dist_t, labeltype >> searchKnn(const void *, size_t) const = 0;
virtual std::priority_queue<std::pair<dist_t, labeltype >> searchKnn(const void *, size_t, faiss::ConcurrentBitsetPtr bitset) const = 0;
template <typename Comp>
std::vector<std::pair<dist_t, labeltype>> searchKnn(const void*, size_t, Comp) {
std::vector<std::pair<dist_t, labeltype>> searchKnn(const void*, size_t, Comp, faiss::ConcurrentBitsetPtr bitset) {
}
virtual void saveIndex(const std::string &location)=0;
virtual ~AlgorithmInterface(){
......
......@@ -89,6 +89,16 @@ if (NOT TARGET test_idmap)
endif ()
target_link_libraries(test_idmap ${depend_libs} ${unittest_libs} ${basic_libs})
#<HNSW-TEST>
set(hnsw_srcs
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexHNSW.cpp
)
if (NOT TARGET test_hnsw)
add_executable(test_hnsw test_hnsw.cpp ${hnsw_srcs} ${util_srcs})
endif ()
target_link_libraries(test_hnsw ${depend_libs} ${unittest_libs} ${basic_libs})
#<BinaryIDMAP-TEST>
if (NOT TARGET test_binaryidmap)
add_executable(test_binaryidmap test_binaryidmap.cpp ${ivf_srcs} ${util_srcs})
......@@ -128,6 +138,7 @@ endif ()
target_link_libraries(test_knowhere_common ${depend_libs} ${unittest_libs} ${basic_libs})
install(TARGETS test_ivf DESTINATION unittest)
install(TARGETS test_hnsw DESTINATION unittest)
install(TARGETS test_binaryivf DESTINATION unittest)
install(TARGETS test_idmap DESTINATION unittest)
install(TARGETS test_binaryidmap DESTINATION unittest)
......
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include <knowhere/index/vector_index/IndexHNSW.h>
#include <src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h>
#include <iostream>
#include <random>
#include "./utils.h"
int
main() {
int64_t d = 64; // dimension
int64_t nb = 10000; // database size
int64_t nq = 10; // 10000; // nb of queries
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(nb);
int64_t* ids = new int64_t[nb];
float* xb = new float[d * nb];
float* xq = new float[d * nq];
// int64_t *ids = (int64_t*)malloc(nb * sizeof(int64_t));
// float* xb = (float*)malloc(d * nb * sizeof(float));
// float* xq = (float*)malloc(d * nq * sizeof(float));
for (int i = 0; i < nb; i++) {
for (int j = 0; j < d; j++) xb[d * i + j] = drand48();
xb[d * i] += i / 1000.;
ids[i] = i;
}
printf("gen xb and ids done! \n");
// srand((unsigned)time(NULL));
auto random_seed = (unsigned)time(NULL);
printf("delete ids: \n");
for (int i = 0; i < nq; i++) {
auto tmp = rand_r(&random_seed) % nb;
printf("%ld\n", tmp);
// std::cout << "before delete, test result: " << bitset->test(tmp) << std::endl;
bitset->set(tmp);
// std::cout << "after delete, test result: " << bitset->test(tmp) << std::endl;
for (int j = 0; j < d; j++) xq[d * i + j] = xb[d * tmp + j];
// xq[d * i] += i / 1000.;
}
printf("\n");
int k = 4;
int m = 16;
int ef = 200;
milvus::knowhere::IndexHNSW index;
milvus::knowhere::DatasetPtr base_dataset = generate_dataset(nb, d, (const void*)xb, ids);
/*
base_dataset->Set(milvus::knowhere::meta::ROWS, nb);
base_dataset->Set(milvus::knowhere::meta::DIM, d);
base_dataset->Set(milvus::knowhere::meta::TENSOR, (const void*)xb);
base_dataset->Set(milvus::knowhere::meta::IDS, (const int64_t*)ids);
*/
milvus::knowhere::Config base_conf{
{milvus::knowhere::meta::DIM, d},
{milvus::knowhere::meta::TOPK, k},
{milvus::knowhere::IndexParams::M, m},
{milvus::knowhere::IndexParams::efConstruction, ef},
{milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
};
milvus::knowhere::DatasetPtr query_dataset = generate_query_dataset(nq, d, (const void*)xq);
milvus::knowhere::Config query_conf{
{milvus::knowhere::meta::DIM, d},
{milvus::knowhere::meta::TOPK, k},
{milvus::knowhere::IndexParams::M, m},
{milvus::knowhere::IndexParams::ef, ef},
{milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
};
index.Train(base_dataset, base_conf);
index.Add(base_dataset, base_conf);
printf("------------sanity check----------------\n");
{ // sanity check
auto res = index.Query(query_dataset, query_conf);
printf("Query done!\n");
const int64_t* I = res->Get<int64_t*>(milvus::knowhere::meta::IDS);
float* D = res->Get<float*>(milvus::knowhere::meta::DISTANCE);
printf("I=\n");
for (int i = 0; i < 5; i++) {
for (int j = 0; j < k; j++) printf("%5ld ", I[i * k + j]);
printf("\n");
}
printf("D=\n");
for (int i = 0; i < 5; i++) {
for (int j = 0; j < k; j++) printf("%7g ", D[i * k + j]);
printf("\n");
}
}
printf("---------------search xq-------------\n");
{ // search xq
auto res = index.Query(query_dataset, query_conf);
const int64_t* I = res->Get<int64_t*>(milvus::knowhere::meta::IDS);
printf("I=\n");
for (int i = 0; i < nq; i++) {
for (int j = 0; j < k; j++) printf("%5ld ", I[i * k + j]);
printf("\n");
}
}
printf("----------------search xq with delete------------\n");
{ // search xq with delete
index.SetBlacklist(bitset);
auto res = index.Query(query_dataset, query_conf);
auto I = res->Get<int64_t*>(milvus::knowhere::meta::IDS);
printf("I=\n");
for (int i = 0; i < nq; i++) {
for (int j = 0; j < k; j++) printf("%5ld ", I[i * k + j]);
printf("\n");
}
}
delete[] xb;
delete[] xq;
delete[] ids;
return 0;
}
......@@ -70,6 +70,7 @@ DeleteByIDRequest::OnExecute() {
// Check table's index type supports delete
if (table_schema.engine_type_ != (int32_t)engine::EngineType::FAISS_IDMAP &&
table_schema.engine_type_ != (int32_t)engine::EngineType::FAISS_BIN_IDMAP &&
table_schema.engine_type_ != (int32_t)engine::EngineType::HNSW &&
table_schema.engine_type_ != (int32_t)engine::EngineType::FAISS_IVFFLAT &&
table_schema.engine_type_ != (int32_t)engine::EngineType::FAISS_BIN_IVFFLAT &&
table_schema.engine_type_ != (int32_t)engine::EngineType::FAISS_IVFSQ8 &&
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册