diff --git a/paddle/fluid/operators/fused/conv_fusion_op.cu b/paddle/fluid/operators/fused/conv_fusion_op.cu index 33d408582ff48504ed7fce2950934fcd43cabc90..c9ba7a61e0907f53888b7088a1fa203d10c569e0 100644 --- a/paddle/fluid/operators/fused/conv_fusion_op.cu +++ b/paddle/fluid/operators/fused/conv_fusion_op.cu @@ -200,13 +200,13 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel { PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetConvolutionMathType( cudnn_conv_desc, CUDNN_DEFAULT_MATH)); -#if CUDNN_VERSION >= 11000 +#if CUDA_VERSION >= 11000 && CUDNN_VERSION >= 8000 if (!platform::allow_tf32_cudnn) { PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnSetConvolutionMathType(cudnn_conv_desc, CUDNN_FMA_MATH)); } -#endif // CUDA_VERSION >= 11000 +#endif // CUDA_VERSION >= 11000 && CUDNN_VERSION >= 8000 auto x_dims = framework::vectorize(transformed_input.dims()); auto f_dims = framework::vectorize(filter->dims()); diff --git a/paddle/fluid/operators/fused/fusion_conv_inception_op.cu b/paddle/fluid/operators/fused/fusion_conv_inception_op.cu index c448c529f569158835020eec78d9092845247cdc..b3796f1df5fdf207b26bebfd89704d4387f0d256 100644 --- a/paddle/fluid/operators/fused/fusion_conv_inception_op.cu +++ b/paddle/fluid/operators/fused/fusion_conv_inception_op.cu @@ -153,13 +153,13 @@ class CUDNNConvInceptionFusionOpKernel : public framework::OpKernel { PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnSetConvolutionMathType(conv_desc[i], CUDNN_DEFAULT_MATH)); -#if CUDNN_VERSION >= 11000 +#if CUDA_VERSION >= 11000 && CUDNN_VERSION >= 8000 if (!platform::allow_tf32_cudnn) { PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnSetConvolutionMathType(conv_desc[i], CUDNN_FMA_MATH)); } -#endif // CUDA_VERSION >= 11000 +#endif // CUDA_VERSION >= 11000 && CUDNN_VERSION >= 8000 } in_dims[2][1] *= 2; in_strides[2][0] = oc * h * w; diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index face048f28e8340b9f6e2b3c3d0fcbc4552eea52..2578c9b6cdea5a48967b39f2e4cfbf569680b282 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -85,6 +85,7 @@ namespace platform { void SetAllowTF32Cublas(bool active); /*Get the global variable allow_tf32_cublas value*/ bool AllowTF32Cublas(); +extern bool allow_tf32_cudnn; /*Set the value of the global variable allow_tf32_cudnn*/ void SetAllowTF32Cudnn(bool active); /*Get the global variable allow_tf32_cudnn value*/