From 2337e609c479dc630cc79b7814ef2f885c36a937 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C5=82awomir=20Siwek?= Date: Mon, 7 Nov 2022 03:13:59 +0100 Subject: [PATCH] [PHI] Migrate batch_norm (#47652) * init changes * bnorm * method signature * change order * bnorm * removed unused args --- .../mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc | 2 +- .../operators/mkldnn/batch_norm_mkldnn_op.cc | 118 -------------- paddle/phi/backends/onednn/onednn_reuse.h | 90 +++++++++++ .../phi/kernels/onednn/batch_norm_kernel.cc | 146 ++++++++++++++++++ 4 files changed, 237 insertions(+), 119 deletions(-) create mode 100644 paddle/phi/kernels/onednn/batch_norm_kernel.cc diff --git a/paddle/fluid/framework/ir/mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc index e51073385b..bdb2bef362 100644 --- a/paddle/fluid/framework/ir/mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc @@ -32,7 +32,7 @@ PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(gelu, CPU, ALL_LAYOUT); USE_OP_ITSELF(batch_norm); -USE_OP_DEVICE_KERNEL(batch_norm, MKLDNN); +PD_DECLARE_KERNEL(batch_norm, OneDNN, ONEDNN); USE_OP_ITSELF(conv2d_transpose); USE_OP_DEVICE_KERNEL(conv2d_transpose, MKLDNN); USE_OP_ITSELF(elementwise_add); diff --git a/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc index d7575f0ebf..4144608de4 100644 --- a/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc @@ -35,38 +35,6 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT< dnnl::batch_normalization_forward, dnnl::batch_normalization_backward> { public: - BatchNormMKLDNNHandler(const paddle::framework::ExecutionContext &ctx, - const dnnl::engine mkldnn_engine, - const Tensor *x, - const bool global_stats, - const bool test_mode) - : platform::MKLDNNHandlerNoCachingT( - mkldnn_engine, ctx.GetPlace()) { - const float epsilon = ctx.Attr("epsilon"); - const bool fuse_with_relu = ctx.HasAttr("fuse_with_relu") - ? ctx.Attr("fuse_with_relu") - : false; - - std::vector DataLayout_error_msg = { - "kNHWC", "kNCHW", "kAnyLayout", "kMKLDNN"}; - - // Flags are added by bitwise OR operation - auto flags = dnnl::normalization_flags::use_scale_shift; // 001 - if (global_stats) - flags |= dnnl::normalization_flags::use_global_stats; // 010 - if (fuse_with_relu && test_mode) - flags |= dnnl::normalization_flags::fuse_norm_relu; // 100 - - this->AcquireForwardPrimitiveDescriptor( - global_stats == true ? dnnl::prop_kind::forward_scoring - : dnnl::prop_kind::forward_training, - x->mem_desc(), - epsilon, - flags); - } - BatchNormMKLDNNHandler(const paddle::framework::ExecutionContext &ctx, const dnnl::engine mkldnn_engine, const Tensor *in_x, @@ -157,88 +125,6 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT< } }; -template -class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - auto &dev_ctx = ctx.template device_context(); - const auto &mkldnn_engine = dev_ctx.GetEngine(); - - const bool is_test = ctx.Attr("is_test"); - const bool use_global_stats = ctx.Attr("use_global_stats"); - const bool trainable_stats = ctx.Attr("trainable_statistics"); - const bool test_mode = is_test && (!trainable_stats); - const bool global_stats = test_mode || use_global_stats; - - const auto *x = ctx.Input("X"); - const auto *scale = ctx.Input("Scale"); - const auto *shift = ctx.Input("Bias"); - - auto *y = ctx.Output("Y"); - auto *batch_mean = ctx.Output("SavedMean"); - auto *batch_variance = ctx.Output("SavedVariance"); - BatchNormMKLDNNHandler handler( - ctx, mkldnn_engine, x, global_stats, test_mode); - - auto src_memory = handler.AcquireSrcMemory(x); - auto scaleshift_memory = handler.AcquireScaleShiftMemory(scale, shift); - auto dst_memory = handler.AcquireDstMemory(y); - auto batch_norm_p = handler.AcquireForwardPrimitive(); - - std::shared_ptr mean_memory; - std::shared_ptr variance_memory; - - if (global_stats) { - // mean and variance are taken from input Tensor - const auto *mean = ctx.Input("Mean"); - const auto *variance = ctx.Input("Variance"); - - mean_memory = handler.AcquireMeanMemory(mean); - variance_memory = handler.AcquireVarianceMemory(variance); - } else { - // mean and variance are calculated and saved in output Tensor - mean_memory = handler.AcquireMeanMemory(batch_mean); - variance_memory = handler.AcquireVarianceMemory(batch_variance); - } - - y->set_mem_desc(dst_memory->get_desc()); - - auto &astream = platform::MKLDNNDeviceContext::tls().get_stream(); - batch_norm_p->execute(astream, - {{DNNL_ARG_SRC, *src_memory}, - {DNNL_ARG_SCALE_SHIFT, *scaleshift_memory}, - {DNNL_ARG_MEAN, *mean_memory}, - {DNNL_ARG_VARIANCE, *variance_memory}, - {DNNL_ARG_DST, *dst_memory}}); - astream.wait(); - - if (!global_stats) { - auto *mean_out = ctx.Output("MeanOut"); - auto *variance_out = ctx.Output("VarianceOut"); - const float momentum = ctx.Attr("momentum"); - - const unsigned int C = phi::vectorize(scale->dims())[0]; - - // mkldnn only compute stats for current batch - // so we need compute momentum stats via Eigen lib - EigenVectorArrayMap batch_mean_e( - batch_mean->mutable_data(ctx.GetPlace()), C); - EigenVectorArrayMap batch_variance_e( - batch_variance->mutable_data(ctx.GetPlace()), C); - - EigenVectorArrayMap running_mean_e( - mean_out->mutable_data(ctx.GetPlace()), C); - EigenVectorArrayMap running_variance_e( - variance_out->mutable_data(ctx.GetPlace()), C); - - running_mean_e = - running_mean_e * momentum + batch_mean_e * (1. - momentum); - running_variance_e = - running_variance_e * momentum + batch_variance_e * (1. - momentum); - } - } -}; - template class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel { public: @@ -308,10 +194,6 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_KERNEL(batch_norm, - MKLDNN, - ::paddle::platform::CPUPlace, - ops::BatchNormMKLDNNOpKernel); REGISTER_OP_KERNEL(batch_norm_grad, MKLDNN, ::paddle::platform::CPUPlace, diff --git a/paddle/phi/backends/onednn/onednn_reuse.h b/paddle/phi/backends/onednn/onednn_reuse.h index 516ab49180..a574f73a65 100644 --- a/paddle/phi/backends/onednn/onednn_reuse.h +++ b/paddle/phi/backends/onednn/onednn_reuse.h @@ -1155,6 +1155,96 @@ class ClipOneDNNHandler } }; +template +class BatchNormOneDNNHandler + : public OneDNNHandlerNoCachingT { + public: + BatchNormOneDNNHandler(const dnnl::engine engine, + Place cpu_place, + const DenseTensor* x, + const float epsilon, + const bool fuse_with_relu, + const bool global_stats, + const bool test_mode) + : OneDNNHandlerNoCachingT(engine, + cpu_place) { + // Flags are added by bitwise OR operation + auto flags = dnnl::normalization_flags::use_scale_shift; // 001 + if (global_stats) + flags |= dnnl::normalization_flags::use_global_stats; // 010 + if (fuse_with_relu && test_mode) + flags |= dnnl::normalization_flags::fuse_norm_relu; // 100 + + this->AcquireForwardPrimitiveDescriptor( + global_stats ? dnnl::prop_kind::forward_scoring + : dnnl::prop_kind::forward_training, + x->mem_desc(), + epsilon, + flags); + } + + std::shared_ptr AcquireScaleShiftMemory( + const DenseTensor* scale, const DenseTensor* shift) { + auto scale_tz = phi::vectorize(scale->dims()); + const unsigned int C = scale_tz[0]; + PADDLE_ENFORCE_EQ( + scale_tz.size(), + 1, + phi::errors::InvalidArgument( + "Dims of scale tensor must be 1, but received scale's size is %d", + scale_tz.size())); + + auto scaleshift_memory = + this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc()); + + // MKLDNN requires a single piece of memory for scale and shift/bias data + auto mem_data_handle = + reinterpret_cast(scaleshift_memory->get_data_handle()); + std::copy(scale->data(), scale->data() + C, mem_data_handle); + std::copy(shift->data(), shift->data() + C, mem_data_handle + C); + return scaleshift_memory; + } + + std::shared_ptr AcquireDiffScaleShiftMemory( + T* diff_scaleshift_data) { + return this->AcquireMemoryFromPrimitive(this->bwd_pd_->diff_weights_desc(), + diff_scaleshift_data); + } + + std::shared_ptr AcquireMeanMemory( + const phi::DenseTensor* mean) { + const T* mean_data = mean->data(); + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(), + to_void_cast(mean_data)); + } + + std::shared_ptr AcquireMeanMemory(phi::DenseTensor* mean) { + T* mean_data = mean->mutable_data(this->place_, + this->fwd_pd_->mean_desc().get_size()); + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(), + mean_data); + } + + std::shared_ptr AcquireVarianceMemory( + const phi::DenseTensor* variance) { + const T* variance_data = variance->data(); + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(), + to_void_cast(variance_data)); + } + + std::shared_ptr AcquireVarianceMemory( + phi::DenseTensor* variance) { + T* variance_data = variance->mutable_data( + this->place_, this->fwd_pd_->variance_desc().get_size()); + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(), + variance_data); + } +}; + template class PoolingOneDNNHandler : public OneDNNHandlerNoCachingT +using EigenVectorArrayMap = Eigen::Map>; + +template +void BatchNormKernel(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &mean, + const DenseTensor &variance, + const DenseTensor &scale, + const DenseTensor &bias, + bool is_test, + float momentum, + float epsilon, + const std::string &data_layout, + bool use_global_stats, + bool trainable_statistics, + DenseTensor *y, + DenseTensor *mean_out, + DenseTensor *variance_out, + DenseTensor *saved_mean, + DenseTensor *saved_variance, + DenseTensor *reserve_space) { + const bool test_mode = is_test && (!trainable_statistics); + const bool global_stats = test_mode || use_global_stats; + const bool fuse_with_relu = + dev_ctx.HasDnnAttr("fuse_with_relu") + ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("fuse_with_relu")) + : false; + + funcs::BatchNormOneDNNHandler handler(dev_ctx.GetEngine(), + dev_ctx.GetPlace(), + &x, + epsilon, + fuse_with_relu, + global_stats, + test_mode); + + auto src_memory = handler.AcquireSrcMemory(&x); + auto scaleshift_memory = handler.AcquireScaleShiftMemory(&scale, &bias); + auto dst_memory = handler.AcquireDstMemory(y); + auto batch_norm_p = handler.AcquireForwardPrimitive(); + + std::shared_ptr mean_memory; + std::shared_ptr variance_memory; + + // mean and variance can be taken either from input or output Tensor + if (global_stats) { + mean_memory = handler.AcquireMeanMemory(&mean); + variance_memory = handler.AcquireVarianceMemory(&variance); + } else { + mean_memory = handler.AcquireMeanMemory(saved_mean); + variance_memory = handler.AcquireVarianceMemory(saved_variance); + } + + y->set_mem_desc(dst_memory->get_desc()); + + auto &astream = OneDNNContext::tls().get_stream(); + batch_norm_p->execute(astream, + {{DNNL_ARG_SRC, *src_memory}, + {DNNL_ARG_SCALE_SHIFT, *scaleshift_memory}, + {DNNL_ARG_MEAN, *mean_memory}, + {DNNL_ARG_VARIANCE, *variance_memory}, + {DNNL_ARG_DST, *dst_memory}}); + astream.wait(); + + if (!global_stats) { + const unsigned int C = phi::vectorize(scale.dims())[0]; + + // mkldnn only compute stats for current batch + // so we need compute momentum stats via Eigen lib + EigenVectorArrayMap batch_mean_e(dev_ctx.template Alloc(saved_mean), + C); + EigenVectorArrayMap batch_variance_e( + dev_ctx.template Alloc(saved_variance), C); + + EigenVectorArrayMap running_mean_e(dev_ctx.template Alloc(mean_out), + C); + EigenVectorArrayMap running_variance_e( + dev_ctx.template Alloc(variance_out), C); + + running_mean_e = running_mean_e * momentum + batch_mean_e * (1. - momentum); + running_variance_e = + running_variance_e * momentum + batch_variance_e * (1. - momentum); + } +} + +template +void BatchNormInferKernel(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &mean, + const DenseTensor &variance, + const DenseTensor &scale, + const DenseTensor &bias, + float momentum, + float epsilon, + const std::string &data_layout, + DenseTensor *y, + DenseTensor *mean_out, + DenseTensor *variance_out) { + BatchNormKernel(dev_ctx, + x, + mean, + variance, + scale, + bias, + /*is_test=*/true, + momentum, + epsilon, + data_layout, + /*use_global_stats=*/false, + /*trainable_statistics=*/false, + y, + mean_out, + variance_out, + /*saved_mean*/ nullptr, + /*saved_variance*/ nullptr, + /*reserve_space=*/nullptr); +} + +} // namespace phi + +PD_REGISTER_KERNEL(batch_norm, OneDNN, ONEDNN, phi::BatchNormKernel, float) {} +PD_REGISTER_KERNEL( + batch_norm_infer, OneDNN, ONEDNN, phi::BatchNormInferKernel, float) {} -- GitLab