未验证 提交 9900b42b 编写于 作者: X xiaoxiaohehe001 提交者: GitHub

conv_fusion_fp16 (#44173)

上级 2fc93f39
......@@ -315,9 +315,14 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
cudnnConvolutionFwdAlgo_t algo;
auto handle = dev_ctx.cudnn_handle();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
auto dtype = platform::CudnnDataType<T>::type;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
cudnn_conv_desc, CUDNN_DEFAULT_MATH));
if (dtype == CUDNN_DATA_HALF) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
cudnn_conv_desc, CUDNN_TENSOR_OP_MATH));
}
#if CUDA_VERSION >= 11000 && CUDNN_VERSION >= 8000
if (!platform::allow_tf32_cudnn) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
......@@ -414,7 +419,6 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
algo = algo_cache.GetAlgorithm(
x_dims[2] * x_dims[3], search_times, 0, search_func);
} else {
auto dtype = platform::CudnnDataType<T>::type;
algo = algo_cache.GetAlgorithm(x_dims,
f_dims,
strides,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册