diff --git a/paddle/fluid/operators/arg_min_max_op_base.cu.h b/paddle/fluid/operators/arg_min_max_op_base.cu.h index b19ba1e1590fe1a732043f1eb9e250c2381c2133..2c34d6f8300a7476992ed8645012b685413e9e8a 100644 --- a/paddle/fluid/operators/arg_min_max_op_base.cu.h +++ b/paddle/fluid/operators/arg_min_max_op_base.cu.h @@ -89,22 +89,25 @@ void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input, const int64_t n) { auto cu_stream = ctx.stream(); auto ComputeBlockSize = [](int64_t col) { + auto block_size = 8; if (col > 512) - return 1024; + block_size = 1024; else if (col > 256) - return 512; + block_size = 512; else if (col > 128) - return 256; + block_size = 256; else if (col > 64) - return 128; + block_size = 128; else if (col > 32) - return 64; + block_size = 64; else if (col > 16) - return 32; + block_size = 32; else if (col > 8) - return 16; - else - return 8; + block_size = 16; +#ifdef __HIPCC__ + block_size = std::min(block_size, 256); +#endif + return block_size; }; int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize().x;