From ce8ed978cbfce2e0fa503690d31d2e3244066b31 Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Tue, 1 Mar 2022 16:11:28 +0800 Subject: [PATCH] [bf16] add bf16 kernel: layer_norm p_norm reduce_sum (#39843) * add layer norm * add p norm * add reduce sum * refine layer norm register bf16 for cudnn811 * add bf16 cast for hip * add unittest * refine rocm * refine layer_norm unittest * refine reduce op * refine unittest * enhance atol for reduce unittest --- paddle/fluid/operators/cast_op.cu | 4 - paddle/fluid/operators/layer_norm_kernel.cu.h | 6 +- paddle/fluid/operators/layer_norm_op.cu | 15 ++++ paddle/fluid/operators/p_norm_op.cu | 12 +++ .../reduce_ops/reduce_sum_op.part.cu | 1 + paddle/phi/kernels/gpu/cast_kernel.cu | 4 - paddle/phi/kernels/gpu/math_kernel.cu | 1 + paddle/phi/kernels/math_kernel.cc | 1 + .../paddle/fluid/tests/unittests/op_test.py | 2 +- .../tests/unittests/test_layer_norm_op.py | 47 ++++++++++++ .../fluid/tests/unittests/test_norm_all.py | 76 ++++++++++++++++++- .../fluid/tests/unittests/test_reduce_op.py | 33 +++++++- 12 files changed, 188 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/operators/cast_op.cu b/paddle/fluid/operators/cast_op.cu index 5c7dd0e2561..eb51215790b 100644 --- a/paddle/fluid/operators/cast_op.cu +++ b/paddle/fluid/operators/cast_op.cu @@ -29,9 +29,5 @@ using CUDA = paddle::platform::CUDADeviceContext; ops::CastOpKernel>, \ ops::CastOpKernel>, ##__VA_ARGS__); -#if !defined(PADDLE_WITH_HIP) // See [ why register transfer_dtype_op alias with cast_op? ] in cast_op.cc REGISTER_CAST_CUDA_BASE(transfer_dtype, ops::CastOpKernel) -#else -REGISTER_CAST_CUDA_BASE(transfer_dtype) -#endif diff --git a/paddle/fluid/operators/layer_norm_kernel.cu.h b/paddle/fluid/operators/layer_norm_kernel.cu.h index b31c7a1cde0..62c21dd2eee 100644 --- a/paddle/fluid/operators/layer_norm_kernel.cu.h +++ b/paddle/fluid/operators/layer_norm_kernel.cu.h @@ -474,11 +474,11 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel( for (int it = 0; it < LDGS; it++) { #pragma unroll for (int jt = 0; jt < VecSize; jt++) { - U x_tmp = x[it][jt]; + U x_tmp = static_cast(x[it][jt]); U y_tmp = var_cur_row * (x_tmp - mean_cur_row); U dy_tmp = static_cast(gamma[it][jt]) * - static_cast(dout[it][jt]); // scale * dy - U dout_tmp = dout[it][jt]; // dy + static_cast(dout[it][jt]); // scale * dy + U dout_tmp = static_cast(dout[it][jt]); // dy // used for get dx (row reduction) sum_loss1 += dy_tmp; // scale * dy, sum_1 diff --git a/paddle/fluid/operators/layer_norm_op.cu b/paddle/fluid/operators/layer_norm_op.cu index d439b3220d9..dfe73d37271 100644 --- a/paddle/fluid/operators/layer_norm_op.cu +++ b/paddle/fluid/operators/layer_norm_op.cu @@ -259,6 +259,21 @@ REGISTER_OP_CUDA_KERNEL( ops::LayerNormGradKernel, ops::LayerNormGradKernel); +#elif CUDNN_VERSION_MIN(8, 1, 0) +REGISTER_OP_CUDA_KERNEL( + layer_norm, + ops::LayerNormKernel, + ops::LayerNormKernel, + ops::LayerNormKernel, + ops::LayerNormKernel); +REGISTER_OP_CUDA_KERNEL( + layer_norm_grad, + ops::LayerNormGradKernel, + ops::LayerNormGradKernel, + ops::LayerNormGradKernel, + ops::LayerNormGradKernel); #else REGISTER_OP_CUDA_KERNEL( layer_norm, diff --git a/paddle/fluid/operators/p_norm_op.cu b/paddle/fluid/operators/p_norm_op.cu index f2cb427a0a5..d0b78b9b064 100644 --- a/paddle/fluid/operators/p_norm_op.cu +++ b/paddle/fluid/operators/p_norm_op.cu @@ -39,6 +39,11 @@ __device__ __forceinline__ int sgn(T val) { __device__ __forceinline__ platform::float16 inline_abs(platform::float16 x) { return static_cast(abs(static_cast(x))); } + +__device__ __forceinline__ platform::bfloat16 inline_abs(platform::bfloat16 x) { + return static_cast(abs(static_cast(x))); +} + __device__ __forceinline__ float inline_abs(float x) { return abs(x); } __device__ __forceinline__ double inline_abs(double x) { return abs(x); } @@ -53,6 +58,11 @@ __device__ __forceinline__ platform::float16 inline_pow( return static_cast( pow(static_cast(base), static_cast(exponent))); } +__device__ __forceinline__ platform::bfloat16 inline_pow( + platform::bfloat16 base, platform::bfloat16 exponent) { + return static_cast( + pow(static_cast(base), static_cast(exponent))); +} __device__ __forceinline__ float inline_pow(float base, float exponent) { return pow(base, exponent); } @@ -202,9 +212,11 @@ using CUDA = paddle::platform::CUDADeviceContext; REGISTER_OP_CUDA_KERNEL(p_norm, ops::PnormCUDAKernel, + ops::PnormCUDAKernel, ops::PnormCUDAKernel, ops::PnormCUDAKernel); REGISTER_OP_CUDA_KERNEL( p_norm_grad, ops::PnormGradCUDAKernel, + ops::PnormGradCUDAKernel, ops::PnormGradCUDAKernel, ops::PnormGradCUDAKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu b/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu index c3d3e0cf6ec..2f6bf127518 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu @@ -23,6 +23,7 @@ REGISTER_OP_CUDA_KERNEL( reduce_sum_grad, CUDAReduceSumGradKernel, CUDAReduceSumGradKernel, CUDAReduceSumGradKernel, CUDAReduceSumGradKernel, + CUDAReduceSumGradKernel, CUDAReduceSumGradKernel, CUDAReduceSumGradKernel, CUDAReduceSumGradKernel>, CUDAReduceSumGradKernel>); diff --git a/paddle/phi/kernels/gpu/cast_kernel.cu b/paddle/phi/kernels/gpu/cast_kernel.cu index 7a6c99c5fe1..569a46f56d5 100644 --- a/paddle/phi/kernels/gpu/cast_kernel.cu +++ b/paddle/phi/kernels/gpu/cast_kernel.cu @@ -80,8 +80,4 @@ void CastKernel(const Context& dev_ctx, paddle::experimental::DataType::UNDEFINED); \ } -#if !defined(PADDLE_WITH_HIP) PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, phi::dtype::bfloat16) -#else -PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast) -#endif diff --git a/paddle/phi/kernels/gpu/math_kernel.cu b/paddle/phi/kernels/gpu/math_kernel.cu index 56e8b16ccbe..fc73ccca6de 100644 --- a/paddle/phi/kernels/gpu/math_kernel.cu +++ b/paddle/phi/kernels/gpu/math_kernel.cu @@ -155,6 +155,7 @@ PD_REGISTER_KERNEL(sum_raw, float, double, float16, + bfloat16, int16_t, int, int64_t, diff --git a/paddle/phi/kernels/math_kernel.cc b/paddle/phi/kernels/math_kernel.cc index 3cb7b66ddf7..480eb56c8b0 100644 --- a/paddle/phi/kernels/math_kernel.cc +++ b/paddle/phi/kernels/math_kernel.cc @@ -165,6 +165,7 @@ PD_REGISTER_KERNEL(sum, float, double, phi::dtype::float16, + phi::dtype::bfloat16, int16_t, int, int64_t, diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 5694ef25c79..628791afef5 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -1140,7 +1140,7 @@ class OpTest(unittest.TestCase): else: atol = 2 else: - atol = 1e-2 + atol = 1e-1 if no_check_set is not None: if self.op_type not in no_check_set_white_list.no_check_set_white_list: diff --git a/python/paddle/fluid/tests/unittests/test_layer_norm_op.py b/python/paddle/fluid/tests/unittests/test_layer_norm_op.py index 7dd310d2b88..ca9a489c749 100644 --- a/python/paddle/fluid/tests/unittests/test_layer_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_layer_norm_op.py @@ -375,6 +375,53 @@ class TestFP16ScaleBiasLayerNorm(unittest.TestCase): assert_equal(b_g_np_1, b_g_np_2) +class TestBF16ScaleBiasLayerNorm(unittest.TestCase): + def check_main(self, x_np, weight_np, bias_np, dtype): + paddle.disable_static() + + x = paddle.to_tensor(x_np) + weight = paddle.to_tensor(weight_np) + bias = paddle.to_tensor(bias_np) + + if dtype == "bfloat16": + x = x.cast(paddle.fluid.core.VarDesc.VarType.BF16) + + x.stop_gradient = False + weight.stop_gradient = False + bias.stop_gradient = False + + y = F.layer_norm(x, x.shape[1:], weight, bias) + x_g, w_g, b_g = paddle.grad(y, [x, weight, bias]) + + y_np = y.cast('float32').numpy() + x_g_np = x_g.cast('float32').numpy() + w_g_np = w_g.cast('float32').numpy() + b_g_np = b_g.cast('float32').numpy() + + paddle.enable_static() + return y_np, x_g_np, w_g_np, b_g_np + + def test_main(self): + if (not core.is_compiled_with_cuda()) or (core.cudnn_version() < 8100): + return + x_np = np.random.random([10, 20]).astype('float32') + weight_np = np.random.random([20]).astype('float32') + bias_np = np.random.random([20]).astype('float32') + + y_np_1, x_g_np_1, w_g_np_1, b_g_np_1 = self.check_main( + x_np, weight_np, bias_np, 'float32') + y_np_2, x_g_np_2, w_g_np_2, b_g_np_2 = self.check_main( + x_np, weight_np, bias_np, 'bfloat16') + + def assert_equal(x, y): + self.assertTrue(np.allclose(x, y, atol=1.e-1)) + + assert_equal(y_np_1, y_np_2) + assert_equal(x_g_np_1, x_g_np_2) + assert_equal(w_g_np_1, w_g_np_2) + assert_equal(b_g_np_1, b_g_np_2) + + class TestGetSetKeepLayerNormScaleBiasFP32Flag(unittest.TestCase): def test_main(self): self.assertTrue(_keep_layer_norm_scale_bias_to_fp32()) diff --git a/python/paddle/fluid/tests/unittests/test_norm_all.py b/python/paddle/fluid/tests/unittests/test_norm_all.py index b20305b78ef..575bc653618 100644 --- a/python/paddle/fluid/tests/unittests/test_norm_all.py +++ b/python/paddle/fluid/tests/unittests/test_norm_all.py @@ -16,7 +16,7 @@ from __future__ import print_function import unittest import numpy as np -from op_test import OpTest +from op_test import OpTest, convert_float_to_uint16 import paddle import paddle.fluid as fluid import paddle.fluid.core as core @@ -282,6 +282,80 @@ class TestPnormOpFP161(TestPnormOpFP16): self.asvector = True +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestPnormBF16Op(OpTest): + def setUp(self): + self.op_type = "p_norm" + self.init_test_case() + self.x = (np.random.random(self.shape) + 0.5).astype(np.float32) + self.norm = p_norm(self.x, self.axis, self.porder, self.keepdim, + self.asvector) + self.gradient = self.calc_gradient() + self.inputs = {'X': convert_float_to_uint16(self.x)} + self.attrs = { + 'epsilon': self.epsilon, + 'axis': self.axis, + 'keepdim': self.keepdim, + 'porder': float(self.porder), + 'asvector': self.asvector + } + self.outputs = {'Out': convert_float_to_uint16(self.norm)} + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place, atol=1e-3) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, ['X'], 'Out', user_defined_grads=self.gradient) + + def init_test_case(self): + self.shape = [2, 3, 4, 5] + self.axis = 1 + self.epsilon = 1e-12 + self.porder = 2.0 + self.keepdim = False + self.dtype = np.uint16 + self.asvector = False + + def calc_gradient(self): + self.attrs = { + 'epsilon': self.epsilon, + 'axis': self.axis, + 'keepdim': self.keepdim, + 'porder': float(self.porder), + 'asvector': self.asvector + } + x = self.x + porder = self.attrs["porder"] + axis = self.attrs["axis"] + asvector = self.attrs["asvector"] + x_dtype = x.dtype + x = x.astype(np.float32) if x.dtype == np.float16 else x + if porder == 0: + grad = np.zeros(x.shape).astype(x.dtype) + elif porder in [float("inf"), float("-inf")]: + norm = p_norm( + x, axis=axis, porder=porder, keepdims=True, reduce_all=asvector) + x_abs = np.abs(x) + grad = np.sign(x) + grad[x_abs != norm] = 0.0 + else: + norm = p_norm( + x, axis=axis, porder=porder, keepdims=True, reduce_all=asvector) + grad = np.power(norm, 1 - porder) * np.power( + np.abs(x), porder - 1) * np.sign(x) + + numel = 1 + for s in x.shape: + numel *= s + divisor = numel if asvector else x.shape[axis] + numel /= divisor + return [grad.astype(x_dtype) * 1 / numel] + + def run_fro(self, p, axis, shape_x, dtype, keep_dim, check_dim=False): with fluid.program_guard(fluid.Program()): data = fluid.data(name="X", shape=shape_x, dtype=dtype) diff --git a/python/paddle/fluid/tests/unittests/test_reduce_op.py b/python/paddle/fluid/tests/unittests/test_reduce_op.py index faa67e1d6da..d246356b4ec 100644 --- a/python/paddle/fluid/tests/unittests/test_reduce_op.py +++ b/python/paddle/fluid/tests/unittests/test_reduce_op.py @@ -16,7 +16,7 @@ from __future__ import print_function import unittest import numpy as np -from op_test import OpTest, skip_check_grad_ci +from op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16 import paddle import paddle.fluid.core as core import paddle.fluid as fluid @@ -61,6 +61,37 @@ class TestSumOp_fp16(OpTest): self.check_grad(['X'], 'Out', user_defined_grads=self.gradient) +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestSumOp_bf16(OpTest): + def setUp(self): + np.random.seed(100) + self.op_type = "reduce_sum" + self.dtype = np.uint16 + self.x = np.random.uniform(0, 0.1, (2, 5, 10)).astype(np.float32) + self.attrs = {'dim': [0, 1, 2]} + self.out = self.x.sum(axis=tuple(self.attrs['dim'])) + self.gradient = self.calc_gradient() + + self.inputs = {'X': convert_float_to_uint16(self.x)} + self.outputs = {'Out': convert_float_to_uint16(self.out)} + self.gradient = self.calc_gradient() + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, ['X'], 'Out', user_defined_grads=self.gradient) + + def calc_gradient(self): + x = self.x + grad = np.ones(x.shape, dtype=x.dtype) + return [grad] + + class TestSumOp_fp16_withInt(OpTest): def setUp(self): self.op_type = "reduce_sum" -- GitLab