From b02f2aff04400dd590c2957e165ae5eea81889b1 Mon Sep 17 00:00:00 2001 From: zhaoyuchen2018 <45989343+zhaoyuchen2018@users.noreply.github.com> Date: Thu, 16 May 2019 21:29:22 +0800 Subject: [PATCH] Add conditional compile for gru opt (#17368) * improve gru unit performance. refine code test=develop Signed-off-by: zhaoyuchen * Add conditional compile for gru opt Not enable gru opt if compute ability < 700 test=develop Signed-off-by: zhaoyuchen * refine code. test=develop Signed-off-by: zhaoyuchen --- paddle/fluid/operators/math/gru_compute.cu | 44 ++++++++++++---------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/operators/math/gru_compute.cu b/paddle/fluid/operators/math/gru_compute.cu index 75417cced23..b564f990b49 100644 --- a/paddle/fluid/operators/math/gru_compute.cu +++ b/paddle/fluid/operators/math/gru_compute.cu @@ -30,25 +30,31 @@ struct GRUUnitFunctor { dim3 threads; dim3 grid; if (batch_size == 1) { - constexpr int tiled_size = 16; - int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size; - threads = dim3(tiled_size, 1); - grid = dim3(frame_blocks, 1); - - detail::KeFastCollectiveGruGate<<>>( - value.gate_value, value.prev_out_value, value.gate_weight, - value.reset_output_value, frame_size, active_gate); - - frame_blocks = (frame_size + tiled_size - 1) / tiled_size; - grid = dim3(frame_blocks, 1); - detail::KeFastCollectiveGruOut<<>>( - value.state_weight, value.prev_out_value, value.output_value, - value.gate_value, value.reset_output_value, frame_size, active_node, - origin_mode); - - return; + if (context.GetComputeCapability() >= 70) { + constexpr int tiled_size = 16; + int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size; + threads = dim3(tiled_size, 1); + grid = dim3(frame_blocks, 1); + detail::KeFastCollectiveGruGate< + T, tiled_size><<>>( + value.gate_value, value.prev_out_value, value.gate_weight, + value.reset_output_value, frame_size, active_gate); + + frame_blocks = (frame_size + tiled_size - 1) / tiled_size; + grid = dim3(frame_blocks, 1); + detail::KeFastCollectiveGruOut< + T, tiled_size><<>>( + value.state_weight, value.prev_out_value, value.output_value, + value.gate_value, value.reset_output_value, frame_size, active_node, + origin_mode); + + return; + } else { + int frame_per_block = frame_size <= 1024 ? frame_size : 1024; + int frame_blocks = (frame_size + 1024 - 1) / 1024; + threads = dim3(frame_per_block, 1); + grid = dim3(frame_blocks, 1); + } } else { threads = dim3(32, 32); grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32); -- GitLab