diff --git a/paddle/fluid/prim/api/api.yaml b/paddle/fluid/prim/api/api.yaml index e47c7a45713dcb0b368f770a2cd90f7de33d680a..55831ca02d0828a2b6abb58002131cdb3102229f 100644 --- a/paddle/fluid/prim/api/api.yaml +++ b/paddle/fluid/prim/api/api.yaml @@ -26,3 +26,4 @@ - transpose - pad - cumsum +- put_along_axis diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index b18a33d5824361028d387b508b8ecfa7fbb8b77f..83e8975c7afe8b0e704cda5c1ffcdb15c64af8a7 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -744,5 +744,22 @@ void cumsum_grad(const Tensor& x, } } +template +void topk_grad(const Tensor& x, + const Tensor& indices, + const Tensor& out_grad, + const Scalar& k, + const int& axis, + const bool& largest, + const bool& sorted, + Tensor* x_grad) { + if (x_grad) { + auto zero_tensor = full(phi::vectorize(x.dims()), 0.0, x.dtype()); + auto x_grad_tmp = put_along_axis(zero_tensor, indices, out_grad, axis); + + set_output(x_grad_tmp, x_grad); + } +} + } // namespace prim } // namespace paddle diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index ee75d281b97da92b4d0b2738147a046475dc1228..8492da75eb251abbe619906da883db3dce4536f0 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1524,6 +1524,7 @@ kernel : func : topk_grad data_type : out_grad + composite : topk_grad(x, indices, out_grad, k, axis, largest, sorted, x_grad) - backward_op : trace_grad forward : trace (Tensor x, int offset, int axis1, int axis2) -> Tensor(out) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 271dd98250c400436fb96a14e2868aac649f386d..2d4db9df69e81e8543eec4b6f88120085705ac5e 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -1209,7 +1209,8 @@ set(TEST_CINN_OPS test_slice_op test_activation_op test_full_like_op - test_fill_any_like_op) + test_fill_any_like_op + test_top_k_v2_op) foreach(TEST_CINN_OPS ${TEST_CINN_OPS}) if(WITH_CINN) diff --git a/python/paddle/fluid/tests/unittests/test_top_k_v2_op.py b/python/paddle/fluid/tests/unittests/test_top_k_v2_op.py index 2a8af4d4ad9a6798eb36eabf7b1132a62eb8f87c..9f6e9ad9d736b5798e1c2bbb0b0d037da3a1aae8 100644 --- a/python/paddle/fluid/tests/unittests/test_top_k_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_top_k_v2_op.py @@ -45,6 +45,7 @@ class TestTopkOp(OpTest): def setUp(self): self.op_type = "top_k_v2" + self.prim_op_type = "prim" self.python_api = paddle.topk self.dtype = np.float64 self.input_data = np.random.rand(10, 20) @@ -60,7 +61,7 @@ class TestTopkOp(OpTest): self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(set(['X']), 'Out', check_eager=True) + self.check_grad(['X'], 'Out', check_eager=True, check_prim=True) class TestTopkOp1(TestTopkOp): @@ -77,7 +78,7 @@ class TestTopkOp2(TestTopkOp): self.largest = False -class TestTopkOp3(OpTest): +class TestTopkOp3(TestTopkOp): def init_args(self): self.k = 6 self.axis = 1 @@ -85,6 +86,7 @@ class TestTopkOp3(OpTest): def setUp(self): self.op_type = "top_k_v2" + self.prim_op_type = "prim" self.python_api = paddle.topk self.dtype = np.float64 self.input_data = np.random.rand(16, 100) @@ -105,6 +107,7 @@ class TestTopkOp4(TestTopkOp): def setUp(self): self.op_type = "top_k_v2" + self.prim_op_type = "prim" self.python_api = paddle.topk self.dtype = np.float64 self.input_data = np.random.rand(10, 10, 5) @@ -125,6 +128,7 @@ class TestTopkOp5(TestTopkOp): def setUp(self): self.op_type = "top_k_v2" + self.prim_op_type = "prim" self.python_api = paddle.topk self.dtype = np.float64 self.input_data = np.random.rand(10, 10, 5) @@ -137,17 +141,39 @@ class TestTopkOp5(TestTopkOp): self.outputs = {'Out': output, 'Indices': indices} -class TestTopkOp6(OpTest): +class TestTopkOp6(TestTopkOp): def init_args(self): - self.k = 100 + self.k = 3 self.axis = 1 self.largest = True def setUp(self): self.op_type = "top_k_v2" + self.prim_op_type = "prim" self.python_api = paddle.topk - self.dtype = np.float64 - self.input_data = np.random.rand(80, 16384) + self.dtype = np.float32 + self.input_data = np.random.rand(10, 10, 5) + self.init_args() + self.inputs = {'X': self.input_data} + self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest} + output, indices = numpy_topk( + self.input_data, axis=self.axis, k=self.k, largest=self.largest + ) + self.outputs = {'Out': output, 'Indices': indices} + + +class TestTopkOp7(TestTopkOp): + def init_args(self): + self.k = 10 + self.axis = 1 + self.largest = True + + def setUp(self): + self.op_type = "top_k_v2" + self.prim_op_type = "prim" + self.python_api = paddle.topk + self.dtype = np.float16 + self.input_data = np.random.rand(10, 20, 10) self.init_args() self.inputs = {'X': self.input_data} self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest}