未验证 提交 eae31856 编写于 作者: W WangXi 提交者: GitHub

Add fused elemwise gelu and optimize performance (#33480)

上级 fa5ddfd9
...@@ -57,6 +57,10 @@ constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024; ...@@ -57,6 +57,10 @@ constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
*mod = dividend_copy % divisor; \ *mod = dividend_copy % divisor; \
} while (0) } while (0)
#define DIVUP(x, y) (((x) + (y)-1) / (y))
#define ROUNDUP(x, y) (DIVUP((x), (y)) * (y))
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -2156,10 +2160,10 @@ template <typename T, typename CompoundFunctor, bool BcastY, ...@@ -2156,10 +2160,10 @@ template <typename T, typename CompoundFunctor, bool BcastY,
static __global__ void FusedElemwiseAndActBroadcast1CUDAKernel( static __global__ void FusedElemwiseAndActBroadcast1CUDAKernel(
const T *x, const T *y, int h, int w, CompoundFunctor compound_functor, const T *x, const T *y, int h, int w, CompoundFunctor compound_functor,
T *out, T *intermediate_out) { T *out, T *intermediate_out) {
int j = blockIdx.x; int i = blockIdx.x;
int i = threadIdx.x; int j = threadIdx.x;
while (i < h) { while (j < w) {
int offset = i * w + j; int offset = i * w + j;
T y_val = BcastY ? y[j] : y[offset]; T y_val = BcastY ? y[j] : y[offset];
...@@ -2185,7 +2189,7 @@ static __global__ void FusedElemwiseAndActBroadcast1CUDAKernel( ...@@ -2185,7 +2189,7 @@ static __global__ void FusedElemwiseAndActBroadcast1CUDAKernel(
out[offset] = compound_functor.GetOut(x_val, y_val); out[offset] = compound_functor.GetOut(x_val, y_val);
} }
i += ELEMWISE_MAX_BLOCK_DIM; j += ELEMWISE_MAX_BLOCK_DIM;
} }
} }
...@@ -2196,8 +2200,8 @@ static void FusedElemwiseAndActBroadcast1CUDA(gpuStream_t stream, const T *x, ...@@ -2196,8 +2200,8 @@ static void FusedElemwiseAndActBroadcast1CUDA(gpuStream_t stream, const T *x,
CompoundFunctor compound_functor, CompoundFunctor compound_functor,
int h, int w, T *out, int h, int w, T *out,
T *intermediate_out) { T *intermediate_out) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h); int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, w);
int gird_size = w; int gird_size = h;
FusedElemwiseAndActBroadcast1CUDAKernel< FusedElemwiseAndActBroadcast1CUDAKernel<
T, CompoundFunctor, BcastY, KeepIntermediateOut, T, CompoundFunctor, BcastY, KeepIntermediateOut,
SameShapeOfIntermediateOutAndOut><<<gird_size, block_size, 0, stream>>>( SameShapeOfIntermediateOutAndOut><<<gird_size, block_size, 0, stream>>>(
...@@ -2585,106 +2589,129 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel( ...@@ -2585,106 +2589,129 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel(
const T *x, const T *y, const T *intermediate_out, const T *out, const T *x, const T *y, const T *intermediate_out, const T *out,
const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op, const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op,
DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) { DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) {
int j = blockIdx.x; __shared__ T sdata[BLOCK_Y][BLOCK_X];
int i = threadIdx.x; size_t idx = threadIdx.x + BLOCK_X * blockIdx.x;
int tid = threadIdx.x; size_t width_stride = gridDim.x * BLOCK_X;
T val(0), inter_val(0);
int64_t tmp_out_idx, x_idx, y_idx; size_t full_w = ROUNDUP(w, BLOCK_X);
T zero = static_cast<T>(0); T zero = static_cast<T>(0);
do { for (size_t j = idx; j < full_w; j += width_stride) {
int offset = i * w + j; T val(0), inter_val(0);
if (j < w) {
for (size_t i = threadIdx.y; i < h; i += BLOCK_Y) {
size_t offset = i * w + j;
tmp_out_idx = BcastY ? j : offset; size_t tmp_out_idx = BcastY ? j : offset;
y_idx = BcastY ? j : offset; size_t y_idx = BcastY ? j : offset;
x_idx = BcastY ? offset : j; size_t x_idx = BcastY ? offset : j;
T x_val = (x == nullptr) ? zero : x[x_idx]; T x_val = (x == nullptr) ? zero : x[x_idx];
T y_val = (y == nullptr) ? zero : y[y_idx]; T y_val = (y == nullptr) ? zero : y[y_idx];
if (SameShapeOfIntermediateOutAndOut) { if (SameShapeOfIntermediateOutAndOut) {
tmp_out_idx = offset; tmp_out_idx = offset;
} }
if (dx != nullptr) { if (dx != nullptr) {
T tmp = UseIntermediateOut T tmp =
UseIntermediateOut
? dx_op.UseIntermediateOut(x_val, y_val, ? dx_op.UseIntermediateOut(x_val, y_val,
intermediate_out[tmp_out_idx], intermediate_out[tmp_out_idx],
out[offset], dout[offset]) out[offset], dout[offset])
: dx_op.Recompute(x_val, y_val, out[offset], dout[offset]); : dx_op.Recompute(x_val, y_val, out[offset], dout[offset]);
if (BcastY) { if (BcastY) {
dx[x_idx] = tmp; dx[x_idx] = tmp;
} else { } else {
val += tmp; val += tmp;
} }
} }
if (dy != nullptr) { if (dy != nullptr) {
T tmp = UseIntermediateOut T tmp =
UseIntermediateOut
? dy_op.UseIntermediateOut(x_val, y_val, ? dy_op.UseIntermediateOut(x_val, y_val,
intermediate_out[tmp_out_idx], intermediate_out[tmp_out_idx],
out[offset], dout[offset]) out[offset], dout[offset])
: dy_op.Recompute(x_val, y_val, out[offset], dout[offset]); : dy_op.Recompute(x_val, y_val, out[offset], dout[offset]);
if (BcastY) { if (BcastY) {
val += tmp; val += tmp;
} else { } else {
dy[y_idx] = tmp; dy[y_idx] = tmp;
} }
} }
if (d_intermediate != nullptr) { if (d_intermediate != nullptr) {
T tmp = UseIntermediateOut T tmp = UseIntermediateOut
? dintermediate_op.UseIntermediateOut( ? dintermediate_op.UseIntermediateOut(
y[y_idx], intermediate_out[tmp_out_idx], out[offset], y[y_idx], intermediate_out[tmp_out_idx],
dout[offset]) out[offset], dout[offset])
: dintermediate_op.Recompute(x_val, y_val, out[offset], : dintermediate_op.Recompute(x_val, y_val, out[offset],
dout[offset]); dout[offset]);
if (SameShapeOfIntermediateOutAndOut) { if (SameShapeOfIntermediateOutAndOut) {
d_intermediate[tmp_out_idx] = tmp; d_intermediate[tmp_out_idx] = tmp;
} else { } else {
inter_val += tmp; inter_val += tmp;
}
}
} }
} }
i += ELEMWISE_MAX_BLOCK_DIM; // transpose, for ReduceSum with wrap
} while (i < h); sdata[threadIdx.y][threadIdx.x] = val;
__syncthreads();
val = sdata[threadIdx.x][threadIdx.y];
#pragma unroll
for (int i = BLOCK_X >> 1; i > 0; i >>= 1) {
// reduce sum with wrap
val += platform::CudaShuffleXorSync(0xFFFFFFFF, val, i);
}
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; size_t idx_j = j + threadIdx.y;
if (BcastY) { if (BcastY) {
if (dy) { if (dy) {
val = paddle::platform::reduceSum(val, tid, h); if (threadIdx.x == 0 && (idx_j < w)) dy[idx_j] = val;
if (threadIdx.x == 0) {
dy[j] = val;
} }
} } else {
} else { if (dx) {
if (dx) { if (threadIdx.x == 0 && (idx_j < w)) dx[idx_j] = val;
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dx[j] = val;
} }
} }
}
if (!SameShapeOfIntermediateOutAndOut) { if (!SameShapeOfIntermediateOutAndOut) {
if (d_intermediate) { if (d_intermediate) {
inter_val = paddle::platform::reduceSum(inter_val, tid, h); sdata[threadIdx.y][threadIdx.x] = inter_val;
if (threadIdx.x == 0) { __syncthreads();
d_intermediate[j] = inter_val; inter_val = sdata[threadIdx.x][threadIdx.y];
#pragma unroll
for (int i = BLOCK_X >> 1; i > 0; i >>= 1) {
// reduce sum with wrap
inter_val += platform::CudaShuffleXorSync(0xFFFFFFFF, inter_val, i);
}
if (threadIdx.x == 0 && (idx_j < w)) d_intermediate[idx_j] = inter_val;
} }
} }
} } // end for
} }
template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP, template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
bool UseIntermediateOut, bool BcastY, bool UseIntermediateOut, bool BcastY,
bool SameShapeOfIntermediateOutAndOut> bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActGradBroadcast1CUDA( static void FusedElemwiseAndActGradBroadcast1CUDA(
gpuStream_t stream, const T *x, const T *y, const T *intermediate_out, const framework::ExecutionContext &ctx, const T *x, const T *y,
const T *out, const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op, const T *intermediate_out, const T *out, const T *dout, int h, int w,
DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) { DX_OP dx_op, DY_OP dy_op, DIntermediate_OP dintermediate_op, T *dx, T *dy,
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h); T *d_intermediate) {
int gird_size = w; gpuStream_t stream = ctx.cuda_device_context().stream();
dim3 blocks(BLOCK_X, BLOCK_Y);
int max_gpu_threads = ctx.cuda_device_context().GetMaxPhysicalThreadCount();
int max_blocks = std::max(max_gpu_threads / (BLOCK_X * BLOCK_Y), 1);
int theory_block = (w + BLOCK_X - 1) / BLOCK_X;
dim3 grids(std::min(theory_block, max_blocks));
FusedElemwiseAndActGradBroadcast1CUDAKernel< FusedElemwiseAndActGradBroadcast1CUDAKernel<
T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut, BcastY, T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut, BcastY,
SameShapeOfIntermediateOutAndOut><<<gird_size, block_size, 0, stream>>>( SameShapeOfIntermediateOutAndOut><<<grids, blocks, 0, stream>>>(
x, y, intermediate_out, out, dout, h, w, dx_op, dy_op, dintermediate_op, x, y, intermediate_out, out, dout, h, w, dx_op, dy_op, dintermediate_op,
dx, dy, d_intermediate); dx, dy, d_intermediate);
} }
...@@ -2836,7 +2863,7 @@ void FusedElemwiseAndActGradComputeWithBroadcast( ...@@ -2836,7 +2863,7 @@ void FusedElemwiseAndActGradComputeWithBroadcast(
FusedElemwiseAndActGradBroadcast1CUDA<T, DX_OP, DY_OP, DIntermediate_OP, FusedElemwiseAndActGradBroadcast1CUDA<T, DX_OP, DY_OP, DIntermediate_OP,
UseIntermediateOut, BcastY, UseIntermediateOut, BcastY,
SameShapeOfIntermediateOutAndOut>( SameShapeOfIntermediateOutAndOut>(
ctx.template device_context<DeviceContext>().stream(), x_data, y_data, ctx, x_data, y_data,
intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(), intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
out->data<T>(), dout->data<T>(), h, w, dx_op, dy_op, dintermediate_op, out->data<T>(), dout->data<T>(), h, w, dx_op, dy_op, dintermediate_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()), dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
......
...@@ -69,7 +69,7 @@ static bool IsSupportedCompound(const std::vector<std::string> &functors) { ...@@ -69,7 +69,7 @@ static bool IsSupportedCompound(const std::vector<std::string> &functors) {
functors.size(), 2)); functors.size(), 2));
static std::unordered_set<std::string> unary_fun = {"scale", "relu", "tanh", static std::unordered_set<std::string> unary_fun = {"scale", "relu", "tanh",
"sigmoid"}; "sigmoid", "gelu"};
static std::unordered_set<std::string> binary_fun = {"elementwise_add", static std::unordered_set<std::string> binary_fun = {"elementwise_add",
"elementwise_mul"}; "elementwise_mul"};
......
...@@ -275,6 +275,13 @@ static void RunFunctors(const framework::ExecutionContext &ctx, ...@@ -275,6 +275,13 @@ static void RunFunctors(const framework::ExecutionContext &ctx,
paddle::operators::math::SigmoidFunctor<T>>( paddle::operators::math::SigmoidFunctor<T>>(
ctx, paddle::operators::math::MulFunctor<T>(), ctx, paddle::operators::math::MulFunctor<T>(),
paddle::operators::math::SigmoidFunctor<T>(), in_x, in_y, outputs); paddle::operators::math::SigmoidFunctor<T>(), in_x, in_y, outputs);
} else if (funcs_str == "gelu,elementwise_add") {
// Z = Unary(Binary(X, Y))
RunUnaryCompoundFunctors<DeviceContext, T,
paddle::operators::math::GeluFunctor<T>,
paddle::operators::math::AddFunctor<T>>(
ctx, paddle::operators::math::GeluFunctor<T>(),
paddle::operators::math::AddFunctor<T>(), in_x, in_y, outputs);
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"%s has not been implemented.", funcs_str)); "%s has not been implemented.", funcs_str));
...@@ -374,6 +381,16 @@ static void RunGradFunctors( ...@@ -374,6 +381,16 @@ static void RunGradFunctors(
paddle::operators::math::SigmoidFunctor<T>(), paddle::operators::math::SigmoidFunctor<T>(),
paddle::operators::math::SigmoidGradFunctor<T>(), in_x, in_y, in_out, paddle::operators::math::SigmoidGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out); in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else if (funcs_str == "gelu_grad,elementwise_add_grad") {
// The backward of Z = Unary(Binary(X, Y))
RunUnaryCompoundGradFunctors<
DeviceContext, T, paddle::operators::math::GeluGradFunctor<T>,
paddle::operators::math::AddFunctor<T>,
paddle::operators::math::AddGradFunctor<T>, InPlace>(
ctx, paddle::operators::math::GeluGradFunctor<T>(),
paddle::operators::math::AddFunctor<T>(),
paddle::operators::math::AddGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"%s has not been implemented.", funcs_str)); "%s has not been implemented.", funcs_str));
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/math.h" #include "paddle/fluid/operators/math.h"
namespace paddle { namespace paddle {
...@@ -130,6 +131,63 @@ struct SigmoidGradFunctor { ...@@ -130,6 +131,63 @@ struct SigmoidGradFunctor {
} }
}; };
template <typename T>
struct GeluFunctor {
using MT = typename details::MPTypeTrait<T>::Type;
inline HOSTDEVICE T operator()(T x) {
// this function is tanh approximation of gelu
// actual gelu is:
// x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
MT mx = static_cast<MT>(x);
MT out = mx * static_cast<MT>(0.5) *
(static_cast<MT>(1.0) +
tanh(static_cast<MT>(0.79788456) * mx *
(static_cast<MT>(1) + static_cast<MT>(0.044715) * mx * mx)));
return static_cast<T>(out);
}
};
template <typename T>
struct GeluGradFunctor {
using MT = typename details::MPTypeTrait<T>::Type;
inline HOSTDEVICE T UseX(T x) {
MT mx = static_cast<MT>(x);
MT tanh_out =
tanh(static_cast<MT>(0.79788456) * mx *
(static_cast<MT>(1) + static_cast<MT>(0.044715) * mx * mx));
MT ans = static_cast<MT>(0.5) * mx *
((static_cast<MT>(1) - tanh_out * tanh_out) *
(static_cast<MT>(0.79788456) +
static_cast<MT>(0.1070322243) * mx * mx)) +
static_cast<MT>(0.5) * (static_cast<MT>(1) + tanh_out);
return static_cast<T>(ans);
}
inline HOSTDEVICE T UseOut(T x) {
MT mx = static_cast<MT>(x);
MT tanh_out =
tanh(static_cast<MT>(0.79788456) * mx *
(static_cast<MT>(1) + static_cast<MT>(0.044715) * mx * mx));
MT ans = static_cast<MT>(0.5) * mx *
((static_cast<MT>(1) - tanh_out * tanh_out) *
(static_cast<MT>(0.79788456) +
static_cast<MT>(0.1070322243) * mx * mx)) +
static_cast<MT>(0.5) * (static_cast<MT>(1) + tanh_out);
return static_cast<T>(ans);
}
inline HOSTDEVICE T UseXAndOut(T x, T out) {
MT mx = static_cast<MT>(x);
MT tanh_out =
tanh(static_cast<MT>(0.79788456) * mx *
(static_cast<MT>(1) + static_cast<MT>(0.044715) * mx * mx));
MT ans = static_cast<MT>(0.5) * mx *
((static_cast<MT>(1) - tanh_out * tanh_out) *
(static_cast<MT>(0.79788456) +
static_cast<MT>(0.1070322243) * mx * mx)) +
static_cast<MT>(0.5) * (static_cast<MT>(1) + tanh_out);
return static_cast<T>(ans);
}
};
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -305,6 +305,15 @@ def mul_scale_func(x, y, x_bcast, y_bcast, scale, mode=0): ...@@ -305,6 +305,15 @@ def mul_scale_func(x, y, x_bcast, y_bcast, scale, mode=0):
return y, x, x * scale, y_bcast * (x_bcast * scale) return y, x, x * scale, y_bcast * (x_bcast * scale)
def gelu_add_func(x, y, x_bcast, y_bcast, mode=0):
im = x_bcast + y_bcast
out = im * 0.5 * (1.0 + np.tanh(0.79788456 * im * (1 + 0.044715 * im * im)))
if mode == 0:
return x, y, im, out
else:
return y, x, im, out
scale = 0.1 scale = 0.1
scale_add_func = partial(scale_add_func, scale=scale) scale_add_func = partial(scale_add_func, scale=scale)
add_scale_func = partial(add_scale_func, scale=scale) add_scale_func = partial(add_scale_func, scale=scale)
...@@ -316,6 +325,7 @@ for mode in {0, 1}: ...@@ -316,6 +325,7 @@ for mode in {0, 1}:
mul_scale_func = partial(mul_scale_func, mode=mode) mul_scale_func = partial(mul_scale_func, mode=mode)
relu_add_func = partial(relu_add_func, mode=mode) relu_add_func = partial(relu_add_func, mode=mode)
add_relu_func = partial(add_relu_func, mode=mode) add_relu_func = partial(add_relu_func, mode=mode)
gelu_add_func = partial(gelu_add_func, mode=mode)
for save_intermediate_out in {True, False}: for save_intermediate_out in {True, False}:
suffix = ("_save_intermediate_out" if save_intermediate_out else "") \ suffix = ("_save_intermediate_out" if save_intermediate_out else "") \
...@@ -343,6 +353,11 @@ for mode in {0, 1}: ...@@ -343,6 +353,11 @@ for mode in {0, 1}:
'functor_list': ["elementwise_mul", "scale"], 'functor_list': ["elementwise_mul", "scale"],
'save_intermediate_out': save_intermediate_out, 'save_intermediate_out': save_intermediate_out,
}) })
create_test_class('gelu_add' + suffix, gelu_add_func, {
'functor_list': ["gelu", "elementwise_add"],
'save_intermediate_out': save_intermediate_out,
})
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
create_test_class( create_test_class(
'scale_add_fp16' + suffix, 'scale_add_fp16' + suffix,
...@@ -388,6 +403,14 @@ for mode in {0, 1}: ...@@ -388,6 +403,14 @@ for mode in {0, 1}:
}, },
dtype=np.float16, dtype=np.float16,
grad_chek=False) grad_chek=False)
create_test_class(
'gelu_add_fp16' + suffix,
gelu_add_func, {
'functor_list': ["gelu", "elementwise_add"],
'save_intermediate_out': save_intermediate_out,
},
dtype=np.float16,
grad_chek=False)
if __name__ == '__main__': if __name__ == '__main__':
import paddle import paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册