From 873b32de2768362d7937feeb6cd630052aa424a9 Mon Sep 17 00:00:00 2001 From: zhaoyuchen2018 <45989343+zhaoyuchen2018@users.noreply.github.com> Date: Mon, 2 Dec 2019 13:24:12 +0800 Subject: [PATCH] [cherry-pick] Fix gru as small frame_size has error. (#20922) (#21440) seems shuffle_sync cannot handle small size test=develop Signed-off-by: zhaoyuchen --- paddle/fluid/operators/math/gru_compute.cu | 52 +++++++++++++++------- 1 file changed, 35 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/operators/math/gru_compute.cu b/paddle/fluid/operators/math/gru_compute.cu index b564f990b49..cf3d57b0630 100644 --- a/paddle/fluid/operators/math/gru_compute.cu +++ b/paddle/fluid/operators/math/gru_compute.cu @@ -31,23 +31,41 @@ struct GRUUnitFunctor { dim3 grid; if (batch_size == 1) { if (context.GetComputeCapability() >= 70) { - constexpr int tiled_size = 16; - int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size; - threads = dim3(tiled_size, 1); - grid = dim3(frame_blocks, 1); - detail::KeFastCollectiveGruGate< - T, tiled_size><<>>( - value.gate_value, value.prev_out_value, value.gate_weight, - value.reset_output_value, frame_size, active_gate); - - frame_blocks = (frame_size + tiled_size - 1) / tiled_size; - grid = dim3(frame_blocks, 1); - detail::KeFastCollectiveGruOut< - T, tiled_size><<>>( - value.state_weight, value.prev_out_value, value.output_value, - value.gate_value, value.reset_output_value, frame_size, active_node, - origin_mode); - + if (frame_size < 16) { + constexpr int tiled_size = 8; + int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size; + threads = dim3(tiled_size, 1); + grid = dim3(frame_blocks, 1); + detail::KeFastCollectiveGruGate< + T, tiled_size><<>>( + value.gate_value, value.prev_out_value, value.gate_weight, + value.reset_output_value, frame_size, active_gate); + + frame_blocks = (frame_size + tiled_size - 1) / tiled_size; + grid = dim3(frame_blocks, 1); + detail::KeFastCollectiveGruOut< + T, tiled_size><<>>( + value.state_weight, value.prev_out_value, value.output_value, + value.gate_value, value.reset_output_value, frame_size, + active_node, origin_mode); + } else { + constexpr int tiled_size = 16; + int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size; + threads = dim3(tiled_size, 1); + grid = dim3(frame_blocks, 1); + detail::KeFastCollectiveGruGate< + T, tiled_size><<>>( + value.gate_value, value.prev_out_value, value.gate_weight, + value.reset_output_value, frame_size, active_gate); + + frame_blocks = (frame_size + tiled_size - 1) / tiled_size; + grid = dim3(frame_blocks, 1); + detail::KeFastCollectiveGruOut< + T, tiled_size><<>>( + value.state_weight, value.prev_out_value, value.output_value, + value.gate_value, value.reset_output_value, frame_size, + active_node, origin_mode); + } return; } else { int frame_per_block = frame_size <= 1024 ? frame_size : 1024; -- GitLab