From 6a9a7748b64e53c72a3872aebb57c87374f07f6d Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Mon, 21 Mar 2022 10:10:46 +0800 Subject: [PATCH] [Phi] Add batch norm infer kernel and related infermeta (#40688) * add batch norm infer kernel * fix value error * fix is_test error * fix test failed * add fuse false cond * add infermeta * revert mutable_data change --- paddle/phi/infermeta/multiary.cc | 42 +++++++++++- paddle/phi/infermeta/multiary.h | 13 ++++ paddle/phi/kernels/batch_norm_kernel.cc | 90 +++++++++++++++++++++++++ paddle/phi/kernels/batch_norm_kernel.h | 15 +++++ paddle/phi/ops/compat/batch_norm_sig.cc | 44 +++++++----- 5 files changed, 187 insertions(+), 17 deletions(-) create mode 100644 paddle/phi/kernels/batch_norm_kernel.cc diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 3e9da9a217a..3faf42fe1ab 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -305,11 +305,48 @@ void BatchNormInferMeta(const MetaTensor& x, y->set_dims(x_dims); mean_out->set_dims({C}); variance_out->set_dims({C}); - saved_mean->set_dims({C}); - saved_variance->set_dims({C}); + if (saved_mean) { + saved_mean->set_dims({C}); + } + if (saved_variance) { + saved_variance->set_dims({C}); + } y->share_lod(x); } +void BatchNormInferInferMeta(const MetaTensor& x, + const MetaTensor& scale, + const MetaTensor& bias, + const MetaTensor& mean, + const MetaTensor& variance, + float momentum, + float epsilon, + const std::string& data_layout, + MetaTensor* y, + MetaTensor* mean_out, + MetaTensor* variance_out, + MetaConfig config) { + BatchNormInferMeta(x, + scale, + bias, + mean, + variance, + momentum, + epsilon, + data_layout, + /*is_test=*/true, + /*use_global_stats=*/false, + /*trainable_statistics=*/false, + /*fuse_with_relu=*/false, + y, + mean_out, + variance_out, + /*saved_mean=*/nullptr, + /*saved_variance=*/nullptr, + /*reserve_space=*/nullptr, + config); +} + void BilinearTensorProductInferMeta(const MetaTensor& x, const MetaTensor& y, const MetaTensor& weight, @@ -689,3 +726,4 @@ void WhereInferMeta(const MetaTensor& condition, } // namespace phi PD_REGISTER_INFER_META_FN(batch_norm, phi::BatchNormInferMeta); +PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta); diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 068766c0e11..e9b5d8c872f 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -92,6 +92,19 @@ void BatchNormInferMeta(const MetaTensor& x, MetaTensor* reserve_space, MetaConfig config = MetaConfig()); +void BatchNormInferInferMeta(const MetaTensor& x, + const MetaTensor& scale, + const MetaTensor& bias, + const MetaTensor& mean, + const MetaTensor& variance, + float momentum, + float epsilon, + const std::string& data_layout, + MetaTensor* y, + MetaTensor* mean_out, + MetaTensor* variance_out, + MetaConfig config = MetaConfig()); + void BilinearTensorProductInferMeta(const MetaTensor& x, const MetaTensor& y, const MetaTensor& weight, diff --git a/paddle/phi/kernels/batch_norm_kernel.cc b/paddle/phi/kernels/batch_norm_kernel.cc new file mode 100644 index 00000000000..a0de7842b9e --- /dev/null +++ b/paddle/phi/kernels/batch_norm_kernel.cc @@ -0,0 +1,90 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/batch_norm_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/empty_kernel.h" + +namespace phi { + +template +void BatchNormInferKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& scale, + const DenseTensor& bias, + const DenseTensor& mean, + const DenseTensor& variance, + float momentum, + float epsilon, + const std::string& data_layout, + DenseTensor* y, + DenseTensor* mean_out, + DenseTensor* variance_out) { + // Since saved_mean and saved_variance are used regardless of whether + // they are in test mode, temporary variables need to be created here + // to be compatible + auto saved_mean = phi::EmptyLike(dev_ctx, *mean_out); + auto saved_variance = phi::EmptyLike(dev_ctx, *variance_out); + BatchNormKernel(dev_ctx, + x, + scale, + bias, + mean, + variance, + momentum, + epsilon, + data_layout, + /*is_test=*/true, + /*use_global_stats=*/false, + /*trainable_statistics=*/false, + /*fuse_with_relu=*/false, + y, + mean_out, + variance_out, + &saved_mean, + &saved_variance, + /*reserve_space=*/nullptr); +} + +} // namespace phi + +PD_REGISTER_KERNEL(batch_norm_infer, + CPU, + ALL_LAYOUT, + phi::BatchNormInferKernel, + float, + double) {} +#ifdef PADDLE_WITH_CUDA +PD_REGISTER_KERNEL(batch_norm_infer, + GPU, + ALL_LAYOUT, + phi::BatchNormInferKernel, + float, + double, + phi::dtype::float16) { + if (kernel_key.dtype() == phi::DataType::FLOAT16) { + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + } +} +#endif +#ifdef PADDLE_WITH_HIP +PD_REGISTER_KERNEL(batch_norm_infer, + GPU, + ALL_LAYOUT, + phi::BatchNormInferKernel, + float, + phi::dtype::float16) {} +#endif diff --git a/paddle/phi/kernels/batch_norm_kernel.h b/paddle/phi/kernels/batch_norm_kernel.h index 7ddf32e27c7..be589e43647 100644 --- a/paddle/phi/kernels/batch_norm_kernel.h +++ b/paddle/phi/kernels/batch_norm_kernel.h @@ -15,6 +15,7 @@ #pragma once #include + #include "paddle/phi/core/dense_tensor.h" namespace phi { @@ -40,4 +41,18 @@ void BatchNormKernel(const Context& dev_ctx, DenseTensor* saved_variance, DenseTensor* reserve_space); +template +void BatchNormInferKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& scale, + const DenseTensor& bias, + const DenseTensor& mean, + const DenseTensor& variance, + float momentum, + float epsilon, + const std::string& data_layout, + DenseTensor* y, + DenseTensor* mean_out, + DenseTensor* variance_out); + } // namespace phi diff --git a/paddle/phi/ops/compat/batch_norm_sig.cc b/paddle/phi/ops/compat/batch_norm_sig.cc index 011d4c12ece..fa1fac5d237 100644 --- a/paddle/phi/ops/compat/batch_norm_sig.cc +++ b/paddle/phi/ops/compat/batch_norm_sig.cc @@ -17,21 +17,35 @@ namespace phi { KernelSignature BatchNormOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("batch_norm", - {"X", "Scale", "Bias", "Mean", "Variance"}, - {"momentum", - "epsilon", - "data_layout", - "is_test", - "use_global_stats", - "trainable_statistics", - "fuse_with_relu"}, - {"Y", - "MeanOut", - "VarianceOut", - "SavedMean", - "SavedVariance", - "ReserveSpace"}); + bool is_test = paddle::any_cast(ctx.Attr("is_test")); + bool use_global_stats = paddle::any_cast(ctx.Attr("use_global_stats")); + bool trainable_statistics = + paddle::any_cast(ctx.Attr("trainable_statistics")); + bool fuse_with_relu = paddle::any_cast(ctx.Attr("fuse_with_relu")); + // Dispenable `MomentumTensor` is useless now + if (is_test && !use_global_stats && !trainable_statistics && + !fuse_with_relu) { + return KernelSignature("batch_norm_infer", + {"X", "Scale", "Bias", "Mean", "Variance"}, + {"momentum", "epsilon", "data_layout"}, + {"Y", "MeanOut", "VarianceOut"}); + } else { + return KernelSignature("batch_norm", + {"X", "Scale", "Bias", "Mean", "Variance"}, + {"momentum", + "epsilon", + "data_layout", + "is_test", + "use_global_stats", + "trainable_statistics", + "fuse_with_relu"}, + {"Y", + "MeanOut", + "VarianceOut", + "SavedMean", + "SavedVariance", + "ReserveSpace"}); + } } KernelSignature BatchNormGradOpArgumentMapping( -- GitLab