From 3b128337a147264c37d48486637a1f0b865b8ad7 Mon Sep 17 00:00:00 2001 From: mozga-intel Date: Fri, 22 Jun 2018 12:20:58 +0200 Subject: [PATCH] The mkldnn batch norm supports other data format --- .../fluid/operators/batch_norm_mkldnn_op.cc | 40 ++++++++++++++----- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/batch_norm_mkldnn_op.cc b/paddle/fluid/operators/batch_norm_mkldnn_op.cc index 6ecb43c49c..199a3d4b04 100644 --- a/paddle/fluid/operators/batch_norm_mkldnn_op.cc +++ b/paddle/fluid/operators/batch_norm_mkldnn_op.cc @@ -115,9 +115,16 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel { if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu; // create mkldnn memory from input x tensor - auto src_memory = - memory({{{src_tz}, memory::data_type::f32, x->format()}, mkldnn_engine}, - to_void_cast(x_data)); + mkldnn::memory::format input_format = x->format(); + if (src_tz.size() == 1) { + input_format = mkldnn::memory::format::x; + } else if (src_tz.size() == 2) { + input_format = mkldnn::memory::format::nc; + } + + auto src_memory = memory( + {{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine}, + to_void_cast(x_data)); // create primitive descriptor for batch norm forward using bn_fwd_types = bn_type_traits; @@ -251,15 +258,28 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel { using bn_bwd_types = bn_type_traits; // create mkldnn memory from input diff_y tensor - auto user_diff_dst_memory = - memory({{{diff_dst_tz}, memory::data_type::f32, diff_y->format()}, - mkldnn_engine}, - to_void_cast(diff_y_data)); + + mkldnn::memory::format dst_format = x->format(); + if (diff_dst_tz.size() == 1) { + dst_format = mkldnn::memory::format::x; + } else if (diff_dst_tz.size() == 2) { + dst_format = mkldnn::memory::format::nc; + } + auto user_diff_dst_memory = memory( + {{{diff_dst_tz}, memory::data_type::f32, dst_format}, mkldnn_engine}, + to_void_cast(diff_y_data)); // create mkldnn memory from input x tensor - auto src_memory = - memory({{{src_tz}, memory::data_type::f32, x->format()}, mkldnn_engine}, - to_void_cast(x_data)); + mkldnn::memory::format input_format = x->format(); + if (src_tz.size() == 1) { + input_format = mkldnn::memory::format::x; + } else if (src_tz.size() == 2) { + input_format = mkldnn::memory::format::nc; + } + + auto src_memory = memory( + {{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine}, + to_void_cast(x_data)); // for diff_dst, try to use same format as dst in forward pass auto diff_dst_pd = batch_norm_fwd_pd.get()->dst_primitive_desc(); -- GitLab