normalization.py 30.8 KB
Newer Older
Z
zhunaipan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""normalization"""
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
Z
zhaojichen 已提交
20
from mindspore.ops.primitive import constexpr
Z
zhunaipan 已提交
21
import mindspore.context as context
Z
zhaojichen 已提交
22
from mindspore._checkparam import check_bool, check_typename
Z
zhunaipan 已提交
23
from mindspore._extends import cell_attr_register
Z
zhaojichen 已提交
24
from mindspore.communication.management import get_group_size, get_rank
Z
zhaojichen 已提交
25 26
from mindspore.communication import management
from mindspore._checkparam import check_int_positive
G
gong chen 已提交
27
from mindspore.ops import _selected_ops
Z
zhunaipan 已提交
28 29 30
from ..cell import Cell


G
gong chen 已提交
31

32 33
__all__ = ['BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'GroupNorm', 'GlobalBatchNorm']

Z
zhunaipan 已提交
34 35 36 37 38 39 40 41 42 43 44 45
class _BatchNorm(Cell):
    """Batch Normalization base class."""
    @cell_attr_register
    def __init__(self,
                 num_features,
                 eps=1e-5,
                 momentum=0.9,
                 affine=True,
                 gamma_init='ones',
                 beta_init='zeros',
                 moving_mean_init='zeros',
                 moving_var_init='ones',
Z
zhaojichen 已提交
46
                 use_batch_statistics=None,
L
liuxiao93 已提交
47
                 device_num_each_group=1,
48
                 input_dims='2d'):
Z
zhunaipan 已提交
49 50 51 52 53 54 55 56 57 58
        super(_BatchNorm, self).__init__()
        if num_features < 1:
            raise ValueError("num_features must be at least 1")

        if momentum < 0 or momentum > 1:
            raise ValueError("momentum should be a number in range [0, 1], but got {}".format(momentum))

        self.use_batch_statistics = use_batch_statistics
        self.num_features = num_features
        self.eps = eps
L
liuxiao93 已提交
59
        self.input_dims = input_dims
Z
zhunaipan 已提交
60 61 62 63 64 65 66 67
        self.moving_mean = Parameter(initializer(
            moving_mean_init, num_features), name="mean", requires_grad=False)
        self.moving_variance = Parameter(initializer(
            moving_var_init, num_features), name="variance", requires_grad=False)
        self.gamma = Parameter(initializer(
            gamma_init, num_features), name="gamma", requires_grad=affine)
        self.beta = Parameter(initializer(
            beta_init, num_features), name="beta", requires_grad=affine)
68
        self.group = check_int_positive(device_num_each_group)
Y
YuJianfeng 已提交
69
        self.is_global = False
Z
zhaojichen 已提交
70 71
        if self.group != 1:
            self.rank_id = get_rank()
Z
zhaojichen 已提交
72
            self.rank_size = get_group_size()
Z
zhaojichen 已提交
73 74 75 76 77 78 79
            self.device_list = [i for i in range(0, self.rank_size)]
            self.rank_list = self.list_group(self.device_list, self.group)
            self.rank_list_idx = len(self.rank_list)
            for i in range(self.rank_list_idx):
                if self.rank_id in self.rank_list[i] and self.group != 1:
                    self.is_global = True
                    management.create_group('group' + str(i), self.rank_list[i])
Z
zhaojichen 已提交
80
                    self.all_reduce = P.AllReduce(P.ReduceOp.SUM, 'group' + str(i)).add_prim_attr('fusion', 1)
Z
zhaojichen 已提交
81
        self.shape = P.Shape()
Z
zhaojichen 已提交
82
        self.reduce_mean = P.ReduceMean(keep_dims=True)
Z
zhaojichen 已提交
83
        self.square = P.Square()
Z
zhaojichen 已提交
84 85 86 87
        self.sqrt = P.Sqrt()
        self.cast = P.Cast()
        self.dtype = P.DType()
        self.reshape = P.Reshape()
Y
YuJianfeng 已提交
88
        self.is_ascend = context.get_context("device_target") == "Ascend"
L
lizhenyu 已提交
89
        self.is_gpu = context.get_context("device_target") == "GPU"
90
        self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE
91
        self.momentum = 1.0 - momentum
Z
zhunaipan 已提交
92 93 94 95
        if context.get_context("enable_ge"):
            self.is_ge_backend = True
        else:
            self.is_ge_backend = False
96

97
        if self.is_graph_mode and (self.is_ge_backend or self.is_ascend):
Y
YuJianfeng 已提交
98 99
            self.bn_train = P.BatchNorm(is_training=True,
                                        epsilon=self.eps)
L
lizhenyu 已提交
100 101 102 103
        elif self.is_gpu:
            self.bn_train = P.FusedBatchNormEx(mode=1,
                                               epsilon=self.eps,
                                               momentum=self.momentum)
Y
YuJianfeng 已提交
104
        else:
Z
zhunaipan 已提交
105 106 107 108
            self.bn_train = P.FusedBatchNorm(mode=1,
                                             epsilon=self.eps,
                                             momentum=self.momentum)
        self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps)
W
wuyongkang 已提交
109 110 111
        self.enable_global_sync = self.is_global and (self.is_ge_backend or (self.is_graph_mode and self.is_ascend))
        self.enable_default_train = self.is_graph_mode and not self.is_global and \
                                    (self.is_ge_backend or self.is_ascend)
Z
zhunaipan 已提交
112 113 114 115 116 117 118 119 120 121 122 123 124

        data_parallel_strategy = ((1,), (1,))
        data_parallel_strategy_one = ((1,), ())
        self.sub_mean = P.Sub().set_strategy(data_parallel_strategy)
        self.sub_var = P.Sub().set_strategy(data_parallel_strategy)
        self.mul_mean = P.Mul().set_strategy(data_parallel_strategy_one)
        self.mul_var = P.Mul().set_strategy(data_parallel_strategy_one)
        self.assign_sub_mean = P.AssignSub().set_strategy(data_parallel_strategy)
        self.assign_sub_var = P.AssignSub().set_strategy(data_parallel_strategy)

    def _check_data_dim(self, x):
        raise NotImplementedError

Z
zhaojichen 已提交
125
    def list_group(self, world_rank, group_size):
Z
zhaojichen 已提交
126
        if group_size > get_group_size():
Z
zhaojichen 已提交
127
            raise ValueError("group size can not be greater than local rank size, group size is {}, "
Z
zhaojichen 已提交
128
                             "local_rank_size is {}".format(group_size, get_group_size()))
Z
zhaojichen 已提交
129 130
        if len(world_rank) % group_size != 0:
            raise ValueError("please make your group size correct.")
G
gong chen 已提交
131
        world_rank_list = zip(*(iter(world_rank),) * group_size)
Z
zhaojichen 已提交
132 133 134
        group_list = [list(i) for i in world_rank_list]
        return group_list

Z
zhaojichen 已提交
135

Z
zhaojichen 已提交
136 137 138 139 140 141 142 143 144 145 146
    def _global_sync(self, x, axes, re_shape):
        """calculate global batch normalization output"""
        x_mean = self.reduce_mean(x, axes)
        x_mean_square = self.reduce_mean(self.square(x), axes)
        global_batch_mean = self.all_reduce(x_mean) / self.group
        global_batch_mean_square = self.all_reduce(x_mean_square) / self.group
        global_mean = global_batch_mean
        global_var = global_batch_mean_square - self.square(global_mean)
        var_sqrt = self.sqrt(global_var + self.eps)
        mean_first = (x - global_mean) / var_sqrt
        y = mean_first * self.reshape(self.gamma, re_shape) + self.reshape(self.beta, re_shape)
Z
zhaojichen 已提交
147

Z
zhaojichen 已提交
148 149 150 151
        mean_sub = self.sub_mean(self.reshape(self.moving_mean, re_shape), global_mean)
        tmp_mean = self.mul_mean(mean_sub, self.cast(self.momentum, self.dtype(mean_sub)))
        mean_sub2 = self.sub_var(self.reshape(self.moving_mean, re_shape), global_var)
        tmp_variance = self.mul_var(mean_sub2, self.cast(self.momentum, self.dtype(mean_sub2)))
152 153 154
        y = F.depend(y, self.assign_sub_mean(self.moving_mean, self.reshape(tmp_mean, self.shape(self.moving_mean))))
        y = F.depend(y, self.assign_sub_var(self.moving_variance,
                                            self.reshape(tmp_variance, self.shape(self.moving_variance))))
Z
zhaojichen 已提交
155 156
        return y

Z
zhunaipan 已提交
157
    def construct(self, x):
W
wuyongkang 已提交
158
        _shape_check_bn(self.shape(x), self.input_dims)
Z
zhaojichen 已提交
159 160 161 162
        if self.use_batch_statistics is None:
            flag = self.training
        else:
            flag = self.use_batch_statistics
W
wuyongkang 已提交
163

Z
zhaojichen 已提交
164
        if flag:
W
wuyongkang 已提交
165
            if self.enable_global_sync:
Y
YuJianfeng 已提交
166
                axes, re_shape = _shape_infer(F.shape(x), self.num_features)
W
wuyongkang 已提交
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
                return self._global_sync(x, axes, re_shape)

            if self.enable_default_train:
                y, batch_mean, batch_var, _, _ = self.bn_train(x,
                                                               self.gamma,
                                                               self.beta,
                                                               None,
                                                               None)

                mean_sub = self.sub_mean(self.moving_mean, batch_mean)
                temp_mean = self.mul_mean(mean_sub, self.momentum)
                mean_sub2 = self.sub_var(self.moving_variance, batch_var)
                temp_variance = self.mul_var(mean_sub2, self.momentum)
                y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean))
                y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance))
                return y

            return self.bn_train(x,
                                 self.gamma,
                                 self.beta,
                                 self.moving_mean,
                                 self.moving_variance)[0]

        return self.bn_infer(x,
                             self.gamma,
                             self.beta,
                             self.moving_mean,
                             self.moving_variance)[0]
Z
zhunaipan 已提交
195 196 197 198 199

    def extend_repr(self):
        return 'num_features={}, eps={}, momentum={}, gamma={}, beta={}, moving_mean={}, moving_variance={}'.format(
            self.num_features, self.eps, self.momentum, self.gamma, self.beta, self.moving_mean, self.moving_variance)

G
gong chen 已提交
200

Z
zhaojichen 已提交
201
@constexpr
Z
zhaojichen 已提交
202 203
def _channel_check(channel, num_channel):
    if channel != num_channel:
Z
zhaojichen 已提交
204
        raise ValueError("the input channel is not equal with num_channel")
Z
zhunaipan 已提交
205

G
gong chen 已提交
206

Z
zhaojichen 已提交
207
@constexpr
Z
zhaojichen 已提交
208 209
def _shape_check(in_shape):
    if len(in_shape) != 4:
210 211 212 213
        raise ValueError("The input must has 4 dims.")


@constexpr
W
wuyongkang 已提交
214 215 216 217 218 219 220
def _shape_check_bn(in_shape, in_dims):
    dim = len(in_shape)
    if in_dims == '1d' and dim != 2:
        raise ValueError("The input must has 2 dims.")
    if in_dims == '2d' and dim != 4:
        raise ValueError("The input must has 4 dims.")
    if in_dims == 'both' and dim != 2 and dim != 4:
221
        raise ValueError("The input must has 2 dims or 4 dims.")
G
gong chen 已提交
222 223


Z
zhaojichen 已提交
224
@constexpr
Z
zhaojichen 已提交
225 226 227 228 229 230 231 232 233 234
def _shape_infer(x_shape, num_feature):
    """global batch normalization shape and axes infer"""
    if len(x_shape) == 4:
        axes = (0, 2, 3)
        re_shape = (1, num_feature, 1, 1)
    else:
        axes = (0,)
        re_shape = (1, num_feature)
    return axes, re_shape

G
gong chen 已提交
235

Z
zhunaipan 已提交
236 237 238 239 240 241 242 243 244 245 246 247 248 249 250
class BatchNorm1d(_BatchNorm):
    r"""
    Batch normalization layer over a 2D input.

    Batch Normalization is widely used in convolutional networks. This layer
    applies Batch Normalization over a 2D input (a mini-batch of 1D inputs) to
    reduce internal covariate shift as described in the paper
    `Batch Normalization: Accelerating Deep Network Training by
    Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It
    rescales and recenters the feature using a mini-batch of data and
    the learned parameters which can be described in the following formula.

    .. math::
        y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

251 252 253 254
    Note:
        The implementation of BatchNorm is different in graph mode and pynative mode, therefore the mode is not
        recommended to be changed after net was initilized.

Z
zhunaipan 已提交
255 256 257 258 259
    Args:
        num_features (int): `C` from an expected input of size (N, C).
        eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
        momentum (float): A floating hyperparameter of the momentum for the
            running_mean and running_var computation. Default: 0.9.
260
        affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
Z
zhunaipan 已提交
261 262 263 264 265 266 267 268 269 270 271 272
        gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
            'he_uniform', etc. Default: 'ones'.
        beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
            'he_uniform', etc. Default: 'zeros'.
        moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
            'he_uniform', etc. Default: 'zeros'.
        moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
            'he_uniform', etc. Default: 'ones'.
Z
zhaojichen 已提交
273
        use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
274 275 276
            use the mean value and variance value of specified value. If None, the training process will use the mean
            and variance of current batch data and track the running mean and variance, the evaluation process will use
            the running mean and variance. Default: None.
Z
zhunaipan 已提交
277 278

    Inputs:
L
liuxiao93 已提交
279
        - **input** (Tensor) - Tensor of shape :math:`(N, C_{in})`.
Z
zhunaipan 已提交
280 281

    Outputs:
L
liuxiao93 已提交
282
        Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out})`.
Z
zhunaipan 已提交
283 284 285

    Examples:
        >>> net = nn.BatchNorm1d(num_features=16)
Z
zhongligeng 已提交
286
        >>> input = Tensor(np.random.randint(0, 255, [3, 16]), mindspore.float32)
Z
zhunaipan 已提交
287 288
        >>> net(input)
    """
G
gong chen 已提交
289

Z
zhaojichen 已提交
290 291 292 293 294 295 296 297 298
    def __init__(self,
                 num_features,
                 eps=1e-5,
                 momentum=0.9,
                 affine=True,
                 gamma_init='ones',
                 beta_init='zeros',
                 moving_mean_init='zeros',
                 moving_var_init='ones',
Z
zhaojichen 已提交
299
                 use_batch_statistics=None):
Z
zhaojichen 已提交
300 301 302 303 304 305 306 307
        super(BatchNorm1d, self).__init__(num_features,
                                          eps,
                                          momentum,
                                          affine,
                                          gamma_init,
                                          beta_init,
                                          moving_mean_init,
                                          moving_var_init,
L
liuxiao93 已提交
308 309
                                          use_batch_statistics,
                                          input_dims='1d')
G
gong chen 已提交
310

Z
zhunaipan 已提交
311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330
    def _check_data_dim(self, x):
        if x.dim() != 2:
            pass


class BatchNorm2d(_BatchNorm):
    r"""
    Batch normalization layer over a 4D input.

    Batch Normalization is widely used in convolutional networks. This layer
    applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with
    additional channel dimension) to avoid internal covariate shift as described
    in the paper `Batch Normalization: Accelerating Deep Network Training by
    Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It
    rescales and recenters the feature using a mini-batch of data and
    the learned parameters which can be described in the following formula.

    .. math::
        y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

331 332 333 334
    Note:
        The implementation of BatchNorm is different in graph mode and pynative mode, therefore that mode can not be
        changed after net was initilized.

Z
zhunaipan 已提交
335 336 337 338 339
    Args:
        num_features (int): `C` from an expected input of size (N, C, H, W).
        eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
        momentum (float): A floating hyperparameter of the momentum for the
            running_mean and running_var computation. Default: 0.9.
340
        affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
Z
zhunaipan 已提交
341 342 343 344 345 346 347 348 349 350 351 352
        gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
            'he_uniform', etc. Default: 'ones'.
        beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
            'he_uniform', etc. Default: 'zeros'.
        moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
            'he_uniform', etc. Default: 'zeros'.
        moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
            'he_uniform', etc. Default: 'ones'.
Z
zhaojichen 已提交
353
        use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
354 355 356
            use the mean value and variance value of specified value. If None, the training process will use the mean
            and variance of current batch data and track the running mean and variance, the evaluation process will use
            the running mean and variance. Default: None.
Z
zhunaipan 已提交
357 358 359 360 361 362 363 364 365 366 367 368

    Inputs:
        - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.

    Outputs:
        Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`.

    Examples:
        >>> net = nn.BatchNorm2d(num_features=3)
        >>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
        >>> net(input)
    """
G
gong chen 已提交
369

Z
zhaojichen 已提交
370 371 372 373 374 375 376 377 378
    def __init__(self,
                 num_features,
                 eps=1e-5,
                 momentum=0.9,
                 affine=True,
                 gamma_init='ones',
                 beta_init='zeros',
                 moving_mean_init='zeros',
                 moving_var_init='ones',
Z
zhaojichen 已提交
379
                 use_batch_statistics=None):
Z
zhaojichen 已提交
380 381 382 383 384 385 386 387
        super(BatchNorm2d, self).__init__(num_features,
                                          eps,
                                          momentum,
                                          affine,
                                          gamma_init,
                                          beta_init,
                                          moving_mean_init,
                                          moving_var_init,
L
liuxiao93 已提交
388 389
                                          use_batch_statistics,
                                          input_dims='2d')
G
gong chen 已提交
390

Z
zhunaipan 已提交
391 392 393 394 395
    def _check_data_dim(self, x):
        if x.dim() != 4:
            pass


Z
zhaojichen 已提交
396 397 398 399
class GlobalBatchNorm(_BatchNorm):
    r"""
    Global normalization layer over a N-dimension input.

400 401
    Global Normalization is cross device synchronized batch normalization. The implementation of Batch Normalization
    only normalizes the data within each device. Global normalization will normalize the input within the group.
Z
zhaojichen 已提交
402 403 404 405 406 407 408
    It has been described in the paper `Batch Normalization: Accelerating Deep Network Training by
    Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the
    feature using a mini-batch of data and the learned parameters which can be described in the following formula.

    .. math::
        y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

409 410 411
    Note:
        Currently, GlobalBatchNorm only supports 2D and 4D inputs.

Z
zhaojichen 已提交
412 413
    Args:
        num_features (int): `C` from an expected input of size (N, C, H, W).
J
jiangjinsheng 已提交
414
        device_num_each_group (int): The number of devices in each group. Default: 1.
Z
zhaojichen 已提交
415 416 417 418 419 420 421 422 423 424 425 426 427 428 429
        eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
        momentum (float): A floating hyperparameter of the momentum for the
            running_mean and running_var computation. Default: 0.9.
        gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
            'he_uniform', etc. Default: 'ones'.
        beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
            'he_uniform', etc. Default: 'zeros'.
        moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
            'he_uniform', etc. Default: 'zeros'.
        moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
            'he_uniform', etc. Default: 'ones'.
Z
zhaojichen 已提交
430 431 432 433
        use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
            use the mean value and variance value of specified value. If None, training process will use the mean and
            variance of current batch data and track the running mean and variance, eval process will use the running
            mean and variance. Default: None.
Z
zhaojichen 已提交
434 435 436 437 438 439 440 441

    Inputs:
        - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.

    Outputs:
        Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`.

    Examples:
Z
zhaojichen 已提交
442
        >>> global_bn_op = nn.GlobalBatchNorm(num_features=3, device_num_each_group=4)
Z
zhaojichen 已提交
443 444 445
        >>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
        >>> global_bn_op(input)
    """
G
gong chen 已提交
446

Z
zhaojichen 已提交
447 448 449 450 451 452 453 454 455
    def __init__(self,
                 num_features,
                 eps=1e-5,
                 momentum=0.9,
                 affine=True,
                 gamma_init='ones',
                 beta_init='zeros',
                 moving_mean_init='zeros',
                 moving_var_init='ones',
Z
zhaojichen 已提交
456
                 use_batch_statistics=None,
457
                 device_num_each_group=1):
Z
zhaojichen 已提交
458
        super(GlobalBatchNorm, self).__init__(num_features,
Z
zhaojichen 已提交
459 460 461 462 463 464 465 466
                                              eps,
                                              momentum,
                                              affine,
                                              gamma_init,
                                              beta_init,
                                              moving_mean_init,
                                              moving_var_init,
                                              use_batch_statistics,
467 468
                                              device_num_each_group,
                                              input_dims='both')
469
        self.group = check_int_positive(device_num_each_group)
Z
zhaojichen 已提交
470
        if self.group <= 1:
Z
zhaojichen 已提交
471
            raise ValueError("the number of group must be greater than 1.")
G
gong chen 已提交
472

Z
zhaojichen 已提交
473 474 475
    def _check_data_dim(self, x):
        if x.dim == 0:
            pass
Z
zhaojichen 已提交
476

G
gong chen 已提交
477

Z
zhunaipan 已提交
478 479 480 481 482
class LayerNorm(Cell):
    r"""
    Applies Layer Normalization over a mini-batch of inputs.

    Layer normalization is widely used in recurrent neural networks. It applies
483
    normalization on a mini-batch of inputs for each single training case as described
Z
zhunaipan 已提交
484 485
    in the paper `Layer Normalization <https://arxiv.org/pdf/1607.06450.pdf>`_. Unlike batch
    normalization, layer normalization performs exactly the same computation at training and
486
    testing time. It can be described using the following formula. It is applied across all channels
Z
zhunaipan 已提交
487 488 489 490 491 492
    and pixel but only one batch size.

    .. math::
        y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    Args:
L
liuxiao 已提交
493 494
        normalized_shape (Union(tuple[int], list[int]): The normalization is performed over axis
            `begin_norm_axis ... R - 1`.
Z
zhunaipan 已提交
495 496 497 498 499 500 501 502 503 504 505
        begin_norm_axis (int): It first normalization dimension: normalization will be performed along dimensions
            `begin_norm_axis: rank(inputs)`, the value should be in [-1, rank(input)). Default: -1.
        begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters
            will have dimensions `begin_params_axis: rank(inputs)` and will be broadcast with
            the normalized inputs accordingly, the value should be in [-1, rank(input)). Default: -1.
        gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
            'he_uniform', etc. Default: 'ones'.
        beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
            'he_uniform', etc. Default: 'zeros'.
Z
zhaojichen 已提交
506
        epsilon (float): A value added to the denominator for numerical stability. Default: 1e-7.
Z
zhunaipan 已提交
507 508

    Inputs:
万万没想到 已提交
509
        - **input_x** (Tensor) - The shape of 'input_x' is :math:`(x_1, x_2, ..., x_R)`,
Z
zhunaipan 已提交
510 511 512 513 514 515
          and `input_shape[begin_norm_axis:]` is equal to `normalized_shape`.

    Outputs:
        Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`.

    Examples:
516
        >>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32)
517
        >>> shape1 = x.shape[1:]
Z
zhongligeng 已提交
518
        >>> m = nn.LayerNorm(shape1,  begin_norm_axis=1, begin_params_axis=1)
Z
zhunaipan 已提交
519 520
        >>> m(x)
    """
G
gong chen 已提交
521

Z
zhunaipan 已提交
522 523 524 525 526 527
    def __init__(self,
                 normalized_shape,
                 begin_norm_axis=-1,
                 begin_params_axis=-1,
                 gamma_init='ones',
                 beta_init='zeros',
Z
zhaojichen 已提交
528
                 epsilon=1e-7
Z
zhunaipan 已提交
529 530
                 ):
        super(LayerNorm, self).__init__()
531 532 533
        if not isinstance(normalized_shape, (tuple, list)):
            raise TypeError("The type of 'normalized_shape' should be tuple[int] or list[int], but '{}' type is {}."
                            .format(normalized_shape, type(normalized_shape)))
Z
zhunaipan 已提交
534 535 536
        self.normalized_shape = normalized_shape
        self.begin_norm_axis = begin_norm_axis
        self.begin_params_axis = begin_params_axis
Z
zhaojichen 已提交
537
        self.epsilon = epsilon
Z
zhunaipan 已提交
538 539 540 541
        self.gamma = Parameter(initializer(
            gamma_init, normalized_shape), name="gamma")
        self.beta = Parameter(initializer(
            beta_init, normalized_shape), name="beta")
G
gong chen 已提交
542 543
        self.layer_norm = _selected_ops.LayerNorm(begin_norm_axis=self.begin_norm_axis,
                                                  begin_params_axis=self.begin_params_axis)
Z
zhunaipan 已提交
544 545 546 547 548 549 550 551 552 553

    def construct(self, input_x):
        y, _, _ = self.layer_norm(input_x, self.gamma, self.beta)
        return y

    def extend_repr(self):
        """Display instance object as string."""
        s = 'normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}'.format(
            self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta)
        return s
Z
zhaojichen 已提交
554

G
gong chen 已提交
555

Z
zhaojichen 已提交
556 557
class GroupNorm(Cell):
    r"""
Z
zhaojichen 已提交
558
    Group Normalization over a mini-batch of inputs.
Z
zhaojichen 已提交
559 560

    Group normalization is widely used in recurrent neural networks. It applies
561
    normalization on a mini-batch of inputs for each single training case as described
Z
zhaojichen 已提交
562 563 564 565 566 567 568 569 570 571 572
    in the paper `Group Normalization <https://arxiv.org/pdf/1803.08494.pdf>`_. Group normalization
    divides the channels into groups and computes within each group the mean and variance for normalization,
    and it performs very stable over a wide range of batch size. It can be described using the following formula.

    .. math::
        y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    Args:
        num_groups (int): The number of groups to be divided along the channel dimension.
        num_channels (int): The number of channels per group.
        eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
573
        affine (bool): A bool value, this layer will have learnable affine parameters when set to true. Default: True.
Z
zhaojichen 已提交
574 575 576 577 578 579
        gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
            'he_uniform', etc. Default: 'ones'.
        beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
            'he_uniform', etc. Default: 'zeros'.
Z
zhaojichen 已提交
580 581 582 583 584 585 586 587

    Inputs:
        - **input_x** (Tensor) - The input feature with shape [N, C, H, W].

    Outputs:
        Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`.

    Examples:
L
lihongkang 已提交
588 589
        >>> goup_norm_op = nn.GroupNorm(2, 2)
        >>> x = Tensor(np.ones([1, 2, 4, 4], np.float32))
Z
zhaojichen 已提交
590
        >>> goup_norm_op(x)
L
lihongkang 已提交
591 592 593 594 595 596 597 598 599
        [[[[0. 0. 0. 0.]
           [0. 0. 0. 0.]
           [0. 0. 0. 0.]
           [0. 0. 0. 0.]]

          [[0. 0. 0. 0.]
           [0. 0. 0. 0.]
           [0. 0. 0. 0.]
           [0. 0. 0. 0.]]]]
Z
zhaojichen 已提交
600
    """
G
gong chen 已提交
601

602
    def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, gamma_init='ones', beta_init='zeros'):
Z
zhaojichen 已提交
603 604 605 606 607
        super(GroupNorm, self).__init__()
        self.num_groups = check_int_positive(num_groups)
        self.num_channels = check_int_positive(num_channels)
        if num_channels % num_groups != 0:
            raise ValueError("num_channels should be divided by num_groups")
608
        self.eps = check_typename('eps', eps, (float,))
Z
zhaojichen 已提交
609 610
        self.affine = check_bool(affine)

611 612
        gamma = initializer(gamma_init, [num_channels, 1, 1])
        beta = initializer(beta_init, [num_channels, 1, 1])
Z
zhaojichen 已提交
613 614 615 616 617 618 619 620 621 622 623 624 625
        if self.affine:
            self.gamma = Parameter(gamma, name='gamma')
            self.beta = Parameter(beta, name='beta')
        else:
            self.gamma = gamma
            self.beta = beta
        self.shape = F.shape
        self.reshape = F.reshape
        self.reduce_mean = P.ReduceMean(keep_dims=True)
        self.square = F.square
        self.reduce_sum = P.ReduceSum(keep_dims=True)
        self.sqrt = P.Sqrt()

Z
zhaojichen 已提交
626 627
    def _cal_output(self, x):
        """calculate groupnorm output"""
Z
zhaojichen 已提交
628
        batch, channel, height, width = self.shape(x)
Z
zhaojichen 已提交
629
        _channel_check(channel, self.num_channels)
J
jiangjinsheng 已提交
630
        x = self.reshape(x, (batch, self.num_groups, -1))
Z
zhaojichen 已提交
631 632 633 634 635 636 637 638
        mean = self.reduce_mean(x, 2)
        var = self.reduce_sum(self.square(x - mean), 2) / (channel * height * width / self.num_groups - 1)
        std = self.sqrt(var + self.eps)
        x = (x - mean) / std
        x = self.reshape(x, (batch, channel, height, width))
        output = x * self.gamma + self.beta
        return output

Z
zhaojichen 已提交
639 640 641 642 643
    def construct(self, x):
        _shape_check(self.shape(x))
        output = self._cal_output(x)
        return output

Z
zhaojichen 已提交
644
    def extend_repr(self):
Z
zhaojichen 已提交
645 646
        """Display instance object as string."""
        s = 'num_groups={}, num_channels={}'.format(self.num_groups, self.num_channels)
Z
zhaojichen 已提交
647
        return s