From dc229b4195052c67ceb28cba4859f10d7a49aa64 Mon Sep 17 00:00:00 2001 From: lvmengsi Date: Fri, 18 Oct 2019 22:57:42 +0800 Subject: [PATCH] fix_depthwise_conv_cudnn, test=develop (#20712) --- paddle/fluid/operators/conv_cudnn_op.cu | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/paddle/fluid/operators/conv_cudnn_op.cu b/paddle/fluid/operators/conv_cudnn_op.cu index 274da9abf0..c6af6da582 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu +++ b/paddle/fluid/operators/conv_cudnn_op.cu @@ -265,6 +265,16 @@ class CUDNNConvOpKernel : public framework::OpKernel { algo = search::Find(args, exhaustive_search, false, 0, ctx); workspace_size = search::GetWorkspaceSize(args, algo); +#if CUDNN_VERSION_MIN(7, 0, 1) + // when groups > 1, SearchAlgorithm find algo is CUDNN_CONVOLUTION_\ + // FWD_ALGO_WINOGRAD_NONFUSED, but this kind of algorithm is unstable + // in forward computation, so change the algorithm to CUDNN_CONVOLUTION_\ + // FWD_ALGO_IMPLICIT_GEMM manually. + if (ctx.Attr("groups") > 1) { + algo = static_cast(0); + } +#endif + // ------------------- cudnn conv forward --------------------- ScalingParamType alpha = 1.0f, beta = 0.0f; for (int i = 0; i < groups; i++) { @@ -805,6 +815,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { #if CUDNN_VERSION_MIN(7, 0, 1) iwo_group = 1; c_group = groups; + groups = 1; #endif auto dtype = platform::CudnnDataType::type; -- GitLab