diff --git a/cpp/src/wrapper/knowhere/vec_impl.cpp b/cpp/src/wrapper/knowhere/vec_impl.cpp index 7efbd54f0f0446d6c5e9c754fe06896953ce3700..d1cc1ae4ffa917cab9740db458f1dccd5a882ec0 100644 --- a/cpp/src/wrapper/knowhere/vec_impl.cpp +++ b/cpp/src/wrapper/knowhere/vec_impl.cpp @@ -134,6 +134,24 @@ IndexType VecIndexImpl::GetType() { return type; } +VecIndexPtr VecIndexImpl::CopyToGpu(const int64_t &device_id, const Config &cfg) { + //if (auto new_type = GetGpuIndexType(type)) { + // auto device_index = index_->CopyToGpu(device_id); + // return std::make_shared(device_index, new_type); + //} + //return nullptr; + + // TODO(linxj): update type + auto gpu_index = zilliz::knowhere::CopyCpuToGpu(index_, device_id, cfg); + return std::make_shared(gpu_index, type); +} + +// TODO(linxj): rename copytocpu => copygputocpu +VecIndexPtr VecIndexImpl::CopyToCpu(const Config &cfg) { + auto cpu_index = zilliz::knowhere::CopyGpuToCpu(index_, cfg); + return std::make_shared(cpu_index, type); +} + float *BFIndex::GetRawVectors() { auto raw_index = std::dynamic_pointer_cast(index_); if (raw_index) { return raw_index->GetRawVectors(); } diff --git a/cpp/src/wrapper/knowhere/vec_impl.h b/cpp/src/wrapper/knowhere/vec_impl.h index c4a0e2ac6112fcbb18f08f47d0d33c1efb5ec681..5e46c16f70f2c08f5e365d7bf3eaa7a2a5d37a3e 100644 --- a/cpp/src/wrapper/knowhere/vec_impl.h +++ b/cpp/src/wrapper/knowhere/vec_impl.h @@ -25,6 +25,8 @@ class VecIndexImpl : public VecIndex { const Config &cfg, const long &nt, const float *xt) override; + VecIndexPtr CopyToGpu(const int64_t &device_id, const Config &cfg) override; + VecIndexPtr CopyToCpu(const Config &cfg) override; IndexType GetType() override; int64_t Dimension() override; int64_t Count() override; diff --git a/cpp/src/wrapper/knowhere/vec_index.h b/cpp/src/wrapper/knowhere/vec_index.h index 80c8771ddab2a7cfcb89e3b1a87be8cdf4ce58c0..088228386cac4646ff01b98e7e5b7d3d601194a4 100644 --- a/cpp/src/wrapper/knowhere/vec_index.h +++ b/cpp/src/wrapper/knowhere/vec_index.h @@ -35,6 +35,9 @@ enum class IndexType { NSG_MIX, }; +class VecIndex; +using VecIndexPtr = std::shared_ptr; + class VecIndex { public: virtual server::KnowhereError BuildAll(const long &nb, @@ -55,6 +58,11 @@ class VecIndex { long *ids, const Config &cfg = Config()) = 0; + virtual VecIndexPtr CopyToGpu(const int64_t& device_id, + const Config &cfg = Config()) = 0; + + virtual VecIndexPtr CopyToCpu(const Config &cfg = Config()) = 0; + virtual IndexType GetType() = 0; virtual int64_t Dimension() = 0; @@ -66,8 +74,6 @@ class VecIndex { virtual server::KnowhereError Load(const zilliz::knowhere::BinarySet &index_binary) = 0; }; -using VecIndexPtr = std::shared_ptr; - extern server::KnowhereError write_index(VecIndexPtr index, const std::string &location); extern VecIndexPtr read_index(const std::string &location);