未验证 提交 1f26dce6 编写于 作者: C Cao Ying 提交者: GitHub

Merge pull request #8302 from guoshengCS/add-python-layernorm

Add python wrapper for layer normalization.
......@@ -323,6 +323,12 @@ batch_norm
.. autofunction:: paddle.v2.fluid.layers.batch_norm
:noindex:
layer_norm
----------
.. autofunction:: paddle.v2.fluid.layers.layer_norm
:noindex:
beam_search_decode
------------------
......
......@@ -116,8 +116,6 @@ class LayerNormGradOp : public framework::OperatorWithKernel {
// check input
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of LayerNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Scale"),
"Input(Scale) of LayerNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Mean"),
"Input(Mean) of LayerNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Variance"),
......
......@@ -29,7 +29,7 @@ import optimizer
import learning_rate_decay
import backward
import regularizer
from param_attr import ParamAttr
from param_attr import ParamAttr, WeightNormParamAttr
from data_feeder import DataFeeder
from core import LoDTensor, CPUPlace, CUDAPlace
from distribute_transpiler import DistributeTranspiler
......@@ -41,11 +41,26 @@ import profiler
Tensor = LoDTensor
__all__ = framework.__all__ + executor.__all__ + [
'io', 'initializer', 'layers', 'nets', 'optimizer', 'learning_rate_decay',
'backward', 'regularizer', 'LoDTensor', 'CPUPlace', 'CUDAPlace', 'Tensor',
'ParamAttr'
'DataFeeder', 'clip', 'SimpleDistributeTranspiler', 'DistributeTranspiler',
'memory_optimize', 'profiler'
'io',
'initializer',
'layers',
'nets',
'optimizer',
'learning_rate_decay',
'backward',
'regularizer',
'LoDTensor',
'CPUPlace',
'CUDAPlace',
'Tensor',
'ParamAttr',
'WeightNormParamAttr',
'DataFeeder',
'clip',
'SimpleDistributeTranspiler',
'DistributeTranspiler',
'memory_optimize',
'profiler',
]
......
......@@ -65,6 +65,7 @@ __all__ = [
'beam_search',
'row_conv',
'multiplex',
'layer_norm',
]
......@@ -641,8 +642,8 @@ def dynamic_gru(input,
Choices = ["sigmoid", "tanh", "relu", "identity"], default "tanh".
Returns:
Variable: The hidden state of GRU. The shape is (T \\times D), and lod \
is the same with the input.
Variable: The hidden state of GRU. The shape is :math:`(T \\times D)`, \
and lod is the same with the input.
Examples:
.. code-block:: python
......@@ -990,7 +991,7 @@ def square_error_cost(input, label, **kwargs):
label(Variable): Label tensor, has target labels.
Returns:
Variable: The tensor variable storing the element-wise squared error
Variable: The tensor variable storing the element-wise squared error \
difference of input and label.
Examples:
......@@ -1214,7 +1215,7 @@ def conv2d(input,
act(str): Activation type. Default: None
Returns:
Variable: The tensor variable storing the convolution and
Variable: The tensor variable storing the convolution and \
non-linearity activation result.
Raises:
......@@ -1565,6 +1566,102 @@ def batch_norm(input,
return helper.append_activation(batch_norm_out)
def layer_norm(input,
scale=True,
shift=True,
begin_norm_axis=1,
epsilon=1e-05,
param_attr=None,
bias_attr=None,
act=None,
name=None):
"""
**Layer Normalization**
Assume feature vectors exist on dimensions
:attr:`begin_norm_axis ... rank(input)` and calculate the moment statistics
along these dimensions for each feature vector :math:`a` with size
:math:`H`, then normalize each feature vector using the corresponding
statistics. After that, apply learnable gain and bias on the normalized
tensor to scale and shift if :attr:`scale` and :attr:`shift` are set.
Refer to `Layer Normalization <https://arxiv.org/pdf/1607.06450v1.pdf>`_
The formula is as follows:
.. math::
\\mu & = \\frac{1}{H}\\sum_{i=1}^{H} a_i
\\sigma & = \\sqrt{\\frac{1}{H}\sum_{i=1}^{H}(a_i - \\mu)^2}
h & = f(\\frac{g}{\\sigma}(a - \\mu) + b)
Args:
input(Variable): The input tensor variable.
scale(bool): Whether to learn the adaptive gain :math:`g` after
normalization.
shift(bool): Whether to learn the adaptive bias :math:`b` after
normalization.
begin_norm_axis(bool): The normalization will be performed along
dimensions from :attr:`begin_norm_axis` to :attr:`rank(input)`.
epsilon(float): The small value added to the variance to prevent
division by zero.
param_attr(ParamAttr|None): The parameter attribute for the learnable
gain :math:`g`.
bias_attr(ParamAttr|None): The parameter attribute for the learnable
bias :math:`b`.
act(str): Activation to be applied to the output of layer normalizaiton.
Returns:
Variable: A tensor variable with the same shape as the input.
Examples:
.. code-block:: python
data = fluid.layers.data(
name='data', shape=[3, 32, 32], dtype='float32')
x = fluid.layers.layer_norm(input=data, begin_norm_axis=1)
"""
helper = LayerHelper('layer_norm', **locals())
dtype = helper.input_dtype()
# create intput and parameters
inputs = {'X': input}
input_shape = input.shape
param_shape = [reduce(lambda x, y: x * y, input_shape[begin_norm_axis:])]
if scale:
scale = helper.create_parameter(
attr=helper.param_attr,
shape=param_shape,
dtype=dtype,
default_initializer=Constant(1.0))
inputs['Scale'] = scale
if shift:
assert bias_attr is not False
bias = helper.create_parameter(
attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True)
inputs['Bias'] = bias
# create output
mean_out = helper.create_tmp_variable(dtype=dtype, stop_gradient=True)
variance_out = helper.create_tmp_variable(dtype=dtype, stop_gradient=True)
layer_norm_out = helper.create_tmp_variable(dtype)
helper.append_op(
type="layer_norm",
inputs=inputs,
outputs={
"Y": layer_norm_out,
"Mean": mean_out,
"Variance": variance_out,
},
attrs={"epsilon": epsilon,
"begin_norm_axis": begin_norm_axis})
return helper.append_activation(layer_norm_out)
def beam_search_decode(ids, scores, name=None):
helper = LayerHelper('beam_search_decode', **locals())
sentence_ids = helper.create_tmp_variable(dtype=ids.dtype)
......
......@@ -194,7 +194,7 @@ def scaled_dot_product_attention(queries,
Returns:
Variable: A 3-D Tensor computed by multi-head scaled dot product
Variable: A 3-D Tensor computed by multi-head scaled dot product \
attention.
Raises:
......@@ -333,6 +333,7 @@ def scaled_dot_product_attention(queries,
x=product, shape=[-1, product.shape[-1]], act="softmax"),
shape=product.shape)
if dropout_rate:
weights = layers.dropout(x, dropout_prob=dropout_rate, is_test=False)
weights = layers.dropout(
weights, dropout_prob=dropout_rate, is_test=False)
ctx_multiheads = layers.matmul(weights, v)
return __combine_heads(ctx_multiheads)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册