// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/phi/kernels/batch_norm_kernel.h" #include "paddle/phi/backends/onednn/onednn_reuse.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/eigen/common.h" namespace phi { template using EigenVectorArrayMap = Eigen::Map>; template void BatchNormKernel(const Context &dev_ctx, const DenseTensor &x, const DenseTensor &mean, const DenseTensor &variance, const DenseTensor &scale, const DenseTensor &bias, bool is_test, float momentum, float epsilon, const std::string &data_layout, bool use_global_stats, bool trainable_statistics, DenseTensor *y, DenseTensor *mean_out, DenseTensor *variance_out, DenseTensor *saved_mean, DenseTensor *saved_variance, DenseTensor *reserve_space) { const bool test_mode = is_test && (!trainable_statistics); const bool global_stats = test_mode || use_global_stats; const bool fuse_with_relu = dev_ctx.HasDnnAttr("fuse_with_relu") ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("fuse_with_relu")) : false; funcs::BatchNormOneDNNHandler handler(dev_ctx.GetEngine(), dev_ctx.GetPlace(), &x, epsilon, fuse_with_relu, global_stats, test_mode); auto src_memory = handler.AcquireSrcMemory(&x); auto scaleshift_memory = handler.AcquireScaleShiftMemory(&scale, &bias); auto dst_memory = handler.AcquireDstMemory(y); auto batch_norm_p = handler.AcquireForwardPrimitive(); std::shared_ptr mean_memory; std::shared_ptr variance_memory; // mean and variance can be taken either from input or output Tensor if (global_stats) { mean_memory = handler.AcquireMeanMemory(&mean); variance_memory = handler.AcquireVarianceMemory(&variance); } else { mean_memory = handler.AcquireMeanMemory(saved_mean); variance_memory = handler.AcquireVarianceMemory(saved_variance); } y->set_mem_desc(dst_memory->get_desc()); auto &astream = OneDNNContext::tls().get_stream(); batch_norm_p->execute(astream, {{DNNL_ARG_SRC, *src_memory}, {DNNL_ARG_SCALE_SHIFT, *scaleshift_memory}, {DNNL_ARG_MEAN, *mean_memory}, {DNNL_ARG_VARIANCE, *variance_memory}, {DNNL_ARG_DST, *dst_memory}}); astream.wait(); if (!global_stats) { const unsigned int C = phi::vectorize(scale.dims())[0]; // mkldnn only compute stats for current batch // so we need compute momentum stats via Eigen lib EigenVectorArrayMap batch_mean_e(dev_ctx.template Alloc(saved_mean), C); EigenVectorArrayMap batch_variance_e( dev_ctx.template Alloc(saved_variance), C); EigenVectorArrayMap running_mean_e(dev_ctx.template Alloc(mean_out), C); EigenVectorArrayMap running_variance_e( dev_ctx.template Alloc(variance_out), C); running_mean_e = running_mean_e * momentum + batch_mean_e * (1. - momentum); running_variance_e = running_variance_e * momentum + batch_variance_e * (1. - momentum); } } template void BatchNormInferKernel(const Context &dev_ctx, const DenseTensor &x, const DenseTensor &mean, const DenseTensor &variance, const DenseTensor &scale, const DenseTensor &bias, float momentum, float epsilon, const std::string &data_layout, DenseTensor *y, DenseTensor *mean_out, DenseTensor *variance_out) { BatchNormKernel(dev_ctx, x, mean, variance, scale, bias, /*is_test=*/true, momentum, epsilon, data_layout, /*use_global_stats=*/false, /*trainable_statistics=*/false, y, mean_out, variance_out, /*saved_mean*/ nullptr, /*saved_variance*/ nullptr, /*reserve_space=*/nullptr); } } // namespace phi PD_REGISTER_KERNEL(batch_norm, OneDNN, ONEDNN, phi::BatchNormKernel, float) {} PD_REGISTER_KERNEL( batch_norm_infer, OneDNN, ONEDNN, phi::BatchNormInferKernel, float) {}