未验证 提交 eee4b8fb 编写于 作者: Z zhangyuqin1998 提交者: GitHub

rename BatchNormGradFunctor (#55717)

* rename BatchNormGradFunctor

* Update batch_norm_grad_kernel.cc

* Update batch_norm_grad_kernel.cu

* Update batch_norm_grad_kernel.cc

* fix

* Update batch_norm_grad_kernel.cc
上级 60e37d17
...@@ -557,7 +557,7 @@ class InplaceABNGradKernel : public framework::OpKernel<T> { ...@@ -557,7 +557,7 @@ class InplaceABNGradKernel : public framework::OpKernel<T> {
} }
auto& dev_ctx = ctx.device_context<DeviceContext>(); auto& dev_ctx = ctx.device_context<DeviceContext>();
phi::BatchNormGradRawKernel<T>( phi::BatchNormGradFunctor<T>(
static_cast<const typename framework::ConvertToPhiContext< static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx), DeviceContext>::TYPE&>(dev_ctx),
*y, *y,
......
...@@ -188,7 +188,7 @@ class InplaceABNGradKernel : public framework::OpKernel<T> { ...@@ -188,7 +188,7 @@ class InplaceABNGradKernel : public framework::OpKernel<T> {
} }
auto& dev_ctx = ctx.device_context<DeviceContext>(); auto& dev_ctx = ctx.device_context<DeviceContext>();
phi::BatchNormGradRawKernel<T>( phi::BatchNormGradFunctor<T>(
static_cast<const typename framework::ConvertToPhiContext< static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx), DeviceContext>::TYPE&>(dev_ctx),
*y, *y,
......
...@@ -21,26 +21,26 @@ ...@@ -21,26 +21,26 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void BatchNormGradRawKernel(const Context& dev_ctx, void BatchNormGradFunctor(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& scale, const DenseTensor& scale,
const DenseTensor& bias, const DenseTensor& bias,
const paddle::optional<DenseTensor>& mean, const paddle::optional<DenseTensor>& mean,
const paddle::optional<DenseTensor>& variance, const paddle::optional<DenseTensor>& variance,
const DenseTensor& saved_mean, const DenseTensor& saved_mean,
const DenseTensor& saved_variance, const DenseTensor& saved_variance,
const paddle::optional<DenseTensor>& reserve_space, const paddle::optional<DenseTensor>& reserve_space,
const DenseTensor& y_grad, const DenseTensor& y_grad,
float momentum, float momentum,
float epsilon, float epsilon,
const std::string& data_layout, const std::string& data_layout,
bool is_test, bool is_test,
bool use_global_stats, bool use_global_stats,
bool trainable_statistics, bool trainable_statistics,
bool is_inplace, bool is_inplace,
DenseTensor* x_grad, DenseTensor* x_grad,
DenseTensor* scale_grad, DenseTensor* scale_grad,
DenseTensor* bias_grad); DenseTensor* bias_grad);
template <typename T, typename Context> template <typename T, typename Context>
void BatchNormGradKernel(const Context& dev_ctx, void BatchNormGradKernel(const Context& dev_ctx,
......
...@@ -53,5 +53,26 @@ void BatchNormInferKernel(const Context& dev_ctx, ...@@ -53,5 +53,26 @@ void BatchNormInferKernel(const Context& dev_ctx,
DenseTensor* y, DenseTensor* y,
DenseTensor* mean_out, DenseTensor* mean_out,
DenseTensor* variance_out); DenseTensor* variance_out);
#define PD_DECLARE_BN_GRAD_FUNCTOR(dtype, backend) \
template void phi::BatchNormGradFunctor<dtype, ::phi::backend##Context>( \
const ::phi::backend##Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& scale, \
const DenseTensor& bias, \
const paddle::optional<DenseTensor>& mean, \
const paddle::optional<DenseTensor>& variance, \
const DenseTensor& saved_mean, \
const DenseTensor& saved_variance, \
const paddle::optional<DenseTensor>& reserve_space, \
const DenseTensor& y_grad, \
float momentum, \
float epsilon, \
const std::string& data_layout, \
bool is_test, \
bool use_global_stats, \
bool trainable_statistics, \
bool is_inplace, \
DenseTensor* x_grad, \
DenseTensor* scale_grad, \
DenseTensor* bias_grad)
} // namespace phi } // namespace phi
...@@ -36,26 +36,26 @@ using ConstEigenVectorArrayMap = ...@@ -36,26 +36,26 @@ using ConstEigenVectorArrayMap =
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>; Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>;
template <typename T, typename Context> template <typename T, typename Context>
void BatchNormGradRawKernel(const Context& ctx, void BatchNormGradFunctor(const Context& ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& scale, const DenseTensor& scale,
const DenseTensor& bias, const DenseTensor& bias,
const paddle::optional<DenseTensor>& mean, const paddle::optional<DenseTensor>& mean,
const paddle::optional<DenseTensor>& variance, const paddle::optional<DenseTensor>& variance,
const DenseTensor& saved_mean, const DenseTensor& saved_mean,
const DenseTensor& saved_variance, const DenseTensor& saved_variance,
const paddle::optional<DenseTensor>& reserve_space, const paddle::optional<DenseTensor>& reserve_space,
const DenseTensor& y_grad, const DenseTensor& y_grad,
float momentum, float momentum,
float epsilon, float epsilon,
const std::string& data_layout_str, const std::string& data_layout_str,
bool is_test, bool is_test,
bool use_global_stats, bool use_global_stats,
bool trainable_statistics, bool trainable_statistics,
bool is_inplace, bool is_inplace,
DenseTensor* x_grad, DenseTensor* x_grad,
DenseTensor* scale_grad, DenseTensor* scale_grad,
DenseTensor* bias_grad) { DenseTensor* bias_grad) {
const auto* d_y = &y_grad; const auto* d_y = &y_grad;
DataLayout data_layout = phi::StringToDataLayout(data_layout_str); DataLayout data_layout = phi::StringToDataLayout(data_layout_str);
...@@ -312,26 +312,26 @@ void BatchNormGradKernel(const Context& dev_ctx, ...@@ -312,26 +312,26 @@ void BatchNormGradKernel(const Context& dev_ctx,
DenseTensor* x_grad, DenseTensor* x_grad,
DenseTensor* scale_grad, DenseTensor* scale_grad,
DenseTensor* bias_grad) { DenseTensor* bias_grad) {
BatchNormGradRawKernel<T, Context>(dev_ctx, BatchNormGradFunctor<T, Context>(dev_ctx,
x, x,
scale, scale,
bias, bias,
mean, mean,
variance, variance,
saved_mean, saved_mean,
saved_variance, saved_variance,
reserve_space, reserve_space,
y_grad, y_grad,
momentum, momentum,
epsilon, epsilon,
data_layout, data_layout,
is_test, is_test,
use_global_stats, use_global_stats,
trainable_statistics, trainable_statistics,
false, false,
x_grad, x_grad,
scale_grad, scale_grad,
bias_grad); bias_grad);
} }
template <typename T, typename Context> template <typename T, typename Context>
...@@ -653,17 +653,13 @@ void BatchNormDoubleGradKernel( ...@@ -653,17 +653,13 @@ void BatchNormDoubleGradKernel(
} // namespace phi } // namespace phi
PD_DECLARE_BN_GRAD_FUNCTOR(float, CPU);
PD_DECLARE_BN_GRAD_FUNCTOR(double, CPU);
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
batch_norm_grad, CPU, ALL_LAYOUT, phi::BatchNormGradKernel, float, double) { batch_norm_grad, CPU, ALL_LAYOUT, phi::BatchNormGradKernel, float, double) {
} }
PD_REGISTER_KERNEL(batch_norm_grad_raw,
CPU,
ALL_LAYOUT,
phi::BatchNormGradRawKernel,
float,
double) {}
PD_REGISTER_KERNEL(batch_norm_double_grad, PD_REGISTER_KERNEL(batch_norm_double_grad,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
......
...@@ -485,26 +485,26 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackwardData( ...@@ -485,26 +485,26 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackwardData(
} }
template <typename T, typename Context> template <typename T, typename Context>
void BatchNormGradRawKernel(const Context &ctx, void BatchNormGradFunctor(const Context &ctx,
const DenseTensor &x, const DenseTensor &x,
const DenseTensor &scale, const DenseTensor &scale,
const DenseTensor &bias, const DenseTensor &bias,
const paddle::optional<DenseTensor> &mean, const paddle::optional<DenseTensor> &mean,
const paddle::optional<DenseTensor> &variance, const paddle::optional<DenseTensor> &variance,
const DenseTensor &saved_mean, const DenseTensor &saved_mean,
const DenseTensor &saved_variance, const DenseTensor &saved_variance,
const paddle::optional<DenseTensor> &reserve_space, const paddle::optional<DenseTensor> &reserve_space,
const DenseTensor &y_grad, const DenseTensor &y_grad,
float momentum, float momentum,
float epsilon_f, float epsilon_f,
const std::string &data_layout_str, const std::string &data_layout_str,
bool is_test, bool is_test,
bool use_global_stats, bool use_global_stats,
bool trainable_statistics, bool trainable_statistics,
bool is_inplace, bool is_inplace,
DenseTensor *x_grad, DenseTensor *x_grad,
DenseTensor *scale_grad, DenseTensor *scale_grad,
DenseTensor *bias_grad) { DenseTensor *bias_grad) {
double epsilon = static_cast<double>(epsilon_f); double epsilon = static_cast<double>(epsilon_f);
const DataLayout data_layout = phi::StringToDataLayout(data_layout_str); const DataLayout data_layout = phi::StringToDataLayout(data_layout_str);
...@@ -1279,26 +1279,26 @@ void BatchNormGradKernel(const Context &dev_ctx, ...@@ -1279,26 +1279,26 @@ void BatchNormGradKernel(const Context &dev_ctx,
DenseTensor *x_grad, DenseTensor *x_grad,
DenseTensor *scale_grad, DenseTensor *scale_grad,
DenseTensor *bias_grad) { DenseTensor *bias_grad) {
BatchNormGradRawKernel<T, Context>(dev_ctx, BatchNormGradFunctor<T, Context>(dev_ctx,
x, x,
scale, scale,
bias, bias,
mean, mean,
variance, variance,
saved_mean, saved_mean,
saved_variance, saved_variance,
reserve_space, reserve_space,
y_grad, y_grad,
momentum, momentum,
epsilon, epsilon,
data_layout, data_layout,
is_test, is_test,
use_global_stats, use_global_stats,
trainable_statistics, trainable_statistics,
false, false,
x_grad, x_grad,
scale_grad, scale_grad,
bias_grad); bias_grad);
} }
template <typename T, typename Context> template <typename T, typename Context>
...@@ -1360,22 +1360,23 @@ void BatchNormDoubleGradKernel( ...@@ -1360,22 +1360,23 @@ void BatchNormDoubleGradKernel(
} // namespace phi } // namespace phi
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PD_DECLARE_BN_GRAD_FUNCTOR(float, GPU);
PD_DECLARE_BN_GRAD_FUNCTOR(phi::dtype::float16, GPU);
PD_REGISTER_KERNEL(batch_norm_grad, PD_REGISTER_KERNEL(batch_norm_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::BatchNormGradKernel, phi::BatchNormGradKernel,
float, float,
phi::dtype::float16) {} phi::dtype::float16) {}
PD_REGISTER_KERNEL(batch_norm_grad_raw,
GPU,
ALL_LAYOUT,
phi::BatchNormGradRawKernel,
float,
phi::dtype::float16) {}
#else #else
#if CUDNN_VERSION_MIN(8, 1, 0) #if CUDNN_VERSION_MIN(8, 1, 0)
PD_DECLARE_BN_GRAD_FUNCTOR(float, GPU);
PD_DECLARE_BN_GRAD_FUNCTOR(double, GPU);
PD_DECLARE_BN_GRAD_FUNCTOR(phi::dtype::bfloat16, GPU);
PD_DECLARE_BN_GRAD_FUNCTOR(phi::dtype::float16, GPU);
PD_REGISTER_KERNEL(batch_norm_grad, PD_REGISTER_KERNEL(batch_norm_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -1391,23 +1392,11 @@ PD_REGISTER_KERNEL(batch_norm_grad, ...@@ -1391,23 +1392,11 @@ PD_REGISTER_KERNEL(batch_norm_grad,
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); // bias_grad kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); // bias_grad
} }
} }
PD_REGISTER_KERNEL(batch_norm_grad_raw,
GPU,
ALL_LAYOUT,
phi::BatchNormGradRawKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16 ||
kernel_key.dtype() == phi::DataType::BFLOAT16) {
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); // x_grad
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); // scale_grad
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); // bias_grad
}
}
#else #else
PD_DECLARE_BN_GRAD_FUNCTOR(float, GPU);
PD_DECLARE_BN_GRAD_FUNCTOR(double, GPU);
PD_DECLARE_BN_GRAD_FUNCTOR(phi::dtype::float16, GPU);
PD_REGISTER_KERNEL(batch_norm_grad, PD_REGISTER_KERNEL(batch_norm_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -1421,20 +1410,6 @@ PD_REGISTER_KERNEL(batch_norm_grad, ...@@ -1421,20 +1410,6 @@ PD_REGISTER_KERNEL(batch_norm_grad,
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); // bias_grad kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); // bias_grad
} }
} }
PD_REGISTER_KERNEL(batch_norm_grad_raw,
GPU,
ALL_LAYOUT,
phi::BatchNormGradRawKernel,
float,
double,
phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); // x_grad
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); // scale_grad
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); // bias_grad
}
}
#endif #endif
#endif #endif
......
...@@ -17,29 +17,52 @@ ...@@ -17,29 +17,52 @@
#include "paddle/phi/backends/onednn/onednn_reuse.h" #include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#define PD_DECLARE_BN_GRAD_FUNCTOR(dtype, backend) \
template void phi::BatchNormGradFunctor<dtype, ::phi::backend##Context>( \
const ::phi::backend##Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& scale, \
const DenseTensor& bias, \
const paddle::optional<DenseTensor>& mean, \
const paddle::optional<DenseTensor>& variance, \
const DenseTensor& saved_mean, \
const DenseTensor& saved_variance, \
const paddle::optional<DenseTensor>& reserve_space, \
const DenseTensor& y_grad, \
float momentum, \
float epsilon, \
const std::string& data_layout, \
bool is_test, \
bool use_global_stats, \
bool trainable_statistics, \
bool is_inplace, \
DenseTensor* x_grad, \
DenseTensor* scale_grad, \
DenseTensor* bias_grad)
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void BatchNormGradRawKernel(const Context& dev_ctx, void BatchNormGradFunctor(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& scale, const DenseTensor& scale,
const DenseTensor& bias, const DenseTensor& bias,
const paddle::optional<DenseTensor>& mean, const paddle::optional<DenseTensor>& mean,
const paddle::optional<DenseTensor>& variance, const paddle::optional<DenseTensor>& variance,
const DenseTensor& saved_mean, const DenseTensor& saved_mean,
const DenseTensor& saved_variance, const DenseTensor& saved_variance,
const paddle::optional<DenseTensor>& reserve_space, const paddle::optional<DenseTensor>& reserve_space,
const DenseTensor& y_grad, const DenseTensor& y_grad,
float momentum, float momentum,
float epsilon, float epsilon,
const std::string& data_layout, const std::string& data_layout,
bool is_test, bool is_test,
bool use_global_stats, bool use_global_stats,
bool trainable_statistics, bool trainable_statistics,
bool is_inplace, bool is_inplace,
DenseTensor* x_grad, DenseTensor* x_grad,
DenseTensor* scale_grad, DenseTensor* scale_grad,
DenseTensor* bias_grad) { DenseTensor* bias_grad) {
funcs::BatchNormOneDNNHandler<T> handler( funcs::BatchNormOneDNNHandler<T> handler(
dev_ctx.GetEngine(), dev_ctx.GetPlace(), epsilon, &x, &scale, &y_grad); dev_ctx.GetEngine(), dev_ctx.GetPlace(), epsilon, &x, &scale, &y_grad);
...@@ -94,31 +117,31 @@ void BatchNormGradKernel(const Context& dev_ctx, ...@@ -94,31 +117,31 @@ void BatchNormGradKernel(const Context& dev_ctx,
DenseTensor* x_grad, DenseTensor* x_grad,
DenseTensor* scale_grad, DenseTensor* scale_grad,
DenseTensor* bias_grad) { DenseTensor* bias_grad) {
BatchNormGradRawKernel<T, Context>(dev_ctx, BatchNormGradFunctor<T, Context>(dev_ctx,
x, x,
scale, scale,
bias, bias,
mean, mean,
variance, variance,
saved_mean, saved_mean,
saved_variance, saved_variance,
reserve_space, reserve_space,
y_grad, y_grad,
momentum, momentum,
epsilon, epsilon,
data_layout, data_layout,
is_test, is_test,
use_global_stats, use_global_stats,
trainable_statistics, trainable_statistics,
/*is_inplace*/ false, /*is_inplace*/ false,
x_grad, x_grad,
scale_grad, scale_grad,
bias_grad); bias_grad);
} }
} // namespace phi } // namespace phi
PD_DECLARE_BN_GRAD_FUNCTOR(float, OneDNN);
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
batch_norm_grad, OneDNN, ONEDNN, phi::BatchNormGradKernel, float) {} batch_norm_grad, OneDNN, ONEDNN, phi::BatchNormGradKernel, float) {}
PD_REGISTER_KERNEL(
batch_norm_grad_raw, OneDNN, ONEDNN, phi::BatchNormGradRawKernel, float) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册