diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index edad20435b41c9eb59c3df793c00ab3bfe96771b..b2cffc3f9063c1fd3b33baa9c740c2402fa00080 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -464,11 +464,9 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const { "BatchNormGrad"); // check output - OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", - framework::GradVarName("X"), "BatchNormGrad"); - const bool has_scale_grad = ctx->HasOutput(framework::GradVarName("Scale")); const bool has_bias_grad = ctx->HasOutput(framework::GradVarName("Bias")); + const bool has_x_grad = ctx->HasOutput(framework::GradVarName("X")); PADDLE_ENFORCE_EQ((has_scale_grad == has_bias_grad), true, platform::errors::NotFound( @@ -496,12 +494,14 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const { ? x_dims[1] : x_dims[x_dims.size() - 1]); - ctx->SetOutputDim(framework::GradVarName("X"), x_dims); // has_scale_grad == has_bias_grad, judge has_scale_grad is enough if (has_scale_grad) { ctx->SetOutputDim(framework::GradVarName("Scale"), {C}); ctx->SetOutputDim(framework::GradVarName("Bias"), {C}); } + if (has_x_grad) { + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + } } framework::OpKernelType BatchNormGradOp::GetExpectedKernelType( @@ -596,15 +596,20 @@ class BatchNormGradKernel if (ctx.HasInput("Y")) { x = ctx.Input("Y"); is_inplace = true; - PADDLE_ENFORCE_EQ(d_x, d_y, - platform::errors::InvalidArgument( - "X@GRAD and Y@GRAD not inplace in inplace mode")); + // if the input of batch norm is stop_gradient, d_x is null. + if (d_x) { + PADDLE_ENFORCE_EQ(d_x, d_y, + platform::errors::InvalidArgument( + "X@GRAD and Y@GRAD not inplace in inplace mode")); + } } else { x = ctx.Input("X"); is_inplace = false; - PADDLE_ENFORCE_NE(d_x, d_y, - platform::errors::InvalidArgument( - "X@GRAD and Y@GRAD inplaced in non-inplace mode")); + if (d_x) { + PADDLE_ENFORCE_NE( + d_x, d_y, platform::errors::InvalidArgument( + "X@GRAD and Y@GRAD inplaced in non-inplace mode")); + } } // Get the size for each dimension. @@ -629,7 +634,9 @@ class BatchNormGradKernel const int sample_size = x->numel() / N / C; // init output - d_x->mutable_data(ctx.GetPlace()); + if (d_x) { + d_x->mutable_data(ctx.GetPlace()); + } const T *mean_data = saved_mean->data(); const T *inv_var_data = saved_inv_variance->data(); @@ -673,7 +680,7 @@ class BatchNormGradKernel d_scale_arr.setZero(); } - if ((N * sample_size) == 1 && !use_global_stats) { + if (d_x && (N * sample_size) == 1 && !use_global_stats) { framework::TensorCopy(*d_y, ctx.GetPlace(), d_x); return; } @@ -718,8 +725,6 @@ class BatchNormGradKernel } ConstEigenArrayMap x_arr(x->data(), sample_size, N * C); ConstEigenArrayMap d_y_arr(d_y->data(), sample_size, N * C); - EigenArrayMap d_x_arr(d_x->mutable_data(ctx.GetPlace()), - sample_size, N * C); for (int nc = 0; nc < N * C; ++nc) { int c = nc % C; @@ -734,19 +739,24 @@ class BatchNormGradKernel d_scale_arr = dy_mul_x_sub_mean_mul_invstd_sum_arr; } - if (!use_global_stats) { - for (int nc = 0; nc < N * C; ++nc) { - int c = nc % C; - d_x_arr.col(nc) = - scale_inv_var_nhw(c) * - (d_y_arr.col(nc) * N * sample_size - dy_sum_arr(c) - - (x_arr.col(nc) - mean_arr[c]) * - dy_mul_x_sub_mean_mul_invstd_sum_arr(c) * inv_var_arr(c)); - } - } else { - for (int nc = 0; nc < N * C; ++nc) { - int c = nc % C; - d_x_arr.col(nc) = scale_inv_var_nhw(c) * d_y_arr.col(nc); + if (d_x) { + EigenArrayMap d_x_arr(d_x->mutable_data(ctx.GetPlace()), + sample_size, N * C); + if (!use_global_stats) { + for (int nc = 0; nc < N * C; ++nc) { + int c = nc % C; + d_x_arr.col(nc) = + scale_inv_var_nhw(c) * + (d_y_arr.col(nc) * N * sample_size - dy_sum_arr(c) - + (x_arr.col(nc) - mean_arr[c]) * + dy_mul_x_sub_mean_mul_invstd_sum_arr(c) * + inv_var_arr(c)); + } + } else { + for (int nc = 0; nc < N * C; ++nc) { + int c = nc % C; + d_x_arr.col(nc) = scale_inv_var_nhw(c) * d_y_arr.col(nc); + } } } break; @@ -765,8 +775,6 @@ class BatchNormGradKernel } ConstEigenArrayMap x_arr(x->data(), C, N * sample_size); ConstEigenArrayMap d_y_arr(d_y->data(), C, N * sample_size); - EigenArrayMap d_x_arr(d_x->mutable_data(ctx.GetPlace()), C, - N * sample_size); for (int nhw = 0; nhw < N * sample_size; ++nhw) { dy_sum_arr += d_y_arr.col(nhw); @@ -779,17 +787,21 @@ class BatchNormGradKernel d_scale_arr = dy_mul_x_sub_mean_mul_invstd_sum_arr; } - if (!use_global_stats) { - for (int nhw = 0; nhw < N * sample_size; ++nhw) { - d_x_arr.col(nhw) = - scale_inv_var_nhw * - (d_y_arr.col(nhw) * N * sample_size - dy_sum_arr - - (x_arr.col(nhw) - mean_arr) * - dy_mul_x_sub_mean_mul_invstd_sum_arr * inv_var_arr); - } - } else { - for (int nhw = 0; nhw < N * sample_size; ++nhw) { - d_x_arr.col(nhw) = scale_inv_var_nhw * d_y_arr.col(nhw); + if (d_x) { + EigenArrayMap d_x_arr(d_x->mutable_data(ctx.GetPlace()), C, + N * sample_size); + if (!use_global_stats) { + for (int nhw = 0; nhw < N * sample_size; ++nhw) { + d_x_arr.col(nhw) = + scale_inv_var_nhw * + (d_y_arr.col(nhw) * N * sample_size - dy_sum_arr - + (x_arr.col(nhw) - mean_arr) * + dy_mul_x_sub_mean_mul_invstd_sum_arr * inv_var_arr); + } + } else { + for (int nhw = 0; nhw < N * sample_size; ++nhw) { + d_x_arr.col(nhw) = scale_inv_var_nhw * d_y_arr.col(nhw); + } } } break; diff --git a/paddle/fluid/operators/batch_norm_op.cu b/paddle/fluid/operators/batch_norm_op.cu index 42e1e2e7463c7753fbf205c88442db63733754ea..3d26c2c570858e11771fc27afabdcce5c0fb9443 100644 --- a/paddle/fluid/operators/batch_norm_op.cu +++ b/paddle/fluid/operators/batch_norm_op.cu @@ -840,15 +840,19 @@ class BatchNormGradKernel if (ctx.HasInput("Y")) { x = ctx.Input("Y"); is_inplace = true; - PADDLE_ENFORCE_EQ(d_x, d_y, - platform::errors::InvalidArgument( - "X@GRAD and Y@GRAD not inplace in inplace mode")); + if (d_x) { + PADDLE_ENFORCE_EQ(d_x, d_y, + platform::errors::InvalidArgument( + "X@GRAD and Y@GRAD not inplace in inplace mode")); + } } else { x = ctx.Input("X"); is_inplace = false; - PADDLE_ENFORCE_NE(d_x, d_y, - platform::errors::InvalidArgument( - "X@GRAD and Y@GRAD inplaced in non-inplace mode")); + if (d_x) { + PADDLE_ENFORCE_NE( + d_x, d_y, platform::errors::InvalidArgument( + "X@GRAD and Y@GRAD inplaced in non-inplace mode")); + } } const bool is_test = ctx.Attr("is_test"); @@ -867,7 +871,9 @@ class BatchNormGradKernel ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D); // init output - d_x->mutable_data(ctx.GetPlace()); + if (d_x) { + d_x->mutable_data(ctx.GetPlace()); + } if (d_scale && d_bias) { d_scale->mutable_data>(ctx.GetPlace()); @@ -908,7 +914,7 @@ class BatchNormGradKernel Tensor transformed_x(x->type()); Tensor transformed_d_y(d_y->type()); - Tensor transformed_d_x(d_x->type()); + Tensor transformed_d_x; if (data_layout == DataLayout::kNHWC && compute_format == DataLayout::kNCHW) { VLOG(3) << "Transform input tensor from NHWC to NCHW."; @@ -920,12 +926,16 @@ class BatchNormGradKernel &transformed_d_y); TransToChannelFirst(ctx, d_y, &transformed_d_y); - ResizeToChannelFirst(ctx, d_x, - &transformed_d_x); + if (d_x) { + ResizeToChannelFirst(ctx, d_x, + &transformed_d_x); + } } else { transformed_x.ShareDataWith(*x); transformed_d_y.ShareDataWith(*d_y); - transformed_d_x.ShareDataWith(*d_x); + if (d_x) { + transformed_d_x.ShareDataWith(*d_x); + } } std::vector dims; @@ -954,7 +964,9 @@ class BatchNormGradKernel if (!use_global_stats) { if ((N * H * W * D) == 1) { - framework::TensorCopy(*d_y, ctx.GetPlace(), d_x); + if (d_x) { + framework::TensorCopy(*d_y, ctx.GetPlace(), d_x); + } math::SetConstant> functor; functor(dev_ctx, d_scale, static_cast>(0)); @@ -1042,7 +1054,7 @@ class BatchNormGradKernel } // This branch calls CUDNN APIs - if (d_scale && d_bias) { + if (d_x && d_scale && d_bias) { bool called = false; #if CUDNN_VERSION_MIN(7, 4, 1) called = true; @@ -1187,6 +1199,15 @@ class BatchNormGradKernel saved_mean_data, x->data(), saved_var_data, C, N, H * W * D, d_x->data()); } + if (d_scale && d_bias) { + KeBNBackwardScaleBias< + T, block, + framework::DataLayout::kNCHW><<>>( + d_y->data(), x->data(), saved_mean_data, saved_var_data, + epsilon, N, C, H * W * D, + d_scale->data>(), + d_bias->data>()); + } } else { if (d_x) { BNBackwardData<<< @@ -1195,6 +1216,15 @@ class BatchNormGradKernel saved_mean_data, x->data(), saved_var_data, C, N, H * W * D, d_x->data()); } + if (d_scale && d_bias) { + KeBNBackwardScaleBias< + T, block, + framework::DataLayout::kNHWC><<>>( + d_y->data(), x->data(), saved_mean_data, saved_var_data, + epsilon, N, C, H * W * D, + d_scale->data>(), + d_bias->data>()); + } } } diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py index 2eb334d09563129fb6555186d7e27497415e740f..9eaa69ce644285725987dac8a2d939e4ce798eca 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py @@ -515,6 +515,13 @@ class TestBatchNormOpTrainingCase2(TestBatchNormOpTraining): os.environ['FLAGS_cudnn_batchnorm_spatial_persistent'] = "1" +class TestBatchNormOpTrainingCase3(TestBatchNormOpTraining): + def init_test_case(self): + self.use_global_stats = False + self.no_grad_set = set(['x@GRAD']) + self.fetch_list = ['y', 'mean', 'variance', 'scale@GRAD', 'bias@GRAD'] + + class TestBatchNormOpTrainingMomentumVariable(TestBatchNormOpTraining): def init_test_case(self): self.use_momentum_variable = True