未验证 提交 e16eb22c 编写于 作者: Z zxcd 提交者: GitHub

add scatter composite rule. (#52005)

* add scatter composite rule.

* add public_python_api

* add python unit16 support.

* fix code style.

* add cinn to makelist

* cinn unsupport uint16, forbidden cinn when dtype==uint16.
上级 fb16bdc7
......@@ -23,6 +23,7 @@
- concat
- elementwise_pow
- floor
- gather
- gather_nd
- log
- max
......
......@@ -1222,6 +1222,27 @@ void cos_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
set_output<T>(x_grad_tmp, x_grad);
}
template <typename T>
void scatter_grad(const Tensor& index,
const Tensor& updates,
const Tensor& out_grad,
bool overwrite,
Tensor* x_grad,
Tensor* updates_grad) {
if (x_grad) {
auto zero_tensor =
full<T>(phi::vectorize(updates.dims()), 0.0, updates.dtype());
auto tmp_grad = scatter<T>(out_grad, index, zero_tensor, false);
set_output<T>(tmp_grad, x_grad);
}
if (updates_grad) {
Scalar tmp_zero = 0;
auto tmp_updates_grad = gather<T>(out_grad, index, tmp_zero);
set_output<T>(tmp_updates_grad, updates_grad);
}
}
template <typename T>
void batch_norm_grad(const Tensor& x,
const Tensor& scale,
......
......@@ -1318,6 +1318,7 @@
kernel :
func : scatter_grad
no_need_buffer : updates
composite: scatter_grad(index, updates, out_grad, overwrite, x_grad, updates_grad)
- backward_op : scatter_nd_add_grad
forward : scatter_nd_add (Tensor x, Tensor index, Tensor updates) -> Tensor(out)
......
......@@ -1197,6 +1197,7 @@ set(TEST_CINN_OPS
test_mean_op
test_unsqueeze2_op
test_meshgrid_op
test_scatter_op
test_gather_op
test_cast_op
test_dropout_op
......
......@@ -28,6 +28,8 @@ class TestScatterOp(OpTest):
def setUp(self):
self.op_type = "scatter"
self.python_api = paddle.scatter
self.public_python_api = paddle.scatter
self.prim_op_type = "prim"
self._set_dtype()
target_dtype = "float16" if self.dtype == np.float16 else "float32"
ref_np = np.ones((3, 50)).astype(target_dtype)
......@@ -46,10 +48,10 @@ class TestScatterOp(OpTest):
self.dtype = np.float32
def test_check_output(self):
self.check_output()
self.check_output(check_prim=True)
def test_check_grad(self):
self.check_grad(["X", "Updates"], "Out")
self.check_grad(["X", "Updates"], "Out", check_prim=True)
class TestScatterFP16Op(TestScatterOp):
......@@ -65,22 +67,30 @@ class TestScatterFP16Op(TestScatterOp):
class TestScatterBF16Op(TestScatterOp):
def _set_dtype(self):
self.dtype = np.uint16
self.enable_cinn = False
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(place, check_prim=True)
def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Updates'], 'Out')
self.check_grad_with_place(
place,
['X', 'Updates'],
'Out',
check_prim=True,
)
class TestScatterOp0(OpTest):
def setUp(self):
self.op_type = "scatter"
self.python_api = paddle.scatter
self.public_python_api = paddle.scatter
self.prim_op_type = "prim"
self._set_dtype()
target_dtype = "float16" if self.dtype == np.float16 else "float32"
ref_np = np.ones((3, 3)).astype(target_dtype)
......@@ -100,10 +110,10 @@ class TestScatterOp0(OpTest):
self.dtype = np.float32
def test_check_output(self):
self.check_output()
self.check_output(check_prim=True)
def test_check_grad(self):
self.check_grad(["X", "Updates"], "Out")
self.check_grad(["X", "Updates"], "Out", check_prim=True)
class TestScatterFP16Op0(TestScatterOp0):
......@@ -119,22 +129,30 @@ class TestScatterFP16Op0(TestScatterOp0):
class TestScatterBF16Op0(TestScatterOp0):
def _set_dtype(self):
self.dtype = np.uint16
self.enable_cinn = False
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(place, check_prim=True)
def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Updates'], 'Out')
self.check_grad_with_place(
place,
['X', 'Updates'],
'Out',
check_prim=True,
)
class TestScatterOp1(OpTest):
def setUp(self):
self.op_type = "scatter"
self.python_api = paddle.scatter
self.public_python_api = paddle.scatter
self.prim_op_type = "prim"
self._set_dtype()
target_dtype = "float16" if self.dtype == np.float16 else "float32"
ref_np = np.ones((3, 3)).astype(target_dtype)
......@@ -157,10 +175,10 @@ class TestScatterOp1(OpTest):
self.dtype = np.float32
def test_check_output(self):
self.check_output()
self.check_output(check_prim=True)
def test_check_grad(self):
self.check_grad(["X", "Updates"], "Out")
self.check_grad(["X", "Updates"], "Out", check_prim=True)
class TestScatterFP16Op1(TestScatterOp1):
......@@ -176,16 +194,22 @@ class TestScatterFP16Op1(TestScatterOp1):
class TestScatterBF16Op1(TestScatterOp1):
def _set_dtype(self):
self.dtype = np.uint16
self.enable_cinn = False
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(place, check_prim=True)
def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Updates'], 'Out')
self.check_grad_with_place(
place,
['X', 'Updates'],
'Out',
check_prim=True,
)
@unittest.skipIf(
......@@ -195,6 +219,8 @@ class TestScatterOp2(OpTest):
def setUp(self):
self.op_type = "scatter"
self.python_api = paddle.scatter
self.public_python_api = paddle.scatter
self.prim_op_type = "prim"
self._set_dtype()
target_dtype = "float16" if self.dtype == np.float16 else "float32"
ref_np = np.ones((3, 3)).astype(target_dtype)
......@@ -215,12 +241,17 @@ class TestScatterOp2(OpTest):
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-3)
self.check_output_with_place(place, atol=1e-3, check_prim=True)
def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Updates'], 'Out')
self.check_grad_with_place(
place,
['X', 'Updates'],
'Out',
check_prim=True,
)
@unittest.skipIf(
......@@ -239,6 +270,7 @@ class TestScatterFP16Op2(TestScatterOp2):
class TestScatterBF16Op2(TestScatterOp2):
def _set_dtype(self):
self.dtype = np.uint16
self.enable_cinn = False
@unittest.skipIf(
......@@ -248,6 +280,8 @@ class TestScatterOp3(OpTest):
def setUp(self):
self.op_type = "scatter"
self.python_api = paddle.scatter
self.public_python_api = paddle.scatter
self.prim_op_type = "prim"
self._set_dtype()
target_dtype = "float16" if self.dtype == np.float16 else "float32"
ref_np = np.ones((3, 3)).astype(target_dtype)
......@@ -272,12 +306,17 @@ class TestScatterOp3(OpTest):
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-3)
self.check_output_with_place(place, atol=1e-3, check_prim=True)
def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Updates'], 'Out')
self.check_grad_with_place(
place,
['X', 'Updates'],
'Out',
check_prim=True,
)
@unittest.skipIf(
......@@ -296,12 +335,15 @@ class TestScatterFP16Op3(TestScatterOp3):
class TestScatterBF16Op3(TestScatterOp3):
def _set_dtype(self):
self.dtype = np.uint16
self.enable_cinn = False
class TestScatterOp4(OpTest):
def setUp(self):
self.op_type = "scatter"
self.python_api = paddle.scatter
self.public_python_api = paddle.scatter
self.prim_op_type = "prim"
self._set_dtype()
target_dtype = "float16" if self.dtype == np.float16 else "float32"
ref_np = np.ones((3, 3)).astype(target_dtype)
......@@ -320,10 +362,10 @@ class TestScatterOp4(OpTest):
self.dtype = np.float32
def test_check_output(self):
self.check_output()
self.check_output(check_prim=True)
def test_check_grad(self):
self.check_grad(['X', 'Updates'], 'Out')
self.check_grad(['X', 'Updates'], 'Out', check_prim=True)
class TestScatterFP16Op4(TestScatterOp4):
......@@ -339,16 +381,22 @@ class TestScatterFP16Op4(TestScatterOp4):
class TestScatterBF16Op4(TestScatterOp4):
def _set_dtype(self):
self.dtype = np.uint16
self.enable_cinn = False
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(place, check_prim=True)
def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Updates'], 'Out')
self.check_grad_with_place(
place,
['X', 'Updates'],
'Out',
check_prim=True,
)
@unittest.skipIf(
......@@ -358,6 +406,8 @@ class TestScatterOp5(OpTest):
def setUp(self):
self.op_type = "scatter"
self.python_api = paddle.scatter
self.public_python_api = paddle.scatter
self.prim_op_type = "prim"
self._set_dtype()
target_dtype = "float16" if self.dtype == np.float16 else "float32"
ref_np = np.ones((3, 3)).astype(target_dtype)
......@@ -378,12 +428,17 @@ class TestScatterOp5(OpTest):
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-3)
self.check_output_with_place(place, atol=1e-3, check_prim=True)
def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Updates'], 'Out')
self.check_grad_with_place(
place,
['X', 'Updates'],
'Out',
check_prim=True,
)
@unittest.skipIf(
......@@ -402,12 +457,16 @@ class TestScatterFP16Op5(TestScatterOp5):
class TestScatterBF16Op5(TestScatterOp5):
def _set_dtype(self):
self.dtype = np.uint16
self.enable_cinn = False
class TestScatterOp6(OpTest):
def setUp(self):
self.op_type = "scatter"
self.python_api = paddle.scatter
self.public_python_api = paddle.scatter
self.prim_op_type = "prim"
self.enable_cinn = False
self._set_dtype()
target_dtype = "float16" if self.dtype == np.float16 else "float32"
ref_np = np.ones((3, 50)).astype(target_dtype)
......@@ -426,10 +485,10 @@ class TestScatterOp6(OpTest):
self.dtype = np.float32
def test_check_output(self):
self.check_output()
self.check_output(check_prim=True)
def test_check_grad(self):
self.check_grad(["X", "Updates"], "Out")
self.check_grad(["X", "Updates"], "Out", check_prim=True)
class TestScatterFP16Op6(TestScatterOp6):
......@@ -449,12 +508,17 @@ class TestScatterBF16Op6(TestScatterOp6):
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(place, check_prim=True)
def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Updates'], 'Out')
self.check_grad_with_place(
place,
['X', 'Updates'],
'Out',
check_prim=True,
)
class TestScatterAPI(unittest.TestCase):
......
......@@ -2903,7 +2903,7 @@ def scatter(x, index, updates, overwrite=True, name=None):
check_variable_and_dtype(
x,
'dtype',
['float32', 'float64', 'float16', 'int32', 'int64'],
['float32', 'float64', 'float16', 'int32', 'int64', 'uint16'],
'scatter',
)
check_type(overwrite, 'overwrite', bool, 'scatter')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册