提交 f4cf028a 编写于 作者: J Jacek Czaja 提交者: Tao Luo

[MKL-DNN] Error throwing for NHWC layout for MKL-DNN ops (#21207)

上级 ed9ceb9f
......@@ -151,6 +151,15 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
// TODO(jczaja): Add support for NHWC
const std::string data_format = ctx.Attr<std::string>("data_format");
PADDLE_ENFORCE_NE(data_format, "NHWC",
platform::errors::Unimplemented(
"Conv MKLDNN does not support NHWC data format yet"));
PADDLE_ENFORCE_NE(
data_format, "NDHWC",
platform::errors::Unimplemented(
"Conv MKLDNN does not support NDHWC data format yet"));
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
customized_type_value =
......@@ -524,6 +533,16 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
// TODO(jczaja): Add support for NHWC
const std::string data_format = ctx.Attr<std::string>("data_format");
PADDLE_ENFORCE_NE(
data_format, "NHWC",
platform::errors::Unimplemented(
"Conv MKLDNN grad does not support NHWC data format yet"));
PADDLE_ENFORCE_NE(
data_format, "NDHWC",
platform::errors::Unimplemented(
"Conv MKLDNN Grad does not support NDHWC data format yet"));
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
customized_type_value = kConvMKLDNNFP32;
......@@ -706,14 +725,6 @@ framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType(
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
customized_type_value = kConvMKLDNNFP32;
}
#endif
auto type = framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(),
......
......@@ -145,6 +145,11 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
// TODO(jczaja): Add support for NHWC
const std::string data_format = ctx.Attr<std::string>("data_format");
PADDLE_ENFORCE_NE(
data_format, "NHWC",
"Conv Transpose MKLDNN does not support NHWC data format yet");
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
......
......@@ -193,6 +193,12 @@ class LRNOp : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
// TODO(jczaja): Add support for NHWC
const std::string data_format = ctx.Attr<std::string>("data_format");
PADDLE_ENFORCE_NE(
data_format, "NHWC",
platform::errors::Unimplemented(
"LRN MKLDNN does not support NHWC data format yet"));
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
......@@ -311,6 +317,12 @@ class LRNOpGrad : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
// TODO(jczaja): Add support for NHWC
const std::string data_format = ctx.Attr<std::string>("data_format");
PADDLE_ENFORCE_NE(
data_format, "NHWC",
platform::errors::Unimplemented(
"LRN MKLDNN grad does not support NHWC data format yet"));
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
......
......@@ -146,6 +146,12 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
// TODO(jczaja): Add support for NHWC
const std::string data_format = ctx.Attr<std::string>("data_format");
PADDLE_ENFORCE_NE(
data_format, "NHWC",
platform::errors::Unimplemented(
"Pool MKLDNN grad does not support NHWC data format yet"));
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
......@@ -177,6 +183,12 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
// TODO(jczaja): Add support for NHWC
const std::string data_format = ctx.Attr<std::string>("data_format");
PADDLE_ENFORCE_NE(
data_format, "NHWC",
platform::errors::Unimplemented(
"Pool MKLDNN grad does not support NHWC data format yet"));
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
......
......@@ -16,6 +16,7 @@ from __future__ import print_function
import unittest
from paddle.fluid.tests.unittests.test_lrn_op import TestLRNOp
import paddle.fluid as fluid
class TestLRNMKLDNNOp(TestLRNOp):
......@@ -54,5 +55,20 @@ class TestLRNMKLDNNOpWithIsTest(TestLRNMKLDNNOp):
self.assertRaises(AttributeError, check_raise_is_test)
# TODO(jczaja): Once mkl-dnn integration support NHWC input
# then those tests should be changed to actual functional positive tests
class TestLRNMKLDNNOpNHWC(TestLRNMKLDNNOp):
def init_test_case(self):
self.data_format = 'NHWC'
def test_check_output(self):
pass
# Grad tests both FWD and BWD ops kernels creation
def test_check_grad_normal(self):
with self.assertRaises(fluid.core_avx.EnforceNotMet):
self.check_grad(['X'], 'Out', max_relative_error=0.01)
if __name__ == "__main__":
unittest.main()
......@@ -141,5 +141,26 @@ class TestAsymPadValid(TestAsymPad):
self.padding_algorithm = "VALID"
# Designed to Fail
# TODO(jczaja): Once mkl-dnn integration support NHWC input
# then those tests should be changed to actual functional positive tests
class TestAsymPadValidNHWC(TestAsymPadValid):
def init_data_format(self):
self.data_format = "NHWC"
def init_shape(self):
self.shape = [2, 7, 7, 3]
def test_check_output(self):
pass
# Grad tests both FWD and BWD ops kernels creation
# GetExpectedKernelType should throw an exception on lack of support
# to NHWC inputs in pool mkldnn kernel
def test_check_grad(self):
with self.assertRaises(fluid.core_avx.EnforceNotMet):
super(TestAsymPadValidNHWC, self).test_check_grad()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册