From d1e2c61b22b9675adc3c4a52227d2220babaa001 Mon Sep 17 00:00:00 2001 From: Roc <30228238+sljlp@users.noreply.github.com> Date: Thu, 16 Mar 2023 20:06:10 +0800 Subject: [PATCH] Comp index select (#51215) --- python/paddle/fluid/tests/unittests/CMakeLists.txt | 1 + .../paddle/fluid/tests/unittests/test_index_select_op.py | 7 +++++-- python/paddle/incubate/autograd/composite_rules.py | 9 +++++++++ python/paddle/incubate/autograd/primitives.py | 2 ++ 4 files changed, 17 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index e64b9c131e..718497311d 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -1199,6 +1199,7 @@ set(TEST_CINN_OPS test_stack_op test_activation_op test_full_like_op + test_index_select_op test_fill_any_like_op test_concat_op test_top_k_v2_op diff --git a/python/paddle/fluid/tests/unittests/test_index_select_op.py b/python/paddle/fluid/tests/unittests/test_index_select_op.py index a62f7028a9..f318cfab10 100644 --- a/python/paddle/fluid/tests/unittests/test_index_select_op.py +++ b/python/paddle/fluid/tests/unittests/test_index_select_op.py @@ -28,7 +28,9 @@ class TestIndexSelectOp(OpTest): def setUp(self): self.python_api = paddle.index_select self.op_type = "index_select" + self.prim_op_type = "comp" self.init_dtype_type() + index_np = np.random.randint( low=0, high=self.x_shape[self.dim], size=self.index_size ) @@ -57,10 +59,10 @@ class TestIndexSelectOp(OpTest): self.index_size = 100 def test_check_output(self): - self.check_output(check_eager=True) + self.check_output(check_eager=True, check_prim=True) def test_check_grad_normal(self): - self.check_grad(['X'], 'Out', check_eager=True) + self.check_grad(['X'], 'Out', check_eager=True, check_prim=True) class TestIndexSelectOpCase2(TestIndexSelectOp): @@ -92,6 +94,7 @@ class TestIndexSelectFP16OP(TestIndexSelectOp): self.index_size = 100 +# no scatter op (the backward op of index_select/gather) for bf16 class TestIndexSelectBF16Op(OpTest): def setUp(self): self.python_api = paddle.index_select diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index ae6143b4cc..8192ccdd2f 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -393,6 +393,15 @@ def hard_swish_composite(x): return res +@REGISTER_COMPOSITE('index_select') +def index_select_composite(x, index, axis): + """define composite rule of op index_select.""" + if axis < 0: + axis = len(x.shape) + axis + res = gather(x, index, axis=axis) + return res + + @REGISTER_COMPOSITE('sigmoid') def sigmoid_composite(x): """ diff --git a/python/paddle/incubate/autograd/primitives.py b/python/paddle/incubate/autograd/primitives.py index 12b86e1452..152d920681 100644 --- a/python/paddle/incubate/autograd/primitives.py +++ b/python/paddle/incubate/autograd/primitives.py @@ -34,6 +34,7 @@ from paddle.tensor import erfinv # noqa: F401 from paddle.tensor import exp # noqa: F401 from paddle.tensor import expm1 # noqa: F401 from paddle.tensor import full # noqa: F401 +from paddle.tensor import gather # noqa: F401 from paddle.tensor import greater_equal # noqa: F401 from paddle.tensor import lgamma # noqa: F401 from paddle.tensor import log # noqa: F401 @@ -124,6 +125,7 @@ others = [ 'cast', 'fill_constant', 'reshape', + 'gather' 'full', 'tile', 'concat', -- GitLab