normalization.py 30.0 KB
Newer Older
Z
zhunaipan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""normalization"""
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
Z
zhaojichen 已提交
20
from mindspore.ops.primitive import constexpr
Z
zhunaipan 已提交
21
import mindspore.context as context
Z
zhaojichen 已提交
22
from mindspore._checkparam import check_bool, check_typename
Z
zhunaipan 已提交
23
from mindspore._extends import cell_attr_register
Z
zhaojichen 已提交
24
from mindspore.communication.management import get_group_size, get_rank
Z
zhaojichen 已提交
25 26
from mindspore.communication import management
from mindspore._checkparam import check_int_positive
G
gong chen 已提交
27
from mindspore.ops import _selected_ops
Z
zhunaipan 已提交
28 29 30
from ..cell import Cell


G
gong chen 已提交
31

32 33
__all__ = ['BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'GroupNorm', 'GlobalBatchNorm']

Z
zhunaipan 已提交
34 35 36 37 38 39 40 41 42 43 44 45
class _BatchNorm(Cell):
    """Batch Normalization base class."""
    @cell_attr_register
    def __init__(self,
                 num_features,
                 eps=1e-5,
                 momentum=0.9,
                 affine=True,
                 gamma_init='ones',
                 beta_init='zeros',
                 moving_mean_init='zeros',
                 moving_var_init='ones',
Z
zhaojichen 已提交
46
                 use_batch_statistics=None,
L
liuxiao93 已提交
47
                 device_num_each_group=1,
48
                 input_dims='2d'):
Z
zhunaipan 已提交
49 50 51 52 53 54 55 56 57 58
        super(_BatchNorm, self).__init__()
        if num_features < 1:
            raise ValueError("num_features must be at least 1")

        if momentum < 0 or momentum > 1:
            raise ValueError("momentum should be a number in range [0, 1], but got {}".format(momentum))

        self.use_batch_statistics = use_batch_statistics
        self.num_features = num_features
        self.eps = eps
L
liuxiao93 已提交
59
        self.input_dims = input_dims
Z
zhunaipan 已提交
60 61 62 63 64 65 66 67
        self.moving_mean = Parameter(initializer(
            moving_mean_init, num_features), name="mean", requires_grad=False)
        self.moving_variance = Parameter(initializer(
            moving_var_init, num_features), name="variance", requires_grad=False)
        self.gamma = Parameter(initializer(
            gamma_init, num_features), name="gamma", requires_grad=affine)
        self.beta = Parameter(initializer(
            beta_init, num_features), name="beta", requires_grad=affine)
68
        self.group = check_int_positive(device_num_each_group)
Y
YuJianfeng 已提交
69
        self.is_global = False
Z
zhaojichen 已提交
70 71
        if self.group != 1:
            self.rank_id = get_rank()
Z
zhaojichen 已提交
72
            self.rank_size = get_group_size()
Z
zhaojichen 已提交
73 74 75 76 77 78 79
            self.device_list = [i for i in range(0, self.rank_size)]
            self.rank_list = self.list_group(self.device_list, self.group)
            self.rank_list_idx = len(self.rank_list)
            for i in range(self.rank_list_idx):
                if self.rank_id in self.rank_list[i] and self.group != 1:
                    self.is_global = True
                    management.create_group('group' + str(i), self.rank_list[i])
Z
zhaojichen 已提交
80
                    self.all_reduce = P.AllReduce(P.ReduceOp.SUM, 'group' + str(i)).add_prim_attr('fusion', 1)
Z
zhaojichen 已提交
81
        self.shape = P.Shape()
Z
zhaojichen 已提交
82
        self.reduce_mean = P.ReduceMean(keep_dims=True)
Z
zhaojichen 已提交
83
        self.square = P.Square()
Z
zhaojichen 已提交
84 85 86 87
        self.sqrt = P.Sqrt()
        self.cast = P.Cast()
        self.dtype = P.DType()
        self.reshape = P.Reshape()
Y
YuJianfeng 已提交
88
        self.is_ascend = context.get_context("device_target") == "Ascend"
89
        self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE
90
        self.momentum = 1.0 - momentum
Z
zhunaipan 已提交
91 92 93 94
        if context.get_context("enable_ge"):
            self.is_ge_backend = True
        else:
            self.is_ge_backend = False
95

96
        if self.is_graph_mode and (self.is_ge_backend or self.is_ascend):
Y
YuJianfeng 已提交
97 98 99
            self.bn_train = P.BatchNorm(is_training=True,
                                        epsilon=self.eps)
        else:
Z
zhunaipan 已提交
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
            self.bn_train = P.FusedBatchNorm(mode=1,
                                             epsilon=self.eps,
                                             momentum=self.momentum)
        self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps)

        data_parallel_strategy = ((1,), (1,))
        data_parallel_strategy_one = ((1,), ())
        self.sub_mean = P.Sub().set_strategy(data_parallel_strategy)
        self.sub_var = P.Sub().set_strategy(data_parallel_strategy)
        self.mul_mean = P.Mul().set_strategy(data_parallel_strategy_one)
        self.mul_var = P.Mul().set_strategy(data_parallel_strategy_one)
        self.assign_sub_mean = P.AssignSub().set_strategy(data_parallel_strategy)
        self.assign_sub_var = P.AssignSub().set_strategy(data_parallel_strategy)

    def _check_data_dim(self, x):
        raise NotImplementedError

Z
zhaojichen 已提交
117
    def list_group(self, world_rank, group_size):
Z
zhaojichen 已提交
118
        if group_size > get_group_size():
Z
zhaojichen 已提交
119
            raise ValueError("group size can not be greater than local rank size, group size is {}, "
Z
zhaojichen 已提交
120
                             "local_rank_size is {}".format(group_size, get_group_size()))
Z
zhaojichen 已提交
121 122
        if len(world_rank) % group_size != 0:
            raise ValueError("please make your group size correct.")
G
gong chen 已提交
123
        world_rank_list = zip(*(iter(world_rank),) * group_size)
Z
zhaojichen 已提交
124 125 126
        group_list = [list(i) for i in world_rank_list]
        return group_list

Z
zhaojichen 已提交
127

Z
zhaojichen 已提交
128 129 130 131 132 133 134 135 136 137 138
    def _global_sync(self, x, axes, re_shape):
        """calculate global batch normalization output"""
        x_mean = self.reduce_mean(x, axes)
        x_mean_square = self.reduce_mean(self.square(x), axes)
        global_batch_mean = self.all_reduce(x_mean) / self.group
        global_batch_mean_square = self.all_reduce(x_mean_square) / self.group
        global_mean = global_batch_mean
        global_var = global_batch_mean_square - self.square(global_mean)
        var_sqrt = self.sqrt(global_var + self.eps)
        mean_first = (x - global_mean) / var_sqrt
        y = mean_first * self.reshape(self.gamma, re_shape) + self.reshape(self.beta, re_shape)
Z
zhaojichen 已提交
139

Z
zhaojichen 已提交
140 141 142 143
        mean_sub = self.sub_mean(self.reshape(self.moving_mean, re_shape), global_mean)
        tmp_mean = self.mul_mean(mean_sub, self.cast(self.momentum, self.dtype(mean_sub)))
        mean_sub2 = self.sub_var(self.reshape(self.moving_mean, re_shape), global_var)
        tmp_variance = self.mul_var(mean_sub2, self.cast(self.momentum, self.dtype(mean_sub2)))
144 145 146
        y = F.depend(y, self.assign_sub_mean(self.moving_mean, self.reshape(tmp_mean, self.shape(self.moving_mean))))
        y = F.depend(y, self.assign_sub_var(self.moving_variance,
                                            self.reshape(tmp_variance, self.shape(self.moving_variance))))
Z
zhaojichen 已提交
147 148
        return y

Z
zhunaipan 已提交
149
    def construct(self, x):
L
liuxiao93 已提交
150 151
        if self.input_dims == '2d':
            _shape_check(self.shape(x))
L
liuxiao93 已提交
152 153
        if self.input_dims == '1d':
            _shape_check_2d(self.shape(x))
154 155
        if self.input_dims == 'both':
            _shape_check_2d_or_4d(self.shape(x))
Z
zhaojichen 已提交
156 157 158 159 160
        if self.use_batch_statistics is None:
            flag = self.training
        else:
            flag = self.use_batch_statistics
        if flag:
Y
YuJianfeng 已提交
161 162 163
            if self.is_ge_backend and self.is_global:
                axes, re_shape = _shape_infer(F.shape(x), self.num_features)
                y = self._global_sync(x, axes, re_shape)
164
            elif self.is_graph_mode and (self.is_ge_backend or self.is_ascend):
Z
zhaojichen 已提交
165 166 167 168 169 170 171 172 173 174
                if self.is_global:
                    axes, re_shape = _shape_infer(F.shape(x), self.num_features)
                    y = self._global_sync(x, axes, re_shape)
                else:
                    y, batch_mean, batch_var, _, _ = \
                        self.bn_train(x,
                                      self.gamma,
                                      self.beta,
                                      None,
                                      None)
Z
zhaojichen 已提交
175

Z
zhaojichen 已提交
176 177 178 179 180 181
                    mean_sub = self.sub_mean(self.moving_mean, batch_mean)
                    temp_mean = self.mul_mean(mean_sub, self.momentum)
                    mean_sub2 = self.sub_var(self.moving_variance, batch_var)
                    temp_variance = self.mul_var(mean_sub2, self.momentum)
                    y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean))
                    y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance))
Z
zhunaipan 已提交
182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
            else:
                y = self.bn_train(x,
                                  self.gamma,
                                  self.beta,
                                  self.moving_mean,
                                  self.moving_variance)[0]
        else:
            y = self.bn_infer(x,
                              self.gamma,
                              self.beta,
                              self.moving_mean,
                              self.moving_variance)[0]
        return y

    def extend_repr(self):
        return 'num_features={}, eps={}, momentum={}, gamma={}, beta={}, moving_mean={}, moving_variance={}'.format(
            self.num_features, self.eps, self.momentum, self.gamma, self.beta, self.moving_mean, self.moving_variance)

G
gong chen 已提交
200

Z
zhaojichen 已提交
201
@constexpr
Z
zhaojichen 已提交
202 203
def _channel_check(channel, num_channel):
    if channel != num_channel:
Z
zhaojichen 已提交
204
        raise ValueError("the input channel is not equal with num_channel")
Z
zhunaipan 已提交
205

G
gong chen 已提交
206

L
liuxiao93 已提交
207 208 209 210 211 212
@constexpr
def _shape_check_2d(input_shape):
    if len(input_shape) != 2:
        raise ValueError("The input must has 2 dims.")


Z
zhaojichen 已提交
213
@constexpr
Z
zhaojichen 已提交
214 215
def _shape_check(in_shape):
    if len(in_shape) != 4:
216 217 218 219 220 221 222
        raise ValueError("The input must has 4 dims.")


@constexpr
def _shape_check_2d_or_4d(in_shape):
    if len(in_shape) != 2 and len(in_shape) != 4:
        raise ValueError("The input must has 2 dims or 4 dims.")
G
gong chen 已提交
223 224


Z
zhaojichen 已提交
225
@constexpr
Z
zhaojichen 已提交
226 227 228 229 230 231 232 233 234 235
def _shape_infer(x_shape, num_feature):
    """global batch normalization shape and axes infer"""
    if len(x_shape) == 4:
        axes = (0, 2, 3)
        re_shape = (1, num_feature, 1, 1)
    else:
        axes = (0,)
        re_shape = (1, num_feature)
    return axes, re_shape

G
gong chen 已提交
236

Z
zhunaipan 已提交
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
class BatchNorm1d(_BatchNorm):
    r"""
    Batch normalization layer over a 2D input.

    Batch Normalization is widely used in convolutional networks. This layer
    applies Batch Normalization over a 2D input (a mini-batch of 1D inputs) to
    reduce internal covariate shift as described in the paper
    `Batch Normalization: Accelerating Deep Network Training by
    Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It
    rescales and recenters the feature using a mini-batch of data and
    the learned parameters which can be described in the following formula.

    .. math::
        y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    Args:
        num_features (int): `C` from an expected input of size (N, C).
        eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
        momentum (float): A floating hyperparameter of the momentum for the
            running_mean and running_var computation. Default: 0.9.
257
        affine (bool): A bool value when set to True, gamma and beta can be learnable. Default: True.
Z
zhunaipan 已提交
258 259 260 261 262 263 264 265 266 267 268 269
        gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
            'he_uniform', etc. Default: 'ones'.
        beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
            'he_uniform', etc. Default: 'zeros'.
        moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
            'he_uniform', etc. Default: 'zeros'.
        moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
            'he_uniform', etc. Default: 'ones'.
Z
zhaojichen 已提交
270 271 272 273
        use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
            use the mean value and variance value of specified value. If None, training process will use the mean and
            variance of current batch data and track the running mean and variance, eval process will use the running
            mean and variance. Default: None.
Z
zhunaipan 已提交
274 275

    Inputs:
L
liuxiao93 已提交
276
        - **input** (Tensor) - Tensor of shape :math:`(N, C_{in})`.
Z
zhunaipan 已提交
277 278

    Outputs:
L
liuxiao93 已提交
279
        Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out})`.
Z
zhunaipan 已提交
280 281 282

    Examples:
        >>> net = nn.BatchNorm1d(num_features=16)
Z
zhongligeng 已提交
283
        >>> input = Tensor(np.random.randint(0, 255, [3, 16]), mindspore.float32)
Z
zhunaipan 已提交
284 285
        >>> net(input)
    """
G
gong chen 已提交
286

Z
zhaojichen 已提交
287 288 289 290 291 292 293 294 295
    def __init__(self,
                 num_features,
                 eps=1e-5,
                 momentum=0.9,
                 affine=True,
                 gamma_init='ones',
                 beta_init='zeros',
                 moving_mean_init='zeros',
                 moving_var_init='ones',
Z
zhaojichen 已提交
296
                 use_batch_statistics=None):
Z
zhaojichen 已提交
297 298 299 300 301 302 303 304
        super(BatchNorm1d, self).__init__(num_features,
                                          eps,
                                          momentum,
                                          affine,
                                          gamma_init,
                                          beta_init,
                                          moving_mean_init,
                                          moving_var_init,
L
liuxiao93 已提交
305 306
                                          use_batch_statistics,
                                          input_dims='1d')
G
gong chen 已提交
307

Z
zhunaipan 已提交
308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332
    def _check_data_dim(self, x):
        if x.dim() != 2:
            pass


class BatchNorm2d(_BatchNorm):
    r"""
    Batch normalization layer over a 4D input.

    Batch Normalization is widely used in convolutional networks. This layer
    applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with
    additional channel dimension) to avoid internal covariate shift as described
    in the paper `Batch Normalization: Accelerating Deep Network Training by
    Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It
    rescales and recenters the feature using a mini-batch of data and
    the learned parameters which can be described in the following formula.

    .. math::
        y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    Args:
        num_features (int): `C` from an expected input of size (N, C, H, W).
        eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
        momentum (float): A floating hyperparameter of the momentum for the
            running_mean and running_var computation. Default: 0.9.
333
        affine (bool): A bool value when set to True, gamma and beta can be learnable. Default: True.
Z
zhunaipan 已提交
334 335 336 337 338 339 340 341 342 343 344 345
        gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
            'he_uniform', etc. Default: 'ones'.
        beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
            'he_uniform', etc. Default: 'zeros'.
        moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
            'he_uniform', etc. Default: 'zeros'.
        moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
            'he_uniform', etc. Default: 'ones'.
Z
zhaojichen 已提交
346 347 348 349
        use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
            use the mean value and variance value of specified value. If None, training process will use the mean and
            variance of current batch data and track the running mean and variance, eval process will use the running
            mean and variance. Default: None.
Z
zhunaipan 已提交
350 351 352 353 354 355 356 357 358 359 360 361

    Inputs:
        - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.

    Outputs:
        Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`.

    Examples:
        >>> net = nn.BatchNorm2d(num_features=3)
        >>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
        >>> net(input)
    """
G
gong chen 已提交
362

Z
zhaojichen 已提交
363 364 365 366 367 368 369 370 371
    def __init__(self,
                 num_features,
                 eps=1e-5,
                 momentum=0.9,
                 affine=True,
                 gamma_init='ones',
                 beta_init='zeros',
                 moving_mean_init='zeros',
                 moving_var_init='ones',
Z
zhaojichen 已提交
372
                 use_batch_statistics=None):
Z
zhaojichen 已提交
373 374 375 376 377 378 379 380
        super(BatchNorm2d, self).__init__(num_features,
                                          eps,
                                          momentum,
                                          affine,
                                          gamma_init,
                                          beta_init,
                                          moving_mean_init,
                                          moving_var_init,
L
liuxiao93 已提交
381 382
                                          use_batch_statistics,
                                          input_dims='2d')
G
gong chen 已提交
383

Z
zhunaipan 已提交
384 385 386 387 388
    def _check_data_dim(self, x):
        if x.dim() != 4:
            pass


Z
zhaojichen 已提交
389 390 391 392 393 394 395 396 397 398 399 400 401
class GlobalBatchNorm(_BatchNorm):
    r"""
    Global normalization layer over a N-dimension input.

    Global Normalization is cross device synchronized batch normalization. Batch Normalization implementation
    only normalize the data within each device. Global normalization will normalize the input within the group.
    It has been described in the paper `Batch Normalization: Accelerating Deep Network Training by
    Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the
    feature using a mini-batch of data and the learned parameters which can be described in the following formula.

    .. math::
        y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

402 403 404
    Note:
        Currently, GlobalBatchNorm only supports 2D and 4D inputs.

Z
zhaojichen 已提交
405 406
    Args:
        num_features (int): `C` from an expected input of size (N, C, H, W).
J
jiangjinsheng 已提交
407
        device_num_each_group (int): The number of devices in each group. Default: 1.
Z
zhaojichen 已提交
408 409 410 411 412 413 414 415 416 417 418 419 420 421 422
        eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
        momentum (float): A floating hyperparameter of the momentum for the
            running_mean and running_var computation. Default: 0.9.
        gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
            'he_uniform', etc. Default: 'ones'.
        beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
            'he_uniform', etc. Default: 'zeros'.
        moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
            'he_uniform', etc. Default: 'zeros'.
        moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
            'he_uniform', etc. Default: 'ones'.
Z
zhaojichen 已提交
423 424 425 426
        use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
            use the mean value and variance value of specified value. If None, training process will use the mean and
            variance of current batch data and track the running mean and variance, eval process will use the running
            mean and variance. Default: None.
Z
zhaojichen 已提交
427 428 429 430 431 432 433 434

    Inputs:
        - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.

    Outputs:
        Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`.

    Examples:
Z
zhaojichen 已提交
435
        >>> global_bn_op = nn.GlobalBatchNorm(num_features=3, device_num_each_group=4)
Z
zhaojichen 已提交
436 437 438
        >>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
        >>> global_bn_op(input)
    """
G
gong chen 已提交
439

Z
zhaojichen 已提交
440 441 442 443 444 445 446 447 448
    def __init__(self,
                 num_features,
                 eps=1e-5,
                 momentum=0.9,
                 affine=True,
                 gamma_init='ones',
                 beta_init='zeros',
                 moving_mean_init='zeros',
                 moving_var_init='ones',
Z
zhaojichen 已提交
449
                 use_batch_statistics=None,
450
                 device_num_each_group=1):
Z
zhaojichen 已提交
451
        super(GlobalBatchNorm, self).__init__(num_features,
Z
zhaojichen 已提交
452 453 454 455 456 457 458 459
                                              eps,
                                              momentum,
                                              affine,
                                              gamma_init,
                                              beta_init,
                                              moving_mean_init,
                                              moving_var_init,
                                              use_batch_statistics,
460 461
                                              device_num_each_group,
                                              input_dims='both')
462
        self.group = check_int_positive(device_num_each_group)
Z
zhaojichen 已提交
463
        if self.group <= 1:
Z
zhaojichen 已提交
464
            raise ValueError("the number of group must be greater than 1.")
G
gong chen 已提交
465

Z
zhaojichen 已提交
466 467 468
    def _check_data_dim(self, x):
        if x.dim == 0:
            pass
Z
zhaojichen 已提交
469

G
gong chen 已提交
470

Z
zhunaipan 已提交
471 472 473 474 475 476 477 478 479 480 481 482 483 484 485
class LayerNorm(Cell):
    r"""
    Applies Layer Normalization over a mini-batch of inputs.

    Layer normalization is widely used in recurrent neural networks. It applies
    normalization over a mini-batch of inputs for each single training case as described
    in the paper `Layer Normalization <https://arxiv.org/pdf/1607.06450.pdf>`_. Unlike batch
    normalization, layer normalization performs exactly the same computation at training and
    testing times. It can be described using the following formula. It is applied across all channels
    and pixel but only one batch size.

    .. math::
        y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    Args:
L
liuxiao 已提交
486 487
        normalized_shape (Union(tuple[int], list[int]): The normalization is performed over axis
            `begin_norm_axis ... R - 1`.
Z
zhunaipan 已提交
488 489 490 491 492 493 494 495 496 497 498
        begin_norm_axis (int): It first normalization dimension: normalization will be performed along dimensions
            `begin_norm_axis: rank(inputs)`, the value should be in [-1, rank(input)). Default: -1.
        begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters
            will have dimensions `begin_params_axis: rank(inputs)` and will be broadcast with
            the normalized inputs accordingly, the value should be in [-1, rank(input)). Default: -1.
        gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
            'he_uniform', etc. Default: 'ones'.
        beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
            'he_uniform', etc. Default: 'zeros'.
Z
zhaojichen 已提交
499
        epsilon (float): A value added to the denominator for numerical stability. Default: 1e-7.
Z
zhunaipan 已提交
500 501

    Inputs:
万万没想到 已提交
502
        - **input_x** (Tensor) - The shape of 'input_x' is :math:`(x_1, x_2, ..., x_R)`,
Z
zhunaipan 已提交
503 504 505 506 507 508
          and `input_shape[begin_norm_axis:]` is equal to `normalized_shape`.

    Outputs:
        Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`.

    Examples:
509
        >>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32)
510
        >>> shape1 = x.shape[1:]
Z
zhongligeng 已提交
511
        >>> m = nn.LayerNorm(shape1,  begin_norm_axis=1, begin_params_axis=1)
Z
zhunaipan 已提交
512 513
        >>> m(x)
    """
G
gong chen 已提交
514

Z
zhunaipan 已提交
515 516 517 518 519 520
    def __init__(self,
                 normalized_shape,
                 begin_norm_axis=-1,
                 begin_params_axis=-1,
                 gamma_init='ones',
                 beta_init='zeros',
Z
zhaojichen 已提交
521
                 epsilon=1e-7
Z
zhunaipan 已提交
522 523
                 ):
        super(LayerNorm, self).__init__()
524 525 526
        if not isinstance(normalized_shape, (tuple, list)):
            raise TypeError("The type of 'normalized_shape' should be tuple[int] or list[int], but '{}' type is {}."
                            .format(normalized_shape, type(normalized_shape)))
Z
zhunaipan 已提交
527 528 529
        self.normalized_shape = normalized_shape
        self.begin_norm_axis = begin_norm_axis
        self.begin_params_axis = begin_params_axis
Z
zhaojichen 已提交
530
        self.epsilon = epsilon
Z
zhunaipan 已提交
531 532 533 534
        self.gamma = Parameter(initializer(
            gamma_init, normalized_shape), name="gamma")
        self.beta = Parameter(initializer(
            beta_init, normalized_shape), name="beta")
G
gong chen 已提交
535 536
        self.layer_norm = _selected_ops.LayerNorm(begin_norm_axis=self.begin_norm_axis,
                                                  begin_params_axis=self.begin_params_axis)
Z
zhunaipan 已提交
537 538 539 540 541 542 543 544 545 546

    def construct(self, input_x):
        y, _, _ = self.layer_norm(input_x, self.gamma, self.beta)
        return y

    def extend_repr(self):
        """Display instance object as string."""
        s = 'normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}'.format(
            self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta)
        return s
Z
zhaojichen 已提交
547

G
gong chen 已提交
548

Z
zhaojichen 已提交
549 550
class GroupNorm(Cell):
    r"""
Z
zhaojichen 已提交
551
    Group Normalization over a mini-batch of inputs.
Z
zhaojichen 已提交
552 553 554 555 556 557 558 559 560 561 562 563 564 565 566

    Group normalization is widely used in recurrent neural networks. It applies
    normalization over a mini-batch of inputs for each single training case as described
    in the paper `Group Normalization <https://arxiv.org/pdf/1803.08494.pdf>`_. Group normalization
    divides the channels into groups and computes within each group the mean and variance for normalization,
    and it performs very stable over a wide range of batch size. It can be described using the following formula.

    .. math::
        y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    Args:
        num_groups (int): The number of groups to be divided along the channel dimension.
        num_channels (int): The number of channels per group.
        eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
        affine (bool): A bool value, this layer will has learnable affine parameters when set to true. Default: True.
Z
zhaojichen 已提交
567 568 569 570 571 572
        gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
            'he_uniform', etc. Default: 'ones'.
        beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
            'he_uniform', etc. Default: 'zeros'.
Z
zhaojichen 已提交
573 574 575 576 577 578 579 580 581 582 583 584

    Inputs:
        - **input_x** (Tensor) - The input feature with shape [N, C, H, W].

    Outputs:
        Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`.

    Examples:
        >>> goup_norm_op = nn.GroupNorm(16, 64)
        >>> x = Tensor(np.ones([1, 64, 256, 256], np.float32))
        >>> goup_norm_op(x)
    """
G
gong chen 已提交
585

586
    def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, gamma_init='ones', beta_init='zeros'):
Z
zhaojichen 已提交
587 588 589 590 591
        super(GroupNorm, self).__init__()
        self.num_groups = check_int_positive(num_groups)
        self.num_channels = check_int_positive(num_channels)
        if num_channels % num_groups != 0:
            raise ValueError("num_channels should be divided by num_groups")
592
        self.eps = check_typename('eps', eps, (float,))
Z
zhaojichen 已提交
593 594
        self.affine = check_bool(affine)

595 596
        gamma = initializer(gamma_init, [num_channels, 1, 1])
        beta = initializer(beta_init, [num_channels, 1, 1])
Z
zhaojichen 已提交
597 598 599 600 601 602 603 604 605 606 607 608 609
        if self.affine:
            self.gamma = Parameter(gamma, name='gamma')
            self.beta = Parameter(beta, name='beta')
        else:
            self.gamma = gamma
            self.beta = beta
        self.shape = F.shape
        self.reshape = F.reshape
        self.reduce_mean = P.ReduceMean(keep_dims=True)
        self.square = F.square
        self.reduce_sum = P.ReduceSum(keep_dims=True)
        self.sqrt = P.Sqrt()

Z
zhaojichen 已提交
610 611
    def _cal_output(self, x):
        """calculate groupnorm output"""
Z
zhaojichen 已提交
612
        batch, channel, height, width = self.shape(x)
Z
zhaojichen 已提交
613
        _channel_check(channel, self.num_channels)
J
jiangjinsheng 已提交
614
        x = self.reshape(x, (batch, self.num_groups, -1))
Z
zhaojichen 已提交
615 616 617 618 619 620 621 622
        mean = self.reduce_mean(x, 2)
        var = self.reduce_sum(self.square(x - mean), 2) / (channel * height * width / self.num_groups - 1)
        std = self.sqrt(var + self.eps)
        x = (x - mean) / std
        x = self.reshape(x, (batch, channel, height, width))
        output = x * self.gamma + self.beta
        return output

Z
zhaojichen 已提交
623 624 625 626 627
    def construct(self, x):
        _shape_check(self.shape(x))
        output = self._cal_output(x)
        return output

Z
zhaojichen 已提交
628
    def extend_repr(self):
Z
zhaojichen 已提交
629 630
        """Display instance object as string."""
        s = 'num_groups={}, num_channels={}'.format(self.num_groups, self.num_channels)
Z
zhaojichen 已提交
631
        return s