From 36791fddea73f23337d5a6cf77441af0507fce09 Mon Sep 17 00:00:00 2001 From: ronnywang <524019753@qq.com> Date: Tue, 28 Sep 2021 16:18:01 +0800 Subject: [PATCH] [ROCM] bugfix for arg_min_max (#36098) --- .../fluid/operators/arg_min_max_op_base.cu.h | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/arg_min_max_op_base.cu.h b/paddle/fluid/operators/arg_min_max_op_base.cu.h index b19ba1e1590..2c34d6f8300 100644 --- a/paddle/fluid/operators/arg_min_max_op_base.cu.h +++ b/paddle/fluid/operators/arg_min_max_op_base.cu.h @@ -89,22 +89,25 @@ void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input, const int64_t n) { auto cu_stream = ctx.stream(); auto ComputeBlockSize = [](int64_t col) { + auto block_size = 8; if (col > 512) - return 1024; + block_size = 1024; else if (col > 256) - return 512; + block_size = 512; else if (col > 128) - return 256; + block_size = 256; else if (col > 64) - return 128; + block_size = 128; else if (col > 32) - return 64; + block_size = 64; else if (col > 16) - return 32; + block_size = 32; else if (col > 8) - return 16; - else - return 8; + block_size = 16; +#ifdef __HIPCC__ + block_size = std::min(block_size, 256); +#endif + return block_size; }; int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize().x; -- GitLab