未验证 提交 5fd7b5c3 编写于 作者: J JZ-LIANG 提交者: GitHub

fix bug for block state (#39854)

上级 6b5749eb
...@@ -139,6 +139,9 @@ class Engine: ...@@ -139,6 +139,9 @@ class Engine:
self._completer = Completer(self._dist_contexts[self.mode]) self._completer = Completer(self._dist_contexts[self.mode])
self._completer.complete_forward_annotation(serial_main_prog) self._completer.complete_forward_annotation(serial_main_prog)
# TODO: add auto planner process # TODO: add auto planner process
# parse forward sub block
self._dist_contexts[self.mode].block_state.parse_forward_blocks(
serial_main_prog)
def _parallel(self, rank): def _parallel(self, rank):
serial_main_program = self._serial_main_progs[self.mode] serial_main_program = self._serial_main_progs[self.mode]
...@@ -177,6 +180,8 @@ class Engine: ...@@ -177,6 +180,8 @@ class Engine:
loss, loss,
distop_context=self._dist_contexts[self.mode].dist_op_context) distop_context=self._dist_contexts[self.mode].dist_op_context)
self._completer.complete_backward_annotation(main_program) self._completer.complete_backward_annotation(main_program)
self._dist_contexts[self.mode].block_state.parse_backward_blocks(
main_program)
return params_grads return params_grads
def _generate_optimizer(self, main_program, startup_program, params_grads): def _generate_optimizer(self, main_program, startup_program, params_grads):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册