未验证 提交 f884edb9 编写于 作者: J Jacek Czaja 提交者: GitHub

Fix to #38126 (#39097)

* - 38126 potential fix

* - fix

* - build fix

* - another candidate fix

* - compilation fix

* - another fix

* - Fix to activation of NHWC being first oneDNN op in chain on oneDNN ops

* - compilation fix

* - added NHWC reotating for elementwise being first op

* - compilation fix

* - compilation fix

* - Added UT

* - cosmetic fixes
上级 ba882657
...@@ -63,9 +63,13 @@ void TransformData(const OpKernelType &expected_kernel_type, ...@@ -63,9 +63,13 @@ void TransformData(const OpKernelType &expected_kernel_type,
out.ShareDataWith(input_tensor); out.ShareDataWith(input_tensor);
// For NHWC data we need reshape of tensors as MKL-DNN // For NHWC data we need reshape of tensors as MKL-DNN
// is expecting NHWC dims description order // is expecting NHWC dims description order
if (lin == DataLayout::kNHWC) {
platform::MatchShapeToLayout(&out, lin, lout); platform::MatchShapeToLayout(&out, lin, lout);
paddle::platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout( // We register only NHWC assuming that model is consistent e.g. either
lin); // NHWC or NCHW
paddle::platform::MKLDNNDeviceContext::tls()
.set_cur_paddle_data_layout(lin);
}
out.set_layout(DataLayout::kMKLDNN); out.set_layout(DataLayout::kMKLDNN);
out.set_format(out_format); out.set_format(out_format);
} else { } else {
......
...@@ -128,6 +128,26 @@ class ActivationOp : public framework::OperatorWithKernel { ...@@ -128,6 +128,26 @@ class ActivationOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "X"); return GetKernelType(ctx, *this, "X");
} }
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN
// When activation is first oneDNN op (there was some non oneDNN op
// previously)
// then we also need to rotate shape NHWC -> NCWH
if ((expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
(tensor.layout() != framework::DataLayout::kMKLDNN) &&
paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout() == framework::DataLayout::kNHWC) {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(),
framework::DataLayout::kNHWC);
}
#endif
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
}; };
class ActivationOpInferVarType class ActivationOpInferVarType
......
...@@ -101,9 +101,37 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -101,9 +101,37 @@ class ElementwiseOp : public framework::OperatorWithKernel {
std::vector<int> x_dims_array(max_dim); std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim); std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim); std::vector<int> out_dims_array(max_dim);
#ifdef PADDLE_WITH_MKLDNN
// (jczaja): Broadcasting of dims has to be done on Paddle shapes (NHWC)
// if model is using NHWC.
bool should_rotate =
ctx->IsRunMKLDNNKernel() &&
(platform::MKLDNNDeviceContext::tls().get_cur_paddle_data_layout() ==
framework::DataLayout::kNHWC);
if (should_rotate) {
// Pick bigger shape and rotate this one
bool x_over_y = (x_dims.size() > y_dims.size());
auto vdims = x_over_y ? framework::vectorize<int>(x_dims)
: framework::vectorize<int>(y_dims);
std::rotate(vdims.begin() + 1, vdims.begin() + 2, vdims.end());
if (x_over_y) {
x_dims = framework::make_ddim(vdims);
} else {
y_dims = framework::make_ddim(vdims);
}
}
#endif
GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(), GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(),
y_dims_array.data(), out_dims_array.data(), y_dims_array.data(), out_dims_array.data(),
max_dim, axis); max_dim, axis);
#ifdef PADDLE_WITH_MKLDNN
// Now rotate shape back if needed (NHWC -> NCHW)
if (should_rotate) {
std::rotate(out_dims_array.begin() + 1, out_dims_array.end() - 1,
out_dims_array.end());
}
#endif
ctx->SetOutputDim("Out", framework::make_ddim(out_dims_array)); ctx->SetOutputDim("Out", framework::make_ddim(out_dims_array));
// to do // to do
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
...@@ -133,6 +161,21 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -133,6 +161,21 @@ class ElementwiseOp : public framework::OperatorWithKernel {
return framework::OpKernelType(tensor.type(), tensor.place(), return framework::OpKernelType(tensor.type(), tensor.place(),
tensor.layout()); tensor.layout());
} else { } else {
#ifdef PADDLE_WITH_MKLDNN
// When elementwise is first oneDNN op (there was some non oneDNN op
// previously)
// then we also need to rotate shape NHWC -> NCWH
if ((expected_kernel_type.data_layout_ ==
framework::DataLayout::kMKLDNN) &&
(tensor.layout() != framework::DataLayout::kMKLDNN) &&
paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout() ==
framework::DataLayout::kNHWC) {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(),
framework::DataLayout::kNHWC);
}
#endif
return framework::OpKernelType(expected_kernel_type.data_type_, return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout()); tensor.place(), tensor.layout());
} }
......
cc_test(test_mkldnn_op_nhwc SRCS mkldnn/test_mkldnn_op_nhwc.cc DEPS op_registry pool_op pooling transpose_op scope device_context enforce executor) cc_test(test_mkldnn_op_nhwc SRCS mkldnn/test_mkldnn_op_nhwc.cc DEPS op_registry pool_op activation_op pooling transpose_op scope device_context enforce executor)
...@@ -27,6 +27,8 @@ ...@@ -27,6 +27,8 @@
USE_OP(pool2d); USE_OP(pool2d);
USE_OP_DEVICE_KERNEL(pool2d, MKLDNN); USE_OP_DEVICE_KERNEL(pool2d, MKLDNN);
USE_OP(relu);
USE_OP_DEVICE_KERNEL(relu, MKLDNN);
USE_OP(transpose); USE_OP(transpose);
USE_OP_DEVICE_KERNEL(transpose, MKLDNN); USE_OP_DEVICE_KERNEL(transpose, MKLDNN);
...@@ -90,5 +92,63 @@ TEST(test_pool2d_transpose_nhwc, cpu_place) { ...@@ -90,5 +92,63 @@ TEST(test_pool2d_transpose_nhwc, cpu_place) {
"Computed shape does not match expected shape")); "Computed shape does not match expected shape"));
} }
TEST(test_pool2d_relu_relu_nhwc, cpu_place) {
framework::DDim dims({1, 4, 8, 512}); // NHWC shape
framework::DDim expected_dims({1, 512, 3, 7}); // NHWC expected shape
platform::CPUPlace p;
framework::Scope scope;
InputVars input_name = {"x",
scope.Var("x")->GetMutable<framework::LoDTensor>()};
// Initialize input data
std::uniform_real_distribution<float> dist(static_cast<float>(10.0),
static_cast<float>(20.0));
std::mt19937 engine;
size_t numel = static_cast<size_t>(framework::product(dims));
input_name.tensor->Resize(dims);
auto data_ptr = input_name.tensor->mutable_data<float>(p);
for (size_t i = 0; i < numel; ++i) {
data_ptr[i] = dist(engine);
}
scope.Var("y")->GetMutable<framework::LoDTensor>();
scope.Var("u")->GetMutable<framework::LoDTensor>();
auto *z = scope.Var("z")->GetMutable<framework::LoDTensor>();
auto &pool = platform::DeviceContextPool::Instance();
// Make pool2d(oneDNN) followed by relu(CPU paddle) followed by
// relu(oneDNN). Second relu should make a shape rotation to NCHW
auto ksize = std::vector<int>(2, 2);
auto op_pool = framework::OpRegistry::CreateOp(
"pool2d", {{"X", {"x"}}}, {{"Out", {"y"}}},
{{"pooling_type", {std::string("max")}},
{"ksize", {ksize}},
{"data_format", {std::string("NHWC")}},
{"use_mkldnn", {true}}});
auto axis = std::vector<int>(4, 0);
axis[1] = 2;
axis[2] = 3;
axis[3] = 1;
auto op_relu1 = framework::OpRegistry::CreateOp(
"relu", {{"X", {"y"}}}, {{"Out", {"u"}}},
{{"axis", {axis}}, {"use_mkldnn", {false}}});
auto op_relu2 = framework::OpRegistry::CreateOp(
"relu", {{"X", {"u"}}}, {{"Out", {"z"}}}, {{"use_mkldnn", {true}}});
op_pool->Run(scope, p);
op_relu1->Run(scope, p);
op_relu2->Run(scope, p);
pool.Get(p)->Wait();
// Verify shape of output
PADDLE_ENFORCE_EQ(z->dims(), expected_dims,
platform::errors::InvalidArgument(
"Computed shape does not match expected shape"));
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -67,9 +67,11 @@ class TransferLayoutFunctor { ...@@ -67,9 +67,11 @@ class TransferLayoutFunctor {
out_tensor.ShareDataWith(in_tensor); out_tensor.ShareDataWith(in_tensor);
// For NHWC data we need reshape of tensors as MKL-DNN // For NHWC data we need reshape of tensors as MKL-DNN
// is expecting NHWC dims description order // is expecting NHWC dims description order
if (in_layout == DataLayout::kNHWC) {
platform::MatchShapeToLayout(&out_tensor, in_layout, out_layout); platform::MatchShapeToLayout(&out_tensor, in_layout, out_layout);
paddle::platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout( paddle::platform::MKLDNNDeviceContext::tls()
in_layout); .set_cur_paddle_data_layout(in_layout);
}
out_tensor.set_layout(DataLayout::kMKLDNN); out_tensor.set_layout(DataLayout::kMKLDNN);
out_tensor.set_format(out_format); out_tensor.set_format(out_format);
} else { } else {
......
...@@ -651,11 +651,21 @@ class BinaryMKLDNNHandler ...@@ -651,11 +651,21 @@ class BinaryMKLDNNHandler
std::vector<int64_t> dims1_ex(rankdiff, 1); std::vector<int64_t> dims1_ex(rankdiff, 1);
dims1_ex.insert(next(dims1_ex.begin(), (axis == -1 ? rankdiff : axis)), dims1_ex.insert(next(dims1_ex.begin(), (axis == -1 ? rankdiff : axis)),
src_y_tz.begin(), src_y_tz.end()); src_y_tz.begin(), src_y_tz.end());
// For broadcasting for NHWC we need rotate extended shape
if (MKLDNNDeviceContext::tls().get_cur_paddle_data_layout() ==
framework::DataLayout::kNHWC) {
std::rotate(dims1_ex.begin() + 1, dims1_ex.end() - 1, dims1_ex.end());
}
src1_md = src1_md.reshape(dims1_ex); src1_md = src1_md.reshape(dims1_ex);
} else if (rankdiff < 0) { // First input is of smaller than second } else if (rankdiff < 0) { // First input is of smaller than second
std::vector<int64_t> dims0_ex(-rankdiff, 1); std::vector<int64_t> dims0_ex(-rankdiff, 1);
dims0_ex.insert(next(dims0_ex.begin(), (axis == -1 ? -rankdiff : axis)), dims0_ex.insert(next(dims0_ex.begin(), (axis == -1 ? -rankdiff : axis)),
src_x_tz.begin(), src_x_tz.end()); src_x_tz.begin(), src_x_tz.end());
// For broadcasting for NHWC we need rotate extended shape
if (MKLDNNDeviceContext::tls().get_cur_paddle_data_layout() ==
framework::DataLayout::kNHWC) {
std::rotate(dims0_ex.begin() + 1, dims0_ex.end() - 1, dims0_ex.end());
}
src0_md = src0_md.reshape(dims0_ex); src0_md = src0_md.reshape(dims0_ex);
} }
const auto dst_md = memory::desc(dst_tz, platform::MKLDNNGetDataType<T>(), const auto dst_md = memory::desc(dst_tz, platform::MKLDNNGetDataType<T>(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册