diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.h b/paddle/fluid/operators/elementwise/elementwise_add_op.h index db5c6eca6e5068648e51052be2916553bcf57328..0ef79667b8d66df8beaf512a95820e119816cbff 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.h @@ -19,6 +19,7 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" #ifdef PADDLE_WITH_CUDA #ifdef __NVCC__ #include "cub/cub.cuh" @@ -176,6 +177,25 @@ __global__ void MatrixColReduce(const T *__restrict__ in, T *__restrict__ out, } } +template +__global__ void VecFP16MatrixColReduce(const __half2 *__restrict__ in, + __half2 *__restrict__ out, size_t width, + size_t height) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + int by = blockIdx.y; + __half2 zero = __half2half2(static_cast<__half>(0)); + const int cols = width / 2; + for (; idx < cols; idx += blockDim.x * gridDim.x) { + __half2 sum = zero; + for (int row = 0; row < SIZE; row++) { + int index = idx + (row + by * SIZE) * cols; + sum = __hadd2(sum, in[index]); + } + + atomicAdd(&(out[idx]), sum); + } +} + template __global__ void MatrixReduceLongWidth(const T *__restrict__ in, T *out, size_t width, size_t height) { @@ -198,7 +218,7 @@ __global__ void VecMatrixReduceLongWidth(const T *__restrict__ in, T *out, int idx = threadIdx.x + blockIdx.x * blockDim.x; int w = idx * VEC_SIZE; int width_stride = blockDim.x * gridDim.x * VEC_SIZE; - for (; w < width; w += width) { + for (; w < width; w += width_stride) { T zero = static_cast(0); T sum[VEC_SIZE] = {zero}; T tmp_vec[VEC_SIZE] = {zero}; @@ -341,6 +361,23 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel { int max_blocks = std::max(max_physical_threads / (block_x * block_y), 1); int theory_block = (width + blocks.x - 1) / blocks.x; dim3 grids(std::min(theory_block, max_blocks)); + if (std::is_same::value && width < 2048 && + width % 2 == 0 && height % 64 == 0) { + auto &dev_ctx = + ctx.template device_context(); + math::SetConstant functor; + if (dout->dims() == dx->dims()) + functor(dev_ctx, dy, static_cast(0)); + else + functor(dev_ctx, dx, static_cast(0)); + const __half2 *ptr1 = reinterpret_cast(dout_data); + __half2 *ptr2 = reinterpret_cast<__half2 *>(out_data); + const int threads = 128; + dim3 grid(1, (height + 64 - 1) / 64); + VecFP16MatrixColReduce<64><<>>(ptr1, ptr2, + width, height); + return; + } if (width / height < 32) { MatrixColReduce<<>>(