diff --git a/paddle/fluid/framework/data_layout_transform.cc b/paddle/fluid/framework/data_layout_transform.cc index bc48fd3b479157d4aea390cd5f4dc61ea46dca4b..cd00b7de7338982308acfa1f1e8c38e010c6a43b 100644 --- a/paddle/fluid/framework/data_layout_transform.cc +++ b/paddle/fluid/framework/data_layout_transform.cc @@ -147,9 +147,9 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, "Input tensor type is not supported: ", in.type().name()); memory::data_type out_type = in_type; - auto in_format = MKLDNNFormatForSize(in_tz.size(), in.format()); + auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format()); auto out_format = - MKLDNNFormatForSize(in_tz.size(), ToMKLDNNFormat(out_layout)); + platform::MKLDNNFormatForSize(in_tz.size(), ToMKLDNNFormat(out_layout)); void* in_data = GetDataFromTensor(in, in_type); diff --git a/paddle/fluid/framework/data_layout_transform.h b/paddle/fluid/framework/data_layout_transform.h index 67f91e4e48d3e11ed493c5e6943cb9071aff60c4..90bb206ec6b698bc23ad1a5c9609a25186ec6de8 100644 --- a/paddle/fluid/framework/data_layout_transform.h +++ b/paddle/fluid/framework/data_layout_transform.h @@ -62,12 +62,6 @@ inline MKLDNNDataType ToMKLDNNDataType(const std::type_index type) { return MKLDNNDataType::data_undef; } -inline MKLDNNFormat MKLDNNFormatForSize(size_t dims_size, - MKLDNNFormat default_format) { - return (dims_size == 1 - ? mkldnn::memory::format::x - : dims_size == 2 ? mkldnn::memory::format::nc : default_format); -} #endif void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, diff --git a/paddle/fluid/framework/data_transform.cc b/paddle/fluid/framework/data_transform.cc index 635eb404d4a94e9a765f511bb249ce0cce22c477..82872224501709080ff02a13464d58543a0abda8 100644 --- a/paddle/fluid/framework/data_transform.cc +++ b/paddle/fluid/framework/data_transform.cc @@ -18,6 +18,10 @@ limitations under the License. */ #include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/framework/data_type_transform.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif + namespace paddle { namespace framework { @@ -48,8 +52,8 @@ void TransformData(const OpKernelType &expected_kernel_type, // Case1 - transform from Non-MKLDNN OPKernel to MKLDNN OPKernel // Just set layout/format. No real transform occur - auto out_format = - MKLDNNFormatForSize(in.dims().size(), ToMKLDNNFormat(lin)); + auto out_format = platform::MKLDNNFormatForSize(in.dims().size(), + ToMKLDNNFormat(lin)); out.ShareDataWith(input_tensor); out.set_layout(DataLayout::kMKLDNN); diff --git a/paddle/fluid/operators/batch_norm_mkldnn_op.cc b/paddle/fluid/operators/batch_norm_mkldnn_op.cc index 6ecb43c49c30f9da2a273d506f7b85c0a4f5fa2c..9ab2179b5fe689762704039c5f67dd080e530aa5 100644 --- a/paddle/fluid/operators/batch_norm_mkldnn_op.cc +++ b/paddle/fluid/operators/batch_norm_mkldnn_op.cc @@ -115,9 +115,12 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel { if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu; // create mkldnn memory from input x tensor - auto src_memory = - memory({{{src_tz}, memory::data_type::f32, x->format()}, mkldnn_engine}, - to_void_cast(x_data)); + mkldnn::memory::format input_format = + platform::MKLDNNFormatForSize(src_tz.size(), x->format()); + + auto src_memory = memory( + {{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine}, + to_void_cast(x_data)); // create primitive descriptor for batch norm forward using bn_fwd_types = bn_type_traits; @@ -251,15 +254,21 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel { using bn_bwd_types = bn_type_traits; // create mkldnn memory from input diff_y tensor - auto user_diff_dst_memory = - memory({{{diff_dst_tz}, memory::data_type::f32, diff_y->format()}, - mkldnn_engine}, - to_void_cast(diff_y_data)); + + mkldnn::memory::format dst_format = + platform::MKLDNNFormatForSize(src_tz.size(), diff_y->format()); + + auto user_diff_dst_memory = memory( + {{{diff_dst_tz}, memory::data_type::f32, dst_format}, mkldnn_engine}, + to_void_cast(diff_y_data)); // create mkldnn memory from input x tensor - auto src_memory = - memory({{{src_tz}, memory::data_type::f32, x->format()}, mkldnn_engine}, - to_void_cast(x_data)); + mkldnn::memory::format input_format = + platform::MKLDNNFormatForSize(src_tz.size(), x->format()); + + auto src_memory = memory( + {{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine}, + to_void_cast(x_data)); // for diff_dst, try to use same format as dst in forward pass auto diff_dst_pd = batch_norm_fwd_pd.get()->dst_primitive_desc(); diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index ed99932546446eb877c9701de15e2d37d29b5f88..a6cccc31219104767ac38bdebeb1d4c0e8c2ac01 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -228,7 +228,7 @@ class MKLDNNHandler { return dstr; }; return dims2str(operand_dims) + suffix; - }; + } protected: const MKLDNNDeviceContext& dev_ctx_; @@ -237,5 +237,15 @@ class MKLDNNHandler { bool is_reusing_; }; +inline mkldnn::memory::format MKLDNNFormatForSize( + size_t dims_size, mkldnn::memory::format data_format) { + if (dims_size == 1) { + return mkldnn::memory::format::x; + } else if (dims_size == 2) { + return mkldnn::memory::format::nc; + } + return data_format; +} + } // namespace platform } // namespace paddle