From 17299b8d217c0872408cc9146a58f0769d8b05ba Mon Sep 17 00:00:00 2001 From: WangXi Date: Thu, 19 Dec 2019 11:01:37 +0800 Subject: [PATCH] fix batch_norm_grad infer shape=0 & add allreduce enforce shape, test=develop (#21801) --- .../framework/details/all_reduce_op_handle.cc | 5 +++++ paddle/fluid/operators/batch_norm_op.cc | 19 +++++++++++++------ 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.cc b/paddle/fluid/framework/details/all_reduce_op_handle.cc index 8deacf4e7dd..6c3b0923ed0 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/all_reduce_op_handle.cc @@ -84,6 +84,11 @@ void AllReduceOpHandle::AllReduceImpl( if (i == 0) { numel = static_cast(lod_tensor.numel()); + // only enforce place0, we will enforce other palce numel == place0 numel + PADDLE_ENFORCE_GT( + numel, 0, platform::errors::InvalidArgument( + "The numel of tensos=[%s] must > 0. But now numel=[%d]", + in_var_handles[i]->name(), numel)); dtype = lod_tensor.type(); is_gpu_place = platform::is_gpu_place(lod_tensor.place()); } diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 0708e104f2d..81526608265 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -444,11 +444,17 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const { // check output PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), ""); - if (ctx->HasOutput(framework::GradVarName("Scale"))) { - PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")), - "Output(Scale@GRAD) and Output(Bias@GRAD) should not be " - "null at same time"); - } + + const bool has_scale_grad = ctx->HasOutput(framework::GradVarName("Scale")); + const bool has_bias_grad = ctx->HasOutput(framework::GradVarName("Bias")); + + PADDLE_ENFORCE_EQ((has_scale_grad == has_bias_grad), true, + platform::errors::InvalidArgument( + "Output(Scale@GRAD) and Output(Bias@GRAD) must be null " + "or not be null at same time. But now, " + "has Scale@Grad=[%d], has Bias@GRAD=[%d]", + has_scale_grad, has_bias_grad)); + const bool use_global_stats = ctx->Attrs().Get("use_global_stats"); if (use_global_stats) { PADDLE_ENFORCE(!ctx->Attrs().Get("use_mkldnn"), @@ -463,7 +469,8 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const { : x_dims[x_dims.size() - 1]); ctx->SetOutputDim(framework::GradVarName("X"), x_dims); - if (ctx->HasOutput(framework::GradVarName("Scale"))) { + // has_scale_grad == has_bias_grad, judge has_scale_grad is enough + if (has_scale_grad) { ctx->SetOutputDim(framework::GradVarName("Scale"), {C}); ctx->SetOutputDim(framework::GradVarName("Bias"), {C}); } -- GitLab