From eae318569f97b3fd01394112000a6ddfbd0f693f Mon Sep 17 00:00:00 2001 From: WangXi Date: Mon, 5 Jul 2021 18:12:49 +0800 Subject: [PATCH] Add fused elemwise gelu and optimize performance (#33480) --- .../elementwise/elementwise_op_function.h | 179 ++++++++++-------- .../fused/fused_elemwise_activation_op.cc | 2 +- .../fused/fused_elemwise_activation_op.h | 17 ++ paddle/fluid/operators/math/functors.h | 58 ++++++ .../test_fused_elemwise_activation_op.py | 23 +++ 5 files changed, 202 insertions(+), 77 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index dce9a54f39..cc291ae471 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -57,6 +57,10 @@ constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024; *mod = dividend_copy % divisor; \ } while (0) +#define DIVUP(x, y) (((x) + (y)-1) / (y)) + +#define ROUNDUP(x, y) (DIVUP((x), (y)) * (y)) + namespace paddle { namespace operators { @@ -2156,10 +2160,10 @@ template <<>>( @@ -2585,106 +2589,129 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel( const T *x, const T *y, const T *intermediate_out, const T *out, const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op, DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) { - int j = blockIdx.x; - int i = threadIdx.x; - int tid = threadIdx.x; - T val(0), inter_val(0); - int64_t tmp_out_idx, x_idx, y_idx; + __shared__ T sdata[BLOCK_Y][BLOCK_X]; + size_t idx = threadIdx.x + BLOCK_X * blockIdx.x; + size_t width_stride = gridDim.x * BLOCK_X; + + size_t full_w = ROUNDUP(w, BLOCK_X); + T zero = static_cast(0); - do { - int offset = i * w + j; + for (size_t j = idx; j < full_w; j += width_stride) { + T val(0), inter_val(0); + if (j < w) { + for (size_t i = threadIdx.y; i < h; i += BLOCK_Y) { + size_t offset = i * w + j; - tmp_out_idx = BcastY ? j : offset; - y_idx = BcastY ? j : offset; - x_idx = BcastY ? offset : j; - T x_val = (x == nullptr) ? zero : x[x_idx]; - T y_val = (y == nullptr) ? zero : y[y_idx]; + size_t tmp_out_idx = BcastY ? j : offset; + size_t y_idx = BcastY ? j : offset; + size_t x_idx = BcastY ? offset : j; + T x_val = (x == nullptr) ? zero : x[x_idx]; + T y_val = (y == nullptr) ? zero : y[y_idx]; - if (SameShapeOfIntermediateOutAndOut) { - tmp_out_idx = offset; - } + if (SameShapeOfIntermediateOutAndOut) { + tmp_out_idx = offset; + } - if (dx != nullptr) { - T tmp = UseIntermediateOut + if (dx != nullptr) { + T tmp = + UseIntermediateOut ? dx_op.UseIntermediateOut(x_val, y_val, intermediate_out[tmp_out_idx], out[offset], dout[offset]) : dx_op.Recompute(x_val, y_val, out[offset], dout[offset]); - if (BcastY) { - dx[x_idx] = tmp; - } else { - val += tmp; - } - } - if (dy != nullptr) { - T tmp = UseIntermediateOut + if (BcastY) { + dx[x_idx] = tmp; + } else { + val += tmp; + } + } + if (dy != nullptr) { + T tmp = + UseIntermediateOut ? dy_op.UseIntermediateOut(x_val, y_val, intermediate_out[tmp_out_idx], out[offset], dout[offset]) : dy_op.Recompute(x_val, y_val, out[offset], dout[offset]); - if (BcastY) { - val += tmp; - } else { - dy[y_idx] = tmp; - } - } - if (d_intermediate != nullptr) { - T tmp = UseIntermediateOut - ? dintermediate_op.UseIntermediateOut( - y[y_idx], intermediate_out[tmp_out_idx], out[offset], - dout[offset]) - : dintermediate_op.Recompute(x_val, y_val, out[offset], - dout[offset]); - if (SameShapeOfIntermediateOutAndOut) { - d_intermediate[tmp_out_idx] = tmp; - } else { - inter_val += tmp; + if (BcastY) { + val += tmp; + } else { + dy[y_idx] = tmp; + } + } + if (d_intermediate != nullptr) { + T tmp = UseIntermediateOut + ? dintermediate_op.UseIntermediateOut( + y[y_idx], intermediate_out[tmp_out_idx], + out[offset], dout[offset]) + : dintermediate_op.Recompute(x_val, y_val, out[offset], + dout[offset]); + if (SameShapeOfIntermediateOutAndOut) { + d_intermediate[tmp_out_idx] = tmp; + } else { + inter_val += tmp; + } + } } } - i += ELEMWISE_MAX_BLOCK_DIM; - } while (i < h); + // transpose, for ReduceSum with wrap + sdata[threadIdx.y][threadIdx.x] = val; + __syncthreads(); + val = sdata[threadIdx.x][threadIdx.y]; +#pragma unroll + for (int i = BLOCK_X >> 1; i > 0; i >>= 1) { + // reduce sum with wrap + val += platform::CudaShuffleXorSync(0xFFFFFFFF, val, i); + } - h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; - if (BcastY) { - if (dy) { - val = paddle::platform::reduceSum(val, tid, h); - if (threadIdx.x == 0) { - dy[j] = val; + size_t idx_j = j + threadIdx.y; + if (BcastY) { + if (dy) { + if (threadIdx.x == 0 && (idx_j < w)) dy[idx_j] = val; } - } - } else { - if (dx) { - val = paddle::platform::reduceSum(val, tid, h); - if (threadIdx.x == 0) { - dx[j] = val; + } else { + if (dx) { + if (threadIdx.x == 0 && (idx_j < w)) dx[idx_j] = val; } } - } - if (!SameShapeOfIntermediateOutAndOut) { - if (d_intermediate) { - inter_val = paddle::platform::reduceSum(inter_val, tid, h); - if (threadIdx.x == 0) { - d_intermediate[j] = inter_val; + + if (!SameShapeOfIntermediateOutAndOut) { + if (d_intermediate) { + sdata[threadIdx.y][threadIdx.x] = inter_val; + __syncthreads(); + inter_val = sdata[threadIdx.x][threadIdx.y]; +#pragma unroll + for (int i = BLOCK_X >> 1; i > 0; i >>= 1) { + // reduce sum with wrap + inter_val += platform::CudaShuffleXorSync(0xFFFFFFFF, inter_val, i); + } + if (threadIdx.x == 0 && (idx_j < w)) d_intermediate[idx_j] = inter_val; } } - } + } // end for } template static void FusedElemwiseAndActGradBroadcast1CUDA( - gpuStream_t stream, const T *x, const T *y, const T *intermediate_out, - const T *out, const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op, - DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) { - int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h); - int gird_size = w; + const framework::ExecutionContext &ctx, const T *x, const T *y, + const T *intermediate_out, const T *out, const T *dout, int h, int w, + DX_OP dx_op, DY_OP dy_op, DIntermediate_OP dintermediate_op, T *dx, T *dy, + T *d_intermediate) { + gpuStream_t stream = ctx.cuda_device_context().stream(); + + dim3 blocks(BLOCK_X, BLOCK_Y); + int max_gpu_threads = ctx.cuda_device_context().GetMaxPhysicalThreadCount(); + int max_blocks = std::max(max_gpu_threads / (BLOCK_X * BLOCK_Y), 1); + int theory_block = (w + BLOCK_X - 1) / BLOCK_X; + dim3 grids(std::min(theory_block, max_blocks)); + FusedElemwiseAndActGradBroadcast1CUDAKernel< T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut, BcastY, - SameShapeOfIntermediateOutAndOut><<>>( + SameShapeOfIntermediateOutAndOut><<>>( x, y, intermediate_out, out, dout, h, w, dx_op, dy_op, dintermediate_op, dx, dy, d_intermediate); } @@ -2836,7 +2863,7 @@ void FusedElemwiseAndActGradComputeWithBroadcast( FusedElemwiseAndActGradBroadcast1CUDA( - ctx.template device_context().stream(), x_data, y_data, + ctx, x_data, y_data, intermediate_out == nullptr ? nullptr : intermediate_out->data(), out->data(), dout->data(), h, w, dx_op, dy_op, dintermediate_op, dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), diff --git a/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc b/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc index 4ff66d0d2b..d51e0de380 100644 --- a/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc +++ b/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc @@ -69,7 +69,7 @@ static bool IsSupportedCompound(const std::vector &functors) { functors.size(), 2)); static std::unordered_set unary_fun = {"scale", "relu", "tanh", - "sigmoid"}; + "sigmoid", "gelu"}; static std::unordered_set binary_fun = {"elementwise_add", "elementwise_mul"}; diff --git a/paddle/fluid/operators/fused/fused_elemwise_activation_op.h b/paddle/fluid/operators/fused/fused_elemwise_activation_op.h index c61b9a9e48..b7dd89a8a2 100644 --- a/paddle/fluid/operators/fused/fused_elemwise_activation_op.h +++ b/paddle/fluid/operators/fused/fused_elemwise_activation_op.h @@ -275,6 +275,13 @@ static void RunFunctors(const framework::ExecutionContext &ctx, paddle::operators::math::SigmoidFunctor>( ctx, paddle::operators::math::MulFunctor(), paddle::operators::math::SigmoidFunctor(), in_x, in_y, outputs); + } else if (funcs_str == "gelu,elementwise_add") { + // Z = Unary(Binary(X, Y)) + RunUnaryCompoundFunctors, + paddle::operators::math::AddFunctor>( + ctx, paddle::operators::math::GeluFunctor(), + paddle::operators::math::AddFunctor(), in_x, in_y, outputs); } else { PADDLE_THROW(platform::errors::InvalidArgument( "%s has not been implemented.", funcs_str)); @@ -374,6 +381,16 @@ static void RunGradFunctors( paddle::operators::math::SigmoidFunctor(), paddle::operators::math::SigmoidGradFunctor(), in_x, in_y, in_out, in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out); + } else if (funcs_str == "gelu_grad,elementwise_add_grad") { + // The backward of Z = Unary(Binary(X, Y)) + RunUnaryCompoundGradFunctors< + DeviceContext, T, paddle::operators::math::GeluGradFunctor, + paddle::operators::math::AddFunctor, + paddle::operators::math::AddGradFunctor, InPlace>( + ctx, paddle::operators::math::GeluGradFunctor(), + paddle::operators::math::AddFunctor(), + paddle::operators::math::AddGradFunctor(), in_x, in_y, in_out, + in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out); } else { PADDLE_THROW(platform::errors::InvalidArgument( "%s has not been implemented.", funcs_str)); diff --git a/paddle/fluid/operators/math/functors.h b/paddle/fluid/operators/math/functors.h index bf64d7e8ce..2eb6d00935 100644 --- a/paddle/fluid/operators/math/functors.h +++ b/paddle/fluid/operators/math/functors.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/math.h" namespace paddle { @@ -130,6 +131,63 @@ struct SigmoidGradFunctor { } }; +template +struct GeluFunctor { + using MT = typename details::MPTypeTrait::Type; + inline HOSTDEVICE T operator()(T x) { + // this function is tanh approximation of gelu + // actual gelu is: + // x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) + MT mx = static_cast(x); + MT out = mx * static_cast(0.5) * + (static_cast(1.0) + + tanh(static_cast(0.79788456) * mx * + (static_cast(1) + static_cast(0.044715) * mx * mx))); + return static_cast(out); + } +}; + +template +struct GeluGradFunctor { + using MT = typename details::MPTypeTrait::Type; + inline HOSTDEVICE T UseX(T x) { + MT mx = static_cast(x); + MT tanh_out = + tanh(static_cast(0.79788456) * mx * + (static_cast(1) + static_cast(0.044715) * mx * mx)); + MT ans = static_cast(0.5) * mx * + ((static_cast(1) - tanh_out * tanh_out) * + (static_cast(0.79788456) + + static_cast(0.1070322243) * mx * mx)) + + static_cast(0.5) * (static_cast(1) + tanh_out); + return static_cast(ans); + } + inline HOSTDEVICE T UseOut(T x) { + MT mx = static_cast(x); + MT tanh_out = + tanh(static_cast(0.79788456) * mx * + (static_cast(1) + static_cast(0.044715) * mx * mx)); + MT ans = static_cast(0.5) * mx * + ((static_cast(1) - tanh_out * tanh_out) * + (static_cast(0.79788456) + + static_cast(0.1070322243) * mx * mx)) + + static_cast(0.5) * (static_cast(1) + tanh_out); + return static_cast(ans); + } + inline HOSTDEVICE T UseXAndOut(T x, T out) { + MT mx = static_cast(x); + MT tanh_out = + tanh(static_cast(0.79788456) * mx * + (static_cast(1) + static_cast(0.044715) * mx * mx)); + MT ans = static_cast(0.5) * mx * + ((static_cast(1) - tanh_out * tanh_out) * + (static_cast(0.79788456) + + static_cast(0.1070322243) * mx * mx)) + + static_cast(0.5) * (static_cast(1) + tanh_out); + return static_cast(ans); + } +}; + } // namespace math } // namespace operators } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_fused_elemwise_activation_op.py b/python/paddle/fluid/tests/unittests/test_fused_elemwise_activation_op.py index 80bb14adf7..ba9e05470e 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_elemwise_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_elemwise_activation_op.py @@ -305,6 +305,15 @@ def mul_scale_func(x, y, x_bcast, y_bcast, scale, mode=0): return y, x, x * scale, y_bcast * (x_bcast * scale) +def gelu_add_func(x, y, x_bcast, y_bcast, mode=0): + im = x_bcast + y_bcast + out = im * 0.5 * (1.0 + np.tanh(0.79788456 * im * (1 + 0.044715 * im * im))) + if mode == 0: + return x, y, im, out + else: + return y, x, im, out + + scale = 0.1 scale_add_func = partial(scale_add_func, scale=scale) add_scale_func = partial(add_scale_func, scale=scale) @@ -316,6 +325,7 @@ for mode in {0, 1}: mul_scale_func = partial(mul_scale_func, mode=mode) relu_add_func = partial(relu_add_func, mode=mode) add_relu_func = partial(add_relu_func, mode=mode) + gelu_add_func = partial(gelu_add_func, mode=mode) for save_intermediate_out in {True, False}: suffix = ("_save_intermediate_out" if save_intermediate_out else "") \ @@ -343,6 +353,11 @@ for mode in {0, 1}: 'functor_list': ["elementwise_mul", "scale"], 'save_intermediate_out': save_intermediate_out, }) + create_test_class('gelu_add' + suffix, gelu_add_func, { + 'functor_list': ["gelu", "elementwise_add"], + 'save_intermediate_out': save_intermediate_out, + }) + if core.is_compiled_with_cuda(): create_test_class( 'scale_add_fp16' + suffix, @@ -388,6 +403,14 @@ for mode in {0, 1}: }, dtype=np.float16, grad_chek=False) + create_test_class( + 'gelu_add_fp16' + suffix, + gelu_add_func, { + 'functor_list': ["gelu", "elementwise_add"], + 'save_intermediate_out': save_intermediate_out, + }, + dtype=np.float16, + grad_chek=False) if __name__ == '__main__': import paddle -- GitLab