diff --git a/paddle/fluid/operators/math/detail/gru_gpu_kernel.h b/paddle/fluid/operators/math/detail/gru_gpu_kernel.h index 77d7ff57cda7416705bed7eb393366e1f87232a0..6d83cff1cdaeea0bfb15e8ecd9ead71af6fa82b5 100644 --- a/paddle/fluid/operators/math/detail/gru_gpu_kernel.h +++ b/paddle/fluid/operators/math/detail/gru_gpu_kernel.h @@ -105,7 +105,7 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output, * threads(tile_size, 1) * grid(frame_blocks, 1) */ -template +template __global__ void KeFastCollectiveGruGate(T *gate_value, T *prev_output_value, T *gate_weight, T *reset_output, int frame_size, @@ -113,7 +113,9 @@ __global__ void KeFastCollectiveGruGate(T *gate_value, T *prev_output_value, T xt_0 = 0.0f; T a0 = 0.0f; T c0 = 0.0f; - T b0[Tiled_size]; + + int Tiled_size = blockDim.x; + T b0[16]; int COL = blockIdx.x * blockDim.x + threadIdx.x; int Tiled_mask = ((1 << Tiled_size) - 1); @@ -163,7 +165,7 @@ __global__ void KeFastCollectiveGruGate(T *gate_value, T *prev_output_value, * threads(tile_size, 1) * grid(frame_blocks, 1) */ -template +template __global__ void KeFastCollectiveGruOut(T *gate_weight, T *prev_out_value, T *output_value, T *gate_value, T *reset_value, int frame_size, @@ -172,9 +174,10 @@ __global__ void KeFastCollectiveGruOut(T *gate_weight, T *prev_out_value, int COL = blockIdx.x * blockDim.x + threadIdx.x; T a0 = 0.0f; - T b0[Tiled_size]; + T b0[16]; T c0 = 0.0f; + int Tiled_size = blockDim.x; int Tiled_mask = ((1 << Tiled_size) - 1); //- Tiled matrix multiply with register shift if (prev_out_value) { diff --git a/paddle/fluid/operators/math/gru_compute.cu b/paddle/fluid/operators/math/gru_compute.cu index b564f990b4920a3a01b6ce0dd53e8f5e5d0464aa..b3695c97a12863a2deaecdb1f0a344593e528d46 100644 --- a/paddle/fluid/operators/math/gru_compute.cu +++ b/paddle/fluid/operators/math/gru_compute.cu @@ -31,19 +31,25 @@ struct GRUUnitFunctor { dim3 grid; if (batch_size == 1) { if (context.GetComputeCapability() >= 70) { - constexpr int tiled_size = 16; + auto ComputeTiledSize = [](int frame_size) { + if (frame_size >= 16) + return 16; + else if (frame_size < 16) + return 8; + }; + + auto tiled_size = ComputeTiledSize(frame_size); int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size; threads = dim3(tiled_size, 1); grid = dim3(frame_blocks, 1); - detail::KeFastCollectiveGruGate< - T, tiled_size><<>>( + + detail::KeFastCollectiveGruGate<<>>( value.gate_value, value.prev_out_value, value.gate_weight, value.reset_output_value, frame_size, active_gate); frame_blocks = (frame_size + tiled_size - 1) / tiled_size; grid = dim3(frame_blocks, 1); - detail::KeFastCollectiveGruOut< - T, tiled_size><<>>( + detail::KeFastCollectiveGruOut<<>>( value.state_weight, value.prev_out_value, value.output_value, value.gate_value, value.reset_output_value, frame_size, active_node, origin_mode);