diff --git a/doc/api/v2/fluid/layers.rst b/doc/api/v2/fluid/layers.rst index e24613b94b422b7cdf9c6383c359fa92a4faf6ff..58c493fd7412cf9dbe507c9622d67dae33a5fb25 100644 --- a/doc/api/v2/fluid/layers.rst +++ b/doc/api/v2/fluid/layers.rst @@ -323,6 +323,12 @@ batch_norm .. autofunction:: paddle.v2.fluid.layers.batch_norm :noindex: +layer_norm +---------- + +.. autofunction:: paddle.v2.fluid.layers.layer_norm + :noindex: + beam_search_decode ------------------ diff --git a/paddle/operators/layer_norm_op.cc b/paddle/operators/layer_norm_op.cc index 76d5d571c31c0cdec207cd171291da1f58d29b61..d9b774272cb7c9d87140bf30d2eabb44f49b2b7c 100644 --- a/paddle/operators/layer_norm_op.cc +++ b/paddle/operators/layer_norm_op.cc @@ -116,8 +116,6 @@ class LayerNormGradOp : public framework::OperatorWithKernel { // check input PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LayerNormOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Scale"), - "Input(Scale) of LayerNormOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Mean"), "Input(Mean) of LayerNormOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Variance"), diff --git a/python/paddle/v2/fluid/__init__.py b/python/paddle/v2/fluid/__init__.py index 3ee58393c72c0b6f9bec96be51ad3946752a35dd..73acbf3e009965f9eaaade77d2fe4cf4f99d4379 100644 --- a/python/paddle/v2/fluid/__init__.py +++ b/python/paddle/v2/fluid/__init__.py @@ -29,7 +29,7 @@ import optimizer import learning_rate_decay import backward import regularizer -from param_attr import ParamAttr +from param_attr import ParamAttr, WeightNormParamAttr from data_feeder import DataFeeder from core import LoDTensor, CPUPlace, CUDAPlace from distribute_transpiler import DistributeTranspiler @@ -41,11 +41,26 @@ import profiler Tensor = LoDTensor __all__ = framework.__all__ + executor.__all__ + [ - 'io', 'initializer', 'layers', 'nets', 'optimizer', 'learning_rate_decay', - 'backward', 'regularizer', 'LoDTensor', 'CPUPlace', 'CUDAPlace', 'Tensor', - 'ParamAttr' - 'DataFeeder', 'clip', 'SimpleDistributeTranspiler', 'DistributeTranspiler', - 'memory_optimize', 'profiler' + 'io', + 'initializer', + 'layers', + 'nets', + 'optimizer', + 'learning_rate_decay', + 'backward', + 'regularizer', + 'LoDTensor', + 'CPUPlace', + 'CUDAPlace', + 'Tensor', + 'ParamAttr', + 'WeightNormParamAttr', + 'DataFeeder', + 'clip', + 'SimpleDistributeTranspiler', + 'DistributeTranspiler', + 'memory_optimize', + 'profiler', ] diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index 99168ecc228045a0206aff1b7de5fc17c1438fe2..0b64e09cd359fc89ddc868ae87c1afdbfface541 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -65,6 +65,7 @@ __all__ = [ 'beam_search', 'row_conv', 'multiplex', + 'layer_norm', ] @@ -641,8 +642,8 @@ def dynamic_gru(input, Choices = ["sigmoid", "tanh", "relu", "identity"], default "tanh". Returns: - Variable: The hidden state of GRU. The shape is (T \\times D), and lod \ - is the same with the input. + Variable: The hidden state of GRU. The shape is :math:`(T \\times D)`, \ + and lod is the same with the input. Examples: .. code-block:: python @@ -990,7 +991,7 @@ def square_error_cost(input, label, **kwargs): label(Variable): Label tensor, has target labels. Returns: - Variable: The tensor variable storing the element-wise squared error + Variable: The tensor variable storing the element-wise squared error \ difference of input and label. Examples: @@ -1214,7 +1215,7 @@ def conv2d(input, act(str): Activation type. Default: None Returns: - Variable: The tensor variable storing the convolution and + Variable: The tensor variable storing the convolution and \ non-linearity activation result. Raises: @@ -1565,6 +1566,102 @@ def batch_norm(input, return helper.append_activation(batch_norm_out) +def layer_norm(input, + scale=True, + shift=True, + begin_norm_axis=1, + epsilon=1e-05, + param_attr=None, + bias_attr=None, + act=None, + name=None): + """ + **Layer Normalization** + + Assume feature vectors exist on dimensions + :attr:`begin_norm_axis ... rank(input)` and calculate the moment statistics + along these dimensions for each feature vector :math:`a` with size + :math:`H`, then normalize each feature vector using the corresponding + statistics. After that, apply learnable gain and bias on the normalized + tensor to scale and shift if :attr:`scale` and :attr:`shift` are set. + + Refer to `Layer Normalization `_ + + The formula is as follows: + + .. math:: + + \\mu & = \\frac{1}{H}\\sum_{i=1}^{H} a_i + + \\sigma & = \\sqrt{\\frac{1}{H}\sum_{i=1}^{H}(a_i - \\mu)^2} + + h & = f(\\frac{g}{\\sigma}(a - \\mu) + b) + + Args: + input(Variable): The input tensor variable. + scale(bool): Whether to learn the adaptive gain :math:`g` after + normalization. + shift(bool): Whether to learn the adaptive bias :math:`b` after + normalization. + begin_norm_axis(bool): The normalization will be performed along + dimensions from :attr:`begin_norm_axis` to :attr:`rank(input)`. + epsilon(float): The small value added to the variance to prevent + division by zero. + param_attr(ParamAttr|None): The parameter attribute for the learnable + gain :math:`g`. + bias_attr(ParamAttr|None): The parameter attribute for the learnable + bias :math:`b`. + act(str): Activation to be applied to the output of layer normalizaiton. + + Returns: + Variable: A tensor variable with the same shape as the input. + + Examples: + .. code-block:: python + + data = fluid.layers.data( + name='data', shape=[3, 32, 32], dtype='float32') + x = fluid.layers.layer_norm(input=data, begin_norm_axis=1) + """ + helper = LayerHelper('layer_norm', **locals()) + dtype = helper.input_dtype() + + # create intput and parameters + inputs = {'X': input} + input_shape = input.shape + param_shape = [reduce(lambda x, y: x * y, input_shape[begin_norm_axis:])] + if scale: + scale = helper.create_parameter( + attr=helper.param_attr, + shape=param_shape, + dtype=dtype, + default_initializer=Constant(1.0)) + inputs['Scale'] = scale + if shift: + assert bias_attr is not False + bias = helper.create_parameter( + attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True) + inputs['Bias'] = bias + + # create output + mean_out = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) + variance_out = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) + layer_norm_out = helper.create_tmp_variable(dtype) + + helper.append_op( + type="layer_norm", + inputs=inputs, + outputs={ + "Y": layer_norm_out, + "Mean": mean_out, + "Variance": variance_out, + }, + attrs={"epsilon": epsilon, + "begin_norm_axis": begin_norm_axis}) + + return helper.append_activation(layer_norm_out) + + def beam_search_decode(ids, scores, name=None): helper = LayerHelper('beam_search_decode', **locals()) sentence_ids = helper.create_tmp_variable(dtype=ids.dtype) diff --git a/python/paddle/v2/fluid/nets.py b/python/paddle/v2/fluid/nets.py index cb63d43709e23ae04c4d23457bbb79e6f7f0ce3c..be7878f869b509fa1117e305aee662cc0123bbcc 100644 --- a/python/paddle/v2/fluid/nets.py +++ b/python/paddle/v2/fluid/nets.py @@ -194,7 +194,7 @@ def scaled_dot_product_attention(queries, Returns: - Variable: A 3-D Tensor computed by multi-head scaled dot product + Variable: A 3-D Tensor computed by multi-head scaled dot product \ attention. Raises: @@ -333,6 +333,7 @@ def scaled_dot_product_attention(queries, x=product, shape=[-1, product.shape[-1]], act="softmax"), shape=product.shape) if dropout_rate: - weights = layers.dropout(x, dropout_prob=dropout_rate, is_test=False) + weights = layers.dropout( + weights, dropout_prob=dropout_rate, is_test=False) ctx_multiheads = layers.matmul(weights, v) return __combine_heads(ctx_multiheads)