未验证 提交 f8b8811d 编写于 作者: L lvmengsi 提交者: GitHub

fix_depthwise_conv_cudnn, test=develop (#20712) (#20727)

上级 5baf1b23
......@@ -265,6 +265,16 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
algo = search::Find<T>(args, exhaustive_search, false, 0, ctx);
workspace_size = search::GetWorkspaceSize(args, algo);
#if CUDNN_VERSION_MIN(7, 0, 1)
// when groups > 1, SearchAlgorithm find algo is CUDNN_CONVOLUTION_\
// FWD_ALGO_WINOGRAD_NONFUSED, but this kind of algorithm is unstable
// in forward computation, so change the algorithm to CUDNN_CONVOLUTION_\
// FWD_ALGO_IMPLICIT_GEMM manually.
if (ctx.Attr<int>("groups") > 1) {
algo = static_cast<cudnnConvolutionFwdAlgo_t>(0);
}
#endif
// ------------------- cudnn conv forward ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
for (int i = 0; i < groups; i++) {
......@@ -805,6 +815,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
#if CUDNN_VERSION_MIN(7, 0, 1)
iwo_group = 1;
c_group = groups;
groups = 1;
#endif
auto dtype = platform::CudnnDataType<T>::type;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册