未验证 提交 934d8b89 编写于 作者: Z zhangyikun02 提交者: GitHub

[XPU] batch_norm_grad support float16 for xpu (#53977)

上级 f71c805e
......@@ -57,7 +57,8 @@ XPUOpMap& get_kl2_ops() {
{"atan", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"atan_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"batch_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"batch_norm_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"batch_norm",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"bmm", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
......
......@@ -89,6 +89,7 @@ void BatchNormGradKernel(const Context &dev_ctx,
DenseTensor *x_grad,
DenseTensor *scale_grad,
DenseTensor *bias_grad) {
using XPUType = typename XPUTypeTrait<T>::Type;
const auto *d_y = &y_grad;
PADDLE_ENFORCE_EQ(data_layout == "NCHW" || data_layout == "NHWC",
true,
......@@ -132,20 +133,21 @@ void BatchNormGradKernel(const Context &dev_ctx,
W = W * D;
const auto *x_data = x.data<T>();
const auto *d_y_data = y_grad.data<T>();
const auto *x_data = reinterpret_cast<const XPUType *>(x.data<T>());
const auto *d_y_data = reinterpret_cast<const XPUType *>(y_grad.data<T>());
const auto *scale_data = scale.data<float>();
// init output
T *x_grad_data = nullptr;
T *bias_grad_data = nullptr;
T *scale_grad_data = nullptr;
XPUType *x_grad_data = nullptr;
float *bias_grad_data = nullptr;
float *scale_grad_data = nullptr;
if (x_grad) {
x_grad_data = dev_ctx.template Alloc<T>(x_grad);
x_grad_data =
reinterpret_cast<XPUType *>(dev_ctx.template Alloc<T>(x_grad));
}
if (scale_grad && bias_grad) {
scale_grad_data = dev_ctx.template Alloc<T>(scale_grad);
bias_grad_data = dev_ctx.template Alloc<T>(bias_grad);
scale_grad_data = dev_ctx.template Alloc<float>(scale_grad);
bias_grad_data = dev_ctx.template Alloc<float>(bias_grad);
}
PADDLE_ENFORCE_EQ(
......@@ -172,65 +174,68 @@ void BatchNormGradKernel(const Context &dev_ctx,
const auto *global_var = variance.get_ptr();
// TODO(guozibin): hadle the situation case of N * H * W = 1
int r = 0;
if (is_inplace) {
float *global_inv_std_data = nullptr;
if (use_global_stats) {
global_inv_std_data =
RAII_GUARD.alloc_l3_or_gm<float>(global_var->numel());
float *epsilon_data = RAII_GUARD.alloc_l3_or_gm<float>(1);
int r1 = CalculateInvVar(dev_ctx.x_context(),
global_var->data<float>(),
epsilon,
C,
epsilon_data,
global_inv_std_data);
PADDLE_ENFORCE_XDNN_SUCCESS(r1,
r = CalculateInvVar(dev_ctx.x_context(),
global_var->data<float>(),
epsilon,
C,
epsilon_data,
global_inv_std_data);
PADDLE_ENFORCE_XDNN_SUCCESS(r,
"batch_norm_grad CalculateInvVar function");
}
// Here is a trick, x is a const input,
// but trans to a non-const var, is it risky?
auto px = x;
float *x_fp32_data = RAII_GUARD.alloc_l3_or_gm<float>(x.numel());
r = xpu::cast<XPUType, float>(
dev_ctx.x_context(), x_data, x_fp32_data, x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
auto *inv_std_data =
use_global_stats ? global_inv_std_data : saved_variance.data<float>();
auto *mean_data = use_global_stats ? global_mean->data<float>()
: saved_mean.data<float>();
int r2 = CalculateInvBNY(dev_ctx.x_context(),
px.data<T>(),
scale.data<float>(),
bias.data<float>(),
mean_data,
inv_std_data,
N,
C,
H * W,
x.data<T>());
PADDLE_ENFORCE_XDNN_SUCCESS(r2, "batch_norm_grad CalculateInvBNY function");
r = CalculateInvBNY(dev_ctx.x_context(),
x_fp32_data,
scale.data<float>(),
bias.data<float>(),
mean_data,
inv_std_data,
N,
C,
H * W,
x_fp32_data);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "batch_norm_grad CalculateInvBNY function");
}
int r3;
bool is_nchw = data_layout == "NCHW";
if (use_global_stats) {
r3 = xpu::batch_norm_grad<T>(dev_ctx.x_context(),
x_data,
d_y_data,
x_grad_data,
N,
C,
H,
W,
scale_data,
nullptr,
nullptr,
scale_grad_data,
bias_grad_data,
is_nchw,
global_mean->data<float>(),
global_var->data<float>(),
epsilon);
r = xpu::batch_norm_grad<XPUType>(dev_ctx.x_context(),
x_data,
d_y_data,
x_grad_data,
N,
C,
H,
W,
scale_data,
nullptr,
nullptr,
scale_grad_data,
bias_grad_data,
is_nchw,
global_mean->data<float>(),
global_var->data<float>(),
epsilon);
} else {
if (!x_grad) {
x_grad_data = RAII_GUARD.alloc_l3_or_gm<T>(x.numel());
x_grad_data = RAII_GUARD.alloc_l3_or_gm<XPUType>(x.numel());
}
if (!scale_grad) {
scale_grad_data = RAII_GUARD.alloc_l3_or_gm<float>(C);
......@@ -238,25 +243,29 @@ void BatchNormGradKernel(const Context &dev_ctx,
if (!bias_grad_data) {
bias_grad_data = RAII_GUARD.alloc_l3_or_gm<float>(C);
}
r3 = xpu::batch_norm_grad<T>(dev_ctx.x_context(),
x_data,
d_y_data,
x_grad_data,
N,
C,
H,
W,
scale_data,
saved_mean.data<float>(),
saved_variance.data<float>(),
scale_grad_data,
bias_grad_data,
is_nchw);
r = xpu::batch_norm_grad<XPUType>(dev_ctx.x_context(),
x_data,
d_y_data,
x_grad_data,
N,
C,
H,
W,
scale_data,
saved_mean.data<float>(),
saved_variance.data<float>(),
scale_grad_data,
bias_grad_data,
is_nchw);
}
PADDLE_ENFORCE_XDNN_SUCCESS(r3, "batch_norm_grad");
PADDLE_ENFORCE_XDNN_SUCCESS(r, "batch_norm_grad");
}
} // namespace phi
PD_REGISTER_KERNEL(
batch_norm_grad, XPU, ALL_LAYOUT, phi::BatchNormGradKernel, float) {}
PD_REGISTER_KERNEL(batch_norm_grad,
XPU,
ALL_LAYOUT,
phi::BatchNormGradKernel,
float,
phi::dtype::float16) {}
......@@ -317,6 +317,11 @@ class XPUTestBatchNormGradOp(XPUOpTestWrapper):
self.init_dtype()
self.set_xpu()
self.set_attrs()
self.rtol = 1e-5
self.atol = 1e-4
if self.dtype == np.float16:
self.rtol = 1e-2
self.atol = 1e-3
if self.data_layout == "NHWC":
channel_size = self.shape[3]
......@@ -451,7 +456,7 @@ class XPUTestBatchNormGradOp(XPUOpTestWrapper):
outs = exe.run(program, feed=inputs, fetch_list=fetch_list)
for id, name in enumerate(fetch_list):
np.testing.assert_allclose(
outputs[name], outs[id], rtol=1e-05, atol=1e-4
outputs[name], outs[id], rtol=self.rtol, atol=self.atol
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册