未验证 提交 6a9a7748 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Add batch norm infer kernel and related infermeta (#40688)

* add batch norm infer kernel

* fix value error

* fix is_test error

* fix test failed

* add fuse false cond

* add infermeta

* revert mutable_data change
上级 facda828
......@@ -305,11 +305,48 @@ void BatchNormInferMeta(const MetaTensor& x,
y->set_dims(x_dims);
mean_out->set_dims({C});
variance_out->set_dims({C});
saved_mean->set_dims({C});
saved_variance->set_dims({C});
if (saved_mean) {
saved_mean->set_dims({C});
}
if (saved_variance) {
saved_variance->set_dims({C});
}
y->share_lod(x);
}
void BatchNormInferInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& bias,
const MetaTensor& mean,
const MetaTensor& variance,
float momentum,
float epsilon,
const std::string& data_layout,
MetaTensor* y,
MetaTensor* mean_out,
MetaTensor* variance_out,
MetaConfig config) {
BatchNormInferMeta(x,
scale,
bias,
mean,
variance,
momentum,
epsilon,
data_layout,
/*is_test=*/true,
/*use_global_stats=*/false,
/*trainable_statistics=*/false,
/*fuse_with_relu=*/false,
y,
mean_out,
variance_out,
/*saved_mean=*/nullptr,
/*saved_variance=*/nullptr,
/*reserve_space=*/nullptr,
config);
}
void BilinearTensorProductInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& weight,
......@@ -689,3 +726,4 @@ void WhereInferMeta(const MetaTensor& condition,
} // namespace phi
PD_REGISTER_INFER_META_FN(batch_norm, phi::BatchNormInferMeta);
PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta);
......@@ -92,6 +92,19 @@ void BatchNormInferMeta(const MetaTensor& x,
MetaTensor* reserve_space,
MetaConfig config = MetaConfig());
void BatchNormInferInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& bias,
const MetaTensor& mean,
const MetaTensor& variance,
float momentum,
float epsilon,
const std::string& data_layout,
MetaTensor* y,
MetaTensor* mean_out,
MetaTensor* variance_out,
MetaConfig config = MetaConfig());
void BilinearTensorProductInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& weight,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/batch_norm_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
namespace phi {
template <typename T, typename Context>
void BatchNormInferKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& scale,
const DenseTensor& bias,
const DenseTensor& mean,
const DenseTensor& variance,
float momentum,
float epsilon,
const std::string& data_layout,
DenseTensor* y,
DenseTensor* mean_out,
DenseTensor* variance_out) {
// Since saved_mean and saved_variance are used regardless of whether
// they are in test mode, temporary variables need to be created here
// to be compatible
auto saved_mean = phi::EmptyLike<T, Context>(dev_ctx, *mean_out);
auto saved_variance = phi::EmptyLike<T, Context>(dev_ctx, *variance_out);
BatchNormKernel<T, Context>(dev_ctx,
x,
scale,
bias,
mean,
variance,
momentum,
epsilon,
data_layout,
/*is_test=*/true,
/*use_global_stats=*/false,
/*trainable_statistics=*/false,
/*fuse_with_relu=*/false,
y,
mean_out,
variance_out,
&saved_mean,
&saved_variance,
/*reserve_space=*/nullptr);
}
} // namespace phi
PD_REGISTER_KERNEL(batch_norm_infer,
CPU,
ALL_LAYOUT,
phi::BatchNormInferKernel,
float,
double) {}
#ifdef PADDLE_WITH_CUDA
PD_REGISTER_KERNEL(batch_norm_infer,
GPU,
ALL_LAYOUT,
phi::BatchNormInferKernel,
float,
double,
phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
}
}
#endif
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(batch_norm_infer,
GPU,
ALL_LAYOUT,
phi::BatchNormInferKernel,
float,
phi::dtype::float16) {}
#endif
......@@ -15,6 +15,7 @@
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
......@@ -40,4 +41,18 @@ void BatchNormKernel(const Context& dev_ctx,
DenseTensor* saved_variance,
DenseTensor* reserve_space);
template <typename T, typename Context>
void BatchNormInferKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& scale,
const DenseTensor& bias,
const DenseTensor& mean,
const DenseTensor& variance,
float momentum,
float epsilon,
const std::string& data_layout,
DenseTensor* y,
DenseTensor* mean_out,
DenseTensor* variance_out);
} // namespace phi
......@@ -17,21 +17,35 @@
namespace phi {
KernelSignature BatchNormOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("batch_norm",
{"X", "Scale", "Bias", "Mean", "Variance"},
{"momentum",
"epsilon",
"data_layout",
"is_test",
"use_global_stats",
"trainable_statistics",
"fuse_with_relu"},
{"Y",
"MeanOut",
"VarianceOut",
"SavedMean",
"SavedVariance",
"ReserveSpace"});
bool is_test = paddle::any_cast<bool>(ctx.Attr("is_test"));
bool use_global_stats = paddle::any_cast<bool>(ctx.Attr("use_global_stats"));
bool trainable_statistics =
paddle::any_cast<bool>(ctx.Attr("trainable_statistics"));
bool fuse_with_relu = paddle::any_cast<bool>(ctx.Attr("fuse_with_relu"));
// Dispenable `MomentumTensor` is useless now
if (is_test && !use_global_stats && !trainable_statistics &&
!fuse_with_relu) {
return KernelSignature("batch_norm_infer",
{"X", "Scale", "Bias", "Mean", "Variance"},
{"momentum", "epsilon", "data_layout"},
{"Y", "MeanOut", "VarianceOut"});
} else {
return KernelSignature("batch_norm",
{"X", "Scale", "Bias", "Mean", "Variance"},
{"momentum",
"epsilon",
"data_layout",
"is_test",
"use_global_stats",
"trainable_statistics",
"fuse_with_relu"},
{"Y",
"MeanOut",
"VarianceOut",
"SavedMean",
"SavedVariance",
"ReserveSpace"});
}
}
KernelSignature BatchNormGradOpArgumentMapping(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册