未验证 提交 40a29186 编写于 作者: R ronnywang 提交者: GitHub

[ROCM] fixbug for arg_min_max (#36113)

ATT, cherry-pick #36098
上级 fe5cddf2
......@@ -89,22 +89,25 @@ void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input,
const int64_t n) {
auto cu_stream = ctx.stream();
auto ComputeBlockSize = [](int64_t col) {
auto block_size = 8;
if (col > 512)
return 1024;
block_size = 1024;
else if (col > 256)
return 512;
block_size = 512;
else if (col > 128)
return 256;
block_size = 256;
else if (col > 64)
return 128;
block_size = 128;
else if (col > 32)
return 64;
block_size = 64;
else if (col > 16)
return 32;
block_size = 32;
else if (col > 8)
return 16;
else
return 8;
block_size = 16;
#ifdef __HIPCC__
block_size = std::min(block_size, 256);
#endif
return block_size;
};
int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize().x;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册