未验证 提交 627e5bd5 编写于 作者: Y Yulong Ao 提交者: GitHub

Adjust the relative error of QR's grad (#44785)

* Adjust the relative error of QR's grad (#42221)

* Fix the format
上级 cd59df5f
......@@ -25,9 +25,10 @@ from op_test import OpTest
class TestQrOp(OpTest):
def setUp(self):
paddle.enable_static()
np.random.seed(4)
np.random.seed(7)
self.op_type = "qr"
a, q, r = self.get_input_and_output()
self.inputs = {"X": a}
......@@ -74,30 +75,37 @@ class TestQrOp(OpTest):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X'], ['Q', 'R'])
self.check_grad(['X'], ['Q', 'R'],
numeric_grad_delta=1e-5,
max_relative_error=1e-6)
class TestQrOpCase1(TestQrOp):
def get_shape(self):
return (10, 12)
class TestQrOpCase2(TestQrOp):
def get_shape(self):
return (16, 15)
class TestQrOpCase3(TestQrOp):
def get_shape(self):
return (2, 12, 16)
class TestQrOpCase4(TestQrOp):
def get_shape(self):
return (3, 16, 15)
class TestQrOpCase5(TestQrOp):
def get_mode(self):
return "complete"
......@@ -106,6 +114,7 @@ class TestQrOpCase5(TestQrOp):
class TestQrOpCase6(TestQrOp):
def get_mode(self):
return "complete"
......@@ -114,8 +123,10 @@ class TestQrOpCase6(TestQrOp):
class TestQrAPI(unittest.TestCase):
def test_dygraph(self):
paddle.disable_static()
np.random.seed(7)
def run_qr_dygraph(shape, mode, dtype):
if dtype == "float32":
......@@ -174,12 +185,13 @@ class TestQrAPI(unittest.TestCase):
]
modes = ["reduced", "complete", "r"]
dtypes = ["float32", "float64"]
for tensor_shape, mode, dtype in itertools.product(tensor_shapes, modes,
dtypes):
for tensor_shape, mode, dtype in itertools.product(
tensor_shapes, modes, dtypes):
run_qr_dygraph(tensor_shape, mode, dtype)
def test_static(self):
paddle.enable_static()
np.random.seed(7)
def run_qr_static(shape, mode, dtype):
if dtype == "float32":
......@@ -216,29 +228,27 @@ class TestQrAPI(unittest.TestCase):
tmp_q, tmp_r = np.linalg.qr(a[coord], mode=mode)
np_q[coord] = tmp_q
np_r[coord] = tmp_r
x = paddle.fluid.data(
name="input", shape=shape, dtype=dtype)
x = paddle.fluid.data(name="input",
shape=shape,
dtype=dtype)
if mode == "r":
r = paddle.linalg.qr(x, mode=mode)
exe = fluid.Executor(place)
fetches = exe.run(fluid.default_main_program(),
feed={"input": a},
fetch_list=[r])
self.assertTrue(
np.allclose(
fetches[0], np_r, atol=1e-5))
self.assertTrue(np.allclose(fetches[0], np_r,
atol=1e-5))
else:
q, r = paddle.linalg.qr(x, mode=mode)
exe = fluid.Executor(place)
fetches = exe.run(fluid.default_main_program(),
feed={"input": a},
fetch_list=[q, r])
self.assertTrue(
np.allclose(
fetches[0], np_q, atol=1e-5))
self.assertTrue(
np.allclose(
fetches[1], np_r, atol=1e-5))
self.assertTrue(np.allclose(fetches[0], np_q,
atol=1e-5))
self.assertTrue(np.allclose(fetches[1], np_r,
atol=1e-5))
tensor_shapes = [
(3, 5),
......@@ -253,8 +263,8 @@ class TestQrAPI(unittest.TestCase):
]
modes = ["reduced", "complete", "r"]
dtypes = ["float32", "float64"]
for tensor_shape, mode, dtype in itertools.product(tensor_shapes, modes,
dtypes):
for tensor_shape, mode, dtype in itertools.product(
tensor_shapes, modes, dtypes):
run_qr_static(tensor_shape, mode, dtype)
......
......@@ -51,6 +51,7 @@ NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST = [
'matrix_power', \
'cholesky_solve', \
'solve', \
'qr', \
]
NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST = ['bilinear_interp',\
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册