diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 0a4ae0866b1ab70150af311d9922fe587596ab1c..ce60e97f4b263102661f5689774fa07e6ac78164 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -151,6 +151,15 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && platform::CanMKLDNNBeUsed(ctx)) { + // TODO(jczaja): Add support for NHWC + const std::string data_format = ctx.Attr("data_format"); + PADDLE_ENFORCE_NE(data_format, "NHWC", + platform::errors::Unimplemented( + "Conv MKLDNN does not support NHWC data format yet")); + PADDLE_ENFORCE_NE( + data_format, "NDHWC", + platform::errors::Unimplemented( + "Conv MKLDNN does not support NDHWC data format yet")); library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; customized_type_value = @@ -524,6 +533,16 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && platform::CanMKLDNNBeUsed(ctx)) { + // TODO(jczaja): Add support for NHWC + const std::string data_format = ctx.Attr("data_format"); + PADDLE_ENFORCE_NE( + data_format, "NHWC", + platform::errors::Unimplemented( + "Conv MKLDNN grad does not support NHWC data format yet")); + PADDLE_ENFORCE_NE( + data_format, "NDHWC", + platform::errors::Unimplemented( + "Conv MKLDNN Grad does not support NDHWC data format yet")); library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; customized_type_value = kConvMKLDNNFP32; @@ -706,14 +725,6 @@ framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType( if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; } -#endif -#ifdef PADDLE_WITH_MKLDNN - if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { - library_ = framework::LibraryType::kMKLDNN; - layout_ = framework::DataLayout::kMKLDNN; - customized_type_value = kConvMKLDNNFP32; - } #endif auto type = framework::OpKernelType( OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), diff --git a/paddle/fluid/operators/conv_transpose_op.cc b/paddle/fluid/operators/conv_transpose_op.cc index 930d48735405015f2f5aefd876e7590837348cc4..25c8450140044fe5c73dd4a64cb52f033fa0ae9d 100644 --- a/paddle/fluid/operators/conv_transpose_op.cc +++ b/paddle/fluid/operators/conv_transpose_op.cc @@ -145,6 +145,11 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && platform::CanMKLDNNBeUsed(ctx)) { + // TODO(jczaja): Add support for NHWC + const std::string data_format = ctx.Attr("data_format"); + PADDLE_ENFORCE_NE( + data_format, "NHWC", + "Conv Transpose MKLDNN does not support NHWC data format yet"); library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } diff --git a/paddle/fluid/operators/lrn_op.cc b/paddle/fluid/operators/lrn_op.cc index 33a418f9f44f352f02256c9a2a9fa17512069771..cfded0370b0d3d2930d3cef379cdfbbb894a514d 100644 --- a/paddle/fluid/operators/lrn_op.cc +++ b/paddle/fluid/operators/lrn_op.cc @@ -193,6 +193,12 @@ class LRNOp : public framework::OperatorWithKernel { #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && platform::CanMKLDNNBeUsed(ctx)) { + // TODO(jczaja): Add support for NHWC + const std::string data_format = ctx.Attr("data_format"); + PADDLE_ENFORCE_NE( + data_format, "NHWC", + platform::errors::Unimplemented( + "LRN MKLDNN does not support NHWC data format yet")); library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } @@ -311,6 +317,12 @@ class LRNOpGrad : public framework::OperatorWithKernel { #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && platform::CanMKLDNNBeUsed(ctx)) { + // TODO(jczaja): Add support for NHWC + const std::string data_format = ctx.Attr("data_format"); + PADDLE_ENFORCE_NE( + data_format, "NHWC", + platform::errors::Unimplemented( + "LRN MKLDNN grad does not support NHWC data format yet")); library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index 469e123d9b044c81d47a552e262a7c74d6e5f5a9..069f2339f40cf041331e9ce78500986af01ed3d7 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -146,6 +146,12 @@ framework::OpKernelType PoolOp::GetExpectedKernelType( #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && platform::CanMKLDNNBeUsed(ctx)) { + // TODO(jczaja): Add support for NHWC + const std::string data_format = ctx.Attr("data_format"); + PADDLE_ENFORCE_NE( + data_format, "NHWC", + platform::errors::Unimplemented( + "Pool MKLDNN grad does not support NHWC data format yet")); library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } @@ -177,6 +183,12 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType( #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && platform::CanMKLDNNBeUsed(ctx)) { + // TODO(jczaja): Add support for NHWC + const std::string data_format = ctx.Attr("data_format"); + PADDLE_ENFORCE_NE( + data_format, "NHWC", + platform::errors::Unimplemented( + "Pool MKLDNN grad does not support NHWC data format yet")); library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_lrn_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_lrn_mkldnn_op.py index 37951ac8e711ade971ba6cdf39beb05d275ad6d4..9478719578821c4fb3a3e5da8b57b4fcda87767b 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_lrn_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_lrn_mkldnn_op.py @@ -16,6 +16,7 @@ from __future__ import print_function import unittest from paddle.fluid.tests.unittests.test_lrn_op import TestLRNOp +import paddle.fluid as fluid class TestLRNMKLDNNOp(TestLRNOp): @@ -54,5 +55,20 @@ class TestLRNMKLDNNOpWithIsTest(TestLRNMKLDNNOp): self.assertRaises(AttributeError, check_raise_is_test) +# TODO(jczaja): Once mkl-dnn integration support NHWC input +# then those tests should be changed to actual functional positive tests +class TestLRNMKLDNNOpNHWC(TestLRNMKLDNNOp): + def init_test_case(self): + self.data_format = 'NHWC' + + def test_check_output(self): + pass + + # Grad tests both FWD and BWD ops kernels creation + def test_check_grad_normal(self): + with self.assertRaises(fluid.core_avx.EnforceNotMet): + self.check_grad(['X'], 'Out', max_relative_error=0.01) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_pool2d_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_pool2d_mkldnn_op.py index 5a9c10073a4f217198fd185aaa09a7911cebdad7..b8403bc3c6f46981108fbe7b43bce93350211a96 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_pool2d_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_pool2d_mkldnn_op.py @@ -141,5 +141,26 @@ class TestAsymPadValid(TestAsymPad): self.padding_algorithm = "VALID" +# Designed to Fail +# TODO(jczaja): Once mkl-dnn integration support NHWC input +# then those tests should be changed to actual functional positive tests +class TestAsymPadValidNHWC(TestAsymPadValid): + def init_data_format(self): + self.data_format = "NHWC" + + def init_shape(self): + self.shape = [2, 7, 7, 3] + + def test_check_output(self): + pass + + # Grad tests both FWD and BWD ops kernels creation + # GetExpectedKernelType should throw an exception on lack of support + # to NHWC inputs in pool mkldnn kernel + def test_check_grad(self): + with self.assertRaises(fluid.core_avx.EnforceNotMet): + super(TestAsymPadValidNHWC, self).test_check_grad() + + if __name__ == '__main__': unittest.main()