diff --git a/paddle/fluid/operators/sum_mkldnn_op.cc b/paddle/fluid/operators/sum_mkldnn_op.cc index 1f0c3ab02381297a2f02bf7806cbe02dbb04e5d7..0e201420cecf803e10ad551c40e3fa494975d834 100644 --- a/paddle/fluid/operators/sum_mkldnn_op.cc +++ b/paddle/fluid/operators/sum_mkldnn_op.cc @@ -118,19 +118,18 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel { auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_mpd); std::shared_ptr dst_mem; - if (in_place) + if (in_place) { dst_mem.reset(new memory(sum_pd.dst_primitive_desc())); - else + } else { dst_mem.reset(new memory(sum_pd.dst_primitive_desc(), output_data)); - + } std::vector inputs; for (size_t i = 0; i < srcs_mem.size(); ++i) { inputs.push_back(srcs_mem[i]); } auto sum_prim = mkldnn::sum(sum_pd, inputs, *dst_mem); - output_format = - (memory::format)sum_pd.dst_primitive_desc().desc().data.format; + output_format = (memory::format)platform::GetMKLDNNFormat(sum_pd); primitive reorder_prim; std::shared_ptr target_mem; diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index de711b7d23ef01d57a62087c552ea090f01f0386..2689d5e0787e0164bfb8e539399d8a378964e50a 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -99,5 +99,11 @@ inline mkldnn::memory::format GetMKLDNNFormat(const mkldnn::memory memory) { memory.get_primitive_desc().desc().data.format); } +inline mkldnn::memory::format GetMKLDNNFormat( + const mkldnn::sum::primitive_desc& memory) { + return static_cast( + memory.dst_primitive_desc().desc().data.format); +} + } // namespace platform } // namespace paddle