From 3e2c6a561bd533e092f6d622c1efcb4b0fc83776 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Fri, 28 Jul 2023 15:30:56 +0800 Subject: [PATCH] [bug fix] fix scatter 0d index grad error (#55738) --- paddle/phi/infermeta/ternary.cc | 13 +++++++++- paddle/phi/kernels/funcs/scatter.cu.h | 8 ++----- paddle/phi/kernels/funcs/scatter.h | 2 +- test/legacy_test/test_zero_dim_tensor.py | 30 ++++++++++++++++++++++++ 4 files changed, 45 insertions(+), 8 deletions(-) diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index b3f17ab91b9..bc41536fd9f 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -1060,7 +1060,7 @@ void ScatterInferMeta(const MetaTensor& x, (ref_dims.size() == updates_dims.size()), true, phi::errors::InvalidArgument( - "When the Input(Updates) is not a 0D tensor, the " + "When the Input(Index) is not a 0D tensor, the " "Input(X) and Input(Updates) should have the same shape size, " "but received the size of Input(x)'s shape is %d, the size of " "Input(Updates)'s shape is %d.", @@ -1075,6 +1075,17 @@ void ScatterInferMeta(const MetaTensor& x, "batch-size is %d.", updates_dims[0], index_dims[0])); + } else { + PADDLE_ENFORCE_EQ( + (ref_dims.size() - 1 == updates_dims.size()), + true, + phi::errors::InvalidArgument( + "When the Input(Index) is a 0D tensor, the " + "Input(Updates) should have the shape size as Input(X)'s " + "shape size - 1. But received the size of Input(x)'s shape is %d, " + " the size of Input(Updates)'s shape is %d.", + ref_dims.size(), + updates_dims.size())); } out->set_dims(ref_dims); out->share_lod(x); diff --git a/paddle/phi/kernels/funcs/scatter.cu.h b/paddle/phi/kernels/funcs/scatter.cu.h index 19a391ea150..c3f0cf61986 100644 --- a/paddle/phi/kernels/funcs/scatter.cu.h +++ b/paddle/phi/kernels/funcs/scatter.cu.h @@ -195,12 +195,8 @@ void GPUScatterGradForX(const phi::GPUContext& ctx, int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0]; auto dst_dims = output->dims(); // slice size - int64_t slice_size = 1; // slice size - if (index.dims().size() != 0) { - for (int i = 1; i < dst_dims.size(); ++i) slice_size *= dst_dims[i]; - } else { - for (int i = 0; i < dst_dims.size(); ++i) slice_size *= dst_dims[i]; - } + int64_t slice_size = 1; + for (int i = 1; i < dst_dims.size(); ++i) slice_size *= dst_dims[i]; const IndexT* p_index = index.data(); T* p_output = output->data(); const size_t& slice_bytes = slice_size * sizeof(T); diff --git a/paddle/phi/kernels/funcs/scatter.h b/paddle/phi/kernels/funcs/scatter.h index 64bca648251..5934f57b47d 100644 --- a/paddle/phi/kernels/funcs/scatter.h +++ b/paddle/phi/kernels/funcs/scatter.h @@ -244,7 +244,7 @@ template void CPUScatterGradForX(const phi::CPUContext& ctx UNUSED, const DenseTensor& index, DenseTensor* output) { - int64_t index_size = index.dims()[0]; + int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0]; auto dst_dims = output->dims(); const IndexT* p_index = index.data(); T* p_output = output->data(); diff --git a/test/legacy_test/test_zero_dim_tensor.py b/test/legacy_test/test_zero_dim_tensor.py index 6f47f2d46b5..7c814a83927 100644 --- a/test/legacy_test/test_zero_dim_tensor.py +++ b/test/legacy_test/test_zero_dim_tensor.py @@ -1916,6 +1916,36 @@ class TestSundryAPI(unittest.TestCase): np.testing.assert_array_equal(out.numpy()[1], [1.0, 2.0, 3.0]) self.assertEqual(out.grad.shape, [2, 3]) + def test_scatter_shape_check(self): + x = paddle.to_tensor([1.0, 2.0, 3.0]) + index = paddle.to_tensor(1) + updates = paddle.to_tensor([3.0]) + with self.assertRaises(ValueError): + out = paddle.scatter(x, index, updates) + + x = paddle.to_tensor([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]]) + index = paddle.to_tensor(1) + updates = paddle.to_tensor([[5.0, 5.0]]) + with self.assertRaises(ValueError): + out = paddle.scatter(x, index, updates) + + def test_scatter_0D_index(self): + x = paddle.to_tensor([1.0, 2.0, 3.0], stop_gradient=False) + index = paddle.to_tensor(1) + updates = paddle.to_tensor(3.0) + out = paddle.scatter(x, index, updates) + out.backward() + np.testing.assert_array_equal(x.grad.numpy()[1], 0.0) + + x = paddle.to_tensor( + [[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]], stop_gradient=False + ) + index = paddle.to_tensor(1) + updates = paddle.to_tensor([5.0, 5.0]) + out = paddle.scatter(x, index, updates) + out.backward() + np.testing.assert_array_equal(x.grad.numpy()[1], [0.0, 0.0]) + def test_diagflat(self): x1 = paddle.rand([]) x2 = paddle.rand([]) -- GitLab