From 627e5bd52e1b4b0b146b401cca38d681fa89d71c Mon Sep 17 00:00:00 2001 From: Yulong Ao Date: Wed, 3 Aug 2022 19:02:16 +0800 Subject: [PATCH] Adjust the relative error of QR's grad (#44785) * Adjust the relative error of QR's grad (#42221) * Fix the format --- .../fluid/tests/unittests/test_qr_op.py | 44 ++++++++++++------- .../white_list/op_threshold_white_list.py | 1 + 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_qr_op.py b/python/paddle/fluid/tests/unittests/test_qr_op.py index 4be46837a67..338b08d1aa5 100644 --- a/python/paddle/fluid/tests/unittests/test_qr_op.py +++ b/python/paddle/fluid/tests/unittests/test_qr_op.py @@ -25,9 +25,10 @@ from op_test import OpTest class TestQrOp(OpTest): + def setUp(self): paddle.enable_static() - np.random.seed(4) + np.random.seed(7) self.op_type = "qr" a, q, r = self.get_input_and_output() self.inputs = {"X": a} @@ -74,30 +75,37 @@ class TestQrOp(OpTest): self.check_output() def test_check_grad_normal(self): - self.check_grad(['X'], ['Q', 'R']) + self.check_grad(['X'], ['Q', 'R'], + numeric_grad_delta=1e-5, + max_relative_error=1e-6) class TestQrOpCase1(TestQrOp): + def get_shape(self): return (10, 12) class TestQrOpCase2(TestQrOp): + def get_shape(self): return (16, 15) class TestQrOpCase3(TestQrOp): + def get_shape(self): return (2, 12, 16) class TestQrOpCase4(TestQrOp): + def get_shape(self): return (3, 16, 15) class TestQrOpCase5(TestQrOp): + def get_mode(self): return "complete" @@ -106,6 +114,7 @@ class TestQrOpCase5(TestQrOp): class TestQrOpCase6(TestQrOp): + def get_mode(self): return "complete" @@ -114,8 +123,10 @@ class TestQrOpCase6(TestQrOp): class TestQrAPI(unittest.TestCase): + def test_dygraph(self): paddle.disable_static() + np.random.seed(7) def run_qr_dygraph(shape, mode, dtype): if dtype == "float32": @@ -174,12 +185,13 @@ class TestQrAPI(unittest.TestCase): ] modes = ["reduced", "complete", "r"] dtypes = ["float32", "float64"] - for tensor_shape, mode, dtype in itertools.product(tensor_shapes, modes, - dtypes): + for tensor_shape, mode, dtype in itertools.product( + tensor_shapes, modes, dtypes): run_qr_dygraph(tensor_shape, mode, dtype) def test_static(self): paddle.enable_static() + np.random.seed(7) def run_qr_static(shape, mode, dtype): if dtype == "float32": @@ -216,29 +228,27 @@ class TestQrAPI(unittest.TestCase): tmp_q, tmp_r = np.linalg.qr(a[coord], mode=mode) np_q[coord] = tmp_q np_r[coord] = tmp_r - x = paddle.fluid.data( - name="input", shape=shape, dtype=dtype) + x = paddle.fluid.data(name="input", + shape=shape, + dtype=dtype) if mode == "r": r = paddle.linalg.qr(x, mode=mode) exe = fluid.Executor(place) fetches = exe.run(fluid.default_main_program(), feed={"input": a}, fetch_list=[r]) - self.assertTrue( - np.allclose( - fetches[0], np_r, atol=1e-5)) + self.assertTrue(np.allclose(fetches[0], np_r, + atol=1e-5)) else: q, r = paddle.linalg.qr(x, mode=mode) exe = fluid.Executor(place) fetches = exe.run(fluid.default_main_program(), feed={"input": a}, fetch_list=[q, r]) - self.assertTrue( - np.allclose( - fetches[0], np_q, atol=1e-5)) - self.assertTrue( - np.allclose( - fetches[1], np_r, atol=1e-5)) + self.assertTrue(np.allclose(fetches[0], np_q, + atol=1e-5)) + self.assertTrue(np.allclose(fetches[1], np_r, + atol=1e-5)) tensor_shapes = [ (3, 5), @@ -253,8 +263,8 @@ class TestQrAPI(unittest.TestCase): ] modes = ["reduced", "complete", "r"] dtypes = ["float32", "float64"] - for tensor_shape, mode, dtype in itertools.product(tensor_shapes, modes, - dtypes): + for tensor_shape, mode, dtype in itertools.product( + tensor_shapes, modes, dtypes): run_qr_static(tensor_shape, mode, dtype) diff --git a/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py b/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py index 5deca1dc5ac..91731c1dd0b 100644 --- a/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py @@ -51,6 +51,7 @@ NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST = [ 'matrix_power', \ 'cholesky_solve', \ 'solve', \ + 'qr', \ ] NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST = ['bilinear_interp',\ -- GitLab