From fe581b0e8aac9ae68014ed3b463457c16bbb0c6d Mon Sep 17 00:00:00 2001 From: Adam <38704900+grygielski@users.noreply.github.com> Date: Sat, 28 Sep 2019 15:50:27 +0200 Subject: [PATCH] Minor GetMKLDNNFormat changes (#20055) test=develop --- .../mkldnn/elementwise_add_mkldnn_op.cc | 4 +--- paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc | 7 +------ paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc | 7 ++++--- paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc | 14 ++------------ paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc | 7 ++++--- paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc | 11 ++--------- 6 files changed, 14 insertions(+), 36 deletions(-) diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc index 97b1f3831c..1f4a4fb0e1 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc @@ -167,9 +167,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel { stream(stream::kind::eager).submit(pipeline).wait(); z->set_layout(DataLayout::kMKLDNN); - z->set_format((MKLDNNMemoryFormat)dst_memory->get_primitive_desc() - .desc() - .data.format); + z->set_format(platform::GetMKLDNNFormat(*dst_memory)); } } }; diff --git a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc index 8823e08655..8010b52a1d 100644 --- a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc @@ -47,11 +47,6 @@ static memory::primitive_desc CreateMemPrimDesc(const Tensor& input, return mem_prim_desc; } -static MKLDNNMemoryFormat GetDstMemFormat( - const concat::primitive_desc& concat_pd) { - return (MKLDNNMemoryFormat)concat_pd.dst_primitive_desc().desc().data.format; -} - static platform::CPUPlace GetCpuPlace( const paddle::framework::ExecutionContext& ctx) { auto place = ctx.GetPlace(); @@ -198,7 +193,7 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel { stream(stream::kind::eager).submit({*concat_p}).wait(); output->set_layout(DataLayout::kMKLDNN); - output->set_format(GetDstMemFormat(*concat_pd)); + output->set_format(platform::GetMKLDNNFormat(*dst_mem)); } }; } // namespace operators diff --git a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc index 01837cfe36..a910deef52 100644 --- a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc @@ -71,7 +71,7 @@ class FCPrimitiveFactory { input_->set_data_handle(const_cast(in->data())); output_->set_data_handle(out->mutable_data(ctx.GetPlace())); if (out->format() == MKLDNNMemoryFormat::format_undef) { - auto output_format = output_->get_primitive_desc().desc().data.format; + auto output_format = platform::GetMKLDNNFormat(*output_); out->set_format((MKLDNNMemoryFormat)output_format); } } @@ -199,8 +199,9 @@ class FCPrimitiveFactory { auto dst_prim_desc = fc_prim_desc.dst_primitive_desc(); auto buffer_size = dst_prim_desc.get_size(); T* output_data = output->mutable_data(ctx.GetPlace(), buffer_size); - output->set_format((MKLDNNMemoryFormat)dst_prim_desc.desc().data.format); - return memory(dst_prim_desc, to_void_cast(output_data)); + memory dst_mem(dst_prim_desc, to_void_cast(output_data)); + output->set_format(platform::GetMKLDNNFormat(dst_mem)); + return dst_mem; } void RecomputeOutputDims(const ExecutionContext& ctx, const LoDTensor* input, diff --git a/paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc index fe1ead8fed..ef922c35b8 100644 --- a/paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc @@ -77,13 +77,8 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { std::vector pipeline = {*lrn_p}; mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); - auto output_format = - (mkldnn::memory::format)dst_memory->get_primitive_desc() - .desc() - .data.format; - out->set_layout(framework::DataLayout::kMKLDNN); - out->set_format(output_format); + out->set_format(platform::GetMKLDNNFormat(*dst_memory)); } }; @@ -129,13 +124,8 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel { std::vector pipeline = {*lrn_bwd}; mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); - auto output_format = - (mkldnn::memory::format)diff_src_memory->get_primitive_desc() - .desc() - .data.format; - x_grad->set_layout(framework::DataLayout::kMKLDNN); - x_grad->set_format(output_format); + x_grad->set_format(platform::GetMKLDNNFormat(*diff_src_memory)); } }; } // namespace operators diff --git a/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc index 5c635e58ec..4bdd93d08e 100644 --- a/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc @@ -107,7 +107,7 @@ class MulPrimitiveFactory { output_->set_data_handle(out->mutable_data(ctx.GetPlace())); if (out->format() == MKLDNNMemoryFormat::format_undef) { - auto output_format = output_->get_primitive_desc().desc().data.format; + auto output_format = platform::GetMKLDNNFormat(*output_); out->set_format((MKLDNNMemoryFormat)output_format); } } @@ -139,8 +139,9 @@ class MulPrimitiveFactory { auto buffer_size = dst_prim_desc.get_size(); OT *output_data = output->mutable_data(ctx.GetPlace(), buffer_size); - output->set_format((MKLDNNMemoryFormat)dst_prim_desc.desc().data.format); - return memory(dst_prim_desc, to_void_cast(output_data)); + memory dst_mem(dst_prim_desc, to_void_cast(output_data)); + output->set_format(platform::GetMKLDNNFormat(dst_mem)); + return dst_mem; } memory Reorder(const memory::desc &src_desc, const memory::desc &dst_desc, diff --git a/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc index 9c0893456a..a7f1bd018c 100644 --- a/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc @@ -95,11 +95,8 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { std::vector pipeline{*pool_p}; stream(stream::kind::eager).submit(pipeline).wait(); - auto output_format = - (MKLDNNMemoryFormat)dst_memory->get_primitive_desc().desc().data.format; - output->set_layout(DataLayout::kMKLDNN); - output->set_format(output_format); + output->set_format(platform::GetMKLDNNFormat(*dst_memory)); } }; @@ -179,12 +176,8 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel { pipeline.push_back(*pool_bwd_p); mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); - auto in_x_grad_format = - (MKLDNNMemoryFormat)diff_src_memory->get_primitive_desc() - .desc() - .data.format; in_x_grad->set_layout(DataLayout::kMKLDNN); - in_x_grad->set_format(in_x_grad_format); + in_x_grad->set_format(platform::GetMKLDNNFormat(*diff_src_memory)); } // Compute() }; -- GitLab