未验证 提交 a4b30a12 编写于 作者: Q Qi Li 提交者: GitHub

[ROCM] fix depthwise conv failure on ROCM, test=develop (#31998)

上级 68e7de26
...@@ -613,6 +613,9 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T, ...@@ -613,6 +613,9 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T,
thread = (output_width - 1) / 2 + 1; thread = (output_width - 1) / 2 + 1;
else if (output_width > 512 && output_width <= 1024) else if (output_width > 512 && output_width <= 1024)
thread = output_width; thread = output_width;
#ifdef __HIPCC__
thread = std::min(thread, 256);
#endif
int blocks = std::min(std::max(thread / output_width, 1), output_height); int blocks = std::min(std::max(thread / output_width, 1), output_height);
dim3 threads(std::min(output_width, thread), blocks, 1); dim3 threads(std::min(output_width, thread), blocks, 1);
dim3 grid(output_channels, batch_size, 1); dim3 grid(output_channels, batch_size, 1);
...@@ -620,7 +623,13 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T, ...@@ -620,7 +623,13 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T,
int nums_output = int nums_output =
batch_size * output_channels * output_height * output_width; batch_size * output_channels * output_height * output_width;
#ifdef __HIPCC__
int block_size = 256;
int grid_size = std::min((nums_output + block_size - 1) / block_size, 256);
#else
int block_size = 512; int block_size = 512;
int grid_size = (nums_output + block_size - 1) / block_size;
#endif
#define check_case(c_filter_multiplier, c_stride, c_filter) \ #define check_case(c_filter_multiplier, c_stride, c_filter) \
if (c_filter_multiplier == 0 || \ if (c_filter_multiplier == 0 || \
...@@ -630,7 +639,7 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T, ...@@ -630,7 +639,7 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T,
c_filter == -1)) { \ c_filter == -1)) { \
if (c_filter == -1) { \ if (c_filter == -1) { \
threads.x = block_size; \ threads.x = block_size; \
grid.x = (nums_output + block_size - 1) / block_size; \ grid.x = grid_size; \
threads.y = threads.z = grid.y = grid.z = 1; \ threads.y = threads.z = grid.y = grid.z = 1; \
} \ } \
KernelDepthwiseConvSp< \ KernelDepthwiseConvSp< \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册