diff --git a/paddle/fluid/platform/cuda_device_function.h b/paddle/fluid/platform/cuda_device_function.h index 4095720f71eb7185c474934231220b917a770375..352143302388a9f8169a40a14ccea9bae647cfc6 100644 --- a/paddle/fluid/platform/cuda_device_function.h +++ b/paddle/fluid/platform/cuda_device_function.h @@ -31,6 +31,7 @@ namespace platform { #endif inline static int RoundToPowerOfTwo(int dim) { +#ifdef PADDLE_WITH_CUDA if (dim > 512) { return 1024; } else if (dim > 256) { @@ -44,6 +45,17 @@ inline static int RoundToPowerOfTwo(int dim) { } else { return 32; } +#else // HIP results in error or nan if > 256 + if (dim > 128) { + return 256; + } else if (dim > 64) { + return 128; + } else if (dim > 32) { + return 64; + } else { + return 32; + } +#endif } #define CUDA_LAUNCH_KERNEL_BASE(dim, ...) \