未验证 提交 b333d7ed 编写于 作者: Z zqw_1997 提交者: GitHub

remove paddle.fluid.layers.layer_norm (#49174)

* remove paddle.fluid.layers.layer_norm

* templatedoc import from paddle.fluid.layers.layer_function_generator

* del import of fluid.layers.layer_norm in __init__.py

* add import of ..common.layer_norm in __init__.py

* fix bug in UT

* fix doc
上级 aa40d80d
......@@ -66,7 +66,6 @@ __all__ = [
'fc',
'embedding',
'row_conv',
'layer_norm',
'spectral_norm',
'one_hot',
'autoincreased_step_counter',
......@@ -742,142 +741,6 @@ def _pull_box_sparse(
return outs
@templatedoc()
def layer_norm(
input,
scale=True,
shift=True,
begin_norm_axis=1,
epsilon=1e-05,
param_attr=None,
bias_attr=None,
act=None,
name=None,
):
r"""
:api_attr: Static Graph
**Layer Normalization Layer**
The API implements the function of the Layer Normalization Layer and can be applied to mini-batch input data.
Refer to `Layer Normalization <https://arxiv.org/pdf/1607.06450v1.pdf>`_
The formula is as follows:
.. math::
\\mu & = \\frac{1}{H}\\sum_{i=1}^{H} x_i
\\sigma & = \\sqrt{\\frac{1}{H}\sum_{i=1}^{H}{(x_i - \\mu)^2} + \\epsilon}
y & = f(\\frac{g}{\\sigma}(x - \\mu) + b)
- :math:`x`: the vector representation of the summed inputs to the neurons in that layer.
- :math:`H`: the number of hidden units in a layers
- :math:`\\epsilon`: the small value added to the variance to prevent division by zero.
- :math:`g`: the trainable scale parameter.
- :math:`b`: the trainable bias parameter.
Args:
input(Tensor): A multi-dimension ``Tensor`` , and the data type is float32 or float64.
scale(bool, optional): Whether to learn the adaptive gain :math:`g` after
normalization. Default: True.
shift(bool, optional): Whether to learn the adaptive bias :math:`b` after
normalization. Default: True.
begin_norm_axis(int, optional): The normalization will be performed along
dimensions from :attr:`begin_norm_axis` to :attr:`rank(input)`.
Default: 1.
epsilon(float, optional): The small value added to the variance to prevent
division by zero. Default: 1e-05.
param_attr(ParamAttr, optional): The parameter attribute for the learnable
gain :math:`g`. If :attr:`scale` is False, :attr:`param_attr` is
omitted. If :attr:`scale` is True and :attr:`param_attr` is None,
a default :code:`ParamAttr` would be added as scale. The
:attr:`param_attr` is initialized as 1 if it is added. Default: None.
bias_attr(ParamAttr, optional): The parameter attribute for the learnable
bias :math:`b`. If :attr:`shift` is False, :attr:`bias_attr` is
omitted. If :attr:`shift` is True and :attr:`param_attr` is None,
a default :code:`ParamAttr` would be added as bias. The
:attr:`bias_attr` is initialized as 0 if it is added. Default: None.
act(str, optional): Activation to be applied to the output of layer normalization.
Default: None.
name(str): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` .
Returns:
Tensor: ``Tensor`` indicating the normalized result, the data type is the same as ``input`` , and the return dimension is the same as ``input`` .
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
x = paddle.static.data(name='x', shape=[8, 32, 32], dtype='float32')
output = paddle.static.nn.layer_norm(input=x, begin_norm_axis=1)
print(output.shape) # [8, 32, 32]
"""
assert (
_non_static_mode() is not True
), "please use LayerNorm instead of layer_norm in dygraph mode!"
helper = LayerHelper('layer_norm', **locals())
check_variable_and_dtype(
input, 'input', ['float32', 'float64'], 'layer_norm'
)
dtype = helper.input_dtype()
# create intput and parameters
inputs = {'X': input}
input_shape = input.shape
param_shape = [reduce(lambda x, y: x * y, input_shape[begin_norm_axis:])]
if scale:
assert (
param_attr is not False
), "param_attr should not be False when using scale."
scale = helper.create_parameter(
attr=helper.param_attr,
shape=param_shape,
dtype=dtype,
default_initializer=Constant(1.0),
)
inputs['Scale'] = scale
else:
if param_attr:
warnings.warn("param_attr is only available with scale is True.")
if shift:
assert (
bias_attr is not False
), "bias_attr should not be False when using shift."
bias = helper.create_parameter(
attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True
)
inputs['Bias'] = bias
else:
if bias_attr:
warnings.warn("bias_attr is only available with shift is True.")
# create output
mean_out = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True
)
variance_out = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True
)
layer_norm_out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type="layer_norm",
inputs=inputs,
outputs={
"Y": layer_norm_out,
"Mean": mean_out,
"Variance": variance_out,
},
attrs={"epsilon": epsilon, "begin_norm_axis": begin_norm_axis},
)
return helper.append_activation(layer_norm_out)
@templatedoc()
def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None):
r"""
......
......@@ -1248,7 +1248,7 @@ def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.0):
if cmd == "a": # add residual connection
out = out + prev_out if prev_out else out
elif cmd == "n": # add layer normalization
out = layers.layer_norm(
out = paddle.static.nn.layer_norm(
out,
begin_norm_axis=len(out.shape) - 1,
param_attr=fluid.initializer.Constant(1.0),
......
......@@ -66,7 +66,7 @@ class TestBase(IPUOpTest):
)
scale = paddle.ParamAttr(trainable=True)
bias = paddle.ParamAttr(trainable=True)
out = paddle.fluid.layers.nn.layer_norm(
out = paddle.static.nn.layer_norm(
conv1, param_attr=scale, bias_attr=bias, **self.attrs
)
loss = paddle.mean(out)
......@@ -74,7 +74,7 @@ class TestBase(IPUOpTest):
else:
scale = self.attrs['scale']
bias = self.attrs['shift']
out = paddle.fluid.layers.nn.layer_norm(
out = paddle.static.nn.layer_norm(
x, param_attr=scale, bias_attr=bias, **self.attrs
)
self.fetch_list = [out.name]
......
......@@ -21,7 +21,6 @@ from inference_pass_test import InferencePassTest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.static.nn as nn
from paddle.fluid.core import AnalysisConfig, PassVersionChecker
......@@ -44,7 +43,7 @@ class TensorRTSubgraphPassFcTest(InferencePassTest):
self.fetch_list = [reshape_out]
def test_check_output(self):
if core.is_compiled_with_cuda():
if paddle.is_compiled_with_cuda():
use_gpu = True
# TRT output shape of fc is (1, 1000, 1, 1). To compare the output value only, flatten the results.
self.check_output_with_option(use_gpu, flatten=True)
......@@ -75,7 +74,7 @@ class TensorRTSubgraphPassConcatTest(InferencePassTest):
self.fetch_list = [out]
def test_check_output(self):
if core.is_compiled_with_cuda():
if paddle.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu)
self.assertTrue(
......@@ -101,7 +100,7 @@ class TensorRTSubgraphPassSplitTest(InferencePassTest):
self.fetch_list = [out]
def test_check_output(self):
if core.is_compiled_with_cuda():
if paddle.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu)
self.assertTrue(
......@@ -127,7 +126,7 @@ class TensorRTSubgraphPassSplitSerializeTest(InferencePassTest):
self.fetch_list = [out]
def test_check_output(self):
if core.is_compiled_with_cuda():
if paddle.is_compiled_with_cuda():
use_gpu = True
if os.path.exists(self.path + "_opt_cache"):
shutil.rmtree(self.path + "_opt_cache")
......@@ -163,7 +162,7 @@ class TensorRTSubgraphPassDynamicSplitFp16SerializeTest(InferencePassTest):
self.fetch_list = [out]
def test_check_output(self):
if core.is_compiled_with_cuda():
if paddle.is_compiled_with_cuda():
use_gpu = True
if os.path.exists(self.path + "_opt_cache"):
shutil.rmtree(self.path + "_opt_cache")
......@@ -202,7 +201,7 @@ class TensorRTSubgraphPassInstanceNormTest(InferencePassTest):
self.fetch_list = [out]
def test_check_output(self):
if core.is_compiled_with_cuda():
if paddle.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu, atol=1e-4, flatten=True)
self.assertTrue(
......@@ -231,7 +230,7 @@ class TensorRTSubgraphPassTransposeTest(InferencePassTest):
return paddle.transpose(data, [0, 3, 1, 2])
def test_check_output(self):
if core.is_compiled_with_cuda():
if paddle.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu)
self.assertTrue(
......@@ -246,7 +245,7 @@ class TensorRTSubgraphPassLayerNormTest(InferencePassTest):
data = fluid.data(
name="data", shape=[-1, 3, 64, 64], dtype="float32"
)
out = fluid.layers.layer_norm(
out = paddle.static.nn.layer_norm(
data, begin_norm_axis=self.begin_norm_axis
)
self.feeds = {
......@@ -262,7 +261,7 @@ class TensorRTSubgraphPassLayerNormTest(InferencePassTest):
self.begin_norm_axis = 1
def test_check_output(self):
if core.is_compiled_with_cuda():
if paddle.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu)
self.assertTrue(
......@@ -277,7 +276,7 @@ class TensorRTSubgraphPassLayerNormDynamicTest(InferencePassTest):
data = fluid.data(
name="data", shape=[-1, 3, 64, 64], dtype="float32"
)
out = fluid.layers.layer_norm(
out = paddle.static.nn.layer_norm(
data, begin_norm_axis=self.begin_norm_axis
)
self.feeds = {
......@@ -316,7 +315,7 @@ class TensorRTSubgraphPassLayerNormDynamicTest(InferencePassTest):
def test_check_output(self):
if os.path.exists(self.path + "_opt_cache"):
shutil.rmtree(self.path + "_opt_cache")
if core.is_compiled_with_cuda():
if paddle.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu)
self.assertTrue(
......@@ -335,7 +334,7 @@ class TensorRTSubgraphPassLayerNormDynamicFP16Test(
def test_check_output(self):
if os.path.exists(self.path + "_opt_cache"):
shutil.rmtree(self.path + "_opt_cache")
if core.is_compiled_with_cuda():
if paddle.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu, atol=0.01, rtol=0.01)
self.assertTrue(
......@@ -382,7 +381,7 @@ class TensorRTSubgraphPassElementwiseTest(InferencePassTest):
return paddle.add(x=data1, y=data2)
def test_check_output(self):
if core.is_compiled_with_cuda():
if paddle.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu)
self.assertTrue(
......@@ -445,7 +444,7 @@ class TensorRTSubgraphPassElementwiseBroadcastDynamicTest(InferencePassTest):
def test_check_output(self):
if os.path.exists(self.path + "_opt_cache"):
shutil.rmtree(self.path + "_opt_cache")
if core.is_compiled_with_cuda():
if paddle.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu)
self.assertTrue(
......
......@@ -54,7 +54,7 @@ class EmbEltwiseLayerNormFusePassTest(PassTest):
)
add1 = paddle.add(word_emb, pos_emb)
add2 = paddle.add(add1, sent_emb)
hidden1 = fluid.layers.layer_norm(input=add2, begin_norm_axis=2)
hidden1 = paddle.static.nn.layer_norm(input=add2, begin_norm_axis=2)
id1 = fluid.layers.data(
name="id1",
......@@ -95,7 +95,9 @@ class EmbEltwiseLayerNormFusePassTest(PassTest):
add_1 = paddle.add(emb1, emb2)
add_2 = paddle.add(add_1, emb3)
add_3 = paddle.add(add_2, emb4)
hidden_1 = fluid.layers.layer_norm(input=add_3, begin_norm_axis=2)
hidden_1 = paddle.static.nn.layer_norm(
input=add_3, begin_norm_axis=2
)
self.feeds = {
"word_id": np.random.randint(
......
......@@ -32,7 +32,7 @@ class SkipLayerNormFusePassTest(PassTest):
name="y", shape=[128, 768], dtype="float32", lod_level=0
)
elementwise_out = paddle.add(x=x, y=y)
out = fluid.layers.layer_norm(input=elementwise_out)
out = paddle.static.nn.layer_norm(input=elementwise_out)
self.fetch_list = [out]
self.pass_names = "skip_layernorm_fuse_pass"
......
......@@ -17,7 +17,6 @@ import numpy as np
import paddle
from operator import mul
import paddle.fluid.core as core
import paddle.fluid as fluid
import paddle.nn.functional as F
from functools import reduce
......@@ -142,7 +141,7 @@ class TestLayerNormOp(unittest.TestCase):
},
)
# generate backward op_desc
grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc(
grad_op_desc_list, op_grad_to_var = paddle.get_grad_op_desc(
layer_norm_op.desc, set(), []
)
grad_op_desc = grad_op_desc_list[0]
......@@ -154,7 +153,7 @@ class TestLayerNormOp(unittest.TestCase):
grad_op_desc.infer_shape(block.desc)
for arg in grad_op_desc.output_arg_names():
grad_var = block.desc.find_var(arg.encode("ascii"))
grad_var.set_dtype(core.VarDesc.VarType.FP32)
grad_var.set_dtype(paddle.VarDesc.VarType.FP32)
program._sync_with_cpp()
exe = fluid.Executor(place)
......@@ -252,7 +251,7 @@ class TestLayerNormAPI(unittest.TestCase):
dtype='float32',
append_batch_size=False,
)
x = fluid.layers.layer_norm(
x = paddle.static.nn.layer_norm(
x,
scale=True,
shift=True,
......@@ -261,7 +260,7 @@ class TestLayerNormAPI(unittest.TestCase):
param_attr=None,
bias_attr=None,
)
x = fluid.layers.layer_norm(
x = paddle.static.nn.layer_norm(
x,
scale=False,
shift=False,
......@@ -270,7 +269,7 @@ class TestLayerNormAPI(unittest.TestCase):
param_attr=None,
bias_attr=None,
)
x = fluid.layers.layer_norm(
x = paddle.static.nn.layer_norm(
x,
scale=False,
shift=False,
......
......@@ -62,8 +62,8 @@ class TestDygraphLoadStatic(unittest.TestCase):
emb_out_2 = fluid.embedding(emb_in, [2000, 200])
layernorm = fluid.data(name="ln", shape=[None, 10], dtype='float32')
layernorm_1 = fluid.layers.layer_norm(layernorm)
layernorm_2 = fluid.layers.layer_norm(layernorm)
layernorm_1 = paddle.static.nn.layer_norm(layernorm)
layernorm_2 = paddle.static.nn.layer_norm(layernorm)
nce_in = fluid.data(name="nce_in", shape=[None, 100], dtype='float32')
nce_label = fluid.data(
......
......@@ -341,7 +341,7 @@ class TestLayerNormAPI(unittest.TestCase):
dtype='float32',
append_batch_size=False,
)
x = fluid.layers.layer_norm(
x = paddle.static.nn.layer_norm(
x,
scale=True,
shift=True,
......@@ -350,7 +350,7 @@ class TestLayerNormAPI(unittest.TestCase):
param_attr=None,
bias_attr=None,
)
x = fluid.layers.layer_norm(
x = paddle.static.nn.layer_norm(
x,
scale=False,
shift=False,
......@@ -359,7 +359,7 @@ class TestLayerNormAPI(unittest.TestCase):
param_attr=None,
bias_attr=None,
)
x = fluid.layers.layer_norm(
x = paddle.static.nn.layer_norm(
x,
scale=False,
shift=False,
......
......@@ -232,7 +232,7 @@ def pre_post_process_layer(prev_out, out, process_cmd, dropout=0.0):
if cmd == "a": # add residual connection
out = out + prev_out if prev_out else out
elif cmd == "n": # add layer normalization
out = layers.layer_norm(
out = paddle.static.nn.layer_norm(
out,
begin_norm_axis=len(out.shape) - 1,
param_attr=fluid.initializer.Constant(1.0),
......
......@@ -31,9 +31,9 @@ from .control_flow import (
from .common import bilinear_tensor_product # noqa: F401
from .common import py_func # noqa: F401
from ...tensor.creation import create_parameter # noqa: F401
from ...fluid.layers import layer_norm # noqa: F401
from .loss import nce # noqa: F401
from .common import prelu # noqa: F401
from .common import layer_norm # noqa: F401
from ...fluid.layers import row_conv # noqa: F401
from ...fluid.layers import spectral_norm # noqa: F401
......
......@@ -13,6 +13,8 @@
# limitations under the License.
import inspect
import warnings
from functools import reduce
import numpy as np
......@@ -3234,3 +3236,138 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None):
# For debug usage
py_func.registered_func = PyFuncRegistry.registered_func
py_func.registered_func_num = PyFuncRegistry.registered_func_num
@templatedoc()
def layer_norm(
input,
scale=True,
shift=True,
begin_norm_axis=1,
epsilon=1e-05,
param_attr=None,
bias_attr=None,
act=None,
name=None,
):
r"""
**Layer Normalization Layer**
The API implements the function of the Layer Normalization Layer and can be applied to mini-batch input data.
Refer to `Layer Normalization <https://arxiv.org/pdf/1607.06450v1.pdf>`_
The formula is as follows:
.. math::
\mu & = \frac{1}{H}\sum_{i=1}^{H} x_i
\sigma & = \sqrt{\frac{1}{H}\sum_{i=1}^{H}{(x_i - \mu)^2} + \epsilon}
y & = f(\frac{g}{\sigma}(x - \mu) + b)
- :math:`x`: the vector representation of the summed inputs to the neurons in that layer.
- :math:`H`: the number of hidden units in a layers
- :math:`\\epsilon`: the small value added to the variance to prevent division by zero.
- :math:`g`: the trainable scale parameter.
- :math:`b`: the trainable bias parameter.
Args:
input(Tensor): A multi-dimension ``Tensor`` , and the data type is float32 or float64.
scale(bool, optional): Whether to learn the adaptive gain :math:`g` after
normalization. Default: True.
shift(bool, optional): Whether to learn the adaptive bias :math:`b` after
normalization. Default: True.
begin_norm_axis(int, optional): The normalization will be performed along
dimensions from :attr:`begin_norm_axis` to :attr:`rank(input)`.
Default: 1.
epsilon(float, optional): The small value added to the variance to prevent
division by zero. Default: 1e-05.
param_attr(ParamAttr, optional): The parameter attribute for the learnable
gain :math:`g`. If :attr:`scale` is False, :attr:`param_attr` is
omitted. If :attr:`scale` is True and :attr:`param_attr` is None,
a default :code:`ParamAttr` would be added as scale. The
:attr:`param_attr` is initialized as 1 if it is added. Default: None.
bias_attr(ParamAttr, optional): The parameter attribute for the learnable
bias :math:`b`. If :attr:`shift` is False, :attr:`bias_attr` is
omitted. If :attr:`shift` is True and :attr:`param_attr` is None,
a default :code:`ParamAttr` would be added as bias. The
:attr:`bias_attr` is initialized as 0 if it is added. Default: None.
act(str, optional): Activation to be applied to the output of layer normalization.
Default: None.
name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` .
Returns:
Tensor: ``Tensor`` indicating the normalized result, the data type is the same as ``input`` , and the return dimension is the same as ``input`` .
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
x = paddle.static.data(name='x', shape=[8, 32, 32], dtype='float32')
output = paddle.static.nn.layer_norm(input=x, begin_norm_axis=1)
print(output.shape) # [8, 32, 32]
"""
assert (
_non_static_mode() is not True
), "please use LayerNorm instead of layer_norm in dygraph mode!"
helper = LayerHelper('layer_norm', **locals())
check_variable_and_dtype(
input, 'input', ['float32', 'float64'], 'layer_norm'
)
dtype = helper.input_dtype()
# create intput and parameters
inputs = {'X': input}
input_shape = input.shape
param_shape = [reduce(lambda x, y: x * y, input_shape[begin_norm_axis:])]
if scale:
assert (
param_attr is not False
), "param_attr should not be False when using scale."
scale = helper.create_parameter(
attr=helper.param_attr,
shape=param_shape,
dtype=dtype,
default_initializer=Constant(1.0),
)
inputs['Scale'] = scale
else:
if param_attr:
warnings.warn("param_attr is only available with scale is True.")
if shift:
assert (
bias_attr is not False
), "bias_attr should not be False when using shift."
bias = helper.create_parameter(
attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True
)
inputs['Bias'] = bias
else:
if bias_attr:
warnings.warn("bias_attr is only available with shift is True.")
# create output
mean_out = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True
)
variance_out = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True
)
layer_norm_out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type="layer_norm",
inputs=inputs,
outputs={
"Y": layer_norm_out,
"Mean": mean_out,
"Variance": variance_out,
},
attrs={"epsilon": epsilon, "begin_norm_axis": begin_norm_axis},
)
return helper.append_activation(layer_norm_out)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册