提交 6d3709eb 编写于 作者: Z zhaojichen

fix batchnorm bug

上级 b6408fec
...@@ -33,7 +33,6 @@ class _BatchNorm(Cell): ...@@ -33,7 +33,6 @@ class _BatchNorm(Cell):
@cell_attr_register @cell_attr_register
def __init__(self, def __init__(self,
num_features, num_features,
group=1,
eps=1e-5, eps=1e-5,
momentum=0.9, momentum=0.9,
affine=True, affine=True,
...@@ -41,7 +40,8 @@ class _BatchNorm(Cell): ...@@ -41,7 +40,8 @@ class _BatchNorm(Cell):
beta_init='zeros', beta_init='zeros',
moving_mean_init='zeros', moving_mean_init='zeros',
moving_var_init='ones', moving_var_init='ones',
use_batch_statistics=True): use_batch_statistics=True,
group=1):
super(_BatchNorm, self).__init__() super(_BatchNorm, self).__init__()
if num_features < 1: if num_features < 1:
raise ValueError("num_features must be at least 1") raise ValueError("num_features must be at least 1")
...@@ -214,6 +214,25 @@ class BatchNorm1d(_BatchNorm): ...@@ -214,6 +214,25 @@ class BatchNorm1d(_BatchNorm):
>>> input = Tensor(np.random.randint(0, 255, [3, 16]), mindspore.float32) >>> input = Tensor(np.random.randint(0, 255, [3, 16]), mindspore.float32)
>>> net(input) >>> net(input)
""" """
def __init__(self,
num_features,
eps=1e-5,
momentum=0.9,
affine=True,
gamma_init='ones',
beta_init='zeros',
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=True):
super(BatchNorm1d, self).__init__(num_features,
eps,
momentum,
affine,
gamma_init,
beta_init,
moving_mean_init,
moving_var_init,
use_batch_statistics)
def _check_data_dim(self, x): def _check_data_dim(self, x):
if x.dim() != 2: if x.dim() != 2:
pass pass
...@@ -266,6 +285,25 @@ class BatchNorm2d(_BatchNorm): ...@@ -266,6 +285,25 @@ class BatchNorm2d(_BatchNorm):
>>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32) >>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
>>> net(input) >>> net(input)
""" """
def __init__(self,
num_features,
eps=1e-5,
momentum=0.9,
affine=True,
gamma_init='ones',
beta_init='zeros',
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=True):
super(BatchNorm2d, self).__init__(num_features,
eps,
momentum,
affine,
gamma_init,
beta_init,
moving_mean_init,
moving_var_init,
use_batch_statistics)
def _check_data_dim(self, x): def _check_data_dim(self, x):
if x.dim() != 4: if x.dim() != 4:
pass pass
...@@ -316,6 +354,30 @@ class GlobalBatchNorm(_BatchNorm): ...@@ -316,6 +354,30 @@ class GlobalBatchNorm(_BatchNorm):
>>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32) >>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
>>> global_bn_op(input) >>> global_bn_op(input)
""" """
def __init__(self,
num_features,
eps=1e-5,
momentum=0.9,
affine=True,
gamma_init='ones',
beta_init='zeros',
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=True,
group=1):
super(GlobalBatchNorm, self).__init__(num_features,
eps,
momentum,
affine,
gamma_init,
beta_init,
moving_mean_init,
moving_var_init,
use_batch_statistics,
group)
self.group = check_int_positive(group)
if self.group <=1:
raise ValueError("the number of group must be greater than 1.")
def _check_data_dim(self, x): def _check_data_dim(self, x):
if x.dim == 0: if x.dim == 0:
pass pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册