From 657abd517f3930b37c2a665dc1ef5c8140252504 Mon Sep 17 00:00:00 2001 From: jakpiase Date: Wed, 25 May 2022 14:51:52 +0200 Subject: [PATCH] OneDNN md-in-tensor refactoring part 4: Memory descriptor enabled for more ops (#42946) * added support for md in more ops * fixed typo --- .../mkldnn/fill_constant_mkldnn_op.cc | 6 ++- .../fluid/operators/mkldnn/lrn_mkldnn_op.cc | 8 ++-- .../fluid/operators/mkldnn/slice_mkldnn_op.cc | 13 ++--- .../fluid/operators/mkldnn/stack_mkldnn_op.cc | 10 ++-- .../fluid/operators/mkldnn/sum_mkldnn_op.cc | 47 +++++-------------- 5 files changed, 29 insertions(+), 55 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/fill_constant_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/fill_constant_mkldnn_op.cc index cfc320da47..73e7830683 100644 --- a/paddle/fluid/operators/mkldnn/fill_constant_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/fill_constant_mkldnn_op.cc @@ -79,8 +79,10 @@ class FillConstantMKLDNNKernel : public framework::OpKernel { {DNNL_ARG_DST, *src0_memory_p}}); astream.wait(); - out->set_layout(framework::DataLayout::kMKLDNN); - out->set_format(platform::GetPlainMKLDNNFormat(out->dims().size())); + // src0_memory_p's md was just to allow the usage of a binary + // primitive as a memset, and now we need to create a real one + out->set_mem_desc({phi::vectorize(shape), platform::MKLDNNGetDataType(), + platform::GetPlainMKLDNNFormat(shape.size())}); } T CalculateFillValue(const framework::ExecutionContext& ctx) const { diff --git a/paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc index d3a36555c3..245ae2196c 100644 --- a/paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc @@ -124,7 +124,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); if (!workspace_memory->get_desc().is_zero()) { - mid->set_format(platform::GetMKLDNNFormat(*workspace_memory)); + mid->set_mem_desc(workspace_memory->get_desc()); lrn_p->execute(astream, {{DNNL_ARG_SRC, *src_memory}, {DNNL_ARG_DST, *dst_memory}, {DNNL_ARG_WORKSPACE, *workspace_memory}}); @@ -134,8 +134,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { } astream.wait(); - out->set_layout(framework::DataLayout::kMKLDNN); - out->set_format(platform::GetMKLDNNFormat(*dst_memory)); + out->set_mem_desc(dst_memory->get_desc()); } }; @@ -177,8 +176,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel { {DNNL_ARG_WORKSPACE, *workspace}}); astream.wait(); - in_x_grad->set_layout(framework::DataLayout::kMKLDNN); - in_x_grad->set_format(platform::GetMKLDNNFormat(*diff_src_memory)); + in_x_grad->set_mem_desc(diff_src_memory->get_desc()); } }; } // namespace operators diff --git a/paddle/fluid/operators/mkldnn/slice_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/slice_mkldnn_op.cc index 2a8627b803..2df9e5c20f 100644 --- a/paddle/fluid/operators/mkldnn/slice_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/slice_mkldnn_op.cc @@ -175,19 +175,17 @@ class SliceGradMKLDNNKernel : public framework::OpKernel { dnnl::memory::data_type dout_type = framework::ToMKLDNNDataType( framework::TransToProtoVarType(dout->dtype())); - dnnl::memory::desc md(dout_vec_dims, platform::MKLDNNGetDataType(), - dout->format()); - dnnl::memory::format_tag reorder_format_tag = - platform::GetMKLDNNFormat(md.reshape(slice_dims)); platform::ReorderMKLDNNHandler reorder_handler( slice_dims, framework::TransToProtoVarType(dout->dtype()), dout_type, onednn_engine); auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( - reorder_format_tag, platform::to_void_cast(dout->data())); + dout->mem_desc().reshape(slice_dims), + platform::to_void_cast(dout->data())); auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( - dx, dx_vec_dims, reorder_format_tag, ctx.GetPlace()); + dx, dx_vec_dims, platform::GetPlainMKLDNNFormat(dx_vec_dims.size()), + ctx.GetPlace()); memset(dx->data(), 0, reorder_dst_memory_p->get_desc().get_size()); auto slice_mem_p = reorder_handler.AcquireSubmemory(slice_dims, offsets, @@ -199,8 +197,7 @@ class SliceGradMKLDNNKernel : public framework::OpKernel { reorder_p->execute(astream, *reorder_src_memory_p, *slice_mem_p); astream.wait(); - dx->set_layout(framework::DataLayout::kMKLDNN); - dx->set_format(reorder_format_tag); + dx->set_mem_desc(reorder_dst_memory_p->get_desc()); } }; } // namespace operators diff --git a/paddle/fluid/operators/mkldnn/stack_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/stack_mkldnn_op.cc index 36be1681b0..28a00be5fa 100644 --- a/paddle/fluid/operators/mkldnn/stack_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/stack_mkldnn_op.cc @@ -59,7 +59,7 @@ class StackMKLDNNHandler // wrong output format deduction and suboptimal performance as a result if (stack_axis != ndims) { for (size_t i = 0; i < inputs.size(); ++i) { - srcs_md.emplace_back(memory::desc(input_dims, dt, inputs[i]->format())); + srcs_md.push_back(inputs[i]->mem_desc()); } input_dims[stack_axis] *= inputs.size(); @@ -69,8 +69,7 @@ class StackMKLDNNHandler extended_input_dims[stack_axis] = 1; for (size_t i = 0; i < inputs.size(); ++i) { - srcs_md.emplace_back(memory::desc(input_dims, dt, inputs[i]->format()) - .reshape(extended_input_dims)); + srcs_md.push_back(inputs[i]->mem_desc().reshape(extended_input_dims)); } // concat primitive choses suboptimal format tag because it cannot @@ -130,9 +129,8 @@ class StackMKLDNNOpKernel : public paddle::framework::OpKernel { concat_p->execute(astream, args); astream.wait(); - output->set_layout(DataLayout::kMKLDNN); - output->set_format(platform::GetMKLDNNFormat( - dst_mem->get_desc().reshape(phi::vectorize(output->dims())))); + output->set_mem_desc( + dst_mem->get_desc().reshape(phi::vectorize(output->dims()))); } }; } // namespace operators diff --git a/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc index 99f957f573..de21c2687b 100644 --- a/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc @@ -60,17 +60,16 @@ class SumMKLDNNHandler auto src_tz = dst_tz; std::vector srcs_md; + srcs_md.reserve(in_vars.size()); for (size_t i = 0; i < in_vars.size(); i++) { auto& input_it = in_vars[i]->Get(); if (input_it.numel() == 0) { continue; } - MKLDNNMemoryFormat input_format = input_it.format(); - srcs_md.push_back(dnnl::memory::desc( - src_tz, platform::MKLDNNGetDataType(), input_format)); + srcs_md.push_back(input_it.mem_desc()); ++num_inputs_; } - std::vector scales(num_inputs_, 1.0); + std::vector scales(num_inputs_, 1.0f); auto dst_md = dnnl::memory::desc(dst_tz, platform::MKLDNNGetDataType(), MKLDNNMemoryFormat::any); @@ -139,47 +138,27 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel { ++input_index; } - std::shared_ptr dst_mem = nullptr; + std::unordered_map args; + std::shared_ptr dst_mem; + + for (size_t i = 0; i < srcs_mem.size(); ++i) { + args.insert({DNNL_ARG_MULTIPLE_SRC + i, *(srcs_mem[i])}); + } + if (in_place) { - dst_mem = handler.AcquireDstMemory(); - output->mutable_data(ctx.GetPlace()); + dst_mem = srcs_mem[0]; } else { dst_mem = handler.AcquireDstMemory(output); } + args.insert({DNNL_ARG_DST, *dst_mem}); auto sum_p = handler.AcquireForwardPrimitive(); - std::unordered_map args; - for (size_t i = 0; i < srcs_mem.size(); ++i) { - args.insert({DNNL_ARG_MULTIPLE_SRC + i, *(srcs_mem[i])}); - } - args.insert({DNNL_ARG_DST, *dst_mem}); - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); sum_p->execute(astream, args); astream.wait(); - // For in-place execution which sum does not have we need to fake it - // so from oneDNN dst memory we reorder data into input - if (in_place) { - auto& in_out = in_vars[0]->Get(); - auto output_tz = phi::vectorize(output->dims()); - platform::ReorderMKLDNNHandler reorder_handler( - output_tz, framework::TransToProtoVarType(output->dtype()), - framework::ToMKLDNNDataType( - framework::TransToProtoVarType(in_out.dtype())), - dev_ctx.GetEngine()); - - auto target_mem = reorder_handler.AcquireDstMemory( - output, in_out.format(), ctx.GetPlace()); - - auto reorder_p = reorder_handler.AcquireReorder(target_mem, dst_mem); - - reorder_p->execute(astream, *dst_mem, *target_mem); - astream.wait(); - } - output->set_layout(framework::DataLayout::kMKLDNN); - output->set_format(platform::GetMKLDNNFormat(*dst_mem)); + output->set_mem_desc(dst_mem->get_desc()); } }; -- GitLab