From 008ea1ae2e1ccf88dfec76ff1b4fbd0e5fbe0e66 Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Thu, 16 Jan 2020 11:26:22 +0800 Subject: [PATCH] =?UTF-8?q?Fix=20test=20layers=20=E5=8D=95=E6=B5=8B?= =?UTF-8?q?=E9=9A=8F=E6=9C=BA=E6=8C=82=E9=97=AE=E9=A2=98=20(#22278)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix test_layers compare static graph and dygraph result; test=develop * fix test_layers random error; test=develop --- python/paddle/fluid/tests/unittests/test_layers.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 384acc2360a..8e2928f93da 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -1558,6 +1558,7 @@ class TestBook(LayerTest): "make_sampled_softmax_with_cross_entropy", "make_sampling_id", "make_uniform_random_batch_size_like" }) + self.all_close_compare = set({"make_spectral_norm"}) def test_all_layers(self): attrs = (getattr(self, name) for name in dir(self)) @@ -1594,9 +1595,18 @@ class TestBook(LayerTest): dy_result = dy_result[0] dy_result_value = dy_result.numpy() + if method.__name__ in self.all_close_compare: + self.assertTrue( + np.allclose( + static_result[0], dy_result_value, atol=0, rtol=1e-05), + "Result of function [{}] compare failed".format( + method.__name__)) + continue + if method.__name__ not in self.not_compare_static_dygraph_set: self.assertTrue( - np.array_equal(static_result[0], dy_result_value)) + np.array_equal(static_result[0], dy_result_value), + "Result of function [{}] not equal".format(method.__name__)) def _get_np_data(self, shape, dtype, append_batch_size=True): np.random.seed(self.seed) -- GitLab