From a7a4843c5fabf21cef13fb7de51d5148aa1acdc7 Mon Sep 17 00:00:00 2001 From: zmxdream Date: Wed, 29 Jun 2022 19:54:10 +0800 Subject: [PATCH] [GPUPS]Optimize dymf kernel (#43911) --- .../fleet/heter_ps/hashtable_kernel.cu | 51 +++-- .../framework/fleet/heter_ps/heter_comm_inl.h | 1 + .../fleet/heter_ps/heter_comm_kernel.cu | 174 ++++++++++++++++-- .../fleet/heter_ps/heter_comm_kernel.h | 33 +++- 4 files changed, 217 insertions(+), 42 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu index 04842caef6..a7e00bb083 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu @@ -89,31 +89,42 @@ __global__ void dy_mf_search_kernel(Table* table, char* vals, size_t len, size_t pull_feature_value_size) { - const size_t i = blockIdx.x * blockDim.x + threadIdx.x; - // return; + const size_t i = blockIdx.x * blockDim.y + threadIdx.y; + const size_t k = threadIdx.x; if (i < len) { auto it = table->find(keys[i]); - if (it != table->end()) { uint64_t offset = i * pull_feature_value_size; FeatureValue* cur = (FeatureValue*)(vals + offset); FeatureValue& input = *(FeatureValue*)(it->second); - cur->slot = input.slot; - cur->show = input.show; - cur->clk = input.clk; - cur->mf_dim = input.mf_dim; - cur->lr = input.lr; - cur->mf_size = input.mf_size; - cur->cpu_ptr = input.cpu_ptr; - cur->delta_score = input.delta_score; - cur->lr_g2sum = input.lr_g2sum; - for (int j = 0; j < cur->mf_dim + 1; ++j) { - cur->mf[j] = input.mf[j]; + char* cur_p = (char*)cur; + char* input_p = (char*)(&input); + int len = 9 + input.mf_dim + 1; + if (k == 3 || k == 6 || k == 7) + *(int*)(cur_p + k * 4) = *(int*)(input_p + k * 4); + else if (k < 8) + *(float*)(cur_p + k * 4) = *(float*)(input_p + k * 4); + else if (k == 8) { + *(uint64_t*)(cur_p + k * 4) = *(uint64_t*)(input_p + k * 4); + } else { + int len_per_thread = (len - 9) / (blockDim.y - 9); + int remain = (len - 9) % (blockDim.y - 9); + int real_len = len_per_thread; + if ((k - 9) < remain) real_len++; + int left = -1, right = -1; + if ((k - 9) < remain) { + left = 9 + (k - 9) * (len_per_thread + 1); + right = left + real_len; + } else { + left = 9 + remain * (len_per_thread + 1) + + (k - 9 - remain) * len_per_thread; + right = left + real_len; + } + for (int j = left; j < right; j++) + *(float*)(cur_p + (j + 1) * 4) = *(float*)(input_p + (j + 1) * 4); } } else { - if (keys[i] != 0) { - printf("warning::pull miss key: %llu", keys[i]); - } + if (keys[i] != 0) printf("pull miss key: %llu", keys[i]); } } } @@ -220,8 +231,10 @@ void HashTable::get(const KeyType* d_keys, if (len == 0) { return; } - const int grid_size = (len - 1) / BLOCK_SIZE_ + 1; - dy_mf_search_kernel<<>>( + dim3 block_dims(32, 32); + const int grid_size = (len - 1) / 32 + 1; + dim3 grid_dims(grid_size); + dy_mf_search_kernel<<>>( container_, d_keys, d_vals, len, pull_feature_value_size_); } diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index ace533cb0c..8952039299 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -760,6 +760,7 @@ void HeterComm::dynamic_merge_grad( (char*)d_grads, (char*)d_merge_grads_ptr, uniq_len, + max_mf_dim_, grad_value_size, merger_, stream); diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu index fd0dd1a72c..8a13d9abe6 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu @@ -144,28 +144,106 @@ __global__ void dy_mf_fill_shard_grads_kernel(KeyType* d_shard_keys, } } -__global__ void merge_gradients_kernel(const uint32_t* offset, - const uint32_t* fea_num, - const uint32_t* index, - const char* input, - char* output, - int n, - size_t grad_value_size, - DynamicGradMerger& merger_) { +// optimized version +template <> +__global__ void +dy_mf_fill_shard_grads_kernel( + FeatureKey* d_shard_keys, + FeatureKey* d_keys, + FeaturePushValue* d_shard_grads, + FeaturePushValue* d_grads, + int* idx, + size_t len, + size_t grad_value_size) { + const size_t i = blockIdx.x * blockDim.y + threadIdx.y; + const size_t k = threadIdx.x; + if (i < len) { + if (k == 0) { + d_shard_keys[i] = d_keys[idx[i]]; + } + FeaturePushValue* cur = + (FeaturePushValue*)((char*)d_shard_grads + i * grad_value_size); + FeaturePushValue& input = *( + FeaturePushValue*)((char*)d_grads + uint64_t(idx[i]) * grad_value_size); + char* cur_p = (char*)cur; + char* input_p = (char*)(&input); + int len = 5 + input.mf_dim; + if (k == 2 || k == 4) + *(int*)(cur_p + k * 4) = *(int*)(input_p + k * 4); + else if (k < 5) + *(float*)(cur_p + k * 4) = *(float*)(input_p + k * 4); + else { + int len_per_thread = (len - 5) / (blockDim.y - 5); + int remain = (len - 5) % (blockDim.y - 5); + int real_len = len_per_thread; + if ((k - 5) < remain) real_len++; + int left = -1, right = -1; + if ((k - 5) < remain) { + left = 5 + (k - 5) * (len_per_thread + 1); + right = left + real_len; + } else { + left = 5 + remain * (len_per_thread + 1) + + (k - 5 - remain) * len_per_thread; + right = left + real_len; + } + for (int j = left; j < right; j++) + *(float*)(cur_p + j * 4) = *(float*)(input_p + j * 4); + } + } +} + +__global__ void merge_gradients_basic_kernel(const uint32_t* offset, + const uint32_t* fea_num, + const uint32_t* index, + const char* input, + char* output, + int n, + size_t grad_value_size, + DynamicGradMerger& merger) { const size_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i < n) { uint32_t start = offset[i]; uint32_t num = fea_num[i]; int ori_index = index[start]; - FeaturePushValue& out = *(FeaturePushValue*)(output + i * grad_value_size); + FeaturePushValue& lhs = *(FeaturePushValue*)(output + i * grad_value_size); FeaturePushValue& in = *(FeaturePushValue*)(input + size_t(ori_index) * grad_value_size); - merger_.update_one(out, in); + merger.update_basic(lhs, in); for (int j = 1; j < num; ++j) { ori_index = index[start + j]; FeaturePushValue& rhs = *(FeaturePushValue*)(input + size_t(ori_index) * grad_value_size); - merger_.merge_one(out, rhs); + merger.merge_basic(lhs, rhs); + } + } +} + +__global__ void merge_gradients_embedx_kernel(const uint32_t* offset, + const uint32_t* fea_num, + const uint32_t* index, + const char* input, + char* output, + int n, + size_t grad_dim, + size_t grad_value_size, + DynamicGradMerger& merger) { + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) { + size_t value_idx = i / grad_dim; + size_t field_idx = i % grad_dim; + uint32_t start = offset[value_idx]; + uint32_t num = fea_num[value_idx]; + int ori_index = index[start]; + FeaturePushValue& in = + *(FeaturePushValue*)(input + size_t(ori_index) * grad_value_size); + FeaturePushValue& lhs = + *(FeaturePushValue*)(output + value_idx * grad_value_size); + merger.update_embedx(lhs, in, field_idx); + for (int j = 1; j < num; ++j) { + int ori_index = index[start + j]; + FeaturePushValue& rhs = + *(FeaturePushValue*)(input + size_t(ori_index) * grad_value_size); + merger.merge_embedx(lhs, rhs, field_idx); } } } @@ -184,6 +262,49 @@ __global__ void dy_mf_fill_dvals_kernel(ValType* d_shard_vals, } } +// optimized version +template <> +__global__ void dy_mf_fill_dvals_kernel( + FeatureValue* d_shard_vals, + FeatureValue* d_vals, + int* idx, + size_t len, + size_t val_size) { + const size_t i = blockIdx.x * blockDim.y + threadIdx.y; + const size_t k = threadIdx.x; + if (i < len) { + uint64_t new_offset = uint64_t(idx[i]) * val_size; + FeatureValue* cur = (FeatureValue*)((char*)d_vals + new_offset); + FeatureValue& input = *(FeatureValue*)((char*)d_shard_vals + i * val_size); + char* cur_p = (char*)cur; + char* input_p = (char*)(&input); + int len = 9 + input.mf_dim + 1; + if (k == 3 || k == 6 || k == 7) + *(int*)(cur_p + k * 4) = *(int*)(input_p + k * 4); + else if (k < 8) + *(float*)(cur_p + k * 4) = *(float*)(input_p + k * 4); + else if (k == 8) { + *(uint64_t*)(cur_p + k * 4) = *(uint64_t*)(input_p + k * 4); + } else { + int len_per_thread = (len - 9) / (blockDim.x - 9); + int remain = (len - 9) % (blockDim.y - 9); + int real_len = len_per_thread; + if ((k - 9) < remain) real_len++; + int left = -1, right = -1; + if ((k - 9) < remain) { + left = 9 + (k - 9) * (len_per_thread + 1); + right = left + real_len; + } else { + left = 9 + remain * (len_per_thread + 1) + + (k - 9 - remain) * len_per_thread; + right = left + real_len; + } + for (int j = left; j < right; j++) + *(float*)(cur_p + (j + 1) * 4) = *(float*)(input_p + (j + 1) * 4); + } + } +} + // cuda implemention of heter_comm_kernel.h template void HeterCommKernel::fill_idx(T* idx, @@ -321,9 +442,12 @@ void HeterCommKernel::dy_mf_fill_shard_grads(KeyType* d_shard_keys, long long len, size_t grad_value_size, const StreamType& stream) { - int grid_size = (len - 1) / block_size_ + 1; + // int grid_size = (len - 1) / block_size_ + 1; size_t c_len = (size_t)len; - dy_mf_fill_shard_grads_kernel<<>>( + dim3 block_dims(32, 32); + const size_t grid_size = (len - 1) / 32 + 1; + dim3 grid_dims(grid_size); + dy_mf_fill_shard_grads_kernel<<>>( d_shard_keys, d_keys, d_shard_grads, @@ -340,12 +464,26 @@ void HeterCommKernel::merge_gradient(const uint32_t* offset, const char* input, char* output, int n, + size_t grad_dim, size_t grad_value_size, DynamicGradMerger& merger_, const StreamType& stream) { int grid_size = (n - 1) / block_size_ + 1; - merge_gradients_kernel<<>>( + merge_gradients_basic_kernel<<>>( offset, fea_num, index, input, output, n, grad_value_size, merger_); + if (grad_dim > 0) { + int grid_size2 = (n * grad_dim - 1) / block_size_ + 1; + merge_gradients_embedx_kernel<<>>( + offset, + fea_num, + index, + input, + output, + n * grad_dim, + grad_dim, + grad_value_size, + merger_); + } } template @@ -355,9 +493,12 @@ void HeterCommKernel::dy_mf_fill_dvals(ValType* d_shard_vals, long long len, size_t val_size, const StreamType& stream) { - int grid_size = (len - 1) / block_size_ + 1; + // int grid_size = (len - 1) / block_size_ + 1; size_t c_len = (size_t)len; - dy_mf_fill_dvals_kernel<<>>( + dim3 block_dims(32, 32); + const size_t grid_size_ = (len - 1) / 32 + 1; + dim3 grid_dims(grid_size_); + dy_mf_fill_dvals_kernel<<>>( d_shard_vals, d_vals, idx, c_len, val_size); } @@ -487,6 +628,7 @@ template void HeterCommKernel::merge_gradient( const char* input, char* output, int n, + size_t grad_dim, size_t grad_value_size, DynamicGradMerger& merger_, const cudaStream_t& stream); diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h index d1555dc2e0..6859161a5f 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h @@ -42,23 +42,41 @@ struct DynamicGradMerger { } template - __device__ __forceinline__ void update_one(T& output, const T& input) { + __device__ __forceinline__ void update_basic(T& output, const T& input) { output.slot = input.slot; output.show = input.show; output.clk = input.clk; output.mf_dim = input.mf_dim; output.lr_g = input.lr_g; - for (int i = 0; i < output.mf_dim; ++i) { - output.mf_g[i] = input.mf_g[i]; - } + // for (int i = 0; i < output.mf_dim; ++i) { + // output.mf_g[i] = input.mf_g[i]; + //} } template - __device__ __forceinline__ void merge_one(T& output, const T& input) { + __device__ __forceinline__ void merge_basic(T& output, const T& input) { output.show += input.show; output.clk += input.clk; output.lr_g += input.lr_g; - for (int i = 0; i < input.mf_dim; ++i) { - output.mf_g[i] += input.mf_g[i]; + // for (int i = 0; i < input.mf_dim; ++i) { + // output.mf_g[i] += input.mf_g[i]; + //} + } + + template + __device__ __forceinline__ void update_embedx(T& output, + const T& input, + size_t embedx_id) { + if (embedx_id < output.mf_dim) { + output.mf_g[embedx_id] = input.mf_g[embedx_id]; + } + } + + template + __device__ __forceinline__ void merge_embedx(T& output, + const T& input, + size_t embedx_id) { + if (embedx_id < output.mf_dim) { + output.mf_g[embedx_id] += input.mf_g[embedx_id]; } } }; @@ -165,6 +183,7 @@ class HeterCommKernel { const char* input, char* output, int n, + size_t grad_dim, size_t grad_value_size, DynamicGradMerger& merger_, const StreamType& stream); -- GitLab