提交 36359554 编写于 作者: X xj.lin

MS-442 enable knowhere wrapper test


Former-commit-id: 6912138096ccc0966999edd6b9b199d7f2092202
上级 01f2db33
......@@ -29,7 +29,6 @@ VectorIndexPtr CopyCpuToGpu(const VectorIndexPtr &index, const int64_t &device_i
if (auto cpu_index = std::dynamic_pointer_cast<IVFSQ>(index)) {
return cpu_index->CopyCpuToGpu(device_id, config);
//KNOWHERE_THROW_MSG("IVFSQ not support tranfer to gpu");
} else if (auto cpu_index = std::dynamic_pointer_cast<IVFPQ>(index)) {
KNOWHERE_THROW_MSG("IVFPQ not support tranfer to gpu");
} else if (auto cpu_index = std::dynamic_pointer_cast<IVF>(index)) {
......
......@@ -31,10 +31,11 @@ IndexModelPtr GPUIVF::Train(const DatasetPtr &dataset, const Config &config) {
GETTENSOR(dataset)
// TODO(linxj): use device_id
auto res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_device);
ResScope rs(gpu_device, res);
faiss::gpu::GpuIndexIVFFlat device_index(res.get(), dim, nlist, metric_type);
faiss::gpu::GpuIndexIVFFlatConfig idx_config;
idx_config.device = gpu_device;
faiss::gpu::GpuIndexIVFFlat device_index(res.get(), dim, nlist, metric_type, idx_config);
device_index.train(rows, (float *) p_data);
std::shared_ptr<faiss::Index> host_index = nullptr;
......
......@@ -114,6 +114,7 @@ server::KnowhereError VecIndexImpl::Search(const long &nq, const float *xq, floa
}
zilliz::knowhere::BinarySet VecIndexImpl::Serialize() {
type = ConvertToCpuIndexType(type);
return index_->Serialize();
}
......@@ -136,26 +137,23 @@ IndexType VecIndexImpl::GetType() {
}
VecIndexPtr VecIndexImpl::CopyToGpu(const int64_t &device_id, const Config &cfg) {
//if (auto new_type = GetGpuIndexType(type)) {
// auto device_index = index_->CopyToGpu(device_id);
// return std::make_shared<VecIndexImpl>(device_index, new_type);
//}
//return nullptr;
// TODO(linxj): update type
// TODO(linxj): exception handle
auto gpu_index = zilliz::knowhere::CopyCpuToGpu(index_, device_id, cfg);
auto new_index = std::make_shared<VecIndexImpl>(gpu_index, type);
auto new_index = std::make_shared<VecIndexImpl>(gpu_index, ConvertToGpuIndexType(type));
new_index->dim = dim;
return new_index;
}
// TODO(linxj): rename copytocpu => copygputocpu
VecIndexPtr VecIndexImpl::CopyToCpu(const Config &cfg) {
// TODO(linxj): exception handle
auto cpu_index = zilliz::knowhere::CopyGpuToCpu(index_, cfg);
return std::make_shared<VecIndexImpl>(cpu_index, type);
auto new_index = std::make_shared<VecIndexImpl>(cpu_index, ConvertToCpuIndexType(type));
new_index->dim = dim;
return new_index;
}
VecIndexPtr VecIndexImpl::Clone() {
// TODO(linxj): exception handle
auto clone_index = std::make_shared<VecIndexImpl>(index_->Clone(), type);
clone_index->dim = dim;
return clone_index;
......@@ -165,10 +163,8 @@ int64_t VecIndexImpl::GetDeviceId() {
if (auto device_idx = std::dynamic_pointer_cast<GPUIndex>(index_)){
return device_idx->GetGpuDevice();
}
else {
return -1; // -1 == cpu
}
return 0;
// else
return -1; // -1 == cpu
}
float *BFIndex::GetRawVectors() {
......@@ -243,9 +239,10 @@ server::KnowhereError IVFMixIndex::BuildAll(const long &nb,
if (auto device_index = std::dynamic_pointer_cast<GPUIVF>(index_)) {
auto host_index = device_index->CopyGpuToCpu(Config());
index_ = host_index;
type = TransferToCpuIndexType(type);
type = ConvertToCpuIndexType(type);
} else {
WRAPPER_LOG_ERROR << "Build IVFMIXIndex Failed";
return server::KNOWHERE_ERROR;
}
} catch (KnowhereException &e) {
WRAPPER_LOG_ERROR << e.what();
......@@ -261,7 +258,7 @@ server::KnowhereError IVFMixIndex::BuildAll(const long &nb,
}
server::KnowhereError IVFMixIndex::Load(const zilliz::knowhere::BinarySet &index_binary) {
index_ = std::make_shared<IVF>();
//index_ = std::make_shared<IVF>();
index_->Load(index_binary);
dim = Dimension();
return server::KNOWHERE_SUCCESS;
......
......@@ -71,8 +71,9 @@ size_t FileIOWriter::operator()(void *ptr, size_t size) {
}
VecIndexPtr GetVecIndexFactory(const IndexType &type) {
VecIndexPtr GetVecIndexFactory(const IndexType &type, const Config& cfg) {
std::shared_ptr<zilliz::knowhere::VectorIndex> index;
auto gpu_device = cfg.get_with_default("gpu_id", 0);
switch (type) {
case IndexType::FAISS_IDMAP: {
index = std::make_shared<zilliz::knowhere::IDMAP>();
......@@ -83,7 +84,8 @@ VecIndexPtr GetVecIndexFactory(const IndexType &type) {
break;
}
case IndexType::FAISS_IVFFLAT_GPU: {
index = std::make_shared<zilliz::knowhere::GPUIVF>(0);
// TODO(linxj): 规范化参数
index = std::make_shared<zilliz::knowhere::GPUIVF>(gpu_device);
break;
}
case IndexType::FAISS_IVFFLAT_MIX: {
......@@ -95,7 +97,7 @@ VecIndexPtr GetVecIndexFactory(const IndexType &type) {
break;
}
case IndexType::FAISS_IVFPQ_GPU: {
index = std::make_shared<zilliz::knowhere::GPUIVFPQ>(0);
index = std::make_shared<zilliz::knowhere::GPUIVFPQ>(gpu_device);
break;
}
case IndexType::SPTAG_KDT_RNT_CPU: {
......@@ -103,15 +105,19 @@ VecIndexPtr GetVecIndexFactory(const IndexType &type) {
break;
}
case IndexType::FAISS_IVFSQ8_MIX: {
index = std::make_shared<zilliz::knowhere::GPUIVFSQ>(0);
index = std::make_shared<zilliz::knowhere::GPUIVFSQ>(gpu_device);
return std::make_shared<IVFMixIndex>(index, IndexType::FAISS_IVFSQ8_MIX);
}
case IndexType::FAISS_IVFSQ8: {
case IndexType::FAISS_IVFSQ8_CPU: {
index = std::make_shared<zilliz::knowhere::IVFSQ>();
break;
}
case IndexType::FAISS_IVFSQ8_GPU: {
index = std::make_shared<zilliz::knowhere::GPUIVFSQ>(gpu_device);
break;
}
case IndexType::NSG_MIX: { // TODO(linxj): bug.
index = std::make_shared<zilliz::knowhere::NSG>(0);
index = std::make_shared<zilliz::knowhere::NSG>(gpu_device);
break;
}
default: {
......@@ -229,20 +235,40 @@ void AutoGenParams(const IndexType &type, const long &size, zilliz::knowhere::Co
}
}
IndexType TransferToCpuIndexType(const IndexType &type) {
IndexType ConvertToCpuIndexType(const IndexType &type) {
// TODO(linxj): add IDMAP
switch (type) {
case IndexType::FAISS_IVFFLAT_GPU:
case IndexType::FAISS_IVFFLAT_MIX: {
return IndexType::FAISS_IVFFLAT_CPU;
}
case IndexType::FAISS_IVFSQ8_GPU:
case IndexType::FAISS_IVFSQ8_MIX: {
return IndexType::FAISS_IVFSQ8;
return IndexType::FAISS_IVFSQ8_CPU;
}
default: {
return IndexType::INVALID;
return type;
}
}
}
IndexType ConvertToGpuIndexType(const IndexType &type) {
switch (type) {
case IndexType::FAISS_IVFFLAT_MIX:
case IndexType::FAISS_IVFFLAT_CPU: {
return IndexType::FAISS_IVFFLAT_GPU;
}
case IndexType::FAISS_IVFSQ8_MIX:
case IndexType::FAISS_IVFSQ8_CPU: {
return IndexType::FAISS_IVFSQ8_GPU;
}
default: {
return type;
}
}
}
}
}
}
......@@ -32,7 +32,8 @@ enum class IndexType {
FAISS_IVFPQ_GPU,
SPTAG_KDT_RNT_CPU,
FAISS_IVFSQ8_MIX,
FAISS_IVFSQ8,
FAISS_IVFSQ8_CPU,
FAISS_IVFSQ8_GPU,
NSG_MIX,
};
......@@ -83,13 +84,14 @@ extern server::KnowhereError write_index(VecIndexPtr index, const std::string &l
extern VecIndexPtr read_index(const std::string &location);
extern VecIndexPtr GetVecIndexFactory(const IndexType &type);
extern VecIndexPtr GetVecIndexFactory(const IndexType &type, const Config& cfg = Config());
extern VecIndexPtr LoadVecIndex(const IndexType &index_type, const zilliz::knowhere::BinarySet &index_binary);
extern void AutoGenParams(const IndexType& type, const long& size, Config& cfg);
extern IndexType TransferToCpuIndexType(const IndexType& type);
extern IndexType ConvertToCpuIndexType(const IndexType& type);
extern IndexType ConvertToGpuIndexType(const IndexType& type);
}
}
......
......@@ -21,15 +21,15 @@ using ::testing::TestWithParam;
using ::testing::Values;
using ::testing::Combine;
constexpr int64_t DIM = 512;
constexpr int64_t NB = 1000000;
constexpr int64_t DIM = 128;
constexpr int64_t NB = 100000;
constexpr int64_t DEVICE_ID = 0;
class KnowhereWrapperTest
: public TestWithParam<::std::tuple<IndexType, std::string, int, int, int, int, Config, Config>> {
protected:
void SetUp() override {
zilliz::knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(0);
zilliz::knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(1);
zilliz::knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICE_ID);
std::string generator_type;
std::tie(index_type, generator_type, dim, nb, nq, k, train_cfg, search_cfg) = GetParam();
......@@ -90,29 +90,40 @@ class KnowhereWrapperTest
INSTANTIATE_TEST_CASE_P(WrapperParam, KnowhereWrapperTest,
Values(
//["Index type", "Generator type", "dim", "nb", "nq", "k", "build config", "search config"]
//std::make_tuple(IndexType::FAISS_IVFFLAT_CPU, "Default",
// 64, 100000, 10, 10,
// Config::object{{"nlist", 100}, {"dim", 64}},
// Config::object{{"dim", 64}, {"k", 10}, {"nprobe", 10}}
//),
//std::make_tuple(IndexType::FAISS_IVFFLAT_GPU, "Default",
// 64, 10000, 10, 10,
// Config::object{{"nlist", 100}, {"dim", 64}},
// Config::object{{"dim", 64}, {"k", 10}, {"nprobe", 40}}
//),
// std::make_tuple(IndexType::FAISS_IVFFLAT_MIX, "Default",
// 64, 100000, 10, 10,
// Config::object{{"nlist", 1000}, {"dim", 64}, {"metric_type", "L2"}},
// Config::object{{"dim", 64}, {"k", 10}, {"nprobe", 5}}
// ),
// std::make_tuple(IndexType::FAISS_IDMAP, "Default",
// 64, 100000, 10, 10,
// Config::object{{"dim", 64}, {"metric_type", "L2"}},
// Config::object{{"dim", 64}, {"k", 10}}
// ),
std::make_tuple(IndexType::FAISS_IVFFLAT_CPU, "Default",
64, 100000, 10, 10,
Config::object{{"nlist", 100}, {"dim", 64}, {"metric_type", "L2"}},
Config::object{{"dim", 64}, {"k", 10}, {"nprobe", 10}}
),
// to_gpu_test Failed
std::make_tuple(IndexType::FAISS_IVFFLAT_GPU, "Default",
DIM, NB, 10, 10,
Config::object{{"nlist", 100}, {"dim", DIM}, {"metric_type", "L2"}, {"gpu_id", DEVICE_ID}},
Config::object{{"dim", DIM}, {"k", 10}, {"nprobe", 40}}
),
std::make_tuple(IndexType::FAISS_IVFFLAT_MIX, "Default",
64, 100000, 10, 10,
Config::object{{"nlist", 1000}, {"dim", 64}, {"metric_type", "L2"}},
Config::object{{"dim", 64}, {"k", 10}, {"nprobe", 5}}
),
std::make_tuple(IndexType::FAISS_IDMAP, "Default",
64, 100000, 10, 10,
Config::object{{"dim", 64}, {"metric_type", "L2"}},
Config::object{{"dim", 64}, {"k", 10}}
),
std::make_tuple(IndexType::FAISS_IVFSQ8_CPU, "Default",
DIM, NB, 10, 10,
Config::object{{"dim", DIM}, {"nlist", 1000}, {"nbits", 8}, {"metric_type", "L2"}, {"gpu_id", DEVICE_ID}},
Config::object{{"dim", DIM}, {"k", 10}, {"nprobe", 5}}
),
std::make_tuple(IndexType::FAISS_IVFSQ8_GPU, "Default",
DIM, NB, 10, 10,
Config::object{{"dim", DIM}, {"nlist", 1000}, {"nbits", 8}, {"metric_type", "L2"}, {"gpu_id", DEVICE_ID}},
Config::object{{"dim", DIM}, {"k", 10}, {"nprobe", 5}}
),
std::make_tuple(IndexType::FAISS_IVFSQ8_MIX, "Default",
DIM, NB, 10, 10,
Config::object{{"dim", DIM}, {"nlist", 1000}, {"nbits", 8}, {"metric_type", "L2"}},
Config::object{{"dim", DIM}, {"nlist", 1000}, {"nbits", 8}, {"metric_type", "L2"}, {"gpu_id", DEVICE_ID}},
Config::object{{"dim", DIM}, {"k", 10}, {"nprobe", 5}}
)
// std::make_tuple(IndexType::NSG_MIX, "Default",
......@@ -151,19 +162,30 @@ TEST_P(KnowhereWrapperTest, to_gpu_test) {
index_->BuildAll(nb, xb.data(), ids.data(), train_cfg);
index_->Search(nq, xq.data(), res_dis.data(), res_ids.data(), search_cfg);
AssertResult(res_ids, res_dis);
{
index_->CopyToGpu(1);
auto dev_idx = index_->CopyToGpu(DEVICE_ID);
for (int i = 0; i < 10; ++i) {
dev_idx->Search(nq, xq.data(), res_dis.data(), res_ids.data(), search_cfg);
}
AssertResult(res_ids, res_dis);
}
std::string file_location = "/tmp/whatever";
write_index(index_, file_location);
auto new_index = read_index(file_location);
{
std::string file_location = "/tmp/test_gpu_file";
write_index(index_, file_location);
auto new_index = read_index(file_location);
auto dev_idx = new_index->CopyToGpu(1);
for (int i = 0; i < 10000; ++i) {
dev_idx->Search(nq, xq.data(), res_dis.data(), res_ids.data(), search_cfg);
auto dev_idx = new_index->CopyToGpu(DEVICE_ID);
for (int i = 0; i < 10; ++i) {
dev_idx->Search(nq, xq.data(), res_dis.data(), res_ids.data(), search_cfg);
}
AssertResult(res_ids, res_dis);
}
AssertResult(res_ids, res_dis);
}
TEST_P(KnowhereWrapperTest, to_cpu_test) {
// dev
}
TEST_P(KnowhereWrapperTest, serialize) {
......@@ -194,7 +216,7 @@ TEST_P(KnowhereWrapperTest, serialize) {
std::string file_location = "/tmp/whatever";
write_index(index_, file_location);
auto new_index = read_index(file_location);
EXPECT_EQ(new_index->GetType(), index_type);
EXPECT_EQ(new_index->GetType(), ConvertToCpuIndexType(index_type));
EXPECT_EQ(new_index->Dimension(), index_->Dimension());
EXPECT_EQ(new_index->Count(), index_->Count());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册