提交 e8b26dbd 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3151 Fix verification of BatchNorm2d input is 4D.

Merge pull request !3151 from liuxiao93/fix-BatchNorm2d
...@@ -44,7 +44,8 @@ class _BatchNorm(Cell): ...@@ -44,7 +44,8 @@ class _BatchNorm(Cell):
moving_mean_init='zeros', moving_mean_init='zeros',
moving_var_init='ones', moving_var_init='ones',
use_batch_statistics=None, use_batch_statistics=None,
device_num_each_group=1): device_num_each_group=1,
input_dims='1d'):
super(_BatchNorm, self).__init__() super(_BatchNorm, self).__init__()
if num_features < 1: if num_features < 1:
raise ValueError("num_features must be at least 1") raise ValueError("num_features must be at least 1")
...@@ -55,6 +56,7 @@ class _BatchNorm(Cell): ...@@ -55,6 +56,7 @@ class _BatchNorm(Cell):
self.use_batch_statistics = use_batch_statistics self.use_batch_statistics = use_batch_statistics
self.num_features = num_features self.num_features = num_features
self.eps = eps self.eps = eps
self.input_dims = input_dims
self.moving_mean = Parameter(initializer( self.moving_mean = Parameter(initializer(
moving_mean_init, num_features), name="mean", requires_grad=False) moving_mean_init, num_features), name="mean", requires_grad=False)
self.moving_variance = Parameter(initializer( self.moving_variance = Parameter(initializer(
...@@ -145,6 +147,8 @@ class _BatchNorm(Cell): ...@@ -145,6 +147,8 @@ class _BatchNorm(Cell):
return y return y
def construct(self, x): def construct(self, x):
if self.input_dims == '2d':
_shape_check(self.shape(x))
if self.use_batch_statistics is None: if self.use_batch_statistics is None:
flag = self.training flag = self.training
else: else:
...@@ -253,10 +257,10 @@ class BatchNorm1d(_BatchNorm): ...@@ -253,10 +257,10 @@ class BatchNorm1d(_BatchNorm):
mean and variance. Default: None. mean and variance. Default: None.
Inputs: Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. - **input** (Tensor) - Tensor of shape :math:`(N, C_{in})`.
Outputs: Outputs:
Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`. Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out})`.
Examples: Examples:
>>> net = nn.BatchNorm1d(num_features=16) >>> net = nn.BatchNorm1d(num_features=16)
...@@ -282,7 +286,8 @@ class BatchNorm1d(_BatchNorm): ...@@ -282,7 +286,8 @@ class BatchNorm1d(_BatchNorm):
beta_init, beta_init,
moving_mean_init, moving_mean_init,
moving_var_init, moving_var_init,
use_batch_statistics) use_batch_statistics,
input_dims='1d')
def _check_data_dim(self, x): def _check_data_dim(self, x):
if x.dim() != 2: if x.dim() != 2:
...@@ -357,7 +362,8 @@ class BatchNorm2d(_BatchNorm): ...@@ -357,7 +362,8 @@ class BatchNorm2d(_BatchNorm):
beta_init, beta_init,
moving_mean_init, moving_mean_init,
moving_var_init, moving_var_init,
use_batch_statistics) use_batch_statistics,
input_dims='2d')
def _check_data_dim(self, x): def _check_data_dim(self, x):
if x.dim() != 4: if x.dim() != 4:
......
...@@ -2931,7 +2931,7 @@ class Round(PrimitiveWithInfer): ...@@ -2931,7 +2931,7 @@ class Round(PrimitiveWithInfer):
class Tan(PrimitiveWithInfer): class Tan(PrimitiveWithInfer):
""" """
Computes tan of `input_x` element-wise. Computes tangent of `input_x` element-wise.
Inputs: Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
......
...@@ -46,7 +46,7 @@ class GradWrap(nn.Cell): ...@@ -46,7 +46,7 @@ class GradWrap(nn.Cell):
def bn_with_initialize(out_channels): def bn_with_initialize(out_channels):
bn = nn.BatchNorm2d(out_channels, momentum=0.1, eps=1e-5) bn = nn.BatchNorm1d(out_channels, momentum=0.1, eps=1e-5)
return bn return bn
......
...@@ -40,7 +40,7 @@ class NetWithLoss(nn.Cell): ...@@ -40,7 +40,7 @@ class NetWithLoss(nn.Cell):
class Blockcell(nn.Cell): class Blockcell(nn.Cell):
def __init__(self): def __init__(self):
super(Blockcell, self).__init__() super(Blockcell, self).__init__()
self.bn = nn.BatchNorm2d(64, momentum=0.9) self.bn = nn.BatchNorm1d(64, momentum=0.9)
def construct(self, x): def construct(self, x):
out = self.bn(x) out = self.bn(x)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册