未验证 提交 008ea1ae 编写于 作者: H hong 提交者: GitHub

Fix test layers 单测随机挂问题 (#22278)

* fix test_layers compare static graph and dygraph result; test=develop

* fix test_layers random error; test=develop
上级 35efbe6d
...@@ -1558,6 +1558,7 @@ class TestBook(LayerTest): ...@@ -1558,6 +1558,7 @@ class TestBook(LayerTest):
"make_sampled_softmax_with_cross_entropy", "make_sampling_id", "make_sampled_softmax_with_cross_entropy", "make_sampling_id",
"make_uniform_random_batch_size_like" "make_uniform_random_batch_size_like"
}) })
self.all_close_compare = set({"make_spectral_norm"})
def test_all_layers(self): def test_all_layers(self):
attrs = (getattr(self, name) for name in dir(self)) attrs = (getattr(self, name) for name in dir(self))
...@@ -1594,9 +1595,18 @@ class TestBook(LayerTest): ...@@ -1594,9 +1595,18 @@ class TestBook(LayerTest):
dy_result = dy_result[0] dy_result = dy_result[0]
dy_result_value = dy_result.numpy() dy_result_value = dy_result.numpy()
if method.__name__ in self.all_close_compare:
self.assertTrue(
np.allclose(
static_result[0], dy_result_value, atol=0, rtol=1e-05),
"Result of function [{}] compare failed".format(
method.__name__))
continue
if method.__name__ not in self.not_compare_static_dygraph_set: if method.__name__ not in self.not_compare_static_dygraph_set:
self.assertTrue( self.assertTrue(
np.array_equal(static_result[0], dy_result_value)) np.array_equal(static_result[0], dy_result_value),
"Result of function [{}] not equal".format(method.__name__))
def _get_np_data(self, shape, dtype, append_batch_size=True): def _get_np_data(self, shape, dtype, append_batch_size=True):
np.random.seed(self.seed) np.random.seed(self.seed)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册