From 12547fb456bca29e2f62c89cb7ca396f79cb8aef Mon Sep 17 00:00:00 2001 From: ZZK <359521840@qq.com> Date: Wed, 16 Aug 2023 20:23:05 +0800 Subject: [PATCH] Refine FusedNorm comment (#56305) * refine static op return val --- paddle/phi/api/yaml/fused_ops.yaml | 11 + paddle/phi/api/yaml/ops.yaml | 10 - paddle/phi/infermeta/multiary.cc | 20 +- .../fusion/gpu/fused_layernorm_kernel.cu | 1 - .../fusion/gpu/fused_layernorm_kernel.h | 43 ---- .../nn/functional/fused_layer_norm.py | 4 +- .../incubate/nn/functional/fused_rms_norm.py | 2 +- test/legacy_test/test_fused_layernorm_op.py | 200 ++++++++++++++++-- test/legacy_test/test_rms_norm_op.py | 1 - 9 files changed, 208 insertions(+), 84 deletions(-) delete mode 100644 paddle/phi/kernels/fusion/gpu/fused_layernorm_kernel.h diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index 46f257f6511..a45759b4000 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -102,6 +102,17 @@ optional : bias, dequant_scales, shift, smooth support_dygraph_mode : true +- op : fused_bias_residual_layernorm + args : (Tensor x, Tensor bias, Tensor residual, Tensor norm_weight, Tensor norm_bias, float epsilon, float residual_alpha, int begin_norm_axis, float quant_scale, int quant_round_type, float quant_max_bound, float quant_min_bound) + output : Tensor(out), Tensor(residual_out), Tensor(mean), Tensor(variance) + infer_meta : + func : FusedLayerNormInferMeta + kernel : + func : fused_bias_residual_layernorm + data_type : x + optional : bias, residual, norm_weight, norm_bias, residual_out + support_dygraph_mode : true + - op : fused_dropout_add args : (Tensor x, Tensor y, Tensor seed_tensor, Scalar p, bool is_test, str mode, int seed = 0, bool fix_seed = false) optional : seed_tensor diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 0241f2c24df..3937ccc21a3 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1017,16 +1017,6 @@ data_type : dtype backend : place -- op : fused_bias_residual_layernorm - args : (Tensor x, Tensor bias, Tensor residual, Tensor norm_weight, Tensor norm_bias, float epsilon, float residual_alpha, int begin_norm_axis, float quant_scale, int quant_round_type, float quant_max_bound, float quant_min_bound) - output : Tensor(out), Tensor(residual_out), Tensor(mean), Tensor(variance) - infer_meta : - func : FusedLayerNormInferMeta - kernel : - func : fused_bias_residual_layernorm - data_type : x - optional : bias, residual, norm_weight, norm_bias, residual_out - - op : gather args : (Tensor x, Tensor index, Scalar axis=0) output : Tensor(out) diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 481e788a0ed..06232c06907 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -1535,15 +1535,17 @@ void FusedLayerNormInferMeta(const MetaTensor& x, rows *= x.dims()[i]; } - PADDLE_ENFORCE_EQ(normalized_dims, - norm_weight.dims()[0], - phi::errors::InvalidArgument( - "The normalized size of Input(X) must equal to be" - "the size of Weight, but received" - "normalized size of Input(X) is [%d], received size" - "of Weight is [%d]", - normalized_dims, - norm_weight.dims()[0])); + if (norm_weight) { + PADDLE_ENFORCE_EQ(normalized_dims, + norm_weight.dims()[0], + phi::errors::InvalidArgument( + "The normalized size of Input(X) must equal to be" + "the size of Weight, but received" + "normalized size of Input(X) is [%d], received size" + "of Weight is [%d]", + normalized_dims, + norm_weight.dims()[0])); + } auto out_dims = phi::make_ddim(x_dims_vec); diff --git a/paddle/phi/kernels/fusion/gpu/fused_layernorm_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_layernorm_kernel.cu index 138e5583a3a..61e88906648 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_layernorm_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_layernorm_kernel.cu @@ -34,7 +34,6 @@ limitations under the License. // The following code modified from OneFlow's implementation, and change to use // single Pass algorithm. Support Int8 quant, dequant Load/Store implementation. -#include "paddle/phi/kernels/fusion/gpu/fused_layernorm_kernel.h" #include #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/amp_type_traits.h" diff --git a/paddle/phi/kernels/fusion/gpu/fused_layernorm_kernel.h b/paddle/phi/kernels/fusion/gpu/fused_layernorm_kernel.h deleted file mode 100644 index ad6531b5403..00000000000 --- a/paddle/phi/kernels/fusion/gpu/fused_layernorm_kernel.h +++ /dev/null @@ -1,43 +0,0 @@ - -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/phi/core/dense_tensor.h" - -namespace phi { -namespace fusion { - -template -void FusedLayerNormKernel(const Context& dev_ctx, - const DenseTensor& x, - const paddle::optional& bias, - const paddle::optional& residual, - const paddle::optional& norm_weight, - const paddle::optional& norm_bias, - const float epsilon, - const float residual_alpha, - const int begin_norm_axis, - const float quant_scale, - const int quant_round_type, - const float quant_max_bound, - const float quant_min_bound, - DenseTensor* out, - DenseTensor* residual_out, - DenseTensor* mean, - DenseTensor* variance); - -} // namespace fusion -} // namespace phi diff --git a/python/paddle/incubate/nn/functional/fused_layer_norm.py b/python/paddle/incubate/nn/functional/fused_layer_norm.py index d158e8b9191..a315ee2ee04 100644 --- a/python/paddle/incubate/nn/functional/fused_layer_norm.py +++ b/python/paddle/incubate/nn/functional/fused_layer_norm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -122,4 +122,4 @@ def fused_layer_norm( }, outputs=outputs_dict, ) - return out + return (out, residual_out) if residual is not None else out diff --git a/python/paddle/incubate/nn/functional/fused_rms_norm.py b/python/paddle/incubate/nn/functional/fused_rms_norm.py index 1fdf4a31431..54c6e1dfba0 100644 --- a/python/paddle/incubate/nn/functional/fused_rms_norm.py +++ b/python/paddle/incubate/nn/functional/fused_rms_norm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/test/legacy_test/test_fused_layernorm_op.py b/test/legacy_test/test_fused_layernorm_op.py index e3fc12a7c3b..a50f216c673 100644 --- a/test/legacy_test/test_fused_layernorm_op.py +++ b/test/legacy_test/test_fused_layernorm_op.py @@ -34,6 +34,10 @@ def quant_helper( ) +def naive_residual_bias_add(x, residual, bias, residual_alpha): + return x + residual_alpha * residual + bias + + def naive_layer_norm(x, gamma, beta, epsilon): x_float = paddle.cast(x, dtype=paddle.float32) mean = paddle.mean(x_float, axis=-1, keepdim=True) @@ -66,13 +70,14 @@ def naive_residual_biasadd_layer_norm( x, residual, bias, gamma, beta, epsilon, residual_alpha ): x = x + residual * residual_alpha + bias + residual_out = x mean = paddle.mean(x, axis=-1, keepdim=True) var = paddle.var(x, axis=-1, keepdim=True) sqrt_var = paddle.rsqrt(var + epsilon) out = ((x - mean) * sqrt_var) * paddle.cast(gamma, x.dtype) + paddle.cast( beta, x.dtype ) - return out + return out, residual_out def naive_residual_biasadd_layer_norm_int8( @@ -88,13 +93,13 @@ def naive_residual_biasadd_layer_norm_int8( quant_max_bound, quant_min_bound, ): - out = naive_residual_biasadd_layer_norm( + out, residual_out = naive_residual_biasadd_layer_norm( x, residual, bias, gamma, beta, epsilon, residual_alpha ) out = quant_helper( out, in_scale, quant_round_type, quant_max_bound, quant_min_bound ) - return out + return out, residual_out @unittest.skipIf( @@ -165,6 +170,29 @@ class TestlayernormOp(unittest.TestCase): paddle.enable_static() return paddle_layernorm_out, paddle_naive_layernorm_out + def check_residual_bias_add(self, x_np, residual_np, bias_np, dtype): + paddle.disable_static() + x = paddle.to_tensor(x_np.astype(dtype)) + residual = paddle.to_tensor(residual_np.astype(dtype)) + bias = paddle.to_tensor(bias_np.astype(dtype)) + + paddle_layernorm_out = paddle.incubate.nn.functional.fused_layer_norm( + x, + None, + None, + self.epsilon, + begin_norm_axis=1, + bias=bias, + residual=residual, + residual_alpha=self.residual_alpha, + ) + + paddle_naive_residual_out = naive_residual_bias_add( + x, residual, bias, self.residual_alpha + ) + paddle.enable_static() + return (paddle_layernorm_out, paddle_naive_residual_out) + def check_residual_bias_layernorm( self, x_np, gamma_np, beta_np, residual_np, bias_np, dtype ): @@ -186,11 +214,18 @@ class TestlayernormOp(unittest.TestCase): residual_alpha=self.residual_alpha, ) - paddle_naive_layernorm_out = naive_residual_biasadd_layer_norm( + ( + paddle_naive_layernorm_out, + paddle_naive_residual_out, + ) = naive_residual_biasadd_layer_norm( x, residual, bias, gamma, beta, self.epsilon, self.residual_alpha ) paddle.enable_static() - return paddle_layernorm_out, paddle_naive_layernorm_out + return ( + paddle_layernorm_out, + paddle_naive_layernorm_out, + paddle_naive_residual_out, + ) def check_residual_bias_layernorm_int8( self, x_np, gamma_np, beta_np, residual_np, bias_np, dtype @@ -217,7 +252,10 @@ class TestlayernormOp(unittest.TestCase): quant_min_bound=self.quant_min_bound, ) - paddle_naive_layernorm_out = naive_residual_biasadd_layer_norm_int8( + ( + paddle_naive_layernorm_out, + paddle_naive_residual_out, + ) = naive_residual_biasadd_layer_norm_int8( x, residual, bias, @@ -231,7 +269,28 @@ class TestlayernormOp(unittest.TestCase): self.quant_min_bound, ) paddle.enable_static() - return paddle_layernorm_out, paddle_naive_layernorm_out + return ( + paddle_layernorm_out, + paddle_naive_layernorm_out, + paddle_naive_residual_out, + ) + + def test_residual_bias_add(self): + if not paddle.is_compiled_with_cuda(): + return + ( + paddle_residual_bias_out, + paddle_naive_residual_bias_out, + ) = self.check_residual_bias_add( + self.x_np, self.residual_np, self.bias_np, 'float16' + ) + + np.testing.assert_allclose( + paddle_residual_bias_out[0].numpy(), + paddle_naive_residual_bias_out.numpy(), + rtol=1e-3, + atol=1e-3, + ) def test_layernorm_fp16(self): if not paddle.is_compiled_with_cuda(): @@ -266,6 +325,7 @@ class TestlayernormOp(unittest.TestCase): ( paddle_layernorm, paddle_naive_layernorm, + paddle_naive_residual_out, ) = self.check_residual_bias_layernorm( self.x_np, self.norm_weight_np, @@ -282,12 +342,20 @@ class TestlayernormOp(unittest.TestCase): atol=1e-3, ) + np.testing.assert_allclose( + paddle_layernorm[1].numpy(), + paddle_naive_residual_out.numpy(), + rtol=1e-3, + atol=1e-3, + ) + def test_residual_bias_add_layernorm_int8(self): if not paddle.is_compiled_with_cuda(): return ( paddle_layernorm, paddle_naive_layernorm, + paddle_naive_residual_out, ) = self.check_residual_bias_layernorm_int8( self.x_np, self.norm_weight_np, @@ -304,6 +372,13 @@ class TestlayernormOp(unittest.TestCase): atol=2, ) + np.testing.assert_allclose( + paddle_layernorm[1].numpy(), + paddle_naive_residual_out.numpy(), + rtol=1e-3, + atol=1e-3, + ) + @unittest.skipIf( not core.is_compiled_with_cuda(), "core is not compiled with CUDA " @@ -367,7 +442,7 @@ class TestlayernormStaticOp(unittest.TestCase): }, fetch_list=[outs], ) - return out_s[0], paddle_naive_layernorm_out + return out_s, paddle_naive_layernorm_out def check_layernorm_int8(self, x_np, gamma_np, beta_np, dtype): paddle.disable_static() @@ -417,7 +492,56 @@ class TestlayernormStaticOp(unittest.TestCase): }, fetch_list=[outs], ) - return out_s[0], paddle_naive_layernorm_out + return out_s, paddle_naive_layernorm_out + + def check_residual_bias_add(self, x_np, residual_np, bias_np, dtype): + paddle.disable_static() + x = paddle.to_tensor(x_np.astype(dtype)) + residual = paddle.to_tensor(residual_np.astype(dtype)) + bias = paddle.to_tensor(bias_np.astype(dtype)) + + paddle_naive_residual_out = naive_residual_bias_add( + x, residual, bias, self.residual_alpha + ) + paddle.enable_static() + + with paddle.static.program_guard(paddle.static.Program()): + x_static = paddle.static.data( + name="x_static", shape=[self.batch, self.cols], dtype=dtype + ) + residual_static = paddle.static.data( + name="residual_static", + shape=[self.batch, self.cols], + dtype=dtype, + ) + bias_static = paddle.static.data( + name="bias_static", shape=[self.cols], dtype=dtype + ) + outs = paddle.incubate.nn.functional.fused_layer_norm( + x_static, + None, + None, + self.epsilon, + begin_norm_axis=1, + bias=bias_static, + residual=residual_static, + residual_alpha=self.residual_alpha, + quant_scale=self.quant_scale, + quant_round_type=self.quant_round_type, + quant_max_bound=self.quant_max_bound, + quant_min_bound=self.quant_min_bound, + ) + + exe = fluid.Executor(self.place) + out_s = exe.run( + feed={ + "x_static": x_np.astype(dtype), + "residual_static": residual_np.astype(dtype), + "bias_static": bias_np.astype(dtype), + }, + fetch_list=[outs], + ) + return out_s, paddle_naive_residual_out def check_residual_bias_layernorm( self, x_np, gamma_np, beta_np, residual_np, bias_np, dtype @@ -429,7 +553,10 @@ class TestlayernormStaticOp(unittest.TestCase): residual = paddle.to_tensor(residual_np.astype(dtype)) bias = paddle.to_tensor(bias_np.astype(dtype)) - paddle_naive_layernorm_out = naive_residual_biasadd_layer_norm( + ( + paddle_naive_layernorm_out, + paddle_naive_residual_out, + ) = naive_residual_biasadd_layer_norm( x, residual, bias, gamma, beta, self.epsilon, self.residual_alpha ) paddle.enable_static() @@ -474,7 +601,7 @@ class TestlayernormStaticOp(unittest.TestCase): }, fetch_list=[outs], ) - return out_s[0], paddle_naive_layernorm_out + return out_s, paddle_naive_layernorm_out, paddle_naive_residual_out def check_residual_bias_layernorm_int8( self, x_np, gamma_np, beta_np, residual_np, bias_np, dtype @@ -486,7 +613,10 @@ class TestlayernormStaticOp(unittest.TestCase): residual = paddle.to_tensor(residual_np.astype(dtype)) bias = paddle.to_tensor(bias_np.astype(dtype)) - paddle_naive_layernorm_out = naive_residual_biasadd_layer_norm_int8( + ( + paddle_naive_layernorm_out, + paddle_naive_residual_out, + ) = naive_residual_biasadd_layer_norm_int8( x, residual, bias, @@ -545,7 +675,7 @@ class TestlayernormStaticOp(unittest.TestCase): }, fetch_list=[outs], ) - return out_s[0], paddle_naive_layernorm_out + return out_s, paddle_naive_layernorm_out, paddle_naive_residual_out def test_layernorm_fp16(self): if not paddle.is_compiled_with_cuda(): @@ -555,7 +685,7 @@ class TestlayernormStaticOp(unittest.TestCase): ) np.testing.assert_allclose( - paddle_layernorm, + paddle_layernorm[0], paddle_naive_layernorm.numpy(), rtol=1e-3, atol=1e-3, @@ -568,18 +698,39 @@ class TestlayernormStaticOp(unittest.TestCase): self.x_np, self.norm_weight_np, self.norm_bias_np, 'float16' ) np.testing.assert_allclose( - paddle_layernorm, + paddle_layernorm[0], paddle_naive_layernorm.numpy(), rtol=2, atol=2, ) + def test_residual_bias_add(self): + if not paddle.is_compiled_with_cuda(): + return + ( + paddle_layernorm, + paddle_naive_residual_out, + ) = self.check_residual_bias_add( + self.x_np, + self.residual_np, + self.bias_np, + 'float16', + ) + + np.testing.assert_allclose( + paddle_layernorm[0], + paddle_naive_residual_out.numpy(), + rtol=1e-3, + atol=1e-3, + ) + def test_residual_bias_add_layernorm_fp16(self): if not paddle.is_compiled_with_cuda(): return ( paddle_layernorm, paddle_naive_layernorm, + paddle_naive_residual_out, ) = self.check_residual_bias_layernorm( self.x_np, self.norm_weight_np, @@ -590,18 +741,26 @@ class TestlayernormStaticOp(unittest.TestCase): ) np.testing.assert_allclose( - paddle_layernorm, + paddle_layernorm[0], paddle_naive_layernorm.numpy(), rtol=1e-3, atol=1e-3, ) + np.testing.assert_allclose( + paddle_layernorm[1], + paddle_naive_residual_out.numpy(), + rtol=1e-3, + atol=1e-3, + ) + def test_residual_bias_add_layernorm_int8(self): if not paddle.is_compiled_with_cuda(): return ( paddle_layernorm, paddle_naive_layernorm, + paddle_naive_residual_out, ) = self.check_residual_bias_layernorm_int8( self.x_np, self.norm_weight_np, @@ -612,12 +771,19 @@ class TestlayernormStaticOp(unittest.TestCase): ) np.testing.assert_allclose( - paddle_layernorm, + paddle_layernorm[0], paddle_naive_layernorm.numpy(), rtol=2, atol=2, ) + np.testing.assert_allclose( + paddle_layernorm[1], + paddle_naive_residual_out.numpy(), + rtol=2, + atol=2, + ) + if __name__ == "__main__": unittest.main() diff --git a/test/legacy_test/test_rms_norm_op.py b/test/legacy_test/test_rms_norm_op.py index 465bc8672b7..d5b6530ed51 100644 --- a/test/legacy_test/test_rms_norm_op.py +++ b/test/legacy_test/test_rms_norm_op.py @@ -487,7 +487,6 @@ class TestRMSNormStaticOp(unittest.TestCase): paddle_rmsnorm, paddle_naive_rmsnorm = self.check_rmsnorm_int8( self.x_np, self.norm_weight_np, self.norm_bias_np, 'float16' ) - print("1111") np.testing.assert_allclose( paddle_rmsnorm, paddle_naive_rmsnorm.numpy(), -- GitLab