diff --git a/paddle/fluid/operators/top_k_op_npu.cc b/paddle/fluid/operators/top_k_op_npu.cc index ca3a5f957685d98bfdc3a008ab71d5806814b1eb..a7d8fe01edd4cdc75ac637d8b560e40dd21c3b0b 100644 --- a/paddle/fluid/operators/top_k_op_npu.cc +++ b/paddle/fluid/operators/top_k_op_npu.cc @@ -51,7 +51,9 @@ class TopkNPUKernel : public framework::OpKernel { indices->mutable_data(ctx.GetPlace()); // prepare assit - auto dim = input->dims().size(); + auto size = input->dims().size(); + // dim is the last dimension of input + auto dim = input->dims()[size - 1]; framework::Tensor assist_seq_tensor; assist_seq_tensor.Resize({2 * dim}); assist_seq_tensor.mutable_data(ctx.GetPlace()); diff --git a/python/paddle/fluid/tests/unittests/npu/test_top_k_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_top_k_op_npu.py index b735adf76d6c1296197960b084aac660c688c5cd..c8a620d9dbb3517ae022f51ed2134ee484376086 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_top_k_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_top_k_op_npu.py @@ -22,6 +22,7 @@ from op_test import OpTest import paddle import paddle.fluid as fluid from paddle.fluid import core +from test_top_k_v2_op_npu import numpy_topk paddle.enable_static() SEED = 2021 @@ -87,5 +88,40 @@ class TestTopkV2(OpTest): self.check_output_with_place(self.place) +class TestTopkV3(OpTest): + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(0) + self.op_type = "top_k" + + self.init_dtype() + self.set_input_data() + self.set_attrs() + output, indices = numpy_topk( + self.input_data, axis=self.axis, k=self.k, largest=True) + + self.inputs = {'X': self.input_data} + self.attrs = {'k': self.k, 'axis': self.axis} + self.outputs = {'Out': output, 'Indices': indices} + + def set_npu(self): + self.__class__.use_npu = True + self.__class__.no_need_check_grad = True + + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + self.check_output_with_place(self.place) + + def set_attrs(self): + self.k = 3 + self.axis = 1 + + def set_input_data(self): + self.input_data = np.random.choice( + 10000, size=(10, 20), replace=False).astype(self.dtype) + + if __name__ == '__main__': unittest.main()