From 8f5a93a07bb3e67615b547dac3cbe2c821d82dd3 Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Fri, 6 Dec 2019 03:04:18 +0100 Subject: [PATCH] - Fix to regression in performance of ResNet-50 training (#21588) test=develop --- paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index de6dea91ea2..52391077eee 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -775,8 +775,23 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { * ('any') which lets a primitive (conv backward in this case) choose * the memory format preferred for best performance */ - auto chosen_memory_format = MKLDNNMemoryFormat::any; + + // TODO(jczaja): Once GRAD NHWC is working then format 'any' + // should be used exclusively. But till forward pass enforce + // NCHW for training we need to have NCHW here as well + // to avoid performance degradation in relu_grad and pool2d_grad + std::string data_format = ctx.Attr("data_format"); + auto chosen_memory_format = + platform::data_format_to_memory_format(data_format); + weights_format = MKLDNNMemoryFormat::any; + // Check the format for user's special output + if (chosen_memory_format != MKLDNNMemoryFormat::any) { + if (is_conv3d) { + chosen_memory_format = + platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format); + } + } auto src_md = platform::MKLDNNMemDesc( src_tz, platform::MKLDNNGetDataType(), chosen_memory_format); -- GitLab