From fa286b105265f1e99ef9c5fc26eab169139e2bd5 Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Wed, 23 Jan 2019 03:27:24 -0800 Subject: [PATCH] LRN reengineering Added reading dst mem pd from lrn pd coding style fixes test=develop --- paddle/fluid/operators/lrn_mkldnn_op.cc | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/lrn_mkldnn_op.cc b/paddle/fluid/operators/lrn_mkldnn_op.cc index 4e4f977fc..d4325b2c0 100644 --- a/paddle/fluid/operators/lrn_mkldnn_op.cc +++ b/paddle/fluid/operators/lrn_mkldnn_op.cc @@ -78,10 +78,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { auto dims = paddle::framework::vectorize2int(x->dims()); auto src_md = paddle::platform::MKLDNNMemDesc( - dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw); - - auto dst_md = paddle::platform::MKLDNNMemDesc( - dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw); + dims, mkldnn::memory::data_type::f32, x->format()); auto forward_desc = mkldnn::lrn_forward::desc{mkldnn::prop_kind::forward, mkldnn::lrn_across_channels, @@ -92,8 +89,6 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { k}; auto src_memory_pd = mkldnn::memory::primitive_desc{src_md, mkldnn_engine}; - auto dst_memory = mkldnn::memory{{dst_md, mkldnn_engine}, - static_cast(output_data)}; if (!is_test) { const std::string key = ctx.op().Output("Out"); @@ -110,11 +105,16 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { src_memory->set_data_handle( static_cast(const_cast(input_data))); + auto dst_memory = mkldnn::memory(forward_pd->dst_primitive_desc(), + static_cast(output_data)); auto workspace_memory = insert_to_context( key_workspace_memory, dev_ctx, forward_pd->workspace_primitive_desc()); run_primitive(*forward_pd, *src_memory, *workspace_memory, dst_memory); + + out->set_layout(framework::DataLayout::kMKLDNN); + out->set_format(platform::GetMKLDNNFormat(dst_memory)); } else { auto forward_pd = mkldnn::lrn_forward::primitive_desc{forward_desc, mkldnn_engine}; @@ -122,8 +122,13 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { src_memory_pd, static_cast(const_cast(input_data))}; auto workspace_memory = mkldnn::memory{forward_pd.workspace_primitive_desc()}; + auto dst_memory = mkldnn::memory(forward_pd.dst_primitive_desc(), + static_cast(output_data)); run_primitive(forward_pd, src_memory, workspace_memory, dst_memory); + + out->set_layout(framework::DataLayout::kMKLDNN); + out->set_format(platform::GetMKLDNNFormat(dst_memory)); } } }; -- GitLab