From ef6ff4ef0554efa2d480151109ef7ebef24ed496 Mon Sep 17 00:00:00 2001 From: zmxdream Date: Fri, 15 Apr 2022 13:39:14 +0800 Subject: [PATCH] [XPUPS]fix hashtable_kernel.kps (#41790) * refactor heter comm kernel * update. test=develop * update calc_shard_offset. test=develop * update xpu kernel. test=develop * update args of calc_shard_offset * update. test=develop * remove customGradMerger * update. test=develop * update. test=develop * fix. test=develop * update. test=develop * update. test=develop * update optimizer kernel * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * fix. test=develop * fix. test=develop * add optimizer kernel. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix kunlun not support size_t. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * update hashtable. test=develop * update. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * update. test=develop * update. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * template init. test=develop * hashtable template init. test=develop * fix. test=develop * fix. test=devlop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix hashtable_kernel. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop Co-authored-by: WorgenZhang --- .../framework/fleet/heter_ps/hashtable.h | 2 +- .../fleet/heter_ps/hashtable_kernel.kps | 32 ++++---- .../framework/fleet/heter_ps/heter_comm.h | 3 +- .../framework/fleet/heter_ps/heter_comm_inl.h | 6 +- .../fleet/heter_ps/heter_comm_kernel.kps | 78 ++++++++++--------- 5 files changed, 63 insertions(+), 58 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable.h b/paddle/fluid/framework/fleet/heter_ps/hashtable.h index 6a51713d74c..b821ccecf0a 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable.h +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable.h @@ -74,7 +74,7 @@ class XPUCacheArray { // ValType* find(const KeyType& key) { return NULL; } // bool insert(const KeyType& key, const ValType& val) { return true; } - int prefetch(const int dev_id, XPUStream stream = NULL) {} + int prefetch(const int dev_id, XPUStream stream = NULL) { return 0; } size_t size() { return size_; } private: diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.kps b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.kps index 9d2a20a361e..55edf883271 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.kps +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.kps @@ -38,7 +38,7 @@ namespace framework { #if defined(PADDLE_WITH_XPU_KP) -__device__ void update_lr(float* w, float* g2sum, float g, // NOLINT +__device__ void update_lr(float& w, float& g2sum, float g, // NOLINT float scale) { __local__ float local_learning_rate; __local__ float local_initial_g2sum; @@ -55,17 +55,17 @@ __device__ void update_lr(float* w, float* g2sum, float g, // NOLINT sqrt(local_initial_g2sum / (local_initial_g2sum + g2sum)); double scaled_grad = g / scale; - (*w) += scaled_grad * ratio; + w += scaled_grad * ratio; if (w < local_min_bound) w = local_min_bound; if (w > local_max_bound) w = local_max_bound; add_g2sum += scaled_grad * scaled_grad; - (*g2sum) += add_g2sum; + g2sum += add_g2sum; } -__device__ void update_mf(int n, float* w, float* g2sum, const float* g, +__device__ void update_mf(int n, float* w, float& g2sum, const float* g, float scale) { __local__ float local_mf_learning_rate; __local__ float local_mf_initial_g2sum; @@ -92,16 +92,16 @@ __device__ void update_mf(int n, float* w, float* g2sum, const float* g, add_g2sum += scaled_grad * scaled_grad; } - (*g2sum) += add_g2sum / n; + g2sum += add_g2sum / n; } __device__ float xpu_rand_uniform() { return 0.1; } template -__device__ void update_value(ValType* val, const GradType* grad) { // NOLINT - (*val).slot = (*grad).slot; - (*val).show += (*grad).show; - (*val).clk += (*grad).clk; +__device__ void update_value(ValType& val, const GradType& grad) { // NOLINT + val.slot = grad.slot; + val.show += grad.show; + val.clk += grad.clk; __local__ float local_nonclk_coeff; __local__ float local_clk_coeff; @@ -114,25 +114,23 @@ __device__ void update_value(ValType* val, const GradType* grad) { // NOLINT GM2LM(optimizer_config::mf_create_thresholds, &local_mf_create_thresholds, sizeof(float)); - val.delta_score += local_nonclk_coeff * ((*grad).show - (*grad).clk) + - local_clk_coeff * (*grad).clk; + val.delta_score += + local_nonclk_coeff * (grad.show - grad.clk) + local_clk_coeff * grad.clk; - update_lr(&(*val).lr, &(*val).lr_g2sum, (*grad).lr_g, (*grad).show); + update_lr(val.lr, val.lr_g2sum, grad.lr_g, grad.show); if (val.mf_size == 0) { if (local_mf_create_thresholds <= - local_nonclk_coeff * ((*val).show - (*val).clk) + - local_clk_coeff * (*val).clk) { + local_nonclk_coeff * (val.show - val.clk) + local_clk_coeff * val.clk) { val.mf_size = MF_DIM + 1; val.mf[0] = 0; - xpu_rand_uniform(&); for (int i = 0; i < MF_DIM; ++i) { - (*val).mf[i + 1] = (xpu_rand_uniform()) * local_mf_initial_range; + val.mf[i + 1] = (xpu_rand_uniform()) * local_mf_initial_range; } } } else { - update_mf(MF_DIM, &val.mf[1], &val.mf[0], (*grad).mf_g, (*grad).show); + update_mf(MF_DIM, &val.mf[1], val.mf[0], grad.mf_g, grad.show); } } diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h index 817fd8d38ee..419bd716eb3 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h @@ -92,6 +92,7 @@ class HeterComm { nccl_inter_comms_ = inter_comms; node_size_ = comm_size; } +#endif bool need_transfer(int send_id, int receive_id) { return ((send_id / 4 != receive_id / 4) && (send_id + 4) % 8 != receive_id); @@ -101,8 +102,6 @@ class HeterComm { int get_transfer_devid(int send_id) { return (send_id + 4) % 8; } -#endif - void end_pass(); struct Node { diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index 3ced33b490d..1e66b3cb250 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -161,8 +161,8 @@ void HeterComm::destroy_storage(int start_index, nodes[i].key_storage); allocator->DeviceFree(resource_->dev_id(nodes[i].dev_num), nodes[i].val_storage); -#endif } +#endif } template @@ -804,9 +804,9 @@ void HeterComm::push_sparse(int dev_num, auto dst_place = platform::CPUPlace(); auto src_place = place; memory_copy(dst_place, h_left, src_place, d_left_ptr, - total_device * sizeof(int)); + total_device * sizeof(int), stream); memory_copy(dst_place, h_right, src_place, d_right_ptr, - total_device * sizeof(int)); + total_device * sizeof(int), stream); for (int i = 0; i < total_device; ++i) { int shard_len = h_right[i] - h_left[i] + 1; diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.kps b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.kps index c3e37d9eba3..a1923a7f601 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.kps +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.kps @@ -236,55 +236,62 @@ __global__ void fill_dvals_kernel(ValType* d_shard_vals, ValType* d_vals, // xpu implementation of heter_comm_kernel.h template -void fill_idx(T* idx, long long len, const StreamType& stream) { +void HeterCommKernel::fill_idx(T* idx, long long len, + const StreamType& stream) { fill_idx_kernel<<<4, 64, stream>>>(idx, len); } template -void calc_shard_offset(T* idx, T* left, T* right, long long len, int total_devs, - const StreamType& stream) { +void HeterCommKernel::calc_shard_offset(T* idx, T* left, T* right, + long long len, int total_devs, + const StreamType& stream) { calc_shard_offset_kernel<<<4, 64, stream>>>(idx, left, right, len, total_devs); } template -void calc_shard_index(KeyType* d_keys, long long len, T* shard_index, - int total_devs, const StreamType& stream) { +void HeterCommKernel::calc_shard_index(KeyType* d_keys, long long len, + T* shard_index, int total_devs, + const StreamType& stream) { calc_shard_index_kernel<<<4, 64, stream>>>( d_keys, len, shard_index, total_devs); } template -void fill_shard_key(KeyType* d_shard_keys, KeyType* d_keys, T* idx, - long long len, const StreamType& stream) { +void HeterCommKernel::fill_shard_key(KeyType* d_shard_keys, KeyType* d_keys, + T* idx, long long len, + const StreamType& stream) { fill_shard_key_kernel<<<4, 64, stream>>>(d_shard_keys, d_keys, idx, len); } template -void fill_shard_grads(KeyType* d_shard_keys, KeyType* d_keys, - GradType* d_shard_grads, GradType* d_grads, T* idx, - long long len, const StreamType& stream) { +void HeterCommKernel::fill_shard_grads(KeyType* d_shard_keys, KeyType* d_keys, + GradType* d_shard_grads, + GradType* d_grads, T* idx, long long len, + const StreamType& stream) { fill_shard_grads_kernel<<<4, 64, stream>>>( d_shard_keys, d_keys, d_shard_grads, d_grads, idx, len); } template -void fill_dvals(ValType* d_shard_vals, ValType* d_vals, T* idx, long long len, - const StreamType& stream) { +void HeterCommKernel::fill_dvals(ValType* d_shard_vals, ValType* d_vals, T* idx, + long long len, const StreamType& stream) { fill_dvals_kernel<<<4, 64, stream>>>(d_shard_vals, d_vals, idx, len); } template -void sort_pairs(void* d_temp_storage, size_t& temp_storage_bytes, // NOLINT - const KeyT* d_keys_in, // NOLINT - KeyT* d_keys_out, const ValueT* d_values_in, - ValueT* d_values_out, int num_items, int begin_bit, int end_bit, - StreamType stream, bool debug_synchronous) {} +void HeterCommKernel::sort_pairs(void* d_temp_storage, + size_t& temp_storage_bytes, // NOLINT + const KeyT* d_keys_in, // NOLINT + KeyT* d_keys_out, const ValueT* d_values_in, + ValueT* d_values_out, int num_items, + int begin_bit, int end_bit, StreamType stream, + bool debug_synchronous) {} template (int* idx, long long len, - const XPUStream& stream); -template void calc_shard_offset(int* idx, int* left, int* right, - long long len, int total_devs, - const XPUStream& stream); -template void calc_shard_index( +template void HeterCommKernel::fill_idx( + int* idx, long long len, const XPUStream& stream); +template void HeterCommKernel::calc_shard_offset( + int* idx, int* left, int* right, long long len, int total_devs, + const XPUStream& stream); +template void HeterCommKernel::calc_shard_index( unsigned long* d_keys, long long len, int* shard_index, int total_devs, const XPUStream& stream); -template void fill_shard_key( +template void HeterCommKernel::fill_shard_key( unsigned long* d_shard_keys, unsigned long* d_keys, int* idx, long long len, const XPUStream& stream); +template void HeterCommKernel::fill_shard_grads< + unsigned long, paddle::framework::FeaturePushValue, int, XPUStream>( + unsigned long* d_shard_keys, unsigned long* d_keys, + paddle::framework::FeaturePushValue* d_shard_grads, + paddle::framework::FeaturePushValue* d_grads, int* idx, long long len, + const XPUStream& stream); template void -fill_shard_grads(unsigned long* d_shard_keys, unsigned long* d_keys, - paddle::framework::FeaturePushValue* d_shard_grads, - paddle::framework::FeaturePushValue* d_grads, - int* idx, long long len, const XPUStream& stream); -template void fill_dvals( +HeterCommKernel::fill_dvals( paddle::framework::FeatureValue* d_shard_vals, paddle::framework::FeatureValue* d_vals, int* idx, long long len, const XPUStream& stream); -template void -sort_pairs( +template void HeterCommKernel::sort_pairs< + unsigned long, paddle::framework::FeaturePushValue, XPUStream>( void* d_temp_storage, size_t& temp_storage_bytes, // NOLINT const unsigned long* d_keys_in, // NOLINT @@ -326,14 +334,14 @@ sort_pairs( paddle::framework::FeaturePushValue* d_values_out, int num_items, int begin_bit, int end_bit, XPUStream stream, bool debug_synchronous); -template void sort_pairs( +template void HeterCommKernel::sort_pairs( void* d_temp_storage, size_t& temp_storage_bytes, // NOLINT const int* d_keys_in, // NOLINT int* d_keys_out, const int* d_values_in, int* d_values_out, int num_items, int begin_bit, int end_bit, XPUStream stream, bool debug_synchronous); -template void reduce_by_key< +template void HeterCommKernel::reduce_by_key< unsigned long*, unsigned long*, paddle::framework::FeaturePushValue*, paddle::framework::FeaturePushValue*, int*, XPUStream>( void* d_temp_storage, -- GitLab