From 1f0ef42e6029e29f9ca46e81de74787a181a5280 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Thu, 21 Feb 2019 10:41:55 +0800 Subject: [PATCH] Change atol of numpy allclose --- python/paddle/fluid/framework.py | 2 +- .../tests/unittests/test_imperative_optimizer.py | 11 +++-------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 14b8339df..4ff769dd4 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1195,7 +1195,7 @@ class Block(object): if not var.persistable: del self.vars[name] - self.ops.clear() + del self.ops[:] def all_parameters(self): return list(self.iter_parameters()) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py b/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py index 3bcfdac6c..bde691652 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py @@ -114,11 +114,7 @@ class TestImperativeMnist(unittest.TestCase): dy_param_init_value = {} for epoch in range(epoch_num): - print("epoch", epoch) for batch_id, data in enumerate(train_reader()): - # if batch_id >= batch_num: - # break - dy_x_data = np.array( [x[0].reshape(1, 28, 28) for x in data]).astype('float32') @@ -186,9 +182,6 @@ class TestImperativeMnist(unittest.TestCase): for epoch in range(epoch_num): for batch_id, data in enumerate(train_reader()): - # if batch_id >= batch_num: - # break - static_x_data = np.array( [x[0].reshape(1, 28, 28) for x in data]).astype('float32') @@ -209,13 +202,15 @@ class TestImperativeMnist(unittest.TestCase): static_param_value[static_param_name_list[i - 1]] = out[ i] + self.assertTrue(np.allclose(dy_x_data.all(), static_x_data.all())) + for key, value in six.iteritems(static_param_init_value): self.assertTrue(np.allclose(value, dy_param_init_value[key])) self.assertTrue(np.allclose(static_out, dy_out)) for key, value in six.iteritems(static_param_value): - self.assertTrue(np.allclose(value, dy_param_value[key])) + self.assertTrue(np.allclose(value, dy_param_value[key], atol=1e-6)) if __name__ == '__main__': -- GitLab