提交 9dbea788 编写于 作者: F ForFishes

fix the scatternd/scatterndadd

上级 6b727e08
......@@ -8555,6 +8555,11 @@ def scatter_nd_add(ref, index, updates, name=None):
output = fluid.layers.scatter_nd_add(ref, index, updates)
"""
if in_dygraph_mode():
op = getattr(core.ops, 'scatter_nd_add')
return op(ref, updates, output)
if ref.dtype != updates.dtype:
raise ValueError("ref and updates must have same data type.")
......@@ -8577,34 +8582,38 @@ def scatter_nd(index, updates, shape, name=None):
Output is obtained by scattering the :attr:`updates` in a new tensor according
to :attr:`index` . This op is similar to :code:`scatter_nd_add`, except the
tensor of :attr:`shape` is zero-initialized. Correspondingly, :code:`scatter_nd(index, updates, shape)`
is equal to :code:`scatter_nd_add(fluid.layers.zeros(shape, updates.dtype), index, updates)` .
is equal to :code:`scatter_nd_add(paddle.zeros(shape, updates.dtype), index, updates)` .
If :attr:`index` has repeated elements, then the corresponding updates are accumulated.
Because of the numerical approximation issues, the different order of repeated elements
in :attr:`index` may cause different results. The specific calculation method can be
seen :code:`scatter_nd_add` . This op is the inverse of the :code:`gather_nd` op.
Args:
index (Variable): The index input with rank > 1 and index.shape[-1] <= len(shape).
index (Tensor): The index input with rank > 1 and index.shape[-1] <= len(shape).
Its dtype should be int32 or int64 as it is used as indexes.
updates (Variable): The updated value of scatter_nd op. Its dtype should be float32, float64.
updates (Tensor): The updated value of scatter_nd op. Its dtype should be float32, float64.
It must have the shape index.shape[:-1] + shape[index.shape[-1]:]
shape(tuple|list): Shape of output tensor.
name (str|None): The output variable name. If set None, the layer will be named automatically.
name (str|None): The output Tensor name. If set None, the layer will be named automatically.
Returns:
output (Variable): The output is a tensor with the same type as :attr:`updates` .
output (Tensor): The output is a tensor with the same type as :attr:`updates` .
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle
import numpy as np
index = fluid.data(name='index', shape=[3, 2], dtype='int64')
updates = fluid.data(name='update', shape=[3, 9, 10], dtype='float32')
index_data = np.array([[1, 1],
[0, 1],
[1, 3]]).astype(np.int64)
index = paddle.to_tensor(index_data)
updates = paddle.rand(shape=[3, 9, 10], dtype='float32')
shape = [3, 5, 9, 10]
output = fluid.layers.scatter_nd(index, updates, shape)
output = paddle.scatter_nd(index, updates, shape)
"""
return scatter_nd_add(zeros(shape, updates.dtype), index, updates, name)
......
......@@ -284,5 +284,24 @@ class TestScatterNdOpRaise(unittest.TestCase):
self.assertRaises(ValueError, check_raise_is_test)
class TestDygraph(unittest.TestCase):
def test_dygraph1(self):
paddle.disable_static()
index_data = np.array([[1, 1], [0, 1], [1, 3]]).astype(np.int64)
index = paddle.to_tensor(index_data)
updates = paddle.rand(shape=[3, 9, 10], dtype='float32')
shape = [3, 5, 9, 10]
output = paddle.scatter_nd(index, updates, shape)
def test_dygraph2(self):
paddle.disable_static()
x = paddle.rand(shape=[3, 5, 9, 10], dtype='float32')
updates = paddle.rand(shape=[3, 9, 10], dtype='float32')
index_data = np.array([[1, 1], [0, 1], [1, 3]]).astype(np.int64)
index = paddle.to_tensor(index_data)
output = paddle.scatter_nd_add(x, index, updates)
if __name__ == "__main__":
unittest.main()
......@@ -29,7 +29,6 @@ from ..fluid.layers import strided_slice #DEFINE_ALIAS
from ..fluid.layers import transpose #DEFINE_ALIAS
from ..fluid.layers import unstack #DEFINE_ALIAS
from ..fluid.layers import scatter_nd_add #DEFINE_ALIAS
from ..fluid.layers import scatter_nd #DEFINE_ALIAS
from ..fluid.layers import shard_index #DEFINE_ALIAS
from ..fluid.layers import unique_with_counts #DEFINE_ALIAS
......@@ -974,6 +973,78 @@ def scatter(x, index, updates, overwrite=True, name=None):
return out
def scatter_nd_add(x, index, updates, name=None):
"""
**Scatter_nd_add Layer**
Output is obtained by applying sparse addition to a single value
or slice in a Tensor.
:attr:`x` is a Tensor with rank :math:`R`
and :attr:`index` is a Tensor with rank :math:`K` . Thus, :attr:`index`
has shape :math:`[i_0, i_1, ..., i_{K-2}, Q]` where :math:`Q \leq R` . :attr:`updates`
is a Tensor with rank :math:`K - 1 + R - Q` and its
shape is :math:`index.shape[:-1] + x.shape[index.shape[-1]:]` .
According to the :math:`[i_0, i_1, ..., i_{K-2}]` of :attr:`index` ,
add the corresponding :attr:`updates` slice to the :attr:`x` slice
which is obtained by the last one dimension of :attr:`index` .
.. code-block:: text
Given:
* Case 1:
x = [0, 1, 2, 3, 4, 5]
index = [[1], [2], [3], [1]]
updates = [9, 10, 11, 12]
we get:
output = [0, 22, 12, 14, 4, 5]
* Case 2:
x = [[65, 17], [-14, -25]]
index = [[], []]
updates = [[[-1, -2], [1, 2]],
[[3, 4], [-3, -4]]]
x.shape = (2, 2)
index.shape = (2, 0)
updates.shape = (2, 2, 2)
we get:
output = [[67, 19], [-16, -27]]
Args:
x (Tensor): The x input. Its dtype should be float32, float64.
index (Tensor): The index input with rank > 1 and index.shape[-1] <= x.rank.
Its dtype should be int32 or int64 as it is used as indexes.
updates (Tensor): The updated value of scatter_nd_add op, and it must have the same dtype
as x. It must have the shape index.shape[:-1] + x.shape[index.shape[-1]:].
name (str|None): The output tensor name. If set None, the layer will be named automatically.
Returns:
output (Tensor): The output is a tensor with the same shape and dtype as x.
Examples:
.. code-block:: python
import paddle
import numpy as np
x = paddle.rand(shape=[3, 5, 9, 10], dtype='float32')
updates = paddle.rand(shape=[3, 9, 10], dtype='float32')
index_data = np.array([[1, 1],
[0, 1],
[1, 3]]).astype(np.int64)
index = paddle.to_tensor(index_data)
output = paddle.scatter_nd_add(x, index, updates)
"""
return layers.scatter_nd_add(x, index, updates, name=None)
def chunk(x, chunks, axis=0, name=None):
"""
Split the input tensor into multiple sub-Tensors.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册