From b8a04c2fa10ba1ecc447379306c5ab1481078346 Mon Sep 17 00:00:00 2001 From: mozga-intel Date: Tue, 26 Jun 2018 10:16:59 +0200 Subject: [PATCH] Duplicated code was moved to common function --- .../fluid/operators/batch_norm_mkldnn_op.cc | 25 ++++++------------- paddle/fluid/platform/mkldnn_helper.h | 12 ++++++++- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/operators/batch_norm_mkldnn_op.cc b/paddle/fluid/operators/batch_norm_mkldnn_op.cc index 199a3d4b041..9ab2179b5fe 100644 --- a/paddle/fluid/operators/batch_norm_mkldnn_op.cc +++ b/paddle/fluid/operators/batch_norm_mkldnn_op.cc @@ -115,12 +115,8 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel { if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu; // create mkldnn memory from input x tensor - mkldnn::memory::format input_format = x->format(); - if (src_tz.size() == 1) { - input_format = mkldnn::memory::format::x; - } else if (src_tz.size() == 2) { - input_format = mkldnn::memory::format::nc; - } + mkldnn::memory::format input_format = + platform::MKLDNNFormatForSize(src_tz.size(), x->format()); auto src_memory = memory( {{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine}, @@ -259,23 +255,16 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel { // create mkldnn memory from input diff_y tensor - mkldnn::memory::format dst_format = x->format(); - if (diff_dst_tz.size() == 1) { - dst_format = mkldnn::memory::format::x; - } else if (diff_dst_tz.size() == 2) { - dst_format = mkldnn::memory::format::nc; - } + mkldnn::memory::format dst_format = + platform::MKLDNNFormatForSize(src_tz.size(), diff_y->format()); + auto user_diff_dst_memory = memory( {{{diff_dst_tz}, memory::data_type::f32, dst_format}, mkldnn_engine}, to_void_cast(diff_y_data)); // create mkldnn memory from input x tensor - mkldnn::memory::format input_format = x->format(); - if (src_tz.size() == 1) { - input_format = mkldnn::memory::format::x; - } else if (src_tz.size() == 2) { - input_format = mkldnn::memory::format::nc; - } + mkldnn::memory::format input_format = + platform::MKLDNNFormatForSize(src_tz.size(), x->format()); auto src_memory = memory( {{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine}, diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index ed999325464..a6cccc31219 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -228,7 +228,7 @@ class MKLDNNHandler { return dstr; }; return dims2str(operand_dims) + suffix; - }; + } protected: const MKLDNNDeviceContext& dev_ctx_; @@ -237,5 +237,15 @@ class MKLDNNHandler { bool is_reusing_; }; +inline mkldnn::memory::format MKLDNNFormatForSize( + size_t dims_size, mkldnn::memory::format data_format) { + if (dims_size == 1) { + return mkldnn::memory::format::x; + } else if (dims_size == 2) { + return mkldnn::memory::format::nc; + } + return data_format; +} + } // namespace platform } // namespace paddle -- GitLab