diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index 98b76056a15a4229cbe8d7d7b6419dfaf1b2ee3b..8efb9eb719237fd433b9fb02b0772eb6581319e4 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -139,6 +139,9 @@ class Engine: self._completer = Completer(self._dist_contexts[self.mode]) self._completer.complete_forward_annotation(serial_main_prog) # TODO: add auto planner process + # parse forward sub block + self._dist_contexts[self.mode].block_state.parse_forward_blocks( + serial_main_prog) def _parallel(self, rank): serial_main_program = self._serial_main_progs[self.mode] @@ -177,6 +180,8 @@ class Engine: loss, distop_context=self._dist_contexts[self.mode].dist_op_context) self._completer.complete_backward_annotation(main_program) + self._dist_contexts[self.mode].block_state.parse_backward_blocks( + main_program) return params_grads def _generate_optimizer(self, main_program, startup_program, params_grads):