diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index da1daac8b8873e2f03d60338811e7b426bc8d1e3..bcd6f459b8dc37accc3922a02d835c14567248a8 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -59,7 +59,11 @@ void gather_grad(const Tensor& x, std::vector reverse_perm(tmp_perm); // make origin ranks for (int i = 0; i < static_cast(tmp_perm.size()); ++i) { - reverse_perm[tmp_perm[i]] = i; + if (tmp_perm[i] >= 0) { + reverse_perm[tmp_perm[i]] = i; + } else { + reverse_perm[tmp_perm[i] + tmp_perm.size()] = i; + } } // transpose out_grad and zero grad to target rank. diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index a355b2545dd615d6990f52032cb4af81d7329998..8495858904189ae13e47309793d6b61b8199b81f 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -982,7 +982,7 @@ set_tests_properties(test_matmul_op PROPERTIES TIMEOUT 120) set_tests_properties(test_nearest_interp_v2_op PROPERTIES TIMEOUT 120) set_tests_properties(test_trilinear_interp_op PROPERTIES TIMEOUT 120) set_tests_properties(test_bicubic_interp_v2_op PROPERTIES TIMEOUT 120) -set_tests_properties(test_gather_op PROPERTIES TIMEOUT 180) +set_tests_properties(test_gather_op PROPERTIES TIMEOUT 200) set_tests_properties(test_static_save_load PROPERTIES TIMEOUT 250) set_tests_properties(test_pylayer_op PROPERTIES TIMEOUT 120) set_tests_properties(test_paddle_save_load_binary PROPERTIES TIMEOUT 120) @@ -1217,7 +1217,8 @@ set(TEST_CINN_OPS test_transpose_op test_reshape_op test_mean_op - test_unsqueeze2_op) + test_unsqueeze2_op + test_gather_op) foreach(TEST_CINN_OPS ${TEST_CINN_OPS}) if(WITH_CINN) diff --git a/python/paddle/fluid/tests/unittests/test_gather_op.py b/python/paddle/fluid/tests/unittests/test_gather_op.py index 5844d7b51da143b840486b974d8dc1bac843a5ba..ae3704dcc1032bfb8746f4288f78f7857351241e 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_op.py @@ -35,6 +35,7 @@ class TestGatherOp(OpTest): self.op_type = "gather" self.python_api = paddle.gather self.config() + self.prim_op_type = "prim" xnp = np.random.random(self.x_shape).astype(self.x_type) self.inputs = { 'X': xnp, @@ -46,7 +47,7 @@ class TestGatherOp(OpTest): self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_eager=True) + self.check_grad(['X'], 'Out', check_eager=True, check_prim=True) def config(self): """