From 626d7bcbce10297ee726d44a407135461484635c Mon Sep 17 00:00:00 2001 From: Zhou Wei <1183042833@qq.com> Date: Tue, 15 Nov 2022 17:30:04 +0800 Subject: [PATCH] [Zero-Dim] Make auto parallel judge dim more strict (#47961) --- paddle/fluid/operators/batch_norm_op.cc | 2 +- python/paddle/distributed/auto_parallel/completion.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 6c6591f34a..878ab18432 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -164,7 +164,7 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { ctx->SetOutputDim("SavedMean", {C}); ctx->SetOutputDim("SavedVariance", {C}); ctx->ShareLoD("X", "Y"); - if (ctx->HasInput("ReserveSpace")) { + if (ctx->HasOutput("ReserveSpace")) { ctx->SetOutputDim("ReserveSpace", {-1}); } } diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index c0f70f482d..7f5e0fee77 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -1239,7 +1239,7 @@ class Completer: input_var ).dims_mapping else: - if fwd_op_dist_attr.get_input_dims_mapping(input_name): + if input_name in forward_op.input_arg_names: ref_dims_mapping = ( fwd_op_dist_attr.get_input_dims_mapping( input_name @@ -1544,7 +1544,7 @@ class Completer: input_var ).dims_mapping else: - if fwd_op_dist_attr.get_input_dims_mapping(input_name): + if input_name in forward_op.input_arg_names: ref_dims_mapping = ( fwd_op_dist_attr.get_input_dims_mapping( input_name -- GitLab