From 0fd6e2a1b4692b5334980a55d053a5a4557ab568 Mon Sep 17 00:00:00 2001 From: wangshengxiang <121413869+shengxiangwang@users.noreply.github.com> Date: Thu, 2 Mar 2023 15:01:38 +0800 Subject: [PATCH] [XPU] add smallest mode for top_k (#51053) --- paddle/phi/kernels/xpu/top_k_kernel.cc | 12 +-- .../unittests/xpu/test_top_k_v2_op_xpu.py | 96 +++++++++++++++++++ 2 files changed, 100 insertions(+), 8 deletions(-) diff --git a/paddle/phi/kernels/xpu/top_k_kernel.cc b/paddle/phi/kernels/xpu/top_k_kernel.cc index 1ed20b0ddf..fca852a086 100644 --- a/paddle/phi/kernels/xpu/top_k_kernel.cc +++ b/paddle/phi/kernels/xpu/top_k_kernel.cc @@ -43,12 +43,6 @@ void TopkKernel(const Context& dev_ctx, errors::External( "XPU API does not support unsorted topk operation currently." " Operator will be supported in future update.")); - PADDLE_ENFORCE_EQ( - largest, - true, - errors::External( - "XPU API does not support smallest topk operation currently." - " Operator will be supported in future update.")); if (in_dims.size() == 0) { int r = xpu::copy(dev_ctx.x_context(), reinterpret_cast(x.data()), @@ -77,7 +71,8 @@ void TopkKernel(const Context& dev_ctx, indices_int_data, row, col, - k); + k, + largest); PADDLE_ENFORCE_XDNN_SUCCESS(r, "sorted_topk"); r = xpu::cast(dev_ctx.x_context(), @@ -140,7 +135,8 @@ void TopkKernel(const Context& dev_ctx, trans_idx_int32_data, row, col, - k); + k, + largest); PADDLE_ENFORCE_XDNN_SUCCESS(r, "sorted_topk"); r = xpu::cast(dev_ctx.x_context(), diff --git a/python/paddle/fluid/tests/unittests/xpu/test_top_k_v2_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_top_k_v2_op_xpu.py index 3ddfc115ae..eaad700192 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_top_k_v2_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_top_k_v2_op_xpu.py @@ -193,6 +193,102 @@ class XPUTestTopKV2Op(XPUOpTestWrapper): self.largest = True self.input_data_shape = (10, 10, 5) + class TestTopkSmallestOp1(TestTopkOp): + def init_args(self): + self.k = 3 + self.axis = 1 + self.largest = False + # too many values for fp16 will lead to failure in random_unique_float function + if self.dtype == np.float16: + self.input_data_shape = (100, 55) + else: + self.input_data_shape = (100, 155) + + class TestTopkSmallestOp2(TestTopkOp): + def init_args(self): + self.k = 3 + self.axis = 1 + self.largest = False + self.input_data_shape = (10, 10, 5) + + class TestTopkSmallestOp3(TestTopkOp): + def init_args(self): + self.k = 5 + self.axis = 1 + self.largest = False + self.input_data_shape = (10, 10, 5) + + class TestTopkSmallestOp4(TestTopkOp): + def init_args(self): + self.k = 1 + self.axis = 1 + self.largest = False + self.input_data_shape = (10, 10, 5) + + class TestTopkSmallestOp5(TestTopkOp): + def init_args(self): + self.k = 3 + self.axis = 2 + self.largest = False + self.input_data_shape = (10, 10, 5) + + class TestTopkSmallestOp6(TestTopkOp): + def init_args(self): + self.k = 5 + self.axis = 1 + self.largest = False + # too many values for fp16 will lead to failure in random_unique_float function + if self.dtype == np.float16: + self.input_data_shape = (8, 32, 32) + else: + self.input_data_shape = (8, 32, 64) + + class TestTopkSmallestOp7(TestTopkOp): + def init_args(self): + self.k = 10 + self.axis = 2 + self.largest = False + self.input_data_shape = (8, 5, 10, 16) + + class TestTopkSmallestOp8(TestTopkOp): + def init_args(self): + self.k = 1 + self.axis = 1 + self.largest = False + # too many values for fp16 will lead to failure in random_unique_float function + if self.dtype == np.float16: + self.input_data_shape = (8, 32, 32) + else: + self.input_data_shape = (8, 32, 64) + + class TestTopkSmallestOp9(TestTopkOp): + def init_args(self): + self.k = 3 + self.axis = 1 + self.largest = False + self.input_data_shape = (10, 10, 5) + + class TestTopkSmallestOp10(TestTopkOp): + def init_args(self): + self.k = 3 + self.axis = 1 + self.largest = False + self.input_data_shape = (10, 10, 5) + + class TestTopkSmallestOp11(TestTopkOp): + def init_args(self): + self.k = 5 + self.axis = 1 + self.largest = False + self.input_data_shape = (10, 10, 5) + + class TestTopkSmallestOp12(TestTopkOp): + def init_args(self): + self.k = 1 + self.axis = 1 + self.largest = False + self.input_data_shape = (10, 10, 5) + support_types = get_xpu_op_support_types('top_k_v2') for stype in support_types: -- GitLab