From 6512be59ece0a452ea8784ee5f04edacc4881692 Mon Sep 17 00:00:00 2001 From: mozga-intel Date: Fri, 15 Jun 2018 17:06:35 +0200 Subject: [PATCH] MKLDNN layout: the code-review changes --- paddle/fluid/operators/sum_mkldnn_op.cc | 9 ++++----- paddle/fluid/platform/mkldnn_helper.h | 6 ++++++ 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/sum_mkldnn_op.cc b/paddle/fluid/operators/sum_mkldnn_op.cc index 1f0c3ab02..0e201420c 100644 --- a/paddle/fluid/operators/sum_mkldnn_op.cc +++ b/paddle/fluid/operators/sum_mkldnn_op.cc @@ -118,19 +118,18 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel { auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_mpd); std::shared_ptr dst_mem; - if (in_place) + if (in_place) { dst_mem.reset(new memory(sum_pd.dst_primitive_desc())); - else + } else { dst_mem.reset(new memory(sum_pd.dst_primitive_desc(), output_data)); - + } std::vector inputs; for (size_t i = 0; i < srcs_mem.size(); ++i) { inputs.push_back(srcs_mem[i]); } auto sum_prim = mkldnn::sum(sum_pd, inputs, *dst_mem); - output_format = - (memory::format)sum_pd.dst_primitive_desc().desc().data.format; + output_format = (memory::format)platform::GetMKLDNNFormat(sum_pd); primitive reorder_prim; std::shared_ptr target_mem; diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index de711b7d2..2689d5e07 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -99,5 +99,11 @@ inline mkldnn::memory::format GetMKLDNNFormat(const mkldnn::memory memory) { memory.get_primitive_desc().desc().data.format); } +inline mkldnn::memory::format GetMKLDNNFormat( + const mkldnn::sum::primitive_desc& memory) { + return static_cast( + memory.dst_primitive_desc().desc().data.format); +} + } // namespace platform } // namespace paddle -- GitLab