diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 6c6591f34abcef8e9a8ea686cc1e2f191bc82336..878ab18432cdcf6b5ecdd0cccf88158304ae4219 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -164,7 +164,7 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { ctx->SetOutputDim("SavedMean", {C}); ctx->SetOutputDim("SavedVariance", {C}); ctx->ShareLoD("X", "Y"); - if (ctx->HasInput("ReserveSpace")) { + if (ctx->HasOutput("ReserveSpace")) { ctx->SetOutputDim("ReserveSpace", {-1}); } } diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index c0f70f482dd17f598c483c44f19e965d914bf7e2..7f5e0fee7752671e270e87b51f57ccb88248c008 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -1239,7 +1239,7 @@ class Completer: input_var ).dims_mapping else: - if fwd_op_dist_attr.get_input_dims_mapping(input_name): + if input_name in forward_op.input_arg_names: ref_dims_mapping = ( fwd_op_dist_attr.get_input_dims_mapping( input_name @@ -1544,7 +1544,7 @@ class Completer: input_var ).dims_mapping else: - if fwd_op_dist_attr.get_input_dims_mapping(input_name): + if input_name in forward_op.input_arg_names: ref_dims_mapping = ( fwd_op_dist_attr.get_input_dims_mapping( input_name