diff --git a/paddle/fluid/operators/top_k_v2_op.cu b/paddle/fluid/operators/top_k_v2_op.cu index 0f2da4b8f6fbb9122ed5a08e5883293147e7ecc3..6e74ca46d2cd2796346b7dc2acbea355058826e2 100644 --- a/paddle/fluid/operators/top_k_v2_op.cu +++ b/paddle/fluid/operators/top_k_v2_op.cu @@ -99,12 +99,21 @@ class TopkV2OpCUDAKernel : public framework::OpKernel { const int kMaxHeight = 2048; int gridx = input_height < kMaxHeight ? input_height : kMaxHeight; switch (GetDesiredBlockDim(input_width)) { +#ifdef PADDLE_WITH_HIP + FIXED_BLOCK_DIM( + KeMatrixTopK<<>>( + output_data, k, indices_data, input_data, input_width, + input_width, static_cast(k), gridx, input_height, + largest)); +#else FIXED_BLOCK_DIM( KeMatrixTopK<<>>( output_data, k, indices_data, input_data, input_width, input_width, static_cast(k), gridx, input_height, largest)); +#endif default: PADDLE_THROW(platform::errors::Fatal( "the input data shape has error in the topk cuda kernel.")); @@ -169,12 +178,21 @@ class TopkV2OpCUDAKernel : public framework::OpKernel { const int kMaxHeight = 2048; int gridx = input_height < kMaxHeight ? input_height : kMaxHeight; switch (GetDesiredBlockDim(input_width)) { +#ifdef PADDLE_WITH_HIP + FIXED_BLOCK_DIM( + KeMatrixTopK<<>>( + trans_out.data(), k, trans_ind.data(), + trans_input.data(), input_width, input_width, + static_cast(k), gridx, input_height, largest)); +#else FIXED_BLOCK_DIM( KeMatrixTopK<<>>( trans_out.data(), k, trans_ind.data(), trans_input.data(), input_width, input_width, static_cast(k), gridx, input_height, largest)); +#endif default: PADDLE_THROW(platform::errors::Fatal( "the input data shape has error in the topk cuda kernel.")); diff --git a/python/paddle/fluid/tests/unittests/test_top_k_v2_op.py b/python/paddle/fluid/tests/unittests/test_top_k_v2_op.py index 94dcf151150ff2e84d2666d4fcca8d66824b4568..4be53304733cbfda55e2e1ab97c4ef5be8951b2f 100644 --- a/python/paddle/fluid/tests/unittests/test_top_k_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_top_k_v2_op.py @@ -131,6 +131,24 @@ class TestTopkOp5(TestTopkOp): self.outputs = {'Out': output, 'Indices': indices} +class TestTopkOp6(OpTest): + def init_args(self): + self.k = 100 + self.axis = 1 + self.largest = True + + def setUp(self): + self.op_type = "top_k_v2" + self.dtype = np.float64 + self.input_data = np.random.rand(80, 16384) + self.init_args() + self.inputs = {'X': self.input_data} + self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest} + output, indices = numpy_topk( + self.input_data, axis=self.axis, k=self.k, largest=self.largest) + self.outputs = {'Out': output, 'Indices': indices} + + class TestTopKAPI(unittest.TestCase): def setUp(self): np.random.seed(123)