未验证 提交 b8c6e180 编写于 作者: Z zhulei 提交者: GitHub

[ROCM] fix top_k_v2 with large shape (#33783)

* [ROCM] fix top_k_v2 with large shape

* [ROCM] fix top_k_v2 with large shape
上级 1ee237c1
...@@ -99,12 +99,21 @@ class TopkV2OpCUDAKernel : public framework::OpKernel<T> { ...@@ -99,12 +99,21 @@ class TopkV2OpCUDAKernel : public framework::OpKernel<T> {
const int kMaxHeight = 2048; const int kMaxHeight = 2048;
int gridx = input_height < kMaxHeight ? input_height : kMaxHeight; int gridx = input_height < kMaxHeight ? input_height : kMaxHeight;
switch (GetDesiredBlockDim(input_width)) { switch (GetDesiredBlockDim(input_width)) {
#ifdef PADDLE_WITH_HIP
FIXED_BLOCK_DIM(
KeMatrixTopK<T, 20,
kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
output_data, k, indices_data, input_data, input_width,
input_width, static_cast<int>(k), gridx, input_height,
largest));
#else
FIXED_BLOCK_DIM( FIXED_BLOCK_DIM(
KeMatrixTopK<T, 5, KeMatrixTopK<T, 5,
kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>( kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
output_data, k, indices_data, input_data, input_width, output_data, k, indices_data, input_data, input_width,
input_width, static_cast<int>(k), gridx, input_height, input_width, static_cast<int>(k), gridx, input_height,
largest)); largest));
#endif
default: default:
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"the input data shape has error in the topk cuda kernel.")); "the input data shape has error in the topk cuda kernel."));
...@@ -169,12 +178,21 @@ class TopkV2OpCUDAKernel : public framework::OpKernel<T> { ...@@ -169,12 +178,21 @@ class TopkV2OpCUDAKernel : public framework::OpKernel<T> {
const int kMaxHeight = 2048; const int kMaxHeight = 2048;
int gridx = input_height < kMaxHeight ? input_height : kMaxHeight; int gridx = input_height < kMaxHeight ? input_height : kMaxHeight;
switch (GetDesiredBlockDim(input_width)) { switch (GetDesiredBlockDim(input_width)) {
#ifdef PADDLE_WITH_HIP
FIXED_BLOCK_DIM(
KeMatrixTopK<T, 20,
kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
trans_out.data<T>(), k, trans_ind.data<int64_t>(),
trans_input.data<T>(), input_width, input_width,
static_cast<int>(k), gridx, input_height, largest));
#else
FIXED_BLOCK_DIM( FIXED_BLOCK_DIM(
KeMatrixTopK<T, 5, KeMatrixTopK<T, 5,
kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>( kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
trans_out.data<T>(), k, trans_ind.data<int64_t>(), trans_out.data<T>(), k, trans_ind.data<int64_t>(),
trans_input.data<T>(), input_width, input_width, trans_input.data<T>(), input_width, input_width,
static_cast<int>(k), gridx, input_height, largest)); static_cast<int>(k), gridx, input_height, largest));
#endif
default: default:
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"the input data shape has error in the topk cuda kernel.")); "the input data shape has error in the topk cuda kernel."));
......
...@@ -131,6 +131,24 @@ class TestTopkOp5(TestTopkOp): ...@@ -131,6 +131,24 @@ class TestTopkOp5(TestTopkOp):
self.outputs = {'Out': output, 'Indices': indices} self.outputs = {'Out': output, 'Indices': indices}
class TestTopkOp6(OpTest):
def init_args(self):
self.k = 100
self.axis = 1
self.largest = True
def setUp(self):
self.op_type = "top_k_v2"
self.dtype = np.float64
self.input_data = np.random.rand(80, 16384)
self.init_args()
self.inputs = {'X': self.input_data}
self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest}
output, indices = numpy_topk(
self.input_data, axis=self.axis, k=self.k, largest=self.largest)
self.outputs = {'Out': output, 'Indices': indices}
class TestTopKAPI(unittest.TestCase): class TestTopKAPI(unittest.TestCase):
def setUp(self): def setUp(self):
np.random.seed(123) np.random.seed(123)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册