未验证 提交 02dfd18d 编写于 作者: V Vvsmile 提交者: GitHub

replace scatter_nd and scatter_nd_add with paddle.scatter_nd and (#47960)

paddle.scatter_nd_add
上级 87388d59
...@@ -117,8 +117,6 @@ __all__ = [ ...@@ -117,8 +117,6 @@ __all__ = [
'resize_nearest', 'resize_nearest',
'gather_nd', 'gather_nd',
'scatter', 'scatter',
'scatter_nd_add',
'scatter_nd',
'random_crop', 'random_crop',
'mean_iou', 'mean_iou',
'relu', 'relu',
...@@ -8695,138 +8693,6 @@ def scatter(input, index, updates, name=None, overwrite=True): ...@@ -8695,138 +8693,6 @@ def scatter(input, index, updates, name=None, overwrite=True):
return out return out
def scatter_nd_add(ref, index, updates, name=None):
r"""
**Scatter_nd_add Layer**
Output is obtained by applying sparse addition to a single value
or slice in a Variable.
:attr:`ref` is a Tensor with rank :math:`R`
and :attr:`index` is a Tensor with rank :math:`K` . Thus, :attr:`index`
has shape :math:`[i_0, i_1, ..., i_{K-2}, Q]` where :math:`Q \leq R` . :attr:`updates`
is a Tensor with rank :math:`K - 1 + R - Q` and its
shape is :math:`index.shape[:-1] + ref.shape[index.shape[-1]:]` .
According to the :math:`[i_0, i_1, ..., i_{K-2}]` of :attr:`index` ,
add the corresponding :attr:`updates` slice to the :attr:`ref` slice
which is obtained by the last one dimension of :attr:`index` .
.. code-block:: text
Given:
* Case 1:
ref = [0, 1, 2, 3, 4, 5]
index = [[1], [2], [3], [1]]
updates = [9, 10, 11, 12]
we get:
output = [0, 22, 12, 14, 4, 5]
* Case 2:
ref = [[65, 17], [-14, -25]]
index = [[], []]
updates = [[[-1, -2], [1, 2]],
[[3, 4], [-3, -4]]]
ref.shape = (2, 2)
index.shape = (2, 0)
updates.shape = (2, 2, 2)
we get:
output = [[67, 19], [-16, -27]]
Args:
ref (Variable): The ref input. Its dtype should be int32, int64, float32, float64.
index (Variable): The index input with rank > 1 and index.shape[-1] <= ref.rank.
Its dtype should be int32 or int64 as it is used as indexes.
updates (Variable): The updated value of scatter_nd_add op, and it must have the same dtype
as ref. It must have the shape index.shape[:-1] + ref.shape[index.shape[-1]:].
name (str|None): The output variable name. If set None, the layer will be named automatically.
Returns:
output (Variable): The output is a tensor with the same shape and dtype as ref.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle
paddle.enable_static()
ref = fluid.data(name='ref', shape=[3, 5, 9, 10], dtype='float32')
index = fluid.data(name='index', shape=[3, 2], dtype='int32')
updates = fluid.data(name='update', shape=[3, 9, 10], dtype='float32')
output = fluid.layers.scatter_nd_add(ref, index, updates)
"""
if in_dygraph_mode():
return _C_ops.scatter_nd_add(ref, index, updates)
else:
if _in_legacy_dygraph():
op = getattr(_legacy_C_ops, 'scatter_nd_add')
return op(ref, index, updates)
else:
if ref.dtype != updates.dtype:
raise ValueError("ref and updates must have same data type.")
helper = LayerHelper('scatter_nd_add', **locals())
dtype = helper.input_dtype(input_param_name='ref')
output = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type="scatter_nd_add",
inputs={"X": ref, "Index": index, "Updates": updates},
outputs={"Out": output},
)
return output
def scatter_nd(index, updates, shape, name=None):
"""
**Scatter_nd Layer**
Output is obtained by scattering the :attr:`updates` in a new tensor according
to :attr:`index` . This op is similar to :code:`scatter_nd_add`, except the
tensor of :attr:`shape` is zero-initialized. Correspondingly, :code:`scatter_nd(index, updates, shape)`
is equal to :code:`scatter_nd_add(paddle.zeros(shape, updates.dtype), index, updates)` .
If :attr:`index` has repeated elements, then the corresponding updates are accumulated.
Because of the numerical approximation issues, the different order of repeated elements
in :attr:`index` may cause different results. The specific calculation method can be
seen :code:`scatter_nd_add` . This op is the inverse of the :code:`gather_nd` op.
Args:
index (Tensor): The index input with ndim > 1 and index.shape[-1] <= len(shape).
Its dtype should be int32 or int64 as it is used as indexes.
updates (Tensor): The updated value of scatter_nd op. Its dtype should be float32, float64.
It must have the shape index.shape[:-1] + shape[index.shape[-1]:]
shape(tuple|list): Shape of output tensor.
name (str|None): The output Tensor name. If set None, the layer will be named automatically.
Returns:
output (Tensor): The output is a tensor with the same type as :attr:`updates` .
Examples:
.. code-block:: python
import paddle
import numpy as np
index_data = np.array([[1, 1],
[0, 1],
[1, 3]]).astype(np.int64)
index = paddle.to_tensor(index_data)
updates = paddle.rand(shape=[3, 9, 10], dtype='float32')
shape = [3, 5, 9, 10]
output = paddle.scatter_nd(index, updates, shape)
"""
return scatter_nd_add(zeros(shape, updates.dtype), index, updates, name)
@templatedoc() @templatedoc()
def random_crop(x, shape, seed=None): def random_crop(x, shape, seed=None):
""" """
......
...@@ -183,7 +183,7 @@ class TestScatterNdOpAPI(unittest.TestCase): ...@@ -183,7 +183,7 @@ class TestScatterNdOpAPI(unittest.TestCase):
dtype='float32', dtype='float32',
append_batch_size=False, append_batch_size=False,
) )
output1 = fluid.layers.scatter_nd_add(ref1, index1, updates1) output1 = paddle.scatter_nd_add(ref1, index1, updates1)
def testcase2(self): def testcase2(self):
ref2 = fluid.layers.data( ref2 = fluid.layers.data(
...@@ -204,7 +204,7 @@ class TestScatterNdOpAPI(unittest.TestCase): ...@@ -204,7 +204,7 @@ class TestScatterNdOpAPI(unittest.TestCase):
dtype='double', dtype='double',
append_batch_size=False, append_batch_size=False,
) )
output2 = fluid.layers.scatter_nd_add( output2 = paddle.scatter_nd_add(
ref2, index2, updates2, name="scatter_nd_add" ref2, index2, updates2, name="scatter_nd_add"
) )
...@@ -222,7 +222,7 @@ class TestScatterNdOpAPI(unittest.TestCase): ...@@ -222,7 +222,7 @@ class TestScatterNdOpAPI(unittest.TestCase):
dtype='float32', dtype='float32',
append_batch_size=False, append_batch_size=False,
) )
output3 = fluid.layers.scatter_nd(index3, updates3, shape3) output3 = paddle.scatter_nd(index3, updates3, shape3)
def testcase4(self): def testcase4(self):
shape4 = [10, 9, 8, 1, 3] shape4 = [10, 9, 8, 1, 3]
...@@ -238,9 +238,7 @@ class TestScatterNdOpAPI(unittest.TestCase): ...@@ -238,9 +238,7 @@ class TestScatterNdOpAPI(unittest.TestCase):
dtype='double', dtype='double',
append_batch_size=False, append_batch_size=False,
) )
output4 = fluid.layers.scatter_nd( output4 = paddle.scatter_nd(index4, updates4, shape4, name='scatter_nd')
index4, updates4, shape4, name='scatter_nd'
)
def testcase5(self): def testcase5(self):
if not fluid.core.is_compiled_with_cuda(): if not fluid.core.is_compiled_with_cuda():
...@@ -307,7 +305,7 @@ class TestScatterNdOpRaise(unittest.TestCase): ...@@ -307,7 +305,7 @@ class TestScatterNdOpRaise(unittest.TestCase):
updates5 = fluid.layers.data( updates5 = fluid.layers.data(
name='updates5', shape=[2, 10], dtype='float32' name='updates5', shape=[2, 10], dtype='float32'
) )
output5 = fluid.layers.scatter_nd_add(ref5, index5, updates5) output5 = paddle.scatter_nd_add(ref5, index5, updates5)
except Exception as e: except Exception as e:
t = "The last dimension of Input(Index)'s shape should be no greater " t = "The last dimension of Input(Index)'s shape should be no greater "
if t in str(e): if t in str(e):
...@@ -335,7 +333,7 @@ class TestScatterNdOpRaise(unittest.TestCase): ...@@ -335,7 +333,7 @@ class TestScatterNdOpRaise(unittest.TestCase):
dtype='float32', dtype='float32',
append_batch_size=False, append_batch_size=False,
) )
output6 = fluid.layers.scatter_nd_add(ref6, index6, updates6) output6 = paddle.scatter_nd_add(ref6, index6, updates6)
def test_check_raise3(self): def test_check_raise3(self):
def check_raise_is_test(): def check_raise_is_test():
...@@ -347,7 +345,7 @@ class TestScatterNdOpRaise(unittest.TestCase): ...@@ -347,7 +345,7 @@ class TestScatterNdOpRaise(unittest.TestCase):
updates7 = fluid.layers.data( updates7 = fluid.layers.data(
name='updates7', shape=[2, 4, 5, 20], dtype='float32' name='updates7', shape=[2, 4, 5, 20], dtype='float32'
) )
output7 = fluid.layers.scatter_nd(index7, updates7, shape) output7 = paddle.scatter_nd(index7, updates7, shape)
except Exception as e: except Exception as e:
t = "Updates has wrong shape" t = "Updates has wrong shape"
if t in str(e): if t in str(e):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册