提交 94206405 编写于 作者: F ForFishes

fix the doc of scatter

上级 9dbea788
...@@ -8558,7 +8558,7 @@ def scatter_nd_add(ref, index, updates, name=None): ...@@ -8558,7 +8558,7 @@ def scatter_nd_add(ref, index, updates, name=None):
if in_dygraph_mode(): if in_dygraph_mode():
op = getattr(core.ops, 'scatter_nd_add') op = getattr(core.ops, 'scatter_nd_add')
return op(ref, updates, output) return op(ref, index, updates)
if ref.dtype != updates.dtype: if ref.dtype != updates.dtype:
raise ValueError("ref and updates must have same data type.") raise ValueError("ref and updates must have same data type.")
......
...@@ -18,6 +18,7 @@ import unittest ...@@ -18,6 +18,7 @@ import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle
def numpy_scatter_nd(ref, index, updates, fun): def numpy_scatter_nd(ref, index, updates, fun):
...@@ -285,22 +286,21 @@ class TestScatterNdOpRaise(unittest.TestCase): ...@@ -285,22 +286,21 @@ class TestScatterNdOpRaise(unittest.TestCase):
class TestDygraph(unittest.TestCase): class TestDygraph(unittest.TestCase):
def test_dygraph1(self): def test_dygraph(self):
paddle.disable_static() with fluid.dygraph.guard(fluid.CPUPlace()):
index_data = np.array([[1, 1], [0, 1], [1, 3]]).astype(np.int64) index_data = np.array([[1, 1], [0, 1], [1, 3]]).astype(np.int64)
index = paddle.to_tensor(index_data) index = fluid.dygraph.to_variable(index_data)
updates = paddle.rand(shape=[3, 9, 10], dtype='float32') updates = paddle.rand(shape=[3, 9, 10], dtype='float32')
shape = [3, 5, 9, 10] shape = [3, 5, 9, 10]
output = paddle.scatter_nd(index, updates, shape)
output = paddle.scatter_nd(index, updates, shape)
def test_dygraph(self):
def test_dygraph2(self): with fluid.dygraph.guard(fluid.CPUPlace()):
paddle.disable_static() x = paddle.rand(shape=[3, 5, 9, 10], dtype='float32')
x = paddle.rand(shape=[3, 5, 9, 10], dtype='float32') updates = paddle.rand(shape=[3, 9, 10], dtype='float32')
updates = paddle.rand(shape=[3, 9, 10], dtype='float32') index_data = np.array([[1, 1], [0, 1], [1, 3]]).astype(np.int64)
index_data = np.array([[1, 1], [0, 1], [1, 3]]).astype(np.int64) index = fluid.dygraph.to_variable(index_data)
index = paddle.to_tensor(index_data) output = paddle.scatter_nd_add(x, index, updates)
output = paddle.scatter_nd_add(x, index, updates)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册