未验证 提交 7b33720c 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] Tensor copy fix to oneDNN tensors (#29771)

* - Tensor copy fix to oneDNN tensors

* - Fixes after review
上级 a400b76d
......@@ -43,20 +43,32 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
dst->Resize(src.dims());
dst->set_layout(src.layout());
#ifdef PADDLE_WITH_MKLDNN
dst->set_format(src.format());
#endif
auto src_place = src.place();
auto src_ptr = src.data<void>();
#ifdef PADDLE_WITH_MKLDNN
dst->set_format(src.format());
// oneDNN tensors due to padding may be of bigger size
// than numel()*size(type())
auto dst_ptr =
src.layout() == DataLayout::kMKLDNN
? dst->mutable_data(dst_place, src.type(), src.memory_size())
: dst->mutable_data(dst_place, src.type());
#else
auto dst_ptr = dst->mutable_data(dst_place, src.type());
#endif
if (src_ptr == dst_ptr && src_place == dst_place) {
VLOG(3) << "Skip copy the same data async from " << src_place << " to "
<< dst_place;
return;
}
#ifdef PADDLE_WITH_MKLDNN
auto size = src.layout() == DataLayout::kMKLDNN
? src.memory_size()
: src.numel() * SizeOfType(src.type());
#else
auto size = src.numel() * SizeOfType(src.type());
#endif
if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) {
memory::Copy(BOOST_GET_CONST(platform::CPUPlace, dst_place), dst_ptr,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册