提交 1cd2c5c5 编写于 作者: J jinhai

Merge branch 'MS-351' into 'branch-0.4.0'

Support MS-351

Closes MS-351

See merge request megasearch/milvus!357

Former-commit-id: 43c7852d7e42215058f031c947c08ba4e21aa556
...@@ -134,6 +134,24 @@ IndexType VecIndexImpl::GetType() { ...@@ -134,6 +134,24 @@ IndexType VecIndexImpl::GetType() {
return type; return type;
} }
VecIndexPtr VecIndexImpl::CopyToGpu(const int64_t &device_id, const Config &cfg) {
//if (auto new_type = GetGpuIndexType(type)) {
// auto device_index = index_->CopyToGpu(device_id);
// return std::make_shared<VecIndexImpl>(device_index, new_type);
//}
//return nullptr;
// TODO(linxj): update type
auto gpu_index = zilliz::knowhere::CopyCpuToGpu(index_, device_id, cfg);
return std::make_shared<VecIndexImpl>(gpu_index, type);
}
// TODO(linxj): rename copytocpu => copygputocpu
VecIndexPtr VecIndexImpl::CopyToCpu(const Config &cfg) {
auto cpu_index = zilliz::knowhere::CopyGpuToCpu(index_, cfg);
return std::make_shared<VecIndexImpl>(cpu_index, type);
}
float *BFIndex::GetRawVectors() { float *BFIndex::GetRawVectors() {
auto raw_index = std::dynamic_pointer_cast<IDMAP>(index_); auto raw_index = std::dynamic_pointer_cast<IDMAP>(index_);
if (raw_index) { return raw_index->GetRawVectors(); } if (raw_index) { return raw_index->GetRawVectors(); }
......
...@@ -25,6 +25,8 @@ class VecIndexImpl : public VecIndex { ...@@ -25,6 +25,8 @@ class VecIndexImpl : public VecIndex {
const Config &cfg, const Config &cfg,
const long &nt, const long &nt,
const float *xt) override; const float *xt) override;
VecIndexPtr CopyToGpu(const int64_t &device_id, const Config &cfg) override;
VecIndexPtr CopyToCpu(const Config &cfg) override;
IndexType GetType() override; IndexType GetType() override;
int64_t Dimension() override; int64_t Dimension() override;
int64_t Count() override; int64_t Count() override;
......
...@@ -35,6 +35,9 @@ enum class IndexType { ...@@ -35,6 +35,9 @@ enum class IndexType {
NSG_MIX, NSG_MIX,
}; };
class VecIndex;
using VecIndexPtr = std::shared_ptr<VecIndex>;
class VecIndex { class VecIndex {
public: public:
virtual server::KnowhereError BuildAll(const long &nb, virtual server::KnowhereError BuildAll(const long &nb,
...@@ -55,6 +58,11 @@ class VecIndex { ...@@ -55,6 +58,11 @@ class VecIndex {
long *ids, long *ids,
const Config &cfg = Config()) = 0; const Config &cfg = Config()) = 0;
virtual VecIndexPtr CopyToGpu(const int64_t& device_id,
const Config &cfg = Config()) = 0;
virtual VecIndexPtr CopyToCpu(const Config &cfg = Config()) = 0;
virtual IndexType GetType() = 0; virtual IndexType GetType() = 0;
virtual int64_t Dimension() = 0; virtual int64_t Dimension() = 0;
...@@ -66,8 +74,6 @@ class VecIndex { ...@@ -66,8 +74,6 @@ class VecIndex {
virtual server::KnowhereError Load(const zilliz::knowhere::BinarySet &index_binary) = 0; virtual server::KnowhereError Load(const zilliz::knowhere::BinarySet &index_binary) = 0;
}; };
using VecIndexPtr = std::shared_ptr<VecIndex>;
extern server::KnowhereError write_index(VecIndexPtr index, const std::string &location); extern server::KnowhereError write_index(VecIndexPtr index, const std::string &location);
extern VecIndexPtr read_index(const std::string &location); extern VecIndexPtr read_index(const std::string &location);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册