From 2b8fd704d0ec555b5b27d50fca261741a7fbbf28 Mon Sep 17 00:00:00 2001 From: pangyoki Date: Wed, 29 Sep 2021 14:50:43 +0800 Subject: [PATCH] fix bug of top_k npu op (#36175) --- paddle/fluid/operators/top_k_op_npu.cc | 4 ++- .../tests/unittests/npu/test_top_k_op_npu.py | 36 +++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/top_k_op_npu.cc b/paddle/fluid/operators/top_k_op_npu.cc index ca3a5f95768..a7d8fe01edd 100644 --- a/paddle/fluid/operators/top_k_op_npu.cc +++ b/paddle/fluid/operators/top_k_op_npu.cc @@ -51,7 +51,9 @@ class TopkNPUKernel : public framework::OpKernel { indices->mutable_data(ctx.GetPlace()); // prepare assit - auto dim = input->dims().size(); + auto size = input->dims().size(); + // dim is the last dimension of input + auto dim = input->dims()[size - 1]; framework::Tensor assist_seq_tensor; assist_seq_tensor.Resize({2 * dim}); assist_seq_tensor.mutable_data(ctx.GetPlace()); diff --git a/python/paddle/fluid/tests/unittests/npu/test_top_k_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_top_k_op_npu.py index b735adf76d6..c8a620d9dbb 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_top_k_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_top_k_op_npu.py @@ -22,6 +22,7 @@ from op_test import OpTest import paddle import paddle.fluid as fluid from paddle.fluid import core +from test_top_k_v2_op_npu import numpy_topk paddle.enable_static() SEED = 2021 @@ -87,5 +88,40 @@ class TestTopkV2(OpTest): self.check_output_with_place(self.place) +class TestTopkV3(OpTest): + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(0) + self.op_type = "top_k" + + self.init_dtype() + self.set_input_data() + self.set_attrs() + output, indices = numpy_topk( + self.input_data, axis=self.axis, k=self.k, largest=True) + + self.inputs = {'X': self.input_data} + self.attrs = {'k': self.k, 'axis': self.axis} + self.outputs = {'Out': output, 'Indices': indices} + + def set_npu(self): + self.__class__.use_npu = True + self.__class__.no_need_check_grad = True + + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + self.check_output_with_place(self.place) + + def set_attrs(self): + self.k = 3 + self.axis = 1 + + def set_input_data(self): + self.input_data = np.random.choice( + 10000, size=(10, 20), replace=False).astype(self.dtype) + + if __name__ == '__main__': unittest.main() -- GitLab