diff --git a/python/paddle/distributed/auto_parallel/partitioner.py b/python/paddle/distributed/auto_parallel/partitioner.py index 97ff881ef95bfdc7e07b8da64a18aa417f493f60..3eb8437db6b1f1b3157eb512fb55bb748bc670fb 100644 --- a/python/paddle/distributed/auto_parallel/partitioner.py +++ b/python/paddle/distributed/auto_parallel/partitioner.py @@ -179,6 +179,16 @@ class Partitioner(object): partitioned_main_prog.current_block_idx = 0 + # should reconnect the block_attr ptr to the correct block + for block_id in range(self._dist_context.block_state.nblock): + block = partitioned_main_prog.block(block_id) + for op in block.ops: + for attr_name in op.all_attrs(): + if op.attr_type(attr_name) == core.AttrType.BLOCK: + relative_id = op._block_attr_id(attr_name) + op._set_attr(attr_name, + partitioned_main_prog.block(relative_id)) + partitioned_params_and_grads = [] for p, g in params_and_grads: assert p.name in self._serial2dist_varname_mapping