提交 17299b8d 编写于 作者: W WangXi 提交者: gongweibao

fix batch_norm_grad infer shape=0 & add allreduce enforce shape, test=develop (#21801)

上级 a3a3558d
...@@ -84,6 +84,11 @@ void AllReduceOpHandle::AllReduceImpl( ...@@ -84,6 +84,11 @@ void AllReduceOpHandle::AllReduceImpl(
if (i == 0) { if (i == 0) {
numel = static_cast<int64_t>(lod_tensor.numel()); numel = static_cast<int64_t>(lod_tensor.numel());
// only enforce place0, we will enforce other palce numel == place0 numel
PADDLE_ENFORCE_GT(
numel, 0, platform::errors::InvalidArgument(
"The numel of tensos=[%s] must > 0. But now numel=[%d]",
in_var_handles[i]->name(), numel));
dtype = lod_tensor.type(); dtype = lod_tensor.type();
is_gpu_place = platform::is_gpu_place(lod_tensor.place()); is_gpu_place = platform::is_gpu_place(lod_tensor.place());
} }
......
...@@ -444,11 +444,17 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -444,11 +444,17 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
// check output // check output
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), ""); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), "");
if (ctx->HasOutput(framework::GradVarName("Scale"))) {
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")), const bool has_scale_grad = ctx->HasOutput(framework::GradVarName("Scale"));
"Output(Scale@GRAD) and Output(Bias@GRAD) should not be " const bool has_bias_grad = ctx->HasOutput(framework::GradVarName("Bias"));
"null at same time");
} PADDLE_ENFORCE_EQ((has_scale_grad == has_bias_grad), true,
platform::errors::InvalidArgument(
"Output(Scale@GRAD) and Output(Bias@GRAD) must be null "
"or not be null at same time. But now, "
"has Scale@Grad=[%d], has Bias@GRAD=[%d]",
has_scale_grad, has_bias_grad));
const bool use_global_stats = ctx->Attrs().Get<bool>("use_global_stats"); const bool use_global_stats = ctx->Attrs().Get<bool>("use_global_stats");
if (use_global_stats) { if (use_global_stats) {
PADDLE_ENFORCE(!ctx->Attrs().Get<bool>("use_mkldnn"), PADDLE_ENFORCE(!ctx->Attrs().Get<bool>("use_mkldnn"),
...@@ -463,7 +469,8 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -463,7 +469,8 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
: x_dims[x_dims.size() - 1]); : x_dims[x_dims.size() - 1]);
ctx->SetOutputDim(framework::GradVarName("X"), x_dims); ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
if (ctx->HasOutput(framework::GradVarName("Scale"))) { // has_scale_grad == has_bias_grad, judge has_scale_grad is enough
if (has_scale_grad) {
ctx->SetOutputDim(framework::GradVarName("Scale"), {C}); ctx->SetOutputDim(framework::GradVarName("Scale"), {C});
ctx->SetOutputDim(framework::GradVarName("Bias"), {C}); ctx->SetOutputDim(framework::GradVarName("Bias"), {C});
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册