未验证 提交 7e06541b 编写于 作者: M MayYouBeProsperous 提交者: GitHub

[Hackathon No.10] Add unit tests for Normal (#47070)

* add test for rsample

* add assert in test_backpropagation

* fix bug
上级 c1077ae8
...@@ -516,6 +516,14 @@ class NormalTest10(NormalTest): ...@@ -516,6 +516,14 @@ class NormalTest10(NormalTest):
) )
def kstest(loc, scale, samples):
# Uses the Kolmogorov-Smirnov test for goodness of fit.
ks, _ = scipy.stats.kstest(
samples, scipy.stats.norm(loc=loc, scale=scale).cdf
)
return ks < 0.02
@place(config.DEVICES) @place(config.DEVICES)
@parameterize_cls( @parameterize_cls(
(TEST_CASE_NAME, 'loc', 'scale'), [('sample', xrand((4,)), xrand((4,)))] (TEST_CASE_NAME, 'loc', 'scale'), [('sample', xrand((4,)), xrand((4,)))]
...@@ -526,9 +534,7 @@ class TestNormalSampleDygraph(unittest.TestCase): ...@@ -526,9 +534,7 @@ class TestNormalSampleDygraph(unittest.TestCase):
self.paddle_normal = Normal(loc=self.loc, scale=self.scale) self.paddle_normal = Normal(loc=self.loc, scale=self.scale)
n = 100000 n = 100000
self.sample_shape = (n,) self.sample_shape = (n,)
self.rsample_shape = (n,)
self.samples = self.paddle_normal.sample(self.sample_shape) self.samples = self.paddle_normal.sample(self.sample_shape)
self.rsamples = self.paddle_normal.rsample(self.rsample_shape)
def test_sample(self): def test_sample(self):
samples_mean = self.samples.mean(axis=0) samples_mean = self.samples.mean(axis=0)
...@@ -540,38 +546,16 @@ class TestNormalSampleDygraph(unittest.TestCase): ...@@ -540,38 +546,16 @@ class TestNormalSampleDygraph(unittest.TestCase):
samples_var, self.paddle_normal.variance, rtol=0.1, atol=0 samples_var, self.paddle_normal.variance, rtol=0.1, atol=0
) )
rsamples_mean = self.rsamples.mean(axis=0)
rsamples_var = self.rsamples.var(axis=0)
np.testing.assert_allclose(
rsamples_mean, self.paddle_normal.mean, rtol=0.1, atol=0
)
np.testing.assert_allclose(
rsamples_var, self.paddle_normal.variance, rtol=0.1, atol=0
)
batch_shape = (self.loc + self.scale).shape batch_shape = (self.loc + self.scale).shape
self.assertEqual( self.assertEqual(
self.samples.shape, list(self.sample_shape + batch_shape) self.samples.shape, list(self.sample_shape + batch_shape)
) )
self.assertEqual(
self.rsamples.shape, list(self.rsample_shape + batch_shape)
)
for i in range(len(self.scale)): for i in range(len(self.scale)):
self.assertTrue( self.assertTrue(
self._kstest(self.loc[i], self.scale[i], self.samples[:, i]) kstest(self.loc[i], self.scale[i], self.samples[:, i])
)
self.assertTrue(
self._kstest(self.loc[i], self.scale[i], self.rsamples[:, i])
) )
def _kstest(self, loc, scale, samples):
# Uses the Kolmogorov-Smirnov test for goodness of fit.
ks, _ = scipy.stats.kstest(
samples, scipy.stats.norm(loc=loc, scale=scale).cdf
)
return ks < 0.02
@place(config.DEVICES) @place(config.DEVICES)
@parameterize_cls( @parameterize_cls(
...@@ -590,17 +574,15 @@ class TestNormalSampleStaic(unittest.TestCase): ...@@ -590,17 +574,15 @@ class TestNormalSampleStaic(unittest.TestCase):
) )
n = 100000 n = 100000
self.sample_shape = (n,) self.sample_shape = (n,)
self.rsample_shape = (n,)
self.paddle_normal = Normal(loc=loc, scale=scale) self.paddle_normal = Normal(loc=loc, scale=scale)
mean = self.paddle_normal.mean mean = self.paddle_normal.mean
variance = self.paddle_normal.variance variance = self.paddle_normal.variance
samples = self.paddle_normal.sample(self.sample_shape) samples = self.paddle_normal.sample(self.sample_shape)
rsamples = self.paddle_normal.rsample(self.rsample_shape) fetch_list = [mean, variance, samples]
fetch_list = [mean, variance, samples, rsamples]
self.feeds = {'loc': self.loc, 'scale': self.scale} self.feeds = {'loc': self.loc, 'scale': self.scale}
executor.run(startup_program) executor.run(startup_program)
[self.mean, self.variance, self.samples, self.rsamples] = executor.run( [self.mean, self.variance, self.samples] = executor.run(
main_program, feed=self.feeds, fetch_list=fetch_list main_program, feed=self.feeds, fetch_list=fetch_list
) )
...@@ -610,31 +592,102 @@ class TestNormalSampleStaic(unittest.TestCase): ...@@ -610,31 +592,102 @@ class TestNormalSampleStaic(unittest.TestCase):
np.testing.assert_allclose(samples_mean, self.mean, rtol=0.1, atol=0) np.testing.assert_allclose(samples_mean, self.mean, rtol=0.1, atol=0)
np.testing.assert_allclose(samples_var, self.variance, rtol=0.1, atol=0) np.testing.assert_allclose(samples_var, self.variance, rtol=0.1, atol=0)
batch_shape = (self.loc + self.scale).shape
self.assertEqual(self.samples.shape, self.sample_shape + batch_shape)
for i in range(len(self.scale)):
self.assertTrue(
kstest(self.loc[i], self.scale[i], self.samples[:, i])
)
@place(config.DEVICES)
@parameterize_cls(
(TEST_CASE_NAME, 'loc', 'scale'), [('rsample', xrand((4,)), xrand((4,)))]
)
class TestNormalRSampleDygraph(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.loc = paddle.to_tensor(self.loc)
self.scale = paddle.to_tensor(self.scale)
self.loc.stop_gradient = False
self.scale.stop_gradient = False
self.paddle_normal = Normal(loc=self.loc, scale=self.scale)
n = 100000
self.rsample_shape = [n]
self.rsamples = self.paddle_normal.rsample(self.rsample_shape)
def test_rsample(self):
rsamples_mean = self.rsamples.mean(axis=0) rsamples_mean = self.rsamples.mean(axis=0)
rsamples_var = self.rsamples.var(axis=0) rsamples_var = self.rsamples.var(axis=0)
np.testing.assert_allclose(rsamples_mean, self.mean, rtol=0.1, atol=0)
np.testing.assert_allclose( np.testing.assert_allclose(
rsamples_var, self.variance, rtol=0.1, atol=0 rsamples_mean, self.paddle_normal.mean, rtol=0.1, atol=0
)
np.testing.assert_allclose(
rsamples_var, self.paddle_normal.variance, rtol=0.1, atol=0
) )
batch_shape = (self.loc + self.scale).shape batch_shape = (self.loc + self.scale).shape
self.assertEqual(self.samples.shape, self.sample_shape + batch_shape)
self.assertEqual(self.rsamples.shape, self.rsample_shape + batch_shape) self.assertEqual(self.rsamples.shape, self.rsample_shape + batch_shape)
for i in range(len(self.scale)): for i in range(len(self.scale)):
self.assertTrue( self.assertTrue(
self._kstest(self.loc[i], self.scale[i], self.samples[:, i]) kstest(self.loc[i], self.scale[i], self.rsamples[:, i])
) )
self.assertTrue(
self._kstest(self.loc[i], self.scale[i], self.rsamples[:, i]) def test_backpropagation(self):
grads = paddle.grad([self.rsamples], [self.loc, self.scale])
self.assertEqual(len(grads), 2)
self.assertEqual(grads[0].dtype, self.loc.dtype)
self.assertEqual(grads[0].shape, self.loc.shape)
self.assertEqual(grads[1].dtype, self.scale.dtype)
self.assertEqual(grads[1].shape, self.scale.shape)
@place(config.DEVICES)
@parameterize_cls(
(TEST_CASE_NAME, 'loc', 'scale'), [('rsample', xrand((4,)), xrand((4,)))]
)
class TestNormalRSampleStaic(unittest.TestCase):
def setUp(self):
paddle.enable_static()
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
executor = paddle.static.Executor(self.place)
with paddle.static.program_guard(main_program, startup_program):
loc = paddle.static.data('loc', self.loc.shape, self.loc.dtype)
scale = paddle.static.data(
'scale', self.scale.shape, self.scale.dtype
) )
n = 100000
self.rsample_shape = (n,)
self.paddle_normal = Normal(loc=loc, scale=scale)
mean = self.paddle_normal.mean
variance = self.paddle_normal.variance
rsamples = self.paddle_normal.rsample(self.rsample_shape)
fetch_list = [mean, variance, rsamples]
self.feeds = {'loc': self.loc, 'scale': self.scale}
executor.run(startup_program)
[self.mean, self.variance, self.rsamples] = executor.run(
main_program, feed=self.feeds, fetch_list=fetch_list
)
def _kstest(self, loc, scale, samples): def test_rsample(self):
# Uses the Kolmogorov-Smirnov test for goodness of fit. rsamples_mean = self.rsamples.mean(axis=0)
ks, _ = scipy.stats.kstest( rsamples_var = self.rsamples.var(axis=0)
samples, scipy.stats.norm(loc=loc, scale=scale).cdf np.testing.assert_allclose(rsamples_mean, self.mean, rtol=0.1, atol=0)
np.testing.assert_allclose(
rsamples_var, self.variance, rtol=0.1, atol=0
) )
return ks < 0.02
batch_shape = (self.loc + self.scale).shape
self.assertEqual(self.rsamples.shape, self.rsample_shape + batch_shape)
for i in range(len(self.scale)):
self.assertTrue(
kstest(self.loc[i], self.scale[i], self.rsamples[:, i])
)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册