未验证 提交 637f27c6 编写于 作者: C ceci3 提交者: GitHub

fix bn grad compute when x.stop_gradient=True (#34102)

* fix bn

* fix

* add unittest

* fix cpu
上级 ff97dea4
...@@ -464,11 +464,9 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -464,11 +464,9 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
"BatchNormGrad"); "BatchNormGrad");
// check output // check output
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
framework::GradVarName("X"), "BatchNormGrad");
const bool has_scale_grad = ctx->HasOutput(framework::GradVarName("Scale")); const bool has_scale_grad = ctx->HasOutput(framework::GradVarName("Scale"));
const bool has_bias_grad = ctx->HasOutput(framework::GradVarName("Bias")); const bool has_bias_grad = ctx->HasOutput(framework::GradVarName("Bias"));
const bool has_x_grad = ctx->HasOutput(framework::GradVarName("X"));
PADDLE_ENFORCE_EQ((has_scale_grad == has_bias_grad), true, PADDLE_ENFORCE_EQ((has_scale_grad == has_bias_grad), true,
platform::errors::NotFound( platform::errors::NotFound(
...@@ -496,12 +494,14 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -496,12 +494,14 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
? x_dims[1] ? x_dims[1]
: x_dims[x_dims.size() - 1]); : x_dims[x_dims.size() - 1]);
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
// has_scale_grad == has_bias_grad, judge has_scale_grad is enough // has_scale_grad == has_bias_grad, judge has_scale_grad is enough
if (has_scale_grad) { if (has_scale_grad) {
ctx->SetOutputDim(framework::GradVarName("Scale"), {C}); ctx->SetOutputDim(framework::GradVarName("Scale"), {C});
ctx->SetOutputDim(framework::GradVarName("Bias"), {C}); ctx->SetOutputDim(framework::GradVarName("Bias"), {C});
} }
if (has_x_grad) {
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
} }
framework::OpKernelType BatchNormGradOp::GetExpectedKernelType( framework::OpKernelType BatchNormGradOp::GetExpectedKernelType(
...@@ -596,16 +596,21 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T> ...@@ -596,16 +596,21 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
if (ctx.HasInput("Y")) { if (ctx.HasInput("Y")) {
x = ctx.Input<Tensor>("Y"); x = ctx.Input<Tensor>("Y");
is_inplace = true; is_inplace = true;
// if the input of batch norm is stop_gradient, d_x is null.
if (d_x) {
PADDLE_ENFORCE_EQ(d_x, d_y, PADDLE_ENFORCE_EQ(d_x, d_y,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"X@GRAD and Y@GRAD not inplace in inplace mode")); "X@GRAD and Y@GRAD not inplace in inplace mode"));
}
} else { } else {
x = ctx.Input<Tensor>("X"); x = ctx.Input<Tensor>("X");
is_inplace = false; is_inplace = false;
PADDLE_ENFORCE_NE(d_x, d_y, if (d_x) {
platform::errors::InvalidArgument( PADDLE_ENFORCE_NE(
d_x, d_y, platform::errors::InvalidArgument(
"X@GRAD and Y@GRAD inplaced in non-inplace mode")); "X@GRAD and Y@GRAD inplaced in non-inplace mode"));
} }
}
// Get the size for each dimension. // Get the size for each dimension.
// NCHW [batch_size, in_channels, in_height, in_width] // NCHW [batch_size, in_channels, in_height, in_width]
...@@ -629,7 +634,9 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T> ...@@ -629,7 +634,9 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
const int sample_size = x->numel() / N / C; const int sample_size = x->numel() / N / C;
// init output // init output
if (d_x) {
d_x->mutable_data<T>(ctx.GetPlace()); d_x->mutable_data<T>(ctx.GetPlace());
}
const T *mean_data = saved_mean->data<T>(); const T *mean_data = saved_mean->data<T>();
const T *inv_var_data = saved_inv_variance->data<T>(); const T *inv_var_data = saved_inv_variance->data<T>();
...@@ -673,7 +680,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T> ...@@ -673,7 +680,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
d_scale_arr.setZero(); d_scale_arr.setZero();
} }
if ((N * sample_size) == 1 && !use_global_stats) { if (d_x && (N * sample_size) == 1 && !use_global_stats) {
framework::TensorCopy(*d_y, ctx.GetPlace(), d_x); framework::TensorCopy(*d_y, ctx.GetPlace(), d_x);
return; return;
} }
...@@ -718,8 +725,6 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T> ...@@ -718,8 +725,6 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
} }
ConstEigenArrayMap<T> x_arr(x->data<T>(), sample_size, N * C); ConstEigenArrayMap<T> x_arr(x->data<T>(), sample_size, N * C);
ConstEigenArrayMap<T> d_y_arr(d_y->data<T>(), sample_size, N * C); ConstEigenArrayMap<T> d_y_arr(d_y->data<T>(), sample_size, N * C);
EigenArrayMap<T> d_x_arr(d_x->mutable_data<T>(ctx.GetPlace()),
sample_size, N * C);
for (int nc = 0; nc < N * C; ++nc) { for (int nc = 0; nc < N * C; ++nc) {
int c = nc % C; int c = nc % C;
...@@ -734,6 +739,9 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T> ...@@ -734,6 +739,9 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
d_scale_arr = dy_mul_x_sub_mean_mul_invstd_sum_arr; d_scale_arr = dy_mul_x_sub_mean_mul_invstd_sum_arr;
} }
if (d_x) {
EigenArrayMap<T> d_x_arr(d_x->mutable_data<T>(ctx.GetPlace()),
sample_size, N * C);
if (!use_global_stats) { if (!use_global_stats) {
for (int nc = 0; nc < N * C; ++nc) { for (int nc = 0; nc < N * C; ++nc) {
int c = nc % C; int c = nc % C;
...@@ -741,7 +749,8 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T> ...@@ -741,7 +749,8 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
scale_inv_var_nhw(c) * scale_inv_var_nhw(c) *
(d_y_arr.col(nc) * N * sample_size - dy_sum_arr(c) - (d_y_arr.col(nc) * N * sample_size - dy_sum_arr(c) -
(x_arr.col(nc) - mean_arr[c]) * (x_arr.col(nc) - mean_arr[c]) *
dy_mul_x_sub_mean_mul_invstd_sum_arr(c) * inv_var_arr(c)); dy_mul_x_sub_mean_mul_invstd_sum_arr(c) *
inv_var_arr(c));
} }
} else { } else {
for (int nc = 0; nc < N * C; ++nc) { for (int nc = 0; nc < N * C; ++nc) {
...@@ -749,6 +758,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T> ...@@ -749,6 +758,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
d_x_arr.col(nc) = scale_inv_var_nhw(c) * d_y_arr.col(nc); d_x_arr.col(nc) = scale_inv_var_nhw(c) * d_y_arr.col(nc);
} }
} }
}
break; break;
} }
case DataLayout::kNHWC: { case DataLayout::kNHWC: {
...@@ -765,8 +775,6 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T> ...@@ -765,8 +775,6 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
} }
ConstEigenArrayMap<T> x_arr(x->data<T>(), C, N * sample_size); ConstEigenArrayMap<T> x_arr(x->data<T>(), C, N * sample_size);
ConstEigenArrayMap<T> d_y_arr(d_y->data<T>(), C, N * sample_size); ConstEigenArrayMap<T> d_y_arr(d_y->data<T>(), C, N * sample_size);
EigenArrayMap<T> d_x_arr(d_x->mutable_data<T>(ctx.GetPlace()), C,
N * sample_size);
for (int nhw = 0; nhw < N * sample_size; ++nhw) { for (int nhw = 0; nhw < N * sample_size; ++nhw) {
dy_sum_arr += d_y_arr.col(nhw); dy_sum_arr += d_y_arr.col(nhw);
...@@ -779,6 +787,9 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T> ...@@ -779,6 +787,9 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
d_scale_arr = dy_mul_x_sub_mean_mul_invstd_sum_arr; d_scale_arr = dy_mul_x_sub_mean_mul_invstd_sum_arr;
} }
if (d_x) {
EigenArrayMap<T> d_x_arr(d_x->mutable_data<T>(ctx.GetPlace()), C,
N * sample_size);
if (!use_global_stats) { if (!use_global_stats) {
for (int nhw = 0; nhw < N * sample_size; ++nhw) { for (int nhw = 0; nhw < N * sample_size; ++nhw) {
d_x_arr.col(nhw) = d_x_arr.col(nhw) =
...@@ -792,6 +803,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T> ...@@ -792,6 +803,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
d_x_arr.col(nhw) = scale_inv_var_nhw * d_y_arr.col(nhw); d_x_arr.col(nhw) = scale_inv_var_nhw * d_y_arr.col(nhw);
} }
} }
}
break; break;
} }
default: default:
......
...@@ -840,16 +840,20 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -840,16 +840,20 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
if (ctx.HasInput("Y")) { if (ctx.HasInput("Y")) {
x = ctx.Input<Tensor>("Y"); x = ctx.Input<Tensor>("Y");
is_inplace = true; is_inplace = true;
if (d_x) {
PADDLE_ENFORCE_EQ(d_x, d_y, PADDLE_ENFORCE_EQ(d_x, d_y,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"X@GRAD and Y@GRAD not inplace in inplace mode")); "X@GRAD and Y@GRAD not inplace in inplace mode"));
}
} else { } else {
x = ctx.Input<Tensor>("X"); x = ctx.Input<Tensor>("X");
is_inplace = false; is_inplace = false;
PADDLE_ENFORCE_NE(d_x, d_y, if (d_x) {
platform::errors::InvalidArgument( PADDLE_ENFORCE_NE(
d_x, d_y, platform::errors::InvalidArgument(
"X@GRAD and Y@GRAD inplaced in non-inplace mode")); "X@GRAD and Y@GRAD inplaced in non-inplace mode"));
} }
}
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
use_global_stats = is_test || use_global_stats; use_global_stats = is_test || use_global_stats;
...@@ -867,7 +871,9 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -867,7 +871,9 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D); ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
// init output // init output
if (d_x) {
d_x->mutable_data<T>(ctx.GetPlace()); d_x->mutable_data<T>(ctx.GetPlace());
}
if (d_scale && d_bias) { if (d_scale && d_bias) {
d_scale->mutable_data<BatchNormParamType<T>>(ctx.GetPlace()); d_scale->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
...@@ -908,7 +914,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -908,7 +914,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
Tensor transformed_x(x->type()); Tensor transformed_x(x->type());
Tensor transformed_d_y(d_y->type()); Tensor transformed_d_y(d_y->type());
Tensor transformed_d_x(d_x->type()); Tensor transformed_d_x;
if (data_layout == DataLayout::kNHWC && if (data_layout == DataLayout::kNHWC &&
compute_format == DataLayout::kNCHW) { compute_format == DataLayout::kNCHW) {
VLOG(3) << "Transform input tensor from NHWC to NCHW."; VLOG(3) << "Transform input tensor from NHWC to NCHW.";
...@@ -920,13 +926,17 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -920,13 +926,17 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
&transformed_d_y); &transformed_d_y);
TransToChannelFirst<platform::CUDADeviceContext, T>(ctx, d_y, TransToChannelFirst<platform::CUDADeviceContext, T>(ctx, d_y,
&transformed_d_y); &transformed_d_y);
if (d_x) {
ResizeToChannelFirst<platform::CUDADeviceContext, T>(ctx, d_x, ResizeToChannelFirst<platform::CUDADeviceContext, T>(ctx, d_x,
&transformed_d_x); &transformed_d_x);
}
} else { } else {
transformed_x.ShareDataWith(*x); transformed_x.ShareDataWith(*x);
transformed_d_y.ShareDataWith(*d_y); transformed_d_y.ShareDataWith(*d_y);
if (d_x) {
transformed_d_x.ShareDataWith(*d_x); transformed_d_x.ShareDataWith(*d_x);
} }
}
std::vector<int> dims; std::vector<int> dims;
std::vector<int> strides; std::vector<int> strides;
...@@ -954,7 +964,9 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -954,7 +964,9 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
if (!use_global_stats) { if (!use_global_stats) {
if ((N * H * W * D) == 1) { if ((N * H * W * D) == 1) {
if (d_x) {
framework::TensorCopy(*d_y, ctx.GetPlace(), d_x); framework::TensorCopy(*d_y, ctx.GetPlace(), d_x);
}
math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>> math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
functor; functor;
functor(dev_ctx, d_scale, static_cast<BatchNormParamType<T>>(0)); functor(dev_ctx, d_scale, static_cast<BatchNormParamType<T>>(0));
...@@ -1042,7 +1054,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -1042,7 +1054,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
} }
// This branch calls CUDNN APIs // This branch calls CUDNN APIs
if (d_scale && d_bias) { if (d_x && d_scale && d_bias) {
bool called = false; bool called = false;
#if CUDNN_VERSION_MIN(7, 4, 1) #if CUDNN_VERSION_MIN(7, 4, 1)
called = true; called = true;
...@@ -1187,6 +1199,15 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -1187,6 +1199,15 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
saved_mean_data, x->data<T>(), saved_var_data, C, N, H * W * D, saved_mean_data, x->data<T>(), saved_var_data, C, N, H * W * D,
d_x->data<T>()); d_x->data<T>());
} }
if (d_scale && d_bias) {
KeBNBackwardScaleBias<
T, block,
framework::DataLayout::kNCHW><<<grid2, block, 0, stream>>>(
d_y->data<T>(), x->data<T>(), saved_mean_data, saved_var_data,
epsilon, N, C, H * W * D,
d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>());
}
} else { } else {
if (d_x) { if (d_x) {
BNBackwardData<T, block, framework::DataLayout::kNHWC><<< BNBackwardData<T, block, framework::DataLayout::kNHWC><<<
...@@ -1195,6 +1216,15 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -1195,6 +1216,15 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
saved_mean_data, x->data<T>(), saved_var_data, C, N, H * W * D, saved_mean_data, x->data<T>(), saved_var_data, C, N, H * W * D,
d_x->data<T>()); d_x->data<T>());
} }
if (d_scale && d_bias) {
KeBNBackwardScaleBias<
T, block,
framework::DataLayout::kNHWC><<<grid2, block, 0, stream>>>(
d_y->data<T>(), x->data<T>(), saved_mean_data, saved_var_data,
epsilon, N, C, H * W * D,
d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>());
}
} }
} }
......
...@@ -515,6 +515,13 @@ class TestBatchNormOpTrainingCase2(TestBatchNormOpTraining): ...@@ -515,6 +515,13 @@ class TestBatchNormOpTrainingCase2(TestBatchNormOpTraining):
os.environ['FLAGS_cudnn_batchnorm_spatial_persistent'] = "1" os.environ['FLAGS_cudnn_batchnorm_spatial_persistent'] = "1"
class TestBatchNormOpTrainingCase3(TestBatchNormOpTraining):
def init_test_case(self):
self.use_global_stats = False
self.no_grad_set = set(['x@GRAD'])
self.fetch_list = ['y', 'mean', 'variance', 'scale@GRAD', 'bias@GRAD']
class TestBatchNormOpTrainingMomentumVariable(TestBatchNormOpTraining): class TestBatchNormOpTrainingMomentumVariable(TestBatchNormOpTraining):
def init_test_case(self): def init_test_case(self):
self.use_momentum_variable = True self.use_momentum_variable = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册