diff --git a/paddle/fluid/framework/data_layout_transform.cc b/paddle/fluid/framework/data_layout_transform.cc index 108cd9ac6d1c0778b7f614116b5739502fcfb0ee..8563b5b6d3695878e4f65c131cff600d08451e4c 100644 --- a/paddle/fluid/framework/data_layout_transform.cc +++ b/paddle/fluid/framework/data_layout_transform.cc @@ -203,7 +203,7 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout, // As MKL-DNN description was in NCHW and paddle is expecting NHWC platform::MatchShapeToLayout(out, in_layout, out_layout); - out->set_layout(out_layout); + out->set_layout(DataLayout::kNCHW); // reset format since the out tensor will be feed to non-MKLDNN OPkernel out->set_format(MKLDNNMemoryFormat::undef); } diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 53e6f4aa6e41bb8c02c01b4897e35c103260e167..5fa8f6bab8cca94617f401f8b50b2572d9a55cb3 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -144,4 +144,5 @@ cc_test(op_debug_string_test SRCS op_debug_string_test.cc DEPS elementwise_add_o if(WITH_MKLDNN) include(mkldnn/inplace_op_tests.cmake) +include(mkldnn/nhwc_op_tests.cmake) endif() diff --git a/paddle/fluid/operators/mkldnn/nhwc_op_tests.cmake b/paddle/fluid/operators/mkldnn/nhwc_op_tests.cmake new file mode 100644 index 0000000000000000000000000000000000000000..232626df02e5045c5fe66395255553d8522aa411 --- /dev/null +++ b/paddle/fluid/operators/mkldnn/nhwc_op_tests.cmake @@ -0,0 +1,2 @@ +cc_test(test_mkldnn_op_nhwc SRCS mkldnn/test_mkldnn_op_nhwc.cc DEPS op_registry pool_op pooling transpose_op scope device_context enforce executor) + diff --git a/paddle/fluid/operators/mkldnn/test_mkldnn_op_nhwc.cc b/paddle/fluid/operators/mkldnn/test_mkldnn_op_nhwc.cc new file mode 100644 index 0000000000000000000000000000000000000000..e7caeef85f5f95859096c2d6edb494106b0a0f93 --- /dev/null +++ b/paddle/fluid/operators/mkldnn/test_mkldnn_op_nhwc.cc @@ -0,0 +1,94 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/place.h" + +USE_OP(pool2d); +USE_OP_DEVICE_KERNEL(pool2d, MKLDNN); +USE_OP(transpose); +USE_OP_DEVICE_KERNEL(transpose, MKLDNN); + +namespace paddle { +namespace operators { + +struct InputVars { + std::string name; + framework::LoDTensor *tensor; +}; + +TEST(test_pool2d_transpose_nhwc, cpu_place) { + framework::DDim dims({1, 4, 8, 512}); // NHWC shape + framework::DDim expected_dims({1, 7, 512, 3}); // NHWC expected shape + platform::CPUPlace p; + framework::Scope scope; + + InputVars input_name = {"x", + scope.Var("x")->GetMutable()}; + // Initialize input data + std::uniform_real_distribution dist(static_cast(10.0), + static_cast(20.0)); + std::mt19937 engine; + size_t numel = static_cast(framework::product(dims)); + input_name.tensor->Resize(dims); + auto data_ptr = input_name.tensor->mutable_data(p); + for (size_t i = 0; i < numel; ++i) { + data_ptr[i] = dist(engine); + } + + scope.Var("y")->GetMutable(); + auto *z = scope.Var("z")->GetMutable(); + + auto &pool = platform::DeviceContextPool::Instance(); + + // Make pool2d followed by transpose + + auto ksize = std::vector(2, 2); + auto op_pool = framework::OpRegistry::CreateOp( + "pool2d", {{"X", {"x"}}}, {{"Out", {"y"}}}, + {{"pooling_type", {std::string("max")}}, + {"ksize", {ksize}}, + {"data_format", {std::string("NHWC")}}, + {"use_mkldnn", {true}}}); + + auto axis = std::vector(4, 0); + axis[1] = 2; + axis[2] = 3; + axis[3] = 1; + auto op_transpose = framework::OpRegistry::CreateOp( + "transpose", {{"X", {"y"}}}, {{"Out", {"z"}}}, + {{"axis", {axis}}, {"use_mkldnn", {true}}}); + + op_pool->Run(scope, p); + op_transpose->Run(scope, p); + pool.Get(p)->Wait(); + + // Verify shape of output + PADDLE_ENFORCE_EQ(z->dims(), expected_dims, + platform::errors::InvalidArgument( + "Computed shape does not match expected shape")); +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index 946fa6305d737363dab3cea2e2b581f2e5659cfd..0e870937ec1a51d577d92aca7da7c6853a68f786 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -61,6 +61,19 @@ class TransposeOp : public framework::OperatorWithKernel { } framework::DDim out_dims(x_dims); +#ifdef PADDLE_WITH_MKLDNN + // Here we need to match dims to paddle layout + // as we are producing non-oneDNN result + if ((x_dims.size() >= 3) && + (paddle::platform::MKLDNNDeviceContext::tls() + .get_cur_paddle_data_layout() == framework::DataLayout::kNHWC)) { + auto dims = framework::vectorize(x_dims); + std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end()); + x_dims = x_dims.reshape(dims); + VLOG(3) + << "Rotating Shape in Transpose from: kMKLDNN to: kNHWC output_shape"; + } +#endif for (size_t i = 0; i < axis_size; i++) { out_dims[i] = x_dims[axis[i]]; } diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index b012a103ea3031efb381d7039b15e82b2af52bf7..d8dd166f325c8dbb0d3262e72dca32279f4a1c33 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -14,7 +14,9 @@ limitations under the License. */ #pragma once #include +#include #include +#include #include #include #include @@ -81,12 +83,30 @@ inline void MatchShapeToLayout(framework::Tensor* tensor_in, return; } + auto print_dims = [](const std::vector& dims) { + std::ostringstream oss; + + if (!dims.empty()) { + oss << "["; + // Convert all but the last element to avoid a trailing "," + std::copy(dims.begin(), dims.end() - 1, + std::ostream_iterator(oss, ",")); + + // Now add the last element with no delimiter + oss << dims.back() << "]"; + } + + return oss.str(); + }; + switch (from) { case framework::DataLayout::kMKLDNN: if (to == framework::DataLayout::kNHWC) { auto dims = framework::vectorize(tensor_in->dims()); std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end()); tensor_in->Resize(framework::make_ddim(dims)); + VLOG(3) << "Rotating Shape from: kMKLDNN to: kNHWC output_shape" + << print_dims(dims); } break; case framework::DataLayout::kNHWC: @@ -94,6 +114,8 @@ inline void MatchShapeToLayout(framework::Tensor* tensor_in, auto dims = framework::vectorize(tensor_in->dims()); std::rotate(dims.begin() + 1, dims.end() - 1, dims.end()); tensor_in->Resize(framework::make_ddim(dims)); + VLOG(3) << "Rotating Shape from: kNHWC to: kMKLDNN output_shape" + << print_dims(dims); } break; default: