未验证 提交 822e42d7 编写于 作者: Y Yuang Liu 提交者: GitHub

[auto parallel] bug fix for op has sub_block attr created with copy_from (#44664)

上级 8fc1cf60
......@@ -179,6 +179,16 @@ class Partitioner(object):
partitioned_main_prog.current_block_idx = 0
# should reconnect the block_attr ptr to the correct block
for block_id in range(self._dist_context.block_state.nblock):
block = partitioned_main_prog.block(block_id)
for op in block.ops:
for attr_name in op.all_attrs():
if op.attr_type(attr_name) == core.AttrType.BLOCK:
relative_id = op._block_attr_id(attr_name)
op._set_attr(attr_name,
partitioned_main_prog.block(relative_id))
partitioned_params_and_grads = []
for p, g in params_and_grads:
assert p.name in self._serial2dist_varname_mapping
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册