提交 0cc5eddb 编写于 作者: J Jacek Czaja

- condidate fix to issue #25537

test=develop
上级 da583edf
......@@ -203,7 +203,7 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
// As MKL-DNN description was in NCHW and paddle is expecting NHWC
platform::MatchShapeToLayout(out, in_layout, out_layout);
out->set_layout(out_layout);
out->set_layout(DataLayout::kNCHW);
// reset format since the out tensor will be feed to non-MKLDNN OPkernel
out->set_format(MKLDNNMemoryFormat::undef);
}
......
......@@ -61,6 +61,19 @@ class TransposeOp : public framework::OperatorWithKernel {
}
framework::DDim out_dims(x_dims);
#ifdef PADDLE_WITH_MKLDNN
// Here we need to match dims to paddle layout
// as we are producing non-oneDNN result
if ((x_dims.size() >= 3) &&
(paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout() == framework::DataLayout::kNHWC)) {
auto dims = framework::vectorize<int>(x_dims);
std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end());
x_dims = x_dims.reshape(dims);
VLOG(3)
<< "Rotating Shape in Transpose from: kMKLDNN to: kNHWC output_shape";
}
#endif
for (size_t i = 0; i < axis_size; i++) {
out_dims[i] = x_dims[axis[i]];
}
......
......@@ -14,7 +14,9 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include <iostream>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
......@@ -81,12 +83,30 @@ inline void MatchShapeToLayout(framework::Tensor* tensor_in,
return;
}
auto print_dims = [](const std::vector<int>& dims) {
std::ostringstream oss;
if (!dims.empty()) {
oss << "[";
// Convert all but the last element to avoid a trailing ","
std::copy(dims.begin(), dims.end() - 1,
std::ostream_iterator<int>(oss, ","));
// Now add the last element with no delimiter
oss << dims.back() << "]";
}
return oss.str();
};
switch (from) {
case framework::DataLayout::kMKLDNN:
if (to == framework::DataLayout::kNHWC) {
auto dims = framework::vectorize<int>(tensor_in->dims());
std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end());
tensor_in->Resize(framework::make_ddim(dims));
VLOG(3) << "Rotating Shape from: kMKLDNN to: kNHWC output_shape"
<< print_dims(dims);
}
break;
case framework::DataLayout::kNHWC:
......@@ -94,6 +114,8 @@ inline void MatchShapeToLayout(framework::Tensor* tensor_in,
auto dims = framework::vectorize<int>(tensor_in->dims());
std::rotate(dims.begin() + 1, dims.end() - 1, dims.end());
tensor_in->Resize(framework::make_ddim(dims));
VLOG(3) << "Rotating Shape from: kNHWC to: kMKLDNN output_shape"
<< print_dims(dims);
}
break;
default:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册