From 5fd7b5c3092bef2e48817da8849c267835a43890 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Thu, 24 Feb 2022 10:42:13 +0800 Subject: [PATCH] fix bug for block state (#39854) --- python/paddle/distributed/auto_parallel/engine.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index 98b76056a15..8efb9eb7192 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -139,6 +139,9 @@ class Engine: self._completer = Completer(self._dist_contexts[self.mode]) self._completer.complete_forward_annotation(serial_main_prog) # TODO: add auto planner process + # parse forward sub block + self._dist_contexts[self.mode].block_state.parse_forward_blocks( + serial_main_prog) def _parallel(self, rank): serial_main_program = self._serial_main_progs[self.mode] @@ -177,6 +180,8 @@ class Engine: loss, distop_context=self._dist_contexts[self.mode].dist_op_context) self._completer.complete_backward_annotation(main_program) + self._dist_contexts[self.mode].block_state.parse_backward_blocks( + main_program) return params_grads def _generate_optimizer(self, main_program, startup_program, params_grads): -- GitLab