From 0cc5eddb47c260535d37141df5e9a1f2ee17c7d7 Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Thu, 17 Sep 2020 11:47:02 +0200 Subject: [PATCH] - condidate fix to issue #25537 test=develop --- .../fluid/framework/data_layout_transform.cc | 2 +- paddle/fluid/operators/transpose_op.cc | 13 +++++++++++ paddle/fluid/platform/mkldnn_helper.h | 22 +++++++++++++++++++ 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/data_layout_transform.cc b/paddle/fluid/framework/data_layout_transform.cc index f757e244e38..291e3cda8a8 100644 --- a/paddle/fluid/framework/data_layout_transform.cc +++ b/paddle/fluid/framework/data_layout_transform.cc @@ -203,7 +203,7 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout, // As MKL-DNN description was in NCHW and paddle is expecting NHWC platform::MatchShapeToLayout(out, in_layout, out_layout); - out->set_layout(out_layout); + out->set_layout(DataLayout::kNCHW); // reset format since the out tensor will be feed to non-MKLDNN OPkernel out->set_format(MKLDNNMemoryFormat::undef); } diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index 946fa6305d7..0e870937ec1 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -61,6 +61,19 @@ class TransposeOp : public framework::OperatorWithKernel { } framework::DDim out_dims(x_dims); +#ifdef PADDLE_WITH_MKLDNN + // Here we need to match dims to paddle layout + // as we are producing non-oneDNN result + if ((x_dims.size() >= 3) && + (paddle::platform::MKLDNNDeviceContext::tls() + .get_cur_paddle_data_layout() == framework::DataLayout::kNHWC)) { + auto dims = framework::vectorize(x_dims); + std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end()); + x_dims = x_dims.reshape(dims); + VLOG(3) + << "Rotating Shape in Transpose from: kMKLDNN to: kNHWC output_shape"; + } +#endif for (size_t i = 0; i < axis_size; i++) { out_dims[i] = x_dims[axis[i]]; } diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index b012a103ea3..d8dd166f325 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -14,7 +14,9 @@ limitations under the License. */ #pragma once #include +#include #include +#include #include #include #include @@ -81,12 +83,30 @@ inline void MatchShapeToLayout(framework::Tensor* tensor_in, return; } + auto print_dims = [](const std::vector& dims) { + std::ostringstream oss; + + if (!dims.empty()) { + oss << "["; + // Convert all but the last element to avoid a trailing "," + std::copy(dims.begin(), dims.end() - 1, + std::ostream_iterator(oss, ",")); + + // Now add the last element with no delimiter + oss << dims.back() << "]"; + } + + return oss.str(); + }; + switch (from) { case framework::DataLayout::kMKLDNN: if (to == framework::DataLayout::kNHWC) { auto dims = framework::vectorize(tensor_in->dims()); std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end()); tensor_in->Resize(framework::make_ddim(dims)); + VLOG(3) << "Rotating Shape from: kMKLDNN to: kNHWC output_shape" + << print_dims(dims); } break; case framework::DataLayout::kNHWC: @@ -94,6 +114,8 @@ inline void MatchShapeToLayout(framework::Tensor* tensor_in, auto dims = framework::vectorize(tensor_in->dims()); std::rotate(dims.begin() + 1, dims.end() - 1, dims.end()); tensor_in->Resize(framework::make_ddim(dims)); + VLOG(3) << "Rotating Shape from: kNHWC to: kMKLDNN output_shape" + << print_dims(dims); } break; default: -- GitLab