From 94206405979a7190b992a95345ca6c79aeec22b7 Mon Sep 17 00:00:00 2001 From: ForFishes <2282912238@qq.com> Date: Sun, 27 Sep 2020 21:47:52 +0800 Subject: [PATCH] fix the doc of scatter --- python/paddle/fluid/layers/nn.py | 2 +- .../tests/unittests/test_scatter_nd_op.py | 32 +++++++++---------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 7bde1d10be4..0990a7a5e20 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -8558,7 +8558,7 @@ def scatter_nd_add(ref, index, updates, name=None): if in_dygraph_mode(): op = getattr(core.ops, 'scatter_nd_add') - return op(ref, updates, output) + return op(ref, index, updates) if ref.dtype != updates.dtype: raise ValueError("ref and updates must have same data type.") diff --git a/python/paddle/fluid/tests/unittests/test_scatter_nd_op.py b/python/paddle/fluid/tests/unittests/test_scatter_nd_op.py index 09b87fd8a5a..90aae939a61 100644 --- a/python/paddle/fluid/tests/unittests/test_scatter_nd_op.py +++ b/python/paddle/fluid/tests/unittests/test_scatter_nd_op.py @@ -18,6 +18,7 @@ import unittest import numpy as np from op_test import OpTest import paddle.fluid as fluid +import paddle def numpy_scatter_nd(ref, index, updates, fun): @@ -285,22 +286,21 @@ class TestScatterNdOpRaise(unittest.TestCase): class TestDygraph(unittest.TestCase): - def test_dygraph1(self): - paddle.disable_static() - index_data = np.array([[1, 1], [0, 1], [1, 3]]).astype(np.int64) - index = paddle.to_tensor(index_data) - updates = paddle.rand(shape=[3, 9, 10], dtype='float32') - shape = [3, 5, 9, 10] - - output = paddle.scatter_nd(index, updates, shape) - - def test_dygraph2(self): - paddle.disable_static() - x = paddle.rand(shape=[3, 5, 9, 10], dtype='float32') - updates = paddle.rand(shape=[3, 9, 10], dtype='float32') - index_data = np.array([[1, 1], [0, 1], [1, 3]]).astype(np.int64) - index = paddle.to_tensor(index_data) - output = paddle.scatter_nd_add(x, index, updates) + def test_dygraph(self): + with fluid.dygraph.guard(fluid.CPUPlace()): + index_data = np.array([[1, 1], [0, 1], [1, 3]]).astype(np.int64) + index = fluid.dygraph.to_variable(index_data) + updates = paddle.rand(shape=[3, 9, 10], dtype='float32') + shape = [3, 5, 9, 10] + output = paddle.scatter_nd(index, updates, shape) + + def test_dygraph(self): + with fluid.dygraph.guard(fluid.CPUPlace()): + x = paddle.rand(shape=[3, 5, 9, 10], dtype='float32') + updates = paddle.rand(shape=[3, 9, 10], dtype='float32') + index_data = np.array([[1, 1], [0, 1], [1, 3]]).astype(np.int64) + index = fluid.dygraph.to_variable(index_data) + output = paddle.scatter_nd_add(x, index, updates) if __name__ == "__main__": -- GitLab