diff --git a/paddle/fluid/platform/cuda_device_function.h b/paddle/fluid/platform/cuda_device_function.h index dde9531e59144218c91d789a8fe668d3fffb70f2..5a86bb46e6ac4de6c644326c696a6ddff9ffe801 100644 --- a/paddle/fluid/platform/cuda_device_function.h +++ b/paddle/fluid/platform/cuda_device_function.h @@ -32,6 +32,7 @@ namespace platform { #endif inline static int RoundToPowerOfTwo(int dim) { +#ifdef PADDLE_WITH_CUDA if (dim > 512) { return 1024; } else if (dim > 256) { @@ -45,6 +46,17 @@ inline static int RoundToPowerOfTwo(int dim) { } else { return 32; } +#else // HIP results in error or nan if > 256 + if (dim > 128) { + return 256; + } else if (dim > 64) { + return 128; + } else if (dim > 32) { + return 64; + } else { + return 32; + } +#endif } #define CUDA_LAUNCH_KERNEL_BASE(dim, ...) \