diff --git a/cpp/src/db/ExecutionEngineImpl.cpp b/cpp/src/db/ExecutionEngineImpl.cpp index c3b1afc3752ac2a521135d97a3cf4fb3692920ec..35f68558c4a24693461b1d8bc8e9a54cc02df535 100644 --- a/cpp/src/db/ExecutionEngineImpl.cpp +++ b/cpp/src/db/ExecutionEngineImpl.cpp @@ -30,7 +30,10 @@ ExecutionEngineImpl::ExecutionEngineImpl(uint16_t dimension, index_ = CreatetVecIndex(EngineType::FAISS_IDMAP); if (!index_) throw Exception("Create Empty VecIndex"); - auto ec = std::static_pointer_cast(index_)->Build(dimension); + Config build_cfg; + build_cfg["dim"] = dimension; + AutoGenParams(index_->GetType(), 0, build_cfg); + auto ec = std::static_pointer_cast(index_)->Build(build_cfg); if (ec != server::KNOWHERE_SUCCESS) { throw Exception("Build index error"); } } diff --git a/cpp/src/wrapper/knowhere/vec_impl.cpp b/cpp/src/wrapper/knowhere/vec_impl.cpp index 63e4d51c26f48f1f4b1538431ae7abc4000da3cb..7efbd54f0f0446d6c5e9c754fe06896953ce3700 100644 --- a/cpp/src/wrapper/knowhere/vec_impl.cpp +++ b/cpp/src/wrapper/knowhere/vec_impl.cpp @@ -144,10 +144,10 @@ int64_t *BFIndex::GetRawIds() { return std::static_pointer_cast(index_)->GetRawIds(); } -server::KnowhereError BFIndex::Build(const int64_t &d) { +server::KnowhereError BFIndex::Build(const Config &cfg) { try { - dim = d; - std::static_pointer_cast(index_)->Train(dim); + dim = cfg["dim"].as(); + std::static_pointer_cast(index_)->Train(cfg); } catch (KnowhereException &e) { WRAPPER_LOG_ERROR << e.what(); return server::KNOWHERE_UNEXPECTED_ERROR; @@ -171,7 +171,7 @@ server::KnowhereError BFIndex::BuildAll(const long &nb, dim = cfg["dim"].as(); auto dataset = GenDatasetWithIds(nb, dim, xb, ids); - std::static_pointer_cast(index_)->Train(dim); + std::static_pointer_cast(index_)->Train(cfg); index_->Add(dataset, cfg); } catch (KnowhereException &e) { WRAPPER_LOG_ERROR << e.what(); diff --git a/cpp/src/wrapper/knowhere/vec_impl.h b/cpp/src/wrapper/knowhere/vec_impl.h index 4f20d17b6ae95a03a2ebdaf97696441e40ddfab8..c4a0e2ac6112fcbb18f08f47d0d33c1efb5ec681 100644 --- a/cpp/src/wrapper/knowhere/vec_impl.h +++ b/cpp/src/wrapper/knowhere/vec_impl.h @@ -57,7 +57,7 @@ class BFIndex : public VecIndexImpl { public: explicit BFIndex(std::shared_ptr index) : VecIndexImpl(std::move(index), IndexType::FAISS_IDMAP) {}; - server::KnowhereError Build(const int64_t &d); + server::KnowhereError Build(const Config& cfg); float *GetRawVectors(); server::KnowhereError BuildAll(const long &nb, const float *xb, diff --git a/cpp/src/wrapper/knowhere/vec_index.cpp b/cpp/src/wrapper/knowhere/vec_index.cpp index 6f5d51a3af4971dff036faecbed198289b502814..65364eb01fb42132367e00d852a93bf66eb8d347 100644 --- a/cpp/src/wrapper/knowhere/vec_index.cpp +++ b/cpp/src/wrapper/knowhere/vec_index.cpp @@ -180,7 +180,7 @@ server::KnowhereError write_index(VecIndexPtr index, const std::string &location } catch (knowhere::KnowhereException &e) { WRAPPER_LOG_ERROR << e.what(); return server::KNOWHERE_UNEXPECTED_ERROR; - } catch (std::exception& e) { + } catch (std::exception &e) { WRAPPER_LOG_ERROR << e.what(); return server::KNOWHERE_ERROR; } @@ -192,6 +192,7 @@ server::KnowhereError write_index(VecIndexPtr index, const std::string &location void AutoGenParams(const IndexType &type, const long &size, zilliz::knowhere::Config &cfg) { if (!cfg.contains("nlist")) { cfg["nlist"] = int(size / 1000000.0 * 16384); } if (!cfg.contains("gpu_id")) { cfg["gpu_id"] = int(0); } + if (!cfg.contains("metric_type")) { cfg["metric_type"] = "L2"; } switch (type) { case IndexType::FAISS_IVFSQ8_MIX: { diff --git a/cpp/thirdparty/knowhere b/cpp/thirdparty/knowhere index b0b9dd18fadbf9dc0fccaad815e14e578a92993e..f866ac4e297dea477ec591a62679cf5cdd219cc8 160000 --- a/cpp/thirdparty/knowhere +++ b/cpp/thirdparty/knowhere @@ -1 +1 @@ -Subproject commit b0b9dd18fadbf9dc0fccaad815e14e578a92993e +Subproject commit f866ac4e297dea477ec591a62679cf5cdd219cc8 diff --git a/cpp/unittest/index_wrapper/knowhere_test.cpp b/cpp/unittest/index_wrapper/knowhere_test.cpp index bec4c940cff29d35f6554a220f32a40ba32d8bc1..064d6dc911d1121d2e2b5862bde89310973b35cd 100644 --- a/cpp/unittest/index_wrapper/knowhere_test.cpp +++ b/cpp/unittest/index_wrapper/knowhere_test.cpp @@ -96,17 +96,17 @@ INSTANTIATE_TEST_CASE_P(WrapperParam, KnowhereWrapperTest, //), std::make_tuple(IndexType::FAISS_IVFFLAT_MIX, "Default", 64, 100000, 10, 10, - Config::object{{"nlist", 1000}, {"dim", 64}}, + Config::object{{"nlist", 1000}, {"dim", 64}, {"metric_type", "L2"}}, Config::object{{"dim", 64}, {"k", 10}, {"nprobe", 5}} ), std::make_tuple(IndexType::FAISS_IDMAP, "Default", 64, 100000, 10, 10, - Config::object{{"dim", 64}}, + Config::object{{"dim", 64}, {"metric_type", "L2"}}, Config::object{{"dim", 64}, {"k", 10}} ), std::make_tuple(IndexType::FAISS_IVFSQ8_MIX, "Default", 64, 100000, 10, 10, - Config::object{{"dim", 64}, {"nlist", 1000}, {"nbits", 8}}, + Config::object{{"dim", 64}, {"nlist", 1000}, {"nbits", 8}, {"metric_type", "L2"}}, Config::object{{"dim", 64}, {"k", 10}, {"nprobe", 5}} ) //std::make_tuple(IndexType::SPTAG_KDT_RNT_CPU, "Default",