提交 1141bbbb 编写于 作者: X xiaojun.lin

update v2


Former-commit-id: 1240499e9e5f0042a2296300b00588ed11dc07c3
上级 bbdc611c
......@@ -696,6 +696,54 @@ TEST_F(GPURESTEST, copyandsearch) {
std::thread search_thread(search_func);
std::thread load_thread(load_func);
search_thread.join();
load_thread.join();
tc.RecordSection("Copy&search total");
}
TEST_F(GPURESTEST, TrainAndSearch) {
index_type = "GPUIVFSQ";
index_ = IndexFactory(index_type);
auto conf = std::make_shared<knowhere::IVFSQCfg>();
conf->nlist = 1638;
conf->d = dim;
conf->gpu_id = device_id;
conf->metric_type = knowhere::METRICTYPE::L2;
conf->k = k;
conf->nbits = 8;
conf->nprobe = 1;
auto preprocessor = index_->BuildPreprocessor(base_dataset, conf);
index_->set_preprocessor(preprocessor);
auto model = index_->Train(base_dataset, conf);
auto new_index = IndexFactory(index_type);
new_index->set_index_model(model);
new_index->Add(base_dataset, conf);
auto cpu_idx = knowhere::cloner::CopyGpuToCpu(new_index, knowhere::Config());
cpu_idx->Seal();
auto search_idx = knowhere::cloner::CopyCpuToGpu(cpu_idx, device_id, knowhere::Config());
constexpr int train_count = 1;
constexpr int search_count = 5000;
auto train_stage = [&] {
for (int i = 0; i < train_count; ++i) {
auto model = index_->Train(base_dataset, conf);
auto test_idx = IndexFactory(index_type);
test_idx->set_index_model(model);
test_idx->Add(base_dataset, conf);
}
};
auto search_stage = [&](knowhere::VectorIndexPtr& search_idx) {
for (int i = 0; i < search_count; ++i) {
auto result = search_idx->Search(query_dataset, conf);
AssertAnns(result, nq, k);
}
};
// TimeRecorder tc("record");
// train_stage();
// tc.RecordSection("train cost");
// search_stage(search_idx);
// tc.RecordSection("search cost");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册