diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index de6dea91ea20d453cc50ab55c5e9be60c99c652e..52391077eee54c0004f663eb3f7ba664e5796a0e 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -775,8 +775,23 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { * ('any') which lets a primitive (conv backward in this case) choose * the memory format preferred for best performance */ - auto chosen_memory_format = MKLDNNMemoryFormat::any; + + // TODO(jczaja): Once GRAD NHWC is working then format 'any' + // should be used exclusively. But till forward pass enforce + // NCHW for training we need to have NCHW here as well + // to avoid performance degradation in relu_grad and pool2d_grad + std::string data_format = ctx.Attr("data_format"); + auto chosen_memory_format = + platform::data_format_to_memory_format(data_format); + weights_format = MKLDNNMemoryFormat::any; + // Check the format for user's special output + if (chosen_memory_format != MKLDNNMemoryFormat::any) { + if (is_conv3d) { + chosen_memory_format = + platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format); + } + } auto src_md = platform::MKLDNNMemDesc( src_tz, platform::MKLDNNGetDataType(), chosen_memory_format);