未验证 提交 b71833ea 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[UT]fix test_poisson op random fail (#44763)

修复poisson op单测随机挂

原因:由于随机OP的无法直接验证数值正确性,该单测随机采样100万个样本,统计落到直方图各区间的数量,计算出粗略的概率密度函数,与标准概率密度函数对比,这种测试方式会有一定误差。
当采样数量越小,误差越大,因此该PR增大采样样本数量(100万->200万),误差进一步减小在rtol范围内。
上级 684b12ee
...@@ -39,13 +39,14 @@ def output_hist(out, lam, a, b): ...@@ -39,13 +39,14 @@ def output_hist(out, lam, a, b):
class TestPoissonOp1(OpTest): class TestPoissonOp1(OpTest):
def setUp(self): def setUp(self):
self.op_type = "poisson" self.op_type = "poisson"
self.config() self.config()
self.attrs = {} self.attrs = {}
self.inputs = {'X': np.full([1024, 1024], self.lam, dtype=self.dtype)} self.inputs = {'X': np.full([2048, 1024], self.lam, dtype=self.dtype)}
self.outputs = {'Out': np.ones([1024, 1024], dtype=self.dtype)} self.outputs = {'Out': np.ones([2048, 1024], dtype=self.dtype)}
def config(self): def config(self):
self.lam = 10 self.lam = 10
...@@ -55,10 +56,8 @@ class TestPoissonOp1(OpTest): ...@@ -55,10 +56,8 @@ class TestPoissonOp1(OpTest):
def verify_output(self, outs): def verify_output(self, outs):
hist, prob = output_hist(np.array(outs[0]), self.lam, self.a, self.b) hist, prob = output_hist(np.array(outs[0]), self.lam, self.a, self.b)
self.assertTrue( self.assertTrue(np.allclose(hist, prob, rtol=0.01),
np.allclose( "actual: {}, expected: {}".format(hist, prob))
hist, prob, rtol=0.01),
"actual: {}, expected: {}".format(hist, prob))
def test_check_output(self): def test_check_output(self):
self.check_output_customized(self.verify_output) self.check_output_customized(self.verify_output)
...@@ -67,22 +66,23 @@ class TestPoissonOp1(OpTest): ...@@ -67,22 +66,23 @@ class TestPoissonOp1(OpTest):
self.check_grad( self.check_grad(
['X'], ['X'],
'Out', 'Out',
user_defined_grads=[np.zeros( user_defined_grads=[np.zeros([2048, 1024], dtype=self.dtype)],
[1024, 1024], dtype=self.dtype)],
user_defined_grad_outputs=[ user_defined_grad_outputs=[
np.random.rand(1024, 1024).astype(self.dtype) np.random.rand(2048, 1024).astype(self.dtype)
]) ])
class TestPoissonOp2(TestPoissonOp1): class TestPoissonOp2(TestPoissonOp1):
def config(self): def config(self):
self.lam = 5 self.lam = 5
self.a = 1 self.a = 1
self.b = 9 self.b = 8
self.dtype = "float32" self.dtype = "float32"
class TestPoissonAPI(unittest.TestCase): class TestPoissonAPI(unittest.TestCase):
def test_static(self): def test_static(self):
with paddle.static.program_guard(paddle.static.Program(), with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()): paddle.static.Program()):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册