diff --git a/paddle/fluid/operators/math/gru_compute.cu b/paddle/fluid/operators/math/gru_compute.cu index 75417cced237c48dda1f6e87c0647b10a66d0907..b564f990b4920a3a01b6ce0dd53e8f5e5d0464aa 100644 --- a/paddle/fluid/operators/math/gru_compute.cu +++ b/paddle/fluid/operators/math/gru_compute.cu @@ -30,25 +30,31 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> { dim3 threads; dim3 grid; if (batch_size == 1) { - constexpr int tiled_size = 16; - int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size; - threads = dim3(tiled_size, 1); - grid = dim3(frame_blocks, 1); - - detail::KeFastCollectiveGruGate<T, - tiled_size><<<grid, threads, 0, stream>>>( - value.gate_value, value.prev_out_value, value.gate_weight, - value.reset_output_value, frame_size, active_gate); - - frame_blocks = (frame_size + tiled_size - 1) / tiled_size; - grid = dim3(frame_blocks, 1); - detail::KeFastCollectiveGruOut<T, - tiled_size><<<grid, threads, 0, stream>>>( - value.state_weight, value.prev_out_value, value.output_value, - value.gate_value, value.reset_output_value, frame_size, active_node, - origin_mode); - - return; + if (context.GetComputeCapability() >= 70) { + constexpr int tiled_size = 16; + int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size; + threads = dim3(tiled_size, 1); + grid = dim3(frame_blocks, 1); + detail::KeFastCollectiveGruGate< + T, tiled_size><<<grid, threads, 0, stream>>>( + value.gate_value, value.prev_out_value, value.gate_weight, + value.reset_output_value, frame_size, active_gate); + + frame_blocks = (frame_size + tiled_size - 1) / tiled_size; + grid = dim3(frame_blocks, 1); + detail::KeFastCollectiveGruOut< + T, tiled_size><<<grid, threads, 0, stream>>>( + value.state_weight, value.prev_out_value, value.output_value, + value.gate_value, value.reset_output_value, frame_size, active_node, + origin_mode); + + return; + } else { + int frame_per_block = frame_size <= 1024 ? frame_size : 1024; + int frame_blocks = (frame_size + 1024 - 1) / 1024; + threads = dim3(frame_per_block, 1); + grid = dim3(frame_blocks, 1); + } } else { threads = dim3(32, 32); grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);