# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import paddle from .. import core from ..layers import utils from ..layers import nn as F from .. import dygraph_utils from . import layers from ..framework import ( Variable, _non_static_mode, OpProtoHolder, Parameter, _dygraph_tracer, _varbase_creator, default_main_program, _global_flags, in_dygraph_mode, _in_legacy_dygraph, ) from ..data_feeder import ( convert_dtype, check_variable_and_dtype, check_type, check_dtype, ) from ..param_attr import ParamAttr from ..initializer import Normal, Constant, NumpyArrayInitializer from .. import unique_name from .layer_object_helper import LayerObjectHelper from ..data_feeder import check_variable_and_dtype, check_type import numpy as np import numbers import logging import os import paddle.utils.deprecated as deprecated from paddle import _C_ops, _legacy_C_ops __all__ = [ 'BatchNorm', 'Embedding', ] class BatchNorm(layers.Layer): r""" This interface is used to construct a callable object of the ``BatchNorm`` class. For more details, refer to code examples. It implements the function of the Batch Normalization Layer and can be used as a normalizer function for conv2d and fully connected operations. The data is normalized by the mean and variance of the channel based on the current batch data. Refer to `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift `_ for more details. When use_global_stats = False, the :math:`\mu_{\beta}` and :math:`\sigma_{\beta}^{2}` are the statistics of one mini-batch. Calculated as follows: .. math:: \mu_{\beta} &\gets \frac{1}{m} \sum_{i=1}^{m} x_i \qquad & //\ mini-batch\ mean \\ \sigma_{\beta}^{2} &\gets \frac{1}{m} \sum_{i=1}^{m}(x_i - \mu_{\beta})^2 \qquad & //\ mini-batch\ variance \\ - :math:`x` : mini-batch data - :math:`m` : the size of the mini-batch data When use_global_stats = True, the :math:`\\mu_{\\beta}` and :math:`\\sigma_{\\beta}^{2}` are not the statistics of one mini-batch. They are global or running statistics (moving_mean and moving_variance). It usually got from the pre-trained model. Calculated as follows: .. math:: moving\_mean = moving\_mean * momentum + \mu_{\beta} * (1. - momentum) \quad &// global mean \\ moving\_variance = moving\_variance * momentum + \sigma_{\beta}^{2} * (1. - momentum) \quad &// global variance \\ The normalization function formula is as follows: .. math:: \hat{x_i} &\gets \frac{x_i - \mu_\beta} {\sqrt{\ \sigma_{\beta}^{2} + \epsilon}} \qquad &//\ normalize \\ y_i &\gets \gamma \hat{x_i} + \beta \qquad &//\ scale\ and\ shift - :math:`\epsilon` : add a smaller value to the variance to prevent division by zero - :math:`\gamma` : trainable proportional parameter - :math:`\beta` : trainable deviation parameter Parameters: num_channels(int): Indicate the number of channels of the input ``Tensor``. act(str, optional): Activation to be applied to the output of batch normalization. Default: None. is_test (bool, optional): A flag indicating whether it is in test phrase or not. This flag only has effect on static graph mode. For dygraph mode, please use ``eval()``. Default: False. momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9. epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5. param_attr(ParamAttr, optional): The parameter attribute for Parameter `scale` of batch_norm. If it is set to None or one attribute of ParamAttr, batch_norm will create ParamAttr as param_attr. If the Initializer of the param_attr is not set, the parameter is initialized with Xavier. Default: None. bias_attr(ParamAttr, optional): The parameter attribute for the bias of batch_norm. If it is set to None or one attribute of ParamAttr, batch_norm will create ParamAttr as bias_attr. If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None. dtype(str, optional): Indicate the data type of the input ``Tensor``, which can be float32 or float64. Default: float32. data_layout(str, optional): Specify the input data format, the data format can be "NCHW" or "NHWC". Default: NCHW. in_place(bool, optional): Make the input and output of batch norm reuse memory. Default: False. moving_mean_name(str, optional): The name of moving_mean which store the global Mean. Default: None. moving_variance_name(str, optional): The name of the moving_variance which store the global Variance. Default: None. do_model_average_for_mean_and_var(bool, optional): Whether parameter mean and variance should do model average when model average is enabled. Default: True. use_global_stats(bool, optional): Whether to use global mean and variance. In inference or test mode, set use_global_stats to true or is_test to true, and the behavior is equivalent. In train mode, when setting use_global_stats True, the global mean and variance are also used during train period. Default: False. trainable_statistics(bool, optional): Whether to calculate mean and var in eval mode. In eval mode, when setting trainable_statistics True, mean and variance will be calculated by current batch statistics. Default: False. Returns: None Examples: .. code-block:: python import paddle.fluid as fluid from paddle.fluid.dygraph.base import to_variable import numpy as np x = np.random.random(size=(3, 10, 3, 7)).astype('float32') with fluid.dygraph.guard(): x = to_variable(x) batch_norm = fluid.BatchNorm(10) hidden1 = batch_norm(x) """ def __init__( self, num_channels, act=None, is_test=False, momentum=0.9, epsilon=1e-05, param_attr=None, bias_attr=None, dtype='float32', data_layout='NCHW', in_place=False, moving_mean_name=None, moving_variance_name=None, do_model_average_for_mean_and_var=True, use_global_stats=False, trainable_statistics=False, ): super().__init__() self._param_attr = param_attr self._bias_attr = bias_attr self._act = act self._use_mkldnn = _global_flags()["FLAGS_use_mkldnn"] assert ( bias_attr is not False ), "bias_attr should not be False in batch_norm." if dtype == "float16": self._dtype = "float32" else: self._dtype = dtype param_shape = [num_channels] # create parameter self.weight = self.create_parameter( attr=self._param_attr, shape=param_shape, dtype=self._dtype, default_initializer=Constant(1.0), ) self.weight.stop_gradient = ( use_global_stats and self._param_attr.learning_rate == 0.0 ) self.bias = self.create_parameter( attr=self._bias_attr, shape=param_shape, dtype=self._dtype, is_bias=True, ) self.bias.stop_gradient = ( use_global_stats and self._param_attr.learning_rate == 0.0 ) self._mean = self.create_parameter( attr=ParamAttr( name=moving_mean_name, initializer=Constant(0.0), trainable=False, do_model_average=do_model_average_for_mean_and_var, ), shape=param_shape, dtype=self._dtype, ) self._mean.stop_gradient = True self._variance = self.create_parameter( attr=ParamAttr( name=moving_variance_name, initializer=Constant(1.0), trainable=False, do_model_average=do_model_average_for_mean_and_var, ), shape=param_shape, dtype=self._dtype, ) self._variance.stop_gradient = True self._in_place = in_place self._data_layout = data_layout self._momentum = momentum self._epsilon = epsilon self._is_test = is_test self._fuse_with_relu = False self._use_global_stats = use_global_stats self._trainable_statistics = trainable_statistics def forward(self, input): # create output # mean and mean_out share the same memory mean_out = self._mean # variance and variance out share the same memory variance_out = self._variance if _non_static_mode(): if in_dygraph_mode(): batch_norm_out, t1, t2, t3, t4, _ = _C_ops.batch_norm( input, self._mean, self._variance, self.weight, self.bias, not self.training, self._momentum, self._epsilon, self._data_layout, self._use_global_stats, self._trainable_statistics, ) return dygraph_utils._append_activation_in_dygraph( batch_norm_out, act=self._act, use_mkldnn=self._use_mkldnn ) elif _in_legacy_dygraph(): attrs = ( "momentum", self._momentum, "epsilon", self._epsilon, "is_test", not self.training, "data_layout", self._data_layout, "use_mkldnn", self._use_mkldnn, "fuse_with_relu", self._fuse_with_relu, "use_global_stats", self._use_global_stats, 'trainable_statistics', self._trainable_statistics, ) batch_norm_out, _, _, _, _, _ = _legacy_C_ops.batch_norm( input, self.weight, self.bias, self._mean, self._variance, None, mean_out, variance_out, *attrs ) return dygraph_utils._append_activation_in_dygraph( batch_norm_out, act=self._act, use_mkldnn=self._use_mkldnn ) check_variable_and_dtype( input, 'input', ['float16', 'float32', 'float64'], 'BatchNorm' ) attrs = { "momentum": self._momentum, "epsilon": self._epsilon, "is_test": self._is_test, "data_layout": self._data_layout, "use_mkldnn": False, "fuse_with_relu": self._fuse_with_relu, "use_global_stats": self._use_global_stats, "trainable_statistics": self._trainable_statistics, } inputs = { "X": [input], "Scale": [self.weight], "Bias": [self.bias], "Mean": [self._mean], "Variance": [self._variance], } saved_mean = self._helper.create_variable_for_type_inference( dtype=self._dtype, stop_gradient=True ) saved_variance = self._helper.create_variable_for_type_inference( dtype=self._dtype, stop_gradient=True ) reserve_space = self._helper.create_variable_for_type_inference( dtype=self._helper.input_dtype(input), stop_gradient=True ) batch_norm_out = ( input if self._in_place else self._helper.create_variable_for_type_inference(self._dtype) ) outputs = { "Y": [batch_norm_out], "MeanOut": [mean_out], "VarianceOut": [variance_out], "SavedMean": [saved_mean], "SavedVariance": [saved_variance], } if reserve_space is not None: outputs["ReserveSpace"] = [reserve_space] self._helper.append_op( type="batch_norm", inputs=inputs, outputs=outputs, attrs=attrs ) # Currently, we don't support inplace in dygraph mode return self._helper.append_activation(batch_norm_out, self._act) class Embedding(layers.Layer): r""" :alias_main: paddle.nn.Embedding :alias: paddle.nn.Embedding,paddle.nn.layer.Embedding,paddle.nn.layer.common.Embedding :old_api: paddle.fluid.dygraph.Embedding **Embedding Layer** This interface is used to construct a callable object of the ``Embedding`` class. For specific usage, refer to code examples. It implements the function of the Embedding Layer. This layer is used to lookup embeddings vector of ids provided by :attr:`input` . It automatically constructs a 2D embedding matrix based on the input :attr:`size` (vocab_size, emb_size) and :attr:`dtype` . The shape of output Tensor is generated by appending an emb_size dimension to the last dimension of the input Tensor shape. **Note:** The id in :attr:`input` must satisfy :math:`0 =< id < size[0]` , otherwise the program will throw an exception and exit. .. code-block:: text Case 1: input is a Tensor. padding_idx = -1 input.data = [[1, 3], [2, 4], [4, 127] input.shape = [3, 2] Given size = [128, 16] output is a Tensor: out.shape = [3, 2, 16] out.data = [[[0.129435295, 0.244512452, ..., 0.436322452], [0.345421456, 0.524563927, ..., 0.144534654]], [[0.345249859, 0.124939536, ..., 0.194353745], [0.945345345, 0.435394634, ..., 0.435345365]], [[0.945345345, 0.435394634, ..., 0.435345365], [0.0, 0.0, ..., 0.0 ]]] # padding data The input padding_idx is less than 0, it is automatically converted to padding_idx = -1 + 128 = 127 It will pad all-zero data when ids is 127. Parameters: size(tuple|list): The shape of the look up table parameter. It should have two elements which indicate the size of the dictionary of embeddings and the size of each embedding vector respectively. is_sparse(bool): The flag indicating whether to use sparse update. This parameter only affects the performance of the backwards gradient update. It is recommended to set True because sparse update is faster. But some optimizer does not support sparse update, such as :ref:`api_fluid_optimizer_AdadeltaOptimizer` , :ref:`api_fluid_optimizer_AdamaxOptimizer` , :ref:`api_fluid_optimizer_DecayedAdagradOptimizer` , :ref:`api_fluid_optimizer_FtrlOptimizer` , :ref:`api_fluid_optimizer_LambOptimizer` and :ref:`api_fluid_optimizer_LarsMomentumOptimizer` . In these case, is_sparse must be False. Default: False. is_distributed(bool): Whether to store the embedding matrix in a distributed manner. Only used in multi-machine distributed CPU training. Default: False. padding_idx(int|long|None): padding_idx needs to be in the interval [-vocab_size, vocab_size). If :math:`padding\_idx < 0`, the :math:`padding\_idx` will automatically be converted to :math:`vocab\_size + padding\_idx` . It will output all-zero padding data whenever lookup encounters :math:`padding\_idx` in id. And the padding data will not be updated while training. If set None, it makes no effect to output. Default: None. param_attr(ParamAttr): To specify the weight parameter property. Default: None, which means the default weight parameter property is used. See usage for details in :ref:`api_fluid_ParamAttr` . In addition, user-defined or pre-trained word vectors can be loaded with the :attr:`param_attr` parameter. The local word vector needs to be transformed into numpy format, and the shape of local word vector should be consistent with :attr:`size` . Then :ref:`api_fluid_initializer_NumpyArrayInitializer` is used to load custom or pre-trained word vectors. See code example 2 for details. dtype(np.dtype|core.VarDesc.VarType|str): It refers to the data type of output Tensor. It must be "float32" or "float64". Default: "float32". Attribute: **weight** (Parameter): the learnable weights of this layer. Returns: Variable: Embedding Tensor or LoDTensor mapped by input. The data type is the same as :attr:`dtype` . Examples: .. code-block:: python import paddle.fluid as fluid import paddle.fluid.dygraph.base as base import numpy as np # example 1 inp_word = np.array([[2, 3, 5], [4, 2, 1]]).astype('int64') inp_word.shape # [2, 3] dict_size = 20 with fluid.dygraph.guard(): emb = fluid.dygraph.Embedding( size=[dict_size, 32], param_attr='emb.w', is_sparse=False) static_rlt3 = emb(base.to_variable(inp_word)) static_rlt3.shape # [2, 3, 32] # example 2: load custom or pre-trained word vectors weight_data = np.random.random(size=(128, 100)) # word vectors with numpy format w_param_attrs = fluid.ParamAttr( name="emb_weight", learning_rate=0.5, initializer=fluid.initializer.NumpyArrayInitializer(weight_data), trainable=True) with fluid.dygraph.guard(): emb = fluid.dygraph.Embedding( size=[128, 100], param_attr= w_param_attrs, is_sparse=False) static_rlt3 = emb(base.to_variable(inp_word)) """ def __init__( self, size, is_sparse=False, is_distributed=False, padding_idx=None, param_attr=None, dtype='float32', ): super().__init__() self._size = size self._is_sparse = is_sparse self._is_distributed = is_distributed self._padding_idx = ( -1 if padding_idx is None else padding_idx if padding_idx >= 0 else (size[0] + padding_idx) ) self._param_attr = param_attr self._dtype = dtype self._remote_prefetch = self._is_sparse and (not self._is_distributed) if self._remote_prefetch: assert self._is_sparse is True and self._is_distributed is False self.weight = self.create_parameter( attr=self._param_attr, shape=self._size, dtype=self._dtype, is_bias=False, ) def forward(self, input): if _non_static_mode(): return _legacy_C_ops.lookup_table_v2( self.weight, input, 'is_sparse', self._is_sparse, 'is_distributed', self._is_distributed, 'remote_prefetch', self._remote_prefetch, 'padding_idx', self._padding_idx, ) check_variable_and_dtype( input, 'input', ['uint8', 'int8', 'int16', 'int32', 'int64'], 'Embedding', ) attrs = { 'is_sparse': self._is_sparse, 'is_distributed': self._is_distributed, 'remote_prefetch': self._remote_prefetch, 'padding_idx': self._padding_idx, } out = self._helper.create_variable_for_type_inference(self._dtype) self._helper.append_op( type='lookup_table_v2', inputs={'Ids': input, 'W': self.weight}, outputs={'Out': out}, attrs=attrs, ) return out class RowConv(layers.Layer): """ ***Row-convolution operator*** The row convolution is called lookahead convolution. This operator was introduced in the following paper for DeepSpeech2: http://www.cs.cmu.edu/~dyogatam/papers/wang+etal.iclrworkshop2016.pdf The main motivation is that a bidirectional RNN, useful in DeepSpeech like speech models, learns representation for a sequence by performing a forward and a backward pass through the entire sequence. However, unlike unidirectional RNNs, bidirectional RNNs are challenging to deploy in an online and low-latency setting. The lookahead convolution incorporates information from future subsequences in a computationally efficient manner to improve unidirectional recurrent neural networks. The row convolution operator is different from the 1D sequence convolution, and is computed as follows: Given an input sequence X of length t and input dimension D, and a filter (W) of size context * D. More details about row_conv please refer to the design document https://github.com/PaddlePaddle/Paddle/issues/2228#issuecomment-303903645 . Parameters: name_scope(str): The name of this class. future_context_size (int): Future context size. Please note, the shape of convolution kernel is [future_context_size + 1, D]. param_attr (ParamAttr): Attributes of parameters, including name, initializer etc. Default: None. act (str): Non-linear activation to be applied to output variable. Default: None. Attributes: weight (Parameter): the learnable weights of this layer. Returns: the output(Out) is a LodTensor, which supports variable time-length input sequences. The underlying tensor in this LodTensor is a matrix with shape T x N, i.e., the same shape as X. Examples: .. code-block:: python import paddle.fluid as fluid import numpy with fluid.dygraph.guard(): x = numpy.random.random((16)).astype('float32') rowConv = fluid.dygraph.nn.RowConv( 'RowConv', future_context_size=2) ret = rowConv(fluid.dygraph.base.to_variable(x)) """ def __init__( self, name_scope, future_context_size, param_attr=None, act=None ): assert ( not _non_static_mode() ), "RowConv is not supported by dynamic graph mode yet!" super().__init__(name_scope) self._act = act self._param_attr = param_attr self._future_context_size = future_context_size def _build_once(self, input): self._dtype = self._helper.input_dtype(input) filter_shape = [self._future_context_size + 1, input.shape[1]] self.weight = self.create_parameter( attr=self._param_attr, shape=filter_shape, dtype=self._dtype, is_bias=False, ) def forward(self, input): out = self._helper.create_variable_for_type_inference(self._dtype) self._helper.append_op( type='row_conv', inputs={'X': [input], 'Filter': [self.weight]}, outputs={'Out': [out]}, ) return self._helper.append_activation(out, act=self._act)