From 8a2caacdbc938c35a4535b4e47f5626b504e2972 Mon Sep 17 00:00:00 2001 From: zhaoyuchen2018 <45989343+zhaoyuchen2018@users.noreply.github.com> Date: Fri, 10 May 2019 14:38:44 +0800 Subject: [PATCH] improve gru unit performance. (#16338) refine code fuse cublas calling and kernels into one cuda kernel. test=develop Signed-off-by: zhaoyuchen --- .../operators/math/detail/gru_gpu_kernel.h | 116 ++++++++++++++++++ paddle/fluid/operators/math/gru_compute.cu | 21 +++- 2 files changed, 134 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/math/detail/gru_gpu_kernel.h b/paddle/fluid/operators/math/detail/gru_gpu_kernel.h index 6b57da104..77d7ff57c 100644 --- a/paddle/fluid/operators/math/detail/gru_gpu_kernel.h +++ b/paddle/fluid/operators/math/detail/gru_gpu_kernel.h @@ -101,6 +101,122 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output, output_value[frame_idx] = r_output; } +/* + * threads(tile_size, 1) + * grid(frame_blocks, 1) + */ +template +__global__ void KeFastCollectiveGruGate(T *gate_value, T *prev_output_value, + T *gate_weight, T *reset_output, + int frame_size, + ActivationType active_node) { + T xt_0 = 0.0f; + T a0 = 0.0f; + T c0 = 0.0f; + T b0[Tiled_size]; + + int COL = blockIdx.x * blockDim.x + threadIdx.x; + int Tiled_mask = ((1 << Tiled_size) - 1); + // Tiled matrix multiply using register shift, faster than sm. + if (prev_output_value) { + for (int k = 0; k < (((frame_size - 1) / Tiled_size) + 1); ++k) { + a0 = 0; + if ((threadIdx.x + k * Tiled_size) < frame_size) { + a0 = prev_output_value[threadIdx.x + (k * Tiled_size)]; + } + for (int i = 0; i < Tiled_size; i++) { + if (COL < frame_size * 2 && (i + k * Tiled_size) < frame_size) { + b0[i] = gate_weight[(i + k * Tiled_size) * frame_size * 2 + COL]; + } + } + + for (int i = 0; i < Tiled_size; ++i) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + c0 = c0 + __shfl_sync(Tiled_mask, a0, i, Tiled_size) * b0[i]; +#else + c0 = c0 + __shfl(a0, i, Tiled_size) * b0[i]; +#endif + } + } + } + + __syncthreads(); + + if (COL < frame_size * 2) { + xt_0 = gate_value[COL]; + c0 += xt_0; + c0 = forward::activation(c0, active_node); + gate_value[COL] = c0; + if (frame_size <= COL && COL < frame_size * 2) { + T htp_0 = 0.0; + if (prev_output_value) { + htp_0 = prev_output_value[COL - frame_size]; + } + reset_output[COL - frame_size] = c0 * htp_0; + } else if (COL < frame_size) { + gate_value[COL] = c0; + } + } +} + +/* + * threads(tile_size, 1) + * grid(frame_blocks, 1) + */ +template +__global__ void KeFastCollectiveGruOut(T *gate_weight, T *prev_out_value, + T *output_value, T *gate_value, + T *reset_value, int frame_size, + ActivationType act_node, + bool origin_mode) { + int COL = blockIdx.x * blockDim.x + threadIdx.x; + + T a0 = 0.0f; + T b0[Tiled_size]; + T c0 = 0.0f; + + int Tiled_mask = ((1 << Tiled_size) - 1); + //- Tiled matrix multiply with register shift + if (prev_out_value) { + for (int k = 0; k < (((frame_size - 1) / Tiled_size) + 1); ++k) { + a0 = 0; + if ((threadIdx.x + k * Tiled_size) < frame_size) { + a0 = reset_value[threadIdx.x + (k * Tiled_size)]; + } + for (int i = 0; i < Tiled_size; i++) { + if (COL < frame_size && (i + k * Tiled_size) < frame_size) { + b0[i] = gate_weight[(i + k * Tiled_size) * frame_size + COL]; + } + } + + for (int i = 0; i < Tiled_size; ++i) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + c0 = c0 + __shfl_sync(Tiled_mask, a0, i, Tiled_size) * b0[i]; +#else + c0 = c0 + __shfl(a0, i, Tiled_size) * b0[i]; +#endif + } + } + } + + __syncthreads(); + + if (COL < frame_size) { + T xt_0 = gate_value[COL + 2 * frame_size]; + T gta_0 = gate_value[COL]; + T htp_0 = 0; + if (prev_out_value) htp_0 = prev_out_value[COL]; + c0 += xt_0; + c0 = forward::activation(c0, act_node); + gate_value[COL + 2 * frame_size] = c0; + if (origin_mode) { + output_value[COL] = htp_0 * gta_0 + (1 - gta_0) * c0; + } else { + output_value[COL] = c0 * gta_0 + (1 - gta_0) * htp_0; + } + } +} + /* * threads(frame_per_block, batch_per_block) * grid(frame_blocks, batch_blocks) diff --git a/paddle/fluid/operators/math/gru_compute.cu b/paddle/fluid/operators/math/gru_compute.cu index ec7e4d222..75417cced 100644 --- a/paddle/fluid/operators/math/gru_compute.cu +++ b/paddle/fluid/operators/math/gru_compute.cu @@ -30,10 +30,25 @@ struct GRUUnitFunctor { dim3 threads; dim3 grid; if (batch_size == 1) { - int frame_per_block = frame_size <= 1024 ? frame_size : 1024; - int frame_blocks = (frame_size + 1024 - 1) / 1024; - threads = dim3(frame_per_block, 1); + constexpr int tiled_size = 16; + int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size; + threads = dim3(tiled_size, 1); grid = dim3(frame_blocks, 1); + + detail::KeFastCollectiveGruGate<<>>( + value.gate_value, value.prev_out_value, value.gate_weight, + value.reset_output_value, frame_size, active_gate); + + frame_blocks = (frame_size + tiled_size - 1) / tiled_size; + grid = dim3(frame_blocks, 1); + detail::KeFastCollectiveGruOut<<>>( + value.state_weight, value.prev_out_value, value.output_value, + value.gate_value, value.reset_output_value, frame_size, active_node, + origin_mode); + + return; } else { threads = dim3(32, 32); grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32); -- GitLab