未验证 提交 e65cdaee 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] fix sharding (#46572)

上级 e87f65c3
......@@ -396,7 +396,7 @@ class ShardingPass(PassBase):
dp_ring_ids = [group.id for group in self.dp_groups]
for idx, op in reversed(list(enumerate(main_block.ops))):
if is_data_parallel_reduce_op(op):
if _is_param_grad_allreduce_op(op, main_block):
input_name = op.input_arg_names[0]
base_name = _get_base_name_from_grad_name(input_name)
sharding_info = self.varname_to_sharding_info[base_name]
......@@ -653,6 +653,20 @@ def _get_base_name_from_grad_name(grad_name):
return base_name
def _is_param_grad_allreduce_op(op, block):
if not is_data_parallel_reduce_op(op):
return False
output_name = op.output_arg_names[0]
base_name = _get_base_name_from_grad_name(output_name)
if not block.has_var(base_name):
return False
return block.var(base_name).is_parameter
def _is_param_grad_sum_op(op, block):
if not is_backward_op(op):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册