From 9900b42bdf5014d283cfdd05b34f8832068f1831 Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 <49090790+xiaoxiaohehe001@users.noreply.github.com> Date: Fri, 8 Jul 2022 17:17:23 +0800 Subject: [PATCH] conv_fusion_fp16 (#44173) --- paddle/fluid/operators/fused/conv_fusion_op.cu | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/fused/conv_fusion_op.cu b/paddle/fluid/operators/fused/conv_fusion_op.cu index 2ee63c93642..121cbc909b8 100644 --- a/paddle/fluid/operators/fused/conv_fusion_op.cu +++ b/paddle/fluid/operators/fused/conv_fusion_op.cu @@ -315,9 +315,14 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel { cudnnConvolutionFwdAlgo_t algo; auto handle = dev_ctx.cudnn_handle(); auto workspace_handle = dev_ctx.cudnn_workspace_handle(); + auto dtype = platform::CudnnDataType::type; PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetConvolutionMathType( cudnn_conv_desc, CUDNN_DEFAULT_MATH)); + if (dtype == CUDNN_DATA_HALF) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetConvolutionMathType( + cudnn_conv_desc, CUDNN_TENSOR_OP_MATH)); + } #if CUDA_VERSION >= 11000 && CUDNN_VERSION >= 8000 if (!platform::allow_tf32_cudnn) { PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetConvolutionMathType( @@ -414,7 +419,6 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel { algo = algo_cache.GetAlgorithm( x_dims[2] * x_dims[3], search_times, 0, search_func); } else { - auto dtype = platform::CudnnDataType::type; algo = algo_cache.GetAlgorithm(x_dims, f_dims, strides, -- GitLab