未验证 提交 a579e523 编写于 作者: J Jacek Czaja 提交者: GitHub

Requantize to use Memory Desc in Tensors (#46608)

* - some more MD changes

* - lint

* - compilation fixes

* - compilation fixes

* - lint

* - fix
上级 ecae7b31
......@@ -202,8 +202,6 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout,
platform::MatchShapeToLayout(out, in_layout, out_layout);
out->set_layout(DataLayout::kNCHW);
// reset format since the out tensor will be feed to non-MKLDNN OPkernel
out->set_format(MKLDNNMemoryFormat::undef);
}
#endif
......
......@@ -471,7 +471,9 @@ void TensorCopySync(const phi::DenseTensor& src,
dst->Resize(src.dims());
dst->set_layout(src.layout());
#ifdef PADDLE_WITH_MKLDNN
dst->set_format(src.format());
if (src.layout() == DataLayout::kMKLDNN) {
dst->set_mem_desc(src.mem_desc());
}
#endif
auto src_place = src.place();
auto src_ptr = src.data();
......
......@@ -146,8 +146,7 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
reorder_p->execute(astream, *src_memory, *dst_memory);
astream.wait();
output->set_layout(framework::DataLayout::kMKLDNN);
output->set_format(platform::GetMKLDNNFormat(*dst_memory));
output->set_mem_desc(dst_memory->get_desc());
}
};
......
......@@ -93,8 +93,13 @@ class TransferLayoutFunctor {
paddle::platform::MKLDNNDeviceContext::tls()
.set_cur_paddle_data_layout(in_layout);
}
out_tensor.set_layout(DataLayout::kMKLDNN);
out_tensor.set_format(out_format);
auto out_tz = phi::vectorize<int64_t>(out_tensor.dims());
dnnl::memory::data_type in_type = framework::ToMKLDNNDataType(
framework::TransToProtoVarType(in_tensor.dtype()));
dnnl::memory::desc out_mem_desc(out_tz, in_type, out_format);
out_tensor.set_mem_desc(out_mem_desc);
} else {
auto target_layout = paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout();
......
......@@ -115,8 +115,6 @@ void innerTransDataLayoutFromOneDNN(DataLayout in_layout,
out->set_layout(DataLayout::kNCHW);
VLOG(10) << "out->layout: " << out->layout() << " in->dims: " << in.dims()
<< " out->dims: " << out->dims();
// reset format since the out tensor will be feed to non-MKLDNN OPkernel
out->set_format(OneDNNMemoryFormat::undef);
}
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册