未验证 提交 190c6957 编写于 作者: S ShenLiang 提交者: GitHub

fix scatter doc (#26248)

* fix the comment of scatter
上级 aa2a9b5d
...@@ -16,6 +16,8 @@ from __future__ import print_function ...@@ -16,6 +16,8 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle
import paddle.fluid as fluid
from op_test import OpTest from op_test import OpTest
import paddle.fluid.core as core import paddle.fluid.core as core
...@@ -173,5 +175,55 @@ class TestScatterOp5(OpTest): ...@@ -173,5 +175,55 @@ class TestScatterOp5(OpTest):
self.check_grad_with_place(place, ['Updates'], 'Out', in_place=True) self.check_grad_with_place(place, ['Updates'], 'Out', in_place=True)
class TestScatterAPI(unittest.TestCase):
def setUp(self):
self.places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
self.places.append(fluid.CUDAPlace(0))
def check_static_result(self, place):
with fluid.program_guard(fluid.Program(), fluid.Program()):
input = fluid.data(name="input", shape=[3, 2], dtype="float64")
index = fluid.data(name="index", shape=[4], dtype="int64")
updates = fluid.data(name="updates", shape=[4, 2], dtype="float64")
result = paddle.scatter(input, index, updates, False)
input_data = np.array([[1, 1], [2, 2], [3, 3]]).astype(np.float64)
index_data = np.array([2, 1, 0, 1]).astype(np.int64)
updates_data = np.array(
[[1, 1], [2, 2], [3, 3], [4, 4]]).astype(np.float64)
exe = fluid.Executor(place)
fetches = exe.run(fluid.default_main_program(),
feed={
"input": input_data,
"index": index_data,
"updates": updates_data
},
fetch_list=[result])
self.assertEqual((fetches[0] == \
np.array([[3., 3.],[6., 6.],[1., 1.]])).all(), True)
def test_static(self):
for place in self.places:
self.check_static_result(place=place)
def test_dygraph(self):
for place in self.places:
with fluid.dygraph.guard(place):
x_data = np.array([[1, 1], [2, 2], [3, 3]]).astype(np.float64)
index_data = np.array([2, 1, 0, 1]).astype(np.int64)
updates_data = np.array(
[[1, 1], [2, 2], [3, 3], [4, 4]]).astype(np.float64)
x = fluid.dygraph.to_variable(x_data)
index = fluid.dygraph.to_variable(index_data)
updates = fluid.dygraph.to_variable(updates_data)
output1 = paddle.scatter(x, index, updates, overwrite=False)
self.assertEqual((output1.numpy() == \
np.array([[3., 3.],[6., 6.],[1., 1.]])).all(), True)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -24,7 +24,6 @@ import numpy as np ...@@ -24,7 +24,6 @@ import numpy as np
# TODO: define functions to manipulate a tensor # TODO: define functions to manipulate a tensor
from ..fluid.layers import cast #DEFINE_ALIAS from ..fluid.layers import cast #DEFINE_ALIAS
from ..fluid.layers import expand_as #DEFINE_ALIAS from ..fluid.layers import expand_as #DEFINE_ALIAS
from ..fluid.layers import scatter #DEFINE_ALIAS
from ..fluid.layers import slice #DEFINE_ALIAS from ..fluid.layers import slice #DEFINE_ALIAS
from ..fluid.layers import strided_slice #DEFINE_ALIAS from ..fluid.layers import strided_slice #DEFINE_ALIAS
from ..fluid.layers import transpose #DEFINE_ALIAS from ..fluid.layers import transpose #DEFINE_ALIAS
...@@ -790,6 +789,100 @@ def unbind(input, axis=0): ...@@ -790,6 +789,100 @@ def unbind(input, axis=0):
return outs return outs
def scatter(x, index, updates, overwrite=True, name=None):
"""
**Scatter Layer**
Output is obtained by updating the input on selected indices based on updates.
.. code-block:: python
import numpy as np
#input:
x = np.array([[1, 1], [2, 2], [3, 3]])
index = np.array([2, 1, 0, 1])
# shape of updates should be the same as x
# shape of updates with dim > 1 should be the same as input
updates = np.array([[1, 1], [2, 2], [3, 3], [4, 4]])
overwrite = False
# calculation:
if not overwrite:
for i in range(len(index)):
x[index[i]] = np.zeros((2))
for i in range(len(index)):
if (overwrite):
x[index[i]] = updates[i]
else:
x[index[i]] += updates[i]
# output:
out = np.array([[3, 3], [6, 6], [1, 1]])
out.shape # [3, 2]
**NOTICE**: The order in which updates are applied is nondeterministic,
so the output will be nondeterministic if index contains duplicates.
Args:
x (Tensor): The input N-D Tensor with ndim>=1. Data type can be float32, float64.
index (Tensor): The index 1-D Tensor. Data type can be int32, int64. The length of index cannot exceed updates's length, and the value in index cannot exceed input's length.
updates (Tensor): update input with updates parameter based on index. shape should be the same as input, and dim value with dim > 1 should be the same as input.
overwrite (bool): The mode that updating the output when there are same indices.
If True, use the overwrite mode to update the output of the same index,
if False, use the accumulate mode to update the output of the same index.Default value is True.
name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` .
Returns:
Tensor: The output is a Tensor with the same shape as x.
Examples:
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static()
x_data = np.array([[1, 1], [2, 2], [3, 3]]).astype(np.float32)
index_data = np.array([2, 1, 0, 1]).astype(np.int64)
updates_data = np.array([[1, 1], [2, 2], [3, 3], [4, 4]]).astype(np.float32)
x = paddle.to_tensor(x_data)
index = paddle.to_tensor(index_data)
updates = paddle.to_tensor(updates_data)
output1 = paddle.scatter(x, index, updates, overwrite=False)
# [[3., 3.],
# [6., 6.],
# [1., 1.]]
output2 = paddle.scatter(x, index, updates, overwrite=True)
# CPU device:
# [[3., 3.],
# [4., 4.],
# [1., 1.]]
# GPU device maybe have two results because of the repeated numbers in index
# result 1:
# [[3., 3.],
# [4., 4.],
# [1., 1.]]
# result 2:
# [[3., 3.],
# [2., 2.],
# [1., 1.]]
"""
if in_dygraph_mode():
return core.ops.scatter(x, index, updates, 'overwrite', overwrite)
check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'scatter')
check_type(overwrite, 'overwrite', bool, 'scatter')
helper = LayerHelper('scatter', **locals())
out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(
type="scatter",
inputs={"X": x,
"Ids": index,
"Updates": updates},
attrs={'overwrite': overwrite},
outputs={"Out": out})
return out
def chunk(x, chunks, axis=0, name=None): def chunk(x, chunks, axis=0, name=None):
""" """
Split the input tensor into multiple sub-Tensors. Split the input tensor into multiple sub-Tensors.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册