From c7027d9e67279323638d737fbe8e551353650f56 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Fri, 19 Aug 2022 19:12:43 +0800 Subject: [PATCH] fix bugs in prepare_gradient_aggregation (#45268) --- .../fluid/dygraph/dygraph_to_static/partial_program.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py index 5011cb7534f..23744db61a1 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py @@ -287,7 +287,8 @@ class PartialProgramLayer: return main_program - def prepare_gradient_aggregation(self, main_program, target_program): + def prepare_gradient_aggregation(self, start_idx, main_program, + target_program): """ Why we need add gradient aggregation operation ? In some cases, if non leaf nodes are used as output, gradient overwriting will occur, such as @@ -323,7 +324,7 @@ class PartialProgramLayer: new_grad_name = var.name + suffix + "@GRAD" finded_ops = list( filter( - lambda x: any([ + lambda x: x[0] >= start_idx and any([ out_arg == var_grad_name for out_arg in x[1].output_arg_names ]), enumerate(target_program.block(0).ops))) @@ -367,7 +368,10 @@ class PartialProgramLayer: if targets and self._params: backward.gradients(targets=targets, inputs=[]) - self.prepare_gradient_aggregation(main_program, program) + start_idx = len( + main_program.block(0).ops) + 2 * len(self._outputs.tolist()) + + self.prepare_gradient_aggregation(start_idx, main_program, program) return program -- GitLab