未验证 提交 e16eb22c 编写于 作者: Z zxcd 提交者: GitHub

add scatter composite rule. (#52005)

* add scatter composite rule.

* add public_python_api

* add python unit16 support.

* fix code style.

* add cinn to makelist

* cinn unsupport uint16, forbidden cinn when dtype==uint16.
上级 fb16bdc7
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
- concat - concat
- elementwise_pow - elementwise_pow
- floor - floor
- gather
- gather_nd - gather_nd
- log - log
- max - max
......
...@@ -1222,6 +1222,27 @@ void cos_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { ...@@ -1222,6 +1222,27 @@ void cos_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
set_output<T>(x_grad_tmp, x_grad); set_output<T>(x_grad_tmp, x_grad);
} }
template <typename T>
void scatter_grad(const Tensor& index,
const Tensor& updates,
const Tensor& out_grad,
bool overwrite,
Tensor* x_grad,
Tensor* updates_grad) {
if (x_grad) {
auto zero_tensor =
full<T>(phi::vectorize(updates.dims()), 0.0, updates.dtype());
auto tmp_grad = scatter<T>(out_grad, index, zero_tensor, false);
set_output<T>(tmp_grad, x_grad);
}
if (updates_grad) {
Scalar tmp_zero = 0;
auto tmp_updates_grad = gather<T>(out_grad, index, tmp_zero);
set_output<T>(tmp_updates_grad, updates_grad);
}
}
template <typename T> template <typename T>
void batch_norm_grad(const Tensor& x, void batch_norm_grad(const Tensor& x,
const Tensor& scale, const Tensor& scale,
......
...@@ -1318,6 +1318,7 @@ ...@@ -1318,6 +1318,7 @@
kernel : kernel :
func : scatter_grad func : scatter_grad
no_need_buffer : updates no_need_buffer : updates
composite: scatter_grad(index, updates, out_grad, overwrite, x_grad, updates_grad)
- backward_op : scatter_nd_add_grad - backward_op : scatter_nd_add_grad
forward : scatter_nd_add (Tensor x, Tensor index, Tensor updates) -> Tensor(out) forward : scatter_nd_add (Tensor x, Tensor index, Tensor updates) -> Tensor(out)
......
...@@ -1197,6 +1197,7 @@ set(TEST_CINN_OPS ...@@ -1197,6 +1197,7 @@ set(TEST_CINN_OPS
test_mean_op test_mean_op
test_unsqueeze2_op test_unsqueeze2_op
test_meshgrid_op test_meshgrid_op
test_scatter_op
test_gather_op test_gather_op
test_cast_op test_cast_op
test_dropout_op test_dropout_op
......
...@@ -28,6 +28,8 @@ class TestScatterOp(OpTest): ...@@ -28,6 +28,8 @@ class TestScatterOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "scatter" self.op_type = "scatter"
self.python_api = paddle.scatter self.python_api = paddle.scatter
self.public_python_api = paddle.scatter
self.prim_op_type = "prim"
self._set_dtype() self._set_dtype()
target_dtype = "float16" if self.dtype == np.float16 else "float32" target_dtype = "float16" if self.dtype == np.float16 else "float32"
ref_np = np.ones((3, 50)).astype(target_dtype) ref_np = np.ones((3, 50)).astype(target_dtype)
...@@ -46,10 +48,10 @@ class TestScatterOp(OpTest): ...@@ -46,10 +48,10 @@ class TestScatterOp(OpTest):
self.dtype = np.float32 self.dtype = np.float32
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_prim=True)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["X", "Updates"], "Out") self.check_grad(["X", "Updates"], "Out", check_prim=True)
class TestScatterFP16Op(TestScatterOp): class TestScatterFP16Op(TestScatterOp):
...@@ -65,22 +67,30 @@ class TestScatterFP16Op(TestScatterOp): ...@@ -65,22 +67,30 @@ class TestScatterFP16Op(TestScatterOp):
class TestScatterBF16Op(TestScatterOp): class TestScatterBF16Op(TestScatterOp):
def _set_dtype(self): def _set_dtype(self):
self.dtype = np.uint16 self.dtype = np.uint16
self.enable_cinn = False
def test_check_output(self): def test_check_output(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_output_with_place(place) self.check_output_with_place(place, check_prim=True)
def test_check_grad(self): def test_check_grad(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Updates'], 'Out') self.check_grad_with_place(
place,
['X', 'Updates'],
'Out',
check_prim=True,
)
class TestScatterOp0(OpTest): class TestScatterOp0(OpTest):
def setUp(self): def setUp(self):
self.op_type = "scatter" self.op_type = "scatter"
self.python_api = paddle.scatter self.python_api = paddle.scatter
self.public_python_api = paddle.scatter
self.prim_op_type = "prim"
self._set_dtype() self._set_dtype()
target_dtype = "float16" if self.dtype == np.float16 else "float32" target_dtype = "float16" if self.dtype == np.float16 else "float32"
ref_np = np.ones((3, 3)).astype(target_dtype) ref_np = np.ones((3, 3)).astype(target_dtype)
...@@ -100,10 +110,10 @@ class TestScatterOp0(OpTest): ...@@ -100,10 +110,10 @@ class TestScatterOp0(OpTest):
self.dtype = np.float32 self.dtype = np.float32
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_prim=True)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["X", "Updates"], "Out") self.check_grad(["X", "Updates"], "Out", check_prim=True)
class TestScatterFP16Op0(TestScatterOp0): class TestScatterFP16Op0(TestScatterOp0):
...@@ -119,22 +129,30 @@ class TestScatterFP16Op0(TestScatterOp0): ...@@ -119,22 +129,30 @@ class TestScatterFP16Op0(TestScatterOp0):
class TestScatterBF16Op0(TestScatterOp0): class TestScatterBF16Op0(TestScatterOp0):
def _set_dtype(self): def _set_dtype(self):
self.dtype = np.uint16 self.dtype = np.uint16
self.enable_cinn = False
def test_check_output(self): def test_check_output(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_output_with_place(place) self.check_output_with_place(place, check_prim=True)
def test_check_grad(self): def test_check_grad(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Updates'], 'Out') self.check_grad_with_place(
place,
['X', 'Updates'],
'Out',
check_prim=True,
)
class TestScatterOp1(OpTest): class TestScatterOp1(OpTest):
def setUp(self): def setUp(self):
self.op_type = "scatter" self.op_type = "scatter"
self.python_api = paddle.scatter self.python_api = paddle.scatter
self.public_python_api = paddle.scatter
self.prim_op_type = "prim"
self._set_dtype() self._set_dtype()
target_dtype = "float16" if self.dtype == np.float16 else "float32" target_dtype = "float16" if self.dtype == np.float16 else "float32"
ref_np = np.ones((3, 3)).astype(target_dtype) ref_np = np.ones((3, 3)).astype(target_dtype)
...@@ -157,10 +175,10 @@ class TestScatterOp1(OpTest): ...@@ -157,10 +175,10 @@ class TestScatterOp1(OpTest):
self.dtype = np.float32 self.dtype = np.float32
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_prim=True)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["X", "Updates"], "Out") self.check_grad(["X", "Updates"], "Out", check_prim=True)
class TestScatterFP16Op1(TestScatterOp1): class TestScatterFP16Op1(TestScatterOp1):
...@@ -176,16 +194,22 @@ class TestScatterFP16Op1(TestScatterOp1): ...@@ -176,16 +194,22 @@ class TestScatterFP16Op1(TestScatterOp1):
class TestScatterBF16Op1(TestScatterOp1): class TestScatterBF16Op1(TestScatterOp1):
def _set_dtype(self): def _set_dtype(self):
self.dtype = np.uint16 self.dtype = np.uint16
self.enable_cinn = False
def test_check_output(self): def test_check_output(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_output_with_place(place) self.check_output_with_place(place, check_prim=True)
def test_check_grad(self): def test_check_grad(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Updates'], 'Out') self.check_grad_with_place(
place,
['X', 'Updates'],
'Out',
check_prim=True,
)
@unittest.skipIf( @unittest.skipIf(
...@@ -195,6 +219,8 @@ class TestScatterOp2(OpTest): ...@@ -195,6 +219,8 @@ class TestScatterOp2(OpTest):
def setUp(self): def setUp(self):
self.op_type = "scatter" self.op_type = "scatter"
self.python_api = paddle.scatter self.python_api = paddle.scatter
self.public_python_api = paddle.scatter
self.prim_op_type = "prim"
self._set_dtype() self._set_dtype()
target_dtype = "float16" if self.dtype == np.float16 else "float32" target_dtype = "float16" if self.dtype == np.float16 else "float32"
ref_np = np.ones((3, 3)).astype(target_dtype) ref_np = np.ones((3, 3)).astype(target_dtype)
...@@ -215,12 +241,17 @@ class TestScatterOp2(OpTest): ...@@ -215,12 +241,17 @@ class TestScatterOp2(OpTest):
def test_check_output(self): def test_check_output(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-3) self.check_output_with_place(place, atol=1e-3, check_prim=True)
def test_check_grad(self): def test_check_grad(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Updates'], 'Out') self.check_grad_with_place(
place,
['X', 'Updates'],
'Out',
check_prim=True,
)
@unittest.skipIf( @unittest.skipIf(
...@@ -239,6 +270,7 @@ class TestScatterFP16Op2(TestScatterOp2): ...@@ -239,6 +270,7 @@ class TestScatterFP16Op2(TestScatterOp2):
class TestScatterBF16Op2(TestScatterOp2): class TestScatterBF16Op2(TestScatterOp2):
def _set_dtype(self): def _set_dtype(self):
self.dtype = np.uint16 self.dtype = np.uint16
self.enable_cinn = False
@unittest.skipIf( @unittest.skipIf(
...@@ -248,6 +280,8 @@ class TestScatterOp3(OpTest): ...@@ -248,6 +280,8 @@ class TestScatterOp3(OpTest):
def setUp(self): def setUp(self):
self.op_type = "scatter" self.op_type = "scatter"
self.python_api = paddle.scatter self.python_api = paddle.scatter
self.public_python_api = paddle.scatter
self.prim_op_type = "prim"
self._set_dtype() self._set_dtype()
target_dtype = "float16" if self.dtype == np.float16 else "float32" target_dtype = "float16" if self.dtype == np.float16 else "float32"
ref_np = np.ones((3, 3)).astype(target_dtype) ref_np = np.ones((3, 3)).astype(target_dtype)
...@@ -272,12 +306,17 @@ class TestScatterOp3(OpTest): ...@@ -272,12 +306,17 @@ class TestScatterOp3(OpTest):
def test_check_output(self): def test_check_output(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-3) self.check_output_with_place(place, atol=1e-3, check_prim=True)
def test_check_grad(self): def test_check_grad(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Updates'], 'Out') self.check_grad_with_place(
place,
['X', 'Updates'],
'Out',
check_prim=True,
)
@unittest.skipIf( @unittest.skipIf(
...@@ -296,12 +335,15 @@ class TestScatterFP16Op3(TestScatterOp3): ...@@ -296,12 +335,15 @@ class TestScatterFP16Op3(TestScatterOp3):
class TestScatterBF16Op3(TestScatterOp3): class TestScatterBF16Op3(TestScatterOp3):
def _set_dtype(self): def _set_dtype(self):
self.dtype = np.uint16 self.dtype = np.uint16
self.enable_cinn = False
class TestScatterOp4(OpTest): class TestScatterOp4(OpTest):
def setUp(self): def setUp(self):
self.op_type = "scatter" self.op_type = "scatter"
self.python_api = paddle.scatter self.python_api = paddle.scatter
self.public_python_api = paddle.scatter
self.prim_op_type = "prim"
self._set_dtype() self._set_dtype()
target_dtype = "float16" if self.dtype == np.float16 else "float32" target_dtype = "float16" if self.dtype == np.float16 else "float32"
ref_np = np.ones((3, 3)).astype(target_dtype) ref_np = np.ones((3, 3)).astype(target_dtype)
...@@ -320,10 +362,10 @@ class TestScatterOp4(OpTest): ...@@ -320,10 +362,10 @@ class TestScatterOp4(OpTest):
self.dtype = np.float32 self.dtype = np.float32
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_prim=True)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X', 'Updates'], 'Out') self.check_grad(['X', 'Updates'], 'Out', check_prim=True)
class TestScatterFP16Op4(TestScatterOp4): class TestScatterFP16Op4(TestScatterOp4):
...@@ -339,16 +381,22 @@ class TestScatterFP16Op4(TestScatterOp4): ...@@ -339,16 +381,22 @@ class TestScatterFP16Op4(TestScatterOp4):
class TestScatterBF16Op4(TestScatterOp4): class TestScatterBF16Op4(TestScatterOp4):
def _set_dtype(self): def _set_dtype(self):
self.dtype = np.uint16 self.dtype = np.uint16
self.enable_cinn = False
def test_check_output(self): def test_check_output(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_output_with_place(place) self.check_output_with_place(place, check_prim=True)
def test_check_grad(self): def test_check_grad(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Updates'], 'Out') self.check_grad_with_place(
place,
['X', 'Updates'],
'Out',
check_prim=True,
)
@unittest.skipIf( @unittest.skipIf(
...@@ -358,6 +406,8 @@ class TestScatterOp5(OpTest): ...@@ -358,6 +406,8 @@ class TestScatterOp5(OpTest):
def setUp(self): def setUp(self):
self.op_type = "scatter" self.op_type = "scatter"
self.python_api = paddle.scatter self.python_api = paddle.scatter
self.public_python_api = paddle.scatter
self.prim_op_type = "prim"
self._set_dtype() self._set_dtype()
target_dtype = "float16" if self.dtype == np.float16 else "float32" target_dtype = "float16" if self.dtype == np.float16 else "float32"
ref_np = np.ones((3, 3)).astype(target_dtype) ref_np = np.ones((3, 3)).astype(target_dtype)
...@@ -378,12 +428,17 @@ class TestScatterOp5(OpTest): ...@@ -378,12 +428,17 @@ class TestScatterOp5(OpTest):
def test_check_output(self): def test_check_output(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-3) self.check_output_with_place(place, atol=1e-3, check_prim=True)
def test_check_grad(self): def test_check_grad(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Updates'], 'Out') self.check_grad_with_place(
place,
['X', 'Updates'],
'Out',
check_prim=True,
)
@unittest.skipIf( @unittest.skipIf(
...@@ -402,12 +457,16 @@ class TestScatterFP16Op5(TestScatterOp5): ...@@ -402,12 +457,16 @@ class TestScatterFP16Op5(TestScatterOp5):
class TestScatterBF16Op5(TestScatterOp5): class TestScatterBF16Op5(TestScatterOp5):
def _set_dtype(self): def _set_dtype(self):
self.dtype = np.uint16 self.dtype = np.uint16
self.enable_cinn = False
class TestScatterOp6(OpTest): class TestScatterOp6(OpTest):
def setUp(self): def setUp(self):
self.op_type = "scatter" self.op_type = "scatter"
self.python_api = paddle.scatter self.python_api = paddle.scatter
self.public_python_api = paddle.scatter
self.prim_op_type = "prim"
self.enable_cinn = False
self._set_dtype() self._set_dtype()
target_dtype = "float16" if self.dtype == np.float16 else "float32" target_dtype = "float16" if self.dtype == np.float16 else "float32"
ref_np = np.ones((3, 50)).astype(target_dtype) ref_np = np.ones((3, 50)).astype(target_dtype)
...@@ -426,10 +485,10 @@ class TestScatterOp6(OpTest): ...@@ -426,10 +485,10 @@ class TestScatterOp6(OpTest):
self.dtype = np.float32 self.dtype = np.float32
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_prim=True)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["X", "Updates"], "Out") self.check_grad(["X", "Updates"], "Out", check_prim=True)
class TestScatterFP16Op6(TestScatterOp6): class TestScatterFP16Op6(TestScatterOp6):
...@@ -449,12 +508,17 @@ class TestScatterBF16Op6(TestScatterOp6): ...@@ -449,12 +508,17 @@ class TestScatterBF16Op6(TestScatterOp6):
def test_check_output(self): def test_check_output(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_output_with_place(place) self.check_output_with_place(place, check_prim=True)
def test_check_grad(self): def test_check_grad(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Updates'], 'Out') self.check_grad_with_place(
place,
['X', 'Updates'],
'Out',
check_prim=True,
)
class TestScatterAPI(unittest.TestCase): class TestScatterAPI(unittest.TestCase):
......
...@@ -2903,7 +2903,7 @@ def scatter(x, index, updates, overwrite=True, name=None): ...@@ -2903,7 +2903,7 @@ def scatter(x, index, updates, overwrite=True, name=None):
check_variable_and_dtype( check_variable_and_dtype(
x, x,
'dtype', 'dtype',
['float32', 'float64', 'float16', 'int32', 'int64'], ['float32', 'float64', 'float16', 'int32', 'int64', 'uint16'],
'scatter', 'scatter',
) )
check_type(overwrite, 'overwrite', bool, 'scatter') check_type(overwrite, 'overwrite', bool, 'scatter')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册