未验证 提交 3f366fee 编写于 作者: Q Qi Li 提交者: GitHub

[ROCM] fix fused_fc_elementwise_layernorm, test=develop (#33281)

上级 ae93d9c2
...@@ -31,6 +31,7 @@ namespace platform { ...@@ -31,6 +31,7 @@ namespace platform {
#endif #endif
inline static int RoundToPowerOfTwo(int dim) { inline static int RoundToPowerOfTwo(int dim) {
#ifdef PADDLE_WITH_CUDA
if (dim > 512) { if (dim > 512) {
return 1024; return 1024;
} else if (dim > 256) { } else if (dim > 256) {
...@@ -44,6 +45,17 @@ inline static int RoundToPowerOfTwo(int dim) { ...@@ -44,6 +45,17 @@ inline static int RoundToPowerOfTwo(int dim) {
} else { } else {
return 32; return 32;
} }
#else // HIP results in error or nan if > 256
if (dim > 128) {
return 256;
} else if (dim > 64) {
return 128;
} else if (dim > 32) {
return 64;
} else {
return 32;
}
#endif
} }
#define CUDA_LAUNCH_KERNEL_BASE(dim, ...) \ #define CUDA_LAUNCH_KERNEL_BASE(dim, ...) \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册