未验证 提交 eee4b8fb 编写于 作者: Z zhangyuqin1998 提交者: GitHub

rename BatchNormGradFunctor (#55717)

* rename BatchNormGradFunctor

* Update batch_norm_grad_kernel.cc

* Update batch_norm_grad_kernel.cu

* Update batch_norm_grad_kernel.cc

* fix

* Update batch_norm_grad_kernel.cc
上级 60e37d17
......@@ -557,7 +557,7 @@ class InplaceABNGradKernel : public framework::OpKernel<T> {
}
auto& dev_ctx = ctx.device_context<DeviceContext>();
phi::BatchNormGradRawKernel<T>(
phi::BatchNormGradFunctor<T>(
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*y,
......
......@@ -188,7 +188,7 @@ class InplaceABNGradKernel : public framework::OpKernel<T> {
}
auto& dev_ctx = ctx.device_context<DeviceContext>();
phi::BatchNormGradRawKernel<T>(
phi::BatchNormGradFunctor<T>(
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*y,
......
......@@ -21,7 +21,7 @@
namespace phi {
template <typename T, typename Context>
void BatchNormGradRawKernel(const Context& dev_ctx,
void BatchNormGradFunctor(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& scale,
const DenseTensor& bias,
......
......@@ -53,5 +53,26 @@ void BatchNormInferKernel(const Context& dev_ctx,
DenseTensor* y,
DenseTensor* mean_out,
DenseTensor* variance_out);
#define PD_DECLARE_BN_GRAD_FUNCTOR(dtype, backend) \
template void phi::BatchNormGradFunctor<dtype, ::phi::backend##Context>( \
const ::phi::backend##Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& scale, \
const DenseTensor& bias, \
const paddle::optional<DenseTensor>& mean, \
const paddle::optional<DenseTensor>& variance, \
const DenseTensor& saved_mean, \
const DenseTensor& saved_variance, \
const paddle::optional<DenseTensor>& reserve_space, \
const DenseTensor& y_grad, \
float momentum, \
float epsilon, \
const std::string& data_layout, \
bool is_test, \
bool use_global_stats, \
bool trainable_statistics, \
bool is_inplace, \
DenseTensor* x_grad, \
DenseTensor* scale_grad, \
DenseTensor* bias_grad)
} // namespace phi
......@@ -36,7 +36,7 @@ using ConstEigenVectorArrayMap =
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>;
template <typename T, typename Context>
void BatchNormGradRawKernel(const Context& ctx,
void BatchNormGradFunctor(const Context& ctx,
const DenseTensor& x,
const DenseTensor& scale,
const DenseTensor& bias,
......@@ -312,7 +312,7 @@ void BatchNormGradKernel(const Context& dev_ctx,
DenseTensor* x_grad,
DenseTensor* scale_grad,
DenseTensor* bias_grad) {
BatchNormGradRawKernel<T, Context>(dev_ctx,
BatchNormGradFunctor<T, Context>(dev_ctx,
x,
scale,
bias,
......@@ -653,17 +653,13 @@ void BatchNormDoubleGradKernel(
} // namespace phi
PD_DECLARE_BN_GRAD_FUNCTOR(float, CPU);
PD_DECLARE_BN_GRAD_FUNCTOR(double, CPU);
PD_REGISTER_KERNEL(
batch_norm_grad, CPU, ALL_LAYOUT, phi::BatchNormGradKernel, float, double) {
}
PD_REGISTER_KERNEL(batch_norm_grad_raw,
CPU,
ALL_LAYOUT,
phi::BatchNormGradRawKernel,
float,
double) {}
PD_REGISTER_KERNEL(batch_norm_double_grad,
CPU,
ALL_LAYOUT,
......
......@@ -485,7 +485,7 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackwardData(
}
template <typename T, typename Context>
void BatchNormGradRawKernel(const Context &ctx,
void BatchNormGradFunctor(const Context &ctx,
const DenseTensor &x,
const DenseTensor &scale,
const DenseTensor &bias,
......@@ -1279,7 +1279,7 @@ void BatchNormGradKernel(const Context &dev_ctx,
DenseTensor *x_grad,
DenseTensor *scale_grad,
DenseTensor *bias_grad) {
BatchNormGradRawKernel<T, Context>(dev_ctx,
BatchNormGradFunctor<T, Context>(dev_ctx,
x,
scale,
bias,
......@@ -1360,22 +1360,23 @@ void BatchNormDoubleGradKernel(
} // namespace phi
#ifdef PADDLE_WITH_HIP
PD_DECLARE_BN_GRAD_FUNCTOR(float, GPU);
PD_DECLARE_BN_GRAD_FUNCTOR(phi::dtype::float16, GPU);
PD_REGISTER_KERNEL(batch_norm_grad,
GPU,
ALL_LAYOUT,
phi::BatchNormGradKernel,
float,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(batch_norm_grad_raw,
GPU,
ALL_LAYOUT,
phi::BatchNormGradRawKernel,
float,
phi::dtype::float16) {}
#else
#if CUDNN_VERSION_MIN(8, 1, 0)
PD_DECLARE_BN_GRAD_FUNCTOR(float, GPU);
PD_DECLARE_BN_GRAD_FUNCTOR(double, GPU);
PD_DECLARE_BN_GRAD_FUNCTOR(phi::dtype::bfloat16, GPU);
PD_DECLARE_BN_GRAD_FUNCTOR(phi::dtype::float16, GPU);
PD_REGISTER_KERNEL(batch_norm_grad,
GPU,
ALL_LAYOUT,
......@@ -1391,23 +1392,11 @@ PD_REGISTER_KERNEL(batch_norm_grad,
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); // bias_grad
}
}
PD_REGISTER_KERNEL(batch_norm_grad_raw,
GPU,
ALL_LAYOUT,
phi::BatchNormGradRawKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16 ||
kernel_key.dtype() == phi::DataType::BFLOAT16) {
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); // x_grad
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); // scale_grad
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); // bias_grad
}
}
#else
PD_DECLARE_BN_GRAD_FUNCTOR(float, GPU);
PD_DECLARE_BN_GRAD_FUNCTOR(double, GPU);
PD_DECLARE_BN_GRAD_FUNCTOR(phi::dtype::float16, GPU);
PD_REGISTER_KERNEL(batch_norm_grad,
GPU,
ALL_LAYOUT,
......@@ -1421,20 +1410,6 @@ PD_REGISTER_KERNEL(batch_norm_grad,
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); // bias_grad
}
}
PD_REGISTER_KERNEL(batch_norm_grad_raw,
GPU,
ALL_LAYOUT,
phi::BatchNormGradRawKernel,
float,
double,
phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); // x_grad
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); // scale_grad
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); // bias_grad
}
}
#endif
#endif
......
......@@ -17,10 +17,33 @@
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"
#define PD_DECLARE_BN_GRAD_FUNCTOR(dtype, backend) \
template void phi::BatchNormGradFunctor<dtype, ::phi::backend##Context>( \
const ::phi::backend##Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& scale, \
const DenseTensor& bias, \
const paddle::optional<DenseTensor>& mean, \
const paddle::optional<DenseTensor>& variance, \
const DenseTensor& saved_mean, \
const DenseTensor& saved_variance, \
const paddle::optional<DenseTensor>& reserve_space, \
const DenseTensor& y_grad, \
float momentum, \
float epsilon, \
const std::string& data_layout, \
bool is_test, \
bool use_global_stats, \
bool trainable_statistics, \
bool is_inplace, \
DenseTensor* x_grad, \
DenseTensor* scale_grad, \
DenseTensor* bias_grad)
namespace phi {
template <typename T, typename Context>
void BatchNormGradRawKernel(const Context& dev_ctx,
void BatchNormGradFunctor(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& scale,
const DenseTensor& bias,
......@@ -94,7 +117,7 @@ void BatchNormGradKernel(const Context& dev_ctx,
DenseTensor* x_grad,
DenseTensor* scale_grad,
DenseTensor* bias_grad) {
BatchNormGradRawKernel<T, Context>(dev_ctx,
BatchNormGradFunctor<T, Context>(dev_ctx,
x,
scale,
bias,
......@@ -118,7 +141,7 @@ void BatchNormGradKernel(const Context& dev_ctx,
} // namespace phi
PD_DECLARE_BN_GRAD_FUNCTOR(float, OneDNN);
PD_REGISTER_KERNEL(
batch_norm_grad, OneDNN, ONEDNN, phi::BatchNormGradKernel, float) {}
PD_REGISTER_KERNEL(
batch_norm_grad_raw, OneDNN, ONEDNN, phi::BatchNormGradRawKernel, float) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册