未验证 提交 9b40cb87 编写于 作者: Q Qi Li 提交者: GitHub

[ROCM] fix test_matmul_v2_op (#31802) (#31828)

Co-authored-by: Nronnywang <524019753@qq.com>
上级 3f66e7de
......@@ -160,7 +160,7 @@ struct DotGradFunction<DeviceContext, T, math::DisableComplex<T>> {
const Tensor* tensor_dout, Tensor* tensor_dx,
Tensor* tensor_dy,
const paddle::framework::ExecutionContext& ctx) {
#ifdef __NVCC__
#if defined(__NVCC__) || defined(__HIPCC__)
if (1 == tensor_dout->dims().size()) {
auto dout = framework::EigenVector<T>::Flatten(*tensor_dout);
......
......@@ -67,7 +67,7 @@ class TestMatMulV2Op(OpTest):
self.trans_y = False
def init_kernel_type(self):
self.dtype = "float64"
self.dtype = "float32" if core.is_compiled_with_rocm() else "float64"
def setUp(self):
self.init_kernel_type()
......@@ -91,6 +91,9 @@ class TestMatMulV2Op(OpTest):
self.check_output()
def test_check_grad(self):
if core.is_compiled_with_rocm():
self.check_grad(['X', 'Y'], 'Out', max_relative_error=1e-2)
else:
self.check_grad(['X', 'Y'], 'Out')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册