diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index f7bae3690991f5cd208489c453e270d132f5c1ea..74391958189186983360df0a14dcd273d211f2b1 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -1039,54 +1039,69 @@ void ScatterNdAddInferMeta(const MetaTensor& x, const auto& updates_dims = updates.dims(); auto updates_dims_size = updates_dims.size(); - PADDLE_ENFORCE_LE( - index_dims[index_dims_size - 1], - ref_dims_size, - phi::errors::InvalidArgument( - "The last dimension of Input(Index)'s shape should be no greater " - "than the rank of Input(X), but received the last dimension of " - "Input(Index)'s shape is %d, the rank of Input(X) is %d.", - index_dims[index_dims_size - 1], - ref_dims_size)); - PADDLE_ENFORCE_GE(index_dims_size, - 2UL, - phi::errors::InvalidArgument( - "The rank of Input(Index) should be greater than 1, " - "but received the rank of Input(Index) is %d.", - index_dims_size)); - - // update.shape = index.shape[:-1] + output.shape[index.shape[-1]:] - std::vector r_updates_dims; - for (int64_t i = 0; i < index_dims_size - 1; ++i) { - r_updates_dims.emplace_back(index_dims[i]); - } - for (int64_t i = index_dims[index_dims_size - 1]; i < ref_dims_size; ++i) { - r_updates_dims.emplace_back(ref_dims[i]); - } - - PADDLE_ENFORCE_EQ( - r_updates_dims.size(), - updates_dims_size, - phi::errors::InvalidArgument( - "Updates has wrong shape. The shape of Updates and Input(Updates) " - "should be same, but received the shape of Updates is %d, " - "the shape of Input(Updates) is %d.", - r_updates_dims.size(), - updates_dims_size)); - - for (int64_t i = 0; i < updates_dims_size; ++i) { + if (updates_dims_size == 0) { + // check for 0d updates + PADDLE_ENFORCE_EQ( + index_dims_size, + 1, + phi::errors::InvalidArgument("When the updates is a 0d tensor, the " + "index should be a 1d tensor.")); PADDLE_ENFORCE_EQ( - r_updates_dims[i], - updates_dims[i], + index_dims[index_dims_size - 1], + ref_dims_size, phi::errors::InvalidArgument( - "Updates has wrong shape. The dimensions of Updates and " - "Input(Updates) should match, but received Updates's" - "%d-th dimension is %d, Input(Updates)'s %d-th " - "dimension is %d.", - i, - r_updates_dims[i], - i, - updates_dims[i])); + "When the update is a 0d tensor, The last dimension of " + "Input(Index)'s shape should be equal with the rank of Input(X).")); + } else { + PADDLE_ENFORCE_LE( + index_dims[index_dims_size - 1], + ref_dims_size, + phi::errors::InvalidArgument( + "The last dimension of Input(Index)'s shape should be no greater " + "than the rank of Input(X), but received the last dimension of " + "Input(Index)'s shape is %d, the rank of Input(X) is %d.", + index_dims[index_dims_size - 1], + ref_dims_size)); + PADDLE_ENFORCE_GE(index_dims_size, + 2UL, + phi::errors::InvalidArgument( + "The rank of Input(Index) should be greater than 1, " + "but received the rank of Input(Index) is %d.", + index_dims_size)); + + // update.shape = index.shape[:-1] + output.shape[index.shape[-1]:] + std::vector r_updates_dims; + for (int64_t i = 0; i < index_dims_size - 1; ++i) { + r_updates_dims.emplace_back(index_dims[i]); + } + for (int64_t i = index_dims[index_dims_size - 1]; i < ref_dims_size; ++i) { + r_updates_dims.emplace_back(ref_dims[i]); + } + // check for non-0d updates + PADDLE_ENFORCE_EQ( + r_updates_dims.size(), + updates_dims_size, + phi::errors::InvalidArgument( + "Updates has wrong shape. The shape of Updates and Input(Updates) " + "should be same, but received the shape of Updates is %d, " + "the shape of Input(Updates) is %d.", + r_updates_dims.size(), + updates_dims_size)); + + for (int64_t i = 0; i < updates_dims_size; ++i) { + PADDLE_ENFORCE_EQ( + r_updates_dims[i], + updates_dims[i], + phi::errors::InvalidArgument( + "Updates has wrong shape. The dimensions of Updates and " + "Input(Updates) should match, but received Updates's" + "%d-th dimension is %d, Input(Updates)'s %d-th " + "dimension is %d.", + i, + r_updates_dims[i], + i, + updates_dims[i])); + } } out->set_dims(ref_dims); out->share_lod(x); diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index eae5528fba244c5ae3539acba4383d3fcee354f9..9fb6017d446d07f9bda45a02d189c14691f05325 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -682,6 +682,36 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(x2.grad.shape, []) self.assertEqual(x3.grad.shape, []) + def test_scatter__1D(self): + x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0]) + index = paddle.full([], 2, 'int64') + updates = paddle.full([], 4.0) + out = paddle.scatter_(x, index, updates) + + self.assertEqual(out.numpy()[2], 4) + + def test_scatter__XD(self): + x = paddle.to_tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + index = paddle.full([], 1, 'int64') + updates = paddle.to_tensor([1.0, 2.0, 3.0]) + out = paddle.scatter_(x, index, updates) + + for i in range(3): + self.assertEqual(out.numpy()[1][i], updates.numpy()[i]) + + def test_scatter_nd(self): + index = paddle.to_tensor([3], dtype="int64", stop_gradient=False) + updates = paddle.full([], 2, dtype='float32') + updates.stop_gradient = False + shape = [5] + + out = paddle.scatter_nd(index, updates, shape) + out.backward() + + self.assertEqual(out.shape, [5]) + self.assertEqual(out.numpy()[3], 2) + self.assertEqual(out.grad.shape, [5]) + class TestSundryAPIStatic(unittest.TestCase): def setUp(self): @@ -845,6 +875,45 @@ class TestSundryAPIStatic(unittest.TestCase): self.assertEqual(res2.shape, (2, 2)) self.assertEqual(res3.shape, (1, 1)) + @prog_scope() + def test_scatter__1D(self): + x = paddle.full([10], 1.0, 'float32') + index = paddle.full([], 2, 'int64') + updates = paddle.full([], 4, 'float32') + out = paddle.scatter_(x, index, updates) + paddle.static.append_backward(out) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out]) + self.assertEqual(res[0][2], 4) + + @prog_scope() + def test_scatter__XD(self): + x = paddle.full([2, 3], 1.0, 'float32') + index = paddle.full([], 1, 'int64') + updates = paddle.full([3], 4, 'float32') + out = paddle.scatter_(x, index, updates) + paddle.static.append_backward(out) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out]) + for i in range(3): + self.assertEqual(res[0][1][i], 4) + + @prog_scope() + def test_scatter_nd(self): + index = paddle.static.data(name='index', shape=[1], dtype='int64') + updates = paddle.full([], 2, 'float32') + shape = [5] + index_data = np.array([3], dtype=np.longlong) + out = paddle.scatter_nd(index, updates, shape) + paddle.static.append_backward(out) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, feed={'index': index_data}, fetch_list=[out]) + self.assertEqual(res[0].shape, (5,)) + self.assertEqual(res[0][3], 2) + # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. class TestNoBackwardAPI(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py index a6f91e5df4c66ecf9e9fb535aef350b128fb3bd3..018ecc20e7dafe09b57d1068cace9d03e4c140cf 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py @@ -504,6 +504,23 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(x2.grad.shape, []) self.assertEqual(x3.grad.shape, []) + def test_scatter__1D(self): + x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0]) + index = paddle.full([], 2, 'int64') + updates = paddle.full([], 4.0) + out = paddle.scatter_(x, index, updates) + + self.assertEqual(out.numpy()[2], 4) + + def test_scatter__XD(self): + x = paddle.to_tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + index = paddle.full([], 1, 'int64') + updates = paddle.to_tensor([1.0, 2.0, 3.0]) + out = paddle.scatter_(x, index, updates) + + for i in range(3): + self.assertEqual(out.numpy()[1][i], updates.numpy()[i]) + # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. class TestNoBackwardAPI(unittest.TestCase): diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index df478ddde460a2343de312112feaec53c7cab391..c255a223683a54a7774f58bcd7793a8e6676c64c 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -3078,7 +3078,7 @@ def scatter_nd(index, updates, shape, name=None): seen :code:`scatter_nd_add` . This op is the inverse of the :code:`gather_nd` op. Args: - index (Tensor): The index input with ndim > 1 and index.shape[-1] <= len(shape). + index (Tensor): The index input with ndim >= 1 and index.shape[-1] <= len(shape). Its dtype should be int32 or int64 as it is used as indexes. updates (Tensor): The updated value of scatter_nd op. Its dtype should be float32, float64. It must have the shape index.shape[:-1] + shape[index.shape[-1]:]