diff --git a/paddle/fluid/operators/batch_norm_mkldnn_op.cc b/paddle/fluid/operators/batch_norm_mkldnn_op.cc index 199a3d4b04125e5b75c42e5a181f7b3877999a58..9ab2179b5fe689762704039c5f67dd080e530aa5 100644 --- a/paddle/fluid/operators/batch_norm_mkldnn_op.cc +++ b/paddle/fluid/operators/batch_norm_mkldnn_op.cc @@ -115,12 +115,8 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel { if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu; // create mkldnn memory from input x tensor - mkldnn::memory::format input_format = x->format(); - if (src_tz.size() == 1) { - input_format = mkldnn::memory::format::x; - } else if (src_tz.size() == 2) { - input_format = mkldnn::memory::format::nc; - } + mkldnn::memory::format input_format = + platform::MKLDNNFormatForSize(src_tz.size(), x->format()); auto src_memory = memory( {{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine}, @@ -259,23 +255,16 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel { // create mkldnn memory from input diff_y tensor - mkldnn::memory::format dst_format = x->format(); - if (diff_dst_tz.size() == 1) { - dst_format = mkldnn::memory::format::x; - } else if (diff_dst_tz.size() == 2) { - dst_format = mkldnn::memory::format::nc; - } + mkldnn::memory::format dst_format = + platform::MKLDNNFormatForSize(src_tz.size(), diff_y->format()); + auto user_diff_dst_memory = memory( {{{diff_dst_tz}, memory::data_type::f32, dst_format}, mkldnn_engine}, to_void_cast(diff_y_data)); // create mkldnn memory from input x tensor - mkldnn::memory::format input_format = x->format(); - if (src_tz.size() == 1) { - input_format = mkldnn::memory::format::x; - } else if (src_tz.size() == 2) { - input_format = mkldnn::memory::format::nc; - } + mkldnn::memory::format input_format = + platform::MKLDNNFormatForSize(src_tz.size(), x->format()); auto src_memory = memory( {{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine}, diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index ed99932546446eb877c9701de15e2d37d29b5f88..a6cccc31219104767ac38bdebeb1d4c0e8c2ac01 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -228,7 +228,7 @@ class MKLDNNHandler { return dstr; }; return dims2str(operand_dims) + suffix; - }; + } protected: const MKLDNNDeviceContext& dev_ctx_; @@ -237,5 +237,15 @@ class MKLDNNHandler { bool is_reusing_; }; +inline mkldnn::memory::format MKLDNNFormatForSize( + size_t dims_size, mkldnn::memory::format data_format) { + if (dims_size == 1) { + return mkldnn::memory::format::x; + } else if (dims_size == 2) { + return mkldnn::memory::format::nc; + } + return data_format; +} + } // namespace platform } // namespace paddle