提交 9ce0e29d 编写于 作者: J Jacek Czaja 提交者: Tao Luo

[MKL-DNN] Batch norm mkl-dnn NHWC support (#21553)

* - BAtch norm mkl-dnn NHWC

test=develop

- compilation fix

test=develop

- UT fix

- cosmetics

test=develop

- Fix to Batch Norm MKL-DNN NHWC UT

test=develop

Conflicts:
	paddle/fluid/operators/batch_norm_op.h

* - Lint fixes

test=develop
上级 f663f34a
...@@ -94,8 +94,9 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -94,8 +94,9 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
x_dims, x_dims.size()); x_dims, x_dims.size());
const int64_t C = const int64_t C =
(data_layout == DataLayout::kNCHW ? x_dims[1] ((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW)
: x_dims[x_dims.size() - 1]); ? x_dims[1]
: x_dims[x_dims.size() - 1]);
auto scale_dim = ctx->GetInputDim("Scale"); auto scale_dim = ctx->GetInputDim("Scale");
auto bias_dim = ctx->GetInputDim("Bias"); auto bias_dim = ctx->GetInputDim("Bias");
...@@ -169,6 +170,32 @@ framework::OpKernelType BatchNormOp::GetExpectedKernelType( ...@@ -169,6 +170,32 @@ framework::OpKernelType BatchNormOp::GetExpectedKernelType(
library); library);
} }
framework::OpKernelType BatchNormOp::GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN
// Only input require reshaping, weights and
// bias are having shape in NCHW order
if ((var_name == "X") &&
(expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
(tensor.layout() != framework::DataLayout::kMKLDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_layout = ar.Get<std::string>("data_layout");
auto dl = framework::StringToDataLayout(data_layout);
// Some models may have intentionally set "AnyLayout" for pool
// op. Treat this as NCHW (default data_format value)
if (dl != framework::DataLayout::kAnyLayout) {
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(),
framework::StringToDataLayout(data_layout));
}
}
#endif
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
void BatchNormOpMaker::Make() { void BatchNormOpMaker::Make() {
AddAttr<bool>("is_test", AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false " "(bool, default false) Set to true for inference only, false "
...@@ -465,6 +492,12 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType( ...@@ -465,6 +492,12 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType(
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain && if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
// TODO(jczaja): Add support for NHWC
const std::string data_layout = ctx.Attr<std::string>("data_layout");
PADDLE_ENFORCE_NE(
data_layout, "NHWC",
platform::errors::Unimplemented(
"Batch Norm MKLDNN grad does not support NHWC data format yet"));
library = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
} }
......
...@@ -134,6 +134,10 @@ class BatchNormOp : public framework::OperatorWithKernel { ...@@ -134,6 +134,10 @@ class BatchNormOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override; const framework::ExecutionContext& ctx) const override;
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override;
}; };
class BatchNormGradOp : public framework::OperatorWithKernel { class BatchNormGradOp : public framework::OperatorWithKernel {
......
...@@ -86,6 +86,13 @@ class TestMKLDNNBatchNormOpInference(TestBatchNormOpInference): ...@@ -86,6 +86,13 @@ class TestMKLDNNBatchNormOpInference(TestBatchNormOpInference):
self.check_with_place(place, data_format, self.dtype, [2, 3, 4, 5]) self.check_with_place(place, data_format, self.dtype, [2, 3, 4, 5])
class TestMKLDNNBatchNormOpInference_NHWC(TestMKLDNNBatchNormOpInference):
def test_check_output(self):
place = core.CPUPlace()
data_format = "NHWC"
self.check_with_place(place, data_format, self.dtype, [2, 4, 5, 3])
class TestMKLDNNBatchNormOpWithReluInference(TestBatchNormOpInference): class TestMKLDNNBatchNormOpWithReluInference(TestBatchNormOpInference):
def init_kernel_type(self): def init_kernel_type(self):
self.use_mkldnn = True self.use_mkldnn = True
......
...@@ -262,6 +262,21 @@ class TestBatchNormOpInference(unittest.TestCase): ...@@ -262,6 +262,21 @@ class TestBatchNormOpInference(unittest.TestCase):
batch_norm_op.run(scope, place) batch_norm_op.run(scope, place)
# When op is called without Executor then
# MKL-DNN Tensor is returned. For NHWC data layout
# dims will be in NCHW order as it is MKL-DNN way
# of memory descripting. So we need to convert NCHW
# dims into NHWC.
if data_layout == "NHWC" and self.use_mkldnn == True:
# Create executor to have MKL-DNN cache
# cleared after NHWC unit test
place = core.CPUPlace()
exe = fluid.Executor(place)
dims = y_tensor.shape()
c = dims.pop(1)
dims.append(c)
y_tensor._set_dims(dims)
# check inference result # check inference result
self.__assert_close( self.__assert_close(
y_tensor, y_tensor,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册