未验证 提交 637f27c6 编写于 作者: C ceci3 提交者: GitHub

fix bn grad compute when x.stop_gradient=True (#34102)

* fix bn

* fix

* add unittest

* fix cpu
上级 ff97dea4
...@@ -464,11 +464,9 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -464,11 +464,9 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
"BatchNormGrad"); "BatchNormGrad");
// check output // check output
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
framework::GradVarName("X"), "BatchNormGrad");
const bool has_scale_grad = ctx->HasOutput(framework::GradVarName("Scale")); const bool has_scale_grad = ctx->HasOutput(framework::GradVarName("Scale"));
const bool has_bias_grad = ctx->HasOutput(framework::GradVarName("Bias")); const bool has_bias_grad = ctx->HasOutput(framework::GradVarName("Bias"));
const bool has_x_grad = ctx->HasOutput(framework::GradVarName("X"));
PADDLE_ENFORCE_EQ((has_scale_grad == has_bias_grad), true, PADDLE_ENFORCE_EQ((has_scale_grad == has_bias_grad), true,
platform::errors::NotFound( platform::errors::NotFound(
...@@ -496,12 +494,14 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -496,12 +494,14 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
? x_dims[1] ? x_dims[1]
: x_dims[x_dims.size() - 1]); : x_dims[x_dims.size() - 1]);
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
// has_scale_grad == has_bias_grad, judge has_scale_grad is enough // has_scale_grad == has_bias_grad, judge has_scale_grad is enough
if (has_scale_grad) { if (has_scale_grad) {
ctx->SetOutputDim(framework::GradVarName("Scale"), {C}); ctx->SetOutputDim(framework::GradVarName("Scale"), {C});
ctx->SetOutputDim(framework::GradVarName("Bias"), {C}); ctx->SetOutputDim(framework::GradVarName("Bias"), {C});
} }
if (has_x_grad) {
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
} }
framework::OpKernelType BatchNormGradOp::GetExpectedKernelType( framework::OpKernelType BatchNormGradOp::GetExpectedKernelType(
...@@ -596,15 +596,20 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T> ...@@ -596,15 +596,20 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
if (ctx.HasInput("Y")) { if (ctx.HasInput("Y")) {
x = ctx.Input<Tensor>("Y"); x = ctx.Input<Tensor>("Y");
is_inplace = true; is_inplace = true;
PADDLE_ENFORCE_EQ(d_x, d_y, // if the input of batch norm is stop_gradient, d_x is null.
platform::errors::InvalidArgument( if (d_x) {
"X@GRAD and Y@GRAD not inplace in inplace mode")); PADDLE_ENFORCE_EQ(d_x, d_y,
platform::errors::InvalidArgument(
"X@GRAD and Y@GRAD not inplace in inplace mode"));
}
} else { } else {
x = ctx.Input<Tensor>("X"); x = ctx.Input<Tensor>("X");
is_inplace = false; is_inplace = false;
PADDLE_ENFORCE_NE(d_x, d_y, if (d_x) {
platform::errors::InvalidArgument( PADDLE_ENFORCE_NE(
"X@GRAD and Y@GRAD inplaced in non-inplace mode")); d_x, d_y, platform::errors::InvalidArgument(
"X@GRAD and Y@GRAD inplaced in non-inplace mode"));
}
} }
// Get the size for each dimension. // Get the size for each dimension.
...@@ -629,7 +634,9 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T> ...@@ -629,7 +634,9 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
const int sample_size = x->numel() / N / C; const int sample_size = x->numel() / N / C;
// init output // init output
d_x->mutable_data<T>(ctx.GetPlace()); if (d_x) {
d_x->mutable_data<T>(ctx.GetPlace());
}
const T *mean_data = saved_mean->data<T>(); const T *mean_data = saved_mean->data<T>();
const T *inv_var_data = saved_inv_variance->data<T>(); const T *inv_var_data = saved_inv_variance->data<T>();
...@@ -673,7 +680,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T> ...@@ -673,7 +680,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
d_scale_arr.setZero(); d_scale_arr.setZero();
} }
if ((N * sample_size) == 1 && !use_global_stats) { if (d_x && (N * sample_size) == 1 && !use_global_stats) {
framework::TensorCopy(*d_y, ctx.GetPlace(), d_x); framework::TensorCopy(*d_y, ctx.GetPlace(), d_x);
return; return;
} }
...@@ -718,8 +725,6 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T> ...@@ -718,8 +725,6 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
} }
ConstEigenArrayMap<T> x_arr(x->data<T>(), sample_size, N * C); ConstEigenArrayMap<T> x_arr(x->data<T>(), sample_size, N * C);
ConstEigenArrayMap<T> d_y_arr(d_y->data<T>(), sample_size, N * C); ConstEigenArrayMap<T> d_y_arr(d_y->data<T>(), sample_size, N * C);
EigenArrayMap<T> d_x_arr(d_x->mutable_data<T>(ctx.GetPlace()),
sample_size, N * C);
for (int nc = 0; nc < N * C; ++nc) { for (int nc = 0; nc < N * C; ++nc) {
int c = nc % C; int c = nc % C;
...@@ -734,19 +739,24 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T> ...@@ -734,19 +739,24 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
d_scale_arr = dy_mul_x_sub_mean_mul_invstd_sum_arr; d_scale_arr = dy_mul_x_sub_mean_mul_invstd_sum_arr;
} }
if (!use_global_stats) { if (d_x) {
for (int nc = 0; nc < N * C; ++nc) { EigenArrayMap<T> d_x_arr(d_x->mutable_data<T>(ctx.GetPlace()),
int c = nc % C; sample_size, N * C);
d_x_arr.col(nc) = if (!use_global_stats) {
scale_inv_var_nhw(c) * for (int nc = 0; nc < N * C; ++nc) {
(d_y_arr.col(nc) * N * sample_size - dy_sum_arr(c) - int c = nc % C;
(x_arr.col(nc) - mean_arr[c]) * d_x_arr.col(nc) =
dy_mul_x_sub_mean_mul_invstd_sum_arr(c) * inv_var_arr(c)); scale_inv_var_nhw(c) *
} (d_y_arr.col(nc) * N * sample_size - dy_sum_arr(c) -
} else { (x_arr.col(nc) - mean_arr[c]) *
for (int nc = 0; nc < N * C; ++nc) { dy_mul_x_sub_mean_mul_invstd_sum_arr(c) *
int c = nc % C; inv_var_arr(c));
d_x_arr.col(nc) = scale_inv_var_nhw(c) * d_y_arr.col(nc); }
} else {
for (int nc = 0; nc < N * C; ++nc) {
int c = nc % C;
d_x_arr.col(nc) = scale_inv_var_nhw(c) * d_y_arr.col(nc);
}
} }
} }
break; break;
...@@ -765,8 +775,6 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T> ...@@ -765,8 +775,6 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
} }
ConstEigenArrayMap<T> x_arr(x->data<T>(), C, N * sample_size); ConstEigenArrayMap<T> x_arr(x->data<T>(), C, N * sample_size);
ConstEigenArrayMap<T> d_y_arr(d_y->data<T>(), C, N * sample_size); ConstEigenArrayMap<T> d_y_arr(d_y->data<T>(), C, N * sample_size);
EigenArrayMap<T> d_x_arr(d_x->mutable_data<T>(ctx.GetPlace()), C,
N * sample_size);
for (int nhw = 0; nhw < N * sample_size; ++nhw) { for (int nhw = 0; nhw < N * sample_size; ++nhw) {
dy_sum_arr += d_y_arr.col(nhw); dy_sum_arr += d_y_arr.col(nhw);
...@@ -779,17 +787,21 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T> ...@@ -779,17 +787,21 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
d_scale_arr = dy_mul_x_sub_mean_mul_invstd_sum_arr; d_scale_arr = dy_mul_x_sub_mean_mul_invstd_sum_arr;
} }
if (!use_global_stats) { if (d_x) {
for (int nhw = 0; nhw < N * sample_size; ++nhw) { EigenArrayMap<T> d_x_arr(d_x->mutable_data<T>(ctx.GetPlace()), C,
d_x_arr.col(nhw) = N * sample_size);
scale_inv_var_nhw * if (!use_global_stats) {
(d_y_arr.col(nhw) * N * sample_size - dy_sum_arr - for (int nhw = 0; nhw < N * sample_size; ++nhw) {
(x_arr.col(nhw) - mean_arr) * d_x_arr.col(nhw) =
dy_mul_x_sub_mean_mul_invstd_sum_arr * inv_var_arr); scale_inv_var_nhw *
} (d_y_arr.col(nhw) * N * sample_size - dy_sum_arr -
} else { (x_arr.col(nhw) - mean_arr) *
for (int nhw = 0; nhw < N * sample_size; ++nhw) { dy_mul_x_sub_mean_mul_invstd_sum_arr * inv_var_arr);
d_x_arr.col(nhw) = scale_inv_var_nhw * d_y_arr.col(nhw); }
} else {
for (int nhw = 0; nhw < N * sample_size; ++nhw) {
d_x_arr.col(nhw) = scale_inv_var_nhw * d_y_arr.col(nhw);
}
} }
} }
break; break;
......
...@@ -840,15 +840,19 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -840,15 +840,19 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
if (ctx.HasInput("Y")) { if (ctx.HasInput("Y")) {
x = ctx.Input<Tensor>("Y"); x = ctx.Input<Tensor>("Y");
is_inplace = true; is_inplace = true;
PADDLE_ENFORCE_EQ(d_x, d_y, if (d_x) {
platform::errors::InvalidArgument( PADDLE_ENFORCE_EQ(d_x, d_y,
"X@GRAD and Y@GRAD not inplace in inplace mode")); platform::errors::InvalidArgument(
"X@GRAD and Y@GRAD not inplace in inplace mode"));
}
} else { } else {
x = ctx.Input<Tensor>("X"); x = ctx.Input<Tensor>("X");
is_inplace = false; is_inplace = false;
PADDLE_ENFORCE_NE(d_x, d_y, if (d_x) {
platform::errors::InvalidArgument( PADDLE_ENFORCE_NE(
"X@GRAD and Y@GRAD inplaced in non-inplace mode")); d_x, d_y, platform::errors::InvalidArgument(
"X@GRAD and Y@GRAD inplaced in non-inplace mode"));
}
} }
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
...@@ -867,7 +871,9 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -867,7 +871,9 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D); ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
// init output // init output
d_x->mutable_data<T>(ctx.GetPlace()); if (d_x) {
d_x->mutable_data<T>(ctx.GetPlace());
}
if (d_scale && d_bias) { if (d_scale && d_bias) {
d_scale->mutable_data<BatchNormParamType<T>>(ctx.GetPlace()); d_scale->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
...@@ -908,7 +914,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -908,7 +914,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
Tensor transformed_x(x->type()); Tensor transformed_x(x->type());
Tensor transformed_d_y(d_y->type()); Tensor transformed_d_y(d_y->type());
Tensor transformed_d_x(d_x->type()); Tensor transformed_d_x;
if (data_layout == DataLayout::kNHWC && if (data_layout == DataLayout::kNHWC &&
compute_format == DataLayout::kNCHW) { compute_format == DataLayout::kNCHW) {
VLOG(3) << "Transform input tensor from NHWC to NCHW."; VLOG(3) << "Transform input tensor from NHWC to NCHW.";
...@@ -920,12 +926,16 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -920,12 +926,16 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
&transformed_d_y); &transformed_d_y);
TransToChannelFirst<platform::CUDADeviceContext, T>(ctx, d_y, TransToChannelFirst<platform::CUDADeviceContext, T>(ctx, d_y,
&transformed_d_y); &transformed_d_y);
ResizeToChannelFirst<platform::CUDADeviceContext, T>(ctx, d_x, if (d_x) {
&transformed_d_x); ResizeToChannelFirst<platform::CUDADeviceContext, T>(ctx, d_x,
&transformed_d_x);
}
} else { } else {
transformed_x.ShareDataWith(*x); transformed_x.ShareDataWith(*x);
transformed_d_y.ShareDataWith(*d_y); transformed_d_y.ShareDataWith(*d_y);
transformed_d_x.ShareDataWith(*d_x); if (d_x) {
transformed_d_x.ShareDataWith(*d_x);
}
} }
std::vector<int> dims; std::vector<int> dims;
...@@ -954,7 +964,9 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -954,7 +964,9 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
if (!use_global_stats) { if (!use_global_stats) {
if ((N * H * W * D) == 1) { if ((N * H * W * D) == 1) {
framework::TensorCopy(*d_y, ctx.GetPlace(), d_x); if (d_x) {
framework::TensorCopy(*d_y, ctx.GetPlace(), d_x);
}
math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>> math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
functor; functor;
functor(dev_ctx, d_scale, static_cast<BatchNormParamType<T>>(0)); functor(dev_ctx, d_scale, static_cast<BatchNormParamType<T>>(0));
...@@ -1042,7 +1054,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -1042,7 +1054,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
} }
// This branch calls CUDNN APIs // This branch calls CUDNN APIs
if (d_scale && d_bias) { if (d_x && d_scale && d_bias) {
bool called = false; bool called = false;
#if CUDNN_VERSION_MIN(7, 4, 1) #if CUDNN_VERSION_MIN(7, 4, 1)
called = true; called = true;
...@@ -1187,6 +1199,15 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -1187,6 +1199,15 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
saved_mean_data, x->data<T>(), saved_var_data, C, N, H * W * D, saved_mean_data, x->data<T>(), saved_var_data, C, N, H * W * D,
d_x->data<T>()); d_x->data<T>());
} }
if (d_scale && d_bias) {
KeBNBackwardScaleBias<
T, block,
framework::DataLayout::kNCHW><<<grid2, block, 0, stream>>>(
d_y->data<T>(), x->data<T>(), saved_mean_data, saved_var_data,
epsilon, N, C, H * W * D,
d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>());
}
} else { } else {
if (d_x) { if (d_x) {
BNBackwardData<T, block, framework::DataLayout::kNHWC><<< BNBackwardData<T, block, framework::DataLayout::kNHWC><<<
...@@ -1195,6 +1216,15 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -1195,6 +1216,15 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
saved_mean_data, x->data<T>(), saved_var_data, C, N, H * W * D, saved_mean_data, x->data<T>(), saved_var_data, C, N, H * W * D,
d_x->data<T>()); d_x->data<T>());
} }
if (d_scale && d_bias) {
KeBNBackwardScaleBias<
T, block,
framework::DataLayout::kNHWC><<<grid2, block, 0, stream>>>(
d_y->data<T>(), x->data<T>(), saved_mean_data, saved_var_data,
epsilon, N, C, H * W * D,
d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>());
}
} }
} }
......
...@@ -515,6 +515,13 @@ class TestBatchNormOpTrainingCase2(TestBatchNormOpTraining): ...@@ -515,6 +515,13 @@ class TestBatchNormOpTrainingCase2(TestBatchNormOpTraining):
os.environ['FLAGS_cudnn_batchnorm_spatial_persistent'] = "1" os.environ['FLAGS_cudnn_batchnorm_spatial_persistent'] = "1"
class TestBatchNormOpTrainingCase3(TestBatchNormOpTraining):
def init_test_case(self):
self.use_global_stats = False
self.no_grad_set = set(['x@GRAD'])
self.fetch_list = ['y', 'mean', 'variance', 'scale@GRAD', 'bias@GRAD']
class TestBatchNormOpTrainingMomentumVariable(TestBatchNormOpTraining): class TestBatchNormOpTrainingMomentumVariable(TestBatchNormOpTraining):
def init_test_case(self): def init_test_case(self):
self.use_momentum_variable = True self.use_momentum_variable = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册