From 651f642d3182a5cb8415e6d3221ee022b1bcaf0c Mon Sep 17 00:00:00 2001 From: "xj.lin" Date: Sun, 21 Jul 2019 16:10:26 +0800 Subject: [PATCH] MS-259 add exception Former-commit-id: 8f6b7b9c2e178146d0db55d14a2b716999d8dc0a --- cpp/src/db/ExecutionEngineImpl.cpp | 71 +++-- cpp/src/utils/Error.h | 6 + cpp/src/wrapper/knowhere/vec_impl.cpp | 277 ++++++++++++------- cpp/src/wrapper/knowhere/vec_impl.h | 46 +-- cpp/src/wrapper/knowhere/vec_index.cpp | 43 +-- cpp/src/wrapper/knowhere/vec_index.h | 40 +-- cpp/thirdparty/knowhere | 2 +- cpp/unittest/index_wrapper/knowhere_test.cpp | 4 + 8 files changed, 310 insertions(+), 179 deletions(-) diff --git a/cpp/src/db/ExecutionEngineImpl.cpp b/cpp/src/db/ExecutionEngineImpl.cpp index 5927d093..63ed00d2 100644 --- a/cpp/src/db/ExecutionEngineImpl.cpp +++ b/cpp/src/db/ExecutionEngineImpl.cpp @@ -3,6 +3,8 @@ * Unauthorized copying of this file, via any medium is strictly prohibited. * Proprietary and confidential. ******************************************************************************/ +#include + #include #include #include "Log.h" @@ -11,6 +13,8 @@ #include "ExecutionEngineImpl.h" #include "wrapper/knowhere/vec_index.h" #include "wrapper/knowhere/vec_impl.h" +#include "knowhere/common/exception.h" +#include "Exception.h" namespace zilliz { @@ -21,9 +25,13 @@ ExecutionEngineImpl::ExecutionEngineImpl(uint16_t dimension, const std::string &location, EngineType type) : location_(location), dim(dimension), build_type(type) { - index_ = CreatetVecIndex(EngineType::FAISS_IDMAP); current_type = EngineType::FAISS_IDMAP; - std::static_pointer_cast(index_)->Build(dimension); + + index_ = CreatetVecIndex(EngineType::FAISS_IDMAP); + if (!index_) throw Exception("Create Empty VecIndex"); + + auto ec = std::static_pointer_cast(index_)->Build(dimension); + if (ec != server::KNOWHERE_SUCCESS) { throw Exception("Build index error"); } } ExecutionEngineImpl::ExecutionEngineImpl(VecIndexPtr index, @@ -61,7 +69,10 @@ VecIndexPtr ExecutionEngineImpl::CreatetVecIndex(EngineType type) { } Status ExecutionEngineImpl::AddWithIds(long n, const float *xdata, const long *xids) { - index_->Add(n, xdata, xids, Config::object{{"dim", dim}}); + auto ec = index_->Add(n, xdata, xids, Config::object{{"dim", dim}}); + if (ec != server::KNOWHERE_SUCCESS) { + return Status::Error("Add error"); + } return Status::OK(); } @@ -82,7 +93,10 @@ size_t ExecutionEngineImpl::PhysicalSize() const { } Status ExecutionEngineImpl::Serialize() { - write_index(index_, location_); + auto ec = write_index(index_, location_); + if (ec != server::KNOWHERE_SUCCESS) { + return Status::Error("Serialize: write to disk error"); + } return Status::OK(); } @@ -91,9 +105,16 @@ Status ExecutionEngineImpl::Load() { bool to_cache = false; auto start_time = METRICS_NOW_TIME; if (!index_) { - index_ = read_index(location_); - to_cache = true; - ENGINE_LOG_DEBUG << "Disk io from: " << location_; + try { + index_ = read_index(location_); + to_cache = true; + ENGINE_LOG_DEBUG << "Disk io from: " << location_; + } catch (knowhere::KnowhereException &e) { + ENGINE_LOG_ERROR << e.what(); + return Status::Error(e.what()); + } catch (std::exception &e) { + return Status::Error(e.what()); + } } if (to_cache) { @@ -118,11 +139,22 @@ Status ExecutionEngineImpl::Merge(const std::string &location) { auto to_merge = zilliz::milvus::cache::CpuCacheMgr::GetInstance()->GetIndex(location); if (!to_merge) { - to_merge = read_index(location); + try { + to_merge = read_index(location); + } catch (knowhere::KnowhereException &e) { + ENGINE_LOG_ERROR << e.what(); + return Status::Error(e.what()); + } catch (std::exception &e) { + return Status::Error(e.what()); + } } if (auto file_index = std::dynamic_pointer_cast(to_merge)) { - index_->Add(file_index->Count(), file_index->GetRawVectors(), file_index->GetRawIds()); + auto ec = index_->Add(file_index->Count(), file_index->GetRawVectors(), file_index->GetRawIds()); + if (ec != server::KNOWHERE_SUCCESS) { + ENGINE_LOG_ERROR << "Merge: Add Error"; + return Status::Error("Merge: Add Error"); + } return Status::OK(); } else { return Status::Error("file index type is not idmap"); @@ -134,13 +166,16 @@ ExecutionEngineImpl::BuildIndex(const std::string &location) { ENGINE_LOG_DEBUG << "Build index file: " << location << " from: " << location_; auto from_index = std::dynamic_pointer_cast(index_); - ENGINE_LOG_DEBUG << "BuildIndex EngineTypee: " << int(build_type); auto to_index = CreatetVecIndex(build_type); - ENGINE_LOG_DEBUG << "Build Params: [gpu_id] " << gpu_num; - to_index->BuildAll(Count(), - from_index->GetRawVectors(), - from_index->GetRawIds(), - Config::object{{"dim", Dimension()}, {"gpu_id", gpu_num}}); + if (!to_index) { + throw Exception("Create Empty VecIndex"); + } + + auto ec = to_index->BuildAll(Count(), + from_index->GetRawVectors(), + from_index->GetRawIds(), + Config::object{{"dim", Dimension()}, {"gpu_id", gpu_num}}); + if (ec != server::KNOWHERE_SUCCESS) { throw Exception("Build index error"); } return std::make_shared(to_index, location, build_type); } @@ -151,7 +186,11 @@ Status ExecutionEngineImpl::Search(long n, float *distances, long *labels) const { ENGINE_LOG_DEBUG << "Search Params: [k] " << k << " [nprobe] " << nprobe_; - index_->Search(n, data, distances, labels, Config::object{{"k", k}, {"nprobe", nprobe_}}); + auto ec = index_->Search(n, data, distances, labels, Config::object{{"k", k}, {"nprobe", nprobe_}}); + if (ec != server::KNOWHERE_SUCCESS) { + ENGINE_LOG_ERROR << "Search error"; + return Status::Error("Search: Search Error"); + } return Status::OK(); } diff --git a/cpp/src/utils/Error.h b/cpp/src/utils/Error.h index 8c4da703..82b22d57 100644 --- a/cpp/src/utils/Error.h +++ b/cpp/src/utils/Error.h @@ -54,6 +54,12 @@ constexpr ServerError SERVER_LICENSE_VALIDATION_FAIL = ToGlobalServerErrorCode(5 constexpr ServerError DB_META_TRANSACTION_FAILED = ToGlobalServerErrorCode(1000); +using KnowhereError = int32_t; +constexpr KnowhereError KNOWHERE_SUCCESS = 0; +constexpr KnowhereError KNOWHERE_ERROR = ToGlobalServerErrorCode(1); +constexpr KnowhereError KNOWHERE_INVALID_ARGUMENT = ToGlobalServerErrorCode(2); +constexpr KnowhereError KNOWHERE_UNEXPECTED_ERROR = ToGlobalServerErrorCode(3); + class ServerException : public std::exception { public: ServerException(ServerError error_code, diff --git a/cpp/src/wrapper/knowhere/vec_impl.cpp b/cpp/src/wrapper/knowhere/vec_impl.cpp index 4ca48bfe..f0bcd30f 100644 --- a/cpp/src/wrapper/knowhere/vec_impl.cpp +++ b/cpp/src/wrapper/knowhere/vec_impl.cpp @@ -7,6 +7,7 @@ #include #include "knowhere/index/vector_index/idmap.h" #include "knowhere/index/vector_index/gpu_ivf.h" +#include "knowhere/common/exception.h" #include "vec_impl.h" #include "data_transfer.h" @@ -19,77 +20,110 @@ namespace engine { using namespace zilliz::knowhere; -void VecIndexImpl::BuildAll(const long &nb, - const float *xb, - const long *ids, - const Config &cfg, - const long &nt, - const float *xt) { - dim = cfg["dim"].as(); - auto dataset = GenDatasetWithIds(nb, dim, xb, ids); - - auto preprocessor = index_->BuildPreprocessor(dataset, cfg); - index_->set_preprocessor(preprocessor); - auto nlist = int(nb / 1000000.0 * 16384); - auto cfg_t = Config::object{{"nlist", nlist}, {"dim", dim}}; - auto model = index_->Train(dataset, cfg_t); - index_->set_index_model(model); - index_->Add(dataset, cfg); -} - -void VecIndexImpl::Add(const long &nb, const float *xb, const long *ids, const Config &cfg) { - // TODO(linxj): Assert index is trained; - - auto d = cfg.get_with_default("dim", dim); - auto dataset = GenDatasetWithIds(nb, d, xb, ids); - - index_->Add(dataset, cfg); -} - -void VecIndexImpl::Search(const long &nq, const float *xq, float *dist, long *ids, const Config &cfg) { - // TODO: Assert index is trained; - - auto k = cfg["k"].as(); - auto d = cfg.get_with_default("dim", dim); - auto dataset = GenDataset(nq, d, xq); - - Config search_cfg; - auto res = index_->Search(dataset, cfg); - auto ids_array = res->array()[0]; - auto dis_array = res->array()[1]; - - //{ - // auto& ids = ids_array; - // auto& dists = dis_array; - // std::stringstream ss_id; - // std::stringstream ss_dist; - // for (auto i = 0; i < 10; i++) { - // for (auto j = 0; j < k; ++j) { - // ss_id << *(ids->data()->GetValues(1, i * k + j)) << " "; - // ss_dist << *(dists->data()->GetValues(1, i * k + j)) << " "; - // } - // ss_id << std::endl; - // ss_dist << std::endl; - // } - // std::cout << "id\n" << ss_id.str() << std::endl; - // std::cout << "dist\n" << ss_dist.str() << std::endl; - //} - - auto p_ids = ids_array->data()->GetValues(1, 0); - auto p_dist = dis_array->data()->GetValues(1, 0); - - // TODO(linxj): avoid copy here. - memcpy(ids, p_ids, sizeof(int64_t) * nq * k); - memcpy(dist, p_dist, sizeof(float) * nq * k); +server::KnowhereError VecIndexImpl::BuildAll(const long &nb, + const float *xb, + const long *ids, + const Config &cfg, + const long &nt, + const float *xt) { + try { + dim = cfg["dim"].as(); + auto dataset = GenDatasetWithIds(nb, dim, xb, ids); + + auto preprocessor = index_->BuildPreprocessor(dataset, cfg); + index_->set_preprocessor(preprocessor); + auto nlist = int(nb / 1000000.0 * 16384); + auto cfg_t = Config::object{{"nlist", nlist}, {"dim", dim}}; + auto model = index_->Train(dataset, cfg_t); + index_->set_index_model(model); + index_->Add(dataset, cfg); + } catch (KnowhereException &e) { + WRAPPER_LOG_ERROR << e.what(); + return server::KNOWHERE_UNEXPECTED_ERROR; + } catch (jsoncons::json_exception &e) { + WRAPPER_LOG_ERROR << e.what(); + return server::KNOWHERE_INVALID_ARGUMENT; + } catch (std::exception &e) { + WRAPPER_LOG_ERROR << e.what(); + return server::KNOWHERE_ERROR; + } + return server::KNOWHERE_SUCCESS; +} + +server::KnowhereError VecIndexImpl::Add(const long &nb, const float *xb, const long *ids, const Config &cfg) { + try { + auto d = cfg.get_with_default("dim", dim); + auto dataset = GenDatasetWithIds(nb, d, xb, ids); + + index_->Add(dataset, cfg); + } catch (KnowhereException &e) { + WRAPPER_LOG_ERROR << e.what(); + return server::KNOWHERE_UNEXPECTED_ERROR; + } catch (jsoncons::json_exception &e) { + WRAPPER_LOG_ERROR << e.what(); + return server::KNOWHERE_INVALID_ARGUMENT; + } catch (std::exception &e) { + WRAPPER_LOG_ERROR << e.what(); + return server::KNOWHERE_ERROR; + } + return server::KNOWHERE_SUCCESS; +} + +server::KnowhereError VecIndexImpl::Search(const long &nq, const float *xq, float *dist, long *ids, const Config &cfg) { + try { + auto k = cfg["k"].as(); + auto d = cfg.get_with_default("dim", dim); + auto dataset = GenDataset(nq, d, xq); + + Config search_cfg; + auto res = index_->Search(dataset, cfg); + auto ids_array = res->array()[0]; + auto dis_array = res->array()[1]; + + //{ + // auto& ids = ids_array; + // auto& dists = dis_array; + // std::stringstream ss_id; + // std::stringstream ss_dist; + // for (auto i = 0; i < 10; i++) { + // for (auto j = 0; j < k; ++j) { + // ss_id << *(ids->data()->GetValues(1, i * k + j)) << " "; + // ss_dist << *(dists->data()->GetValues(1, i * k + j)) << " "; + // } + // ss_id << std::endl; + // ss_dist << std::endl; + // } + // std::cout << "id\n" << ss_id.str() << std::endl; + // std::cout << "dist\n" << ss_dist.str() << std::endl; + //} + + auto p_ids = ids_array->data()->GetValues(1, 0); + auto p_dist = dis_array->data()->GetValues(1, 0); + + // TODO(linxj): avoid copy here. + memcpy(ids, p_ids, sizeof(int64_t) * nq * k); + memcpy(dist, p_dist, sizeof(float) * nq * k); + } catch (KnowhereException &e) { + WRAPPER_LOG_ERROR << e.what(); + return server::KNOWHERE_UNEXPECTED_ERROR; + } catch (jsoncons::json_exception &e) { + WRAPPER_LOG_ERROR << e.what(); + return server::KNOWHERE_INVALID_ARGUMENT; + } catch (std::exception &e) { + WRAPPER_LOG_ERROR << e.what(); + return server::KNOWHERE_ERROR; + } + return server::KNOWHERE_SUCCESS; } zilliz::knowhere::BinarySet VecIndexImpl::Serialize() { return index_->Serialize(); } -void VecIndexImpl::Load(const zilliz::knowhere::BinarySet &index_binary) { +server::KnowhereError VecIndexImpl::Load(const zilliz::knowhere::BinarySet &index_binary) { index_->Load(index_binary); dim = Dimension(); + return server::KNOWHERE_SUCCESS; } int64_t VecIndexImpl::Dimension() { @@ -114,56 +148,91 @@ int64_t *BFIndex::GetRawIds() { return std::static_pointer_cast(index_)->GetRawIds(); } -void BFIndex::Build(const int64_t &d) { - dim = d; - std::static_pointer_cast(index_)->Train(dim); -} - -void BFIndex::BuildAll(const long &nb, - const float *xb, - const long *ids, - const Config &cfg, - const long &nt, - const float *xt) { - dim = cfg["dim"].as(); - auto dataset = GenDatasetWithIds(nb, dim, xb, ids); - - std::static_pointer_cast(index_)->Train(dim); - index_->Add(dataset, cfg); +server::KnowhereError BFIndex::Build(const int64_t &d) { + try { + dim = d; + std::static_pointer_cast(index_)->Train(dim); + } catch (KnowhereException &e) { + WRAPPER_LOG_ERROR << e.what(); + return server::KNOWHERE_UNEXPECTED_ERROR; + } catch (jsoncons::json_exception &e) { + WRAPPER_LOG_ERROR << e.what(); + return server::KNOWHERE_INVALID_ARGUMENT; + } catch (std::exception &e) { + WRAPPER_LOG_ERROR << e.what(); + return server::KNOWHERE_ERROR; + } + return server::KNOWHERE_SUCCESS; +} + +server::KnowhereError BFIndex::BuildAll(const long &nb, + const float *xb, + const long *ids, + const Config &cfg, + const long &nt, + const float *xt) { + try { + dim = cfg["dim"].as(); + auto dataset = GenDatasetWithIds(nb, dim, xb, ids); + + std::static_pointer_cast(index_)->Train(dim); + index_->Add(dataset, cfg); + } catch (KnowhereException &e) { + WRAPPER_LOG_ERROR << e.what(); + return server::KNOWHERE_UNEXPECTED_ERROR; + } catch (jsoncons::json_exception &e) { + WRAPPER_LOG_ERROR << e.what(); + return server::KNOWHERE_INVALID_ARGUMENT; + } catch (std::exception &e) { + WRAPPER_LOG_ERROR << e.what(); + return server::KNOWHERE_ERROR; + } + return server::KNOWHERE_SUCCESS; } // TODO(linxj): add lock here. -void IVFMixIndex::BuildAll(const long &nb, - const float *xb, - const long *ids, - const Config &cfg, - const long &nt, - const float *xt) { - WRAPPER_LOG_DEBUG << "Get Into Build IVFMIX"; - - dim = cfg["dim"].as(); - auto dataset = GenDatasetWithIds(nb, dim, xb, ids); - - auto preprocessor = index_->BuildPreprocessor(dataset, cfg); - index_->set_preprocessor(preprocessor); - auto nlist = int(nb / 1000000.0 * 16384); - auto cfg_t = Config::object{{"nlist", nlist}, {"dim", dim}}; - auto model = index_->Train(dataset, cfg_t); - index_->set_index_model(model); - index_->Add(dataset, cfg); - - if (auto device_index = std::dynamic_pointer_cast(index_)) { - auto host_index = device_index->Copy_index_gpu_to_cpu(); - index_ = host_index; - } else { - WRAPPER_LOG_ERROR << "Build IVFMIXIndex Failed"; +server::KnowhereError IVFMixIndex::BuildAll(const long &nb, + const float *xb, + const long *ids, + const Config &cfg, + const long &nt, + const float *xt) { + try { + dim = cfg["dim"].as(); + auto dataset = GenDatasetWithIds(nb, dim, xb, ids); + + auto preprocessor = index_->BuildPreprocessor(dataset, cfg); + index_->set_preprocessor(preprocessor); + auto nlist = int(nb / 1000000.0 * 16384); + auto cfg_t = Config::object{{"nlist", nlist}, {"dim", dim}}; + auto model = index_->Train(dataset, cfg_t); + index_->set_index_model(model); + index_->Add(dataset, cfg); + + if (auto device_index = std::dynamic_pointer_cast(index_)) { + auto host_index = device_index->Copy_index_gpu_to_cpu(); + index_ = host_index; + } else { + WRAPPER_LOG_ERROR << "Build IVFMIXIndex Failed"; + } + } catch (KnowhereException &e) { + WRAPPER_LOG_ERROR << e.what(); + return server::KNOWHERE_UNEXPECTED_ERROR; + } catch (jsoncons::json_exception &e) { + WRAPPER_LOG_ERROR << e.what(); + return server::KNOWHERE_INVALID_ARGUMENT; + } catch (std::exception &e) { + WRAPPER_LOG_ERROR << e.what(); + return server::KNOWHERE_ERROR; } + return server::KNOWHERE_SUCCESS; } -void IVFMixIndex::Load(const zilliz::knowhere::BinarySet &index_binary) { +server::KnowhereError IVFMixIndex::Load(const zilliz::knowhere::BinarySet &index_binary) { index_ = std::make_shared(); index_->Load(index_binary); dim = Dimension(); + return server::KNOWHERE_SUCCESS; } } diff --git a/cpp/src/wrapper/knowhere/vec_impl.h b/cpp/src/wrapper/knowhere/vec_impl.h index 1d09a069..3d432ff0 100644 --- a/cpp/src/wrapper/knowhere/vec_impl.h +++ b/cpp/src/wrapper/knowhere/vec_impl.h @@ -19,19 +19,19 @@ class VecIndexImpl : public VecIndex { public: explicit VecIndexImpl(std::shared_ptr index, const IndexType &type) : index_(std::move(index)), type(type) {}; - void BuildAll(const long &nb, - const float *xb, - const long *ids, - const Config &cfg, - const long &nt, - const float *xt) override; + server::KnowhereError BuildAll(const long &nb, + const float *xb, + const long *ids, + const Config &cfg, + const long &nt, + const float *xt) override; IndexType GetType() override; int64_t Dimension() override; int64_t Count() override; - void Add(const long &nb, const float *xb, const long *ids, const Config &cfg) override; + server::KnowhereError Add(const long &nb, const float *xb, const long *ids, const Config &cfg) override; zilliz::knowhere::BinarySet Serialize() override; - void Load(const zilliz::knowhere::BinarySet &index_binary) override; - void Search(const long &nq, const float *xq, float *dist, long *ids, const Config &cfg) override; + server::KnowhereError Load(const zilliz::knowhere::BinarySet &index_binary) override; + server::KnowhereError Search(const long &nq, const float *xq, float *dist, long *ids, const Config &cfg) override; protected: int64_t dim = 0; @@ -43,27 +43,27 @@ class IVFMixIndex : public VecIndexImpl { public: explicit IVFMixIndex(std::shared_ptr index) : VecIndexImpl(std::move(index), IndexType::FAISS_IVFFLAT_MIX) {}; - void BuildAll(const long &nb, - const float *xb, - const long *ids, - const Config &cfg, - const long &nt, - const float *xt) override; - void Load(const zilliz::knowhere::BinarySet &index_binary) override; + server::KnowhereError BuildAll(const long &nb, + const float *xb, + const long *ids, + const Config &cfg, + const long &nt, + const float *xt) override; + server::KnowhereError Load(const zilliz::knowhere::BinarySet &index_binary) override; }; class BFIndex : public VecIndexImpl { public: explicit BFIndex(std::shared_ptr index) : VecIndexImpl(std::move(index), IndexType::FAISS_IDMAP) {}; - void Build(const int64_t &d); + server::KnowhereError Build(const int64_t &d); float *GetRawVectors(); - void BuildAll(const long &nb, - const float *xb, - const long *ids, - const Config &cfg, - const long &nt, - const float *xt) override; + server::KnowhereError BuildAll(const long &nb, + const float *xb, + const long *ids, + const Config &cfg, + const long &nt, + const float *xt) override; int64_t *GetRawIds(); }; diff --git a/cpp/src/wrapper/knowhere/vec_index.cpp b/cpp/src/wrapper/knowhere/vec_index.cpp index 55e1ea4c..342f10a6 100644 --- a/cpp/src/wrapper/knowhere/vec_index.cpp +++ b/cpp/src/wrapper/knowhere/vec_index.cpp @@ -7,9 +7,11 @@ #include "knowhere/index/vector_index/idmap.h" #include "knowhere/index/vector_index/gpu_ivf.h" #include "knowhere/index/vector_index/cpu_kdt_rng.h" +#include "knowhere/common/exception.h" #include "vec_index.h" #include "vec_impl.h" +#include "wrapper_log.h" namespace zilliz { @@ -153,23 +155,32 @@ VecIndexPtr read_index(const std::string &location) { return LoadVecIndex(current_type, load_data_list); } -void write_index(VecIndexPtr index, const std::string &location) { - auto binaryset = index->Serialize(); - auto index_type = index->GetType(); - - FileIOWriter writer(location); - writer(&index_type, sizeof(IndexType)); - for (auto &iter: binaryset.binary_map_) { - auto meta = iter.first.c_str(); - size_t meta_length = iter.first.length(); - writer(&meta_length, sizeof(meta_length)); - writer((void *) meta, meta_length); - - auto binary = iter.second; - int64_t binary_length = binary->size; - writer(&binary_length, sizeof(binary_length)); - writer((void *) binary->data.get(), binary_length); +server::KnowhereError write_index(VecIndexPtr index, const std::string &location) { + try { + auto binaryset = index->Serialize(); + auto index_type = index->GetType(); + + FileIOWriter writer(location); + writer(&index_type, sizeof(IndexType)); + for (auto &iter: binaryset.binary_map_) { + auto meta = iter.first.c_str(); + size_t meta_length = iter.first.length(); + writer(&meta_length, sizeof(meta_length)); + writer((void *) meta, meta_length); + + auto binary = iter.second; + int64_t binary_length = binary->size; + writer(&binary_length, sizeof(binary_length)); + writer((void *) binary->data.get(), binary_length); + } + } catch (knowhere::KnowhereException &e) { + WRAPPER_LOG_ERROR << e.what(); + return server::KNOWHERE_UNEXPECTED_ERROR; + } catch (std::exception& e) { + WRAPPER_LOG_ERROR << e.what(); + return server::KNOWHERE_ERROR; } + return server::KNOWHERE_SUCCESS; } } diff --git a/cpp/src/wrapper/knowhere/vec_index.h b/cpp/src/wrapper/knowhere/vec_index.h index a488922d..c3f55286 100644 --- a/cpp/src/wrapper/knowhere/vec_index.h +++ b/cpp/src/wrapper/knowhere/vec_index.h @@ -9,6 +9,8 @@ #include #include +#include "utils/Error.h" + #include "knowhere/common/config.h" #include "knowhere/common/binary_set.h" @@ -34,23 +36,23 @@ enum class IndexType { class VecIndex { public: - virtual void BuildAll(const long &nb, - const float *xb, - const long *ids, - const Config &cfg, - const long &nt = 0, - const float *xt = nullptr) = 0; - - virtual void Add(const long &nb, - const float *xb, - const long *ids, - const Config &cfg = Config()) = 0; - - virtual void Search(const long &nq, - const float *xq, - float *dist, - long *ids, - const Config &cfg = Config()) = 0; + virtual server::KnowhereError BuildAll(const long &nb, + const float *xb, + const long *ids, + const Config &cfg, + const long &nt = 0, + const float *xt = nullptr) = 0; + + virtual server::KnowhereError Add(const long &nb, + const float *xb, + const long *ids, + const Config &cfg = Config()) = 0; + + virtual server::KnowhereError Search(const long &nq, + const float *xq, + float *dist, + long *ids, + const Config &cfg = Config()) = 0; virtual IndexType GetType() = 0; @@ -60,12 +62,12 @@ class VecIndex { virtual zilliz::knowhere::BinarySet Serialize() = 0; - virtual void Load(const zilliz::knowhere::BinarySet &index_binary) = 0; + virtual server::KnowhereError Load(const zilliz::knowhere::BinarySet &index_binary) = 0; }; using VecIndexPtr = std::shared_ptr; -extern void write_index(VecIndexPtr index, const std::string &location); +extern server::KnowhereError write_index(VecIndexPtr index, const std::string &location); extern VecIndexPtr read_index(const std::string &location); diff --git a/cpp/thirdparty/knowhere b/cpp/thirdparty/knowhere index afaf6528..1a4dc447 160000 --- a/cpp/thirdparty/knowhere +++ b/cpp/thirdparty/knowhere @@ -1 +1 @@ -Subproject commit afaf65282737514e232bf477aacb2772a4d32d5d +Subproject commit 1a4dc447797d281c3c83255c1b8a7709fc8d7738 diff --git a/cpp/unittest/index_wrapper/knowhere_test.cpp b/cpp/unittest/index_wrapper/knowhere_test.cpp index 928d61d0..83a4d440 100644 --- a/cpp/unittest/index_wrapper/knowhere_test.cpp +++ b/cpp/unittest/index_wrapper/knowhere_test.cpp @@ -163,3 +163,7 @@ TEST_P(KnowhereWrapperTest, serialize) { } } +// TODO(linxj): add exception test +//TEST_P(KnowhereWrapperTest, exception_test) { +//} + -- GitLab