diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py index 5011cb7534f67c55fa7489eabd41f199c984be3b..23744db61a11daee7fe1600befba8b825f2880e6 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py @@ -287,7 +287,8 @@ class PartialProgramLayer: return main_program - def prepare_gradient_aggregation(self, main_program, target_program): + def prepare_gradient_aggregation(self, start_idx, main_program, + target_program): """ Why we need add gradient aggregation operation ? In some cases, if non leaf nodes are used as output, gradient overwriting will occur, such as @@ -323,7 +324,7 @@ class PartialProgramLayer: new_grad_name = var.name + suffix + "@GRAD" finded_ops = list( filter( - lambda x: any([ + lambda x: x[0] >= start_idx and any([ out_arg == var_grad_name for out_arg in x[1].output_arg_names ]), enumerate(target_program.block(0).ops))) @@ -367,7 +368,10 @@ class PartialProgramLayer: if targets and self._params: backward.gradients(targets=targets, inputs=[]) - self.prepare_gradient_aggregation(main_program, program) + start_idx = len( + main_program.block(0).ops) + 2 * len(self._outputs.tolist()) + + self.prepare_gradient_aggregation(start_idx, main_program, program) return program