未验证 提交 637f27c6 编写于 作者: C ceci3 提交者: GitHub

fix bn grad compute when x.stop_gradient=True (#34102)

* fix bn

* fix

* add unittest

* fix cpu
上级 ff97dea4
......@@ -464,11 +464,9 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
"BatchNormGrad");
// check output
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
framework::GradVarName("X"), "BatchNormGrad");
const bool has_scale_grad = ctx->HasOutput(framework::GradVarName("Scale"));
const bool has_bias_grad = ctx->HasOutput(framework::GradVarName("Bias"));
const bool has_x_grad = ctx->HasOutput(framework::GradVarName("X"));
PADDLE_ENFORCE_EQ((has_scale_grad == has_bias_grad), true,
platform::errors::NotFound(
......@@ -496,12 +494,14 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
? x_dims[1]
: x_dims[x_dims.size() - 1]);
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
// has_scale_grad == has_bias_grad, judge has_scale_grad is enough
if (has_scale_grad) {
ctx->SetOutputDim(framework::GradVarName("Scale"), {C});
ctx->SetOutputDim(framework::GradVarName("Bias"), {C});
}
if (has_x_grad) {
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
}
framework::OpKernelType BatchNormGradOp::GetExpectedKernelType(
......@@ -596,16 +596,21 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
if (ctx.HasInput("Y")) {
x = ctx.Input<Tensor>("Y");
is_inplace = true;
// if the input of batch norm is stop_gradient, d_x is null.
if (d_x) {
PADDLE_ENFORCE_EQ(d_x, d_y,
platform::errors::InvalidArgument(
"X@GRAD and Y@GRAD not inplace in inplace mode"));
}
} else {
x = ctx.Input<Tensor>("X");
is_inplace = false;
PADDLE_ENFORCE_NE(d_x, d_y,
platform::errors::InvalidArgument(
if (d_x) {
PADDLE_ENFORCE_NE(
d_x, d_y, platform::errors::InvalidArgument(
"X@GRAD and Y@GRAD inplaced in non-inplace mode"));
}
}
// Get the size for each dimension.
// NCHW [batch_size, in_channels, in_height, in_width]
......@@ -629,7 +634,9 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
const int sample_size = x->numel() / N / C;
// init output
if (d_x) {
d_x->mutable_data<T>(ctx.GetPlace());
}
const T *mean_data = saved_mean->data<T>();
const T *inv_var_data = saved_inv_variance->data<T>();
......@@ -673,7 +680,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
d_scale_arr.setZero();
}
if ((N * sample_size) == 1 && !use_global_stats) {
if (d_x && (N * sample_size) == 1 && !use_global_stats) {
framework::TensorCopy(*d_y, ctx.GetPlace(), d_x);
return;
}
......@@ -718,8 +725,6 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
}
ConstEigenArrayMap<T> x_arr(x->data<T>(), sample_size, N * C);
ConstEigenArrayMap<T> d_y_arr(d_y->data<T>(), sample_size, N * C);
EigenArrayMap<T> d_x_arr(d_x->mutable_data<T>(ctx.GetPlace()),
sample_size, N * C);
for (int nc = 0; nc < N * C; ++nc) {
int c = nc % C;
......@@ -734,6 +739,9 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
d_scale_arr = dy_mul_x_sub_mean_mul_invstd_sum_arr;
}
if (d_x) {
EigenArrayMap<T> d_x_arr(d_x->mutable_data<T>(ctx.GetPlace()),
sample_size, N * C);
if (!use_global_stats) {
for (int nc = 0; nc < N * C; ++nc) {
int c = nc % C;
......@@ -741,7 +749,8 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
scale_inv_var_nhw(c) *
(d_y_arr.col(nc) * N * sample_size - dy_sum_arr(c) -
(x_arr.col(nc) - mean_arr[c]) *
dy_mul_x_sub_mean_mul_invstd_sum_arr(c) * inv_var_arr(c));
dy_mul_x_sub_mean_mul_invstd_sum_arr(c) *
inv_var_arr(c));
}
} else {
for (int nc = 0; nc < N * C; ++nc) {
......@@ -749,6 +758,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
d_x_arr.col(nc) = scale_inv_var_nhw(c) * d_y_arr.col(nc);
}
}
}
break;
}
case DataLayout::kNHWC: {
......@@ -765,8 +775,6 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
}
ConstEigenArrayMap<T> x_arr(x->data<T>(), C, N * sample_size);
ConstEigenArrayMap<T> d_y_arr(d_y->data<T>(), C, N * sample_size);
EigenArrayMap<T> d_x_arr(d_x->mutable_data<T>(ctx.GetPlace()), C,
N * sample_size);
for (int nhw = 0; nhw < N * sample_size; ++nhw) {
dy_sum_arr += d_y_arr.col(nhw);
......@@ -779,6 +787,9 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
d_scale_arr = dy_mul_x_sub_mean_mul_invstd_sum_arr;
}
if (d_x) {
EigenArrayMap<T> d_x_arr(d_x->mutable_data<T>(ctx.GetPlace()), C,
N * sample_size);
if (!use_global_stats) {
for (int nhw = 0; nhw < N * sample_size; ++nhw) {
d_x_arr.col(nhw) =
......@@ -792,6 +803,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
d_x_arr.col(nhw) = scale_inv_var_nhw * d_y_arr.col(nhw);
}
}
}
break;
}
default:
......
......@@ -840,16 +840,20 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
if (ctx.HasInput("Y")) {
x = ctx.Input<Tensor>("Y");
is_inplace = true;
if (d_x) {
PADDLE_ENFORCE_EQ(d_x, d_y,
platform::errors::InvalidArgument(
"X@GRAD and Y@GRAD not inplace in inplace mode"));
}
} else {
x = ctx.Input<Tensor>("X");
is_inplace = false;
PADDLE_ENFORCE_NE(d_x, d_y,
platform::errors::InvalidArgument(
if (d_x) {
PADDLE_ENFORCE_NE(
d_x, d_y, platform::errors::InvalidArgument(
"X@GRAD and Y@GRAD inplaced in non-inplace mode"));
}
}
const bool is_test = ctx.Attr<bool>("is_test");
use_global_stats = is_test || use_global_stats;
......@@ -867,7 +871,9 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
// init output
if (d_x) {
d_x->mutable_data<T>(ctx.GetPlace());
}
if (d_scale && d_bias) {
d_scale->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
......@@ -908,7 +914,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
Tensor transformed_x(x->type());
Tensor transformed_d_y(d_y->type());
Tensor transformed_d_x(d_x->type());
Tensor transformed_d_x;
if (data_layout == DataLayout::kNHWC &&
compute_format == DataLayout::kNCHW) {
VLOG(3) << "Transform input tensor from NHWC to NCHW.";
......@@ -920,13 +926,17 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
&transformed_d_y);
TransToChannelFirst<platform::CUDADeviceContext, T>(ctx, d_y,
&transformed_d_y);
if (d_x) {
ResizeToChannelFirst<platform::CUDADeviceContext, T>(ctx, d_x,
&transformed_d_x);
}
} else {
transformed_x.ShareDataWith(*x);
transformed_d_y.ShareDataWith(*d_y);
if (d_x) {
transformed_d_x.ShareDataWith(*d_x);
}
}
std::vector<int> dims;
std::vector<int> strides;
......@@ -954,7 +964,9 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
if (!use_global_stats) {
if ((N * H * W * D) == 1) {
if (d_x) {
framework::TensorCopy(*d_y, ctx.GetPlace(), d_x);
}
math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
functor;
functor(dev_ctx, d_scale, static_cast<BatchNormParamType<T>>(0));
......@@ -1042,7 +1054,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
}
// This branch calls CUDNN APIs
if (d_scale && d_bias) {
if (d_x && d_scale && d_bias) {
bool called = false;
#if CUDNN_VERSION_MIN(7, 4, 1)
called = true;
......@@ -1187,6 +1199,15 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
saved_mean_data, x->data<T>(), saved_var_data, C, N, H * W * D,
d_x->data<T>());
}
if (d_scale && d_bias) {
KeBNBackwardScaleBias<
T, block,
framework::DataLayout::kNCHW><<<grid2, block, 0, stream>>>(
d_y->data<T>(), x->data<T>(), saved_mean_data, saved_var_data,
epsilon, N, C, H * W * D,
d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>());
}
} else {
if (d_x) {
BNBackwardData<T, block, framework::DataLayout::kNHWC><<<
......@@ -1195,6 +1216,15 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
saved_mean_data, x->data<T>(), saved_var_data, C, N, H * W * D,
d_x->data<T>());
}
if (d_scale && d_bias) {
KeBNBackwardScaleBias<
T, block,
framework::DataLayout::kNHWC><<<grid2, block, 0, stream>>>(
d_y->data<T>(), x->data<T>(), saved_mean_data, saved_var_data,
epsilon, N, C, H * W * D,
d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>());
}
}
}
......
......@@ -515,6 +515,13 @@ class TestBatchNormOpTrainingCase2(TestBatchNormOpTraining):
os.environ['FLAGS_cudnn_batchnorm_spatial_persistent'] = "1"
class TestBatchNormOpTrainingCase3(TestBatchNormOpTraining):
def init_test_case(self):
self.use_global_stats = False
self.no_grad_set = set(['x@GRAD'])
self.fetch_list = ['y', 'mean', 'variance', 'scale@GRAD', 'bias@GRAD']
class TestBatchNormOpTrainingMomentumVariable(TestBatchNormOpTraining):
def init_test_case(self):
self.use_momentum_variable = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册