提交 8705545b 编写于 作者: L liuxiao93

Fix bug about output of AddNGrad.

上级 32921ea3
...@@ -149,6 +149,8 @@ class _BatchNorm(Cell): ...@@ -149,6 +149,8 @@ class _BatchNorm(Cell):
def construct(self, x): def construct(self, x):
if self.input_dims == '2d': if self.input_dims == '2d':
_shape_check(self.shape(x)) _shape_check(self.shape(x))
if self.input_dims == '1d':
_shape_check_2d(self.shape(x))
if self.use_batch_statistics is None: if self.use_batch_statistics is None:
flag = self.training flag = self.training
else: else:
...@@ -200,6 +202,12 @@ def _channel_check(channel, num_channel): ...@@ -200,6 +202,12 @@ def _channel_check(channel, num_channel):
raise ValueError("the input channel is not equal with num_channel") raise ValueError("the input channel is not equal with num_channel")
@constexpr
def _shape_check_2d(input_shape):
if len(input_shape) != 2:
raise ValueError("The input must has 2 dims.")
@constexpr @constexpr
def _shape_check(in_shape): def _shape_check(in_shape):
if len(in_shape) != 4: if len(in_shape) != 4:
......
...@@ -980,7 +980,8 @@ def get_bprop_scalar_accumulatenv2(self): ...@@ -980,7 +980,8 @@ def get_bprop_scalar_accumulatenv2(self):
dx = () dx = ()
for _ in range(len(x)): for _ in range(len(x)):
dx = dx + (dout,) dx = dx + (dout,)
return dx return (dx,)
return bprop return bprop
...@@ -992,7 +993,7 @@ def get_bprop_scalar_addn(self): ...@@ -992,7 +993,7 @@ def get_bprop_scalar_addn(self):
dx = () dx = ()
for _ in range(len(x)): for _ in range(len(x)):
dx = dx + (dout,) dx = dx + (dout,)
return dx return (dx,)
return bprop return bprop
......
...@@ -1671,13 +1671,11 @@ test_case_array_ops = [ ...@@ -1671,13 +1671,11 @@ test_case_array_ops = [
('AddN', { ('AddN', {
'block': NetForTupleInput(P.AddN()), 'block': NetForTupleInput(P.AddN()),
'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5]], 'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5]],
'desc_bprop': [[2, 3, 3, 5]], 'desc_bprop': [[2, 3, 3, 5]]}),
'skip': ['backward']}),
('AccumulateNV2', { ('AccumulateNV2', {
'block': NetForTupleInput(P.AccumulateNV2()), 'block': NetForTupleInput(P.AccumulateNV2()),
'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5]], 'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5]],
'desc_bprop': [[2, 3, 3, 5]], 'desc_bprop': [[2, 3, 3, 5]]}),
'skip': ['backward']}),
('Shape', { ('Shape', {
'block': P.Shape(), 'block': P.Shape(),
'desc_inputs': [[3, 3, 2, 2]], 'desc_inputs': [[3, 3, 2, 2]],
......
...@@ -67,10 +67,10 @@ def test_bn2d(): ...@@ -67,10 +67,10 @@ def test_bn2d():
def test_bn1d(): def test_bn1d():
"""ut of nn.BatchNorm1d""" """ut of nn.BatchNorm1d"""
bn = nn.BatchNorm1d(3) bn = nn.BatchNorm1d(3)
input_data = Tensor(np.random.randint(0, 1, [1, 3, 100, 100]).astype(np.float32)) input_data = Tensor(np.random.randint(0, 1, [1, 3]).astype(np.float32))
output = bn(input_data) output = bn(input_data)
output_np = output.asnumpy() output_np = output.asnumpy()
assert isinstance(output_np[0][0][0][0], (np.float32, np.float64)) assert isinstance(output_np[0][0], (np.float32, np.float64))
def test_bn2d_train(): def test_bn2d_train():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册