提交 d38af191 编写于 作者: X xj.lin

MS-351 enable index move between cpu<=>gpu


Former-commit-id: 32ded76c691a23e4926626f35b1dda9ec77dbc09
上级 d3f5cf3c
......@@ -134,6 +134,24 @@ IndexType VecIndexImpl::GetType() {
return type;
}
VecIndexPtr VecIndexImpl::CopyToGpu(const int64_t &device_id, const Config &cfg) {
//if (auto new_type = GetGpuIndexType(type)) {
// auto device_index = index_->CopyToGpu(device_id);
// return std::make_shared<VecIndexImpl>(device_index, new_type);
//}
//return nullptr;
// TODO(linxj): update type
auto gpu_index = zilliz::knowhere::CopyCpuToGpu(index_, device_id, cfg);
return std::make_shared<VecIndexImpl>(gpu_index, type);
}
// TODO(linxj): rename copytocpu => copygputocpu
VecIndexPtr VecIndexImpl::CopyToCpu(const Config &cfg) {
auto cpu_index = zilliz::knowhere::CopyGpuToCpu(index_, cfg);
return std::make_shared<VecIndexImpl>(cpu_index, type);
}
float *BFIndex::GetRawVectors() {
auto raw_index = std::dynamic_pointer_cast<IDMAP>(index_);
if (raw_index) { return raw_index->GetRawVectors(); }
......
......@@ -25,6 +25,8 @@ class VecIndexImpl : public VecIndex {
const Config &cfg,
const long &nt,
const float *xt) override;
VecIndexPtr CopyToGpu(const int64_t &device_id, const Config &cfg) override;
VecIndexPtr CopyToCpu(const Config &cfg) override;
IndexType GetType() override;
int64_t Dimension() override;
int64_t Count() override;
......
......@@ -35,6 +35,9 @@ enum class IndexType {
NSG_MIX,
};
class VecIndex;
using VecIndexPtr = std::shared_ptr<VecIndex>;
class VecIndex {
public:
virtual server::KnowhereError BuildAll(const long &nb,
......@@ -55,6 +58,11 @@ class VecIndex {
long *ids,
const Config &cfg = Config()) = 0;
virtual VecIndexPtr CopyToGpu(const int64_t& device_id,
const Config &cfg = Config()) = 0;
virtual VecIndexPtr CopyToCpu(const Config &cfg = Config()) = 0;
virtual IndexType GetType() = 0;
virtual int64_t Dimension() = 0;
......@@ -66,8 +74,6 @@ class VecIndex {
virtual server::KnowhereError Load(const zilliz::knowhere::BinarySet &index_binary) = 0;
};
using VecIndexPtr = std::shared_ptr<VecIndex>;
extern server::KnowhereError write_index(VecIndexPtr index, const std::string &location);
extern VecIndexPtr read_index(const std::string &location);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册