未验证 提交 55c4eb8a 编写于 作者: M mhy-666 提交者: GitHub

【prim】scatter_nd_add_grad (#52469)

* add scatter_nd_add comp

* add scatter_nd_add prim

* fix

* fix

* add public_python_api in TestScatterNdAddSimpleOp setup function

* fix composite_backward_api.h

* fix composite_backward

* add test cases

* fix composite_backward_api.h, unittest
上级 1164626c
......@@ -1805,5 +1805,21 @@ void roll_grad(const Tensor& x,
set_output<T>(x_grad_output, x_grad);
}
}
template <typename T>
void scatter_nd_add_grad(const Tensor& index,
const Tensor& updates,
const Tensor& out_grad,
Tensor* x_grad,
Tensor* updates_grad) {
if (x_grad) {
by_pass<T>(out_grad, x_grad);
}
if (updates_grad) {
// Gradient by Gather: dUpdates = dO[Ids]
auto tmp_updates_grad = gather_nd<T>(out_grad, index);
set_output<T>(tmp_updates_grad, updates_grad);
}
}
} // namespace prim
} // namespace paddle
......@@ -1454,6 +1454,7 @@
kernel :
func : scatter_nd_add_grad
no_need_buffer : updates
composite: scatter_nd_add_grad(index, updates, out_grad, x_grad, updates_grad)
- backward_op : segment_pool_grad
forward : segment_pool (Tensor x, Tensor segment_ids, str pooltype="SUM") -> Tensor(out), Tensor(summed_ids)
......
......@@ -69,6 +69,8 @@ class TestScatterNdAddSimpleOp(OpTest):
def setUp(self):
self.op_type = "scatter_nd_add"
self.python_api = paddle.scatter_nd_add
self.public_python_api = paddle.scatter_nd_add
self.prim_op_type = "prim"
self._set_dtype()
if self.dtype == np.float64:
target_dtype = "float64"
......@@ -94,7 +96,7 @@ class TestScatterNdAddSimpleOp(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(['X', 'Updates'], 'Out')
self.check_grad(['X', 'Updates'], 'Out', check_prim=True)
class TestScatterNdAddSimpleFP16Op(TestScatterNdAddSimpleOp):
......@@ -127,7 +129,9 @@ class TestScatterNdAddSimpleBF16Op(TestScatterNdAddSimpleOp):
def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Updates'], 'Out')
self.check_grad_with_place(
place, ['X', 'Updates'], 'Out', check_prim=True
)
class TestScatterNdAddWithEmptyIndex(OpTest):
......@@ -138,6 +142,8 @@ class TestScatterNdAddWithEmptyIndex(OpTest):
def setUp(self):
self.op_type = "scatter_nd_add"
self.python_api = paddle.scatter_nd_add
self.public_python_api = paddle.scatter_nd_add
self.prim_op_type = "prim"
self._set_dtype()
if self.dtype == np.float64:
target_dtype = "float64"
......@@ -166,7 +172,7 @@ class TestScatterNdAddWithEmptyIndex(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(['X', 'Updates'], 'Out')
self.check_grad(['X', 'Updates'], 'Out', check_prim=True)
class TestScatterNdAddWithEmptyIndexFP16(TestScatterNdAddWithEmptyIndex):
......@@ -199,7 +205,9 @@ class TestScatterNdAddWithEmptyIndexBF16(TestScatterNdAddWithEmptyIndex):
def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Updates'], 'Out')
self.check_grad_with_place(
place, ['X', 'Updates'], 'Out', check_prim=True
)
class TestScatterNdAddWithHighRankSame(OpTest):
......@@ -210,6 +218,8 @@ class TestScatterNdAddWithHighRankSame(OpTest):
def setUp(self):
self.op_type = "scatter_nd_add"
self.python_api = paddle.scatter_nd_add
self.public_python_api = paddle.scatter_nd_add
self.prim_op_type = "prim"
self._set_dtype()
if self.dtype == np.float64:
target_dtype = "float64"
......@@ -241,7 +251,7 @@ class TestScatterNdAddWithHighRankSame(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(['X', 'Updates'], 'Out')
self.check_grad(['X', 'Updates'], 'Out', check_prim=True)
class TestScatterNdAddWithHighRankSameFP16(TestScatterNdAddWithHighRankSame):
......@@ -274,7 +284,9 @@ class TestScatterNdAddWithHighRankSameBF16(TestScatterNdAddWithHighRankSame):
def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Updates'], 'Out')
self.check_grad_with_place(
place, ['X', 'Updates'], 'Out', check_prim=True
)
class TestScatterNdAddWithHighRankDiff(OpTest):
......@@ -285,6 +297,8 @@ class TestScatterNdAddWithHighRankDiff(OpTest):
def setUp(self):
self.op_type = "scatter_nd_add"
self.python_api = paddle.scatter_nd_add
self.public_python_api = paddle.scatter_nd_add
self.prim_op_type = "prim"
shape = (8, 2, 2, 1, 10)
ref_np = np.random.rand(*shape).astype("double")
index = np.vstack([np.random.randint(0, s, size=500) for s in shape]).T
......@@ -300,7 +314,7 @@ class TestScatterNdAddWithHighRankDiff(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(['X', 'Updates'], 'Out')
self.check_grad(['X', 'Updates'], 'Out', check_prim=True)
# Test Python API
......@@ -310,70 +324,76 @@ class TestScatterNdOpAPI(unittest.TestCase):
"""
def testcase1(self):
ref1 = paddle.static.data(
name='ref1',
shape=[10, 9, 8, 1, 3],
dtype='float32',
)
index1 = paddle.static.data(
name='index1',
shape=[5, 5, 8, 5],
dtype='int32',
)
updates1 = paddle.static.data(
name='update1',
shape=[5, 5, 8],
dtype='float32',
)
output1 = paddle.scatter_nd_add(ref1, index1, updates1)
with paddle.fluid.framework._static_guard():
ref1 = paddle.static.data(
name='ref1',
shape=[10, 9, 8, 1, 3],
dtype='float32',
)
index1 = paddle.static.data(
name='index1',
shape=[5, 5, 8, 5],
dtype='int32',
)
updates1 = paddle.static.data(
name='update1',
shape=[5, 5, 8],
dtype='float32',
)
output1 = paddle.scatter_nd_add(ref1, index1, updates1)
def testcase2(self):
ref2 = paddle.static.data(
name='ref2',
shape=[10, 9, 8, 1, 3],
dtype='double',
)
index2 = paddle.static.data(
name='index2',
shape=[5, 8, 5],
dtype='int32',
)
updates2 = paddle.static.data(
name='update2',
shape=[5, 8],
dtype='double',
)
output2 = paddle.scatter_nd_add(
ref2, index2, updates2, name="scatter_nd_add"
)
with paddle.fluid.framework._static_guard():
ref2 = paddle.static.data(
name='ref2',
shape=[10, 9, 8, 1, 3],
dtype='double',
)
index2 = paddle.static.data(
name='index2',
shape=[5, 8, 5],
dtype='int32',
)
updates2 = paddle.static.data(
name='update2',
shape=[5, 8],
dtype='double',
)
output2 = paddle.scatter_nd_add(
ref2, index2, updates2, name="scatter_nd_add"
)
def testcase3(self):
shape3 = [10, 9, 8, 1, 3]
index3 = paddle.static.data(
name='index3',
shape=[5, 5, 8, 5],
dtype='int32',
)
updates3 = paddle.static.data(
name='update3',
shape=[5, 5, 8],
dtype='float32',
)
output3 = paddle.scatter_nd(index3, updates3, shape3)
with paddle.fluid.framework._static_guard():
shape3 = [10, 9, 8, 1, 3]
index3 = paddle.static.data(
name='index3',
shape=[5, 5, 8, 5],
dtype='int32',
)
updates3 = paddle.static.data(
name='update3',
shape=[5, 5, 8],
dtype='float32',
)
output3 = paddle.scatter_nd(index3, updates3, shape3)
def testcase4(self):
shape4 = [10, 9, 8, 1, 3]
index4 = paddle.static.data(
name='index4',
shape=[5, 5, 8, 5],
dtype='int32',
)
updates4 = paddle.static.data(
name='update4',
shape=[5, 5, 8],
dtype='double',
)
output4 = paddle.scatter_nd(index4, updates4, shape4, name='scatter_nd')
with paddle.fluid.framework._static_guard():
shape4 = [10, 9, 8, 1, 3]
index4 = paddle.static.data(
name='index4',
shape=[5, 5, 8, 5],
dtype='int32',
)
updates4 = paddle.static.data(
name='update4',
shape=[5, 5, 8],
dtype='double',
)
output4 = paddle.scatter_nd(
index4, updates4, shape4, name='scatter_nd'
)
def testcase5(self):
if not fluid.core.is_compiled_with_cuda():
......@@ -430,60 +450,65 @@ class TestScatterNdOpAPI(unittest.TestCase):
class TestScatterNdOpRaise(unittest.TestCase):
def test_check_raise(self):
def check_raise_is_test():
try:
ref5 = paddle.static.data(
name='ref5', shape=[-1, 3, 4, 5], dtype='float32'
)
index5 = paddle.static.data(
name='index5', shape=[-1, 2, 10], dtype='int32'
)
updates5 = paddle.static.data(
name='updates5', shape=[-1, 2, 10], dtype='float32'
)
output5 = paddle.scatter_nd_add(ref5, index5, updates5)
except Exception as e:
t = "The last dimension of Input(Index)'s shape should be no greater "
if t in str(e):
raise IndexError
with paddle.fluid.framework._static_guard():
try:
ref5 = paddle.static.data(
name='ref5', shape=[-1, 3, 4, 5], dtype='float32'
)
index5 = paddle.static.data(
name='index5', shape=[-1, 2, 10], dtype='int32'
)
updates5 = paddle.static.data(
name='updates5', shape=[-1, 2, 10], dtype='float32'
)
output5 = paddle.scatter_nd_add(ref5, index5, updates5)
except Exception as e:
t = "The last dimension of Input(Index)'s shape should be no greater "
if t in str(e):
raise IndexError
self.assertRaises(IndexError, check_raise_is_test)
def test_check_raise2(self):
with self.assertRaises(ValueError):
ref6 = paddle.static.data(
name='ref6',
shape=[10, 9, 8, 1, 3],
dtype='double',
)
index6 = paddle.static.data(
name='index6',
shape=[5, 8, 5],
dtype='int32',
)
updates6 = paddle.static.data(
name='update6',
shape=[5, 8],
dtype='float32',
)
output6 = paddle.scatter_nd_add(ref6, index6, updates6)
def test_check_raise3(self):
def check_raise_is_test():
try:
shape = [3, 4, 5]
index7 = paddle.static.data(
name='index7', shape=[-1, 2, 1], dtype='int32'
with paddle.fluid.framework._static_guard():
ref6 = paddle.static.data(
name='ref6',
shape=[10, 9, 8, 1, 3],
dtype='double',
)
index6 = paddle.static.data(
name='index6',
shape=[5, 8, 5],
dtype='int32',
)
updates7 = paddle.static.data(
name='updates7', shape=[-1, 2, 4, 5, 20], dtype='float32'
updates6 = paddle.static.data(
name='update6',
shape=[5, 8],
dtype='float32',
)
output7 = paddle.scatter_nd(index7, updates7, shape)
except Exception as e:
t = "Updates has wrong shape"
if t in str(e):
raise ValueError
output6 = paddle.scatter_nd_add(ref6, index6, updates6)
self.assertRaises(ValueError, check_raise_is_test)
def test_check_raise3(self):
def check_raise_is_test():
with paddle.fluid.framework._static_guard():
try:
shape = [3, 4, 5]
index7 = paddle.static.data(
name='index7', shape=[-1, 2, 1], dtype='int32'
)
updates7 = paddle.static.data(
name='updates7',
shape=[-1, 2, 4, 5, 20],
dtype='float32',
)
output7 = paddle.scatter_nd(index7, updates7, shape)
except Exception as e:
t = "Updates has wrong shape"
if t in str(e):
raise ValueError
self.assertRaises(ValueError, check_raise_is_test)
class TestDygraph(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册