[Paddle-TRT] Different behaviors between blockReduce and warpReduce
Created by: zlsh80826
- Paddle Version: develop
- GPU/cuda-10.2/cudnn7.6.5
- Ubuntu16.04
Hello! There are different behaviors between blockReduceSum and warpReduceSum. warpReduceSum makes each thread in the warp has a copy of summation. However, blockReduceSum only ensures the first warp has the summation, which results in one more block synchronization like here.
I suggest two solutions,
- Change the behavior of blockReduceSum to AllReduce, which computes the summation and ensures each thread in the block have the summation copy.
We only need to change one line to do this. Change the code
val = (threadIdx.x < block_span) ? shared[lane] : static_cast<T>(0.0f);
toval = (lane < block_span) ? shared[lane] : static_cast<T>(0.0f);
. By doing this, each warp has a copy of first round result. Thus, all threads in the block have the reduced values after the second warpReduce. - In blockReduceSum, add a condition which lets only the first warp in the block does the second reduction. Because we only need to ensure
threadIdx.x
has the correct reduced value (SoftmaxKernelWithEltadd is the only caller), so only the first warp needs to do the second reduction. i.e.
if (threadIdx.x < warpSize) {
val = (threadIdx.x < block_span) ? shared[lane] : static_cast<T>(0.0f);
val = warpReduceSum<T>(val, mask);
}
In my experiments, the first one has higher performance, because we can delete the unnecessary shared memory copy and __syncthreads after calling the blockReduceXXX function.
BTW, the same issue exists in the blockReduceMax too.