提交 94206405 编写于 作者: F ForFishes

fix the doc of scatter

上级 9dbea788
......@@ -8558,7 +8558,7 @@ def scatter_nd_add(ref, index, updates, name=None):
if in_dygraph_mode():
op = getattr(core.ops, 'scatter_nd_add')
return op(ref, updates, output)
return op(ref, index, updates)
if ref.dtype != updates.dtype:
raise ValueError("ref and updates must have same data type.")
......
......@@ -18,6 +18,7 @@ import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
import paddle
def numpy_scatter_nd(ref, index, updates, fun):
......@@ -285,21 +286,20 @@ class TestScatterNdOpRaise(unittest.TestCase):
class TestDygraph(unittest.TestCase):
def test_dygraph1(self):
paddle.disable_static()
def test_dygraph(self):
with fluid.dygraph.guard(fluid.CPUPlace()):
index_data = np.array([[1, 1], [0, 1], [1, 3]]).astype(np.int64)
index = paddle.to_tensor(index_data)
index = fluid.dygraph.to_variable(index_data)
updates = paddle.rand(shape=[3, 9, 10], dtype='float32')
shape = [3, 5, 9, 10]
output = paddle.scatter_nd(index, updates, shape)
def test_dygraph2(self):
paddle.disable_static()
def test_dygraph(self):
with fluid.dygraph.guard(fluid.CPUPlace()):
x = paddle.rand(shape=[3, 5, 9, 10], dtype='float32')
updates = paddle.rand(shape=[3, 9, 10], dtype='float32')
index_data = np.array([[1, 1], [0, 1], [1, 3]]).astype(np.int64)
index = paddle.to_tensor(index_data)
index = fluid.dygraph.to_variable(index_data)
output = paddle.scatter_nd_add(x, index, updates)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册