未验证 提交 1078e064 编写于 作者: Z zmxdream 提交者: GitHub

[pglbox2.0]fix load into memory (#49389)

* fix load into memory

* fix load into memory

* fix code style
上级 df3f74df
......@@ -144,7 +144,7 @@ void PSGPUWrapper::add_key_to_local(const std::vector<uint64_t>& vec_data) {
iter++) {
uint64_t cur_key = *iter;
int shard_id = cur_key % thread_keys_shard_num_;
// TODO: feasign <-> slot <-> multi_dim
// TODO(lxsbupt): feasign <-> slot <-> multi_dim
this->thread_dim_keys_[i][shard_id][0].insert(cur_key);
}
};
......@@ -1304,6 +1304,7 @@ void PSGPUWrapper::LoadIntoMemory(bool is_shuffle) {
}
InitSlotInfo();
#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS)
if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) {
std::shared_ptr<HeterContext> gpu_task = gpu_task_pool_.Get();
gpu_task->Reset();
......@@ -1312,7 +1313,12 @@ void PSGPUWrapper::LoadIntoMemory(bool is_shuffle) {
} else if (hbm_sparse_table_initialized_ == false) {
SparseTableToHbm();
}
#else
std::shared_ptr<HeterContext> gpu_task = gpu_task_pool_.Get();
gpu_task->Reset();
gpu_task->pass_id_ = (uint16_t)(dataset_->GetPassID());
data_ready_channel_->Put(gpu_task);
#endif
VLOG(3) << "End LoadIntoMemory(), dataset[" << dataset_ << "]";
}
......@@ -1544,7 +1550,7 @@ void PSGPUWrapper::HbmToSparseTable() {
float* gpu_val =
reinterpret_cast<float*>(test_build_values + local_offset);
#ifdef PADDLE_WITH_PSLIB
// TODO: PSLIB DumpFill
// TODO(lxsbupt): PSLIB DumpFill
#endif
#ifdef PADDLE_WITH_PSCORE
accessor_wrapper_ptr->DumpFill(gpu_val, cpu_table_accessor_, mf_dim);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册