diff --git a/paddle/fluid/operators/fused/conv_fusion_op.cu b/paddle/fluid/operators/fused/conv_fusion_op.cu index 2ee63c93642218fddc331f4205d8b07575e921ec..121cbc909b8129c3ceec66dbe75c8eedbf10c296 100644 --- a/paddle/fluid/operators/fused/conv_fusion_op.cu +++ b/paddle/fluid/operators/fused/conv_fusion_op.cu @@ -315,9 +315,14 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel { cudnnConvolutionFwdAlgo_t algo; auto handle = dev_ctx.cudnn_handle(); auto workspace_handle = dev_ctx.cudnn_workspace_handle(); + auto dtype = platform::CudnnDataType::type; PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetConvolutionMathType( cudnn_conv_desc, CUDNN_DEFAULT_MATH)); + if (dtype == CUDNN_DATA_HALF) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetConvolutionMathType( + cudnn_conv_desc, CUDNN_TENSOR_OP_MATH)); + } #if CUDA_VERSION >= 11000 && CUDNN_VERSION >= 8000 if (!platform::allow_tf32_cudnn) { PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetConvolutionMathType( @@ -414,7 +419,6 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel { algo = algo_cache.GetAlgorithm( x_dims[2] * x_dims[3], search_times, 0, search_func); } else { - auto dtype = platform::CudnnDataType::type; algo = algo_cache.GetAlgorithm(x_dims, f_dims, strides,