提交 1331c9e1 编写于 作者: L LielinJiang 提交者: qingqing01

fix distributions unittest bug, test=develop (#19012)

上级 77572b70
......@@ -234,38 +234,59 @@ class DistributionTest(unittest.TestCase):
fetch_list=fetch_list)
np.testing.assert_allclose(
output_sample_float.shape, gt_sample_float.shape, rtol=tolerance)
output_sample_float.shape,
gt_sample_float.shape,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_sample_float_np_broadcast.shape,
gt_sample_float_np_broadcast.shape,
rtol=tolerance)
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_sample_np.shape, gt_sample_np.shape, rtol=tolerance)
output_sample_np.shape,
gt_sample_np.shape,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_sample_variable.shape, gt_sample_np.shape, rtol=tolerance)
output_sample_variable.shape,
gt_sample_np.shape,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_entropy_float, gt_entropy_float, rtol=tolerance)
output_entropy_float,
gt_entropy_float,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_entropy_float_np_broadcast,
gt_entropy_float_np_broadcast,
rtol=tolerance)
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_entropy_np, gt_entropy, rtol=tolerance)
output_entropy_np, gt_entropy, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_entropy_variable, gt_entropy, rtol=tolerance)
output_entropy_variable, gt_entropy, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_lp_float_np_broadcast,
gt_lp_float_np_broadcast,
rtol=tolerance)
np.testing.assert_allclose(output_lp_np, gt_lp, rtol=tolerance)
np.testing.assert_allclose(output_lp_variable, gt_lp, rtol=tolerance)
np.testing.assert_allclose(output_kl_float, gt_kl_float, rtol=tolerance)
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_lp_np, gt_lp, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_lp_variable, gt_lp, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_kl_float, gt_kl_float, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_kl_float_np_broadcast,
gt_kl_float_np_broadcast,
rtol=tolerance)
np.testing.assert_allclose(output_kl_np, gt_kl, rtol=tolerance)
np.testing.assert_allclose(output_kl_variable, gt_kl, rtol=tolerance)
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_kl_np, gt_kl, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_kl_variable, gt_kl, rtol=tolerance, atol=tolerance)
def build_uniform_program(self, test_program, batch_size, dims, low_float,
high_float, high_np, low_np, values_np):
......@@ -346,31 +367,48 @@ class DistributionTest(unittest.TestCase):
fetch_list=fetch_list)
np.testing.assert_allclose(
output_sample_float.shape, gt_sample_float.shape, rtol=tolerance)
output_sample_float.shape,
gt_sample_float.shape,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_sample_float_np_broadcast.shape,
gt_sample_float_np_broadcast.shape,
rtol=tolerance)
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_sample_np.shape, gt_sample_np.shape, rtol=tolerance)
output_sample_np.shape,
gt_sample_np.shape,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_sample_variable.shape, gt_sample_np.shape, rtol=tolerance)
output_sample_variable.shape,
gt_sample_np.shape,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_entropy_float, gt_entropy_float, rtol=tolerance)
output_entropy_float,
gt_entropy_float,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_entropy_float_np_broadcast,
gt_entropy_float_np_broadcast,
rtol=tolerance)
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_entropy_np, gt_entropy, rtol=tolerance)
output_entropy_np, gt_entropy, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_entropy_variable, gt_entropy, rtol=tolerance)
output_entropy_variable, gt_entropy, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_lp_float_np_broadcast,
gt_lp_float_np_broadcast,
rtol=tolerance)
np.testing.assert_allclose(output_lp_np, gt_lp, rtol=tolerance)
np.testing.assert_allclose(output_lp_variable, gt_lp, rtol=tolerance)
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_lp_np, gt_lp, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_lp_variable, gt_lp, rtol=tolerance, atol=tolerance)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册