未验证 提交 746b774b 编写于 作者: L limingshu 提交者: GitHub

[OptionalOptimization]: LayerNorm forward Optimization with Welford (#50362)

* first commit

* main codes has been developed

* fix all bugs

* add vectorize input&output

* a test for optimization_of_layer_norm_fwd

* add some changes

* fix memory coalesced access for more optimization.

* fix addition ctest error

* fix according to ci-approval

* remove change on slice
上级 e1956ab5
...@@ -13,14 +13,445 @@ ...@@ -13,14 +13,445 @@
// limitations under the License. // limitations under the License.
#include "paddle/phi/kernels/layer_norm_kernel.h" #include "paddle/phi/kernels/layer_norm_kernel.h"
#include "gflags/gflags.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h" #include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h"
#include "paddle/phi/kernels/funcs/layer_norm_util.h" #include "paddle/phi/kernels/funcs/layer_norm_util.h"
DECLARE_bool(use_fast_math);
namespace phi { namespace phi {
#ifdef PADDLE_WITH_CUDA
template <typename U>
__device__ inline void WelfordOnline(U val, U *mean, U *square, U *count) {
*count += 1;
U delta1 = val - *mean;
*mean += delta1 / (*count);
U delta2 = val - *mean;
*square += delta1 * delta2;
}
template <typename U>
__device__ inline void WelfordOnline(
U b_mean, U b_square, U b_cnt, U *mean, U *square, U *count) {
if (b_cnt == 0) {
return;
}
U new_cnt = *count + b_cnt;
U nb_n = b_cnt / new_cnt;
U delta = b_mean - *mean;
*mean += delta * nb_n;
*square += b_square + delta * delta * (*count) * nb_n;
*count = new_cnt;
}
template <typename U>
__device__ inline void WelfordWarpAllReduce(U *mean, U *square, U *count) {
constexpr int kWarpSize = 32;
#pragma unroll
for (int mask = 1; mask < kWarpSize; mask *= 2) {
U b_mean = __shfl_down_sync(0xffffffff, *mean, mask);
U b_square = __shfl_down_sync(0xffffffff, *square, mask);
U b_cnt = __shfl_down_sync(0xffffffff, *count, mask);
WelfordOnline<U>(b_mean, b_square, b_cnt, mean, square, count);
}
*mean = __shfl_sync(0xffffffff, *mean, 0, kWarpSize);
*square = __shfl_sync(0xffffffff, *square, 0, kWarpSize);
*count = __shfl_sync(0xffffffff, *count, 0, kWarpSize);
}
template <int VecSize>
struct ThreadAssigner {
__device__ __forceinline__ int operator()(const int cols,
const int cols_per_thread,
int32_t *last_tid_idx) {
return cols_per_thread;
}
};
template <>
struct ThreadAssigner<1> {
__device__ inline int operator()(const int cols,
const int cols_per_thread,
int *last_tid_idx) {
int cols_this_thread = cols_per_thread;
int last_tid = (cols / cols_per_thread);
*last_tid_idx = last_tid;
if (threadIdx.x == last_tid) {
cols_this_thread = cols - cols_per_thread * last_tid;
} else if (threadIdx.x > last_tid) {
cols_this_thread = 0;
}
return cols_this_thread;
}
};
template <typename T, typename U, int VecSize>
struct LayerNormDataReader {
__device__ inline void operator()(const T *__restrict__ row_src,
U *buffer,
const int last_tid_idx,
const int read_times,
const int cols_this_thread) {
using VecT = phi::AlignedVector<T, VecSize>;
const VecT *__restrict__ v_src =
reinterpret_cast<const VecT *__restrict__>(row_src);
for (int i = 0; i < read_times; ++i) {
VecT temp_src = v_src[threadIdx.x + i * blockDim.x];
#pragma unroll
for (int j = 0; j < VecSize; ++j) {
buffer[i * VecSize + j] = static_cast<U>(temp_src[j]);
}
}
}
};
template <typename T, typename U>
struct LayerNormDataReader<T, U, 1> {
__device__ inline void operator()(const T *__restrict__ row_src,
U *buffer,
const int last_tid_idx,
const int read_times,
const int cols_this_thread) {
// read_time is just cols_per_thread while VecSize is 1.
if (threadIdx.x < last_tid_idx) {
for (int i = 0; i < cols_this_thread; ++i) {
buffer[i] = static_cast<U>(row_src[threadIdx.x + last_tid_idx * i]);
}
} else {
for (int i = 0; i < cols_this_thread; ++i) {
buffer[i] = static_cast<U>(row_src[i + read_times * last_tid_idx]);
}
}
}
};
template <typename T, typename U, bool IsSameType, int VecSize>
struct LayerNormDataWritter {
__device__ inline void operator()(
T *__restrict__ row_dst,
const U *__restrict__ buffer,
const funcs::LayerNormScaleBiasT<T, U, IsSameType> *__restrict__ scale,
const funcs::LayerNormScaleBiasT<T, U, IsSameType> *__restrict__ bias,
const U row_mean,
const U row_inv_var,
const int write_times,
const int cols_this_thread,
const int last_tid_idx,
const bool valid_scale,
const bool valid_bias) {
using VecT = phi::AlignedVector<T, VecSize>;
using ScaleT = funcs::LayerNormScaleBiasT<T, U, IsSameType>;
using VecScaleT = phi::AlignedVector<ScaleT, VecSize>;
VecT *v_dst = reinterpret_cast<VecT *>(row_dst);
// cols_this_thread is just cols_per_thread
if ((!valid_scale) && (!valid_bias)) {
for (int i = 0; i < write_times; ++i) {
VecT temp_dst;
#pragma unroll
for (int j = 0; j < VecSize; ++j) {
temp_dst[j] = static_cast<T>((buffer[i * VecSize + j] - row_mean) *
row_inv_var);
}
v_dst[threadIdx.x + blockDim.x * i] = temp_dst;
}
} else {
const VecScaleT *__restrict__ v_scale =
reinterpret_cast<const VecScaleT *__restrict__>(scale);
const VecScaleT *__restrict__ v_bias =
reinterpret_cast<const VecScaleT *__restrict__>(bias);
if (valid_scale && valid_bias) {
for (int i = 0; i < write_times; ++i) {
int idx = threadIdx.x + blockDim.x * i;
VecT temp_dst;
VecScaleT temp_v_scale = v_scale[idx];
VecScaleT temp_v_bias = v_bias[idx];
#pragma unroll
for (int j = 0; j < VecSize; ++j) {
temp_dst[j] = static_cast<T>(
static_cast<U>(temp_v_scale[j]) *
(buffer[i * VecSize + j] - row_mean) * row_inv_var +
static_cast<U>(temp_v_bias[j]));
}
v_dst[idx] = temp_dst;
}
} else {
if (valid_scale) {
for (int i = 0; i < write_times; ++i) {
int idx = threadIdx.x + blockDim.x * i;
VecT temp_dst;
VecScaleT temp_v_scale = v_scale[idx];
#pragma unroll
for (int j = 0; j < VecSize; ++j) {
temp_dst[j] = static_cast<T>(
static_cast<U>(temp_v_scale[j]) *
(buffer[i * VecSize + j] - row_mean) * row_inv_var);
}
v_dst[idx] = temp_dst;
}
} else {
for (int i = 0; i < write_times; ++i) {
int idx = threadIdx.x + blockDim.x * i;
VecT temp_dst;
VecScaleT temp_v_bias = v_bias[idx];
#pragma unroll
for (int j = 0; j < VecSize; ++j) {
temp_dst[j] = static_cast<T>(
(buffer[i * VecSize + j] - row_mean) * row_inv_var +
static_cast<U>(temp_v_bias[j]));
}
v_dst[idx] = temp_dst;
}
}
}
}
}
};
template <typename T, typename U, bool IsSameType>
struct LayerNormDataWritter<T, U, IsSameType, 1> {
__device__ __forceinline__ void operator()(
T *__restrict__ row_dst,
U *__restrict__ buffer,
const funcs::LayerNormScaleBiasT<T, U, IsSameType> *__restrict__ scale,
const funcs::LayerNormScaleBiasT<T, U, IsSameType> *__restrict__ bias,
const U row_mean,
const U row_inv_var,
const int write_times,
const int cols_this_thread,
const int last_tid_idx,
const bool valid_scale,
const bool valid_bias) {
// write_times is just col_per_thread.
if ((!valid_scale) && (!valid_bias)) {
if (threadIdx.x < last_tid_idx) {
for (int i = 0; i < cols_this_thread; ++i) {
row_dst[threadIdx.x + last_tid_idx * i] =
(buffer[i] - row_mean) * row_inv_var;
}
} else {
for (int i = 0; i < cols_this_thread; ++i) {
row_dst[last_tid_idx * write_times + i] =
(buffer[i] - row_mean) * row_inv_var;
}
}
} else if (valid_scale && valid_bias) {
if (threadIdx.x < last_tid_idx) {
for (int i = 0; i < cols_this_thread; ++i) {
int idx = threadIdx.x + last_tid_idx * i;
row_dst[idx] =
static_cast<T>(static_cast<U>(scale[idx]) *
(buffer[i] - row_mean) * row_inv_var +
static_cast<U>(bias[idx]));
}
} else {
for (int i = 0; i < cols_this_thread; ++i) {
int idx = last_tid_idx * write_times + i;
row_dst[idx] =
static_cast<T>(static_cast<U>(scale[idx]) *
(buffer[i] - row_mean) * row_inv_var +
static_cast<U>(bias[idx]));
}
}
} else {
if (valid_scale) {
if (threadIdx.x < last_tid_idx) {
for (int i = 0; i < cols_this_thread; ++i) {
int idx = threadIdx.x + last_tid_idx * i;
row_dst[idx] = static_cast<T>(static_cast<U>(scale[idx]) *
(buffer[i] - row_mean) * row_inv_var);
}
} else {
for (int i = 0; i < cols_this_thread; ++i) {
int idx = last_tid_idx * write_times + i;
row_dst[idx] = static_cast<T>(static_cast<U>(scale[idx]) *
(buffer[i] - row_mean) * row_inv_var);
}
}
} else {
if (threadIdx.x < last_tid_idx) {
for (int i = 0; i < cols_this_thread; ++i) {
int idx = threadIdx.x + last_tid_idx * i;
row_dst[idx] = static_cast<T>((buffer[i] - row_mean) * row_inv_var +
static_cast<U>(bias[idx]));
}
} else {
for (int i = 0; i < cols_this_thread; ++i) {
int idx = last_tid_idx * write_times + i;
row_dst[idx] = static_cast<T>((buffer[i] - row_mean) * row_inv_var +
static_cast<U>(bias[idx]));
}
}
}
}
}
};
template <typename IndexT, typename T, typename U, bool IsSameType, int VecSize>
__global__ void LayerNormFwdWithWelford(
const T *__restrict__ src_data,
T *dst_data,
const funcs::LayerNormScaleBiasT<T, U, IsSameType> *__restrict__ scale,
const funcs::LayerNormScaleBiasT<T, U, IsSameType> *__restrict__ bias,
U *mean,
U *var,
const U epsilon,
const IndexT rows,
const int32_t cols,
const int32_t cols_per_thread,
const bool valid_scale,
const bool valid_bias) {
constexpr int kWarpSize = 32;
int last_tid_idx = 0; // For condition once vecSize is 1.
IndexT row_offset = blockIdx.x * blockDim.y + threadIdx.y;
int cols_this_thread =
ThreadAssigner<VecSize>()(cols, cols_per_thread, &last_tid_idx);
int read_times = cols_per_thread / VecSize;
if (row_offset < rows) {
U buffer[kWarpSize];
U tid_cnt = static_cast<U>(0);
U tid_mean = static_cast<U>(0);
U tid_square = static_cast<U>(0);
const T *__restrict__ row_src = src_data + row_offset * cols;
T *row_dst = dst_data + row_offset * cols;
LayerNormDataReader<T, U, VecSize>()(
row_src, buffer, last_tid_idx, read_times, cols_this_thread);
for (int i = 0; i < cols_this_thread; i++) {
WelfordOnline<U>(buffer[i], &tid_mean, &tid_square, &tid_cnt);
}
U warp_cnt = tid_cnt;
U warp_mean = tid_mean;
U warp_square = tid_square;
WelfordWarpAllReduce<U>(&warp_mean, &warp_square, &warp_cnt);
U row_variance = max(warp_square / warp_cnt, 0.f);
U row_inv_var = funcs::rsqrt_(row_variance + epsilon);
// TODO(limingshu): make code below vectorization.
if (threadIdx.x == 0) {
// warp_mean is just row_mean here.
mean[row_offset] = warp_mean;
var[row_offset] = row_variance;
}
LayerNormDataWritter<T, U, IsSameType, VecSize>()(row_dst,
buffer,
scale,
bias,
warp_mean,
row_inv_var,
read_times,
cols_this_thread,
last_tid_idx,
valid_scale,
valid_bias);
}
}
template <typename Context, typename T, typename U>
void LaunchLayerNormKernel(const Context &dev_ctx,
const T *x_data,
T *y_data,
const void *void_scale_data,
const void *void_bias_data,
U *mean_data,
U *var_data,
float epsilon,
const int64_t rows,
const int cols,
const bool valid_scale,
const bool valid_bias,
const bool is_same_type) {
constexpr int WarpSize = 32;
constexpr int RowPerBlock = 4;
int64_t block_size = (rows + (RowPerBlock - 1)) / RowPerBlock;
dim3 threads(WarpSize, RowPerBlock, 1);
int vec_size = 1;
int cols_per_thread = (cols + (WarpSize - 1)) / WarpSize;
if (cols_per_thread > 1 && (cols % WarpSize == 0)) {
int data_vec_size = 0;
uint64_t addr = (reinterpret_cast<uint64_t>(x_data) |
reinterpret_cast<uint64_t>(y_data));
if (valid_bias || valid_scale) {
if (is_same_type) {
addr = valid_scale
? (addr | reinterpret_cast<uint64_t>(void_scale_data))
: addr;
addr = valid_bias ? (addr | reinterpret_cast<uint64_t>(void_bias_data))
: addr;
data_vec_size = phi::GetVectorizedSize<T>(reinterpret_cast<T *>(addr));
} else {
uint64_t bias_addr = reinterpret_cast<uint64_t>(void_bias_data);
uint64_t attr_addr = valid_scale
? reinterpret_cast<uint64_t>(void_scale_data)
: bias_addr;
attr_addr = valid_bias
? (valid_scale ? (attr_addr | bias_addr) : attr_addr)
: attr_addr;
data_vec_size = std::min(
phi::GetVectorizedSize<T>(reinterpret_cast<T *>(addr)),
phi::GetVectorizedSize<U>(reinterpret_cast<U *>(attr_addr)));
}
}
for (int size = data_vec_size; size > 0; size /= 2) {
if (cols_per_thread % size == 0) {
vec_size = size;
break;
}
}
}
#define IMPL_LAYER_NORM_WELFORD_CASE(index_t, scale_t, is_same_, vec_size_) \
case (vec_size_): { \
LayerNormFwdWithWelford<index_t, T, U, is_same_, vec_size_> \
<<<block_size, threads, 0, dev_ctx.stream()>>>( \
x_data, \
y_data, \
static_cast<const scale_t *>(void_scale_data), \
static_cast<const scale_t *>(void_bias_data), \
mean_data, \
var_data, \
static_cast<const U>(epsilon), \
rows, \
cols, \
cols_per_thread, \
valid_scale, \
valid_bias); \
} break
#define IMPL_LAYER_NORM_WELFORD(index_t, scale_t, is_same_) \
IMPL_LAYER_NORM_WELFORD_CASE(index_t, scale_t, is_same_, 4); \
IMPL_LAYER_NORM_WELFORD_CASE(index_t, scale_t, is_same_, 2); \
IMPL_LAYER_NORM_WELFORD_CASE(index_t, scale_t, is_same_, 1);
if (rows < std::numeric_limits<int32_t>::max()) {
if (is_same_type) {
switch (vec_size) { IMPL_LAYER_NORM_WELFORD(int32_t, T, true); }
} else {
switch (vec_size) { IMPL_LAYER_NORM_WELFORD(int32_t, U, false); }
}
} else {
if (is_same_type) {
switch (vec_size) { IMPL_LAYER_NORM_WELFORD(int64_t, T, true); }
} else {
switch (vec_size) { IMPL_LAYER_NORM_WELFORD(int64_t, U, false); }
}
}
#undef IMPL_LAYER_NORM_WELFORD_CASE
#undef IMPL_LAYER_NORM_WELFORD
}
#endif // PADDLE_WITH_CUDA
template <typename T, typename U> template <typename T, typename U>
void LayerNormDirectCUDAFunctor<T, U>::operator()(gpuStream_t stream, void LayerNormDirectCUDAFunctor<T, U>::operator()(gpuStream_t stream,
const T *input, const T *input,
...@@ -75,14 +506,16 @@ void LayerNormKernel(const Context &dev_ctx, ...@@ -75,14 +506,16 @@ void LayerNormKernel(const Context &dev_ctx,
auto *mean_data = dev_ctx.template Alloc<U>(mean); auto *mean_data = dev_ctx.template Alloc<U>(mean);
auto *var_data = dev_ctx.template Alloc<U>(var); auto *var_data = dev_ctx.template Alloc<U>(var);
auto *void_scale_data = (scale == nullptr ? nullptr : scale->data()); bool valid_scale = (scale != nullptr);
auto *void_bias_data = (bias == nullptr ? nullptr : bias->data()); bool valid_bias = (bias != nullptr);
auto *void_scale_data = valid_scale ? scale->data() : nullptr;
auto *void_bias_data = valid_bias ? bias->data() : nullptr;
auto x_dtype = x.dtype(); auto x_dtype = x.dtype();
phi::DataType scale_bias_dtype; phi::DataType scale_bias_dtype;
if (void_scale_data != nullptr) { if (valid_scale) {
scale_bias_dtype = scale->dtype(); scale_bias_dtype = scale->dtype();
if (void_bias_data != nullptr) { if (valid_bias) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
scale->dtype(), scale->dtype(),
bias->dtype(), bias->dtype(),
...@@ -90,7 +523,7 @@ void LayerNormKernel(const Context &dev_ctx, ...@@ -90,7 +523,7 @@ void LayerNormKernel(const Context &dev_ctx,
"should have the same data type.")); "should have the same data type."));
} }
} else { } else {
scale_bias_dtype = (void_bias_data != nullptr ? bias->dtype() : x_dtype); scale_bias_dtype = valid_bias ? bias->dtype() : x_dtype;
} }
bool is_scale_bias_same_dtype_with_x = x_dtype == scale_bias_dtype; bool is_scale_bias_same_dtype_with_x = x_dtype == scale_bias_dtype;
...@@ -104,7 +537,6 @@ void LayerNormKernel(const Context &dev_ctx, ...@@ -104,7 +537,6 @@ void LayerNormKernel(const Context &dev_ctx,
auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis); auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis);
int64_t batch_size = static_cast<int64_t>(matrix_dim[0]); int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
int64_t feature_size = static_cast<int64_t>(matrix_dim[1]); int64_t feature_size = static_cast<int64_t>(matrix_dim[1]);
auto stream = dev_ctx.stream(); auto stream = dev_ctx.stream();
#define PADDLE_LAUNCH_LAYERNORM_FWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \ #define PADDLE_LAUNCH_LAYERNORM_FWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \
...@@ -200,13 +632,31 @@ void LayerNormKernel(const Context &dev_ctx, ...@@ -200,13 +632,31 @@ void LayerNormKernel(const Context &dev_ctx,
} }
} }
} else { } else {
#endif // WarpShuffle intrinsics is involved in LaunchLayerNormKernel.
if (is_scale_bias_same_dtype_with_x) { if (FLAGS_use_fast_math && feature_size <= 1024 &&
PADDLE_LAUNCH_LAYERNORM_FWD(T, true); (!std::is_same<T, int8_t>::value)) {
LaunchLayerNormKernel<Context, T, U>(dev_ctx,
x_data,
y_data,
void_scale_data,
void_bias_data,
mean_data,
var_data,
epsilon,
batch_size,
feature_size,
valid_scale,
valid_bias,
is_scale_bias_same_dtype_with_x);
} else { } else {
PADDLE_LAUNCH_LAYERNORM_FWD(U, false); #endif
} if (is_scale_bias_same_dtype_with_x) {
PADDLE_LAUNCH_LAYERNORM_FWD(T, true);
} else {
PADDLE_LAUNCH_LAYERNORM_FWD(U, false);
}
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
}
} }
#endif #endif
......
...@@ -489,6 +489,83 @@ class TestGetSetKeepLayerNormScaleBiasFP32Flag(unittest.TestCase): ...@@ -489,6 +489,83 @@ class TestGetSetKeepLayerNormScaleBiasFP32Flag(unittest.TestCase):
self.assertTrue(_keep_layer_norm_scale_bias_to_fp32()) self.assertTrue(_keep_layer_norm_scale_bias_to_fp32())
class TestFastMathLayerNormOp(unittest.TestCase):
def check_layer_norm(
self, dtype, x_np, scale_np, bias_np, norm_axis, has_scale, has_bias
):
paddle.disable_static()
epsilon = 0.00001
x = paddle.to_tensor(x_np)
if dtype == "bfloat16":
x = x.cast(paddle.fluid.core.VarDesc.VarType.BF16)
x.stop_gradient = True
bias = paddle.to_tensor(bias_np) if has_scale else None
scale = paddle.to_tensor(scale_np) if has_bias else None
if bias is not None:
bias.stop_gradient = True
if scale is not None:
scale.stop_gradient = True
y = F.layer_norm(x, x.shape[norm_axis:], scale, bias)
y_np = y.cast('float32').numpy()
paddle.enable_static()
return y_np
def check_with_fast_math(
self, dtype, shape, norm_axis, has_scale, has_bias
):
def use_fast_math(enabled):
paddle.set_flags({'FLAGS_use_fast_math': enabled})
def __assert_close(x, y):
np.testing.assert_allclose(x, y, rtol=1e-05, atol=1e-04)
x_np = np.random.random(shape).astype('float32')
bias_np = np.random.random(shape[norm_axis:]).astype('float32')
scale_np = np.random.random(shape[norm_axis:]).astype('float32')
use_fast_math(False)
y_fast = self.check_layer_norm(
dtype, x_np, scale_np, bias_np, norm_axis, has_scale, has_bias
)
use_fast_math(True)
y_dev = self.check_layer_norm(
dtype, x_np, scale_np, bias_np, norm_axis, has_scale, has_bias
)
__assert_close(y_fast, y_dev)
def check_with_dtype(self, dtype):
self.check_with_fast_math(
dtype,
shape=[17, 129],
norm_axis=1,
has_scale=False,
has_bias=True,
)
self.check_with_fast_math(
dtype,
shape=[8, 512],
norm_axis=1,
has_scale=False,
has_bias=False,
)
self.check_with_fast_math(
dtype,
shape=[2, 768],
norm_axis=1,
has_scale=False,
has_bias=False,
)
def test_main(self):
if not paddle.is_compiled_with_cuda():
return
self.check_with_dtype(dtype="float32")
self.check_with_dtype(dtype="bfloat16")
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册