未验证 提交 36791fdd 编写于 作者: R ronnywang 提交者: GitHub

[ROCM] bugfix for arg_min_max (#36098)

上级 97d30602
...@@ -89,22 +89,25 @@ void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input, ...@@ -89,22 +89,25 @@ void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input,
const int64_t n) { const int64_t n) {
auto cu_stream = ctx.stream(); auto cu_stream = ctx.stream();
auto ComputeBlockSize = [](int64_t col) { auto ComputeBlockSize = [](int64_t col) {
auto block_size = 8;
if (col > 512) if (col > 512)
return 1024; block_size = 1024;
else if (col > 256) else if (col > 256)
return 512; block_size = 512;
else if (col > 128) else if (col > 128)
return 256; block_size = 256;
else if (col > 64) else if (col > 64)
return 128; block_size = 128;
else if (col > 32) else if (col > 32)
return 64; block_size = 64;
else if (col > 16) else if (col > 16)
return 32; block_size = 32;
else if (col > 8) else if (col > 8)
return 16; block_size = 16;
else #ifdef __HIPCC__
return 8; block_size = std::min(block_size, 256);
#endif
return block_size;
}; };
int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize().x; int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize().x;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册