提交 32210b08 编写于 作者: P pangyoki

change tolerance of Normal log_prob method

上级 c995a2c6
......@@ -396,11 +396,18 @@ class NormalTest(unittest.TestCase):
np_other_normal = NormalNumpy(self.other_loc_np, self.other_scale_np)
np_kl = np_normal.kl_divergence(np_other_normal)
# Because assign op does not support the input of numpy.ndarray whose dtype is FP64.
# When loc and scale are FP64 numpy.ndarray, we need to use assign op to convert it
# to FP32 Tensor. And then use cast op to convert it to a FP64 Tensor.
# There is a loss of accuracy in this conversion.
# So set the tolerance from 1e-6 to 1e-4.
log_tolerance = 1e-4
np.testing.assert_equal(sample.shape, np_sample.shape)
np.testing.assert_allclose(
entropy, np_entropy, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
log_prob, np_lp, rtol=tolerance, atol=tolerance)
log_prob, np_lp, rtol=log_tolerance, atol=log_tolerance)
np.testing.assert_allclose(probs, np_p, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(kl, np_kl, rtol=tolerance, atol=tolerance)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册