未验证 提交 626d7bcb 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim] Make auto parallel judge dim more strict (#47961)

上级 8fece428
......@@ -164,7 +164,7 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
ctx->SetOutputDim("SavedMean", {C});
ctx->SetOutputDim("SavedVariance", {C});
ctx->ShareLoD("X", "Y");
if (ctx->HasInput("ReserveSpace")) {
if (ctx->HasOutput("ReserveSpace")) {
ctx->SetOutputDim("ReserveSpace", {-1});
}
}
......
......@@ -1239,7 +1239,7 @@ class Completer:
input_var
).dims_mapping
else:
if fwd_op_dist_attr.get_input_dims_mapping(input_name):
if input_name in forward_op.input_arg_names:
ref_dims_mapping = (
fwd_op_dist_attr.get_input_dims_mapping(
input_name
......@@ -1544,7 +1544,7 @@ class Completer:
input_var
).dims_mapping
else:
if fwd_op_dist_attr.get_input_dims_mapping(input_name):
if input_name in forward_op.input_arg_names:
ref_dims_mapping = (
fwd_op_dist_attr.get_input_dims_mapping(
input_name
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册