diff --git a/paddle/phi/kernels/funcs/gather.cu.h b/paddle/phi/kernels/funcs/gather.cu.h index 2b1822ece2627d33d4b834fe32a0c13c6723a390..7be374c6bf98a50249f4b34092dabe2d98768d5c 100644 --- a/paddle/phi/kernels/funcs/gather.cu.h +++ b/paddle/phi/kernels/funcs/gather.cu.h @@ -285,7 +285,8 @@ void GatherV2GradCUDAFunction(const DenseTensor* input, if (input->numel() == 0) return; int axis_index = axis; - int64_t input_index_dim_size = input_dim[axis_index]; + int64_t input_index_dim_size = + index->dims().size() == 0 ? 1 : input_dim[axis_index]; int64_t inner_dim_size = 1; int64_t outer_dim_size = 1; diff --git a/paddle/phi/kernels/funcs/scatter.cu.h b/paddle/phi/kernels/funcs/scatter.cu.h index c03dcba1e2e7f0d7faf28811f48b441e1b1d310e..5e9bba414d3a6ca97c90d721119c8c1958bac900 100644 --- a/paddle/phi/kernels/funcs/scatter.cu.h +++ b/paddle/phi/kernels/funcs/scatter.cu.h @@ -182,11 +182,15 @@ template void GPUScatterGradForX(const phi::GPUContext& ctx, const DenseTensor& index, DenseTensor* output) { - int64_t index_size = index.dims()[0]; + int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0]; auto dst_dims = output->dims(); // slice size - int64_t slice_size = 1; - for (int i = 1; i < dst_dims.size(); ++i) slice_size *= dst_dims[i]; + int64_t slice_size = 1; // slice size + if (index.dims().size() != 0) { + for (int i = 1; i < dst_dims.size(); ++i) slice_size *= dst_dims[i]; + } else { + for (int i = 0; i < dst_dims.size(); ++i) slice_size *= dst_dims[i]; + } const IndexT* p_index = index.data(); T* p_output = output->data(); const size_t& slice_bytes = slice_size * sizeof(T); diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index b00d305895a5c37c382f1479b16c7bf007555120..3170f3f62cc8521cd55783f5c645515ac048b0ba 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -887,7 +887,7 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(x.grad.shape, [2, 3]) self.assertEqual(out.grad.shape, [3]) - def _test_gather_xD_axis_1(self): + def test_gather_xD_axis_1(self): x = paddle.to_tensor( [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False ) @@ -901,7 +901,7 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(x.grad.shape, [2, 3]) self.assertEqual(out.grad.shape, [2]) - def _test_scatter_1D(self): + def test_scatter_1D(self): x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0], stop_gradient=False) index = paddle.full([], 2, 'int64') updates = paddle.full([], 4.0) @@ -913,7 +913,7 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(out.numpy()[2], 4) self.assertEqual(out.grad.shape, [5]) - def _test_scatter_XD(self): + def test_scatter_XD(self): x = paddle.to_tensor( [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False ) @@ -1925,7 +1925,7 @@ class TestSundryAPIStatic(unittest.TestCase): self.assertEqual(res[2].shape, (3,)) @prog_scope() - def _test_gather_XD_axis_1(self): + def test_gather_XD_axis_1(self): x = paddle.full([2, 3], 1.0, 'float32') x.stop_gradient = False index = paddle.full([], 1, 'int64') @@ -1940,7 +1940,7 @@ class TestSundryAPIStatic(unittest.TestCase): self.assertEqual(res[2].shape, (2,)) @prog_scope() - def _test_scatter_1D(self): + def test_scatter_1D(self): x = paddle.full([10], 1.0, 'float32') x.stop_gradient = False index = paddle.full([], 2, 'int64') @@ -1956,7 +1956,7 @@ class TestSundryAPIStatic(unittest.TestCase): self.assertEqual(res[2].shape, (10,)) @prog_scope() - def _test_scatter_XD(self): + def test_scatter_XD(self): x = paddle.full([2, 3], 1.0, 'float32') x.stop_gradient = False index = paddle.full([], 1, 'int64')