未验证 提交 934d8b89 编写于 作者: Z zhangyikun02 提交者: GitHub

[XPU] batch_norm_grad support float16 for xpu (#53977)

上级 f71c805e
...@@ -57,7 +57,8 @@ XPUOpMap& get_kl2_ops() { ...@@ -57,7 +57,8 @@ XPUOpMap& get_kl2_ops() {
{"atan", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"atan", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"atan_grad", {"atan_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"batch_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"batch_norm_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"batch_norm", {"batch_norm",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"bmm", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"bmm", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
......
...@@ -89,6 +89,7 @@ void BatchNormGradKernel(const Context &dev_ctx, ...@@ -89,6 +89,7 @@ void BatchNormGradKernel(const Context &dev_ctx,
DenseTensor *x_grad, DenseTensor *x_grad,
DenseTensor *scale_grad, DenseTensor *scale_grad,
DenseTensor *bias_grad) { DenseTensor *bias_grad) {
using XPUType = typename XPUTypeTrait<T>::Type;
const auto *d_y = &y_grad; const auto *d_y = &y_grad;
PADDLE_ENFORCE_EQ(data_layout == "NCHW" || data_layout == "NHWC", PADDLE_ENFORCE_EQ(data_layout == "NCHW" || data_layout == "NHWC",
true, true,
...@@ -132,20 +133,21 @@ void BatchNormGradKernel(const Context &dev_ctx, ...@@ -132,20 +133,21 @@ void BatchNormGradKernel(const Context &dev_ctx,
W = W * D; W = W * D;
const auto *x_data = x.data<T>(); const auto *x_data = reinterpret_cast<const XPUType *>(x.data<T>());
const auto *d_y_data = y_grad.data<T>(); const auto *d_y_data = reinterpret_cast<const XPUType *>(y_grad.data<T>());
const auto *scale_data = scale.data<float>(); const auto *scale_data = scale.data<float>();
// init output // init output
T *x_grad_data = nullptr; XPUType *x_grad_data = nullptr;
T *bias_grad_data = nullptr; float *bias_grad_data = nullptr;
T *scale_grad_data = nullptr; float *scale_grad_data = nullptr;
if (x_grad) { if (x_grad) {
x_grad_data = dev_ctx.template Alloc<T>(x_grad); x_grad_data =
reinterpret_cast<XPUType *>(dev_ctx.template Alloc<T>(x_grad));
} }
if (scale_grad && bias_grad) { if (scale_grad && bias_grad) {
scale_grad_data = dev_ctx.template Alloc<T>(scale_grad); scale_grad_data = dev_ctx.template Alloc<float>(scale_grad);
bias_grad_data = dev_ctx.template Alloc<T>(bias_grad); bias_grad_data = dev_ctx.template Alloc<float>(bias_grad);
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -172,65 +174,68 @@ void BatchNormGradKernel(const Context &dev_ctx, ...@@ -172,65 +174,68 @@ void BatchNormGradKernel(const Context &dev_ctx,
const auto *global_var = variance.get_ptr(); const auto *global_var = variance.get_ptr();
// TODO(guozibin): hadle the situation case of N * H * W = 1 // TODO(guozibin): hadle the situation case of N * H * W = 1
int r = 0;
if (is_inplace) { if (is_inplace) {
float *global_inv_std_data = nullptr; float *global_inv_std_data = nullptr;
if (use_global_stats) { if (use_global_stats) {
global_inv_std_data = global_inv_std_data =
RAII_GUARD.alloc_l3_or_gm<float>(global_var->numel()); RAII_GUARD.alloc_l3_or_gm<float>(global_var->numel());
float *epsilon_data = RAII_GUARD.alloc_l3_or_gm<float>(1); float *epsilon_data = RAII_GUARD.alloc_l3_or_gm<float>(1);
int r1 = CalculateInvVar(dev_ctx.x_context(), r = CalculateInvVar(dev_ctx.x_context(),
global_var->data<float>(), global_var->data<float>(),
epsilon, epsilon,
C, C,
epsilon_data, epsilon_data,
global_inv_std_data); global_inv_std_data);
PADDLE_ENFORCE_XDNN_SUCCESS(r1, PADDLE_ENFORCE_XDNN_SUCCESS(r,
"batch_norm_grad CalculateInvVar function"); "batch_norm_grad CalculateInvVar function");
} }
// Here is a trick, x is a const input, // Here is a trick, x is a const input,
// but trans to a non-const var, is it risky? // but trans to a non-const var, is it risky?
auto px = x; float *x_fp32_data = RAII_GUARD.alloc_l3_or_gm<float>(x.numel());
r = xpu::cast<XPUType, float>(
dev_ctx.x_context(), x_data, x_fp32_data, x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
auto *inv_std_data = auto *inv_std_data =
use_global_stats ? global_inv_std_data : saved_variance.data<float>(); use_global_stats ? global_inv_std_data : saved_variance.data<float>();
auto *mean_data = use_global_stats ? global_mean->data<float>() auto *mean_data = use_global_stats ? global_mean->data<float>()
: saved_mean.data<float>(); : saved_mean.data<float>();
int r2 = CalculateInvBNY(dev_ctx.x_context(), r = CalculateInvBNY(dev_ctx.x_context(),
px.data<T>(), x_fp32_data,
scale.data<float>(), scale.data<float>(),
bias.data<float>(), bias.data<float>(),
mean_data, mean_data,
inv_std_data, inv_std_data,
N, N,
C, C,
H * W, H * W,
x.data<T>()); x_fp32_data);
PADDLE_ENFORCE_XDNN_SUCCESS(r2, "batch_norm_grad CalculateInvBNY function"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "batch_norm_grad CalculateInvBNY function");
} }
int r3;
bool is_nchw = data_layout == "NCHW"; bool is_nchw = data_layout == "NCHW";
if (use_global_stats) { if (use_global_stats) {
r3 = xpu::batch_norm_grad<T>(dev_ctx.x_context(), r = xpu::batch_norm_grad<XPUType>(dev_ctx.x_context(),
x_data, x_data,
d_y_data, d_y_data,
x_grad_data, x_grad_data,
N, N,
C, C,
H, H,
W, W,
scale_data, scale_data,
nullptr, nullptr,
nullptr, nullptr,
scale_grad_data, scale_grad_data,
bias_grad_data, bias_grad_data,
is_nchw, is_nchw,
global_mean->data<float>(), global_mean->data<float>(),
global_var->data<float>(), global_var->data<float>(),
epsilon); epsilon);
} else { } else {
if (!x_grad) { if (!x_grad) {
x_grad_data = RAII_GUARD.alloc_l3_or_gm<T>(x.numel()); x_grad_data = RAII_GUARD.alloc_l3_or_gm<XPUType>(x.numel());
} }
if (!scale_grad) { if (!scale_grad) {
scale_grad_data = RAII_GUARD.alloc_l3_or_gm<float>(C); scale_grad_data = RAII_GUARD.alloc_l3_or_gm<float>(C);
...@@ -238,25 +243,29 @@ void BatchNormGradKernel(const Context &dev_ctx, ...@@ -238,25 +243,29 @@ void BatchNormGradKernel(const Context &dev_ctx,
if (!bias_grad_data) { if (!bias_grad_data) {
bias_grad_data = RAII_GUARD.alloc_l3_or_gm<float>(C); bias_grad_data = RAII_GUARD.alloc_l3_or_gm<float>(C);
} }
r3 = xpu::batch_norm_grad<T>(dev_ctx.x_context(), r = xpu::batch_norm_grad<XPUType>(dev_ctx.x_context(),
x_data, x_data,
d_y_data, d_y_data,
x_grad_data, x_grad_data,
N, N,
C, C,
H, H,
W, W,
scale_data, scale_data,
saved_mean.data<float>(), saved_mean.data<float>(),
saved_variance.data<float>(), saved_variance.data<float>(),
scale_grad_data, scale_grad_data,
bias_grad_data, bias_grad_data,
is_nchw); is_nchw);
} }
PADDLE_ENFORCE_XDNN_SUCCESS(r3, "batch_norm_grad"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "batch_norm_grad");
} }
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(batch_norm_grad,
batch_norm_grad, XPU, ALL_LAYOUT, phi::BatchNormGradKernel, float) {} XPU,
ALL_LAYOUT,
phi::BatchNormGradKernel,
float,
phi::dtype::float16) {}
...@@ -317,6 +317,11 @@ class XPUTestBatchNormGradOp(XPUOpTestWrapper): ...@@ -317,6 +317,11 @@ class XPUTestBatchNormGradOp(XPUOpTestWrapper):
self.init_dtype() self.init_dtype()
self.set_xpu() self.set_xpu()
self.set_attrs() self.set_attrs()
self.rtol = 1e-5
self.atol = 1e-4
if self.dtype == np.float16:
self.rtol = 1e-2
self.atol = 1e-3
if self.data_layout == "NHWC": if self.data_layout == "NHWC":
channel_size = self.shape[3] channel_size = self.shape[3]
...@@ -451,7 +456,7 @@ class XPUTestBatchNormGradOp(XPUOpTestWrapper): ...@@ -451,7 +456,7 @@ class XPUTestBatchNormGradOp(XPUOpTestWrapper):
outs = exe.run(program, feed=inputs, fetch_list=fetch_list) outs = exe.run(program, feed=inputs, fetch_list=fetch_list)
for id, name in enumerate(fetch_list): for id, name in enumerate(fetch_list):
np.testing.assert_allclose( np.testing.assert_allclose(
outputs[name], outs[id], rtol=1e-05, atol=1e-4 outputs[name], outs[id], rtol=self.rtol, atol=self.atol
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册