未验证 提交 631c1f30 编写于 作者: J Jacek Czaja 提交者: GitHub

- Fix to 27398 (#27770)

test=develop

- compilation fix

test=develop
上级 061240b3
...@@ -117,6 +117,9 @@ void SetTensorToVariable(const Variable &in_var, const Tensor &tensor, ...@@ -117,6 +117,9 @@ void SetTensorToVariable(const Variable &in_var, const Tensor &tensor,
auto *tran_lod_tensor = out_var->GetMutable<LoDTensor>(); auto *tran_lod_tensor = out_var->GetMutable<LoDTensor>();
tran_lod_tensor->set_lod(in_lod_tensor.lod()); tran_lod_tensor->set_lod(in_lod_tensor.lod());
tran_lod_tensor->set_layout(in_lod_tensor.layout()); tran_lod_tensor->set_layout(in_lod_tensor.layout());
#ifdef PADDLE_WITH_MKLDNN
tran_lod_tensor->set_format(in_lod_tensor.format());
#endif
tran_lod_tensor->ShareDataWith(tensor); tran_lod_tensor->ShareDataWith(tensor);
} else if (in_var.IsType<SelectedRows>()) { } else if (in_var.IsType<SelectedRows>()) {
auto &in_selected_rows = in_var.Get<SelectedRows>(); auto &in_selected_rows = in_var.Get<SelectedRows>();
......
...@@ -38,6 +38,9 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place, ...@@ -38,6 +38,9 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
dst->Resize(src.dims()); dst->Resize(src.dims());
dst->set_layout(src.layout()); dst->set_layout(src.layout());
#ifdef PADDLE_WITH_MKLDNN
dst->set_format(src.format());
#endif
auto src_place = src.place(); auto src_place = src.place();
auto src_ptr = src.data<void>(); auto src_ptr = src.data<void>();
auto dst_ptr = dst->mutable_data(dst_place, src.type()); auto dst_ptr = dst->mutable_data(dst_place, src.type());
...@@ -237,6 +240,9 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place, ...@@ -237,6 +240,9 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
src.check_memory_size(); src.check_memory_size();
dst->Resize(src.dims()); dst->Resize(src.dims());
dst->set_layout(src.layout()); dst->set_layout(src.layout());
#ifdef PADDLE_WITH_MKLDNN
dst->set_format(src.format());
#endif
auto src_place = src.place(); auto src_place = src.place();
auto src_ptr = src.data<void>(); auto src_ptr = src.data<void>();
auto dst_ptr = dst->mutable_data(dst_place, src.type()); auto dst_ptr = dst->mutable_data(dst_place, src.type());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册