提交 17299b8d 编写于 作者: W WangXi 提交者: gongweibao

fix batch_norm_grad infer shape=0 & add allreduce enforce shape, test=develop (#21801)

上级 a3a3558d
......@@ -84,6 +84,11 @@ void AllReduceOpHandle::AllReduceImpl(
if (i == 0) {
numel = static_cast<int64_t>(lod_tensor.numel());
// only enforce place0, we will enforce other palce numel == place0 numel
PADDLE_ENFORCE_GT(
numel, 0, platform::errors::InvalidArgument(
"The numel of tensos=[%s] must > 0. But now numel=[%d]",
in_var_handles[i]->name(), numel));
dtype = lod_tensor.type();
is_gpu_place = platform::is_gpu_place(lod_tensor.place());
}
......
......@@ -444,11 +444,17 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
// check output
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), "");
if (ctx->HasOutput(framework::GradVarName("Scale"))) {
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")),
"Output(Scale@GRAD) and Output(Bias@GRAD) should not be "
"null at same time");
}
const bool has_scale_grad = ctx->HasOutput(framework::GradVarName("Scale"));
const bool has_bias_grad = ctx->HasOutput(framework::GradVarName("Bias"));
PADDLE_ENFORCE_EQ((has_scale_grad == has_bias_grad), true,
platform::errors::InvalidArgument(
"Output(Scale@GRAD) and Output(Bias@GRAD) must be null "
"or not be null at same time. But now, "
"has Scale@Grad=[%d], has Bias@GRAD=[%d]",
has_scale_grad, has_bias_grad));
const bool use_global_stats = ctx->Attrs().Get<bool>("use_global_stats");
if (use_global_stats) {
PADDLE_ENFORCE(!ctx->Attrs().Get<bool>("use_mkldnn"),
......@@ -463,7 +469,8 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
: x_dims[x_dims.size() - 1]);
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
if (ctx->HasOutput(framework::GradVarName("Scale"))) {
// has_scale_grad == has_bias_grad, judge has_scale_grad is enough
if (has_scale_grad) {
ctx->SetOutputDim(framework::GradVarName("Scale"), {C});
ctx->SetOutputDim(framework::GradVarName("Bias"), {C});
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册