未验证 提交 0581d74d 编写于 作者: C Chen Weihang 提交者: GitHub

try to fix test imperative se resnet, test=develop (#23700)

上级 a7b8d46f
......@@ -286,7 +286,6 @@ class SeResNeXt(fluid.dygraph.Layer):
for bottleneck_block in self.bottleneck_block_list:
y = bottleneck_block(y)
y = self.pool2d_avg(y)
y = fluid.layers.dropout(y, dropout_prob=0.2)
y = fluid.layers.reshape(y, shape=[-1, self.pool2d_avg_output])
y = self.out(y)
return y
......@@ -352,12 +351,13 @@ class TestImperativeResneXt(unittest.TestCase):
dy_param_init_value[param.name] = param.numpy()
avg_loss.backward()
#dy_grad_value = {}
#for param in se_resnext.parameters():
# if param.trainable:
# np_array = np.array(param._grad_ivar().value()
# .get_tensor())
# dy_grad_value[param.name + core.grad_var_suffix()] = np_array
dy_grad_value = {}
for param in se_resnext.parameters():
if param.trainable:
np_array = np.array(param._grad_ivar().value()
.get_tensor())
dy_grad_value[param.name + core.grad_var_suffix(
)] = np_array
optimizer.minimize(avg_loss)
se_resnext.clear_gradients()
......@@ -442,6 +442,7 @@ class TestImperativeResneXt(unittest.TestCase):
len(static_grad_name_list) + grad_start_pos):
static_grad_value[static_grad_name_list[
i - grad_start_pos]] = out[i]
self.assertTrue(np.allclose(static_out, dy_out))
self.assertEqual(len(dy_param_init_value), len(static_param_init_value))
......@@ -450,12 +451,13 @@ class TestImperativeResneXt(unittest.TestCase):
self.assertTrue(np.allclose(value, dy_param_init_value[key]))
self.assertTrue(np.isfinite(value.all()))
self.assertFalse(np.isnan(value.any()))
# FIXME(Yancey1989): np.array(_ivar.value().get_tensor()) leads to memory lake
#self.assertEqual(len(dy_grad_value), len(static_grad_value))
#for key, value in six.iteritems(static_grad_value):
# self.assertTrue(np.allclose(value, dy_grad_value[key]))
# self.assertTrue(np.isfinite(value.all()))
# self.assertFalse(np.isnan(value.any()))
self.assertEqual(len(dy_grad_value), len(static_grad_value))
for key, value in six.iteritems(static_grad_value):
self.assertTrue(np.allclose(value, dy_grad_value[key]))
self.assertTrue(np.isfinite(value.all()))
self.assertFalse(np.isnan(value.any()))
self.assertEqual(len(dy_param_value), len(static_param_value))
for key, value in six.iteritems(static_param_value):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册