From ce4637c1d9acba3356a3730d258aba121f3d79f8 Mon Sep 17 00:00:00 2001 From: 201716010711 <87008376+201716010711@users.noreply.github.com> Date: Mon, 30 Jan 2023 21:42:39 -0800 Subject: [PATCH] support fp16 squaredl2norm (#48315) --- .../gpu/squared_l2_norm_grad_kernel.cu | 38 +++++++++++++++++- .../phi/kernels/gpu/squared_l2_norm_kernel.cu | 29 ++++++++++++-- .../tests/unittests/test_gradient_clip.py | 6 +-- .../unittests/test_squared_l2_norm_op.py | 40 +++++++++++++++++++ python/paddle/nn/clip.py | 9 ++--- 5 files changed, 106 insertions(+), 16 deletions(-) mode change 100644 => 100755 python/paddle/fluid/tests/unittests/test_squared_l2_norm_op.py diff --git a/paddle/phi/kernels/gpu/squared_l2_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/squared_l2_norm_grad_kernel.cu index 908a7557d1b..7fc355b51ac 100644 --- a/paddle/phi/kernels/gpu/squared_l2_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/squared_l2_norm_grad_kernel.cu @@ -15,12 +15,46 @@ #include "paddle/phi/kernels/squared_l2_norm_grad_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/squared_l2_norm_grad_kernel_impl.h" +#include "paddle/phi/kernels/funcs/broadcast_function.h" + +namespace phi { +/** + * x*y*2.0 + */ +template +struct DoubleMulFunctor { + __device__ __forceinline__ T operator()(const T a, const T b) const { + return b * a * static_cast(2.0f); + } +}; + +template +void SquaredL2NormGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& dout, + DenseTensor* dx) { + dev_ctx.template Alloc(dx); + + PADDLE_ENFORCE_EQ( + dout.numel(), + 1, + phi::errors::InvalidArgument( + "Input(GRAD@Out) of SquaredL2NormGradOP should be a scalar.")); + std::vector ins{&x, &dout}; + std::vector outs{dx}; + + funcs::BroadcastKernel( + dev_ctx, ins, &outs, -1, phi::DoubleMulFunctor()); +} +} // namespace phi PD_REGISTER_KERNEL(squared_l2_norm_grad, GPU, ALL_LAYOUT, phi::SquaredL2NormGradKernel, float, - double) {} + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/squared_l2_norm_kernel.cu b/paddle/phi/kernels/gpu/squared_l2_norm_kernel.cu index d585d209b42..81108145653 100644 --- a/paddle/phi/kernels/gpu/squared_l2_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/squared_l2_norm_kernel.cu @@ -15,9 +15,30 @@ #include "paddle/phi/kernels/squared_l2_norm_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/squared_l2_norm_kernel_impl.h" - -PD_REGISTER_KERNEL( - squared_l2_norm, GPU, ALL_LAYOUT, phi::SquaredL2NormKernel, float, double) { +#include "paddle/phi/kernels/funcs/reduce_function.h" +namespace phi { +template +void SquaredL2NormKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { + dev_ctx.template Alloc(out); + std::vector origin_reduce_dims; + for (size_t i = 0; i < x.dims().size(); i++) { + origin_reduce_dims.push_back(i); + } + phi::funcs::ReduceKernel>( + dev_ctx, x, out, kps::SquareFunctor(), origin_reduce_dims, false); } + +} // namespace phi + +PD_REGISTER_KERNEL(squared_l2_norm, + GPU, + ALL_LAYOUT, + phi::SquaredL2NormKernel, + float, + double, + phi::dtype::float16) {} diff --git a/python/paddle/fluid/tests/unittests/test_gradient_clip.py b/python/paddle/fluid/tests/unittests/test_gradient_clip.py index c74917c2a07..66fe40bf8ab 100644 --- a/python/paddle/fluid/tests/unittests/test_gradient_clip.py +++ b/python/paddle/fluid/tests/unittests/test_gradient_clip.py @@ -254,10 +254,8 @@ class TestGradientClipByGlobalNorm(TestGradientClip): self.assertListEqual( ops, [ - 'square', - 'reduce_sum', - 'square', - 'reduce_sum', + 'squared_l2_norm', + 'squared_l2_norm', 'sum', 'cast', 'sqrt', diff --git a/python/paddle/fluid/tests/unittests/test_squared_l2_norm_op.py b/python/paddle/fluid/tests/unittests/test_squared_l2_norm_op.py old mode 100644 new mode 100755 index 8124254e7b2..a7076e18a58 --- a/python/paddle/fluid/tests/unittests/test_squared_l2_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_squared_l2_norm_op.py @@ -30,6 +30,46 @@ def test_squared_l2_norm(x): return _legacy_C_ops.squared_l2_norm(x) +class TestSquaredL2NormF16Op(unittest.TestCase): + def init_test_case(self): + X = np.random.uniform(-0.1, 0.1, (8, 5, 10)).astype('float32') + return X + + def check_main(self, x_np, dtype): + paddle.disable_static() + x = paddle.to_tensor(x_np) + + x.stop_gradient = False + y = test_squared_l2_norm(x) + x_g = paddle.grad(y, [x]) + + paddle.enable_static() + return y, x_g + + def test_main(self): + x_np = self.init_test_case() + y_np_1, x_g_np_1 = self.check_main(x_np, 'float32') + y_np_2, x_g_np_2 = self.check_main(x_np, 'float16') + + def assert_equal(x, y): + np.testing.assert_allclose(x, y, rtol=1e-05, atol=0.0) + + assert_equal(y_np_1, y_np_2) + assert_equal(x_g_np_1, x_g_np_2) + + +class TestSquaredL2NormF16Op1(TestSquaredL2NormF16Op): + def init_test_case(self): + X = np.random.uniform(-2.0, 2.0, (30, 10)).astype('float32') + return X + + +class TestSquaredL2NormF16Op2(TestSquaredL2NormF16Op): + def init_test_case(self): + X = np.random.uniform(-5.0, 5.0, (20, 10, 20)).astype('float32') + return X + + class TestL2LossOp(OpTest): """Test squared_l2_norm""" diff --git a/python/paddle/nn/clip.py b/python/paddle/nn/clip.py index 10eeb631906..53eed3cae58 100644 --- a/python/paddle/nn/clip.py +++ b/python/paddle/nn/clip.py @@ -207,11 +207,8 @@ def _squared_l2_norm(x): """ x = _cast_to_mp_type_if_enabled(x) - if ( - core.is_compiled_with_xpu() - or x.dtype == core.VarDesc.VarType.FP16 - or x.dtype == core.VarDesc.VarType.BF16 - ): + + if core.is_compiled_with_xpu(): square = paddle.square(x) sum_square = paddle.sum(square) return sum_square @@ -220,7 +217,7 @@ def _squared_l2_norm(x): return _C_ops.squared_l2_norm(x) op_type = 'squared_l2_norm' - check_variable_and_dtype(x, 'x', ['float32', 'float64'], op_type) + check_variable_and_dtype(x, 'x', ['float32', 'float64', 'float16'], op_type) helper = LayerHelper(op_type, **locals()) out = helper.create_variable_for_type_inference(x.dtype) -- GitLab