未验证 提交 626d7bcb 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim] Make auto parallel judge dim more strict (#47961)

上级 8fece428
...@@ -164,7 +164,7 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -164,7 +164,7 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
ctx->SetOutputDim("SavedMean", {C}); ctx->SetOutputDim("SavedMean", {C});
ctx->SetOutputDim("SavedVariance", {C}); ctx->SetOutputDim("SavedVariance", {C});
ctx->ShareLoD("X", "Y"); ctx->ShareLoD("X", "Y");
if (ctx->HasInput("ReserveSpace")) { if (ctx->HasOutput("ReserveSpace")) {
ctx->SetOutputDim("ReserveSpace", {-1}); ctx->SetOutputDim("ReserveSpace", {-1});
} }
} }
......
...@@ -1239,7 +1239,7 @@ class Completer: ...@@ -1239,7 +1239,7 @@ class Completer:
input_var input_var
).dims_mapping ).dims_mapping
else: else:
if fwd_op_dist_attr.get_input_dims_mapping(input_name): if input_name in forward_op.input_arg_names:
ref_dims_mapping = ( ref_dims_mapping = (
fwd_op_dist_attr.get_input_dims_mapping( fwd_op_dist_attr.get_input_dims_mapping(
input_name input_name
...@@ -1544,7 +1544,7 @@ class Completer: ...@@ -1544,7 +1544,7 @@ class Completer:
input_var input_var
).dims_mapping ).dims_mapping
else: else:
if fwd_op_dist_attr.get_input_dims_mapping(input_name): if input_name in forward_op.input_arg_names:
ref_dims_mapping = ( ref_dims_mapping = (
fwd_op_dist_attr.get_input_dims_mapping( fwd_op_dist_attr.get_input_dims_mapping(
input_name input_name
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册