提交 1f0ef42e 编写于 作者: M minqiyang

Change atol of numpy allclose

上级 f53e1d5c
......@@ -1195,7 +1195,7 @@ class Block(object):
if not var.persistable:
del self.vars[name]
self.ops.clear()
del self.ops[:]
def all_parameters(self):
return list(self.iter_parameters())
......
......@@ -114,11 +114,7 @@ class TestImperativeMnist(unittest.TestCase):
dy_param_init_value = {}
for epoch in range(epoch_num):
print("epoch", epoch)
for batch_id, data in enumerate(train_reader()):
# if batch_id >= batch_num:
# break
dy_x_data = np.array(
[x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
......@@ -186,9 +182,6 @@ class TestImperativeMnist(unittest.TestCase):
for epoch in range(epoch_num):
for batch_id, data in enumerate(train_reader()):
# if batch_id >= batch_num:
# break
static_x_data = np.array(
[x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
......@@ -209,13 +202,15 @@ class TestImperativeMnist(unittest.TestCase):
static_param_value[static_param_name_list[i - 1]] = out[
i]
self.assertTrue(np.allclose(dy_x_data.all(), static_x_data.all()))
for key, value in six.iteritems(static_param_init_value):
self.assertTrue(np.allclose(value, dy_param_init_value[key]))
self.assertTrue(np.allclose(static_out, dy_out))
for key, value in six.iteritems(static_param_value):
self.assertTrue(np.allclose(value, dy_param_value[key]))
self.assertTrue(np.allclose(value, dy_param_value[key], atol=1e-6))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册