From d38af1912f86a20f97c5439292487c87bbfbbd83 Mon Sep 17 00:00:00 2001 From: "xj.lin" Date: Wed, 14 Aug 2019 19:48:48 +0800 Subject: [PATCH] MS-351 enable index move between cpu<=>gpu Former-commit-id: 32ded76c691a23e4926626f35b1dda9ec77dbc09 --- cpp/src/wrapper/knowhere/vec_impl.cpp | 18 ++++++++++++++++++ cpp/src/wrapper/knowhere/vec_impl.h | 2 ++ cpp/src/wrapper/knowhere/vec_index.h | 10 ++++++++-- 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/cpp/src/wrapper/knowhere/vec_impl.cpp b/cpp/src/wrapper/knowhere/vec_impl.cpp index 7efbd54f..d1cc1ae4 100644 --- a/cpp/src/wrapper/knowhere/vec_impl.cpp +++ b/cpp/src/wrapper/knowhere/vec_impl.cpp @@ -134,6 +134,24 @@ IndexType VecIndexImpl::GetType() { return type; } +VecIndexPtr VecIndexImpl::CopyToGpu(const int64_t &device_id, const Config &cfg) { + //if (auto new_type = GetGpuIndexType(type)) { + // auto device_index = index_->CopyToGpu(device_id); + // return std::make_shared(device_index, new_type); + //} + //return nullptr; + + // TODO(linxj): update type + auto gpu_index = zilliz::knowhere::CopyCpuToGpu(index_, device_id, cfg); + return std::make_shared(gpu_index, type); +} + +// TODO(linxj): rename copytocpu => copygputocpu +VecIndexPtr VecIndexImpl::CopyToCpu(const Config &cfg) { + auto cpu_index = zilliz::knowhere::CopyGpuToCpu(index_, cfg); + return std::make_shared(cpu_index, type); +} + float *BFIndex::GetRawVectors() { auto raw_index = std::dynamic_pointer_cast(index_); if (raw_index) { return raw_index->GetRawVectors(); } diff --git a/cpp/src/wrapper/knowhere/vec_impl.h b/cpp/src/wrapper/knowhere/vec_impl.h index c4a0e2ac..5e46c16f 100644 --- a/cpp/src/wrapper/knowhere/vec_impl.h +++ b/cpp/src/wrapper/knowhere/vec_impl.h @@ -25,6 +25,8 @@ class VecIndexImpl : public VecIndex { const Config &cfg, const long &nt, const float *xt) override; + VecIndexPtr CopyToGpu(const int64_t &device_id, const Config &cfg) override; + VecIndexPtr CopyToCpu(const Config &cfg) override; IndexType GetType() override; int64_t Dimension() override; int64_t Count() override; diff --git a/cpp/src/wrapper/knowhere/vec_index.h b/cpp/src/wrapper/knowhere/vec_index.h index 80c8771d..08822838 100644 --- a/cpp/src/wrapper/knowhere/vec_index.h +++ b/cpp/src/wrapper/knowhere/vec_index.h @@ -35,6 +35,9 @@ enum class IndexType { NSG_MIX, }; +class VecIndex; +using VecIndexPtr = std::shared_ptr; + class VecIndex { public: virtual server::KnowhereError BuildAll(const long &nb, @@ -55,6 +58,11 @@ class VecIndex { long *ids, const Config &cfg = Config()) = 0; + virtual VecIndexPtr CopyToGpu(const int64_t& device_id, + const Config &cfg = Config()) = 0; + + virtual VecIndexPtr CopyToCpu(const Config &cfg = Config()) = 0; + virtual IndexType GetType() = 0; virtual int64_t Dimension() = 0; @@ -66,8 +74,6 @@ class VecIndex { virtual server::KnowhereError Load(const zilliz::knowhere::BinarySet &index_binary) = 0; }; -using VecIndexPtr = std::shared_ptr; - extern server::KnowhereError write_index(VecIndexPtr index, const std::string &location); extern VecIndexPtr read_index(const std::string &location); -- GitLab