未验证 提交 aa4a56fc 编写于 作者: Z zhulei 提交者: GitHub

[Rocm] fix test of random_crop_op & logsumexp (#32824)

* [Rocm] fix test of random_crop_op

* [Rocm] fix test of random_crop_op

* [Rocm] fix test of random_crop_op & simple_rnn_op

* [Rocm] fix test of random_crop_op & simple_rnn_op & logsumexp

* [Rocm] fix test of random_crop_op & simple_rnn_op & logsumexp

* [Rocm] fix test of random_crop_op & simple_rnn_op & logsumexp

* [Rocm] fix test of random_crop_op & logsumexp
上级 67c2700f
...@@ -59,16 +59,6 @@ HOSTDEVICE inline void StridedMemcpy(const T* x, const size_t* x_dims, T* out, ...@@ -59,16 +59,6 @@ HOSTDEVICE inline void StridedMemcpy(const T* x, const size_t* x_dims, T* out,
size_t offset_i = offsets[i]; size_t offset_i = offsets[i];
if (i == rank - 1) { if (i == rank - 1) {
PADDLE_ENFORCE(x_stride == 1,
"When i:%d == rank:%d - 1, x_stride of random_crop_op "
"expected to be 1, but got %ld. Please check input "
"value.",
i, rank, x_stride);
PADDLE_ENFORCE(out_stride == 1,
"When i:%d == rank:%d - 1, out_stride of random_crop_op "
"expected to be 1, but got %ld. Please check input "
"value.",
i, rank, out_stride);
x += offset_i; x += offset_i;
for (size_t j = 0; j < out_dim_i; ++j) { for (size_t j = 0; j < out_dim_i; ++j) {
*out++ = *x++; *out++ = *x++;
......
...@@ -50,15 +50,30 @@ class TestLogsumexp(OpTest): ...@@ -50,15 +50,30 @@ class TestLogsumexp(OpTest):
'keepdim': self.keepdim, 'keepdim': self.keepdim,
'reduce_all': self.reduce_all 'reduce_all': self.reduce_all
} }
self.user_defined_grads = None
self.user_defined_grad_outputs = None
self.set_attrs_addition()
def set_attrs(self): def set_attrs(self):
pass pass
def set_attrs_addition(self):
pass
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], ['Out']) self.check_grad(
['X'], ['Out'],
user_defined_grads=self.user_defined_grads,
user_defined_grad_outputs=self.user_defined_grad_outputs)
def calc_grad(self):
dy = np.ones(1, dtype=self.dtype)
x = self.inputs['X']
y = self.outputs['Out']
return dy * np.exp(x - y)
class TestLogsumexp_shape(TestLogsumexp): class TestLogsumexp_shape(TestLogsumexp):
...@@ -75,6 +90,11 @@ class TestLogsumexp_axis_all(TestLogsumexp): ...@@ -75,6 +90,11 @@ class TestLogsumexp_axis_all(TestLogsumexp):
def set_attrs(self): def set_attrs(self):
self.axis = [0, 1, 2, 3] self.axis = [0, 1, 2, 3]
def set_attrs_addition(self):
if paddle.fluid.core.is_compiled_with_rocm():
self.user_defined_grads = [self.calc_grad()]
self.user_defined_grad_outputs = [np.ones(1, dtype=self.dtype)]
class TestLogsumexp_keepdim(TestLogsumexp): class TestLogsumexp_keepdim(TestLogsumexp):
def set_attrs(self): def set_attrs(self):
...@@ -85,6 +105,11 @@ class TestLogsumexp_reduce_all(TestLogsumexp): ...@@ -85,6 +105,11 @@ class TestLogsumexp_reduce_all(TestLogsumexp):
def set_attrs(self): def set_attrs(self):
self.reduce_all = True self.reduce_all = True
def set_attrs_addition(self):
if paddle.fluid.core.is_compiled_with_rocm():
self.user_defined_grads = [self.calc_grad()]
self.user_defined_grad_outputs = [np.ones(1, dtype=self.dtype)]
class TestLogsumexpError(unittest.TestCase): class TestLogsumexpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册