From 746b774b8b925a04bb3c3e9d9908280050bf7131 Mon Sep 17 00:00:00 2001 From: limingshu <61349199+JamesLim-sy@users.noreply.github.com> Date: Thu, 23 Feb 2023 16:41:37 +0800 Subject: [PATCH] [OptionalOptimization]: LayerNorm forward Optimization with Welford (#50362) * first commit * main codes has been developed * fix all bugs * add vectorize input&output * a test for optimization_of_layer_norm_fwd * add some changes * fix memory coalesced access for more optimization. * fix addition ctest error * fix according to ci-approval * remove change on slice --- paddle/phi/kernels/gpu/layer_norm_kernel.cu | 474 +++++++++++++++++- .../tests/unittests/test_layer_norm_op.py | 77 +++ 2 files changed, 539 insertions(+), 12 deletions(-) diff --git a/paddle/phi/kernels/gpu/layer_norm_kernel.cu b/paddle/phi/kernels/gpu/layer_norm_kernel.cu index cccf93f9446..2f770098925 100644 --- a/paddle/phi/kernels/gpu/layer_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/layer_norm_kernel.cu @@ -13,14 +13,445 @@ // limitations under the License. #include "paddle/phi/kernels/layer_norm_kernel.h" - +#include "gflags/gflags.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h" #include "paddle/phi/kernels/funcs/layer_norm_util.h" +DECLARE_bool(use_fast_math); + namespace phi { +#ifdef PADDLE_WITH_CUDA +template +__device__ inline void WelfordOnline(U val, U *mean, U *square, U *count) { + *count += 1; + U delta1 = val - *mean; + *mean += delta1 / (*count); + U delta2 = val - *mean; + *square += delta1 * delta2; +} + +template +__device__ inline void WelfordOnline( + U b_mean, U b_square, U b_cnt, U *mean, U *square, U *count) { + if (b_cnt == 0) { + return; + } + + U new_cnt = *count + b_cnt; + U nb_n = b_cnt / new_cnt; + U delta = b_mean - *mean; + *mean += delta * nb_n; + *square += b_square + delta * delta * (*count) * nb_n; + *count = new_cnt; +} + +template +__device__ inline void WelfordWarpAllReduce(U *mean, U *square, U *count) { + constexpr int kWarpSize = 32; +#pragma unroll + for (int mask = 1; mask < kWarpSize; mask *= 2) { + U b_mean = __shfl_down_sync(0xffffffff, *mean, mask); + U b_square = __shfl_down_sync(0xffffffff, *square, mask); + U b_cnt = __shfl_down_sync(0xffffffff, *count, mask); + WelfordOnline(b_mean, b_square, b_cnt, mean, square, count); + } + + *mean = __shfl_sync(0xffffffff, *mean, 0, kWarpSize); + *square = __shfl_sync(0xffffffff, *square, 0, kWarpSize); + *count = __shfl_sync(0xffffffff, *count, 0, kWarpSize); +} + +template +struct ThreadAssigner { + __device__ __forceinline__ int operator()(const int cols, + const int cols_per_thread, + int32_t *last_tid_idx) { + return cols_per_thread; + } +}; + +template <> +struct ThreadAssigner<1> { + __device__ inline int operator()(const int cols, + const int cols_per_thread, + int *last_tid_idx) { + int cols_this_thread = cols_per_thread; + int last_tid = (cols / cols_per_thread); + *last_tid_idx = last_tid; + if (threadIdx.x == last_tid) { + cols_this_thread = cols - cols_per_thread * last_tid; + } else if (threadIdx.x > last_tid) { + cols_this_thread = 0; + } + return cols_this_thread; + } +}; + +template +struct LayerNormDataReader { + __device__ inline void operator()(const T *__restrict__ row_src, + U *buffer, + const int last_tid_idx, + const int read_times, + const int cols_this_thread) { + using VecT = phi::AlignedVector; + const VecT *__restrict__ v_src = + reinterpret_cast(row_src); + + for (int i = 0; i < read_times; ++i) { + VecT temp_src = v_src[threadIdx.x + i * blockDim.x]; +#pragma unroll + for (int j = 0; j < VecSize; ++j) { + buffer[i * VecSize + j] = static_cast(temp_src[j]); + } + } + } +}; + +template +struct LayerNormDataReader { + __device__ inline void operator()(const T *__restrict__ row_src, + U *buffer, + const int last_tid_idx, + const int read_times, + const int cols_this_thread) { + // read_time is just cols_per_thread while VecSize is 1. + if (threadIdx.x < last_tid_idx) { + for (int i = 0; i < cols_this_thread; ++i) { + buffer[i] = static_cast(row_src[threadIdx.x + last_tid_idx * i]); + } + } else { + for (int i = 0; i < cols_this_thread; ++i) { + buffer[i] = static_cast(row_src[i + read_times * last_tid_idx]); + } + } + } +}; + +template +struct LayerNormDataWritter { + __device__ inline void operator()( + T *__restrict__ row_dst, + const U *__restrict__ buffer, + const funcs::LayerNormScaleBiasT *__restrict__ scale, + const funcs::LayerNormScaleBiasT *__restrict__ bias, + const U row_mean, + const U row_inv_var, + const int write_times, + const int cols_this_thread, + const int last_tid_idx, + const bool valid_scale, + const bool valid_bias) { + using VecT = phi::AlignedVector; + using ScaleT = funcs::LayerNormScaleBiasT; + using VecScaleT = phi::AlignedVector; + VecT *v_dst = reinterpret_cast(row_dst); + + // cols_this_thread is just cols_per_thread + if ((!valid_scale) && (!valid_bias)) { + for (int i = 0; i < write_times; ++i) { + VecT temp_dst; +#pragma unroll + for (int j = 0; j < VecSize; ++j) { + temp_dst[j] = static_cast((buffer[i * VecSize + j] - row_mean) * + row_inv_var); + } + v_dst[threadIdx.x + blockDim.x * i] = temp_dst; + } + } else { + const VecScaleT *__restrict__ v_scale = + reinterpret_cast(scale); + const VecScaleT *__restrict__ v_bias = + reinterpret_cast(bias); + if (valid_scale && valid_bias) { + for (int i = 0; i < write_times; ++i) { + int idx = threadIdx.x + blockDim.x * i; + VecT temp_dst; + VecScaleT temp_v_scale = v_scale[idx]; + VecScaleT temp_v_bias = v_bias[idx]; +#pragma unroll + for (int j = 0; j < VecSize; ++j) { + temp_dst[j] = static_cast( + static_cast(temp_v_scale[j]) * + (buffer[i * VecSize + j] - row_mean) * row_inv_var + + static_cast(temp_v_bias[j])); + } + v_dst[idx] = temp_dst; + } + } else { + if (valid_scale) { + for (int i = 0; i < write_times; ++i) { + int idx = threadIdx.x + blockDim.x * i; + VecT temp_dst; + VecScaleT temp_v_scale = v_scale[idx]; +#pragma unroll + for (int j = 0; j < VecSize; ++j) { + temp_dst[j] = static_cast( + static_cast(temp_v_scale[j]) * + (buffer[i * VecSize + j] - row_mean) * row_inv_var); + } + v_dst[idx] = temp_dst; + } + } else { + for (int i = 0; i < write_times; ++i) { + int idx = threadIdx.x + blockDim.x * i; + VecT temp_dst; + VecScaleT temp_v_bias = v_bias[idx]; +#pragma unroll + for (int j = 0; j < VecSize; ++j) { + temp_dst[j] = static_cast( + (buffer[i * VecSize + j] - row_mean) * row_inv_var + + static_cast(temp_v_bias[j])); + } + v_dst[idx] = temp_dst; + } + } + } + } + } +}; + +template +struct LayerNormDataWritter { + __device__ __forceinline__ void operator()( + T *__restrict__ row_dst, + U *__restrict__ buffer, + const funcs::LayerNormScaleBiasT *__restrict__ scale, + const funcs::LayerNormScaleBiasT *__restrict__ bias, + const U row_mean, + const U row_inv_var, + const int write_times, + const int cols_this_thread, + const int last_tid_idx, + const bool valid_scale, + const bool valid_bias) { + // write_times is just col_per_thread. + if ((!valid_scale) && (!valid_bias)) { + if (threadIdx.x < last_tid_idx) { + for (int i = 0; i < cols_this_thread; ++i) { + row_dst[threadIdx.x + last_tid_idx * i] = + (buffer[i] - row_mean) * row_inv_var; + } + } else { + for (int i = 0; i < cols_this_thread; ++i) { + row_dst[last_tid_idx * write_times + i] = + (buffer[i] - row_mean) * row_inv_var; + } + } + } else if (valid_scale && valid_bias) { + if (threadIdx.x < last_tid_idx) { + for (int i = 0; i < cols_this_thread; ++i) { + int idx = threadIdx.x + last_tid_idx * i; + row_dst[idx] = + static_cast(static_cast(scale[idx]) * + (buffer[i] - row_mean) * row_inv_var + + static_cast(bias[idx])); + } + } else { + for (int i = 0; i < cols_this_thread; ++i) { + int idx = last_tid_idx * write_times + i; + row_dst[idx] = + static_cast(static_cast(scale[idx]) * + (buffer[i] - row_mean) * row_inv_var + + static_cast(bias[idx])); + } + } + } else { + if (valid_scale) { + if (threadIdx.x < last_tid_idx) { + for (int i = 0; i < cols_this_thread; ++i) { + int idx = threadIdx.x + last_tid_idx * i; + row_dst[idx] = static_cast(static_cast(scale[idx]) * + (buffer[i] - row_mean) * row_inv_var); + } + } else { + for (int i = 0; i < cols_this_thread; ++i) { + int idx = last_tid_idx * write_times + i; + row_dst[idx] = static_cast(static_cast(scale[idx]) * + (buffer[i] - row_mean) * row_inv_var); + } + } + } else { + if (threadIdx.x < last_tid_idx) { + for (int i = 0; i < cols_this_thread; ++i) { + int idx = threadIdx.x + last_tid_idx * i; + row_dst[idx] = static_cast((buffer[i] - row_mean) * row_inv_var + + static_cast(bias[idx])); + } + } else { + for (int i = 0; i < cols_this_thread; ++i) { + int idx = last_tid_idx * write_times + i; + row_dst[idx] = static_cast((buffer[i] - row_mean) * row_inv_var + + static_cast(bias[idx])); + } + } + } + } + } +}; + +template +__global__ void LayerNormFwdWithWelford( + const T *__restrict__ src_data, + T *dst_data, + const funcs::LayerNormScaleBiasT *__restrict__ scale, + const funcs::LayerNormScaleBiasT *__restrict__ bias, + U *mean, + U *var, + const U epsilon, + const IndexT rows, + const int32_t cols, + const int32_t cols_per_thread, + const bool valid_scale, + const bool valid_bias) { + constexpr int kWarpSize = 32; + int last_tid_idx = 0; // For condition once vecSize is 1. + IndexT row_offset = blockIdx.x * blockDim.y + threadIdx.y; + int cols_this_thread = + ThreadAssigner()(cols, cols_per_thread, &last_tid_idx); + int read_times = cols_per_thread / VecSize; + + if (row_offset < rows) { + U buffer[kWarpSize]; + U tid_cnt = static_cast(0); + U tid_mean = static_cast(0); + U tid_square = static_cast(0); + + const T *__restrict__ row_src = src_data + row_offset * cols; + T *row_dst = dst_data + row_offset * cols; + LayerNormDataReader()( + row_src, buffer, last_tid_idx, read_times, cols_this_thread); + + for (int i = 0; i < cols_this_thread; i++) { + WelfordOnline(buffer[i], &tid_mean, &tid_square, &tid_cnt); + } + + U warp_cnt = tid_cnt; + U warp_mean = tid_mean; + U warp_square = tid_square; + WelfordWarpAllReduce(&warp_mean, &warp_square, &warp_cnt); + + U row_variance = max(warp_square / warp_cnt, 0.f); + U row_inv_var = funcs::rsqrt_(row_variance + epsilon); + + // TODO(limingshu): make code below vectorization. + if (threadIdx.x == 0) { + // warp_mean is just row_mean here. + mean[row_offset] = warp_mean; + var[row_offset] = row_variance; + } + LayerNormDataWritter()(row_dst, + buffer, + scale, + bias, + warp_mean, + row_inv_var, + read_times, + cols_this_thread, + last_tid_idx, + valid_scale, + valid_bias); + } +} + +template +void LaunchLayerNormKernel(const Context &dev_ctx, + const T *x_data, + T *y_data, + const void *void_scale_data, + const void *void_bias_data, + U *mean_data, + U *var_data, + float epsilon, + const int64_t rows, + const int cols, + const bool valid_scale, + const bool valid_bias, + const bool is_same_type) { + constexpr int WarpSize = 32; + constexpr int RowPerBlock = 4; + int64_t block_size = (rows + (RowPerBlock - 1)) / RowPerBlock; + dim3 threads(WarpSize, RowPerBlock, 1); + + int vec_size = 1; + int cols_per_thread = (cols + (WarpSize - 1)) / WarpSize; + if (cols_per_thread > 1 && (cols % WarpSize == 0)) { + int data_vec_size = 0; + uint64_t addr = (reinterpret_cast(x_data) | + reinterpret_cast(y_data)); + if (valid_bias || valid_scale) { + if (is_same_type) { + addr = valid_scale + ? (addr | reinterpret_cast(void_scale_data)) + : addr; + addr = valid_bias ? (addr | reinterpret_cast(void_bias_data)) + : addr; + data_vec_size = phi::GetVectorizedSize(reinterpret_cast(addr)); + } else { + uint64_t bias_addr = reinterpret_cast(void_bias_data); + uint64_t attr_addr = valid_scale + ? reinterpret_cast(void_scale_data) + : bias_addr; + attr_addr = valid_bias + ? (valid_scale ? (attr_addr | bias_addr) : attr_addr) + : attr_addr; + data_vec_size = std::min( + phi::GetVectorizedSize(reinterpret_cast(addr)), + phi::GetVectorizedSize(reinterpret_cast(attr_addr))); + } + } + for (int size = data_vec_size; size > 0; size /= 2) { + if (cols_per_thread % size == 0) { + vec_size = size; + break; + } + } + } + +#define IMPL_LAYER_NORM_WELFORD_CASE(index_t, scale_t, is_same_, vec_size_) \ + case (vec_size_): { \ + LayerNormFwdWithWelford \ + <<>>( \ + x_data, \ + y_data, \ + static_cast(void_scale_data), \ + static_cast(void_bias_data), \ + mean_data, \ + var_data, \ + static_cast(epsilon), \ + rows, \ + cols, \ + cols_per_thread, \ + valid_scale, \ + valid_bias); \ + } break + +#define IMPL_LAYER_NORM_WELFORD(index_t, scale_t, is_same_) \ + IMPL_LAYER_NORM_WELFORD_CASE(index_t, scale_t, is_same_, 4); \ + IMPL_LAYER_NORM_WELFORD_CASE(index_t, scale_t, is_same_, 2); \ + IMPL_LAYER_NORM_WELFORD_CASE(index_t, scale_t, is_same_, 1); + + if (rows < std::numeric_limits::max()) { + if (is_same_type) { + switch (vec_size) { IMPL_LAYER_NORM_WELFORD(int32_t, T, true); } + } else { + switch (vec_size) { IMPL_LAYER_NORM_WELFORD(int32_t, U, false); } + } + } else { + if (is_same_type) { + switch (vec_size) { IMPL_LAYER_NORM_WELFORD(int64_t, T, true); } + } else { + switch (vec_size) { IMPL_LAYER_NORM_WELFORD(int64_t, U, false); } + } + } +#undef IMPL_LAYER_NORM_WELFORD_CASE +#undef IMPL_LAYER_NORM_WELFORD +} +#endif // PADDLE_WITH_CUDA + template void LayerNormDirectCUDAFunctor::operator()(gpuStream_t stream, const T *input, @@ -75,14 +506,16 @@ void LayerNormKernel(const Context &dev_ctx, auto *mean_data = dev_ctx.template Alloc(mean); auto *var_data = dev_ctx.template Alloc(var); - auto *void_scale_data = (scale == nullptr ? nullptr : scale->data()); - auto *void_bias_data = (bias == nullptr ? nullptr : bias->data()); + bool valid_scale = (scale != nullptr); + bool valid_bias = (bias != nullptr); + auto *void_scale_data = valid_scale ? scale->data() : nullptr; + auto *void_bias_data = valid_bias ? bias->data() : nullptr; auto x_dtype = x.dtype(); phi::DataType scale_bias_dtype; - if (void_scale_data != nullptr) { + if (valid_scale) { scale_bias_dtype = scale->dtype(); - if (void_bias_data != nullptr) { + if (valid_bias) { PADDLE_ENFORCE_EQ( scale->dtype(), bias->dtype(), @@ -90,7 +523,7 @@ void LayerNormKernel(const Context &dev_ctx, "should have the same data type.")); } } else { - scale_bias_dtype = (void_bias_data != nullptr ? bias->dtype() : x_dtype); + scale_bias_dtype = valid_bias ? bias->dtype() : x_dtype; } bool is_scale_bias_same_dtype_with_x = x_dtype == scale_bias_dtype; @@ -104,7 +537,6 @@ void LayerNormKernel(const Context &dev_ctx, auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis); int64_t batch_size = static_cast(matrix_dim[0]); int64_t feature_size = static_cast(matrix_dim[1]); - auto stream = dev_ctx.stream(); #define PADDLE_LAUNCH_LAYERNORM_FWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \ @@ -200,13 +632,31 @@ void LayerNormKernel(const Context &dev_ctx, } } } else { -#endif - if (is_scale_bias_same_dtype_with_x) { - PADDLE_LAUNCH_LAYERNORM_FWD(T, true); + // WarpShuffle intrinsics is involved in LaunchLayerNormKernel. + if (FLAGS_use_fast_math && feature_size <= 1024 && + (!std::is_same::value)) { + LaunchLayerNormKernel(dev_ctx, + x_data, + y_data, + void_scale_data, + void_bias_data, + mean_data, + var_data, + epsilon, + batch_size, + feature_size, + valid_scale, + valid_bias, + is_scale_bias_same_dtype_with_x); } else { - PADDLE_LAUNCH_LAYERNORM_FWD(U, false); - } +#endif + if (is_scale_bias_same_dtype_with_x) { + PADDLE_LAUNCH_LAYERNORM_FWD(T, true); + } else { + PADDLE_LAUNCH_LAYERNORM_FWD(U, false); + } #ifdef PADDLE_WITH_CUDA + } } #endif diff --git a/python/paddle/fluid/tests/unittests/test_layer_norm_op.py b/python/paddle/fluid/tests/unittests/test_layer_norm_op.py index 0878f468074..d9bef2efa39 100644 --- a/python/paddle/fluid/tests/unittests/test_layer_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_layer_norm_op.py @@ -489,6 +489,83 @@ class TestGetSetKeepLayerNormScaleBiasFP32Flag(unittest.TestCase): self.assertTrue(_keep_layer_norm_scale_bias_to_fp32()) +class TestFastMathLayerNormOp(unittest.TestCase): + def check_layer_norm( + self, dtype, x_np, scale_np, bias_np, norm_axis, has_scale, has_bias + ): + paddle.disable_static() + epsilon = 0.00001 + + x = paddle.to_tensor(x_np) + if dtype == "bfloat16": + x = x.cast(paddle.fluid.core.VarDesc.VarType.BF16) + + x.stop_gradient = True + bias = paddle.to_tensor(bias_np) if has_scale else None + scale = paddle.to_tensor(scale_np) if has_bias else None + if bias is not None: + bias.stop_gradient = True + if scale is not None: + scale.stop_gradient = True + + y = F.layer_norm(x, x.shape[norm_axis:], scale, bias) + y_np = y.cast('float32').numpy() + paddle.enable_static() + return y_np + + def check_with_fast_math( + self, dtype, shape, norm_axis, has_scale, has_bias + ): + def use_fast_math(enabled): + paddle.set_flags({'FLAGS_use_fast_math': enabled}) + + def __assert_close(x, y): + np.testing.assert_allclose(x, y, rtol=1e-05, atol=1e-04) + + x_np = np.random.random(shape).astype('float32') + bias_np = np.random.random(shape[norm_axis:]).astype('float32') + scale_np = np.random.random(shape[norm_axis:]).astype('float32') + + use_fast_math(False) + y_fast = self.check_layer_norm( + dtype, x_np, scale_np, bias_np, norm_axis, has_scale, has_bias + ) + use_fast_math(True) + y_dev = self.check_layer_norm( + dtype, x_np, scale_np, bias_np, norm_axis, has_scale, has_bias + ) + __assert_close(y_fast, y_dev) + + def check_with_dtype(self, dtype): + self.check_with_fast_math( + dtype, + shape=[17, 129], + norm_axis=1, + has_scale=False, + has_bias=True, + ) + self.check_with_fast_math( + dtype, + shape=[8, 512], + norm_axis=1, + has_scale=False, + has_bias=False, + ) + self.check_with_fast_math( + dtype, + shape=[2, 768], + norm_axis=1, + has_scale=False, + has_bias=False, + ) + + def test_main(self): + if not paddle.is_compiled_with_cuda(): + return + self.check_with_dtype(dtype="float32") + self.check_with_dtype(dtype="bfloat16") + + if __name__ == '__main__': paddle.enable_static() unittest.main() -- GitLab