From f4cf028a8c69de3622ff9c2c5435a37405c11468 Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Tue, 26 Nov 2019 03:07:02 +0100 Subject: [PATCH] [MKL-DNN] Error throwing for NHWC layout for MKL-DNN ops (#21207) --- paddle/fluid/operators/conv_op.cc | 27 +++++++++++++------ paddle/fluid/operators/conv_transpose_op.cc | 5 ++++ paddle/fluid/operators/lrn_op.cc | 12 +++++++++ paddle/fluid/operators/pool_op.cc | 12 +++++++++ .../unittests/mkldnn/test_lrn_mkldnn_op.py | 16 +++++++++++ .../unittests/mkldnn/test_pool2d_mkldnn_op.py | 21 +++++++++++++++ 6 files changed, 85 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 0a4ae0866b1..ce60e97f4b2 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -151,6 +151,15 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && platform::CanMKLDNNBeUsed(ctx)) { + // TODO(jczaja): Add support for NHWC + const std::string data_format = ctx.Attr("data_format"); + PADDLE_ENFORCE_NE(data_format, "NHWC", + platform::errors::Unimplemented( + "Conv MKLDNN does not support NHWC data format yet")); + PADDLE_ENFORCE_NE( + data_format, "NDHWC", + platform::errors::Unimplemented( + "Conv MKLDNN does not support NDHWC data format yet")); library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; customized_type_value = @@ -524,6 +533,16 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && platform::CanMKLDNNBeUsed(ctx)) { + // TODO(jczaja): Add support for NHWC + const std::string data_format = ctx.Attr("data_format"); + PADDLE_ENFORCE_NE( + data_format, "NHWC", + platform::errors::Unimplemented( + "Conv MKLDNN grad does not support NHWC data format yet")); + PADDLE_ENFORCE_NE( + data_format, "NDHWC", + platform::errors::Unimplemented( + "Conv MKLDNN Grad does not support NDHWC data format yet")); library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; customized_type_value = kConvMKLDNNFP32; @@ -706,14 +725,6 @@ framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType( if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; } -#endif -#ifdef PADDLE_WITH_MKLDNN - if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { - library_ = framework::LibraryType::kMKLDNN; - layout_ = framework::DataLayout::kMKLDNN; - customized_type_value = kConvMKLDNNFP32; - } #endif auto type = framework::OpKernelType( OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), diff --git a/paddle/fluid/operators/conv_transpose_op.cc b/paddle/fluid/operators/conv_transpose_op.cc index 930d4873540..25c84501400 100644 --- a/paddle/fluid/operators/conv_transpose_op.cc +++ b/paddle/fluid/operators/conv_transpose_op.cc @@ -145,6 +145,11 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && platform::CanMKLDNNBeUsed(ctx)) { + // TODO(jczaja): Add support for NHWC + const std::string data_format = ctx.Attr("data_format"); + PADDLE_ENFORCE_NE( + data_format, "NHWC", + "Conv Transpose MKLDNN does not support NHWC data format yet"); library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } diff --git a/paddle/fluid/operators/lrn_op.cc b/paddle/fluid/operators/lrn_op.cc index 33a418f9f44..cfded0370b0 100644 --- a/paddle/fluid/operators/lrn_op.cc +++ b/paddle/fluid/operators/lrn_op.cc @@ -193,6 +193,12 @@ class LRNOp : public framework::OperatorWithKernel { #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && platform::CanMKLDNNBeUsed(ctx)) { + // TODO(jczaja): Add support for NHWC + const std::string data_format = ctx.Attr("data_format"); + PADDLE_ENFORCE_NE( + data_format, "NHWC", + platform::errors::Unimplemented( + "LRN MKLDNN does not support NHWC data format yet")); library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } @@ -311,6 +317,12 @@ class LRNOpGrad : public framework::OperatorWithKernel { #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && platform::CanMKLDNNBeUsed(ctx)) { + // TODO(jczaja): Add support for NHWC + const std::string data_format = ctx.Attr("data_format"); + PADDLE_ENFORCE_NE( + data_format, "NHWC", + platform::errors::Unimplemented( + "LRN MKLDNN grad does not support NHWC data format yet")); library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index 469e123d9b0..069f2339f40 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -146,6 +146,12 @@ framework::OpKernelType PoolOp::GetExpectedKernelType( #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && platform::CanMKLDNNBeUsed(ctx)) { + // TODO(jczaja): Add support for NHWC + const std::string data_format = ctx.Attr("data_format"); + PADDLE_ENFORCE_NE( + data_format, "NHWC", + platform::errors::Unimplemented( + "Pool MKLDNN grad does not support NHWC data format yet")); library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } @@ -177,6 +183,12 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType( #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && platform::CanMKLDNNBeUsed(ctx)) { + // TODO(jczaja): Add support for NHWC + const std::string data_format = ctx.Attr("data_format"); + PADDLE_ENFORCE_NE( + data_format, "NHWC", + platform::errors::Unimplemented( + "Pool MKLDNN grad does not support NHWC data format yet")); library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_lrn_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_lrn_mkldnn_op.py index 37951ac8e71..94787195788 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_lrn_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_lrn_mkldnn_op.py @@ -16,6 +16,7 @@ from __future__ import print_function import unittest from paddle.fluid.tests.unittests.test_lrn_op import TestLRNOp +import paddle.fluid as fluid class TestLRNMKLDNNOp(TestLRNOp): @@ -54,5 +55,20 @@ class TestLRNMKLDNNOpWithIsTest(TestLRNMKLDNNOp): self.assertRaises(AttributeError, check_raise_is_test) +# TODO(jczaja): Once mkl-dnn integration support NHWC input +# then those tests should be changed to actual functional positive tests +class TestLRNMKLDNNOpNHWC(TestLRNMKLDNNOp): + def init_test_case(self): + self.data_format = 'NHWC' + + def test_check_output(self): + pass + + # Grad tests both FWD and BWD ops kernels creation + def test_check_grad_normal(self): + with self.assertRaises(fluid.core_avx.EnforceNotMet): + self.check_grad(['X'], 'Out', max_relative_error=0.01) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_pool2d_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_pool2d_mkldnn_op.py index 5a9c10073a4..b8403bc3c6f 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_pool2d_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_pool2d_mkldnn_op.py @@ -141,5 +141,26 @@ class TestAsymPadValid(TestAsymPad): self.padding_algorithm = "VALID" +# Designed to Fail +# TODO(jczaja): Once mkl-dnn integration support NHWC input +# then those tests should be changed to actual functional positive tests +class TestAsymPadValidNHWC(TestAsymPadValid): + def init_data_format(self): + self.data_format = "NHWC" + + def init_shape(self): + self.shape = [2, 7, 7, 3] + + def test_check_output(self): + pass + + # Grad tests both FWD and BWD ops kernels creation + # GetExpectedKernelType should throw an exception on lack of support + # to NHWC inputs in pool mkldnn kernel + def test_check_grad(self): + with self.assertRaises(fluid.core_avx.EnforceNotMet): + super(TestAsymPadValidNHWC, self).test_check_grad() + + if __name__ == '__main__': unittest.main() -- GitLab