norm.py 72.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16 17 18 19 20 21 22 23 24 25 26 27
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

28
# TODO: define normalization api
29

30 31
import numbers
import warnings
C
ceci3 已提交
32

33
import numpy as np
34

35 36 37
from paddle import _C_ops, _legacy_C_ops, in_dynamic_mode
from paddle.device import get_all_custom_device_type
from paddle.fluid.framework import _in_legacy_dygraph, in_dygraph_mode
38

39
from ...fluid import dygraph_utils
40
from ...fluid.data_feeder import check_variable_and_dtype
41 42 43 44 45 46 47
from ...framework import (
    ParamAttr,
    _global_flags,
    _non_static_mode,
    get_default_dtype,
    no_grad,
)
Z
zhiboniu 已提交
48
from .. import Layer
49 50
from .. import functional as F
from ..functional import batch_norm, instance_norm, layer_norm
51
from ..initializer import Constant, Normal
52

53 54
__all__ = []

C
ceci3 已提交
55

Z
zhiboniu 已提交
56
class _InstanceNormBase(Layer):
57
    """
58
    This class is based class for InstanceNorm1D, 2d, 3d.
59

C
cnn 已提交
60
    See InstaceNorm1D, InstanceNorm2D or InstanceNorm3D for more details.
61 62
    """

63 64 65 66 67 68 69 70 71 72
    def __init__(
        self,
        num_features,
        epsilon=1e-5,
        momentum=0.9,
        weight_attr=None,
        bias_attr=None,
        data_format="NCHW",
        name=None,
    ):
73
        super().__init__()
74

75
        if weight_attr is False or bias_attr is False:
76 77
            assert (
                weight_attr == bias_attr
78
            ), "weight_attr and bias_attr must be set to False at the same time in InstanceNorm"
79 80 81
        self._epsilon = epsilon
        self._weight_attr = weight_attr
        self._bias_attr = bias_attr
82
        self._num_features = num_features
83

84
        if weight_attr is not False and bias_attr is not False:
85 86 87 88
            self.scale = self.create_parameter(
                attr=self._weight_attr,
                shape=[num_features],
                default_initializer=Constant(1.0),
89 90 91 92 93 94 95 96
                is_bias=False,
            )
            self.bias = self.create_parameter(
                attr=self._bias_attr,
                shape=[num_features],
                default_initializer=Constant(0.0),
                is_bias=True,
            )
97 98 99 100 101 102 103 104 105 106
        else:
            self.scale = None
            self.bias = None

    def _check_input_dim(self, input):
        raise NotImplementedError("InstanceNorm Base error")

    def forward(self, input):
        self._check_input_dim(input)

107 108 109
        return instance_norm(
            input, weight=self.scale, bias=self.bias, eps=self._epsilon
        )
110

111
    def extra_repr(self):
112 113 114
        return 'num_features={}, epsilon={}'.format(
            self._num_features, self._epsilon
        )
115

116

C
cnn 已提交
117
class InstanceNorm1D(_InstanceNormBase):
118
    r"""
119
    Create a callable object of `InstanceNorm1D`. Applies Instance Normalization over a 3D input (a mini-batch of 1D inputs with additional channel dimension) as described in the paper Instance Normalization: The Missing Ingredient for Fast Stylization .
120 121 122 123 124 125

    DataLayout: NCL `[batch, in_channels, length]`

    :math:`input` is the input features over a mini-batch.

    ..  math::
126

127 128 129 130 131 132 133
        \mu_{\beta} &\gets \frac{1}{HW} \sum_{i=1}^{HW} x_i \qquad &//\
        \ mean\ of\ one\  feature\ map\ in\ mini-batch \\
        \sigma_{\beta}^{2} &\gets \frac{1}{HW} \sum_{i=1}^{HW}(x_i - \
        \mu_{\beta})^2 \qquad &//\ variance\ of\ one\ feature\ map\ in\ mini-batch \\
        \hat{x_i} &\gets \frac{x_i - \mu_\beta} {\sqrt{\
        \sigma_{\beta}^{2} + \epsilon}} \qquad &//\ normalize \\
        y_i &\gets \gamma \hat{x_i} + \beta \qquad &//\ scale\ and\ shift
134

135
    Where `H` means height of feature map, `W` means width of feature map.
136 137 138 139 140 141

    Parameters:
        num_features(int): Indicate the number of channels of the input ``Tensor``.
        epsilon(float, optional): A value added to the denominator for
            numerical stability. Default is 1e-5.
        momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9.
142 143
        weight_attr(ParamAttr|bool, optional): The parameter attribute for Parameter `scale` of instance_norm.
            If it is set to None or one attribute of ParamAttr, instance_norm
144 145
            will create ParamAttr as weight_attr, the name of scale can be set in ParamAttr.
            If the Initializer of the weight_attr is not set, the parameter is initialized
146
            one. If it is set to False, will not create weight_attr. Default: None. For more information, please refer to :ref:`api_paddle_ParamAttr` .
147
        bias_attr(ParamAttr|bool, optional): The parameter attribute for the bias of instance_norm.
148 149 150
            If it is set to None or one attribute of ParamAttr, instance_norm
            will create ParamAttr as bias_attr, the name of bias can be set in ParamAttr.
            If the Initializer of the bias_attr is not set, the bias is initialized zero.
151
            If it is set to False, will not create bias_attr. Default: None. For more information, please refer to :ref:`api_paddle_ParamAttr` .
152
        data_format(str, optional): Specify the input data format, may be "NC", "NCL". Default "NCL".
153
        name(str, optional): Name for the InstanceNorm, default is None. For more information, please refer to :ref:`api_guide_Name` .
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169


    Shape:
        - x: 2-D or 3-D tensor with shape: (batch, num_features) or (batch, num_features, length).
        - output: 3-D tensor with same shape as input x.

    Returns:
        None.


    Examples:

        .. code-block:: python

          import paddle

170
          x = paddle.rand((2, 2, 3))
C
cnn 已提交
171
          instance_norm = paddle.nn.InstanceNorm1D(2)
172 173
          instance_norm_out = instance_norm(x)

Z
zhang wenhui 已提交
174
          print(instance_norm_out)
175 176 177

    """

178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
    def __init__(
        self,
        num_features,
        epsilon=0.00001,
        momentum=0.9,
        weight_attr=None,
        bias_attr=None,
        data_format="NCL",
        name=None,
    ):
        super().__init__(
            num_features,
            epsilon,
            momentum,
            weight_attr,
            bias_attr,
            data_format,
            name,
        )

198 199
    def _check_input_dim(self, input):
        if len(input.shape) != 2 and len(input.shape) != 3:
200 201 202 203 204
            raise ValueError(
                'expected 2D or 3D input (got {}D input)'.format(
                    len(input.shape)
                )
            )
205 206


C
cnn 已提交
207
class InstanceNorm2D(_InstanceNormBase):
208
    r"""
209
    Create a callable object of `InstanceNorm2D`. Applies Instance Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension) as described in the paper Instance Normalization: The Missing Ingredient for Fast Stylization .
210 211 212 213 214 215 216

    DataLayout: NCHW `[batch, in_channels, in_height, in_width]`


    :math:`input` is the input features over a mini-batch.

    ..  math::
217

218 219 220 221 222 223 224
        \mu_{\beta} &\gets \frac{1}{HW} \sum_{i=1}^{HW} x_i \qquad &//\
        \ mean\ of\ one\  feature\ map\ in\ mini-batch \\
        \sigma_{\beta}^{2} &\gets \frac{1}{HW} \sum_{i=1}^{HW}(x_i - \
        \mu_{\beta})^2 \qquad &//\ variance\ of\ one\ feature\ map\ in\ mini-batch \\
        \hat{x_i} &\gets \frac{x_i - \mu_\beta} {\sqrt{\
        \sigma_{\beta}^{2} + \epsilon}} \qquad &//\ normalize \\
        y_i &\gets \gamma \hat{x_i} + \beta \qquad &//\ scale\ and\ shift
225

226
    Where `H` means height of feature map, `W` means width of feature map.
227 228 229 230 231 232 233

    Parameters:
        num_features(int): Indicate the number of channels of the input ``Tensor``.
        epsilon(float, optional): A value added to the denominator for
            numerical stability. Default is 1e-5.
        momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9.
        weight_attr(ParamAttr|bool, optional): The parameter attribute for Parameter `scale`
234 235 236
            of instance_norm. If it is set to None or one attribute of ParamAttr, instance_norm
            will create ParamAttr as weight_attr, the name of scale can be set in ParamAttr.
            If the Initializer of the weight_attr is not set, the parameter is initialized
237
            one. If it is set to False, will not create weight_attr. Default: None. For more information, please refer to :ref:`api_paddle_ParamAttr` .
238
        bias_attr(ParamAttr|bool, optional): The parameter attribute for the bias of instance_norm.
239 240 241
            If it is set to None or one attribute of ParamAttr, instance_norm
            will create ParamAttr as bias_attr, the name of bias can be set in ParamAttr.
            If the Initializer of the bias_attr is not set, the bias is initialized zero.
242
            If it is set to False, will not create bias_attr. Default: None. For more information, please refer to :ref:`api_paddle_ParamAttr` .
243
        data_format(str, optional): Specify the input data format, could be "NCHW". Default: NCHW.
244
        name(str, optional): Name for the InstanceNorm, default is None. For more information, please refer to :ref:`api_guide_Name` .
245 246 247 248 249 250 251 252 253 254 255 256 257

    Shape:
        - x: 4-D tensor with shape: (batch, num_features, height, weight).
        - output: 4-D tensor with same shape as input x.

    Returns:
        None.


    Examples:

        .. code-block:: python

258
            import paddle
259

260 261 262
            x = paddle.rand((2, 2, 2, 3))
            instance_norm = paddle.nn.InstanceNorm2D(2)
            instance_norm_out = instance_norm(x)
263

264
            print(instance_norm_out)
265 266
    """

267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286
    def __init__(
        self,
        num_features,
        epsilon=0.00001,
        momentum=0.9,
        weight_attr=None,
        bias_attr=None,
        data_format="NCHW",
        name=None,
    ):
        super().__init__(
            num_features,
            epsilon,
            momentum,
            weight_attr,
            bias_attr,
            data_format,
            name,
        )

287 288
    def _check_input_dim(self, input):
        if len(input.shape) != 4:
289 290 291
            raise ValueError(
                'expected 4D input (got {}D input)'.format(len(input.shape))
            )
292 293


C
cnn 已提交
294
class InstanceNorm3D(_InstanceNormBase):
295
    r"""
296
    Create a callable object of `InstanceNorm3D`. Applies Instance Normalization over a 5D input (a mini-batch of 3D inputs with additional channel dimension) as described in the paper Instance Normalization: The Missing Ingredient for Fast Stylization .
297

298
    DataLayout: NCDHW `[batch, in_channels, D, in_height, in_width]`
299 300 301 302 303


    :math:`input` is the input features over a mini-batch.

    ..  math::
304

305 306 307 308 309 310 311
        \mu_{\beta} &\gets \frac{1}{HW} \sum_{i=1}^{HW} x_i \qquad &//\
        \ mean\ of\ one\  feature\ map\ in\ mini-batch \\
        \sigma_{\beta}^{2} &\gets \frac{1}{HW} \sum_{i=1}^{HW}(x_i - \
        \mu_{\beta})^2 \qquad &//\ variance\ of\ one\ feature\ map\ in\ mini-batch \\
        \hat{x_i} &\gets \frac{x_i - \mu_\beta} {\sqrt{\
        \sigma_{\beta}^{2} + \epsilon}} \qquad &//\ normalize \\
        y_i &\gets \gamma \hat{x_i} + \beta \qquad &//\ scale\ and\ shift
312

313
    Where `H` means height of feature map, `W` means width of feature map.
314 315 316 317 318 319 320

    Parameters:
        num_features(int): Indicate the number of channels of the input ``Tensor``.
        epsilon(float, optional): A value added to the denominator for
            numerical stability. Default is 1e-5.
        momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9.
        weight_attr(ParamAttr|bool, optional): The parameter attribute for Parameter `scale`
321 322 323
            of instance_norm. If it is set to None or one attribute of ParamAttr, instance_norm
            will create ParamAttr as weight_attr, the name of scale can be set in ParamAttr.
            If the Initializer of the weight_attr is not set, the parameter is initialized
324
            one. If it is set to False, will not create weight_attr. Default: None. For more information, please refer to :ref:`api_paddle_ParamAttr` .
325
        bias_attr(ParamAttr|bool, optional): The parameter attribute for the bias of instance_norm.
326 327 328
            If it is set to None or one attribute of ParamAttr, instance_norm
            will create ParamAttr as bias_attr, the name of bias can be set in ParamAttr.
            If the Initializer of the bias_attr is not set, the bias is initialized zero.
329
            If it is set to False, will not create bias_attr. Default: None. For more information, please refer to :ref:`api_paddle_ParamAttr` .
330
        data_format(str, optional): Specify the input data format, could be "NCDHW". Default: NCDHW.
331
        name(str, optional): Name for the InstanceNorm, default is None. For more information, please refer to :ref:`api_guide_Name` .
332 333 334 335 336 337 338 339 340 341 342 343 344

    Shape:
        - x: 5-D tensor with shape: (batch, num_features, dims, height, weight).
        - output: 5-D tensor with same shape as input x.

    Returns:
        None.


    Examples:

        .. code-block:: python

345
            import paddle
346

347 348 349
            x = paddle.rand((2, 2, 2, 2, 3))
            instance_norm = paddle.nn.InstanceNorm3D(2)
            instance_norm_out = instance_norm(x)
350

351
            print(instance_norm_out.numpy)
352 353
    """

354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373
    def __init__(
        self,
        num_features,
        epsilon=0.00001,
        momentum=0.9,
        weight_attr=None,
        bias_attr=None,
        data_format="NCDHW",
        name=None,
    ):
        super().__init__(
            num_features,
            epsilon,
            momentum,
            weight_attr,
            bias_attr,
            data_format,
            name,
        )

374 375
    def _check_input_dim(self, input):
        if len(input.shape) != 5:
376 377 378
            raise ValueError(
                'expected 5D input (got {}D input)'.format(len(input.shape))
            )
379 380


Z
zhiboniu 已提交
381
class GroupNorm(Layer):
382
    """
383

384 385 386 387 388 389 390
    This interface is used to construct a callable object of the ``GroupNorm`` class.
    For more details, refer to code examples.
    It implements the function of the Group Normalization Layer.
    Refer to `Group Normalization <https://arxiv.org/abs/1803.08494>`_ .

    Parameters:
        num_groups(int): The number of groups that divided from channels.
391
        num_channels(int): The number of channels of input.
392
        epsilon(float, optional): The small value added to the variance to prevent
393
            division by zero. Default: 1e-05.
394
        weight_attr(ParamAttr|bool, optional): The parameter attribute for the learnable
395 396
            scale :math:`g`. If it is set to False, no scale will be added to the output units.
            If it is set to None, the bias is initialized one. Default: None.
397
        bias_attr(ParamAttr|bool, optional): The parameter attribute for the learnable
398 399
            bias :math:`b`. If it is set to False, no bias will be added to the output units.
            If it is set to None, the bias is initialized zero. Default: None.
400 401 402 403
        data_format(str, optional): Specify the input data format. Only NCHW is supported. Default: NCHW.
        name(str, optional): Name for the GroupNorm, default is None. For more information, please refer to :ref:`api_guide_Name`..

    Shape:
404
        - x: Tensor with shape: attr:`(batch, num_features, *)`.
405
        - output: The same shape as input x.
406 407 408 409 410 411

    Returns:
        None

    Examples:
        .. code-block:: python
Z
zhang wenhui 已提交
412

413
            import paddle
414

415
            x = paddle.arange(48, dtype="float32").reshape((2, 6, 2, 2))
416 417
            group_norm = paddle.nn.GroupNorm(num_channels=6, num_groups=6)
            group_norm_out = group_norm(x)
418

419
            print(group_norm_out)
420 421
    """

422 423 424 425 426 427 428 429 430 431
    def __init__(
        self,
        num_groups,
        num_channels,
        epsilon=1e-05,
        weight_attr=None,
        bias_attr=None,
        data_format='NCHW',
        name=None,
    ):
432
        super().__init__()
433 434 435 436 437
        self._weight_attr = weight_attr
        self._bias_attr = bias_attr
        self._epsilon = epsilon
        self._num_channels = num_channels
        self._num_groups = num_groups
438
        if data_format not in ['NCHW', 'NHWC']:
439
            raise ValueError("unsupported data layout:" + data_format)
440
        self._data_format = data_format
441 442 443

        param_shape = [self._num_channels]

444
        if weight_attr is False:
445
            self.weight = self.create_parameter(
446 447
                attr=None, shape=param_shape, default_initializer=Constant(1.0)
            )
448 449 450 451 452
            self.weight.stop_gradient = True
        else:
            self.weight = self.create_parameter(
                attr=self._weight_attr,
                shape=param_shape,
453 454
                default_initializer=Constant(1.0),
            )
455 456
            self.weight.stop_gradient = self._weight_attr is not None and (
                hasattr(self._weight_attr, "learning_rate")
457 458
                and self._weight_attr.learning_rate == 0.0
            )
459

460
        if bias_attr is False:
461 462 463 464 465 466
            self.bias = self.create_parameter(
                attr=None,
                shape=param_shape,
                default_initializer=Constant(0.0),
                is_bias=True,
            )
467 468
            self.bias.stop_gradient = True
        else:
469 470 471
            self.bias = self.create_parameter(
                attr=self._bias_attr, shape=param_shape, is_bias=True
            )
472 473
            self.bias.stop_gradient = self._bias_attr is not None and (
                hasattr(self._bias_attr, "learning_rate")
474
                and self._bias_attr.learning_rate == 0.0
475
            )
476 477

    def forward(self, input):
478
        if in_dygraph_mode():
479
            return _C_ops.group_norm(
480 481 482 483 484
                input,
                self.weight,
                self.bias,
                self._epsilon,
                self._num_groups,
485
                self._data_format,
486
            )
487

488 489 490 491 492 493
        mean_out = self._helper.create_variable_for_type_inference(
            dtype=input.dtype, stop_gradient=True
        )
        variance_out = self._helper.create_variable_for_type_inference(
            dtype=input.dtype, stop_gradient=True
        )
494

495
        if _in_legacy_dygraph():
496
            pre_act, _, _ = _legacy_C_ops.group_norm(
497 498 499 500 501 502 503 504
                input,
                self.weight,
                self.bias,
                mean_out,
                variance_out,
                'epsilon',
                self._epsilon,
                'groups',
505 506
                self._num_groups,
            )
507
            return pre_act
508

509 510 511 512 513 514 515 516
        inputs = {'X': input}
        if self.bias is not None:
            inputs['Bias'] = self.bias
        if self.weight is not None:
            inputs['Scale'] = self.weight

        # create output
        group_norm_out = self._helper.create_variable_for_type_inference(
517 518 519 520 521 522 523 524 525 526 527 528 529
            dtype=input.dtype
        )

        self._helper.append_op(
            type="group_norm",
            inputs=inputs,
            outputs={
                "Y": group_norm_out,
                "Mean": mean_out,
                "Variance": variance_out,
            },
            attrs={"epsilon": self._epsilon, "groups": self._num_groups},
        )
530 531 532

        return self._helper.append_activation(group_norm_out, None)

533 534
    def extra_repr(self):
        return 'num_groups={}, num_channels={}, epsilon={}'.format(
535 536
            self._num_groups, self._num_channels, self._epsilon
        )
537

538

Z
zhiboniu 已提交
539
class LayerNorm(Layer):
540
    r"""
541
    Construct a callable object of the ``LayerNorm`` class.
542 543 544 545 546 547 548 549
    For more details, refer to code examples.
    It implements the function of the Layer Normalization Layer and can be applied to mini-batch input data.
    Refer to `Layer Normalization <https://arxiv.org/pdf/1607.06450v1.pdf>`_

    The formula is as follows:

    ..  math::

550
        \mu & = \frac{1}{H}\sum_{i=1}^{H} x_i
551

552
        \sigma & = \sqrt{\frac{1}{H}\sum_{i=1}^{H}{(x_i - \mu)^2} + \epsilon}
553

554
        y & = f(\frac{g}{\sigma}(x - \mu) + b)
555 556 557

    - :math:`x`: the vector representation of the summed inputs to the neurons in that layer.
    - :math:`H`: the number of hidden units in a layers
558
    - :math:`\epsilon`: the small value added to the variance to prevent division by zero.
559 560 561 562 563 564 565 566 567 568 569 570
    - :math:`g`: the trainable scale parameter.
    - :math:`b`: the trainable bias parameter.

    Parameters:
        normalized_shape(int|list|tuple): Input shape from an expected input of
            size :math:`[*, normalized_shape[0], normalized_shape[1], ..., normalized_shape[-1]]`.
            If it is a single integer, this module will normalize over the last dimension
            which is expected to be of that specific size.
        epsilon(float, optional): The small value added to the variance to prevent
            division by zero. Default: 1e-05.
        weight_attr(ParamAttr|bool, optional): The parameter attribute for the learnable
            gain :math:`g`. If False, weight is None. If is None, a default :code:`ParamAttr` would be added as scale. The
571
            :attr:`param_attr` is initialized as 1 if it is added. Default: None. For more information, please refer to :ref:`api_paddle_ParamAttr` .
572 573
        bias_attr(ParamAttr|bool, optional): The parameter attribute for the learnable
            bias :math:`b`. If is False, bias is None. If is None, a default :code:`ParamAttr` would be added as bias. The
574 575
            :attr:`bias_attr` is initialized as 0 if it is added. Default: None. For more information, please refer to :ref:`api_paddle_ParamAttr` .
        name(str, optional): Name for the LayerNorm, default is None. For more information, please refer to :ref:`api_guide_Name` .
576 577 578 579 580 581 582 583 584 585 586 587 588 589

    Shape:
        - x: 2-D, 3-D, 4-D or 5-D tensor.
        - output: same shape as input x.

    Returns:
        None

    Examples:

        .. code-block:: python

          import paddle

590 591
          x = paddle.rand((2, 2, 2, 3))
          layer_norm = paddle.nn.LayerNorm(x.shape[1:])
592 593
          layer_norm_out = layer_norm(x)

Z
zhang wenhui 已提交
594
          print(layer_norm_out)
595 596
    """

597 598 599 600 601 602 603 604
    def __init__(
        self,
        normalized_shape,
        epsilon=1e-05,
        weight_attr=None,
        bias_attr=None,
        name=None,
    ):
605
        super().__init__()
606 607 608 609 610 611 612 613 614 615 616 617 618 619 620
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = [normalized_shape]

        self._normalized_shape = list(normalized_shape)
        self._epsilon = epsilon
        self._weight_attr = weight_attr
        self._bias_attr = bias_attr
        param_shape = [np.prod(self._normalized_shape)]

        if weight_attr is False:
            self.weight = None
        else:
            self.weight = self.create_parameter(
                attr=self._weight_attr,
                shape=param_shape,
621 622
                default_initializer=Constant(1.0),
            )
623 624 625 626

        if bias_attr is False:
            self.bias = None
        else:
627 628 629
            self.bias = self.create_parameter(
                attr=self._bias_attr, shape=param_shape, is_bias=True
            )
630 631

    def forward(self, input):
632 633 634 635 636 637 638
        return layer_norm(
            input,
            normalized_shape=self._normalized_shape,
            weight=self.weight,
            bias=self.bias,
            epsilon=self._epsilon,
        )
639

640
    def extra_repr(self):
641 642 643
        return 'normalized_shape={}, epsilon={}'.format(
            self._normalized_shape, self._epsilon
        )
644

645

Z
zhiboniu 已提交
646
class _BatchNormBase(Layer):
647 648 649 650
    """
    BatchNorm base .
    """

651 652 653 654 655 656 657 658 659 660 661
    def __init__(
        self,
        num_features,
        momentum=0.9,
        epsilon=1e-05,
        weight_attr=None,
        bias_attr=None,
        data_format='NCHW',
        use_global_stats=None,
        name=None,
    ):
662
        super().__init__()
663 664 665
        self._num_features = num_features
        self._weight_attr = weight_attr
        self._bias_attr = bias_attr
C
ceci3 已提交
666
        self._use_global_stats = use_global_stats
667 668

        if get_default_dtype() == 'float16':
G
Guoxia Wang 已提交
669 670 671
            self._dtype = 'float32'
        else:
            self._dtype = get_default_dtype()
672 673 674 675

        param_shape = [num_features]

        # create parameter
676
        if weight_attr is False:
677
            self.weight = self.create_parameter(
G
Guoxia Wang 已提交
678 679 680
                attr=None,
                shape=param_shape,
                dtype=self._dtype,
681 682
                default_initializer=Constant(1.0),
            )
683 684 685 686 687
            self.weight.stop_gradient = True
        else:
            self.weight = self.create_parameter(
                attr=self._weight_attr,
                shape=param_shape,
G
Guoxia Wang 已提交
688
                dtype=self._dtype,
689 690 691
                default_initializer=Constant(1.0),
            )
            self.weight.stop_gradient = (
692
                self._weight_attr is not None
693 694
                and self._weight_attr.learning_rate == 0.0
            )
695

696
        if bias_attr is False:
697 698 699 700 701 702 703
            self.bias = self.create_parameter(
                attr=None,
                shape=param_shape,
                dtype=self._dtype,
                default_initializer=Constant(0.0),
                is_bias=True,
            )
704 705
            self.bias.stop_gradient = True
        else:
706 707 708 709 710 711 712
            self.bias = self.create_parameter(
                attr=self._bias_attr,
                shape=param_shape,
                dtype=self._dtype,
                is_bias=True,
            )
            self.bias.stop_gradient = (
713 714
                self._bias_attr is not None
                and self._bias_attr.learning_rate == 0.0
715
            )
716 717 718 719 720 721 722 723

        moving_mean_name = None
        moving_variance_name = None

        if name is not None:
            moving_mean_name = name + "_mean"
            moving_variance_name = name + "_variance"

724 725 726 727 728 729 730 731 732 733
        self._mean = self.create_parameter(
            dtype=self._dtype,
            attr=ParamAttr(
                name=moving_mean_name,
                initializer=Constant(0.0),
                trainable=False,
                do_model_average=True,
            ),
            shape=param_shape,
        )
734 735
        self._mean.stop_gradient = True

736 737 738 739 740 741 742 743 744 745
        self._variance = self.create_parameter(
            dtype=self._dtype,
            attr=ParamAttr(
                name=moving_variance_name,
                initializer=Constant(1.0),
                trainable=False,
                do_model_average=True,
            ),
            shape=param_shape,
        )
746 747
        self._variance.stop_gradient = True

748
        # TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op
749
        if (
750
            _global_flags()['FLAGS_npu_storage_format']
751 752
            and 'npu' in get_all_custom_device_type()
        ):
753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770
            with no_grad():
                weight_trans = _C_ops.npu_identity(
                    self.weight, 3
                )  # ACL_FORMAT_NC1HWC0 = 3
                bias_trans = _C_ops.npu_identity(
                    self.bias, 3
                )  # ACL_FORMAT_NC1HWC0 = 3
                mean_trans = _C_ops.npu_identity(
                    self._mean, 3
                )  # ACL_FORMAT_NC1HWC0 = 3
                var_trans = _C_ops.npu_identity(
                    self._variance, 3
                )  # ACL_FORMAT_NC1HWC0 = 3
                weight_trans._share_underline_tensor_to(self.weight)
                bias_trans._share_underline_tensor_to(self.bias)
                mean_trans._share_underline_tensor_to(self._mean)
                var_trans._share_underline_tensor_to(self._variance)

771 772 773 774 775
        self._data_format = data_format
        self._in_place = False
        self._momentum = momentum
        self._epsilon = epsilon
        self._fuse_with_relu = False
776
        self._name = name
777 778 779 780

    def _check_input_dim(self, input):
        raise NotImplementedError("BatchNorm Base error")

781 782 783
    def _check_data_format(self, input):
        raise NotImplementedError("BatchNorm Base data format error")

784 785
    def forward(self, input):

786 787
        self._check_data_format(self._data_format)

788 789
        self._check_input_dim(input)

790
        if self.training:
791
            warnings.warn(
792 793 794 795 796 797 798 799 800 801 802 803 804 805 806
                "When training, we now always track global mean and variance."
            )

        return batch_norm(
            input,
            self._mean,
            self._variance,
            weight=self.weight,
            bias=self.bias,
            training=self.training,
            momentum=self._momentum,
            epsilon=self._epsilon,
            data_format=self._data_format,
            use_global_stats=self._use_global_stats,
        )
807

808 809
    def extra_repr(self):
        main_str = 'num_features={}, momentum={}, epsilon={}'.format(
810 811
            self._num_features, self._momentum, self._epsilon
        )
812
        if self._data_format != 'NCHW':
813 814 815 816 817
            main_str += ', data_format={}'.format(self._data_format)
        if self._name is not None:
            main_str += ', name={}'.format(self._name)
        return main_str

818

819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883
class BatchNorm(Layer):
    r"""
    This interface is used to construct a callable object of the ``BatchNorm`` class.
    For more details, refer to code examples.
    It implements the function of the Batch Normalization Layer and can be used
    as a normalizer function for conv2d and fully connected operations.
    The data is normalized by the mean and variance of the channel based on the current batch data.
    Refer to `Batch Normalization: Accelerating Deep Network Training by Reducing
    Internal Covariate Shift <https://arxiv.org/pdf/1502.03167.pdf>`_
    for more details.

    When use_global_stats = False, the :math:`\mu_{\beta}`
    and :math:`\sigma_{\beta}^{2}` are the statistics of one mini-batch.
    Calculated as follows:

    ..  math::

        \mu_{\beta} &\gets \frac{1}{m} \sum_{i=1}^{m} x_i \qquad &
        //\ mini-batch\ mean \\
        \sigma_{\beta}^{2} &\gets \frac{1}{m} \sum_{i=1}^{m}(x_i - \mu_{\beta})^2 \qquad &
        //\ mini-batch\ variance \\

    - :math:`x` : mini-batch data
    - :math:`m` : the size of the mini-batch data

    When use_global_stats = True, the :math:`\\mu_{\\beta}`
    and :math:`\\sigma_{\\beta}^{2}` are not the statistics of one mini-batch.
    They are global or running statistics (moving_mean and moving_variance). It usually got from the
    pre-trained model. Calculated as follows:

    .. math::
        moving\_mean = moving\_mean * momentum + \mu_{\beta} * (1. - momentum) \quad &// global mean \\
        moving\_variance = moving\_variance * momentum + \sigma_{\beta}^{2} * (1. - momentum) \quad &// global variance \\

    The normalization function formula is as follows:

    ..  math::

        \hat{x_i} &\gets \frac{x_i - \mu_\beta} {\sqrt{\
        \sigma_{\beta}^{2} + \epsilon}} \qquad &//\ normalize \\
        y_i &\gets \gamma \hat{x_i} + \beta \qquad &//\ scale\ and\ shift


    - :math:`\epsilon` : add a smaller value to the variance to prevent division by zero
    - :math:`\gamma` : trainable proportional parameter
    - :math:`\beta` : trainable deviation parameter

    Parameters:
        num_channels(int): Indicate the number of channels of the input ``Tensor``.
        act(str, optional): Activation to be applied to the output of batch normalization. Default: None.
        is_test (bool, optional): A flag indicating whether it is in test phrase or not.
             This flag only has effect on static graph mode. For dygraph mode, please use ``eval()``.
             Default: False.
        momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9.
        epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5.
        param_attr(ParamAttr, optional): The parameter attribute for Parameter `scale`
             of batch_norm. If it is set to None or one attribute of ParamAttr, batch_norm
             will create ParamAttr as param_attr. If the Initializer of the param_attr
             is not set, the parameter is initialized with Xavier. Default: None.
        bias_attr(ParamAttr, optional): The parameter attribute for the bias of batch_norm.
             If it is set to None or one attribute of ParamAttr, batch_norm
             will create ParamAttr as bias_attr. If the Initializer of the bias_attr
             is not set, the bias is initialized zero. Default: None.
        dtype(str, optional): Indicate the data type of the input ``Tensor``,
             which can be float32 or float64. Default: float32.
学渣戊's avatar
学渣戊 已提交
884
        data_layout(str, optional): Specify the input data format, the data format can be "NCHW" or "NHWC", where `N` is batch size, `C` is the number of the feature map, `H` is the height of the feature map, `W` is the width of the feature map. Default: NCHW.
885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124
        in_place(bool, optional): Make the input and output of batch norm reuse memory. Default: False.
        moving_mean_name(str, optional): The name of moving_mean which store the global Mean. Default: None.
        moving_variance_name(str, optional): The name of the moving_variance which store the global Variance. Default: None.
        do_model_average_for_mean_and_var(bool, optional): Whether parameter mean and variance should do model
            average when model average is enabled. Default: True.
        use_global_stats(bool, optional): Whether to use global mean and
            variance. In inference or test mode, set use_global_stats to true
            or is_test to true, and the behavior is equivalent.
            In train mode, when setting use_global_stats True, the global mean
            and variance are also used during train period. Default: False.
        trainable_statistics(bool, optional): Whether to calculate mean and var in eval mode. In eval mode, when
            setting trainable_statistics True, mean and variance will be calculated by current batch statistics.
            Default: False.

    Returns:
        None

    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          import paddle.nn as nn
          from paddle.fluid.dygraph.base import to_variable
          import numpy as np


          x = np.random.random(size=(3, 10, 3, 7)).astype('float32')
          with fluid.dygraph.guard():
              x = to_variable(x)
              batch_norm = nn.layer.norm.BatchNorm(10)
              hidden1 = batch_norm(x)
    """

    def __init__(
        self,
        num_channels,
        act=None,
        is_test=False,
        momentum=0.9,
        epsilon=1e-05,
        param_attr=None,
        bias_attr=None,
        dtype='float32',
        data_layout='NCHW',
        in_place=False,
        moving_mean_name=None,
        moving_variance_name=None,
        do_model_average_for_mean_and_var=True,
        use_global_stats=False,
        trainable_statistics=False,
    ):
        super().__init__()
        self._param_attr = param_attr
        self._bias_attr = bias_attr
        self._act = act
        self._use_mkldnn = _global_flags()["FLAGS_use_mkldnn"]

        assert (
            bias_attr is not False
        ), "bias_attr should not be False in batch_norm."

        if dtype == "float16":
            self._dtype = "float32"
        else:
            self._dtype = dtype

        param_shape = [num_channels]

        # create parameter
        self.weight = self.create_parameter(
            attr=self._param_attr,
            shape=param_shape,
            dtype=self._dtype,
            default_initializer=Constant(1.0),
        )
        self.weight.stop_gradient = (
            use_global_stats and self._param_attr.learning_rate == 0.0
        )

        self.bias = self.create_parameter(
            attr=self._bias_attr,
            shape=param_shape,
            dtype=self._dtype,
            is_bias=True,
        )
        self.bias.stop_gradient = (
            use_global_stats and self._param_attr.learning_rate == 0.0
        )

        self._mean = self.create_parameter(
            attr=ParamAttr(
                name=moving_mean_name,
                initializer=Constant(0.0),
                trainable=False,
                do_model_average=do_model_average_for_mean_and_var,
            ),
            shape=param_shape,
            dtype=self._dtype,
        )
        self._mean.stop_gradient = True

        self._variance = self.create_parameter(
            attr=ParamAttr(
                name=moving_variance_name,
                initializer=Constant(1.0),
                trainable=False,
                do_model_average=do_model_average_for_mean_and_var,
            ),
            shape=param_shape,
            dtype=self._dtype,
        )
        self._variance.stop_gradient = True

        self._in_place = in_place
        self._data_layout = data_layout
        self._momentum = momentum
        self._epsilon = epsilon
        self._is_test = is_test
        self._fuse_with_relu = False
        self._use_global_stats = use_global_stats
        self._trainable_statistics = trainable_statistics

    def forward(self, input):
        # create output
        # mean and mean_out share the same memory
        mean_out = self._mean
        # variance and variance out share the same memory
        variance_out = self._variance

        if _non_static_mode():
            if in_dygraph_mode():
                batch_norm_out, t1, t2, t3, t4, _ = _C_ops.batch_norm(
                    input,
                    self._mean,
                    self._variance,
                    self.weight,
                    self.bias,
                    not self.training,
                    self._momentum,
                    self._epsilon,
                    self._data_layout,
                    self._use_global_stats,
                    self._trainable_statistics,
                )
                return dygraph_utils._append_activation_in_dygraph(
                    batch_norm_out, act=self._act, use_mkldnn=self._use_mkldnn
                )

            elif _in_legacy_dygraph():
                attrs = (
                    "momentum",
                    self._momentum,
                    "epsilon",
                    self._epsilon,
                    "is_test",
                    not self.training,
                    "data_layout",
                    self._data_layout,
                    "use_mkldnn",
                    self._use_mkldnn,
                    "fuse_with_relu",
                    self._fuse_with_relu,
                    "use_global_stats",
                    self._use_global_stats,
                    'trainable_statistics',
                    self._trainable_statistics,
                )
                batch_norm_out, _, _, _, _, _ = _legacy_C_ops.batch_norm(
                    input,
                    self.weight,
                    self.bias,
                    self._mean,
                    self._variance,
                    None,
                    mean_out,
                    variance_out,
                    *attrs
                )

            return dygraph_utils._append_activation_in_dygraph(
                batch_norm_out, act=self._act, use_mkldnn=self._use_mkldnn
            )

        check_variable_and_dtype(
            input, 'input', ['float16', 'float32', 'float64'], 'BatchNorm'
        )

        attrs = {
            "momentum": self._momentum,
            "epsilon": self._epsilon,
            "is_test": self._is_test,
            "data_layout": self._data_layout,
            "use_mkldnn": False,
            "fuse_with_relu": self._fuse_with_relu,
            "use_global_stats": self._use_global_stats,
            "trainable_statistics": self._trainable_statistics,
        }

        inputs = {
            "X": [input],
            "Scale": [self.weight],
            "Bias": [self.bias],
            "Mean": [self._mean],
            "Variance": [self._variance],
        }

        saved_mean = self._helper.create_variable_for_type_inference(
            dtype=self._dtype, stop_gradient=True
        )
        saved_variance = self._helper.create_variable_for_type_inference(
            dtype=self._dtype, stop_gradient=True
        )
        reserve_space = self._helper.create_variable_for_type_inference(
            dtype=self._helper.input_dtype(input), stop_gradient=True
        )

        batch_norm_out = (
            input
            if self._in_place
            else self._helper.create_variable_for_type_inference(self._dtype)
        )

        outputs = {
            "Y": [batch_norm_out],
            "MeanOut": [mean_out],
            "VarianceOut": [variance_out],
            "SavedMean": [saved_mean],
            "SavedVariance": [saved_variance],
        }
        if reserve_space is not None:
            outputs["ReserveSpace"] = [reserve_space]

        self._helper.append_op(
            type="batch_norm", inputs=inputs, outputs=outputs, attrs=attrs
        )

        # Currently, we don't support inplace in dygraph mode
        return self._helper.append_activation(batch_norm_out, self._act)


C
cnn 已提交
1125
class BatchNorm1D(_BatchNormBase):
1126
    r"""
1127 1128
    Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D inputswith additional channel dimension) as described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift .

1129 1130
    When use_global_stats = False, the :math:`\mu_{\beta}`
    and :math:`\sigma_{\beta}^{2}` are the statistics of one mini-batch.
1131 1132 1133 1134
    Calculated as follows:

    ..  math::

1135 1136 1137 1138
        \mu_{\beta} &\gets \frac{1}{m} \sum_{i=1}^{m} x_i \qquad &//\
        \ mini-batch\ mean \\
        \sigma_{\beta}^{2} &\gets \frac{1}{m} \sum_{i=1}^{m}(x_i - \
        \mu_{\beta})^2 \qquad &//\ mini-batch\ variance \\
1139

1140 1141
    When use_global_stats = True, the :math:`\mu_{\beta}`
    and :math:`\sigma_{\beta}^{2}` are not the statistics of one mini-batch.
1142 1143 1144 1145
    They are global or running statistics (moving_mean and moving_variance). It usually got from the
    pre-trained model. Calculated as follows:

    .. math::
1146 1147
        moving\_mean = moving\_mean * momentum + \mu_{\beta} * (1. - momentum) \quad &// global \ mean \\
        moving\_variance = moving\_variance * momentum + \sigma_{\beta}^{2} * (1. - momentum) \quad &// global \ variance \\
1148 1149 1150 1151 1152

    The normalization function formula is as follows:

    ..  math::

1153 1154
        \hat{x_i} &\gets \frac{x_i - \mu_\beta} {\sqrt{\sigma_{\beta}^{2} + \epsilon}} \qquad &//\ normalize \\
        y_i &\gets \gamma \hat{x_i} + \beta \qquad &//\ scale\ and\ shift
1155

1156 1157 1158
    - :math:`\epsilon` : add a smaller value to the variance to prevent division by zero
    - :math:`\gamma` : trainable proportional parameter
    - :math:`\beta` : trainable deviation parameter
1159 1160 1161 1162 1163 1164 1165

    Parameters:
        num_features(int): Indicate the number of channels of the input ``Tensor``.
        epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5.
        momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9.
        weight_attr(ParamAttr|bool, optional): The parameter attribute for Parameter `scale`
            of batch_norm. If it is set to None or one attribute of ParamAttr, batch_norm
1166
            will create ParamAttr as weight_attr. If it is set to False, the weight is not learnable.
1167
            If the Initializer of the weight_attr is not set, the parameter is initialized with ones. Default: None.
1168 1169
        bias_attr(ParamAttr|bool, optional): The parameter attribute for the bias of batch_norm.
            If it is set to None or one attribute of ParamAttr, batch_norm
1170
            will create ParamAttr as bias_attr. If it is set to False, the weight is not learnable.
1171
            If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None.
学渣戊's avatar
学渣戊 已提交
1172
        data_format(str, optional): Specify the input data format, may be "NC", "NCL" or "NLC", where `N` is batch size, `C` is the number of the feature map, `L` is the length of the feature map. Default "NCL".
C
ceci3 已提交
1173
        use_global_stats(bool|None, optional): Whether to use global mean and variance. If set to False, use the statistics of one mini-batch, if set to True, use the global statistics, if set to None, use global statistics in the test phase and use the statistics of one mini-batch in the training phase. Default: None.
1174 1175 1176
        name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`..

    Shape:
F
Feiyu Chan 已提交
1177 1178
        - x: 2-D or 3-D tensor with shape: (batch, num_features) or (batch, num_features, length) when data_format is "NC" or "NCL",
            (batch, length, num_features) when data_format is "NLC".
1179 1180 1181 1182
        - output: 3-D tensor with same shape as input x.

    Returns:
        None.
1183

1184 1185 1186 1187 1188 1189

    Examples:
        .. code-block:: python

          import paddle

1190
          x = paddle.rand((2, 1, 3))
C
cnn 已提交
1191
          batch_norm = paddle.nn.BatchNorm1D(1)
1192 1193
          batch_norm_out = batch_norm(x)

Z
zhang wenhui 已提交
1194
          print(batch_norm_out)
1195 1196
    """

1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207
    def __init__(
        self,
        num_features,
        momentum=0.9,
        epsilon=1e-05,
        weight_attr=None,
        bias_attr=None,
        data_format='NCL',
        use_global_stats=None,
        name=None,
    ):
1208
        super().__init__(
1209 1210 1211 1212 1213 1214 1215 1216 1217
            num_features,
            momentum,
            epsilon,
            weight_attr,
            bias_attr,
            data_format,
            use_global_stats,
            name,
        )
C
ceci3 已提交
1218

1219 1220 1221
    def _check_data_format(self, input):
        if input == 'NCHW' or input == 'NC' or input == 'NCL':
            self._data_format = 'NCHW'
F
Feiyu Chan 已提交
1222 1223
        elif input == "NHWC" or input == 'NLC':
            self._data_format = "NHWC"
1224
        else:
F
Feiyu Chan 已提交
1225
            raise ValueError(
1226 1227
                'expected NC , NCL, NLC or None for data_format input'
            )
1228

1229 1230
    def _check_input_dim(self, input):
        if len(input.shape) != 2 and len(input.shape) != 3:
1231 1232 1233 1234 1235
            raise ValueError(
                'expected 2D or 3D input (got {}D input)'.format(
                    len(input.shape)
                )
            )
1236 1237


C
cnn 已提交
1238
class BatchNorm2D(_BatchNormBase):
1239
    r"""
1240 1241
    Applies Batch Normalization over a 4D input (a mini-batch of 2D inputswith additional channel dimension) as described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift .

1242 1243
    When use_global_stats = False, the :math:`\mu_{\beta}`
    and :math:`\sigma_{\beta}^{2}` are the statistics of one mini-batch.
1244 1245 1246 1247
    Calculated as follows:

    ..  math::

1248 1249
        \mu_{\beta} &\gets \frac{1}{m} \sum_{i=1}^{m} x_i \qquad &//
        \ mini-batch\ mean \\
1250
        \sigma_{\beta}^{2} &\gets \frac{1}{m} \sum_{i=1}^{m}(x_i -
1251
        \mu_{\beta})^2 \qquad &//\ mini-batch\ variance \\
1252

1253 1254
    When use_global_stats = True, the :math:`\mu_{\beta}`
    and :math:`\sigma_{\beta}^{2}` are not the statistics of one mini-batch.
1255 1256 1257 1258
    They are global or running statistics (moving_mean and moving_variance). It usually got from the
    pre-trained model. Calculated as follows:

    .. math::
1259 1260
        moving\_mean = moving\_mean * momentum + \mu_{\beta} * (1. - momentum) \quad &// global \ mean \\
        moving\_variance = moving\_variance * momentum + \sigma_{\beta}^{2} * (1. - momentum) \quad &// global \ variance \\
1261 1262 1263 1264 1265

    The normalization function formula is as follows:

    ..  math::

1266 1267
        \hat{x_i} &\gets \frac{x_i - \mu_\beta} {\sqrt{\sigma_{\beta}^{2} + \epsilon}} \qquad &//\ normalize \\
        y_i &\gets \gamma \hat{x_i} + \beta \qquad &//\ scale\ and\ shift
1268

1269 1270 1271
    - :math:`\epsilon` : add a smaller value to the variance to prevent division by zero
    - :math:`\gamma` : trainable proportional parameter
    - :math:`\beta` : trainable deviation parameter
1272 1273 1274 1275 1276 1277 1278

    Parameters:
        num_features(int): Indicate the number of channels of the input ``Tensor``.
        epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5.
        momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9.
        weight_attr(ParamAttr|bool, optional): The parameter attribute for Parameter `scale`
            of batch_norm. If it is set to None or one attribute of ParamAttr, batch_norm
1279
            will create ParamAttr as weight_attr. If it is set to False, the weight is not learnable.
1280
            If the Initializer of the weight_attr is not set, the parameter is initialized with ones. Default: None.
1281 1282
        bias_attr(ParamAttr|bool, optional): The parameter attribute for the bias of batch_norm.
            If it is set to None or one attribute of ParamAttr, batch_norm
1283
            will create ParamAttr as bias_attr. If it is set to False, the weight is not learnable.
1284
            If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None.
学渣戊's avatar
学渣戊 已提交
1285
        data_format(str, optional): Specify the input data format, the data format can be "NCHW" or "NHWC", where `N` is batch size, `C` is the number of the feature map, `H` is the height of the feature map, `W` is the width of the feature map. Default: NCHW.
C
ceci3 已提交
1286
        use_global_stats(bool|None, optional): Whether to use global mean and variance. If set to False, use the statistics of one mini-batch, if set to True, use the global statistics, if set to None, use global statistics in the test phase and use the statistics of one mini-batch in the training phase. Default: None.
1287 1288 1289
        name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`..

    Shape:
F
Feiyu Chan 已提交
1290 1291
        - x: 4-D tensor with shape: (batch, num_features, height, weight) when data_format is "NCHW",
            or (batch, height, weight, num_features) when data_format is "NHWC".
1292 1293 1294 1295 1296 1297 1298 1299 1300 1301
        - output: 4-D tensor with same shape as input x.

    Returns:
        None

    Examples:
        .. code-block:: python

          import paddle

1302
          x = paddle.rand((2, 1, 2, 3))
C
cnn 已提交
1303
          batch_norm = paddle.nn.BatchNorm2D(1)
1304 1305
          batch_norm_out = batch_norm(x)

Z
zhang wenhui 已提交
1306
          print(batch_norm_out)
1307 1308
    """

1309
    def _check_data_format(self, input):
1310
        if input == 'NCHW':
1311
            self._data_format = input
F
Feiyu Chan 已提交
1312 1313
        elif input == "NHWC":
            self._data_format = input
1314
        else:
F
Feiyu Chan 已提交
1315
            raise ValueError('expected NCHW or NHWC for data_format input')
1316

1317 1318
    def _check_input_dim(self, input):
        if len(input.shape) != 4:
1319 1320 1321
            raise ValueError(
                'expected 4D input (got {}D input)'.format(len(input.shape))
            )
1322 1323


C
cnn 已提交
1324
class BatchNorm3D(_BatchNormBase):
1325
    r"""
1326 1327
    Applies Batch Normalization over a 5D input (a mini-batch of 3D inputswith additional channel dimension) as described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift .

1328 1329
    When use_global_stats = False, the :math:`\mu_{\beta}`
    and :math:`\sigma_{\beta}^{2}` are the statistics of one mini-batch.
1330 1331 1332 1333
    Calculated as follows:

    ..  math::

1334 1335 1336 1337
        \mu_{\beta} &\gets \frac{1}{m} \sum_{i=1}^{m} x_i \qquad &//\
        \ mini-batch\ mean \\
        \sigma_{\beta}^{2} &\gets \frac{1}{m} \sum_{i=1}^{m}(x_i - \
        \mu_{\beta})^2 \qquad &//\ mini-batch\ variance \\
1338

C
ceci3 已提交
1339
    When use_global_stats = True, the :math:`\\mu_{\\beta}`
1340 1341 1342 1343 1344
    and :math:`\\sigma_{\\beta}^{2}` are not the statistics of one mini-batch.
    They are global or running statistics (moving_mean and moving_variance). It usually got from the
    pre-trained model. Calculated as follows:

    .. math::
1345 1346
        moving\_mean = moving\_mean * momentum + \mu_{\beta} * (1. - momentum) \quad &// global \ mean \\
        moving\_variance = moving\_variance * momentum + \sigma_{\beta}^{2} * (1. - momentum) \quad &// global \ variance \\
1347 1348 1349 1350 1351

    The normalization function formula is as follows:

    ..  math::

1352 1353
        \hat{x_i} &\gets \frac{x_i - \mu_\beta} {\sqrt{\sigma_{\beta}^{2} + \epsilon}} \qquad &//\ normalize \\
        y_i &\gets \gamma \hat{x_i} + \beta \qquad &//\ scale\ and\ shift
1354

1355 1356 1357
    - :math:`\epsilon` : add a smaller value to the variance to prevent division by zero
    - :math:`\gamma` : trainable proportional parameter
    - :math:`\beta` : trainable deviation parameter
1358 1359 1360 1361 1362 1363 1364

    Parameters:
        num_features(int): Indicate the number of channels of the input ``Tensor``.
        epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5.
        momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9.
        weight_attr(ParamAttr|bool, optional): The parameter attribute for Parameter `scale`
            of batch_norm. If it is set to None or one attribute of ParamAttr, batch_norm
1365
            will create ParamAttr as weight_attr. If it is set to False, the weight is not learnable.
1366
            If the Initializer of the weight_attr is not set, the parameter is initialized with ones. Default: None.
1367 1368
        bias_attr(ParamAttr|bool, optional): The parameter attribute for the bias of batch_norm.
            If it is set to None or one attribute of ParamAttr, batch_norm
1369
            will create ParamAttr as bias_attr. If it is set to False, the weight is not learnable.
1370
            If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None.
学渣戊's avatar
学渣戊 已提交
1371
        data_format(str, optional): Specify the input data format, the data format can be "NCDHW" or "NDHWC", where `N` is batch size, `C` is the number of the feature map, `D` is the depth of the feature, `H` is the height of the feature map, `W` is the width of the feature map. Default: NCDHW.
C
ceci3 已提交
1372
        use_global_stats(bool|None, optional): Whether to use global mean and variance. If set to False, use the statistics of one mini-batch, if set to True, use the global statistics, if set to None, use global statistics in the test phase and use the statistics of one mini-batch in the training phase. Default: None.
1373 1374 1375
        name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`..

    Shape:
F
Feiyu Chan 已提交
1376 1377
        - x: 5-D tensor with shape: (batch, num_features, dims, height, weight) when data_format is "NCDHW",
            or (batch, dims, height, weight, num_features) when data_format is "NDHWC".
1378 1379 1380 1381 1382 1383 1384 1385 1386 1387
        - output: 5-D tensor with same shape as input x.

    Returns:
        None

    Examples:
        .. code-block:: python

          import paddle

1388
          x = paddle.rand((2, 1, 2, 2, 3))
C
cnn 已提交
1389
          batch_norm = paddle.nn.BatchNorm3D(1)
1390 1391
          batch_norm_out = batch_norm(x)

Z
zhang wenhui 已提交
1392
          print(batch_norm_out)
1393 1394
    """

1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405
    def __init__(
        self,
        num_features,
        momentum=0.9,
        epsilon=1e-05,
        weight_attr=None,
        bias_attr=None,
        data_format='NCDHW',
        use_global_stats=None,
        name=None,
    ):
1406
        super().__init__(
1407 1408 1409 1410 1411 1412 1413 1414 1415
            num_features,
            momentum,
            epsilon,
            weight_attr,
            bias_attr,
            data_format,
            use_global_stats,
            name,
        )
C
ceci3 已提交
1416

1417 1418 1419
    def _check_data_format(self, input):
        if input == 'NCHW' or input == 'NCDHW':
            self._data_format = 'NCHW'
F
Feiyu Chan 已提交
1420 1421
        elif input == "NHWC" or input == "NDHWC":
            self._data_format = 'NHWC'
1422
        else:
F
Feiyu Chan 已提交
1423
            raise ValueError(
1424 1425
                'expected NCDHW, NDHWC or None for data_format input'
            )
1426

1427 1428
    def _check_input_dim(self, input):
        if len(input.shape) != 5:
1429 1430 1431
            raise ValueError(
                'expected 5D input (got {}D input)'.format(len(input.shape))
            )
1432 1433


1434
class SyncBatchNorm(_BatchNormBase):
1435
    r"""
1436

C
ceci3 已提交
1437
    This interface is used to construct a callable object of the ``SyncBatchNorm`` class.
1438 1439
    It implements the function of the Cross-GPU Synchronized Batch Normalization Layer, and can
    be used as a normalizer function for other operations, such as conv2d and fully connected
C
ceci3 已提交
1440 1441 1442 1443 1444 1445 1446
    operations.
    The data is normalized by the mean and variance of the channel based on whole mini-batch
    , which including data in all gpus.
    Refer to `Batch Normalization: Accelerating Deep Network Training by Reducing
    Internal Covariate Shift <https://arxiv.org/pdf/1502.03167.pdf>`_
    for more details.

1447
    When model in training mode, the :math:`\\mu_{\\beta}`
C
ceci3 已提交
1448 1449 1450 1451 1452
    and :math:`\\sigma_{\\beta}^{2}` are the statistics of whole mini-batch data in all gpus.
    Calculated as follows:

    ..  math::

1453 1454 1455 1456
        \mu_{\beta} &\gets \frac{1}{m} \sum_{i=1}^{m} x_i \qquad &//\
        \ mini-batch\ mean \\
        \sigma_{\beta}^{2} &\gets \frac{1}{m} \sum_{i=1}^{m}(x_i - \
        \mu_{\beta})^2 \qquad &//\ mini-batch\ variance \\
C
ceci3 已提交
1457 1458 1459 1460 1461

    - :math:`x` : whole mini-batch data in all gpus
    - :math:`m` : the size of the whole mini-batch data

    When model in evaluation mode, the :math:`\\mu_{\\beta}`
1462
    and :math:`\sigma_{\beta}^{2}` are global statistics (moving_mean and moving_variance,
C
ceci3 已提交
1463 1464 1465
    which usually got from the pre-trained model). Global statistics calculated as follows:

    .. math::
1466 1467
        moving\_mean = moving\_mean * momentum + \mu_{\beta} * (1. - momentum) \quad &// global \ mean \\
        moving\_variance = moving\_variance * momentum + \sigma_{\beta}^{2} * (1. - momentum) \quad &// global \ variance \\
C
ceci3 已提交
1468 1469

    The formula of normalization is as follows:
1470

C
ceci3 已提交
1471 1472
    ..  math::

1473 1474 1475
        \hat{x_i} &\gets \frac{x_i - \mu_\beta} {\sqrt{\
        \sigma_{\beta}^{2} + \epsilon}} \qquad &//\ normalize \\
        y_i &\gets \gamma \hat{x_i} + \beta \qquad &//\ scale\ and\ shift
C
ceci3 已提交
1476

1477 1478
    - :math:`\epsilon` : add a smaller value to the variance to prevent division by zero
    - :math:`\gamma` : trainable scale parameter vector
1479
    - :math:`\beta` : trainable shift parameter vector
C
ceci3 已提交
1480

1481
    Note:
1482 1483 1484
        If you want to use container to pack your model and has :ref:`api_paddle_nn_SyncBatchNorm` in the
        evaluation phase, please use :ref:`api_paddle_nn_LayerList` or :ref:`api_paddle_nn_Sequential` instead of
        :ref:`api_paddle_hub_list` to pack the model.
1485

C
ceci3 已提交
1486 1487 1488 1489 1490 1491 1492
    Parameters:
        num_features(int): Indicate the number of channels of the input ``Tensor``.
        epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5.
        momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9.
        weight_attr(ParamAttr|bool, optional): The parameter attribute for Parameter `scale`
             of this layer. If it is set to None or one attribute of ParamAttr, this layerr
             will create ParamAttr as param_attr. If the Initializer of the param_attr
1493
             is not set, the parameter is initialized with ones. If it is set to False,
C
ceci3 已提交
1494 1495 1496 1497
             this layer will not have trainable scale parameter. Default: None.
        bias_attr(ParamAttr|bool, optional): The parameter attribute for the bias of this layer.
             If it is set to None or one attribute of ParamAttr, this layer
             will create ParamAttr as bias_attr. If the Initializer of the bias_attr
1498
             is not set, the bias is initialized zero. If it is set to False, this layer will not
C
ceci3 已提交
1499 1500 1501
             have trainable bias parameter. Default: None.

    Shapes:
1502 1503
        - input: Tensor that the dimension from 2 to 5.
        - output: Tensor with the same shape as input.
C
ceci3 已提交
1504 1505 1506 1507

    Examples:
        .. code-block:: python

1508
            # required: gpu
1509

1510 1511
            import paddle
            import paddle.nn as nn
C
ceci3 已提交
1512

1513
            x = paddle.to_tensor([[[[0.3, 0.4], [0.3, 0.07]], [[0.83, 0.37], [0.18, 0.93]]]]).astype('float32')
C
ceci3 已提交
1514

1515 1516 1517 1518 1519 1520 1521 1522 1523 1524
            if paddle.is_compiled_with_cuda():
                sync_batch_norm = nn.SyncBatchNorm(2)
                hidden1 = sync_batch_norm(x)
                print(hidden1)
                # Tensor(shape=[1, 2, 2, 2], dtype=float32, place=Place(gpu:0), stop_gradient=False,
                #        [[[[ 0.26824948,  1.09363246],
                #           [ 0.26824948, -1.63013160]],

                #          [[ 0.80956620, -0.66528702],
                #           [-1.27446556,  1.13018656]]]])
1525

C
ceci3 已提交
1526 1527
    """

1528 1529 1530 1531 1532 1533 1534 1535 1536 1537
    def __init__(
        self,
        num_features,
        momentum=0.9,
        epsilon=1e-05,
        weight_attr=None,
        bias_attr=None,
        data_format='NCHW',
        name=None,
    ):
1538
        super().__init__(
1539 1540 1541 1542 1543 1544 1545 1546 1547
            num_features,
            momentum,
            epsilon,
            weight_attr,
            bias_attr,
            data_format,
            None,
            name,
        )
C
ceci3 已提交
1548

C
ceci3 已提交
1549 1550 1551 1552 1553 1554 1555 1556 1557 1558
    def _check_data_format(self):
        if self._data_format in ['NCHW', 'NCDHW', 'NC', 'NCL']:
            self._data_format = 'NCHW'
        elif self._data_format in ["NHWC", "NDHWC", 'NLC']:
            self._data_format = 'NHWC'
        else:
            raise ValueError(
                'expected \'NCDHW\', \'NDHWC\', \'NCL\', \'NLC\', \'NC\', \'NCHW\', \'NHWC\' for data_format'
            )

C
ceci3 已提交
1559
    def forward(self, x):
C
ceci3 已提交
1560
        self._check_data_format()
C
ceci3 已提交
1561 1562 1563 1564 1565 1566
        # create output
        # mean and mean_out share the same memory
        mean_out = self._mean
        # variance and variance out share the same memory
        variance_out = self._variance

1567 1568
        # train mode: use mini-batch stats, eval mode: use global stats
        # use_global_stats only support False in sync_batch_norm
1569
        if in_dygraph_mode():
1570
            sync_batch_norm_out, _, _, _, _, _ = _C_ops.sync_batch_norm_(
1571 1572 1573
                x,
                self._mean,
                self._variance,
1574 1575 1576
                self.weight,
                self.bias,
                not self.training,
1577 1578 1579 1580 1581 1582
                self._momentum,
                self._epsilon,
                self._data_format,
                False,
                False,
            )
1583 1584 1585
            return sync_batch_norm_out

        elif in_dynamic_mode():
1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603
            attrs = (
                "momentum",
                self._momentum,
                "epsilon",
                self._epsilon,
                "is_test",
                not self.training,
                "data_layout",
                self._data_format,
                "use_mkldnn",
                False,
                "fuse_with_relu",
                False,
                "use_global_stats",
                False,
                'trainable_statistics',
                False,
            )
1604
            sync_batch_norm_out, _, _, _, _, _ = _legacy_C_ops.sync_batch_norm(
1605 1606 1607 1608 1609 1610 1611 1612 1613
                x,
                self.weight,
                self.bias,
                self._mean,
                self._variance,
                mean_out,
                variance_out,
                *attrs
            )
C
ceci3 已提交
1614 1615
            return sync_batch_norm_out

1616 1617 1618
        check_variable_and_dtype(
            x, 'input', ['float16', 'float32', 'float64'], 'SyncBatchNorm'
        )
C
ceci3 已提交
1619 1620 1621 1622 1623

        attrs = {
            "momentum": self._momentum,
            "epsilon": self._epsilon,
            "is_test": not self.training,
1624
            "data_layout": self._data_format,
C
ceci3 已提交
1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635
            "use_mkldnn": False,
            "fuse_with_relu": False,
            "use_global_stats": False,
            "trainable_statistics": False,
        }

        inputs = {
            "X": [x],
            "Scale": [self.weight],
            "Bias": [self.bias],
            "Mean": [self._mean],
1636
            "Variance": [self._variance],
C
ceci3 已提交
1637 1638 1639
        }

        saved_mean = self._helper.create_variable_for_type_inference(
1640 1641
            dtype=self._dtype, stop_gradient=True
        )
C
ceci3 已提交
1642
        saved_variance = self._helper.create_variable_for_type_inference(
1643 1644
            dtype=self._dtype, stop_gradient=True
        )
C
ceci3 已提交
1645
        sync_batch_norm_out = self._helper.create_variable_for_type_inference(
1646 1647
            self._dtype
        )
C
ceci3 已提交
1648 1649 1650 1651 1652 1653

        outputs = {
            "Y": [sync_batch_norm_out],
            "MeanOut": [mean_out],
            "VarianceOut": [variance_out],
            "SavedMean": [saved_mean],
1654
            "SavedVariance": [saved_variance],
C
ceci3 已提交
1655 1656
        }

1657 1658 1659
        self._helper.append_op(
            type="sync_batch_norm", inputs=inputs, outputs=outputs, attrs=attrs
        )
C
ceci3 已提交
1660
        return sync_batch_norm_out
1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674

    @classmethod
    def convert_sync_batchnorm(cls, layer):
        """
        Helper function to convert :class: `paddle.nn.BatchNorm*d` layers in the model to :class: `paddle.nn.SyncBatchNorm` layers.

        Parameters:
            layer(paddle.nn.Layer): model containing one or more `BatchNorm*d` layers.

        Returns:
            The original model with converted SyncBatchNorm layers. If BatchNorm*d layer in the model, use SyncBatchNorm layer instead.

        Examples:
            .. code-block:: python
1675

1676 1677 1678
                import paddle
                import paddle.nn as nn

C
cnn 已提交
1679
                model = nn.Sequential(nn.Conv2D(3, 5, 3), nn.BatchNorm2D(5))
1680 1681 1682 1683 1684
                sync_model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

        """
        layer_output = layer
        if isinstance(layer, _BatchNormBase):
1685
            if (
1686
                layer._weight_attr is not None
1687
                and not isinstance(layer._weight_attr, bool)
1688
                and layer._weight_attr.name is not None
1689
            ):
C
ceci3 已提交
1690
                layer._weight_attr.name = layer._weight_attr.name + '_sync'
1691
            if (
1692
                layer._bias_attr is not None
1693
                and not isinstance(layer._bias_attr, bool)
1694
                and layer._bias_attr.name is not None
1695
            ):
C
ceci3 已提交
1696 1697
                layer._bias_attr.name = layer._bias_attr.name + '_sync'

1698 1699 1700 1701 1702 1703 1704 1705 1706
            layer_output = SyncBatchNorm(
                layer._num_features,
                layer._momentum,
                layer._epsilon,
                layer._weight_attr,
                layer._bias_attr,
                layer._data_format,
                layer._name,
            )
1707

1708 1709 1710 1711
            if (
                layer._weight_attr is not False
                and layer._bias_attr is not False
            ):
1712 1713 1714 1715 1716 1717
                with no_grad():
                    layer_output.weight = layer.weight
                    layer_output.bias = layer.bias
            layer_output._mean = layer._mean
            layer_output._variance = layer._variance

C
ceci3 已提交
1718
        for name, sublayer in layer.named_children():
1719 1720 1721
            layer_output.add_sublayer(
                name, cls.convert_sync_batchnorm(sublayer)
            )
1722 1723
        del layer
        return layer_output
1724 1725


Z
zhiboniu 已提交
1726
class LocalResponseNorm(Layer):
1727
    """
1728 1729
    Local Response Normalization performs a type of "lateral inhibition" by normalizing over local input regions.
    For more information, please refer to `ImageNet Classification with Deep Convolutional Neural Networks <https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf>`_
1730

1731
    See more details in :ref:`api_paddle_nn_functional_local_response_norm` .
1732

1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747
    Parameters:
        size (int): The number of channels to sum over.
        alpha (float, optional): The scaling parameter, positive. Default:1e-4
        beta (float, optional): The exponent, positive. Default:0.75
        k (float, optional): An offset, positive. Default: 1.0
        data_format (str, optional): Specify the data format of the input, and the data format of the output
            will be consistent with that of the input. An optional string from:
            If input is 3-D Tensor, the string could be `"NCL"` or `"NLC"` . When it is `"NCL"`,
            the data is stored in the order of: `[batch_size, input_channels, feature_length]`.
            If input is 4-D Tensor, the string could be  `"NCHW"`, `"NHWC"`. When it is `"NCHW"`,
            the data is stored in the order of: `[batch_size, input_channels, input_height, input_width]`.
            If input is 5-D Tensor, the string could be  `"NCDHW"`, `"NDHWC"` . When it is `"NCDHW"`,
            the data is stored in the order of: `[batch_size, input_channels, input_depth, input_height, input_width]`.
        name (str, optional): Name for the operation (optional, default is None). For more information,
            please refer to :ref:`api_guide_Name`.
1748

1749 1750 1751
    Shape:
        - input: 3-D/4-D/5-D tensor.
        - output: 3-D/4-D/5-D tensor, the same shape as input.
1752

1753
    Examples:
1754

1755
    .. code-block:: python
1756

1757 1758 1759 1760 1761 1762 1763
        import paddle

        x = paddle.rand(shape=(3, 3, 112, 112), dtype="float32")
        m = paddle.nn.LocalResponseNorm(size=5)
        y = m(x)
        print(y.shape)  # [3, 3, 112, 112]
    """
1764

1765 1766 1767 1768 1769 1770 1771 1772 1773
    def __init__(
        self,
        size,
        alpha=0.0001,
        beta=0.75,
        k=1.0,
        data_format="NCHW",
        name=None,
    ):
1774
        super().__init__()
1775 1776 1777 1778 1779 1780 1781 1782
        self.size = size
        self.alpha = alpha
        self.beta = beta
        self.k = k
        self.data_format = data_format
        self.name = name

    def forward(self, input):
1783 1784 1785 1786 1787 1788 1789 1790 1791
        out = F.local_response_norm(
            input,
            self.size,
            self.alpha,
            self.beta,
            self.k,
            self.data_format,
            self.name,
        )
1792
        return out
1793 1794 1795

    def extra_repr(self):
        main_str = 'size={}, alpha={}, beta={}, k={}'.format(
1796 1797
            self.size, self.alpha, self.beta, self.k
        )
1798
        if self.data_format != 'NCHW':
1799 1800 1801 1802
            main_str += ', data_format={}'.format(self.data_format)
        if self.name is not None:
            main_str += ', name={}'.format(self.name)
        return main_str
1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936


class SpectralNorm(Layer):
    r"""
    This interface is used to construct a callable object of the ``SpectralNorm`` class.
    For more details, refer to code examples. It implements the function of the Spectral Normalization Layer.
    This layer calculates the spectral normalization value of weight parameters of
    fc, conv1d, conv2d, conv3d layers which should be 2-D, 3-D, 4-D, 5-D
    Parameters. Calculations are showed as follows.

    Step 1:
    Generate vector U in shape of [H], and V in shape of [W].
    While H is the :attr:`axis` th dimension of the input weights,
    and W is the product result of remaining dimensions.

    Step 2:
    :attr:`power_iters` should be a positive integer, do following
    calculations with U and V for :attr:`power_iters` rounds.

    .. math::

        \mathbf{v} := \frac{\mathbf{W}^{T} \mathbf{u}}{\|\mathbf{W}^{T} \mathbf{u}\|_2}

        \mathbf{u} := \frac{\mathbf{W}^{T} \mathbf{v}}{\|\mathbf{W}^{T} \mathbf{v}\|_2}

    Step 3:
    Calculate :math:`\sigma(\mathbf{W})` and normalize weight values.

    .. math::

        \sigma(\mathbf{W}) = \mathbf{u}^{T} \mathbf{W} \mathbf{v}

        \mathbf{W} = \frac{\mathbf{W}}{\sigma(\mathbf{W})}


    Refer to `Spectral Normalization <https://arxiv.org/abs/1802.05957>`_ .

    Parameters:
        weight_shape(list or tuple): The shape of weight parameter.
        axis(int, optional): The index of dimension which should be permuted to the first before reshaping Input(Weight) to matrix, it should be set as 0 if Input(Weight) is the weight of fc layer, and should be set as 1 if Input(Weight) is the weight of conv layer. Default: 0.
        power_iters(int, optional): The number of power iterations to calculate spectral norm. Default: 1.
        epsilon(float, optional): The epsilon for numerical stability in calculating norms. Default: 1e-12.
        name (str, optional): The default value is None.  Normally there is no need for user to set this property.  For more information, please refer to :ref:`api_guide_Name` .
        dtype (str, optional): Data type, it can be "float32" or "float64". Default: "float32".

    Returns:
        None

    Examples:
       .. code-block:: python

            import paddle
            x = paddle.rand((2,8,32,32))

            spectral_norm = paddle.nn.SpectralNorm(x.shape, axis=1, power_iters=2)
            spectral_norm_out = spectral_norm(x)

            print(spectral_norm_out.shape) # [2, 8, 32, 32]

    """

    def __init__(
        self,
        weight_shape,
        axis=0,
        power_iters=1,
        epsilon=1e-12,
        dtype='float32',
    ):
        super().__init__()
        self._power_iters = power_iters
        self._epsilon = epsilon
        self._dim = axis
        self._dtype = dtype

        self._weight_shape = list(weight_shape)
        assert (
            np.prod(self._weight_shape) > 0
        ), "Any dimension of `weight_shape` cannot be equal to 0."
        assert axis < len(self._weight_shape), (
            "The input `axis` should be less than the "
            "length of `weight_shape`, but received axis="
            "{}".format(axis)
        )
        h = self._weight_shape[self._dim]
        w = np.prod(self._weight_shape) // h

        self.weight_u = self.create_parameter(
            attr=ParamAttr(),
            shape=[h],
            dtype=self._dtype,
            default_initializer=Normal(0.0, 1.0),
        )
        self.weight_u.stop_gradient = True

        self.weight_v = self.create_parameter(
            attr=ParamAttr(),
            shape=[w],
            dtype=self._dtype,
            default_initializer=Normal(0.0, 1.0),
        )
        self.weight_v.stop_gradient = True

    def forward(self, x):
        weight = x
        if in_dygraph_mode():
            return _C_ops.spectral_norm(
                weight,
                self.weight_u,
                self.weight_v,
                self._dim,
                self._power_iters,
                self._epsilon,
            )

        check_variable_and_dtype(
            weight, "weight", ['float32', 'float64'], 'SpectralNorm'
        )
        inputs = {'Weight': weight, 'U': self.weight_u, 'V': self.weight_v}
        out = self._helper.create_variable_for_type_inference(self._dtype)
        self._helper.append_op(
            type="spectral_norm",
            inputs=inputs,
            outputs={
                "Out": out,
            },
            attrs={
                "dim": self._dim,
                "power_iters": self._power_iters,
                "eps": self._epsilon,
            },
        )

        return out