diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.h b/paddle/fluid/operators/elementwise/elementwise_add_op.h index 0e8d202a9aa384fe22dc612dc0c20e075b5aafc9..44c233be5750d4b48a63f3b274c5c6a5830c0482 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.h @@ -19,6 +19,11 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/math/blas.h" +#ifdef PADDLE_WITH_CUDA +#ifdef __NVCC__ +#include "cub/cub.cuh" +#endif +#endif namespace paddle { namespace operators { @@ -121,6 +126,20 @@ elementwise_add_grad(const framework::ExecutionContext &ctx, #ifdef PADDLE_WITH_CUDA #ifdef __NVCC__ +template +struct alignas(sizeof(T) * Size) AlignedVector { + T val[Size]; +}; + +template +inline int VectorizedSize(const T *pointer) { + uint64_t address = reinterpret_cast(pointer); + constexpr int vec4 = std::alignment_of>::value; // NOLINT + if (address % vec4 == 0) { + return 4; + } + return 1; +} template __global__ void MatrixColReduce(const T *__restrict__ in, T *__restrict__ out, size_t width, size_t height) { @@ -200,6 +219,45 @@ __global__ void FP16MatrixColReduce( if ((threadIdx.y == 0) && ((w) < width)) out[w] = sdata[0][threadIdx.x]; } } + +template +__global__ void MatrixReduceLongWidth(const T *__restrict__ in, T *out, + size_t width, size_t height) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + + for (; idx < width; idx += blockDim.x * gridDim.x) { + T sum = static_cast(0); + for (int row = 0; row < height; row++) { + sum += in[idx + row * width]; + } + + out[idx] = sum; + } +} + +template +__global__ void VecMatrixReduceLongWidth(const T *__restrict__ in, T *out, + size_t width, size_t height) { + using LoadT = AlignedVector; + int idx = threadIdx.x + blockIdx.x * blockDim.x; + int w = idx * VEC_SIZE; + int width_stride = blockDim.x * gridDim.x * VEC_SIZE; + for (; w < width; w += width) { + T zero = static_cast(0); + T sum[VEC_SIZE] = {zero}; + T tmp_vec[VEC_SIZE] = {zero}; + LoadT *tmp_ptr = reinterpret_cast(&tmp_vec); + for (int row = 0; row < height; row++) { + int offset = width * row + w; + *tmp_ptr = *reinterpret_cast(&in[offset]); + for (int v = 0; v < VEC_SIZE; v++) { + sum[v] += tmp_vec[v]; + } + } + + for (int v = 0; v < VEC_SIZE; v++) out[w + v] = sum[v]; + } +} #endif #endif bool static RunSpecialDims(const framework::DDim &dx_dims, @@ -301,6 +359,21 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel { *dout, ctx.GetPlace(), ctx.template device_context(), dy); } + // special optimization using cub + if (width == 1) { + int nums = height; + size_t temp_storage_bytes = 0; + auto err = cub::DeviceReduce::Sum(nullptr, temp_storage_bytes, + dout_data, out_data, nums, stream); + PADDLE_ENFORCE_CUDA_SUCCESS(err); + framework::Tensor tmp; + auto *temp_storage = tmp.mutable_data( + framework::make_ddim({static_cast(temp_storage_bytes)}), + ctx.GetPlace()); + err = cub::DeviceReduce::Sum(temp_storage, temp_storage_bytes, + dout_data, out_data, nums, stream); + PADDLE_ENFORCE_CUDA_SUCCESS(err); + } constexpr int block_x = 32; constexpr int block_y = 32; @@ -311,7 +384,8 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel { int max_blocks = std::max(max_physical_threads / (block_x * block_y), 1); int theory_block = (width + blocks.x - 1) / blocks.x; dim3 grids(std::min(theory_block, max_blocks)); - if (std::is_same::value) { + if (std::is_same::value && + (width / height) < 32) { const paddle::platform::float16 *ptr1 = reinterpret_cast(dout_data); paddle::platform::float16 *ptr2 = @@ -325,8 +399,24 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel { } return; } - MatrixColReduce<<>>( - dout_data, out_data, width, height); + + if (width / height < 32) { + MatrixColReduce<<>>( + dout_data, out_data, width, height); + } else { + size_t thread_nums = 1024; + size_t block_nums = (width + thread_nums - 1) / thread_nums; + int vec_size = VectorizedSize(dx_data); + if (vec_size == 4 && width % 4 == 0) { + block_nums = (width / vec_size + thread_nums - 1) / thread_nums; + VecMatrixReduceLongWidth<<>>( + dout_data, out_data, width, height); + } else { + MatrixReduceLongWidth<<>>( + dout_data, out_data, width, height); + } + } return; }