From 7004f65c53acf3cdeef99cd3c53e20b22fa0f1ac Mon Sep 17 00:00:00 2001 From: piotrekobi <48731682+piotrekobi@users.noreply.github.com> Date: Wed, 16 Mar 2022 11:46:03 +0100 Subject: [PATCH] Refactor elementwise op grad classes (#40187) * Refactor elementwise op grad classes * Add more refactor changes * Revert set layout and format deletion * Fix failing elementwise test --- .../mkldnn/elementwise_add_mkldnn_op.cc | 102 +------- .../mkldnn/elementwise_div_mkldnn_op.cc | 174 +++----------- .../mkldnn/elementwise_mkldnn_op.h | 223 ++++++++++++++++-- .../mkldnn/elementwise_mul_mkldnn_op.cc | 142 ++--------- .../mkldnn/elementwise_sub_mkldnn_op.cc | 117 +-------- .../unittests/test_elementwise_add_op.py | 2 +- .../unittests/test_elementwise_mul_op.py | 2 +- 7 files changed, 266 insertions(+), 496 deletions(-) diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc index 838df2e1625..f9347d28104 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,100 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h" -namespace paddle { -namespace framework { -class ExecutionContext; -} // namespace framework -namespace platform { -class CPUDeviceContext; -} // namespace platform -} // namespace paddle - -namespace paddle { -namespace operators { -template -class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - ElemwiseGradKernel::Compute(ctx); - using Tensor = framework::Tensor; - - auto& dev_ctx = - ctx.template device_context(); - const auto& onednn_engine = dev_ctx.GetEngine(); - - auto* dout = ctx.Input(framework::GradVarName("Out")); - auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dy = ctx.Output(framework::GradVarName("Y")); - - auto tz = phi::vectorize(dout->dims()); - memory::data_type dout_type = framework::ToMKLDNNDataType( - framework::TransToProtoVarType(dout->dtype())); - platform::ReorderMKLDNNHandler handler( - tz, framework::TransToProtoVarType(dout->dtype()), dout_type, - onednn_engine); - - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - auto reorder_src_memory_p = handler.AcquireSrcMemory( - dout->format(), platform::to_void_cast(dout->data())); - - if (dx) { - auto reorder_dst_memory_p = - handler.AcquireDstMemory(dx, dout->format(), ctx.GetPlace()); - auto reorder_p = - handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p); - platform::RecordEvent record_reorder( - "int_reorder", platform::TracerEventType::UserDefined, 2, - platform::EventRole::kUniqueOp); - reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); - astream.wait(); - - dx->set_layout(DataLayout::kMKLDNN); - dx->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p)); - } - - if (dy) { - // Direct copy - if (dout->dims() == dy->dims()) { - auto reorder_dst_memory_p = - handler.AcquireDstMemory(dy, dout->format(), ctx.GetPlace()); - auto reorder_p = - handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p); - platform::RecordEvent record_reorder( - "int_reorder", platform::TracerEventType::UserDefined, 2, - platform::EventRole::kUniqueOp); - reorder_p->execute(astream, *reorder_src_memory_p, - *reorder_dst_memory_p); - astream.wait(); - - dy->set_layout(DataLayout::kMKLDNN); - dy->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p)); - } else { - // Broadcasting - platform::ReductionMKLDNNHandler handler_sum( - dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine, - ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy)); - auto dy_memory_p = handler_sum.AcquireDstMemory(dy); - auto reduction_p = handler_sum.AcquireForwardPrimitive(); - reduction_p->execute(astream, {{DNNL_ARG_SRC, *reorder_src_memory_p}, - {DNNL_ARG_DST, *dy_memory_p}}); - astream.wait(); - - dy->set_layout(DataLayout::kMKLDNN); - dy->set_format( - platform::GetMKLDNNFormat(dy_memory_p->get_desc().reshape( - phi::vectorize(dy->dims())))); - } - } - } -}; - -} // namespace operators -} // namespace paddle - namespace ops = paddle::operators; REGISTER_OP_KERNEL( @@ -116,6 +24,8 @@ REGISTER_OP_KERNEL( ops::EltwiseMKLDNNKernel, ops::EltwiseMKLDNNKernel) -REGISTER_OP_KERNEL(elementwise_add_grad, MKLDNN, ::paddle::platform::CPUPlace, - ops::EltwiseAddMKLDNNGradKernel, - ops::EltwiseAddMKLDNNGradKernel) +REGISTER_OP_KERNEL( + elementwise_add_grad, MKLDNN, ::paddle::platform::CPUPlace, + ops::EltwiseMKLDNNGradKernel, + ops::EltwiseMKLDNNGradKernel) diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_div_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_div_mkldnn_op.cc index 367d602f590..c68aa8d3d1b 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_div_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_div_mkldnn_op.cc @@ -1,146 +1,28 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h" - -namespace paddle { -namespace framework { -class ExecutionContext; -} // namespace framework -namespace platform { -class CPUDeviceContext; -} // namespace platform -} // namespace paddle - -namespace paddle { -namespace operators { -template -class EltwiseDivMKLDNNGradKernel : public ElemwiseGradKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - ElemwiseGradKernel::Compute(ctx); - - auto& dev_ctx = - ctx.template device_context(); - const auto& mkldnn_engine = dev_ctx.GetEngine(); - - auto* y = ctx.Input("Y"); - auto* out = ctx.Input("Out"); - auto* dout = ctx.Input(framework::GradVarName("Out")); - auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dy = ctx.Output(framework::GradVarName("Y")); - int axis = ctx.Attr("axis"); - - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - - if (dx) { - // dx = dout / y - - platform::BinaryMKLDNNHandler handler( - dnnl::algorithm::binary_div, axis, mkldnn_engine, ctx.GetPlace(), - dout, y, dx, 1.0f, 1.0f, 1.0f); - - const auto src_dout_memory = handler.AcquireSrcMemory(dout); - const auto src_y_memory = handler.AcquireSecondSrcMemory(y); - const auto dst_dx_memory = handler.AcquireDstMemory(dx); - - const auto binary_prim = handler.AcquireForwardPrimitive(); - - const std::unordered_map args = { - {DNNL_ARG_SRC_0, *src_dout_memory}, - {DNNL_ARG_SRC_1, *src_y_memory}, - {DNNL_ARG_DST, *dst_dx_memory}}; - - binary_prim->execute(astream, args); - astream.wait(); - - dx->set_layout(framework::DataLayout::kMKLDNN); - dx->set_format(platform::GetMKLDNNFormat(*dst_dx_memory)); - } - - if (dy) { - // dy = -dout * out / y - - platform::BinaryMKLDNNHandler y_handler( - dnnl::algorithm::binary_div, axis, mkldnn_engine, ctx.GetPlace(), y, - y, nullptr, 1.0f, 1.0f, 1.0f); - - const auto y_memory = y_handler.AcquireSrcMemory(y); - - dnnl::post_ops po; - po.append_binary(dnnl::algorithm::binary_div, y_memory->get_desc()); - - platform::BinaryMKLDNNHandler handler( - dnnl::algorithm::binary_mul, axis, mkldnn_engine, ctx.GetPlace(), - dout, out, nullptr, -1.0f, 1.0f, 1.0f, po); - - const auto src_dout_memory = handler.AcquireSrcMemory(dout); - const auto src_out_memory = handler.AcquireSecondSrcMemory(out); - - // If broadcasting is in use then let's write to temporary - // buffer allocated by oneDNN - const auto dst_dy_memory = (dout->dims() == dy->dims()) - ? handler.AcquireDstMemory(dy) - : handler.AcquireDstMemory(); - - const auto binary_prim = handler.AcquireForwardPrimitive(); - - const std::unordered_map args = { - {DNNL_ARG_SRC_0, *src_dout_memory}, - {DNNL_ARG_SRC_1, *src_out_memory}, - {DNNL_ARG_DST, *dst_dy_memory}, - {DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, *y_memory}}; - - binary_prim->execute(astream, args); - astream.wait(); - - dy->set_layout(framework::DataLayout::kMKLDNN); - - // Reduction is needed for broadcasting scenario - if (dout->dims() != dy->dims()) { - platform::ReductionMKLDNNHandler handler_sum( - dnnl::algorithm::reduction_sum, 0.0f, 0.0f, mkldnn_engine, - ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy)); - auto dy_memory_p = handler_sum.AcquireDstMemory(dy); - auto reduction_p = handler_sum.AcquireForwardPrimitive(); - - // As source we use mem object with results from binary operation - reduction_p->execute(astream, {{DNNL_ARG_SRC, *dst_dy_memory}, - {DNNL_ARG_DST, *dy_memory_p}}); - astream.wait(); - dy->set_format( - platform::GetMKLDNNFormat(dy_memory_p->get_desc().reshape( - phi::vectorize(dy->dims())))); - - } else { - dy->set_format(platform::GetMKLDNNFormat(*dst_dy_memory)); - } - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -// TODO(piotrekobi) add int8, uint8 support -REGISTER_OP_KERNEL(elementwise_div, MKLDNN, paddle::platform::CPUPlace, - ops::EltwiseMKLDNNKernel, - ops::EltwiseMKLDNNKernel) - -REGISTER_OP_KERNEL(elementwise_div_grad, MKLDNN, paddle::platform::CPUPlace, - ops::EltwiseDivMKLDNNGradKernel, - ops::EltwiseDivMKLDNNGradKernel) +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_KERNEL(elementwise_div, MKLDNN, paddle::platform::CPUPlace, + ops::EltwiseMKLDNNKernel, + ops::EltwiseMKLDNNKernel) + +REGISTER_OP_KERNEL( + elementwise_div_grad, MKLDNN, paddle::platform::CPUPlace, + ops::EltwiseMKLDNNGradKernel, + ops::EltwiseMKLDNNGradKernel) diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h index ad8fd317013..761b401ca9a 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h @@ -15,20 +15,35 @@ #pragma once #include #include -#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/framework/data_layout_transform.h" +#include "paddle/fluid/operators/elementwise/elementwise_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/platform/mkldnn_reuse.h" namespace paddle { namespace operators { -using framework::DataLayout; -using framework::Tensor; using dnnl::memory; using dnnl::primitive; using dnnl::stream; +using framework::DataLayout; +using framework::Tensor; + +inline std::vector CalculateBroadcastedDims(const Tensor* x, + const Tensor* y) { + const auto src_tz = phi::vectorize(x->dims()); + const auto dst_tz = phi::vectorize(y->dims()); + + size_t j = 0; + std::vector dst_tz_ex(src_tz.size(), 1); + for (size_t i = 0; i < src_tz.size(); ++i) { + dst_tz_ex[i] = (src_tz[i] != dst_tz[j]) ? 1 : dst_tz[j++]; + if (j == dst_tz.size()) break; + } + + return dst_tz_ex; +} template class EltwiseMKLDNNKernel : public framework::OpKernel { @@ -103,7 +118,7 @@ class EltwiseMKLDNNKernel : public framework::OpKernel { // operation. const bool reuse_x_memopry = x->numel() == z->numel() && x->IsSharedBufferWith(*z); - std::shared_ptr dst_memory = nullptr; + std::shared_ptr dst_memory; if (reuse_x_memopry) { dst_memory = src_x_memory; // NOTE(chenfeiyu): when the output reuses memory from other tensor rather @@ -135,19 +150,193 @@ class EltwiseMKLDNNKernel : public framework::OpKernel { } }; -inline std::vector CalculateBroadcastedDims(const Tensor* x, - const Tensor* y) { - const auto src_tz = phi::vectorize(x->dims()); - const auto dst_tz = phi::vectorize(y->dims()); +template +class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); + using Tensor = framework::Tensor; - size_t j = 0; - std::vector dst_tz_ex(src_tz.size(), 1); - for (size_t i = 0; i < src_tz.size(); ++i) { - dst_tz_ex[i] = (src_tz[i] != dst_tz[j]) ? 1 : dst_tz[j++]; - if (j == dst_tz.size()) break; - } + auto& dev_ctx = + ctx.template device_context(); + const auto& onednn_engine = dev_ctx.GetEngine(); - return dst_tz_ex; -} + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* out = ctx.Input("Out"); + + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + auto* dout = ctx.Input(framework::GradVarName("Out")); + + int axis = ctx.Attr("axis"); + + auto tz = phi::vectorize(dout->dims()); + auto proto_type_dout = framework::TransToProtoVarType(dout->dtype()); + + platform::ReorderMKLDNNHandler reorder_handler( + tz, proto_type_dout, framework::ToMKLDNNDataType(proto_type_dout), + onednn_engine); + + auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( + dout->format(), platform::to_void_cast(dout->data())); + + auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); + + if (dx) { + std::shared_ptr dst_memory; + + // elementwise_add & elementwise_sub + if (BINARY_OP == dnnl::algorithm::binary_add || + BINARY_OP == dnnl::algorithm::binary_sub) { + dst_memory = reorder_handler.AcquireDstMemory(dx, dout->format(), + ctx.GetPlace()); + auto reorder_p = + reorder_handler.AcquireReorder(dst_memory, reorder_src_memory_p); + platform::RecordEvent record_reorder( + "int_reorder", platform::TracerEventType::UserDefined, 2, + platform::EventRole::kUniqueOp); + + reorder_p->execute(astream, *reorder_src_memory_p, *dst_memory); + } + + // elementwise_mul & elementwise_div + else { + platform::BinaryMKLDNNHandler binary_handler( + BINARY_OP, axis, onednn_engine, ctx.GetPlace(), dout, y, dx, 1.0f, + 1.0f, 1.0f); + + const auto src_dout_memory = binary_handler.AcquireSrcMemory(dout); + const auto src_y_memory = binary_handler.AcquireSecondSrcMemory(y); + dst_memory = binary_handler.AcquireDstMemory(dx); + + const auto binary_prim = binary_handler.AcquireForwardPrimitive(); + + const std::unordered_map args = { + {DNNL_ARG_SRC_0, *src_dout_memory}, + {DNNL_ARG_SRC_1, *src_y_memory}, + {DNNL_ARG_DST, *dst_memory}}; + + binary_prim->execute(astream, args); + } + astream.wait(); + + dx->set_layout(framework::DataLayout::kMKLDNN); + dx->set_format(platform::GetMKLDNNFormat(*dst_memory)); + } + + if (dy) { + dnnl::primitive_attr broadcast_reduction_attr; + std::shared_ptr broadcast_src_memory; + std::shared_ptr dst_memory; + + // elementwise_add & elementwise_sub + if (BINARY_OP == dnnl::algorithm::binary_add || + BINARY_OP == dnnl::algorithm::binary_sub) { + if (dout->dims() == dy->dims()) { + auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( + dy, dout->format(), ctx.GetPlace()); + + dnnl::primitive_attr reorder_attr; + std::vector scales(1); + scales[0] = (BINARY_OP == dnnl::algorithm::binary_add) ? 1 : -1; + reorder_attr.set_output_scales(0, scales); + auto reorder_p = std::make_shared( + *(reorder_src_memory_p), *(reorder_dst_memory_p), reorder_attr); + platform::RecordEvent record_reorder( + "int_reorder", platform::TracerEventType::UserDefined, 2, + platform::EventRole::kUniqueOp); + reorder_p->execute(astream, *reorder_src_memory_p, + *reorder_dst_memory_p); + + dst_memory = reorder_dst_memory_p; + } else { + broadcast_src_memory = reorder_src_memory_p; + } + } + + // elementwise_mul & elementwise_div + else { + std::unordered_map args; + std::shared_ptr binary_prim; + std::shared_ptr post_op_memory; + std::shared_ptr src_0_memory; + std::shared_ptr src_1_memory; + + platform::BinaryMKLDNNHandler binary_handler( + dnnl::algorithm::binary_mul, axis, onednn_engine, ctx.GetPlace(), + dout, x, nullptr, 1.0f, 1.0f, 1.0f); + + src_1_memory = binary_handler.AcquireSecondSrcMemory(x); + + if (BINARY_OP == dnnl::algorithm::binary_div) { + platform::BinaryMKLDNNHandler post_op_binary_handler( + dnnl::algorithm::binary_div, axis, onednn_engine, ctx.GetPlace(), + y, y, nullptr, 1.0f, 1.0f, 1.0f); + + post_op_memory = post_op_binary_handler.AcquireSrcMemory(y); + + dnnl::post_ops po; + po.append_binary(dnnl::algorithm::binary_div, + post_op_memory->get_desc()); + + binary_handler = platform::BinaryMKLDNNHandler( + dnnl::algorithm::binary_mul, axis, onednn_engine, ctx.GetPlace(), + dout, out, nullptr, -1.0f, 1.0f, 1.0f, po); + + src_1_memory = binary_handler.AcquireSecondSrcMemory(out); + } + + src_0_memory = binary_handler.AcquireSrcMemory(dout); + + const auto dst_dy_memory = (dout->dims() == dy->dims()) + ? binary_handler.AcquireDstMemory(dy) + : binary_handler.AcquireDstMemory(); + + binary_prim = binary_handler.AcquireForwardPrimitive(); + args = {{DNNL_ARG_SRC_0, *src_0_memory}, + {DNNL_ARG_SRC_1, *src_1_memory}, + {DNNL_ARG_DST, *dst_dy_memory}}; + + if (BINARY_OP == dnnl::algorithm::binary_div) + args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, + *post_op_memory}); + + binary_prim->execute(astream, args); + broadcast_src_memory = dst_dy_memory; + dst_memory = dst_dy_memory; + } + astream.wait(); + dy->set_layout(DataLayout::kMKLDNN); + + if (dout->dims() != dy->dims()) { + // Broadcasting + if (BINARY_OP == dnnl::algorithm::binary_sub) { + dnnl::post_ops po; + po.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, -1.0f, 0); + broadcast_reduction_attr.set_post_ops(po); + } + + platform::ReductionMKLDNNHandler reduction_handler( + dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine, + ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy), + broadcast_reduction_attr); + dst_memory = reduction_handler.AcquireDstMemory(dy); + + auto reduction_p = reduction_handler.AcquireForwardPrimitive(); + + reduction_p->execute(astream, { + {DNNL_ARG_SRC, *broadcast_src_memory}, + {DNNL_ARG_DST, *dst_memory}, + }); + astream.wait(); + dy->set_format(platform::GetMKLDNNFormat(dst_memory->get_desc().reshape( + phi::vectorize(dy->dims())))); + } else { + dy->set_format(platform::GetMKLDNNFormat(*dst_memory)); + } + } + } +}; } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc index c03794012ff..0ef5c5e628c 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc @@ -1,127 +1,19 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h" -namespace paddle { -namespace framework { -class ExecutionContext; -} // namespace framework -namespace platform { -class CPUDeviceContext; -} // namespace platform -} // namespace paddle - -namespace paddle { -namespace operators { -template -class EltwiseMulMKLDNNGradKernel : public ElemwiseGradKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - ElemwiseGradKernel::Compute(ctx); - - auto& dev_ctx = - ctx.template device_context(); - const auto& mkldnn_engine = dev_ctx.GetEngine(); - - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* dout = ctx.Input(framework::GradVarName("Out")); - auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dy = ctx.Output(framework::GradVarName("Y")); - int axis = ctx.Attr("axis"); - - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - - if (dx) { - // dx = dout*y - platform::BinaryMKLDNNHandler handler( - dnnl::algorithm::binary_mul, axis, mkldnn_engine, ctx.GetPlace(), - dout, y, dx, 1.0f, 1.0f, 1.0f); - - const auto src_dout_memory = handler.AcquireSrcMemory(dout); - const auto src_y_memory = handler.AcquireSecondSrcMemory(y); - const auto dst_dx_memory = handler.AcquireDstMemory(dx); - - const auto binary_prim = handler.AcquireForwardPrimitive(); - - const std::unordered_map args = { - {DNNL_ARG_SRC_0, *src_dout_memory}, - {DNNL_ARG_SRC_1, *src_y_memory}, - {DNNL_ARG_DST, *dst_dx_memory}}; - - binary_prim->execute(astream, args); - astream.wait(); - - dx->set_layout(framework::DataLayout::kMKLDNN); - dx->set_format(platform::GetMKLDNNFormat(*dst_dx_memory)); - } - - if (dy) { - // dy = dout*x - // Handler is having nullptr passed instead of output tensor as - // we want Dst buffer to be allocated by oneDNN not to use Tensor - platform::BinaryMKLDNNHandler handler( - dnnl::algorithm::binary_mul, axis, mkldnn_engine, ctx.GetPlace(), - dout, x, nullptr, 1.0f, 1.0f, 1.0f); - - const auto src_dout_memory = handler.AcquireSrcMemory(dout); - const auto src_x_memory = handler.AcquireSecondSrcMemory(x); - - // If broadcasting is in use then let's write to temporary - // buffer allocated by oneDNN - const auto dst_dy_memory = (dout->dims() == dy->dims()) - ? handler.AcquireDstMemory(dy) - : handler.AcquireDstMemory(); - - const auto binary_prim = handler.AcquireForwardPrimitive(); - - const std::unordered_map args = { - {DNNL_ARG_SRC_0, *src_dout_memory}, - {DNNL_ARG_SRC_1, *src_x_memory}, - {DNNL_ARG_DST, *dst_dy_memory}}; - - binary_prim->execute(astream, args); - astream.wait(); - - dy->set_layout(framework::DataLayout::kMKLDNN); - - // Reduction is needed for broadcasting scenario - if (dout->dims() != dy->dims()) { - platform::ReductionMKLDNNHandler handler_sum( - dnnl::algorithm::reduction_sum, 0.0f, 0.0f, mkldnn_engine, - ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy)); - auto dy_memory_p = handler_sum.AcquireDstMemory(dy); - auto reduction_p = handler_sum.AcquireForwardPrimitive(); - // As source we use mem object with results from binary operation - reduction_p->execute(astream, {{DNNL_ARG_SRC, *dst_dy_memory}, - {DNNL_ARG_DST, *dy_memory_p}}); - astream.wait(); - dy->set_format( - platform::GetMKLDNNFormat(dy_memory_p->get_desc().reshape( - phi::vectorize(dy->dims())))); - - } else { - dy->set_format(platform::GetMKLDNNFormat(*dst_dy_memory)); - } - } - } -}; - -} // namespace operators -} // namespace paddle - namespace ops = paddle::operators; REGISTER_OP_KERNEL( @@ -132,6 +24,8 @@ REGISTER_OP_KERNEL( ops::EltwiseMKLDNNKernel, ops::EltwiseMKLDNNKernel) -REGISTER_OP_KERNEL(elementwise_mul_grad, MKLDNN, ::paddle::platform::CPUPlace, - ops::EltwiseMulMKLDNNGradKernel, - ops::EltwiseMulMKLDNNGradKernel) +REGISTER_OP_KERNEL( + elementwise_mul_grad, MKLDNN, ::paddle::platform::CPUPlace, + ops::EltwiseMKLDNNGradKernel, + ops::EltwiseMKLDNNGradKernel) diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_sub_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_sub_mkldnn_op.cc index 3c799008a2a..510373831eb 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_sub_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_sub_mkldnn_op.cc @@ -1,5 +1,4 @@ - -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,113 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h" -namespace paddle { -namespace framework { -class ExecutionContext; -} // namespace framework -namespace platform { -class CPUDeviceContext; -} // namespace platform -} // namespace paddle - -namespace paddle { -namespace operators { -template -class EltwiseSubMKLDNNGradKernel : public ElemwiseGradKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - ElemwiseGradKernel::Compute(ctx); - using Tensor = framework::Tensor; - - auto& dev_ctx = - ctx.template device_context(); - const auto& onednn_engine = dev_ctx.GetEngine(); - - auto* dout = ctx.Input(framework::GradVarName("Out")); - auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dy = ctx.Output(framework::GradVarName("Y")); - - auto tz = phi::vectorize(dout->dims()); - memory::data_type dout_type = framework::ToMKLDNNDataType( - framework::TransToProtoVarType(dout->dtype())); - platform::ReorderMKLDNNHandler handler( - tz, framework::TransToProtoVarType(dout->dtype()), dout_type, - onednn_engine); - - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - auto reorder_src_memory_p = handler.AcquireSrcMemory( - dout->format(), platform::to_void_cast(dout->data())); - - if (dx) { - auto reorder_dst_memory_p = - handler.AcquireDstMemory(dx, dout->format(), ctx.GetPlace()); - auto reorder_p = - handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p); - platform::RecordEvent record_reorder( - "int_reorder", platform::TracerEventType::UserDefined, 2, - platform::EventRole::kUniqueOp); - - reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); - astream.wait(); - - dx->set_layout(DataLayout::kMKLDNN); - dx->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p)); - } - - if (dy) { - // Direct copy - if (dout->dims() == dy->dims()) { - auto reorder_dst_memory_p = - handler.AcquireDstMemory(dy, dout->format(), ctx.GetPlace()); - - dnnl::primitive_attr reorder_attr; - std::vector scales = {-1}; - reorder_attr.set_output_scales(0, scales); - auto reorder_p = std::make_shared( - *(reorder_src_memory_p), *(reorder_dst_memory_p), reorder_attr); - platform::RecordEvent record_reorder( - "int_reorder", platform::TracerEventType::UserDefined, 2, - platform::EventRole::kUniqueOp); - reorder_p->execute(astream, *reorder_src_memory_p, - *reorder_dst_memory_p); - astream.wait(); - - dy->set_layout(DataLayout::kMKLDNN); - dy->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p)); - } else { - // Broadcasting - - dnnl::post_ops po; - po.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, -1.0f, 0); - dnnl::primitive_attr attr; - attr.set_post_ops(po); - - platform::ReductionMKLDNNHandler handler_sum( - dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine, - ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy), attr); - - auto dy_memory_p = handler_sum.AcquireDstMemory(dy); - auto reduction_p = handler_sum.AcquireForwardPrimitive(); - - reduction_p->execute(astream, { - {DNNL_ARG_SRC, *reorder_src_memory_p}, - {DNNL_ARG_DST, *dy_memory_p}, - }); - astream.wait(); - - dy->set_layout(DataLayout::kMKLDNN); - dy->set_format( - platform::GetMKLDNNFormat(dy_memory_p->get_desc().reshape( - phi::vectorize(dy->dims())))); - } - } - } -}; - -} // namespace operators -} // namespace paddle namespace ops = paddle::operators; @@ -131,6 +24,8 @@ REGISTER_OP_KERNEL( ops::EltwiseMKLDNNKernel, ops::EltwiseMKLDNNKernel) -REGISTER_OP_KERNEL(elementwise_sub_grad, MKLDNN, ::paddle::platform::CPUPlace, - ops::EltwiseSubMKLDNNGradKernel, - ops::EltwiseSubMKLDNNGradKernel) +REGISTER_OP_KERNEL( + elementwise_sub_grad, MKLDNN, ::paddle::platform::CPUPlace, + ops::EltwiseMKLDNNGradKernel, + ops::EltwiseMKLDNNGradKernel) diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py index d1d391a3949..318e826058f 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py @@ -17,7 +17,7 @@ import unittest import numpy as np import paddle import paddle.fluid.core as core -from op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16 +from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16 import paddle.fluid as fluid from paddle.fluid import compiler, Program, program_guard diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py index 00967cb503f..b35b2840ed3 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py @@ -23,7 +23,7 @@ import paddle.fluid.core as core from paddle.fluid import Program, compiler, program_guard from paddle.fluid.op import Operator -from op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16 +from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16 class ElementwiseMulOp(OpTest): -- GitLab