From c4a52b83c7321f72dbcfd52985f5b8f7f3926b01 Mon Sep 17 00:00:00 2001 From: zmxdream Date: Mon, 27 Jun 2022 14:49:58 +0800 Subject: [PATCH] [GPUPS]fix merge_grad&push_sparse (#43840) --- .../fleet/heter_ps/hashtable_kernel.cu | 17 ++--------------- .../framework/fleet/heter_ps/heter_comm_inl.h | 11 +++-------- .../fleet/heter_ps/heter_comm_kernel.cu | 6 +++--- 3 files changed, 8 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu index 6bc4e08241a..04842caef6b 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu @@ -112,20 +112,7 @@ __global__ void dy_mf_search_kernel(Table* table, } } else { if (keys[i] != 0) { - printf("warning::pull miss key: %d", keys[i]); - } - FeatureValue* cur = (FeatureValue*)(vals + i * pull_feature_value_size); - cur->delta_score = 0; - cur->show = 0; - cur->clk = 0; - cur->slot = -1; - cur->lr = 0; - cur->lr_g2sum = 0; - cur->mf_size = 0; - cur->mf_dim = 8; - cur->cpu_ptr; - for (int j = 0; j < cur->mf_dim + 1; j++) { - cur->mf[j] = 0; + printf("warning::pull miss key: %llu", keys[i]); } } } @@ -163,7 +150,7 @@ __global__ void dy_mf_update_kernel(Table* table, sgd.dy_mf_update_value(optimizer_config, (it.getter())->second, *cur); } else { if (keys[i] != 0) { - printf("warning::push miss key: %d", keys[i]); + printf("warning::push miss key: %llu", keys[i]); } } } diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index 0ac2c3cda58..ace533cb0c7 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -1026,14 +1026,9 @@ void HeterComm::push_sparse(int dev_num, auto d_shard_keys = memory::Alloc(place, len * sizeof(KeyType)); KeyType* d_shard_keys_ptr = reinterpret_cast(d_shard_keys->ptr()); - GradType* d_shard_grads_ptr; - if (!multi_mf_dim_) { - auto d_shard_grads = memory::Alloc(place, len * sizeof(GradType)); - d_shard_grads_ptr = reinterpret_cast(d_shard_grads->ptr()); - } else { - auto d_shard_grads = memory::Alloc(place, len * grad_value_size); - d_shard_grads_ptr = reinterpret_cast(d_shard_grads->ptr()); - } + auto d_shard_grads = memory::Alloc(place, len * grad_value_size); + GradType* d_shard_grads_ptr = + reinterpret_cast(d_shard_grads->ptr()); int uniq_len = len; dynamic_merge_grad(dev_num, d_keys, d_grads, len, uniq_len); diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu index 3ad3c5fa151..fd0dd1a72cc 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu @@ -153,7 +153,6 @@ __global__ void merge_gradients_kernel(const uint32_t* offset, size_t grad_value_size, DynamicGradMerger& merger_) { const size_t i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < n) { uint32_t start = offset[i]; uint32_t num = fea_num[i]; @@ -164,8 +163,9 @@ __global__ void merge_gradients_kernel(const uint32_t* offset, merger_.update_one(out, in); for (int j = 1; j < num; ++j) { ori_index = index[start + j]; - in = *(FeaturePushValue*)(input + size_t(ori_index) * grad_value_size); - merger_.merge_one(out, in); + FeaturePushValue& rhs = + *(FeaturePushValue*)(input + size_t(ori_index) * grad_value_size); + merger_.merge_one(out, rhs); } } } -- GitLab