提交 6d3709eb 编写于 作者: Z zhaojichen

fix batchnorm bug

上级 b6408fec
......@@ -33,7 +33,6 @@ class _BatchNorm(Cell):
@cell_attr_register
def __init__(self,
num_features,
group=1,
eps=1e-5,
momentum=0.9,
affine=True,
......@@ -41,7 +40,8 @@ class _BatchNorm(Cell):
beta_init='zeros',
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=True):
use_batch_statistics=True,
group=1):
super(_BatchNorm, self).__init__()
if num_features < 1:
raise ValueError("num_features must be at least 1")
......@@ -214,6 +214,25 @@ class BatchNorm1d(_BatchNorm):
>>> input = Tensor(np.random.randint(0, 255, [3, 16]), mindspore.float32)
>>> net(input)
"""
def __init__(self,
num_features,
eps=1e-5,
momentum=0.9,
affine=True,
gamma_init='ones',
beta_init='zeros',
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=True):
super(BatchNorm1d, self).__init__(num_features,
eps,
momentum,
affine,
gamma_init,
beta_init,
moving_mean_init,
moving_var_init,
use_batch_statistics)
def _check_data_dim(self, x):
if x.dim() != 2:
pass
......@@ -266,6 +285,25 @@ class BatchNorm2d(_BatchNorm):
>>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
>>> net(input)
"""
def __init__(self,
num_features,
eps=1e-5,
momentum=0.9,
affine=True,
gamma_init='ones',
beta_init='zeros',
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=True):
super(BatchNorm2d, self).__init__(num_features,
eps,
momentum,
affine,
gamma_init,
beta_init,
moving_mean_init,
moving_var_init,
use_batch_statistics)
def _check_data_dim(self, x):
if x.dim() != 4:
pass
......@@ -316,6 +354,30 @@ class GlobalBatchNorm(_BatchNorm):
>>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
>>> global_bn_op(input)
"""
def __init__(self,
num_features,
eps=1e-5,
momentum=0.9,
affine=True,
gamma_init='ones',
beta_init='zeros',
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=True,
group=1):
super(GlobalBatchNorm, self).__init__(num_features,
eps,
momentum,
affine,
gamma_init,
beta_init,
moving_mean_init,
moving_var_init,
use_batch_statistics,
group)
self.group = check_int_positive(group)
if self.group <=1:
raise ValueError("the number of group must be greater than 1.")
def _check_data_dim(self, x):
if x.dim == 0:
pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册