提交 6c7ba81c 编写于 作者: G guosheng

Add python wrapper for layer_norm

上级 6c3b78b7
...@@ -323,6 +323,12 @@ batch_norm ...@@ -323,6 +323,12 @@ batch_norm
.. autofunction:: paddle.v2.fluid.layers.batch_norm .. autofunction:: paddle.v2.fluid.layers.batch_norm
:noindex: :noindex:
layer_norm
----------
.. autofunction:: paddle.v2.fluid.layers.layer_norm
:noindex:
beam_search_decode beam_search_decode
------------------ ------------------
......
...@@ -29,7 +29,7 @@ import optimizer ...@@ -29,7 +29,7 @@ import optimizer
import learning_rate_decay import learning_rate_decay
import backward import backward
import regularizer import regularizer
from param_attr import ParamAttr from param_attr import ParamAttr, WeightNormParamAttr
from data_feeder import DataFeeder from data_feeder import DataFeeder
from core import LoDTensor, CPUPlace, CUDAPlace from core import LoDTensor, CPUPlace, CUDAPlace
from distribute_transpiler import DistributeTranspiler from distribute_transpiler import DistributeTranspiler
...@@ -43,9 +43,9 @@ Tensor = LoDTensor ...@@ -43,9 +43,9 @@ Tensor = LoDTensor
__all__ = framework.__all__ + executor.__all__ + [ __all__ = framework.__all__ + executor.__all__ + [
'io', 'initializer', 'layers', 'nets', 'optimizer', 'learning_rate_decay', 'io', 'initializer', 'layers', 'nets', 'optimizer', 'learning_rate_decay',
'backward', 'regularizer', 'LoDTensor', 'CPUPlace', 'CUDAPlace', 'Tensor', 'backward', 'regularizer', 'LoDTensor', 'CPUPlace', 'CUDAPlace', 'Tensor',
'ParamAttr' 'ParamAttr', 'WeightNormParamAttr', 'DataFeeder', 'clip',
'DataFeeder', 'clip', 'SimpleDistributeTranspiler', 'DistributeTranspiler', 'SimpleDistributeTranspiler', 'DistributeTranspiler', 'memory_optimize',
'memory_optimize', 'profiler' 'profiler'
] ]
......
...@@ -65,6 +65,7 @@ __all__ = [ ...@@ -65,6 +65,7 @@ __all__ = [
'beam_search', 'beam_search',
'row_conv', 'row_conv',
'multiplex', 'multiplex',
'layer_norm',
] ]
...@@ -641,8 +642,8 @@ def dynamic_gru(input, ...@@ -641,8 +642,8 @@ def dynamic_gru(input,
Choices = ["sigmoid", "tanh", "relu", "identity"], default "tanh". Choices = ["sigmoid", "tanh", "relu", "identity"], default "tanh".
Returns: Returns:
Variable: The hidden state of GRU. The shape is (T \\times D), and lod \ Variable: The hidden state of GRU. The shape is :math:`(T \\times D)`, \
is the same with the input. and lod is the same with the input.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -990,7 +991,7 @@ def square_error_cost(input, label, **kwargs): ...@@ -990,7 +991,7 @@ def square_error_cost(input, label, **kwargs):
label(Variable): Label tensor, has target labels. label(Variable): Label tensor, has target labels.
Returns: Returns:
Variable: The tensor variable storing the element-wise squared error Variable: The tensor variable storing the element-wise squared error \
difference of input and label. difference of input and label.
Examples: Examples:
...@@ -1214,7 +1215,7 @@ def conv2d(input, ...@@ -1214,7 +1215,7 @@ def conv2d(input,
act(str): Activation type. Default: None act(str): Activation type. Default: None
Returns: Returns:
Variable: The tensor variable storing the convolution and Variable: The tensor variable storing the convolution and \
non-linearity activation result. non-linearity activation result.
Raises: Raises:
...@@ -1565,6 +1566,102 @@ def batch_norm(input, ...@@ -1565,6 +1566,102 @@ def batch_norm(input,
return helper.append_activation(batch_norm_out) return helper.append_activation(batch_norm_out)
def layer_norm(input,
scale=True,
shift=True,
begin_norm_axis=1,
epsilon=1e-05,
param_attr=None,
bias_attr=None,
act=None,
name=None):
"""
**Layer Normalization**
Assume feature vectors exist on dimensions
:attr:`begin_norm_axis ... rank(input)` and calculate the moment statistics
along these dimensions for each feature vector :math:`a` with size
:math:`H`, then normalize each feature vector using the corresponding
statistics. After that, apply learnable gain and bias on the normalized
tensor to scale and shift if :attr:`scale` and :attr:`shift` are set.
Refer to `Layer Normalization <https://arxiv.org/pdf/1607.06450v1.pdf>`_
The formula is as follows:
.. math::
\\mu & = \\frac{1}{H}\\sum_{i=1}^{H} a_i
\\sigma & = \\sqrt{\\frac{1}{H}\sum_{i=1}^{H}(a_i - \\mu)^2}
h & = f(\\frac{g}{\\sigma}(a - \\mu) + b)
Args:
input(Variable): The input tensor variable.
scale(bool): Whether to learn the adaptive gain :math:`g` after
normalization.
shift(bool): Whether to learn the adaptive bias :math:`b` after
normalization.
begin_norm_axis(bool): The normalization will be performed along
dimensions from :attr:`begin_norm_axis` to :attr:`rank(input)`.
epsilon(float): The small value added to the variance to prevent
division by zero.
param_attr(ParamAttr|None): The parameter attribute for the learnable
gain :math:`g`.
bias_attr(ParamAttr|None): The parameter attribute for the learnable
bias :math:`b`.
act(str): Activation to be applied to the output of layer normalizaiton.
Returns:
Variable: A tensor variable with the same shape as the input.
Examples:
.. code-block:: python
data = fluid.layers.data(
name='data', shape=[3, 32, 32], dtype='float32')
x = fluid.layers.layer_norm(input=data, begin_norm_axis=1)
"""
helper = LayerHelper('layer_norm', **locals())
dtype = helper.input_dtype()
# create intput and parameters
inputs = {'X': input}
input_shape = input.shape
param_shape = [reduce(lambda x, y: x * y, input_shape[begin_norm_axis:])]
if scale:
scale = helper.create_parameter(
attr=helper.param_attr,
shape=param_shape,
dtype=dtype,
default_initializer=Constant(1.0))
inputs['Scale'] = scale
if center:
assert bias_attr is not False
bias = helper.create_parameter(
attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True)
inputs['Bias'] = bias
# create output
mean_out = helper.create_tmp_variable(dtype=dtype, stop_gradient=True)
variance_out = helper.create_tmp_variable(dtype=dtype, stop_gradient=True)
layer_norm_out = helper.create_tmp_variable(dtype)
helper.append_op(
type="layer_norm",
inputs=inputs,
outputs={
"Y": layer_norm_out,
"Mean": mean_out,
"Variance": variance_out,
},
attrs={"epsilon": epsilon,
"begin_norm_axis": begin_norm_axis})
return helper.append_activation(layer_norm_out)
def beam_search_decode(ids, scores, name=None): def beam_search_decode(ids, scores, name=None):
helper = LayerHelper('beam_search_decode', **locals()) helper = LayerHelper('beam_search_decode', **locals())
sentence_ids = helper.create_tmp_variable(dtype=ids.dtype) sentence_ids = helper.create_tmp_variable(dtype=ids.dtype)
......
...@@ -194,7 +194,7 @@ def scaled_dot_product_attention(queries, ...@@ -194,7 +194,7 @@ def scaled_dot_product_attention(queries,
Returns: Returns:
Variable: A 3-D Tensor computed by multi-head scaled dot product Variable: A 3-D Tensor computed by multi-head scaled dot product \
attention. attention.
Raises: Raises:
...@@ -333,6 +333,7 @@ def scaled_dot_product_attention(queries, ...@@ -333,6 +333,7 @@ def scaled_dot_product_attention(queries,
x=product, shape=[-1, product.shape[-1]], act="softmax"), x=product, shape=[-1, product.shape[-1]], act="softmax"),
shape=product.shape) shape=product.shape)
if dropout_rate: if dropout_rate:
weights = layers.dropout(x, dropout_prob=dropout_rate, is_test=False) weights = layers.dropout(
weights, dropout_prob=dropout_rate, is_test=False)
ctx_multiheads = layers.matmul(weights, v) ctx_multiheads = layers.matmul(weights, v)
return __combine_heads(ctx_multiheads) return __combine_heads(ctx_multiheads)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册