未验证 提交 cda57cf7 编写于 作者: T Tinkerrr 提交者: GitHub

Update NSG (#1744)

* enable IP and fix crash
Signed-off-by: NNicky <nicky.xj.lin@gmail.com>

* update.
Signed-off-by: Nxiaojun.lin <xiaojun.lin@zilliz.com>

* lint pass
Signed-off-by: Nxiaojun.lin <xiaojun.lin@zilliz.com>
上级 e865e9c8
......@@ -13,6 +13,8 @@ Please mark all change in change log and use the issue from GitHub
- \#1663 PQ index parameter 'm' validation
- \#1686 API search_in_files cannot work correctly when vectors is stored in certain non-default partition
- \#1689 Fix SQ8H search fail on SIFT-1B dataset
- \#1667 Create index failed with type: rnsg if metric_type is IP
- \#1708 NSG search crashed
- \#1724 Remove unused unittests
- \#1734 Opentracing for combined search request
......
......@@ -211,7 +211,7 @@ NSGConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
static int64_t MAX_OUT_DEGREE = 300;
static int64_t MIN_CANDIDATE_POOL_SIZE = 50;
static int64_t MAX_CANDIDATE_POOL_SIZE = 1000;
static std::vector<std::string> METRICS{knowhere::Metric::L2};
static std::vector<std::string> METRICS{knowhere::Metric::L2, knowhere::Metric::IP};
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
......
......@@ -10,6 +10,7 @@
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <fiu-local.h>
#include <string>
#include "knowhere/common/Exception.h"
#include "knowhere/common/Timer.h"
......@@ -139,7 +140,7 @@ NSG::Train(const DatasetPtr& dataset_ptr, const Config& config) {
auto p_ids = dataset_ptr->Get<const int64_t*>(meta::IDS);
GETTENSOR(dataset_ptr)
index_ = std::make_shared<impl::NsgIndex>(dim, rows);
index_ = std::make_shared<impl::NsgIndex>(dim, rows, config[Metric::TYPE].get<std::string>());
index_->SetKnnGraph(knng);
index_->Build_with_ids(rows, (float*)p_data, (int64_t*)p_ids, b_params);
}
......
......@@ -9,6 +9,8 @@
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include "knowhere/index/vector_index/impl/nsg/NSG.h"
#include <algorithm>
#include <cstdlib>
#include <cstring>
......@@ -20,7 +22,6 @@
#include "knowhere/common/Exception.h"
#include "knowhere/common/Log.h"
#include "knowhere/common/Timer.h"
#include "knowhere/index/vector_index/impl/nsg/NSG.h"
#include "knowhere/index/vector_index/impl/nsg/NSGHelper.h"
namespace milvus {
......@@ -31,14 +32,11 @@ unsigned int seed = 100;
NsgIndex::NsgIndex(const size_t& dimension, const size_t& n, std::string metric)
: dimension(dimension), ntotal(n), metric_type(metric) {
// switch (metric) {
// case METRICTYPE::L2:
// break;
// case METRICTYPE::IP:
// distance_ = new DistanceIP;
// break;
// }
distance_ = new DistanceL2;
if (metric == knowhere::Metric::L2) {
distance_ = new DistanceL2;
} else if (metric == knowhere::Metric::IP) {
distance_ = new DistanceIP;
}
}
NsgIndex::~NsgIndex() {
......@@ -697,139 +695,176 @@ NsgIndex::FindUnconnectedNode(boost::dynamic_bitset<>& has_linked, int64_t& root
nsg[root].push_back(id);
}
void
NsgIndex::GetNeighbors(const float* query, node_t* I, float* D, SearchParams* params) {
size_t buffer_size = params ? params->search_length : search_length;
if (buffer_size > ntotal) {
KNOWHERE_THROW_MSG("Search Error, search_length > ntotal");
}
std::vector<Neighbor> resset(buffer_size);
std::vector<node_t> init_ids(buffer_size);
boost::dynamic_bitset<> has_calculated_dist{ntotal, 0};
{
/*
* copy navigation-point neighbor, pick random node if less than buffer size
*/
size_t count = 0;
// Get all neighbors
for (size_t i = 0; i < init_ids.size() && i < nsg[navigation_point].size(); ++i) {
init_ids[i] = nsg[navigation_point][i];
has_calculated_dist[init_ids[i]] = true;
++count;
}
while (count < buffer_size) {
node_t id = rand_r(&seed) % ntotal;
if (has_calculated_dist[id])
continue; // duplicate id
init_ids[count] = id;
++count;
has_calculated_dist[id] = true;
}
}
{
// init resset and sort by distance
for (size_t i = 0; i < init_ids.size(); ++i) {
node_t id = init_ids[i];
if (id >= static_cast<node_t>(ntotal)) {
KNOWHERE_THROW_MSG("Search Error, id > ntotal");
}
float dist = distance_->Compare(ori_data_ + id * dimension, query, dimension);
resset[i] = Neighbor(id, dist, false);
}
std::sort(resset.begin(), resset.end()); // sort by distance
// search nearest neighbor
size_t cursor = 0;
while (cursor < buffer_size) {
size_t nearest_updated_pos = buffer_size;
if (!resset[cursor].has_explored) {
resset[cursor].has_explored = true;
node_t start_pos = resset[cursor].id;
auto& wait_for_search_node_vec = nsg[start_pos];
for (size_t i = 0; i < wait_for_search_node_vec.size(); ++i) {
node_t id = wait_for_search_node_vec[i];
if (has_calculated_dist[id])
continue;
has_calculated_dist[id] = true;
float dist = distance_->Compare(query, ori_data_ + dimension * id, dimension);
if (dist >= resset[buffer_size - 1].distance)
continue;
//// difference from other GetNeighbors
Neighbor nn(id, dist, false);
///////////////////////////////////////
size_t pos = InsertIntoPool(resset.data(), buffer_size, nn); // replace with a closer node
if (pos < nearest_updated_pos)
nearest_updated_pos = pos;
//>> Debug code
/////
// std::cout << "pos: " << pos << ", nn: " << nn.id << ":" << nn.distance << ", nup: " <<
// nearest_updated_pos << std::endl;
/////
// trick: avoid search query search_length < init_ids.size() ...
if (buffer_size + 1 < resset.size())
++buffer_size;
}
}
if (cursor >= nearest_updated_pos) {
cursor = nearest_updated_pos; // re-search from new pos
} else {
++cursor;
}
}
}
if ((resset.size() - params->k) >= 0) {
for (size_t i = 0; i < params->k; ++i) {
I[i] = resset[i].id;
D[i] = resset[i].distance;
}
} else {
size_t i = 0;
for (; i < resset.size(); ++i) {
I[i] = resset[i].id;
D[i] = resset[i].distance;
}
for (; i < params->k; ++i) {
I[i] = -1;
D[i] = -1;
}
}
}
// void
// NsgIndex::GetNeighbors(const float* query, node_t* I, float* D, SearchParams* params) {
// size_t buffer_size = params ? params->search_length : search_length;
// if (buffer_size > ntotal) {
// KNOWHERE_THROW_MSG("Search Error, search_length > ntotal");
// }
// std::vector<Neighbor> resset(buffer_size);
// std::vector<node_t> init_ids(buffer_size);
// boost::dynamic_bitset<> has_calculated_dist{ntotal, 0};
// {
// /*
// * copy navigation-point neighbor, pick random node if less than buffer size
// */
// size_t count = 0;
// // Get all neighbors
// for (size_t i = 0; i < init_ids.size() && i < nsg[navigation_point].size(); ++i) {
// init_ids[i] = nsg[navigation_point][i];
// has_calculated_dist[init_ids[i]] = true;
// ++count;
// }
// while (count < buffer_size) {
// node_t id = rand_r(&seed) % ntotal;
// if (has_calculated_dist[id])
// continue; // duplicate id
// init_ids[count] = id;
// ++count;
// has_calculated_dist[id] = true;
// }
// }
// {
// // init resset and sort by distance
// for (size_t i = 0; i < init_ids.size(); ++i) {
// node_t id = init_ids[i];
// if (id >= static_cast<node_t>(ntotal)) {
// KNOWHERE_THROW_MSG("Search Error, id > ntotal");
// }
// float dist = distance_->Compare(ori_data_ + id * dimension, query, dimension);
// resset[i] = Neighbor(id, dist, false);
// }
// std::sort(resset.begin(), resset.end()); // sort by distance
// // search nearest neighbor
// size_t cursor = 0;
// while (cursor < buffer_size) {
// size_t nearest_updated_pos = buffer_size;
// if (!resset[cursor].has_explored) {
// resset[cursor].has_explored = true;
// node_t start_pos = resset[cursor].id;
// auto& wait_for_search_node_vec = nsg[start_pos];
// for (size_t i = 0; i < wait_for_search_node_vec.size(); ++i) {
// node_t id = wait_for_search_node_vec[i];
// if (has_calculated_dist[id])
// continue;
// has_calculated_dist[id] = true;
// float dist = distance_->Compare(query, ori_data_ + dimension * id, dimension);
// if (dist >= resset[buffer_size - 1].distance)
// continue;
// //// difference from other GetNeighbors
// Neighbor nn(id, dist, false);
// ///////////////////////////////////////
// size_t pos = InsertIntoPool(resset.data(), buffer_size, nn); // replace with a closer node
// if (pos < nearest_updated_pos)
// nearest_updated_pos = pos;
// //>> Debug code
// /////
// // std::cout << "pos: " << pos << ", nn: " << nn.id << ":" << nn.distance << ", nup: " <<
// // nearest_updated_pos << std::endl;
// /////
// // trick: avoid search query search_length < init_ids.size() ...
// if (buffer_size + 1 < resset.size())
// ++buffer_size;
// }
// }
// if (cursor >= nearest_updated_pos) {
// cursor = nearest_updated_pos; // re-search from new pos
// } else {
// ++cursor;
// }
// }
// }
// if ((resset.size() - params->k) >= 0) {
// for (size_t i = 0; i < params->k; ++i) {
// I[i] = resset[i].id;
// D[i] = resset[i].distance;
// }
// } else {
// size_t i = 0;
// for (; i < resset.size(); ++i) {
// I[i] = resset[i].id;
// D[i] = resset[i].distance;
// }
// for (; i < params->k; ++i) {
// I[i] = -1;
// D[i] = -1;
// }
// }
// }
// void
// NsgIndex::Search(const float* query, const unsigned& nq, const unsigned& dim, const unsigned& k, float* dist,
// int64_t* ids, SearchParams& params) {
// // if (k >= 45) {
// // params.search_length = k;
// // }
// TimeRecorder rc("nsgsearch", 1);
// if (nq == 1) {
// GetNeighbors(query, ids, dist, &params);
// } else {
// #pragma omp parallel for
// for (unsigned int i = 0; i < nq; ++i) {
// const float* single_query = query + i * dim;
// GetNeighbors(single_query, ids + i * k, dist + i * k, &params);
// }
// }
// rc.ElapseFromBegin("seach finish");
// }
void
NsgIndex::Search(const float* query, const unsigned& nq, const unsigned& dim, const unsigned& k, float* dist,
int64_t* ids, SearchParams& params) {
// if (k >= 45) {
// params.search_length = k;
// }
TimeRecorder rc("nsgsearch", 1);
std::vector<std::vector<Neighbor>> resset(nq);
TimeRecorder rc("NsgIndex::search", 1);
if (nq == 1) {
GetNeighbors(query, ids, dist, &params);
GetNeighbors(query, resset[0], nsg, &params);
} else {
#pragma omp parallel for
for (unsigned int i = 0; i < nq; ++i) {
const float* single_query = query + i * dim;
GetNeighbors(single_query, ids + i * k, dist + i * k, &params);
GetNeighbors(single_query, resset[i], nsg, &params);
}
}
rc.RecordSection("search");
for (unsigned int i = 0; i < nq; ++i) {
int64_t var = resset[i].size() - k;
if (var >= 0) {
for (unsigned int j = 0; j < k; ++j) {
ids[i * k + j] = ids_[resset[i][j].id];
dist[i * k + j] = resset[i][j].distance;
}
} else {
for (unsigned int j = 0; j < resset[i].size(); ++j) {
ids[i * k + j] = ids_[resset[i][j].id];
dist[i * k + j] = resset[i][j].distance;
}
for (unsigned int j = resset[i].size(); j < k; ++j) {
ids[i * k + j] = -1;
dist[i * k + j] = -1;
}
}
}
rc.ElapseFromBegin("seach finish");
rc.RecordSection("merge");
}
void
......
......@@ -11,16 +11,16 @@
#pragma once
#include <boost/dynamic_bitset.hpp>
#include <cstddef>
#include <mutex>
#include <string>
#include <vector>
#include <boost/dynamic_bitset.hpp>
#include "Distance.h"
#include "Neighbor.h"
#include "knowhere/common/Config.h"
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
namespace milvus {
namespace knowhere {
......@@ -65,7 +65,7 @@ class NsgIndex {
size_t out_degree;
public:
explicit NsgIndex(const size_t& dimension, const size_t& n, std::string metric = "L2");
explicit NsgIndex(const size_t& dimension, const size_t& n, std::string metric = knowhere::Metric::L2);
NsgIndex() = default;
......@@ -111,9 +111,9 @@ class NsgIndex {
void
GetNeighbors(const float* query, std::vector<Neighbor>& resset, Graph& graph, SearchParams* param = nullptr);
// used by search
void
GetNeighbors(const float* query, node_t* I, float* D, SearchParams* params);
// only for search
// void
// GetNeighbors(const float* query, node_t* I, float* D, SearchParams* params);
void
Link();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册