diff --git a/paddle/fluid/framework/data_layout_transform.cc b/paddle/fluid/framework/data_layout_transform.cc index f757e244e38ec965d62d673e63ed082ca70c63c7..291e3cda8a812f85d5114c65eaace1112f619356 100644 --- a/paddle/fluid/framework/data_layout_transform.cc +++ b/paddle/fluid/framework/data_layout_transform.cc @@ -203,7 +203,7 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout, // As MKL-DNN description was in NCHW and paddle is expecting NHWC platform::MatchShapeToLayout(out, in_layout, out_layout); - out->set_layout(out_layout); + out->set_layout(DataLayout::kNCHW); // reset format since the out tensor will be feed to non-MKLDNN OPkernel out->set_format(MKLDNNMemoryFormat::undef); } diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index 946fa6305d737363dab3cea2e2b581f2e5659cfd..0e870937ec1a51d577d92aca7da7c6853a68f786 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -61,6 +61,19 @@ class TransposeOp : public framework::OperatorWithKernel { } framework::DDim out_dims(x_dims); +#ifdef PADDLE_WITH_MKLDNN + // Here we need to match dims to paddle layout + // as we are producing non-oneDNN result + if ((x_dims.size() >= 3) && + (paddle::platform::MKLDNNDeviceContext::tls() + .get_cur_paddle_data_layout() == framework::DataLayout::kNHWC)) { + auto dims = framework::vectorize(x_dims); + std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end()); + x_dims = x_dims.reshape(dims); + VLOG(3) + << "Rotating Shape in Transpose from: kMKLDNN to: kNHWC output_shape"; + } +#endif for (size_t i = 0; i < axis_size; i++) { out_dims[i] = x_dims[axis[i]]; } diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index b012a103ea3031efb381d7039b15e82b2af52bf7..d8dd166f325c8dbb0d3262e72dca32279f4a1c33 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -14,7 +14,9 @@ limitations under the License. */ #pragma once #include +#include #include +#include #include #include #include @@ -81,12 +83,30 @@ inline void MatchShapeToLayout(framework::Tensor* tensor_in, return; } + auto print_dims = [](const std::vector& dims) { + std::ostringstream oss; + + if (!dims.empty()) { + oss << "["; + // Convert all but the last element to avoid a trailing "," + std::copy(dims.begin(), dims.end() - 1, + std::ostream_iterator(oss, ",")); + + // Now add the last element with no delimiter + oss << dims.back() << "]"; + } + + return oss.str(); + }; + switch (from) { case framework::DataLayout::kMKLDNN: if (to == framework::DataLayout::kNHWC) { auto dims = framework::vectorize(tensor_in->dims()); std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end()); tensor_in->Resize(framework::make_ddim(dims)); + VLOG(3) << "Rotating Shape from: kMKLDNN to: kNHWC output_shape" + << print_dims(dims); } break; case framework::DataLayout::kNHWC: @@ -94,6 +114,8 @@ inline void MatchShapeToLayout(framework::Tensor* tensor_in, auto dims = framework::vectorize(tensor_in->dims()); std::rotate(dims.begin() + 1, dims.end() - 1, dims.end()); tensor_in->Resize(framework::make_ddim(dims)); + VLOG(3) << "Rotating Shape from: kNHWC to: kMKLDNN output_shape" + << print_dims(dims); } break; default: