From 561b727834529fe5613a31edc2170f66bd4b8add Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C5=82awomir=20Siwek?= Date: Thu, 24 Nov 2022 03:29:20 +0100 Subject: [PATCH] [PHI] Migrate batch_norm_grad kernel (#48288) --- paddle/fluid/operators/batch_norm_op.cc | 2 +- paddle/fluid/operators/inplace_abn_op.cc | 2 +- .../operators/mkldnn/batch_norm_mkldnn_op.cc | 199 ------------------ paddle/fluid/operators/unity_build_rule.cmake | 1 - paddle/phi/backends/onednn/onednn_reuse.h | 54 +++-- .../phi/kernels/cpu/batch_norm_grad_kernel.cc | 1 - .../kernels/onednn/batch_norm_grad_kernel.cc | 134 ++++++++++++ 7 files changed, 178 insertions(+), 215 deletions(-) delete mode 100644 paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc create mode 100644 paddle/phi/kernels/onednn/batch_norm_grad_kernel.cc diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 7452c64f6fc..a20b2ad21d3 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -350,7 +350,7 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const { true, platform::errors::InvalidArgument( "Using global stats during training is not supported " - "in gradient op kernel of batch_norm_mkldnn_op now.")); + "in oneDNN version of batch_norm_gradient kernel now.")); } OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "BatchNormGrad"); diff --git a/paddle/fluid/operators/inplace_abn_op.cc b/paddle/fluid/operators/inplace_abn_op.cc index f87d7effcae..61379a3d893 100644 --- a/paddle/fluid/operators/inplace_abn_op.cc +++ b/paddle/fluid/operators/inplace_abn_op.cc @@ -113,7 +113,7 @@ class InplaceABNGradOp : public paddle::operators::BatchNormGradOp { true, platform::errors::InvalidArgument( "Using global stats during training is not supported " - "in gradient op kernel of batch_norm_mkldnn_op now.")); + "in oneDNN version of batch_norm_gradient kernel now.")); } OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "InplaceABNGrad"); diff --git a/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc deleted file mode 100644 index aeba1e0ae63..00000000000 --- a/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc +++ /dev/null @@ -1,199 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/batch_norm_op.h" -#include "paddle/fluid/platform/mkldnn_reuse.h" - -namespace phi { -class DenseTensor; -} // namespace phi - -namespace paddle { -namespace operators { - -using dnnl::memory; -using dnnl::primitive; -using dnnl::stream; -using paddle::platform::MKLDNNDeviceContext; - -template -class BatchNormMKLDNNHandler : public phi::funcs::OneDNNHandlerNoCachingT< - T, - dnnl::batch_normalization_forward, - dnnl::batch_normalization_backward> { - public: - BatchNormMKLDNNHandler(const paddle::framework::ExecutionContext &ctx, - const dnnl::engine mkldnn_engine, - const Tensor *in_x, - const Tensor *scale, - const Tensor *out_grad) - : phi::funcs::OneDNNHandlerNoCachingT( - mkldnn_engine, ctx.GetPlace()) { - auto scale_tz = phi::vectorize(scale->dims()); - PADDLE_ENFORCE_EQ( - scale_tz.size(), - 1, - platform::errors::InvalidArgument( - "Dims of scale tensor must be 1, but received scale's size is %d", - scale_tz.size())); - - const float epsilon = ctx.Attr("epsilon"); - - this->AcquireForwardPrimitiveDescriptor( - dnnl::prop_kind::forward_training, - in_x->mem_desc(), - epsilon, - dnnl::normalization_flags::use_scale_shift); - this->AcquireBackwardPrimitiveDescriptor( - dnnl::prop_kind::backward, - out_grad->mem_desc(), - in_x->mem_desc(), - epsilon, - dnnl::normalization_flags::use_scale_shift); - } - - std::shared_ptr AcquireScaleShiftMemory(const Tensor *scale, - const Tensor *shift) { - auto scale_tz = phi::vectorize(scale->dims()); - const unsigned int C = scale_tz[0]; - PADDLE_ENFORCE_EQ( - scale_tz.size(), - 1, - platform::errors::InvalidArgument( - "Dims of scale tensor must be 1, but received scale's size is %d", - scale_tz.size())); - - auto scaleshift_memory = - this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc()); - - // MKLDNN requires a single piece of memory for scale and shift/bias data - auto mem_data_handle = - reinterpret_cast(scaleshift_memory->get_data_handle()); - std::copy(scale->data(), scale->data() + C, mem_data_handle); - std::copy(shift->data(), shift->data() + C, mem_data_handle + C); - return scaleshift_memory; - } - - std::shared_ptr AcquireDiffScaleShiftMemory( - T *diff_scaleshift_data) { - return this->AcquireMemoryFromPrimitive(this->bwd_pd_->diff_weights_desc(), - diff_scaleshift_data); - } - - std::shared_ptr AcquireMeanMemory( - const phi::DenseTensor *mean) { - const T *mean_data = mean->data(); - return this->AcquireMemoryFromPrimitive( - this->fwd_pd_->mean_desc(), phi::funcs::to_void_cast(mean_data)); - } - - std::shared_ptr AcquireMeanMemory(phi::DenseTensor *mean) { - T *mean_data = mean->mutable_data(this->place_, - this->fwd_pd_->mean_desc().get_size()); - return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(), - mean_data); - } - - std::shared_ptr AcquireVarianceMemory( - const phi::DenseTensor *variance) { - const T *variance_data = variance->data(); - return this->AcquireMemoryFromPrimitive( - this->fwd_pd_->variance_desc(), - phi::funcs::to_void_cast(variance_data)); - } - - std::shared_ptr AcquireVarianceMemory( - phi::DenseTensor *variance) { - T *variance_data = variance->mutable_data( - this->place_, this->fwd_pd_->variance_desc().get_size()); - return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(), - variance_data); - } -}; - -template -class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel { - public: - void Compute(const paddle::framework::ExecutionContext &ctx) const override { - auto &dev_ctx = ctx.template device_context(); - auto mkldnn_engine = dev_ctx.GetEngine(); - - const auto *x = ctx.Input("X"); - const auto *scale = ctx.Input("Scale"); - const auto *shift = ctx.Input("Bias"); - const auto *batch_mean = ctx.Input("SavedMean"); - const auto *batch_variance = ctx.Input("SavedVariance"); - const auto *diff_y = - ctx.Input(framework::GradVarName("Y")); - auto *diff_x = ctx.Output(framework::GradVarName("X")); - auto *diff_scale = - ctx.Output(framework::GradVarName("Scale")); - auto *diff_shift = - ctx.Output(framework::GradVarName("Bias")); - - BatchNormMKLDNNHandler handler(ctx, mkldnn_engine, x, scale, diff_y); - - // MKLDNN requires a single piece of memory for scale and shift/bias data - const unsigned int C = phi::vectorize(scale->dims())[0]; - const size_t scaleshift_size = 2 * C; - std::vector diff_scaleshift_data; - diff_scaleshift_data.reserve(scaleshift_size); - - auto src_memory = handler.AcquireSrcMemory(x); - auto mean_memory = handler.AcquireMeanMemory(batch_mean); - auto variance_memory = handler.AcquireVarianceMemory(batch_variance); - auto diff_dst_memory = handler.AcquireDiffDstMemory(diff_y); - auto scaleshift_memory = handler.AcquireScaleShiftMemory(scale, shift); - auto diff_src_memory = handler.AcquireDiffSrcMemory(diff_x); - auto diff_scaleshift_memory = - handler.AcquireDiffScaleShiftMemory(diff_scaleshift_data.data()); - // finally create batch_norm backward primitive - auto batch_norm_bwd_p = handler.AcquireBackwardPrimitive(); - - auto &astream = platform::MKLDNNDeviceContext::tls().get_stream(); - batch_norm_bwd_p->execute( - astream, - {{DNNL_ARG_SRC, *src_memory}, - {DNNL_ARG_MEAN, *mean_memory}, - {DNNL_ARG_VARIANCE, *variance_memory}, - {DNNL_ARG_DIFF_DST, *diff_dst_memory}, - {DNNL_ARG_SCALE_SHIFT, *scaleshift_memory}, - {DNNL_ARG_DIFF_SRC, *diff_src_memory}, - {DNNL_ARG_DIFF_SCALE_SHIFT, *diff_scaleshift_memory}}); - astream.wait(); - - T *diff_scale_data = diff_scale->mutable_data(ctx.GetPlace()); - T *diff_shift_data = diff_shift->mutable_data(ctx.GetPlace()); - - // copy back diff scale/shift to output tensors (diff scale/shift) - diff_scaleshift_data.resize(scaleshift_size); - auto it = std::begin(diff_scaleshift_data); - std::copy(it, std::next(it, C), diff_scale_data); - std::copy( - std::next(it, C), std::end(diff_scaleshift_data), diff_shift_data); - - // set memory descriptor of out tensor - diff_x->set_mem_desc(diff_src_memory->get_desc()); - } -}; -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_KERNEL(batch_norm_grad, - MKLDNN, - ::paddle::platform::CPUPlace, - ops::BatchNormMKLDNNGradOpKernel); diff --git a/paddle/fluid/operators/unity_build_rule.cmake b/paddle/fluid/operators/unity_build_rule.cmake index 97fe4d620cb..891cb40ab28 100644 --- a/paddle/fluid/operators/unity_build_rule.cmake +++ b/paddle/fluid/operators/unity_build_rule.cmake @@ -27,7 +27,6 @@ register_unity_group( bilateral_slice_op.cc) register_unity_group( cc - mkldnn/batch_norm_mkldnn_op.cc bilinear_tensor_product_op.cc bmm_op.cc bpr_loss_op.cc diff --git a/paddle/phi/backends/onednn/onednn_reuse.h b/paddle/phi/backends/onednn/onednn_reuse.h index bc88fef443d..f4577dab5aa 100644 --- a/paddle/phi/backends/onednn/onednn_reuse.h +++ b/paddle/phi/backends/onednn/onednn_reuse.h @@ -47,7 +47,7 @@ bool constexpr is_int8() { template constexpr bool is_bfloat16() { - return std::is_same::value; + return std::is_same::value; } static void AppendActivation(const OneDNNContext& dev_ctx, @@ -102,7 +102,7 @@ static void AppendActivation(const OneDNNContext& dev_ctx, PADDLE_ENFORCE_NE( activation_type, activation_map.end(), - phi::errors::InvalidArgument( + errors::InvalidArgument( "Activation '%s' not found in oneDNN algorithms mapper", fuse_activation)); @@ -810,7 +810,7 @@ class SoftmaxOneDNNHandler PADDLE_ENFORCE_EQ( x->dims(), out->dims(), - phi::errors::InvalidArgument( + errors::InvalidArgument( "The shape of input and output tensor must be identical.")); const int canonical_axis = funcs::CanonicalAxis(axis, x->dims().size()); @@ -1145,7 +1145,7 @@ class PReluOneDNNHandler const bool is_test) : OneDNNHandlerNoCachingT( engine, cpu_place) { - auto weights_dims = phi::vectorize(weights.dims()); + auto weights_dims = vectorize(weights.dims()); // weights must have same size as X only for "element" case if (weights.dims().size() != x.dims().size()) { auto new_weights_dims = std::vector(x.dims().size(), 1); @@ -1304,21 +1304,52 @@ class BatchNormOneDNNHandler flags); } + BatchNormOneDNNHandler(const dnnl::engine engine, + Place cpu_place, + const float epsilon, + const DenseTensor* in_x, + const DenseTensor* scale, + const DenseTensor* out_grad) + : OneDNNHandlerNoCachingT(engine, + cpu_place) { + auto scale_tz = vectorize(scale->dims()); + PADDLE_ENFORCE_EQ( + scale_tz.size(), + 1, + errors::InvalidArgument( + "Dims of scale tensor must be 1, but received scale's size is %d", + scale_tz.size())); + + this->AcquireForwardPrimitiveDescriptor( + dnnl::prop_kind::forward_training, + in_x->mem_desc(), + epsilon, + dnnl::normalization_flags::use_scale_shift); + this->AcquireBackwardPrimitiveDescriptor( + dnnl::prop_kind::backward, + out_grad->mem_desc(), + in_x->mem_desc(), + epsilon, + dnnl::normalization_flags::use_scale_shift); + } + std::shared_ptr AcquireScaleShiftMemory( const DenseTensor* scale, const DenseTensor* shift) { - auto scale_tz = phi::vectorize(scale->dims()); + auto scale_tz = vectorize(scale->dims()); const unsigned int C = scale_tz[0]; PADDLE_ENFORCE_EQ( scale_tz.size(), 1, - phi::errors::InvalidArgument( + errors::InvalidArgument( "Dims of scale tensor must be 1, but received scale's size is %d", scale_tz.size())); auto scaleshift_memory = this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc()); - // MKLDNN requires a single piece of memory for scale and shift/bias data + // oneDNN requires a single piece of memory for scale and shift/bias data auto mem_data_handle = reinterpret_cast(scaleshift_memory->get_data_handle()); std::copy(scale->data(), scale->data() + C, mem_data_handle); @@ -1692,7 +1723,7 @@ static std::vector GetInputStrides(const OneDNNContext& dev_ctx, auto& MatrixDimsFromVector = input_name == "X" ? RowMatrixDimsFromVector : ColumnMatrixDimsFromVector; - phi::funcs::MatDescriptor mat_dim = phi::funcs::CreateMatrixDescriptor( + MatDescriptor mat_dim = CreateMatrixDescriptor( MatrixDimsFromVector(new_dims), 0, transpose_input); std::vector strides; @@ -1728,8 +1759,7 @@ static bool IsOutputFused(const OneDNNContext& dev_ctx) { } template -class MatmulOneDNNHandler - : public phi::funcs::OneDNNHandlerNoCachingT { +class MatmulOneDNNHandler : public OneDNNHandlerNoCachingT { public: MatmulOneDNNHandler(const OneDNNContext& dev_ctx, const std::vector& x_org_dims, @@ -1739,8 +1769,8 @@ class MatmulOneDNNHandler const std::vector& x_strides_override, const std::vector& y_strides_override, bool is_output_fused) - : phi::funcs::OneDNNHandlerNoCachingT( - dev_ctx.GetEngine(), dev_ctx.GetPlace()) { + : OneDNNHandlerNoCachingT(dev_ctx.GetEngine(), + dev_ctx.GetPlace()) { // M X K * K X N std::vector x_dims(x_org_dims); std::vector y_dims(y_org_dims); diff --git a/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc b/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc index efd55dee88c..8d0ae7e08d7 100644 --- a/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc @@ -36,7 +36,6 @@ using ConstEigenVectorArrayMap = template void BatchNormGradRawKernel(const Context& ctx, - const DenseTensor& x, const DenseTensor& scale, const DenseTensor& bias, diff --git a/paddle/phi/kernels/onednn/batch_norm_grad_kernel.cc b/paddle/phi/kernels/onednn/batch_norm_grad_kernel.cc new file mode 100644 index 00000000000..503dd6416b4 --- /dev/null +++ b/paddle/phi/kernels/onednn/batch_norm_grad_kernel.cc @@ -0,0 +1,134 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/batch_norm_grad_kernel.h" + +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void BatchNormGradRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& scale, + const DenseTensor& bias, + const paddle::optional& mean, + const paddle::optional& variance, + const DenseTensor& saved_mean, + const DenseTensor& saved_variance, + const paddle::optional& reserve_space, + const DenseTensor& y_grad, + float momentum, + float epsilon, + const std::string& data_layout, + bool is_test, + bool use_global_stats, + bool trainable_statistics, + bool is_inplace, + DenseTensor* x_grad, + DenseTensor* scale_grad, + DenseTensor* bias_grad) { + funcs::BatchNormOneDNNHandler handler( + dev_ctx.GetEngine(), dev_ctx.GetPlace(), epsilon, &x, &scale, &y_grad); + + const unsigned int C = vectorize(scale.dims())[0]; + const size_t scaleshift_size = 2 * C; + std::vector diff_scaleshift_data; + diff_scaleshift_data.reserve(scaleshift_size); + + auto src_memory = handler.AcquireSrcMemory(&x); + auto mean_memory = handler.AcquireMeanMemory(&saved_mean); + auto variance_memory = handler.AcquireVarianceMemory(&saved_variance); + auto diff_dst_memory = handler.AcquireDiffDstMemory(&y_grad); + auto scaleshift_memory = handler.AcquireScaleShiftMemory(&scale, &bias); + auto diff_src_memory = handler.AcquireDiffSrcMemory(x_grad); + auto diff_scaleshift_memory = + handler.AcquireDiffScaleShiftMemory(diff_scaleshift_data.data()); + + auto batch_norm_bwd_p = handler.AcquireBackwardPrimitive(); + + auto& astream = OneDNNContext::tls().get_stream(); + batch_norm_bwd_p->execute( + astream, + {{DNNL_ARG_SRC, *src_memory}, + {DNNL_ARG_MEAN, *mean_memory}, + {DNNL_ARG_VARIANCE, *variance_memory}, + {DNNL_ARG_DIFF_DST, *diff_dst_memory}, + {DNNL_ARG_SCALE_SHIFT, *scaleshift_memory}, + {DNNL_ARG_DIFF_SRC, *diff_src_memory}, + {DNNL_ARG_DIFF_SCALE_SHIFT, *diff_scaleshift_memory}}); + astream.wait(); + + T* diff_scale_data = dev_ctx.template Alloc(scale_grad); + T* diff_shift_data = dev_ctx.template Alloc(bias_grad); + + // copy back diff scale/shift to output tensors (diff scale/shift) + diff_scaleshift_data.resize(scaleshift_size); + auto it = std::begin(diff_scaleshift_data); + std::copy(it, std::next(it, C), diff_scale_data); + std::copy(std::next(it, C), std::end(diff_scaleshift_data), diff_shift_data); + + // set memory descriptor of out tensor + x_grad->set_mem_desc(diff_src_memory->get_desc()); +} + +template +void BatchNormGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& scale, + const DenseTensor& bias, + const paddle::optional& mean, + const paddle::optional& variance, + const DenseTensor& saved_mean, + const DenseTensor& saved_variance, + const paddle::optional& reserve_space, + const DenseTensor& y_grad, + float momentum, + float epsilon, + const std::string& data_layout, + bool is_test, + bool use_global_stats, + bool trainable_statistics, + DenseTensor* x_grad, + DenseTensor* scale_grad, + DenseTensor* bias_grad) { + BatchNormGradRawKernel(dev_ctx, + x, + scale, + bias, + mean, + variance, + saved_mean, + saved_variance, + reserve_space, + y_grad, + momentum, + epsilon, + data_layout, + is_test, + use_global_stats, + trainable_statistics, + /*is_inplace*/ false, + x_grad, + scale_grad, + bias_grad); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + batch_norm_grad, OneDNN, ONEDNN, phi::BatchNormGradKernel, float) {} +PD_REGISTER_KERNEL( + batch_norm_grad_raw, OneDNN, ONEDNN, phi::BatchNormGradRawKernel, float) {} -- GitLab