未验证 提交 9774f965 编写于 作者: Z Zhangjingyu06 提交者: GitHub

modify batch_norm and batch_norm_grad. *test=kunlun (#41976)

上级 c3b0b680
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
...@@ -38,15 +37,25 @@ class BatchNormXPUKernel : public framework::OpKernel<T> { ...@@ -38,15 +37,25 @@ class BatchNormXPUKernel : public framework::OpKernel<T> {
bool global_stats = test_mode || use_global_stats; bool global_stats = test_mode || use_global_stats;
const auto &data_layout_str = ctx.Attr<std::string>("data_layout"); const auto &data_layout_str = ctx.Attr<std::string>("data_layout");
const auto data_layout = framework::StringToDataLayout(data_layout_str); const auto data_layout = framework::StringToDataLayout(data_layout_str);
PADDLE_ENFORCE_EQ(data_layout_str == "NCHW" || data_layout_str == "NHWC",
true,
platform::errors::InvalidArgument(
"The 'data_layout' attribute must be NCHW or NHWC. "
"But recevived 'data_layout' is [%s].",
data_layout_str));
const auto *x = ctx.Input<Tensor>("X"); const auto *x = ctx.Input<Tensor>("X");
const auto &x_dims = x->dims(); const auto &x_dims = x->dims();
int temp = x_dims[3]; PADDLE_ENFORCE_EQ(
temp = (x_dims.size() != 4) ? 1 : temp; x_dims.size() >= 2 && x_dims.size() <= 5, true,
bool is_nchw = (data_layout == DataLayout::kNCHW); platform::errors::InvalidArgument(
const int N = x_dims[0]; "The size of input's dimensions should be between 2 and 5"
const int C = is_nchw ? x_dims[1] : temp; "But received: the size of input's dimensions is [%d]",
const int H = is_nchw ? x_dims[2] : x_dims[1]; x_dims.size()));
const int W = is_nchw ? temp : x_dims[2];
int N, C, H, W, D;
ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
const auto *scale = ctx.Input<Tensor>("Scale"); const auto *scale = ctx.Input<Tensor>("Scale");
const auto *bias = ctx.Input<Tensor>("Bias"); const auto *bias = ctx.Input<Tensor>("Bias");
const auto *x_data = x->data<T>(); const auto *x_data = x->data<T>();
...@@ -67,6 +76,7 @@ class BatchNormXPUKernel : public framework::OpKernel<T> { ...@@ -67,6 +76,7 @@ class BatchNormXPUKernel : public framework::OpKernel<T> {
saved_variance->mutable_data<float>(ctx.GetPlace()); saved_variance->mutable_data<float>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<DeviceContext>(); auto &dev_ctx = ctx.template device_context<DeviceContext>();
bool is_nchw = data_layout_str == "NCHW";
if (!global_stats) { if (!global_stats) {
auto *mean_out_data = mean_out->data<float>(); auto *mean_out_data = mean_out->data<float>();
...@@ -83,35 +93,29 @@ class BatchNormXPUKernel : public framework::OpKernel<T> { ...@@ -83,35 +93,29 @@ class BatchNormXPUKernel : public framework::OpKernel<T> {
&mom_cpu); &mom_cpu);
momentum = mom_tensor->data<float>()[0]; momentum = mom_tensor->data<float>()[0];
} }
if (C == 1) {
int r = xpu::batch_norm<T>(dev_ctx.x_context(), x_data, y_data, N, 1, H, int r = xpu::batch_norm<T>(dev_ctx.x_context(), x_data, y_data, N, C, H,
W, epsilon, momentum, scale_data, bias_data, W, epsilon, momentum, scale_data, bias_data,
saved_mean_data, saved_variance_data, saved_mean_data, saved_variance_data,
mean_out_data, variance_out_data, true); mean_out_data, variance_out_data, is_nchw);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
r, xpu::Error_t::SUCCESS, platform::errors::External(
platform::errors::External( "The batch_norm XPU API return wrong value[%d %s]",
"The batch_norm XPU API return wrong value[%d %s]", r, r, XPUAPIErrorMsg[r]));
XPUAPIErrorMsg[r]));
} else {
int r = xpu::batch_norm<T>(dev_ctx.x_context(), x_data, y_data, N, C, H,
W, epsilon, momentum, scale_data, bias_data,
saved_mean_data, saved_variance_data,
mean_out_data, variance_out_data, is_nchw);
PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS,
platform::errors::External(
"The batch_norm XPU API return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
}
} else { } else {
PADDLE_ENFORCE_EQ(
data_layout_str == "NCHW", true,
platform::errors::InvalidArgument(
"The batch_norm_infer 'data_layout' attribute must be NCHW. "
"But recevived 'data_layout' is [%s].",
data_layout_str));
const auto *mean = ctx.Input<Tensor>("Mean"); const auto *mean = ctx.Input<Tensor>("Mean");
const auto *variance = ctx.Input<Tensor>("Variance"); const auto *variance = ctx.Input<Tensor>("Variance");
const auto *mean_data = mean->data<float>(); const auto *mean_data = mean->data<float>();
const auto *variance_data = variance->data<float>(); const auto *variance_data = variance->data<float>();
int r = xpu::batch_norm_infer(dev_ctx.x_context(), x_data, y_data, N, C, int r = xpu::batch_norm_infer(dev_ctx.x_context(), x_data, y_data, N, C,
H, W, epsilon, scale_data, bias_data, H, W, epsilon, scale_data, bias_data,
mean_data, variance_data, true); mean_data, variance_data, is_nchw);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS, r, xpu::Error_t::SUCCESS,
platform::errors::External( platform::errors::External(
...@@ -172,6 +176,13 @@ class BatchNormGradXPUKernel : public framework::OpKernel<T> { ...@@ -172,6 +176,13 @@ class BatchNormGradXPUKernel : public framework::OpKernel<T> {
const float epsilon = ctx.Attr<float>("epsilon"); const float epsilon = ctx.Attr<float>("epsilon");
const auto data_layout = framework::StringToDataLayout(data_layout_str); const auto data_layout = framework::StringToDataLayout(data_layout_str);
PADDLE_ENFORCE_EQ(data_layout_str == "NCHW" || data_layout_str == "NHWC",
true,
platform::errors::InvalidArgument(
"The 'data_layout' attribute must be NCHW or NHWC. "
"But recevived 'data_layout' is [%s].",
data_layout_str));
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X")); auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale")); auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias")); auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
...@@ -204,13 +215,15 @@ class BatchNormGradXPUKernel : public framework::OpKernel<T> { ...@@ -204,13 +215,15 @@ class BatchNormGradXPUKernel : public framework::OpKernel<T> {
} }
const auto &x_dims = x->dims(); const auto &x_dims = x->dims();
int temp = x_dims[3]; PADDLE_ENFORCE_EQ(
temp = (x_dims.size() != 4) ? 1 : temp; x_dims.size() >= 2 && x_dims.size() <= 5, true,
bool is_nchw = (data_layout == DataLayout::kNCHW); platform::errors::InvalidArgument(
const int N = x_dims[0]; "The size of input's dimensions should be between 2 and 5"
const int C = is_nchw ? x_dims[1] : temp; "But received: the size of input's dimensions is [%d]",
const int H = is_nchw ? x_dims[2] : x_dims[1]; x_dims.size()));
const int W = is_nchw ? temp : x_dims[2];
int N, C, H, W, D;
ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
const auto *x_data = x->data<T>(); const auto *x_data = x->data<T>();
const auto *d_y_data = d_y->data<T>(); const auto *d_y_data = d_y->data<T>();
...@@ -235,42 +248,45 @@ class BatchNormGradXPUKernel : public framework::OpKernel<T> { ...@@ -235,42 +248,45 @@ class BatchNormGradXPUKernel : public framework::OpKernel<T> {
"the size of scale's dimensions is [%d], the dimensions of scale " "the size of scale's dimensions is [%d], the dimensions of scale "
"is [%s].", "is [%s].",
scale->dims().size(), scale->dims())); scale->dims().size(), scale->dims()));
PADDLE_ENFORCE_EQ(
scale->dims()[0], C,
platform::errors::InvalidArgument(
"The first dimension of scale must equal to Channels[%d]. But "
"received: the first dimension of scale is [%d]",
C, scale->dims()[0]));
auto &dev_ctx = ctx.template device_context<DeviceContext>(); auto &dev_ctx = ctx.template device_context<DeviceContext>();
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
const T *mean_data = nullptr; const auto *batch_mean = ctx.Input<Tensor>("SavedMean");
const T *inv_var_data = nullptr; const auto *batch_inv_std = ctx.Input<Tensor>("SavedVariance");
const auto *global_mean = ctx.Input<Tensor>("Mean");
const auto *global_var = ctx.Input<Tensor>("Variance");
// TODO(guozibin): hadle the situation case of N * H * W = 1 // TODO(guozibin): hadle the situation case of N * H * W = 1
if (!use_global_stats) {
const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
// SavedVariance have been reverted in forward operator
const auto *saved_inv_variance = ctx.Input<Tensor>("SavedVariance");
mean_data = saved_mean->data<float>();
inv_var_data = saved_inv_variance->data<float>();
} else {
const auto *running_mean = ctx.Input<Tensor>("Mean");
const auto *running_variance = ctx.Input<Tensor>("Variance");
mean_data = running_mean->data<float>();
inv_var_data = running_variance->data<float>();
float *running_inv_var_data =
RAII_GUARD.alloc_l3_or_gm<float>(running_variance->numel());
float *epsilon_data = RAII_GUARD.alloc_l3_or_gm<float>(1);
int r1 = calculate_inv_var(dev_ctx.x_context(), inv_var_data, epsilon, C,
epsilon_data, running_inv_var_data);
PADDLE_ENFORCE_EQ(r1, XPU_SUCCESS, platform::errors::External(
"XPU API(batch_norm_grad "
"calculate_inv_var function) "
"return wrong value[%d %s]",
r1, XPUAPIErrorMsg[r1]));
inv_var_data = running_inv_var_data;
}
if (is_inplace) { if (is_inplace) {
float *global_inv_std_data = nullptr;
if (use_global_stats) {
global_inv_std_data =
RAII_GUARD.alloc_l3_or_gm<float>(global_var->numel());
float *epsilon_data = RAII_GUARD.alloc_l3_or_gm<float>(1);
int r1 =
calculate_inv_var(dev_ctx.x_context(), global_var->data<float>(),
epsilon, C, epsilon_data, global_inv_std_data);
PADDLE_ENFORCE_EQ(r1, XPU_SUCCESS, platform::errors::External(
"XPU API(batch_norm_grad "
"calculate_inv_var function) "
"return wrong value[%d %s]",
r1, XPUAPIErrorMsg[r1]));
}
auto px = *x; auto px = *x;
auto *inv_std_data =
use_global_stats ? global_inv_std_data : batch_inv_std->data<float>();
auto mean_data = use_global_stats ? global_mean->data<float>()
: batch_mean->data<float>();
int r2 = calculate_inv_BN_Y( int r2 = calculate_inv_BN_Y(
dev_ctx.x_context(), px.mutable_data<T>(ctx.GetPlace()), dev_ctx.x_context(), px.mutable_data<T>(ctx.GetPlace()),
scale->data<float>(), bias->data<float>(), mean_data, inv_var_data, N, scale->data<float>(), bias->data<float>(), mean_data, inv_std_data, N,
C, H * W, x->data<T>()); C, H * W, x->data<T>());
PADDLE_ENFORCE_EQ(r2, XPU_SUCCESS, platform::errors::External( PADDLE_ENFORCE_EQ(r2, XPU_SUCCESS, platform::errors::External(
"XPU API(batch_norm_grad " "XPU API(batch_norm_grad "
...@@ -278,19 +294,29 @@ class BatchNormGradXPUKernel : public framework::OpKernel<T> { ...@@ -278,19 +294,29 @@ class BatchNormGradXPUKernel : public framework::OpKernel<T> {
"return wrong value[%d %s]", "return wrong value[%d %s]",
r2, XPUAPIErrorMsg[r2])); r2, XPUAPIErrorMsg[r2]));
} }
if (!d_x) {
d_x_data = RAII_GUARD.alloc_l3_or_gm<T>(x->numel());
}
if (!d_scale) {
d_scale_data = RAII_GUARD.alloc_l3_or_gm<float>(C);
}
if (!d_bias_data) {
d_bias_data = RAII_GUARD.alloc_l3_or_gm<float>(C);
}
int r3 = xpu::batch_norm_grad<T>( int r3;
dev_ctx.x_context(), x_data, d_y_data, d_x_data, N, C, H, W, scale_data, bool is_nchw = data_layout_str == "NCHW";
mean_data, inv_var_data, d_scale_data, d_bias_data, is_nchw); if (use_global_stats) {
r3 = xpu::batch_norm_grad<T>(
dev_ctx.x_context(), x_data, d_y_data, d_x_data, N, C, H, W,
scale_data, nullptr, nullptr, d_scale_data, d_bias_data, is_nchw,
global_mean->data<float>(), global_var->data<float>(), epsilon);
} else {
if (!d_x) {
d_x_data = RAII_GUARD.alloc_l3_or_gm<T>(x->numel());
}
if (!d_scale) {
d_scale_data = RAII_GUARD.alloc_l3_or_gm<float>(C);
}
if (!d_bias_data) {
d_bias_data = RAII_GUARD.alloc_l3_or_gm<float>(C);
}
r3 = xpu::batch_norm_grad<T>(
dev_ctx.x_context(), x_data, d_y_data, d_x_data, N, C, H, W,
scale_data, batch_mean->data<float>(), batch_inv_std->data<float>(),
d_scale_data, d_bias_data, is_nchw);
}
PADDLE_ENFORCE_EQ(r3, XPU_SUCCESS, platform::errors::External( PADDLE_ENFORCE_EQ(r3, XPU_SUCCESS, platform::errors::External(
"XPU API(batch_norm_grad) return " "XPU API(batch_norm_grad) return "
"wrong value[%d %s]", "wrong value[%d %s]",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册