From 813f5151e8161a2c99300a8a455bc94f94c1c67b Mon Sep 17 00:00:00 2001 From: "xj.lin" Date: Mon, 22 Jul 2019 19:53:44 +0800 Subject: [PATCH] MS-267 Support Inner Product update.. Former-commit-id: 8fdbb39fbdd05853f25c374f437f3ed78a46345d --- cpp/src/db/ExecutionEngineImpl.cpp | 5 ++++- cpp/src/wrapper/knowhere/vec_impl.cpp | 8 ++++---- cpp/src/wrapper/knowhere/vec_impl.h | 2 +- cpp/src/wrapper/knowhere/vec_index.cpp | 3 ++- cpp/thirdparty/knowhere | 2 +- cpp/unittest/index_wrapper/knowhere_test.cpp | 6 +++--- 6 files changed, 15 insertions(+), 11 deletions(-) diff --git a/cpp/src/db/ExecutionEngineImpl.cpp b/cpp/src/db/ExecutionEngineImpl.cpp index c3b1afc3..35f68558 100644 --- a/cpp/src/db/ExecutionEngineImpl.cpp +++ b/cpp/src/db/ExecutionEngineImpl.cpp @@ -30,7 +30,10 @@ ExecutionEngineImpl::ExecutionEngineImpl(uint16_t dimension, index_ = CreatetVecIndex(EngineType::FAISS_IDMAP); if (!index_) throw Exception("Create Empty VecIndex"); - auto ec = std::static_pointer_cast(index_)->Build(dimension); + Config build_cfg; + build_cfg["dim"] = dimension; + AutoGenParams(index_->GetType(), 0, build_cfg); + auto ec = std::static_pointer_cast(index_)->Build(build_cfg); if (ec != server::KNOWHERE_SUCCESS) { throw Exception("Build index error"); } } diff --git a/cpp/src/wrapper/knowhere/vec_impl.cpp b/cpp/src/wrapper/knowhere/vec_impl.cpp index 63e4d51c..7efbd54f 100644 --- a/cpp/src/wrapper/knowhere/vec_impl.cpp +++ b/cpp/src/wrapper/knowhere/vec_impl.cpp @@ -144,10 +144,10 @@ int64_t *BFIndex::GetRawIds() { return std::static_pointer_cast(index_)->GetRawIds(); } -server::KnowhereError BFIndex::Build(const int64_t &d) { +server::KnowhereError BFIndex::Build(const Config &cfg) { try { - dim = d; - std::static_pointer_cast(index_)->Train(dim); + dim = cfg["dim"].as(); + std::static_pointer_cast(index_)->Train(cfg); } catch (KnowhereException &e) { WRAPPER_LOG_ERROR << e.what(); return server::KNOWHERE_UNEXPECTED_ERROR; @@ -171,7 +171,7 @@ server::KnowhereError BFIndex::BuildAll(const long &nb, dim = cfg["dim"].as(); auto dataset = GenDatasetWithIds(nb, dim, xb, ids); - std::static_pointer_cast(index_)->Train(dim); + std::static_pointer_cast(index_)->Train(cfg); index_->Add(dataset, cfg); } catch (KnowhereException &e) { WRAPPER_LOG_ERROR << e.what(); diff --git a/cpp/src/wrapper/knowhere/vec_impl.h b/cpp/src/wrapper/knowhere/vec_impl.h index 4f20d17b..c4a0e2ac 100644 --- a/cpp/src/wrapper/knowhere/vec_impl.h +++ b/cpp/src/wrapper/knowhere/vec_impl.h @@ -57,7 +57,7 @@ class BFIndex : public VecIndexImpl { public: explicit BFIndex(std::shared_ptr index) : VecIndexImpl(std::move(index), IndexType::FAISS_IDMAP) {}; - server::KnowhereError Build(const int64_t &d); + server::KnowhereError Build(const Config& cfg); float *GetRawVectors(); server::KnowhereError BuildAll(const long &nb, const float *xb, diff --git a/cpp/src/wrapper/knowhere/vec_index.cpp b/cpp/src/wrapper/knowhere/vec_index.cpp index 6f5d51a3..65364eb0 100644 --- a/cpp/src/wrapper/knowhere/vec_index.cpp +++ b/cpp/src/wrapper/knowhere/vec_index.cpp @@ -180,7 +180,7 @@ server::KnowhereError write_index(VecIndexPtr index, const std::string &location } catch (knowhere::KnowhereException &e) { WRAPPER_LOG_ERROR << e.what(); return server::KNOWHERE_UNEXPECTED_ERROR; - } catch (std::exception& e) { + } catch (std::exception &e) { WRAPPER_LOG_ERROR << e.what(); return server::KNOWHERE_ERROR; } @@ -192,6 +192,7 @@ server::KnowhereError write_index(VecIndexPtr index, const std::string &location void AutoGenParams(const IndexType &type, const long &size, zilliz::knowhere::Config &cfg) { if (!cfg.contains("nlist")) { cfg["nlist"] = int(size / 1000000.0 * 16384); } if (!cfg.contains("gpu_id")) { cfg["gpu_id"] = int(0); } + if (!cfg.contains("metric_type")) { cfg["metric_type"] = "L2"; } switch (type) { case IndexType::FAISS_IVFSQ8_MIX: { diff --git a/cpp/thirdparty/knowhere b/cpp/thirdparty/knowhere index b0b9dd18..f866ac4e 160000 --- a/cpp/thirdparty/knowhere +++ b/cpp/thirdparty/knowhere @@ -1 +1 @@ -Subproject commit b0b9dd18fadbf9dc0fccaad815e14e578a92993e +Subproject commit f866ac4e297dea477ec591a62679cf5cdd219cc8 diff --git a/cpp/unittest/index_wrapper/knowhere_test.cpp b/cpp/unittest/index_wrapper/knowhere_test.cpp index bec4c940..064d6dc9 100644 --- a/cpp/unittest/index_wrapper/knowhere_test.cpp +++ b/cpp/unittest/index_wrapper/knowhere_test.cpp @@ -96,17 +96,17 @@ INSTANTIATE_TEST_CASE_P(WrapperParam, KnowhereWrapperTest, //), std::make_tuple(IndexType::FAISS_IVFFLAT_MIX, "Default", 64, 100000, 10, 10, - Config::object{{"nlist", 1000}, {"dim", 64}}, + Config::object{{"nlist", 1000}, {"dim", 64}, {"metric_type", "L2"}}, Config::object{{"dim", 64}, {"k", 10}, {"nprobe", 5}} ), std::make_tuple(IndexType::FAISS_IDMAP, "Default", 64, 100000, 10, 10, - Config::object{{"dim", 64}}, + Config::object{{"dim", 64}, {"metric_type", "L2"}}, Config::object{{"dim", 64}, {"k", 10}} ), std::make_tuple(IndexType::FAISS_IVFSQ8_MIX, "Default", 64, 100000, 10, 10, - Config::object{{"dim", 64}, {"nlist", 1000}, {"nbits", 8}}, + Config::object{{"dim", 64}, {"nlist", 1000}, {"nbits", 8}, {"metric_type", "L2"}}, Config::object{{"dim", 64}, {"k", 10}, {"nprobe", 5}} ) //std::make_tuple(IndexType::SPTAG_KDT_RNT_CPU, "Default", -- GitLab