未验证 提交 a579e523 编写于 作者: J Jacek Czaja 提交者: GitHub

Requantize to use Memory Desc in Tensors (#46608)

* - some more MD changes

* - lint

* - compilation fixes

* - compilation fixes

* - lint

* - fix
上级 ecae7b31
...@@ -202,8 +202,6 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, ...@@ -202,8 +202,6 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout,
platform::MatchShapeToLayout(out, in_layout, out_layout); platform::MatchShapeToLayout(out, in_layout, out_layout);
out->set_layout(DataLayout::kNCHW); out->set_layout(DataLayout::kNCHW);
// reset format since the out tensor will be feed to non-MKLDNN OPkernel
out->set_format(MKLDNNMemoryFormat::undef);
} }
#endif #endif
......
...@@ -471,7 +471,9 @@ void TensorCopySync(const phi::DenseTensor& src, ...@@ -471,7 +471,9 @@ void TensorCopySync(const phi::DenseTensor& src,
dst->Resize(src.dims()); dst->Resize(src.dims());
dst->set_layout(src.layout()); dst->set_layout(src.layout());
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
dst->set_format(src.format()); if (src.layout() == DataLayout::kMKLDNN) {
dst->set_mem_desc(src.mem_desc());
}
#endif #endif
auto src_place = src.place(); auto src_place = src.place();
auto src_ptr = src.data(); auto src_ptr = src.data();
......
...@@ -146,8 +146,7 @@ class ReQuantOpKernel : public framework::OpKernel<T> { ...@@ -146,8 +146,7 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
reorder_p->execute(astream, *src_memory, *dst_memory); reorder_p->execute(astream, *src_memory, *dst_memory);
astream.wait(); astream.wait();
output->set_layout(framework::DataLayout::kMKLDNN); output->set_mem_desc(dst_memory->get_desc());
output->set_format(platform::GetMKLDNNFormat(*dst_memory));
} }
}; };
......
...@@ -93,8 +93,13 @@ class TransferLayoutFunctor { ...@@ -93,8 +93,13 @@ class TransferLayoutFunctor {
paddle::platform::MKLDNNDeviceContext::tls() paddle::platform::MKLDNNDeviceContext::tls()
.set_cur_paddle_data_layout(in_layout); .set_cur_paddle_data_layout(in_layout);
} }
out_tensor.set_layout(DataLayout::kMKLDNN);
out_tensor.set_format(out_format); auto out_tz = phi::vectorize<int64_t>(out_tensor.dims());
dnnl::memory::data_type in_type = framework::ToMKLDNNDataType(
framework::TransToProtoVarType(in_tensor.dtype()));
dnnl::memory::desc out_mem_desc(out_tz, in_type, out_format);
out_tensor.set_mem_desc(out_mem_desc);
} else { } else {
auto target_layout = paddle::platform::MKLDNNDeviceContext::tls() auto target_layout = paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout(); .get_cur_paddle_data_layout();
......
...@@ -115,8 +115,6 @@ void innerTransDataLayoutFromOneDNN(DataLayout in_layout, ...@@ -115,8 +115,6 @@ void innerTransDataLayoutFromOneDNN(DataLayout in_layout,
out->set_layout(DataLayout::kNCHW); out->set_layout(DataLayout::kNCHW);
VLOG(10) << "out->layout: " << out->layout() << " in->dims: " << in.dims() VLOG(10) << "out->layout: " << out->layout() << " in->dims: " << in.dims()
<< " out->dims: " << out->dims(); << " out->dims: " << out->dims();
// reset format since the out tensor will be feed to non-MKLDNN OPkernel
out->set_format(OneDNNMemoryFormat::undef);
} }
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册