diff --git a/paddle/fluid/framework/data_layout_transform.cc b/paddle/fluid/framework/data_layout_transform.cc index 254a7abd66db57044def325170a138b8d4fae9c7..1594272fc5b5e491985da7524a9363967e9484c3 100644 --- a/paddle/fluid/framework/data_layout_transform.cc +++ b/paddle/fluid/framework/data_layout_transform.cc @@ -202,8 +202,6 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, platform::MatchShapeToLayout(out, in_layout, out_layout); out->set_layout(DataLayout::kNCHW); - // reset format since the out tensor will be feed to non-MKLDNN OPkernel - out->set_format(MKLDNNMemoryFormat::undef); } #endif diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index efc7f685bc90b2e9713e780475650c6d4cf6e866..0d8ed3c4eb118ec7f3b1a616c14c27990b302495 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -471,7 +471,9 @@ void TensorCopySync(const phi::DenseTensor& src, dst->Resize(src.dims()); dst->set_layout(src.layout()); #ifdef PADDLE_WITH_MKLDNN - dst->set_format(src.format()); + if (src.layout() == DataLayout::kMKLDNN) { + dst->set_mem_desc(src.mem_desc()); + } #endif auto src_place = src.place(); auto src_ptr = src.data(); diff --git a/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc index abfef00ae1678d61cbf150ca2b954f2be6588c36..23409db02bec9064a324ad2c6fc779e0888bdf9a 100644 --- a/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc @@ -146,8 +146,7 @@ class ReQuantOpKernel : public framework::OpKernel { reorder_p->execute(astream, *src_memory, *dst_memory); astream.wait(); - output->set_layout(framework::DataLayout::kMKLDNN); - output->set_format(platform::GetMKLDNNFormat(*dst_memory)); + output->set_mem_desc(dst_memory->get_desc()); } }; diff --git a/paddle/fluid/operators/transfer_layout_op.h b/paddle/fluid/operators/transfer_layout_op.h index a4c7b482ff5968dc268d70949d07fc4d7fe96e60..aee8592842846043b3848a525a0ebd4e641f6283 100644 --- a/paddle/fluid/operators/transfer_layout_op.h +++ b/paddle/fluid/operators/transfer_layout_op.h @@ -93,8 +93,13 @@ class TransferLayoutFunctor { paddle::platform::MKLDNNDeviceContext::tls() .set_cur_paddle_data_layout(in_layout); } - out_tensor.set_layout(DataLayout::kMKLDNN); - out_tensor.set_format(out_format); + + auto out_tz = phi::vectorize(out_tensor.dims()); + dnnl::memory::data_type in_type = framework::ToMKLDNNDataType( + framework::TransToProtoVarType(in_tensor.dtype())); + + dnnl::memory::desc out_mem_desc(out_tz, in_type, out_format); + out_tensor.set_mem_desc(out_mem_desc); } else { auto target_layout = paddle::platform::MKLDNNDeviceContext::tls() .get_cur_paddle_data_layout(); diff --git a/paddle/phi/kernels/funcs/data_layout_transform.cc b/paddle/phi/kernels/funcs/data_layout_transform.cc index 9d2d0bf3b5c8897fe62e80c8697a53260552d0e8..767566cef2ff9452ecf9d4da4f24329dca91d0bd 100644 --- a/paddle/phi/kernels/funcs/data_layout_transform.cc +++ b/paddle/phi/kernels/funcs/data_layout_transform.cc @@ -115,8 +115,6 @@ void innerTransDataLayoutFromOneDNN(DataLayout in_layout, out->set_layout(DataLayout::kNCHW); VLOG(10) << "out->layout: " << out->layout() << " in->dims: " << in.dims() << " out->dims: " << out->dims(); - // reset format since the out tensor will be feed to non-MKLDNN OPkernel - out->set_format(OneDNNMemoryFormat::undef); } #endif