未验证 提交 55c4eb8a 编写于 作者: M mhy-666 提交者: GitHub

【prim】scatter_nd_add_grad (#52469)

* add scatter_nd_add comp

* add scatter_nd_add prim

* fix

* fix

* add public_python_api in TestScatterNdAddSimpleOp setup function

* fix composite_backward_api.h

* fix composite_backward

* add test cases

* fix composite_backward_api.h, unittest
上级 1164626c
...@@ -1805,5 +1805,21 @@ void roll_grad(const Tensor& x, ...@@ -1805,5 +1805,21 @@ void roll_grad(const Tensor& x,
set_output<T>(x_grad_output, x_grad); set_output<T>(x_grad_output, x_grad);
} }
} }
template <typename T>
void scatter_nd_add_grad(const Tensor& index,
const Tensor& updates,
const Tensor& out_grad,
Tensor* x_grad,
Tensor* updates_grad) {
if (x_grad) {
by_pass<T>(out_grad, x_grad);
}
if (updates_grad) {
// Gradient by Gather: dUpdates = dO[Ids]
auto tmp_updates_grad = gather_nd<T>(out_grad, index);
set_output<T>(tmp_updates_grad, updates_grad);
}
}
} // namespace prim } // namespace prim
} // namespace paddle } // namespace paddle
...@@ -1454,6 +1454,7 @@ ...@@ -1454,6 +1454,7 @@
kernel : kernel :
func : scatter_nd_add_grad func : scatter_nd_add_grad
no_need_buffer : updates no_need_buffer : updates
composite: scatter_nd_add_grad(index, updates, out_grad, x_grad, updates_grad)
- backward_op : segment_pool_grad - backward_op : segment_pool_grad
forward : segment_pool (Tensor x, Tensor segment_ids, str pooltype="SUM") -> Tensor(out), Tensor(summed_ids) forward : segment_pool (Tensor x, Tensor segment_ids, str pooltype="SUM") -> Tensor(out), Tensor(summed_ids)
......
...@@ -69,6 +69,8 @@ class TestScatterNdAddSimpleOp(OpTest): ...@@ -69,6 +69,8 @@ class TestScatterNdAddSimpleOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "scatter_nd_add" self.op_type = "scatter_nd_add"
self.python_api = paddle.scatter_nd_add self.python_api = paddle.scatter_nd_add
self.public_python_api = paddle.scatter_nd_add
self.prim_op_type = "prim"
self._set_dtype() self._set_dtype()
if self.dtype == np.float64: if self.dtype == np.float64:
target_dtype = "float64" target_dtype = "float64"
...@@ -94,7 +96,7 @@ class TestScatterNdAddSimpleOp(OpTest): ...@@ -94,7 +96,7 @@ class TestScatterNdAddSimpleOp(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X', 'Updates'], 'Out') self.check_grad(['X', 'Updates'], 'Out', check_prim=True)
class TestScatterNdAddSimpleFP16Op(TestScatterNdAddSimpleOp): class TestScatterNdAddSimpleFP16Op(TestScatterNdAddSimpleOp):
...@@ -127,7 +129,9 @@ class TestScatterNdAddSimpleBF16Op(TestScatterNdAddSimpleOp): ...@@ -127,7 +129,9 @@ class TestScatterNdAddSimpleBF16Op(TestScatterNdAddSimpleOp):
def test_check_grad(self): def test_check_grad(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Updates'], 'Out') self.check_grad_with_place(
place, ['X', 'Updates'], 'Out', check_prim=True
)
class TestScatterNdAddWithEmptyIndex(OpTest): class TestScatterNdAddWithEmptyIndex(OpTest):
...@@ -138,6 +142,8 @@ class TestScatterNdAddWithEmptyIndex(OpTest): ...@@ -138,6 +142,8 @@ class TestScatterNdAddWithEmptyIndex(OpTest):
def setUp(self): def setUp(self):
self.op_type = "scatter_nd_add" self.op_type = "scatter_nd_add"
self.python_api = paddle.scatter_nd_add self.python_api = paddle.scatter_nd_add
self.public_python_api = paddle.scatter_nd_add
self.prim_op_type = "prim"
self._set_dtype() self._set_dtype()
if self.dtype == np.float64: if self.dtype == np.float64:
target_dtype = "float64" target_dtype = "float64"
...@@ -166,7 +172,7 @@ class TestScatterNdAddWithEmptyIndex(OpTest): ...@@ -166,7 +172,7 @@ class TestScatterNdAddWithEmptyIndex(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X', 'Updates'], 'Out') self.check_grad(['X', 'Updates'], 'Out', check_prim=True)
class TestScatterNdAddWithEmptyIndexFP16(TestScatterNdAddWithEmptyIndex): class TestScatterNdAddWithEmptyIndexFP16(TestScatterNdAddWithEmptyIndex):
...@@ -199,7 +205,9 @@ class TestScatterNdAddWithEmptyIndexBF16(TestScatterNdAddWithEmptyIndex): ...@@ -199,7 +205,9 @@ class TestScatterNdAddWithEmptyIndexBF16(TestScatterNdAddWithEmptyIndex):
def test_check_grad(self): def test_check_grad(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Updates'], 'Out') self.check_grad_with_place(
place, ['X', 'Updates'], 'Out', check_prim=True
)
class TestScatterNdAddWithHighRankSame(OpTest): class TestScatterNdAddWithHighRankSame(OpTest):
...@@ -210,6 +218,8 @@ class TestScatterNdAddWithHighRankSame(OpTest): ...@@ -210,6 +218,8 @@ class TestScatterNdAddWithHighRankSame(OpTest):
def setUp(self): def setUp(self):
self.op_type = "scatter_nd_add" self.op_type = "scatter_nd_add"
self.python_api = paddle.scatter_nd_add self.python_api = paddle.scatter_nd_add
self.public_python_api = paddle.scatter_nd_add
self.prim_op_type = "prim"
self._set_dtype() self._set_dtype()
if self.dtype == np.float64: if self.dtype == np.float64:
target_dtype = "float64" target_dtype = "float64"
...@@ -241,7 +251,7 @@ class TestScatterNdAddWithHighRankSame(OpTest): ...@@ -241,7 +251,7 @@ class TestScatterNdAddWithHighRankSame(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X', 'Updates'], 'Out') self.check_grad(['X', 'Updates'], 'Out', check_prim=True)
class TestScatterNdAddWithHighRankSameFP16(TestScatterNdAddWithHighRankSame): class TestScatterNdAddWithHighRankSameFP16(TestScatterNdAddWithHighRankSame):
...@@ -274,7 +284,9 @@ class TestScatterNdAddWithHighRankSameBF16(TestScatterNdAddWithHighRankSame): ...@@ -274,7 +284,9 @@ class TestScatterNdAddWithHighRankSameBF16(TestScatterNdAddWithHighRankSame):
def test_check_grad(self): def test_check_grad(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Updates'], 'Out') self.check_grad_with_place(
place, ['X', 'Updates'], 'Out', check_prim=True
)
class TestScatterNdAddWithHighRankDiff(OpTest): class TestScatterNdAddWithHighRankDiff(OpTest):
...@@ -285,6 +297,8 @@ class TestScatterNdAddWithHighRankDiff(OpTest): ...@@ -285,6 +297,8 @@ class TestScatterNdAddWithHighRankDiff(OpTest):
def setUp(self): def setUp(self):
self.op_type = "scatter_nd_add" self.op_type = "scatter_nd_add"
self.python_api = paddle.scatter_nd_add self.python_api = paddle.scatter_nd_add
self.public_python_api = paddle.scatter_nd_add
self.prim_op_type = "prim"
shape = (8, 2, 2, 1, 10) shape = (8, 2, 2, 1, 10)
ref_np = np.random.rand(*shape).astype("double") ref_np = np.random.rand(*shape).astype("double")
index = np.vstack([np.random.randint(0, s, size=500) for s in shape]).T index = np.vstack([np.random.randint(0, s, size=500) for s in shape]).T
...@@ -300,7 +314,7 @@ class TestScatterNdAddWithHighRankDiff(OpTest): ...@@ -300,7 +314,7 @@ class TestScatterNdAddWithHighRankDiff(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X', 'Updates'], 'Out') self.check_grad(['X', 'Updates'], 'Out', check_prim=True)
# Test Python API # Test Python API
...@@ -310,6 +324,7 @@ class TestScatterNdOpAPI(unittest.TestCase): ...@@ -310,6 +324,7 @@ class TestScatterNdOpAPI(unittest.TestCase):
""" """
def testcase1(self): def testcase1(self):
with paddle.fluid.framework._static_guard():
ref1 = paddle.static.data( ref1 = paddle.static.data(
name='ref1', name='ref1',
shape=[10, 9, 8, 1, 3], shape=[10, 9, 8, 1, 3],
...@@ -328,6 +343,7 @@ class TestScatterNdOpAPI(unittest.TestCase): ...@@ -328,6 +343,7 @@ class TestScatterNdOpAPI(unittest.TestCase):
output1 = paddle.scatter_nd_add(ref1, index1, updates1) output1 = paddle.scatter_nd_add(ref1, index1, updates1)
def testcase2(self): def testcase2(self):
with paddle.fluid.framework._static_guard():
ref2 = paddle.static.data( ref2 = paddle.static.data(
name='ref2', name='ref2',
shape=[10, 9, 8, 1, 3], shape=[10, 9, 8, 1, 3],
...@@ -348,6 +364,7 @@ class TestScatterNdOpAPI(unittest.TestCase): ...@@ -348,6 +364,7 @@ class TestScatterNdOpAPI(unittest.TestCase):
) )
def testcase3(self): def testcase3(self):
with paddle.fluid.framework._static_guard():
shape3 = [10, 9, 8, 1, 3] shape3 = [10, 9, 8, 1, 3]
index3 = paddle.static.data( index3 = paddle.static.data(
name='index3', name='index3',
...@@ -362,6 +379,7 @@ class TestScatterNdOpAPI(unittest.TestCase): ...@@ -362,6 +379,7 @@ class TestScatterNdOpAPI(unittest.TestCase):
output3 = paddle.scatter_nd(index3, updates3, shape3) output3 = paddle.scatter_nd(index3, updates3, shape3)
def testcase4(self): def testcase4(self):
with paddle.fluid.framework._static_guard():
shape4 = [10, 9, 8, 1, 3] shape4 = [10, 9, 8, 1, 3]
index4 = paddle.static.data( index4 = paddle.static.data(
name='index4', name='index4',
...@@ -373,7 +391,9 @@ class TestScatterNdOpAPI(unittest.TestCase): ...@@ -373,7 +391,9 @@ class TestScatterNdOpAPI(unittest.TestCase):
shape=[5, 5, 8], shape=[5, 5, 8],
dtype='double', dtype='double',
) )
output4 = paddle.scatter_nd(index4, updates4, shape4, name='scatter_nd') output4 = paddle.scatter_nd(
index4, updates4, shape4, name='scatter_nd'
)
def testcase5(self): def testcase5(self):
if not fluid.core.is_compiled_with_cuda(): if not fluid.core.is_compiled_with_cuda():
...@@ -430,6 +450,7 @@ class TestScatterNdOpAPI(unittest.TestCase): ...@@ -430,6 +450,7 @@ class TestScatterNdOpAPI(unittest.TestCase):
class TestScatterNdOpRaise(unittest.TestCase): class TestScatterNdOpRaise(unittest.TestCase):
def test_check_raise(self): def test_check_raise(self):
def check_raise_is_test(): def check_raise_is_test():
with paddle.fluid.framework._static_guard():
try: try:
ref5 = paddle.static.data( ref5 = paddle.static.data(
name='ref5', shape=[-1, 3, 4, 5], dtype='float32' name='ref5', shape=[-1, 3, 4, 5], dtype='float32'
...@@ -450,6 +471,7 @@ class TestScatterNdOpRaise(unittest.TestCase): ...@@ -450,6 +471,7 @@ class TestScatterNdOpRaise(unittest.TestCase):
def test_check_raise2(self): def test_check_raise2(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
with paddle.fluid.framework._static_guard():
ref6 = paddle.static.data( ref6 = paddle.static.data(
name='ref6', name='ref6',
shape=[10, 9, 8, 1, 3], shape=[10, 9, 8, 1, 3],
...@@ -469,13 +491,16 @@ class TestScatterNdOpRaise(unittest.TestCase): ...@@ -469,13 +491,16 @@ class TestScatterNdOpRaise(unittest.TestCase):
def test_check_raise3(self): def test_check_raise3(self):
def check_raise_is_test(): def check_raise_is_test():
with paddle.fluid.framework._static_guard():
try: try:
shape = [3, 4, 5] shape = [3, 4, 5]
index7 = paddle.static.data( index7 = paddle.static.data(
name='index7', shape=[-1, 2, 1], dtype='int32' name='index7', shape=[-1, 2, 1], dtype='int32'
) )
updates7 = paddle.static.data( updates7 = paddle.static.data(
name='updates7', shape=[-1, 2, 4, 5, 20], dtype='float32' name='updates7',
shape=[-1, 2, 4, 5, 20],
dtype='float32',
) )
output7 = paddle.scatter_nd(index7, updates7, shape) output7 = paddle.scatter_nd(index7, updates7, shape)
except Exception as e: except Exception as e:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册