From 1f598dfa18f25a40969f88b92dc2165d9c81b03e Mon Sep 17 00:00:00 2001 From: bingyanghuang <33643817+bingyanghuang@users.noreply.github.com> Date: Fri, 6 Dec 2019 15:47:13 +0800 Subject: [PATCH] cherry-pick MKL-DNN NHWC FWD support fix (#21593) --- paddle/fluid/operators/batch_norm_op.cc | 37 ++++++++++++++++++- paddle/fluid/operators/batch_norm_op.h | 4 ++ .../fluid/operators/mkldnn/conv_mkldnn_op.cc | 17 ++++++++- .../mkldnn/test_batch_norm_mkldnn_op.py | 7 ++++ .../tests/unittests/test_batch_norm_op.py | 15 ++++++++ 5 files changed, 77 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index e835e61897d..ee5dffc3f05 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -79,8 +79,9 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { x_dims, x_dims.size()); const int64_t C = - (data_layout == DataLayout::kNCHW ? x_dims[1] - : x_dims[x_dims.size() - 1]); + ((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW) + ? x_dims[1] + : x_dims[x_dims.size() - 1]); auto scale_dim = ctx->GetInputDim("Scale"); auto bias_dim = ctx->GetInputDim("Bias"); @@ -154,6 +155,32 @@ framework::OpKernelType BatchNormOp::GetExpectedKernelType( library); } +framework::OpKernelType BatchNormOp::GetKernelTypeForVar( + const std::string &var_name, const Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const { +#ifdef PADDLE_WITH_MKLDNN + // Only input require reshaping, weights and + // bias are having shape in NCHW order + if ((var_name == "X") && + (expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) && + (tensor.layout() != framework::DataLayout::kMKLDNN)) { + auto attrs = Attrs(); + auto ar = paddle::framework::AttrReader(attrs); + const std::string data_layout = ar.Get("data_layout"); + auto dl = framework::StringToDataLayout(data_layout); + // Some models may have intentionally set "AnyLayout" for pool + // op. Treat this as NCHW (default data_format value) + if (dl != framework::DataLayout::kAnyLayout) { + return framework::OpKernelType( + expected_kernel_type.data_type_, tensor.place(), + framework::StringToDataLayout(data_layout)); + } + } +#endif + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); +} + void BatchNormOpMaker::Make() { AddAttr("is_test", "(bool, default false) Set to true for inference only, false " @@ -446,6 +473,12 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType( #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && platform::CanMKLDNNBeUsed(ctx)) { + // TODO(jczaja): Add support for NHWC + const std::string data_layout = ctx.Attr("data_layout"); + PADDLE_ENFORCE_NE( + data_layout, "NHWC", + platform::errors::Unimplemented( + "Batch Norm MKLDNN grad does not support NHWC data format yet")); library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } diff --git a/paddle/fluid/operators/batch_norm_op.h b/paddle/fluid/operators/batch_norm_op.h index 6c7dbe0db4e..0bf81bb4559 100644 --- a/paddle/fluid/operators/batch_norm_op.h +++ b/paddle/fluid/operators/batch_norm_op.h @@ -47,6 +47,10 @@ class BatchNormOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override; + + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override; }; class BatchNormGradOp : public framework::OperatorWithKernel { diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index 124171e4cf0..611704badc7 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -775,8 +775,23 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { * ('any') which lets a primitive (conv backward in this case) choose * the memory format preferred for best performance */ - auto chosen_memory_format = MKLDNNMemoryFormat::any; + + // TODO(jczaja): Once GRAD NHWC is working then format 'any' + // should be used exclusively. But till forward pass enforce + // NCHW for training we need to have NCHW here as well + // to avoid performance degradation in relu_grad and pool2d_grad + std::string data_format = ctx.Attr("data_format"); + auto chosen_memory_format = + platform::data_format_to_memory_format(data_format); + weights_format = MKLDNNMemoryFormat::any; + // Check the format for user's special output + if (chosen_memory_format != MKLDNNMemoryFormat::any) { + if (is_conv3d) { + chosen_memory_format = + platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format); + } + } auto src_md = platform::MKLDNNMemDesc( src_tz, platform::MKLDNNGetDataType(), chosen_memory_format); diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_batch_norm_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_batch_norm_mkldnn_op.py index eb12470789a..f802325d634 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_batch_norm_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_batch_norm_mkldnn_op.py @@ -84,6 +84,13 @@ class TestMKLDNNBatchNormOpInference(TestBatchNormOpInference): self.check_with_place(place, data_format, self.dtype, [2, 3, 4, 5]) +class TestMKLDNNBatchNormOpInference_NHWC(TestMKLDNNBatchNormOpInference): + def test_check_output(self): + place = core.CPUPlace() + data_format = "NHWC" + self.check_with_place(place, data_format, self.dtype, [2, 4, 5, 3]) + + class TestMKLDNNBatchNormOpWithReluInference(TestBatchNormOpInference): def init_kernel_type(self): self.use_mkldnn = True diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py index 2d9f38acb89..64a3551ede9 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py @@ -259,6 +259,21 @@ class TestBatchNormOpInference(unittest.TestCase): batch_norm_op.run(scope, place) + # When op is called without Executor then + # MKL-DNN Tensor is returned. For NHWC data layout + # dims will be in NCHW order as it is MKL-DNN way + # of memory descripting. So we need to convert NCHW + # dims into NHWC. + if data_layout == "NHWC" and self.use_mkldnn == True: + # Create executor to have MKL-DNN cache + # cleared after NHWC unit test + place = core.CPUPlace() + exe = fluid.Executor(place) + dims = y_tensor.shape() + c = dims.pop(1) + dims.append(c) + y_tensor._set_dims(dims) + # check inference result self.__assert_close( y_tensor, -- GitLab