From 5b3c7ee725bdfdc8876bbc976ed52704984f3ed7 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Wed, 15 Mar 2023 11:36:34 +0800 Subject: [PATCH] support gather test on prim and cinn (#51376) * support gather test on prim and cinn * reset timeout for gather --- .../prim/api/composite_backward/composite_backward_api.h | 6 +++++- python/paddle/fluid/tests/unittests/CMakeLists.txt | 5 +++-- python/paddle/fluid/tests/unittests/test_gather_op.py | 3 ++- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index da1daac8b88..bcd6f459b8d 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -59,7 +59,11 @@ void gather_grad(const Tensor& x, std::vector reverse_perm(tmp_perm); // make origin ranks for (int i = 0; i < static_cast(tmp_perm.size()); ++i) { - reverse_perm[tmp_perm[i]] = i; + if (tmp_perm[i] >= 0) { + reverse_perm[tmp_perm[i]] = i; + } else { + reverse_perm[tmp_perm[i] + tmp_perm.size()] = i; + } } // transpose out_grad and zero grad to target rank. diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index a355b2545dd..84958589041 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -982,7 +982,7 @@ set_tests_properties(test_matmul_op PROPERTIES TIMEOUT 120) set_tests_properties(test_nearest_interp_v2_op PROPERTIES TIMEOUT 120) set_tests_properties(test_trilinear_interp_op PROPERTIES TIMEOUT 120) set_tests_properties(test_bicubic_interp_v2_op PROPERTIES TIMEOUT 120) -set_tests_properties(test_gather_op PROPERTIES TIMEOUT 180) +set_tests_properties(test_gather_op PROPERTIES TIMEOUT 200) set_tests_properties(test_static_save_load PROPERTIES TIMEOUT 250) set_tests_properties(test_pylayer_op PROPERTIES TIMEOUT 120) set_tests_properties(test_paddle_save_load_binary PROPERTIES TIMEOUT 120) @@ -1217,7 +1217,8 @@ set(TEST_CINN_OPS test_transpose_op test_reshape_op test_mean_op - test_unsqueeze2_op) + test_unsqueeze2_op + test_gather_op) foreach(TEST_CINN_OPS ${TEST_CINN_OPS}) if(WITH_CINN) diff --git a/python/paddle/fluid/tests/unittests/test_gather_op.py b/python/paddle/fluid/tests/unittests/test_gather_op.py index 5844d7b51da..ae3704dcc10 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_op.py @@ -35,6 +35,7 @@ class TestGatherOp(OpTest): self.op_type = "gather" self.python_api = paddle.gather self.config() + self.prim_op_type = "prim" xnp = np.random.random(self.x_shape).astype(self.x_type) self.inputs = { 'X': xnp, @@ -46,7 +47,7 @@ class TestGatherOp(OpTest): self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_eager=True) + self.check_grad(['X'], 'Out', check_eager=True, check_prim=True) def config(self): """ -- GitLab