提交 1f598dfa 编写于 作者: B bingyanghuang 提交者: Tao Luo

cherry-pick MKL-DNN NHWC FWD support fix (#21593)

上级 f83254d6
...@@ -79,7 +79,8 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -79,7 +79,8 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
x_dims, x_dims.size()); x_dims, x_dims.size());
const int64_t C = const int64_t C =
(data_layout == DataLayout::kNCHW ? x_dims[1] ((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW)
? x_dims[1]
: x_dims[x_dims.size() - 1]); : x_dims[x_dims.size() - 1]);
auto scale_dim = ctx->GetInputDim("Scale"); auto scale_dim = ctx->GetInputDim("Scale");
...@@ -154,6 +155,32 @@ framework::OpKernelType BatchNormOp::GetExpectedKernelType( ...@@ -154,6 +155,32 @@ framework::OpKernelType BatchNormOp::GetExpectedKernelType(
library); library);
} }
framework::OpKernelType BatchNormOp::GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN
// Only input require reshaping, weights and
// bias are having shape in NCHW order
if ((var_name == "X") &&
(expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
(tensor.layout() != framework::DataLayout::kMKLDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_layout = ar.Get<std::string>("data_layout");
auto dl = framework::StringToDataLayout(data_layout);
// Some models may have intentionally set "AnyLayout" for pool
// op. Treat this as NCHW (default data_format value)
if (dl != framework::DataLayout::kAnyLayout) {
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(),
framework::StringToDataLayout(data_layout));
}
}
#endif
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
void BatchNormOpMaker::Make() { void BatchNormOpMaker::Make() {
AddAttr<bool>("is_test", AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false " "(bool, default false) Set to true for inference only, false "
...@@ -446,6 +473,12 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType( ...@@ -446,6 +473,12 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType(
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain && if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
// TODO(jczaja): Add support for NHWC
const std::string data_layout = ctx.Attr<std::string>("data_layout");
PADDLE_ENFORCE_NE(
data_layout, "NHWC",
platform::errors::Unimplemented(
"Batch Norm MKLDNN grad does not support NHWC data format yet"));
library = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
} }
......
...@@ -47,6 +47,10 @@ class BatchNormOp : public framework::OperatorWithKernel { ...@@ -47,6 +47,10 @@ class BatchNormOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override; const framework::ExecutionContext &ctx) const override;
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override;
}; };
class BatchNormGradOp : public framework::OperatorWithKernel { class BatchNormGradOp : public framework::OperatorWithKernel {
......
...@@ -775,8 +775,23 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -775,8 +775,23 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
* ('any') which lets a primitive (conv backward in this case) choose * ('any') which lets a primitive (conv backward in this case) choose
* the memory format preferred for best performance * the memory format preferred for best performance
*/ */
auto chosen_memory_format = MKLDNNMemoryFormat::any;
// TODO(jczaja): Once GRAD NHWC is working then format 'any'
// should be used exclusively. But till forward pass enforce
// NCHW for training we need to have NCHW here as well
// to avoid performance degradation in relu_grad and pool2d_grad
std::string data_format = ctx.Attr<std::string>("data_format");
auto chosen_memory_format =
platform::data_format_to_memory_format(data_format);
weights_format = MKLDNNMemoryFormat::any; weights_format = MKLDNNMemoryFormat::any;
// Check the format for user's special output
if (chosen_memory_format != MKLDNNMemoryFormat::any) {
if (is_conv3d) {
chosen_memory_format =
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
}
}
auto src_md = platform::MKLDNNMemDesc( auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
......
...@@ -84,6 +84,13 @@ class TestMKLDNNBatchNormOpInference(TestBatchNormOpInference): ...@@ -84,6 +84,13 @@ class TestMKLDNNBatchNormOpInference(TestBatchNormOpInference):
self.check_with_place(place, data_format, self.dtype, [2, 3, 4, 5]) self.check_with_place(place, data_format, self.dtype, [2, 3, 4, 5])
class TestMKLDNNBatchNormOpInference_NHWC(TestMKLDNNBatchNormOpInference):
def test_check_output(self):
place = core.CPUPlace()
data_format = "NHWC"
self.check_with_place(place, data_format, self.dtype, [2, 4, 5, 3])
class TestMKLDNNBatchNormOpWithReluInference(TestBatchNormOpInference): class TestMKLDNNBatchNormOpWithReluInference(TestBatchNormOpInference):
def init_kernel_type(self): def init_kernel_type(self):
self.use_mkldnn = True self.use_mkldnn = True
......
...@@ -259,6 +259,21 @@ class TestBatchNormOpInference(unittest.TestCase): ...@@ -259,6 +259,21 @@ class TestBatchNormOpInference(unittest.TestCase):
batch_norm_op.run(scope, place) batch_norm_op.run(scope, place)
# When op is called without Executor then
# MKL-DNN Tensor is returned. For NHWC data layout
# dims will be in NCHW order as it is MKL-DNN way
# of memory descripting. So we need to convert NCHW
# dims into NHWC.
if data_layout == "NHWC" and self.use_mkldnn == True:
# Create executor to have MKL-DNN cache
# cleared after NHWC unit test
place = core.CPUPlace()
exe = fluid.Executor(place)
dims = y_tensor.shape()
c = dims.pop(1)
dims.append(c)
y_tensor._set_dims(dims)
# check inference result # check inference result
self.__assert_close( self.__assert_close(
y_tensor, y_tensor,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册