未验证 提交 eae31856 编写于 作者: W WangXi 提交者: GitHub

Add fused elemwise gelu and optimize performance (#33480)

上级 fa5ddfd9
......@@ -57,6 +57,10 @@ constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
*mod = dividend_copy % divisor; \
} while (0)
#define DIVUP(x, y) (((x) + (y)-1) / (y))
#define ROUNDUP(x, y) (DIVUP((x), (y)) * (y))
namespace paddle {
namespace operators {
......@@ -2156,10 +2160,10 @@ template <typename T, typename CompoundFunctor, bool BcastY,
static __global__ void FusedElemwiseAndActBroadcast1CUDAKernel(
const T *x, const T *y, int h, int w, CompoundFunctor compound_functor,
T *out, T *intermediate_out) {
int j = blockIdx.x;
int i = threadIdx.x;
int i = blockIdx.x;
int j = threadIdx.x;
while (i < h) {
while (j < w) {
int offset = i * w + j;
T y_val = BcastY ? y[j] : y[offset];
......@@ -2185,7 +2189,7 @@ static __global__ void FusedElemwiseAndActBroadcast1CUDAKernel(
out[offset] = compound_functor.GetOut(x_val, y_val);
}
i += ELEMWISE_MAX_BLOCK_DIM;
j += ELEMWISE_MAX_BLOCK_DIM;
}
}
......@@ -2196,8 +2200,8 @@ static void FusedElemwiseAndActBroadcast1CUDA(gpuStream_t stream, const T *x,
CompoundFunctor compound_functor,
int h, int w, T *out,
T *intermediate_out) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
int gird_size = w;
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, w);
int gird_size = h;
FusedElemwiseAndActBroadcast1CUDAKernel<
T, CompoundFunctor, BcastY, KeepIntermediateOut,
SameShapeOfIntermediateOutAndOut><<<gird_size, block_size, 0, stream>>>(
......@@ -2585,106 +2589,129 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel(
const T *x, const T *y, const T *intermediate_out, const T *out,
const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op,
DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) {
int j = blockIdx.x;
int i = threadIdx.x;
int tid = threadIdx.x;
T val(0), inter_val(0);
int64_t tmp_out_idx, x_idx, y_idx;
__shared__ T sdata[BLOCK_Y][BLOCK_X];
size_t idx = threadIdx.x + BLOCK_X * blockIdx.x;
size_t width_stride = gridDim.x * BLOCK_X;
size_t full_w = ROUNDUP(w, BLOCK_X);
T zero = static_cast<T>(0);
do {
int offset = i * w + j;
for (size_t j = idx; j < full_w; j += width_stride) {
T val(0), inter_val(0);
if (j < w) {
for (size_t i = threadIdx.y; i < h; i += BLOCK_Y) {
size_t offset = i * w + j;
tmp_out_idx = BcastY ? j : offset;
y_idx = BcastY ? j : offset;
x_idx = BcastY ? offset : j;
T x_val = (x == nullptr) ? zero : x[x_idx];
T y_val = (y == nullptr) ? zero : y[y_idx];
size_t tmp_out_idx = BcastY ? j : offset;
size_t y_idx = BcastY ? j : offset;
size_t x_idx = BcastY ? offset : j;
T x_val = (x == nullptr) ? zero : x[x_idx];
T y_val = (y == nullptr) ? zero : y[y_idx];
if (SameShapeOfIntermediateOutAndOut) {
tmp_out_idx = offset;
}
if (SameShapeOfIntermediateOutAndOut) {
tmp_out_idx = offset;
}
if (dx != nullptr) {
T tmp = UseIntermediateOut
if (dx != nullptr) {
T tmp =
UseIntermediateOut
? dx_op.UseIntermediateOut(x_val, y_val,
intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dx_op.Recompute(x_val, y_val, out[offset], dout[offset]);
if (BcastY) {
dx[x_idx] = tmp;
} else {
val += tmp;
}
}
if (dy != nullptr) {
T tmp = UseIntermediateOut
if (BcastY) {
dx[x_idx] = tmp;
} else {
val += tmp;
}
}
if (dy != nullptr) {
T tmp =
UseIntermediateOut
? dy_op.UseIntermediateOut(x_val, y_val,
intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dy_op.Recompute(x_val, y_val, out[offset], dout[offset]);
if (BcastY) {
val += tmp;
} else {
dy[y_idx] = tmp;
}
}
if (d_intermediate != nullptr) {
T tmp = UseIntermediateOut
? dintermediate_op.UseIntermediateOut(
y[y_idx], intermediate_out[tmp_out_idx], out[offset],
dout[offset])
: dintermediate_op.Recompute(x_val, y_val, out[offset],
dout[offset]);
if (SameShapeOfIntermediateOutAndOut) {
d_intermediate[tmp_out_idx] = tmp;
} else {
inter_val += tmp;
if (BcastY) {
val += tmp;
} else {
dy[y_idx] = tmp;
}
}
if (d_intermediate != nullptr) {
T tmp = UseIntermediateOut
? dintermediate_op.UseIntermediateOut(
y[y_idx], intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dintermediate_op.Recompute(x_val, y_val, out[offset],
dout[offset]);
if (SameShapeOfIntermediateOutAndOut) {
d_intermediate[tmp_out_idx] = tmp;
} else {
inter_val += tmp;
}
}
}
}
i += ELEMWISE_MAX_BLOCK_DIM;
} while (i < h);
// transpose, for ReduceSum with wrap
sdata[threadIdx.y][threadIdx.x] = val;
__syncthreads();
val = sdata[threadIdx.x][threadIdx.y];
#pragma unroll
for (int i = BLOCK_X >> 1; i > 0; i >>= 1) {
// reduce sum with wrap
val += platform::CudaShuffleXorSync(0xFFFFFFFF, val, i);
}
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
if (BcastY) {
if (dy) {
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dy[j] = val;
size_t idx_j = j + threadIdx.y;
if (BcastY) {
if (dy) {
if (threadIdx.x == 0 && (idx_j < w)) dy[idx_j] = val;
}
}
} else {
if (dx) {
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dx[j] = val;
} else {
if (dx) {
if (threadIdx.x == 0 && (idx_j < w)) dx[idx_j] = val;
}
}
}
if (!SameShapeOfIntermediateOutAndOut) {
if (d_intermediate) {
inter_val = paddle::platform::reduceSum(inter_val, tid, h);
if (threadIdx.x == 0) {
d_intermediate[j] = inter_val;
if (!SameShapeOfIntermediateOutAndOut) {
if (d_intermediate) {
sdata[threadIdx.y][threadIdx.x] = inter_val;
__syncthreads();
inter_val = sdata[threadIdx.x][threadIdx.y];
#pragma unroll
for (int i = BLOCK_X >> 1; i > 0; i >>= 1) {
// reduce sum with wrap
inter_val += platform::CudaShuffleXorSync(0xFFFFFFFF, inter_val, i);
}
if (threadIdx.x == 0 && (idx_j < w)) d_intermediate[idx_j] = inter_val;
}
}
}
} // end for
}
template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
bool UseIntermediateOut, bool BcastY,
bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActGradBroadcast1CUDA(
gpuStream_t stream, const T *x, const T *y, const T *intermediate_out,
const T *out, const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op,
DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
int gird_size = w;
const framework::ExecutionContext &ctx, const T *x, const T *y,
const T *intermediate_out, const T *out, const T *dout, int h, int w,
DX_OP dx_op, DY_OP dy_op, DIntermediate_OP dintermediate_op, T *dx, T *dy,
T *d_intermediate) {
gpuStream_t stream = ctx.cuda_device_context().stream();
dim3 blocks(BLOCK_X, BLOCK_Y);
int max_gpu_threads = ctx.cuda_device_context().GetMaxPhysicalThreadCount();
int max_blocks = std::max(max_gpu_threads / (BLOCK_X * BLOCK_Y), 1);
int theory_block = (w + BLOCK_X - 1) / BLOCK_X;
dim3 grids(std::min(theory_block, max_blocks));
FusedElemwiseAndActGradBroadcast1CUDAKernel<
T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut, BcastY,
SameShapeOfIntermediateOutAndOut><<<gird_size, block_size, 0, stream>>>(
SameShapeOfIntermediateOutAndOut><<<grids, blocks, 0, stream>>>(
x, y, intermediate_out, out, dout, h, w, dx_op, dy_op, dintermediate_op,
dx, dy, d_intermediate);
}
......@@ -2836,7 +2863,7 @@ void FusedElemwiseAndActGradComputeWithBroadcast(
FusedElemwiseAndActGradBroadcast1CUDA<T, DX_OP, DY_OP, DIntermediate_OP,
UseIntermediateOut, BcastY,
SameShapeOfIntermediateOutAndOut>(
ctx.template device_context<DeviceContext>().stream(), x_data, y_data,
ctx, x_data, y_data,
intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
out->data<T>(), dout->data<T>(), h, w, dx_op, dy_op, dintermediate_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
......
......@@ -69,7 +69,7 @@ static bool IsSupportedCompound(const std::vector<std::string> &functors) {
functors.size(), 2));
static std::unordered_set<std::string> unary_fun = {"scale", "relu", "tanh",
"sigmoid"};
"sigmoid", "gelu"};
static std::unordered_set<std::string> binary_fun = {"elementwise_add",
"elementwise_mul"};
......
......@@ -275,6 +275,13 @@ static void RunFunctors(const framework::ExecutionContext &ctx,
paddle::operators::math::SigmoidFunctor<T>>(
ctx, paddle::operators::math::MulFunctor<T>(),
paddle::operators::math::SigmoidFunctor<T>(), in_x, in_y, outputs);
} else if (funcs_str == "gelu,elementwise_add") {
// Z = Unary(Binary(X, Y))
RunUnaryCompoundFunctors<DeviceContext, T,
paddle::operators::math::GeluFunctor<T>,
paddle::operators::math::AddFunctor<T>>(
ctx, paddle::operators::math::GeluFunctor<T>(),
paddle::operators::math::AddFunctor<T>(), in_x, in_y, outputs);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s has not been implemented.", funcs_str));
......@@ -374,6 +381,16 @@ static void RunGradFunctors(
paddle::operators::math::SigmoidFunctor<T>(),
paddle::operators::math::SigmoidGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else if (funcs_str == "gelu_grad,elementwise_add_grad") {
// The backward of Z = Unary(Binary(X, Y))
RunUnaryCompoundGradFunctors<
DeviceContext, T, paddle::operators::math::GeluGradFunctor<T>,
paddle::operators::math::AddFunctor<T>,
paddle::operators::math::AddGradFunctor<T>, InPlace>(
ctx, paddle::operators::math::GeluGradFunctor<T>(),
paddle::operators::math::AddFunctor<T>(),
paddle::operators::math::AddGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s has not been implemented.", funcs_str));
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/math.h"
namespace paddle {
......@@ -130,6 +131,63 @@ struct SigmoidGradFunctor {
}
};
template <typename T>
struct GeluFunctor {
using MT = typename details::MPTypeTrait<T>::Type;
inline HOSTDEVICE T operator()(T x) {
// this function is tanh approximation of gelu
// actual gelu is:
// x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
MT mx = static_cast<MT>(x);
MT out = mx * static_cast<MT>(0.5) *
(static_cast<MT>(1.0) +
tanh(static_cast<MT>(0.79788456) * mx *
(static_cast<MT>(1) + static_cast<MT>(0.044715) * mx * mx)));
return static_cast<T>(out);
}
};
template <typename T>
struct GeluGradFunctor {
using MT = typename details::MPTypeTrait<T>::Type;
inline HOSTDEVICE T UseX(T x) {
MT mx = static_cast<MT>(x);
MT tanh_out =
tanh(static_cast<MT>(0.79788456) * mx *
(static_cast<MT>(1) + static_cast<MT>(0.044715) * mx * mx));
MT ans = static_cast<MT>(0.5) * mx *
((static_cast<MT>(1) - tanh_out * tanh_out) *
(static_cast<MT>(0.79788456) +
static_cast<MT>(0.1070322243) * mx * mx)) +
static_cast<MT>(0.5) * (static_cast<MT>(1) + tanh_out);
return static_cast<T>(ans);
}
inline HOSTDEVICE T UseOut(T x) {
MT mx = static_cast<MT>(x);
MT tanh_out =
tanh(static_cast<MT>(0.79788456) * mx *
(static_cast<MT>(1) + static_cast<MT>(0.044715) * mx * mx));
MT ans = static_cast<MT>(0.5) * mx *
((static_cast<MT>(1) - tanh_out * tanh_out) *
(static_cast<MT>(0.79788456) +
static_cast<MT>(0.1070322243) * mx * mx)) +
static_cast<MT>(0.5) * (static_cast<MT>(1) + tanh_out);
return static_cast<T>(ans);
}
inline HOSTDEVICE T UseXAndOut(T x, T out) {
MT mx = static_cast<MT>(x);
MT tanh_out =
tanh(static_cast<MT>(0.79788456) * mx *
(static_cast<MT>(1) + static_cast<MT>(0.044715) * mx * mx));
MT ans = static_cast<MT>(0.5) * mx *
((static_cast<MT>(1) - tanh_out * tanh_out) *
(static_cast<MT>(0.79788456) +
static_cast<MT>(0.1070322243) * mx * mx)) +
static_cast<MT>(0.5) * (static_cast<MT>(1) + tanh_out);
return static_cast<T>(ans);
}
};
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -305,6 +305,15 @@ def mul_scale_func(x, y, x_bcast, y_bcast, scale, mode=0):
return y, x, x * scale, y_bcast * (x_bcast * scale)
def gelu_add_func(x, y, x_bcast, y_bcast, mode=0):
im = x_bcast + y_bcast
out = im * 0.5 * (1.0 + np.tanh(0.79788456 * im * (1 + 0.044715 * im * im)))
if mode == 0:
return x, y, im, out
else:
return y, x, im, out
scale = 0.1
scale_add_func = partial(scale_add_func, scale=scale)
add_scale_func = partial(add_scale_func, scale=scale)
......@@ -316,6 +325,7 @@ for mode in {0, 1}:
mul_scale_func = partial(mul_scale_func, mode=mode)
relu_add_func = partial(relu_add_func, mode=mode)
add_relu_func = partial(add_relu_func, mode=mode)
gelu_add_func = partial(gelu_add_func, mode=mode)
for save_intermediate_out in {True, False}:
suffix = ("_save_intermediate_out" if save_intermediate_out else "") \
......@@ -343,6 +353,11 @@ for mode in {0, 1}:
'functor_list': ["elementwise_mul", "scale"],
'save_intermediate_out': save_intermediate_out,
})
create_test_class('gelu_add' + suffix, gelu_add_func, {
'functor_list': ["gelu", "elementwise_add"],
'save_intermediate_out': save_intermediate_out,
})
if core.is_compiled_with_cuda():
create_test_class(
'scale_add_fp16' + suffix,
......@@ -388,6 +403,14 @@ for mode in {0, 1}:
},
dtype=np.float16,
grad_chek=False)
create_test_class(
'gelu_add_fp16' + suffix,
gelu_add_func, {
'functor_list': ["gelu", "elementwise_add"],
'save_intermediate_out': save_intermediate_out,
},
dtype=np.float16,
grad_chek=False)
if __name__ == '__main__':
import paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册