未验证 提交 74582aaa 编写于 作者: Y Yuang Liu 提交者: GitHub

0d tensor for scatter_ and scatter_nd (#49072)

上级 69536892
...@@ -1039,54 +1039,69 @@ void ScatterNdAddInferMeta(const MetaTensor& x, ...@@ -1039,54 +1039,69 @@ void ScatterNdAddInferMeta(const MetaTensor& x,
const auto& updates_dims = updates.dims(); const auto& updates_dims = updates.dims();
auto updates_dims_size = updates_dims.size(); auto updates_dims_size = updates_dims.size();
PADDLE_ENFORCE_LE( if (updates_dims_size == 0) {
index_dims[index_dims_size - 1], // check for 0d updates
ref_dims_size, PADDLE_ENFORCE_EQ(
phi::errors::InvalidArgument( index_dims_size,
"The last dimension of Input(Index)'s shape should be no greater " 1,
"than the rank of Input(X), but received the last dimension of " phi::errors::InvalidArgument("When the updates is a 0d tensor, the "
"Input(Index)'s shape is %d, the rank of Input(X) is %d.", "index should be a 1d tensor."));
index_dims[index_dims_size - 1],
ref_dims_size));
PADDLE_ENFORCE_GE(index_dims_size,
2UL,
phi::errors::InvalidArgument(
"The rank of Input(Index) should be greater than 1, "
"but received the rank of Input(Index) is %d.",
index_dims_size));
// update.shape = index.shape[:-1] + output.shape[index.shape[-1]:]
std::vector<int64_t> r_updates_dims;
for (int64_t i = 0; i < index_dims_size - 1; ++i) {
r_updates_dims.emplace_back(index_dims[i]);
}
for (int64_t i = index_dims[index_dims_size - 1]; i < ref_dims_size; ++i) {
r_updates_dims.emplace_back(ref_dims[i]);
}
PADDLE_ENFORCE_EQ(
r_updates_dims.size(),
updates_dims_size,
phi::errors::InvalidArgument(
"Updates has wrong shape. The shape of Updates and Input(Updates) "
"should be same, but received the shape of Updates is %d, "
"the shape of Input(Updates) is %d.",
r_updates_dims.size(),
updates_dims_size));
for (int64_t i = 0; i < updates_dims_size; ++i) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
r_updates_dims[i], index_dims[index_dims_size - 1],
updates_dims[i], ref_dims_size,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"Updates has wrong shape. The dimensions of Updates and " "When the update is a 0d tensor, The last dimension of "
"Input(Updates) should match, but received Updates's" "Input(Index)'s shape should be equal with the rank of Input(X)."));
"%d-th dimension is %d, Input(Updates)'s %d-th " } else {
"dimension is %d.", PADDLE_ENFORCE_LE(
i, index_dims[index_dims_size - 1],
r_updates_dims[i], ref_dims_size,
i, phi::errors::InvalidArgument(
updates_dims[i])); "The last dimension of Input(Index)'s shape should be no greater "
"than the rank of Input(X), but received the last dimension of "
"Input(Index)'s shape is %d, the rank of Input(X) is %d.",
index_dims[index_dims_size - 1],
ref_dims_size));
PADDLE_ENFORCE_GE(index_dims_size,
2UL,
phi::errors::InvalidArgument(
"The rank of Input(Index) should be greater than 1, "
"but received the rank of Input(Index) is %d.",
index_dims_size));
// update.shape = index.shape[:-1] + output.shape[index.shape[-1]:]
std::vector<int64_t> r_updates_dims;
for (int64_t i = 0; i < index_dims_size - 1; ++i) {
r_updates_dims.emplace_back(index_dims[i]);
}
for (int64_t i = index_dims[index_dims_size - 1]; i < ref_dims_size; ++i) {
r_updates_dims.emplace_back(ref_dims[i]);
}
// check for non-0d updates
PADDLE_ENFORCE_EQ(
r_updates_dims.size(),
updates_dims_size,
phi::errors::InvalidArgument(
"Updates has wrong shape. The shape of Updates and Input(Updates) "
"should be same, but received the shape of Updates is %d, "
"the shape of Input(Updates) is %d.",
r_updates_dims.size(),
updates_dims_size));
for (int64_t i = 0; i < updates_dims_size; ++i) {
PADDLE_ENFORCE_EQ(
r_updates_dims[i],
updates_dims[i],
phi::errors::InvalidArgument(
"Updates has wrong shape. The dimensions of Updates and "
"Input(Updates) should match, but received Updates's"
"%d-th dimension is %d, Input(Updates)'s %d-th "
"dimension is %d.",
i,
r_updates_dims[i],
i,
updates_dims[i]));
}
} }
out->set_dims(ref_dims); out->set_dims(ref_dims);
out->share_lod(x); out->share_lod(x);
......
...@@ -682,6 +682,36 @@ class TestSundryAPI(unittest.TestCase): ...@@ -682,6 +682,36 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(x2.grad.shape, []) self.assertEqual(x2.grad.shape, [])
self.assertEqual(x3.grad.shape, []) self.assertEqual(x3.grad.shape, [])
def test_scatter__1D(self):
x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0])
index = paddle.full([], 2, 'int64')
updates = paddle.full([], 4.0)
out = paddle.scatter_(x, index, updates)
self.assertEqual(out.numpy()[2], 4)
def test_scatter__XD(self):
x = paddle.to_tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
index = paddle.full([], 1, 'int64')
updates = paddle.to_tensor([1.0, 2.0, 3.0])
out = paddle.scatter_(x, index, updates)
for i in range(3):
self.assertEqual(out.numpy()[1][i], updates.numpy()[i])
def test_scatter_nd(self):
index = paddle.to_tensor([3], dtype="int64", stop_gradient=False)
updates = paddle.full([], 2, dtype='float32')
updates.stop_gradient = False
shape = [5]
out = paddle.scatter_nd(index, updates, shape)
out.backward()
self.assertEqual(out.shape, [5])
self.assertEqual(out.numpy()[3], 2)
self.assertEqual(out.grad.shape, [5])
class TestSundryAPIStatic(unittest.TestCase): class TestSundryAPIStatic(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -845,6 +875,45 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -845,6 +875,45 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res2.shape, (2, 2)) self.assertEqual(res2.shape, (2, 2))
self.assertEqual(res3.shape, (1, 1)) self.assertEqual(res3.shape, (1, 1))
@prog_scope()
def test_scatter__1D(self):
x = paddle.full([10], 1.0, 'float32')
index = paddle.full([], 2, 'int64')
updates = paddle.full([], 4, 'float32')
out = paddle.scatter_(x, index, updates)
paddle.static.append_backward(out)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out])
self.assertEqual(res[0][2], 4)
@prog_scope()
def test_scatter__XD(self):
x = paddle.full([2, 3], 1.0, 'float32')
index = paddle.full([], 1, 'int64')
updates = paddle.full([3], 4, 'float32')
out = paddle.scatter_(x, index, updates)
paddle.static.append_backward(out)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out])
for i in range(3):
self.assertEqual(res[0][1][i], 4)
@prog_scope()
def test_scatter_nd(self):
index = paddle.static.data(name='index', shape=[1], dtype='int64')
updates = paddle.full([], 2, 'float32')
shape = [5]
index_data = np.array([3], dtype=np.longlong)
out = paddle.scatter_nd(index, updates, shape)
paddle.static.append_backward(out)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, feed={'index': index_data}, fetch_list=[out])
self.assertEqual(res[0].shape, (5,))
self.assertEqual(res[0][3], 2)
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase): class TestNoBackwardAPI(unittest.TestCase):
......
...@@ -504,6 +504,23 @@ class TestSundryAPI(unittest.TestCase): ...@@ -504,6 +504,23 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(x2.grad.shape, []) self.assertEqual(x2.grad.shape, [])
self.assertEqual(x3.grad.shape, []) self.assertEqual(x3.grad.shape, [])
def test_scatter__1D(self):
x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0])
index = paddle.full([], 2, 'int64')
updates = paddle.full([], 4.0)
out = paddle.scatter_(x, index, updates)
self.assertEqual(out.numpy()[2], 4)
def test_scatter__XD(self):
x = paddle.to_tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
index = paddle.full([], 1, 'int64')
updates = paddle.to_tensor([1.0, 2.0, 3.0])
out = paddle.scatter_(x, index, updates)
for i in range(3):
self.assertEqual(out.numpy()[1][i], updates.numpy()[i])
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase): class TestNoBackwardAPI(unittest.TestCase):
......
...@@ -3078,7 +3078,7 @@ def scatter_nd(index, updates, shape, name=None): ...@@ -3078,7 +3078,7 @@ def scatter_nd(index, updates, shape, name=None):
seen :code:`scatter_nd_add` . This op is the inverse of the :code:`gather_nd` op. seen :code:`scatter_nd_add` . This op is the inverse of the :code:`gather_nd` op.
Args: Args:
index (Tensor): The index input with ndim > 1 and index.shape[-1] <= len(shape). index (Tensor): The index input with ndim >= 1 and index.shape[-1] <= len(shape).
Its dtype should be int32 or int64 as it is used as indexes. Its dtype should be int32 or int64 as it is used as indexes.
updates (Tensor): The updated value of scatter_nd op. Its dtype should be float32, float64. updates (Tensor): The updated value of scatter_nd op. Its dtype should be float32, float64.
It must have the shape index.shape[:-1] + shape[index.shape[-1]:] It must have the shape index.shape[:-1] + shape[index.shape[-1]:]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册