未验证 提交 270699e6 编写于 作者: R ronnywang 提交者: GitHub

[ROCM] fix test_matmul_v2_op (#31802)

上级 1eb927f9
...@@ -160,7 +160,7 @@ struct DotGradFunction<DeviceContext, T, math::DisableComplex<T>> { ...@@ -160,7 +160,7 @@ struct DotGradFunction<DeviceContext, T, math::DisableComplex<T>> {
const Tensor* tensor_dout, Tensor* tensor_dx, const Tensor* tensor_dout, Tensor* tensor_dx,
Tensor* tensor_dy, Tensor* tensor_dy,
const paddle::framework::ExecutionContext& ctx) { const paddle::framework::ExecutionContext& ctx) {
#ifdef __NVCC__ #if defined(__NVCC__) || defined(__HIPCC__)
if (1 == tensor_dout->dims().size()) { if (1 == tensor_dout->dims().size()) {
auto dout = framework::EigenVector<T>::Flatten(*tensor_dout); auto dout = framework::EigenVector<T>::Flatten(*tensor_dout);
......
...@@ -67,7 +67,7 @@ class TestMatMulV2Op(OpTest): ...@@ -67,7 +67,7 @@ class TestMatMulV2Op(OpTest):
self.trans_y = False self.trans_y = False
def init_kernel_type(self): def init_kernel_type(self):
self.dtype = "float64" self.dtype = "float32" if core.is_compiled_with_rocm() else "float64"
def setUp(self): def setUp(self):
self.init_kernel_type() self.init_kernel_type()
...@@ -91,7 +91,10 @@ class TestMatMulV2Op(OpTest): ...@@ -91,7 +91,10 @@ class TestMatMulV2Op(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X', 'Y'], 'Out') if core.is_compiled_with_rocm():
self.check_grad(['X', 'Y'], 'Out', max_relative_error=1e-2)
else:
self.check_grad(['X', 'Y'], 'Out')
class TestMatMuklOp2(TestMatMulV2Op): class TestMatMuklOp2(TestMatMulV2Op):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册