From f4a6c53978db4cfa399f028788e2d8774dc04b6a Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Mon, 26 Sep 2022 17:12:18 +0200 Subject: [PATCH] Conv grad to use set_mem_desc() (#46459) * - Conv grad changed for MD * - lint * - compilation fix * yet another lint --- .../fluid/operators/mkldnn/conv_mkldnn_op.cc | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index fc8f2991309..09cdd70a661 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -1048,14 +1048,12 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel { {DNNL_ARG_DIFF_WEIGHTS, *diff_weights_memory_p}}); astream.wait(); - filter_grad->set_layout(framework::DataLayout::kMKLDNN); - // in OneDNN groups in convolution are treated as separate dimension - // which is not the case in paddlepaddle - auto filter_fmt = platform::GetMKLDNNFormat(*diff_weights_memory_p); - // For convolution with groups convert from blocked to NCHW // otherwise there will be problems in next operators working on this data if (g > 1) { + // in OneDNN groups in convolution are treated as separate dimension + // which is not the case in paddlepaddle + dnnl::memory::data_type in_type = framework::ToMKLDNNDataType( framework::TransToProtoVarType(filter->dtype())); // for 3d conv with groups (six dimensional data reorder to goidhw) @@ -1094,9 +1092,12 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel { dnnl::memory::format_tag target_format = weights_tz.size() == 6 ? dnnl::memory::format_tag::oidhw : dnnl::memory::format_tag::oihw; - filter_grad->set_format(target_format); + filter_grad->set_mem_desc( + dnnl::memory::desc(phi::vectorize(filter_grad->dims()), + in_type, + target_format)); } else { - filter_grad->set_format(filter_fmt); + filter_grad->set_mem_desc(diff_weights_memory_p->get_desc()); } } if (input_grad) { @@ -1119,8 +1120,7 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel { {DNNL_ARG_DIFF_SRC, *diff_src_memory_p}}); astream.wait(); - input_grad->set_layout(framework::DataLayout::kMKLDNN); - input_grad->set_format(platform::GetMKLDNNFormat(*diff_src_memory_p)); + input_grad->set_mem_desc(diff_src_memory_p->get_desc()); } } }; -- GitLab