diff --git a/paddle/fluid/framework/data_transform.cc b/paddle/fluid/framework/data_transform.cc index 3a40de6988f294314f83ba0308e9d57de84d60f7..70693a5df2609ec64a4b5732d310dec68c036c92 100644 --- a/paddle/fluid/framework/data_transform.cc +++ b/paddle/fluid/framework/data_transform.cc @@ -117,6 +117,9 @@ void SetTensorToVariable(const Variable &in_var, const Tensor &tensor, auto *tran_lod_tensor = out_var->GetMutable(); tran_lod_tensor->set_lod(in_lod_tensor.lod()); tran_lod_tensor->set_layout(in_lod_tensor.layout()); +#ifdef PADDLE_WITH_MKLDNN + tran_lod_tensor->set_format(in_lod_tensor.format()); +#endif tran_lod_tensor->ShareDataWith(tensor); } else if (in_var.IsType()) { auto &in_selected_rows = in_var.Get(); diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index a073dbd733f0b1f1bb1b94221b99d477cb8dfadb..4fe01aff79e521195b297afd17df1893d6abffd8 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -38,6 +38,9 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place, dst->Resize(src.dims()); dst->set_layout(src.layout()); +#ifdef PADDLE_WITH_MKLDNN + dst->set_format(src.format()); +#endif auto src_place = src.place(); auto src_ptr = src.data(); auto dst_ptr = dst->mutable_data(dst_place, src.type()); @@ -237,6 +240,9 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place, src.check_memory_size(); dst->Resize(src.dims()); dst->set_layout(src.layout()); +#ifdef PADDLE_WITH_MKLDNN + dst->set_format(src.format()); +#endif auto src_place = src.place(); auto src_ptr = src.data(); auto dst_ptr = dst->mutable_data(dst_place, src.type());