未验证 提交 4c80385a 编写于 作者: Y Yulong Ao 提交者: GitHub

Adjust the relative error of QR's grad (#42221)

上级 acca0352
......@@ -27,7 +27,7 @@ from op_test import OpTest
class TestQrOp(OpTest):
def setUp(self):
paddle.enable_static()
np.random.seed(4)
np.random.seed(7)
self.op_type = "qr"
a, q, r = self.get_input_and_output()
self.inputs = {"X": a}
......@@ -74,7 +74,8 @@ class TestQrOp(OpTest):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X'], ['Q', 'R'])
self.check_grad(
['X'], ['Q', 'R'], numeric_grad_delta=1e-5, max_relative_error=1e-6)
class TestQrOpCase1(TestQrOp):
......@@ -116,6 +117,7 @@ class TestQrOpCase6(TestQrOp):
class TestQrAPI(unittest.TestCase):
def test_dygraph(self):
paddle.disable_static()
np.random.seed(7)
def run_qr_dygraph(shape, mode, dtype):
if dtype == "float32":
......@@ -180,6 +182,7 @@ class TestQrAPI(unittest.TestCase):
def test_static(self):
paddle.enable_static()
np.random.seed(7)
def run_qr_static(shape, mode, dtype):
if dtype == "float32":
......
......@@ -51,6 +51,7 @@ NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST = [
'matrix_power', \
'cholesky_solve', \
'solve', \
'qr', \
]
NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST = ['bilinear_interp',\
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册