diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 6456a3603dba0d58f3132958e8dea248caf91b1f..07d08f9b2f6406b22e0bda07b018220a7e6e1402 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -33,7 +33,6 @@ class _BatchNorm(Cell): @cell_attr_register def __init__(self, num_features, - group=1, eps=1e-5, momentum=0.9, affine=True, @@ -41,7 +40,8 @@ class _BatchNorm(Cell): beta_init='zeros', moving_mean_init='zeros', moving_var_init='ones', - use_batch_statistics=True): + use_batch_statistics=True, + group=1): super(_BatchNorm, self).__init__() if num_features < 1: raise ValueError("num_features must be at least 1") @@ -214,6 +214,25 @@ class BatchNorm1d(_BatchNorm): >>> input = Tensor(np.random.randint(0, 255, [3, 16]), mindspore.float32) >>> net(input) """ + def __init__(self, + num_features, + eps=1e-5, + momentum=0.9, + affine=True, + gamma_init='ones', + beta_init='zeros', + moving_mean_init='zeros', + moving_var_init='ones', + use_batch_statistics=True): + super(BatchNorm1d, self).__init__(num_features, + eps, + momentum, + affine, + gamma_init, + beta_init, + moving_mean_init, + moving_var_init, + use_batch_statistics) def _check_data_dim(self, x): if x.dim() != 2: pass @@ -266,6 +285,25 @@ class BatchNorm2d(_BatchNorm): >>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32) >>> net(input) """ + def __init__(self, + num_features, + eps=1e-5, + momentum=0.9, + affine=True, + gamma_init='ones', + beta_init='zeros', + moving_mean_init='zeros', + moving_var_init='ones', + use_batch_statistics=True): + super(BatchNorm2d, self).__init__(num_features, + eps, + momentum, + affine, + gamma_init, + beta_init, + moving_mean_init, + moving_var_init, + use_batch_statistics) def _check_data_dim(self, x): if x.dim() != 4: pass @@ -316,6 +354,30 @@ class GlobalBatchNorm(_BatchNorm): >>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32) >>> global_bn_op(input) """ + def __init__(self, + num_features, + eps=1e-5, + momentum=0.9, + affine=True, + gamma_init='ones', + beta_init='zeros', + moving_mean_init='zeros', + moving_var_init='ones', + use_batch_statistics=True, + group=1): + super(GlobalBatchNorm, self).__init__(num_features, + eps, + momentum, + affine, + gamma_init, + beta_init, + moving_mean_init, + moving_var_init, + use_batch_statistics, + group) + self.group = check_int_positive(group) + if self.group <=1: + raise ValueError("the number of group must be greater than 1.") def _check_data_dim(self, x): if x.dim == 0: pass