未验证 提交 627e5bd5 编写于 作者: Y Yulong Ao 提交者: GitHub

Adjust the relative error of QR's grad (#44785)

* Adjust the relative error of QR's grad (#42221)

* Fix the format
上级 cd59df5f
...@@ -25,9 +25,10 @@ from op_test import OpTest ...@@ -25,9 +25,10 @@ from op_test import OpTest
class TestQrOp(OpTest): class TestQrOp(OpTest):
def setUp(self): def setUp(self):
paddle.enable_static() paddle.enable_static()
np.random.seed(4) np.random.seed(7)
self.op_type = "qr" self.op_type = "qr"
a, q, r = self.get_input_and_output() a, q, r = self.get_input_and_output()
self.inputs = {"X": a} self.inputs = {"X": a}
...@@ -74,30 +75,37 @@ class TestQrOp(OpTest): ...@@ -74,30 +75,37 @@ class TestQrOp(OpTest):
self.check_output() self.check_output()
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad(['X'], ['Q', 'R']) self.check_grad(['X'], ['Q', 'R'],
numeric_grad_delta=1e-5,
max_relative_error=1e-6)
class TestQrOpCase1(TestQrOp): class TestQrOpCase1(TestQrOp):
def get_shape(self): def get_shape(self):
return (10, 12) return (10, 12)
class TestQrOpCase2(TestQrOp): class TestQrOpCase2(TestQrOp):
def get_shape(self): def get_shape(self):
return (16, 15) return (16, 15)
class TestQrOpCase3(TestQrOp): class TestQrOpCase3(TestQrOp):
def get_shape(self): def get_shape(self):
return (2, 12, 16) return (2, 12, 16)
class TestQrOpCase4(TestQrOp): class TestQrOpCase4(TestQrOp):
def get_shape(self): def get_shape(self):
return (3, 16, 15) return (3, 16, 15)
class TestQrOpCase5(TestQrOp): class TestQrOpCase5(TestQrOp):
def get_mode(self): def get_mode(self):
return "complete" return "complete"
...@@ -106,6 +114,7 @@ class TestQrOpCase5(TestQrOp): ...@@ -106,6 +114,7 @@ class TestQrOpCase5(TestQrOp):
class TestQrOpCase6(TestQrOp): class TestQrOpCase6(TestQrOp):
def get_mode(self): def get_mode(self):
return "complete" return "complete"
...@@ -114,8 +123,10 @@ class TestQrOpCase6(TestQrOp): ...@@ -114,8 +123,10 @@ class TestQrOpCase6(TestQrOp):
class TestQrAPI(unittest.TestCase): class TestQrAPI(unittest.TestCase):
def test_dygraph(self): def test_dygraph(self):
paddle.disable_static() paddle.disable_static()
np.random.seed(7)
def run_qr_dygraph(shape, mode, dtype): def run_qr_dygraph(shape, mode, dtype):
if dtype == "float32": if dtype == "float32":
...@@ -174,12 +185,13 @@ class TestQrAPI(unittest.TestCase): ...@@ -174,12 +185,13 @@ class TestQrAPI(unittest.TestCase):
] ]
modes = ["reduced", "complete", "r"] modes = ["reduced", "complete", "r"]
dtypes = ["float32", "float64"] dtypes = ["float32", "float64"]
for tensor_shape, mode, dtype in itertools.product(tensor_shapes, modes, for tensor_shape, mode, dtype in itertools.product(
dtypes): tensor_shapes, modes, dtypes):
run_qr_dygraph(tensor_shape, mode, dtype) run_qr_dygraph(tensor_shape, mode, dtype)
def test_static(self): def test_static(self):
paddle.enable_static() paddle.enable_static()
np.random.seed(7)
def run_qr_static(shape, mode, dtype): def run_qr_static(shape, mode, dtype):
if dtype == "float32": if dtype == "float32":
...@@ -216,29 +228,27 @@ class TestQrAPI(unittest.TestCase): ...@@ -216,29 +228,27 @@ class TestQrAPI(unittest.TestCase):
tmp_q, tmp_r = np.linalg.qr(a[coord], mode=mode) tmp_q, tmp_r = np.linalg.qr(a[coord], mode=mode)
np_q[coord] = tmp_q np_q[coord] = tmp_q
np_r[coord] = tmp_r np_r[coord] = tmp_r
x = paddle.fluid.data( x = paddle.fluid.data(name="input",
name="input", shape=shape, dtype=dtype) shape=shape,
dtype=dtype)
if mode == "r": if mode == "r":
r = paddle.linalg.qr(x, mode=mode) r = paddle.linalg.qr(x, mode=mode)
exe = fluid.Executor(place) exe = fluid.Executor(place)
fetches = exe.run(fluid.default_main_program(), fetches = exe.run(fluid.default_main_program(),
feed={"input": a}, feed={"input": a},
fetch_list=[r]) fetch_list=[r])
self.assertTrue( self.assertTrue(np.allclose(fetches[0], np_r,
np.allclose( atol=1e-5))
fetches[0], np_r, atol=1e-5))
else: else:
q, r = paddle.linalg.qr(x, mode=mode) q, r = paddle.linalg.qr(x, mode=mode)
exe = fluid.Executor(place) exe = fluid.Executor(place)
fetches = exe.run(fluid.default_main_program(), fetches = exe.run(fluid.default_main_program(),
feed={"input": a}, feed={"input": a},
fetch_list=[q, r]) fetch_list=[q, r])
self.assertTrue( self.assertTrue(np.allclose(fetches[0], np_q,
np.allclose( atol=1e-5))
fetches[0], np_q, atol=1e-5)) self.assertTrue(np.allclose(fetches[1], np_r,
self.assertTrue( atol=1e-5))
np.allclose(
fetches[1], np_r, atol=1e-5))
tensor_shapes = [ tensor_shapes = [
(3, 5), (3, 5),
...@@ -253,8 +263,8 @@ class TestQrAPI(unittest.TestCase): ...@@ -253,8 +263,8 @@ class TestQrAPI(unittest.TestCase):
] ]
modes = ["reduced", "complete", "r"] modes = ["reduced", "complete", "r"]
dtypes = ["float32", "float64"] dtypes = ["float32", "float64"]
for tensor_shape, mode, dtype in itertools.product(tensor_shapes, modes, for tensor_shape, mode, dtype in itertools.product(
dtypes): tensor_shapes, modes, dtypes):
run_qr_static(tensor_shape, mode, dtype) run_qr_static(tensor_shape, mode, dtype)
......
...@@ -51,6 +51,7 @@ NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST = [ ...@@ -51,6 +51,7 @@ NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST = [
'matrix_power', \ 'matrix_power', \
'cholesky_solve', \ 'cholesky_solve', \
'solve', \ 'solve', \
'qr', \
] ]
NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST = ['bilinear_interp',\ NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST = ['bilinear_interp',\
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册