From 0059404e77138c75b863c29cae349cd259691106 Mon Sep 17 00:00:00 2001 From: zhaoyuchen2018 <45989343+zhaoyuchen2018@users.noreply.github.com> Date: Tue, 5 Nov 2019 00:48:20 -0600 Subject: [PATCH] Fix ce ocr_recognition test fails (#20987) ocr_recognition fails, so add a path to handle small frame_size. test=develop --- .../operators/math/detail/gru_gpu_kernel.h | 11 ++-- paddle/fluid/operators/math/gru_compute.cu | 58 +++++++++++-------- 2 files changed, 39 insertions(+), 30 deletions(-) diff --git a/paddle/fluid/operators/math/detail/gru_gpu_kernel.h b/paddle/fluid/operators/math/detail/gru_gpu_kernel.h index 6d83cff1cda..77d7ff57cda 100644 --- a/paddle/fluid/operators/math/detail/gru_gpu_kernel.h +++ b/paddle/fluid/operators/math/detail/gru_gpu_kernel.h @@ -105,7 +105,7 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output, * threads(tile_size, 1) * grid(frame_blocks, 1) */ -template +template __global__ void KeFastCollectiveGruGate(T *gate_value, T *prev_output_value, T *gate_weight, T *reset_output, int frame_size, @@ -113,9 +113,7 @@ __global__ void KeFastCollectiveGruGate(T *gate_value, T *prev_output_value, T xt_0 = 0.0f; T a0 = 0.0f; T c0 = 0.0f; - - int Tiled_size = blockDim.x; - T b0[16]; + T b0[Tiled_size]; int COL = blockIdx.x * blockDim.x + threadIdx.x; int Tiled_mask = ((1 << Tiled_size) - 1); @@ -165,7 +163,7 @@ __global__ void KeFastCollectiveGruGate(T *gate_value, T *prev_output_value, * threads(tile_size, 1) * grid(frame_blocks, 1) */ -template +template __global__ void KeFastCollectiveGruOut(T *gate_weight, T *prev_out_value, T *output_value, T *gate_value, T *reset_value, int frame_size, @@ -174,10 +172,9 @@ __global__ void KeFastCollectiveGruOut(T *gate_weight, T *prev_out_value, int COL = blockIdx.x * blockDim.x + threadIdx.x; T a0 = 0.0f; - T b0[16]; + T b0[Tiled_size]; T c0 = 0.0f; - int Tiled_size = blockDim.x; int Tiled_mask = ((1 << Tiled_size) - 1); //- Tiled matrix multiply with register shift if (prev_out_value) { diff --git a/paddle/fluid/operators/math/gru_compute.cu b/paddle/fluid/operators/math/gru_compute.cu index b3695c97a12..cf3d57b0630 100644 --- a/paddle/fluid/operators/math/gru_compute.cu +++ b/paddle/fluid/operators/math/gru_compute.cu @@ -31,29 +31,41 @@ struct GRUUnitFunctor { dim3 grid; if (batch_size == 1) { if (context.GetComputeCapability() >= 70) { - auto ComputeTiledSize = [](int frame_size) { - if (frame_size >= 16) - return 16; - else if (frame_size < 16) - return 8; - }; - - auto tiled_size = ComputeTiledSize(frame_size); - int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size; - threads = dim3(tiled_size, 1); - grid = dim3(frame_blocks, 1); - - detail::KeFastCollectiveGruGate<<>>( - value.gate_value, value.prev_out_value, value.gate_weight, - value.reset_output_value, frame_size, active_gate); - - frame_blocks = (frame_size + tiled_size - 1) / tiled_size; - grid = dim3(frame_blocks, 1); - detail::KeFastCollectiveGruOut<<>>( - value.state_weight, value.prev_out_value, value.output_value, - value.gate_value, value.reset_output_value, frame_size, active_node, - origin_mode); - + if (frame_size < 16) { + constexpr int tiled_size = 8; + int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size; + threads = dim3(tiled_size, 1); + grid = dim3(frame_blocks, 1); + detail::KeFastCollectiveGruGate< + T, tiled_size><<>>( + value.gate_value, value.prev_out_value, value.gate_weight, + value.reset_output_value, frame_size, active_gate); + + frame_blocks = (frame_size + tiled_size - 1) / tiled_size; + grid = dim3(frame_blocks, 1); + detail::KeFastCollectiveGruOut< + T, tiled_size><<>>( + value.state_weight, value.prev_out_value, value.output_value, + value.gate_value, value.reset_output_value, frame_size, + active_node, origin_mode); + } else { + constexpr int tiled_size = 16; + int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size; + threads = dim3(tiled_size, 1); + grid = dim3(frame_blocks, 1); + detail::KeFastCollectiveGruGate< + T, tiled_size><<>>( + value.gate_value, value.prev_out_value, value.gate_weight, + value.reset_output_value, frame_size, active_gate); + + frame_blocks = (frame_size + tiled_size - 1) / tiled_size; + grid = dim3(frame_blocks, 1); + detail::KeFastCollectiveGruOut< + T, tiled_size><<>>( + value.state_weight, value.prev_out_value, value.output_value, + value.gate_value, value.reset_output_value, frame_size, + active_node, origin_mode); + } return; } else { int frame_per_block = frame_size <= 1024 ? frame_size : 1024; -- GitLab