From b8c6e180a40ed8897bbb71d3cd72ae8e62286137 Mon Sep 17 00:00:00 2001 From: zhulei <563755780@qq.com> Date: Mon, 13 Sep 2021 10:24:54 +0800 Subject: [PATCH] [ROCM] fix top_k_v2 with large shape (#33783) * [ROCM] fix top_k_v2 with large shape * [ROCM] fix top_k_v2 with large shape --- paddle/fluid/operators/top_k_v2_op.cu | 18 ++++++++++++++++++ .../fluid/tests/unittests/test_top_k_v2_op.py | 18 ++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/paddle/fluid/operators/top_k_v2_op.cu b/paddle/fluid/operators/top_k_v2_op.cu index 0f2da4b8f6f..6e74ca46d2c 100644 --- a/paddle/fluid/operators/top_k_v2_op.cu +++ b/paddle/fluid/operators/top_k_v2_op.cu @@ -99,12 +99,21 @@ class TopkV2OpCUDAKernel : public framework::OpKernel { const int kMaxHeight = 2048; int gridx = input_height < kMaxHeight ? input_height : kMaxHeight; switch (GetDesiredBlockDim(input_width)) { +#ifdef PADDLE_WITH_HIP + FIXED_BLOCK_DIM( + KeMatrixTopK<<>>( + output_data, k, indices_data, input_data, input_width, + input_width, static_cast(k), gridx, input_height, + largest)); +#else FIXED_BLOCK_DIM( KeMatrixTopK<<>>( output_data, k, indices_data, input_data, input_width, input_width, static_cast(k), gridx, input_height, largest)); +#endif default: PADDLE_THROW(platform::errors::Fatal( "the input data shape has error in the topk cuda kernel.")); @@ -169,12 +178,21 @@ class TopkV2OpCUDAKernel : public framework::OpKernel { const int kMaxHeight = 2048; int gridx = input_height < kMaxHeight ? input_height : kMaxHeight; switch (GetDesiredBlockDim(input_width)) { +#ifdef PADDLE_WITH_HIP + FIXED_BLOCK_DIM( + KeMatrixTopK<<>>( + trans_out.data(), k, trans_ind.data(), + trans_input.data(), input_width, input_width, + static_cast(k), gridx, input_height, largest)); +#else FIXED_BLOCK_DIM( KeMatrixTopK<<>>( trans_out.data(), k, trans_ind.data(), trans_input.data(), input_width, input_width, static_cast(k), gridx, input_height, largest)); +#endif default: PADDLE_THROW(platform::errors::Fatal( "the input data shape has error in the topk cuda kernel.")); diff --git a/python/paddle/fluid/tests/unittests/test_top_k_v2_op.py b/python/paddle/fluid/tests/unittests/test_top_k_v2_op.py index 94dcf151150..4be53304733 100644 --- a/python/paddle/fluid/tests/unittests/test_top_k_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_top_k_v2_op.py @@ -131,6 +131,24 @@ class TestTopkOp5(TestTopkOp): self.outputs = {'Out': output, 'Indices': indices} +class TestTopkOp6(OpTest): + def init_args(self): + self.k = 100 + self.axis = 1 + self.largest = True + + def setUp(self): + self.op_type = "top_k_v2" + self.dtype = np.float64 + self.input_data = np.random.rand(80, 16384) + self.init_args() + self.inputs = {'X': self.input_data} + self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest} + output, indices = numpy_topk( + self.input_data, axis=self.axis, k=self.k, largest=self.largest) + self.outputs = {'Out': output, 'Indices': indices} + + class TestTopKAPI(unittest.TestCase): def setUp(self): np.random.seed(123) -- GitLab