diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index c692b6c8fcd13e5254c1d67ca1b8143102ea28be..ae8c7dd61c3bbfca86fd3acea613527474d7901c 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -259,6 +259,103 @@ void GraphSendRecvInferMeta(const MetaTensor& x, } } +void LayerNormInferMeta(const MetaTensor& x, + paddle::optional scale, + paddle::optional bias, + float epsilon, + int begin_norm_axis, + bool is_test, + MetaTensor* out, + MetaTensor* mean, + MetaTensor* variance, + MetaConfig config) { + auto x_dim = x.dims(); + PADDLE_ENFORCE_LT( + begin_norm_axis, + x_dim.size(), + phi::errors::InvalidArgument( + "'begin_norm_axis' must be less than the dimensions of X," + "But received 'begin_norm_axis' is [%d]," + "received the dimensions of X is [%d].", + begin_norm_axis, + x_dim.size())); + + auto matrix_dim = phi::flatten_to_2d(x_dim, begin_norm_axis); + int left = static_cast(matrix_dim[0]); + int right = static_cast(matrix_dim[1]); + if (scale.get_ptr() != nullptr) { + PADDLE_ENFORCE_EQ(scale->dims().size(), + 1, + phi::errors::InvalidArgument( + "The dimensions of Input(Scale) must be 1, but " + "received dimensions of" + "Input(Scale) is [%d]", + scale->dims().size())); + } + + if (config.is_runtime && scale.get_ptr() != nullptr) { + PADDLE_ENFORCE_EQ( + scale->dims()[0], + right, + phi::errors::InvalidArgument( + "The first dimension value of Input(Scale) must equal to be the" + "second dimension value of the flattened 2D matrix of Input(X)," + "But received the first dimension value of Input(Scale) is" + "[%d], the second dimension value of the flattened 2D matrix of" + " Input(Scale) is [%d].", + scale->dims()[0], + right)); + } + if (bias.get_ptr() != nullptr) { + PADDLE_ENFORCE_EQ(bias->dims().size(), + 1, + phi::errors::InvalidArgument( + "The dimensions of Input(Bias) must be 1, but " + "received dimensions of" + "Input(Bias) is [%d]", + bias->dims().size())); + } + if (config.is_runtime && bias.get_ptr() != nullptr) { + PADDLE_ENFORCE_EQ( + bias->dims()[0], + right, + phi::errors::InvalidArgument( + "The first dimension value of Input(Bias) must equal to be the" + "second dimension value of the flattened 2D matrix of Input(X)," + "But received the first dimension value of Input(Bias) is" + "[%d], the second dimension value of the flattened 2D matrix of" + " Input(Bias) is [%d].", + bias->dims()[0], + right)); + } + + out->set_dims(x_dim); + if (mean) { + mean->set_dims({left}); + } + if (variance) { + variance->set_dims({left}); + } + out->share_lod(x); +} + +void LayerNormGradInferMeta(const MetaTensor& x, + paddle::optional y, + paddle::optional z, + MetaTensor* dx, + MetaTensor* dy, + MetaTensor* dz) { + if (dx) { + dx->share_meta(x); + } + if (dy && (y.get_ptr() != nullptr)) { + dy->share_meta(*y.get_ptr()); + } + if (dz && (z.get_ptr() != nullptr)) { + dz->share_meta(*z.get_ptr()); + } +} + void LerpInferMeta(const MetaTensor& x, const MetaTensor& y, const MetaTensor& weight, diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index 83505f2c2fadaef07ae8249edf477fb5295cb907..4f561e0adf19d9443e1404b9858be7a8caa6ae9f 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -60,6 +60,24 @@ void GraphSendRecvInferMeta(const MetaTensor& x, MetaTensor* out, MetaTensor* dst_count); +void LayerNormInferMeta(const MetaTensor& x, + paddle::optional scale, + paddle::optional bias, + float epsilon, + int begin_norm_axis, + bool is_test, + MetaTensor* out, + MetaTensor* mean, + MetaTensor* variance, + MetaConfig config = MetaConfig()); + +void LayerNormGradInferMeta(const MetaTensor& x, + paddle::optional y, + paddle::optional z, + MetaTensor* dx, + MetaTensor* dy, + MetaTensor* dz); + void LerpInferMeta(const MetaTensor& x, const MetaTensor& y, const MetaTensor& weight, diff --git a/paddle/phi/kernels/cpu/layer_norm_grad_kernel.cc b/paddle/phi/kernels/cpu/layer_norm_grad_kernel.cc index cee48ed96db1c60fb77dc7c870cb256b7ce0cb6e..7c1b33f047b61ac9f6d30de4f331ff92ffa78905 100644 --- a/paddle/phi/kernels/cpu/layer_norm_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/layer_norm_grad_kernel.cc @@ -32,10 +32,10 @@ namespace phi { template void LayerNormGradKernel(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& mean, - const DenseTensor& variance, paddle::optional scale_opt, paddle::optional bias_opt, + const DenseTensor& mean, + const DenseTensor& variance, const DenseTensor& out_grad, float epsilon, int begin_norm_axis, diff --git a/paddle/phi/kernels/gpu/layer_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/layer_norm_grad_kernel.cu index c3f7a5261712a1d33bb4ad47dd080a489b303717..146d307a5938006ff3c62ca68fbc03e7587d8aae 100644 --- a/paddle/phi/kernels/gpu/layer_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/layer_norm_grad_kernel.cu @@ -24,10 +24,10 @@ namespace phi { template void LayerNormGradKernel(const Context &dev_ctx, const DenseTensor &x, - const DenseTensor &mean, - const DenseTensor &variance, paddle::optional scale_opt, paddle::optional bias_opt, + const DenseTensor &mean, + const DenseTensor &variance, const DenseTensor &out_grad, float epsilon, int begin_norm_axis, diff --git a/paddle/phi/kernels/layer_norm_grad_kernel.h b/paddle/phi/kernels/layer_norm_grad_kernel.h index c32be63db4178f92d9564f357c30bb28fb415516..65f19a11b94d608bb0fb0da300653df891b38d7e 100644 --- a/paddle/phi/kernels/layer_norm_grad_kernel.h +++ b/paddle/phi/kernels/layer_norm_grad_kernel.h @@ -21,10 +21,10 @@ namespace phi { template void LayerNormGradKernel(const Context& ctx, const DenseTensor& x, - const DenseTensor& mean, - const DenseTensor& variance, paddle::optional scale, paddle::optional bias, + const DenseTensor& mean, + const DenseTensor& variance, const DenseTensor& out_grad, float epsilon, int begin_norm_axis, diff --git a/paddle/phi/ops/compat/layer_norm_sig.cc b/paddle/phi/ops/compat/layer_norm_sig.cc index 17a81e9ec012f2c116762ff2d653bb96f0e1c4f4..4151b9e94fbdceb4e0ec375061d5efb6bb3ac4d4 100644 --- a/paddle/phi/ops/compat/layer_norm_sig.cc +++ b/paddle/phi/ops/compat/layer_norm_sig.cc @@ -27,7 +27,7 @@ KernelSignature LayerNormGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature( "layer_norm_grad", - {"X", "Mean", "Variance", "Scale", "Bias", GradVarName("Y")}, + {"X", "Scale", "Bias", "Mean", "Variance", GradVarName("Y")}, {"epsilon", "begin_norm_axis", "is_test"}, {GradVarName("X"), GradVarName("Scale"), GradVarName("Bias")}); } diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index 89fcbe1a5d18da99ce204380eafe315f1f6899ea..a3310f1a46ce44d39b10ceed5a8827eea2fdb287 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -1827,11 +1827,18 @@ class LayerNorm(layers.Layer): 1:] + ', but got input shape ' + str(input_shape)) if _non_static_mode(): - pre_act, _, _ = _C_ops.layer_norm( - input, self.weight, self.bias, 'epsilon', self._epsilon, - 'begin_norm_axis', self._begin_norm_axis) - return dygraph_utils._append_activation_in_dygraph( - pre_act, act=self._act) + if in_dygraph_mode(): + pre_act, _, _, = _C_ops.final_state_layer_norm( + input, self.weight, self.bias, self._epsilon, + self._begin_norm_axis, False) + return dygraph_utils._append_activation_in_dygraph( + pre_act, act=self._act) + else: + pre_act, _, _ = _C_ops.layer_norm( + input, self.weight, self.bias, 'epsilon', self._epsilon, + 'begin_norm_axis', self._begin_norm_axis) + return dygraph_utils._append_activation_in_dygraph( + pre_act, act=self._act) check_variable_and_dtype(input, 'input', ['float32', 'float64'], 'LayerNorm') diff --git a/python/paddle/fluid/tests/unittests/test_layer_norm_op_v2.py b/python/paddle/fluid/tests/unittests/test_layer_norm_op_v2.py index 987c3da4dd7be887c00007fb25d88acc3ae69762..85c6694324d25db9fe3be359bdeb092f073caf25 100644 --- a/python/paddle/fluid/tests/unittests/test_layer_norm_op_v2.py +++ b/python/paddle/fluid/tests/unittests/test_layer_norm_op_v2.py @@ -19,7 +19,7 @@ import paddle.fluid.core as core from paddle.fluid.op import Operator import paddle.fluid as fluid from op_test import OpTest, _set_use_system_allocator -from paddle.fluid.framework import grad_var_name +from paddle.fluid.framework import grad_var_name, _test_eager_guard import paddle.fluid as fluid from paddle.fluid import Program, program_guard import paddle @@ -36,13 +36,13 @@ class TestDygraphLayerNormv2(unittest.TestCase): def compute_v1(x): with fluid.dygraph.guard(p): ln = fluid.dygraph.LayerNorm(shape[1:]) - y = ln(fluid.dygraph.to_variable(x)) + y = ln(paddle.to_tensor(x)) return y.numpy() def compute_v2(x): with fluid.dygraph.guard(p): ln = paddle.nn.LayerNorm(shape[1:]) - y = ln(fluid.dygraph.to_variable(x)) + y = ln(paddle.to_tensor(x)) return y.numpy() x = np.random.randn(*shape).astype("float32") @@ -50,6 +50,38 @@ class TestDygraphLayerNormv2(unittest.TestCase): y2 = compute_v2(x) self.assertTrue(np.allclose(y1, y2)) + def test_eager(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu("layer_norm"): + places.append(fluid.CUDAPlace(0)) + for p in places: + shape = [4, 10, 4, 4] + + def compute_v1(x): + with fluid.dygraph.guard(p): + ln = fluid.dygraph.LayerNorm(shape[1:]) + x1 = paddle.to_tensor(x) + x1.stop_gradient = False + y = ln(x1) + y.backward() + return y.numpy(), x1.gradient() + + def compute_v2(x): + with fluid.dygraph.guard(p): + with _test_eager_guard(): + ln = paddle.nn.LayerNorm(shape[1:]) + x1 = paddle.to_tensor(x) + x1.stop_gradient = False + y = ln(x1) + y.backward() + return y.numpy(), x1.gradient() + + x = np.random.randn(*shape).astype("float32") + y1, g1 = compute_v1(x) + y2, g2 = compute_v2(x) + self.assertTrue(np.allclose(y1, y2)) + self.assertTrue(np.allclose(g1, g2)) + def test_static(self): paddle.enable_static() places = [fluid.CPUPlace()] @@ -94,30 +126,30 @@ class TestLayerNormFunction(unittest.TestCase): def compute_v0(x): with fluid.dygraph.guard(p): ln = fluid.dygraph.LayerNorm(shape[1:]) - y = ln(fluid.dygraph.to_variable(x)) + y = ln(paddle.to_tensor(x)) return y.numpy() def compute_v1(x): with fluid.dygraph.guard(p): - x = fluid.dygraph.to_variable(x) + x = paddle.to_tensor(x) y = paddle.nn.functional.layer_norm(x, shape[1:]) return y.numpy() def compute_v2(x): with fluid.dygraph.guard(p): - x = fluid.dygraph.to_variable(x) + x = paddle.to_tensor(x) y = paddle.nn.functional.layer_norm(x, tuple(shape[1:])) return y.numpy() def compute_v3(x): with fluid.dygraph.guard(p): ln = fluid.dygraph.LayerNorm(shape[-1]) - y = ln(fluid.dygraph.to_variable(x)) + y = ln(paddle.to_tensor(x)) return y.numpy() def compute_v4(x): with fluid.dygraph.guard(p): - x = fluid.dygraph.to_variable(x) + x = paddle.to_tensor(x) y = paddle.nn.functional.layer_norm(x, shape[-1]) return y.numpy() @@ -139,4 +171,5 @@ class TestLayerNormFunction(unittest.TestCase): if __name__ == '__main__': + paddle.enable_static() unittest.main() diff --git a/python/paddle/nn/functional/norm.py b/python/paddle/nn/functional/norm.py index 1a5fc109805e05ba71aca967b599b680fa302c9c..e719099b4b39da51406ba9a5305d2f8f522404d7 100644 --- a/python/paddle/nn/functional/norm.py +++ b/python/paddle/nn/functional/norm.py @@ -318,7 +318,13 @@ def layer_norm(x, str_normalized_shape[ 1:] + ', but got input shape ' + str(input_shape)) - if in_dynamic_mode(): + if in_dygraph_mode(): + pre_act, _, _, = _C_ops.final_state_layer_norm(x, weight, bias, epsilon, + begin_norm_axis, False) + + return dygraph_utils._append_activation_in_dygraph(pre_act, act=None) + + if _in_legacy_dygraph(): pre_act, _, _ = _C_ops.layer_norm(x, weight, bias, 'epsilon', epsilon, 'begin_norm_axis', begin_norm_axis) return dygraph_utils._append_activation_in_dygraph(pre_act, act=None) diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index b20259d3ebd2516bf0e63b7d3cab81d679938baa..e3d8e8f5f47a50d9583f54c51a40f7b21b57b4f7 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -1002,6 +1002,16 @@ optional : prior_dist backward : label_smooth_grad +- api : layer_norm + args : (Tensor x, Tensor scale, Tensor bias, float epsilon, int begin_norm_axis, bool is_test) + output : Tensor(out), Tensor(mean), Tensor(variance) + infer_meta : + func : LayerNormInferMeta + kernel : + func : layer_norm + backward : layer_norm_grad + optional : scale, bias + # leaky_relu - api : leaky_relu args : (Tensor x, float alpha) diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index 78f4ac7c985fb67534929f5006840260f42f6e9d..f8366744bdbe6d78e3a0a08c22599643e6ce53d7 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -723,6 +723,18 @@ func : label_smooth_grad optional : prior_dist +- backward_api : layer_norm_grad + forward : layer_norm (Tensor x, Tensor scale, Tensor bias, float epsilon, int begin_norm_axis, bool is_test) -> Tensor(out), Tensor(mean), Tensor(variance) + args : (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, Tensor out_grad, float epsilon, int begin_norm_axis, bool is_test) + output : Tensor(x_grad), Tensor(scale_grad), Tensor(bias_grad) + infer_meta : + func : LayerNormGradInferMeta + param : [x, scale, bias] + kernel : + func : layer_norm_grad + data_type : out_grad + optional : scale, bias + - backward_api : leaky_relu_grad forward : leaky_relu (Tensor x, float alpha) -> Tensor(out) args : (Tensor x, Tensor out_grad, float alpha) diff --git a/tools/infrt/skipped_phi_api.json b/tools/infrt/skipped_phi_api.json index 8e2dd0f65d7d5d2b0836491e9d9d5c4b52c7ed8b..b352240c6dcc5cff69ae19bfd0d9d6a6795dfa2a 100644 --- a/tools/infrt/skipped_phi_api.json +++ b/tools/infrt/skipped_phi_api.json @@ -1,4 +1,4 @@ { -"phi_apis":["conj", "dropout", "expand_as", "nll_loss", "psroi_pool", "roi_align", "roi_pool", "label_smooth"], +"phi_apis":["conj", "dropout", "expand_as", "nll_loss", "psroi_pool", "roi_align", "roi_pool", "label_smooth", "layer_norm"], "phi_kernels":["equal_all"] }