未验证 提交 c7027d9e 编写于 作者: X xiongkun 提交者: GitHub

fix bugs in prepare_gradient_aggregation (#45268)

上级 6fb34e74
...@@ -287,7 +287,8 @@ class PartialProgramLayer: ...@@ -287,7 +287,8 @@ class PartialProgramLayer:
return main_program return main_program
def prepare_gradient_aggregation(self, main_program, target_program): def prepare_gradient_aggregation(self, start_idx, main_program,
target_program):
""" """
Why we need add gradient aggregation operation ? Why we need add gradient aggregation operation ?
In some cases, if non leaf nodes are used as output, gradient overwriting will occur, such as In some cases, if non leaf nodes are used as output, gradient overwriting will occur, such as
...@@ -323,7 +324,7 @@ class PartialProgramLayer: ...@@ -323,7 +324,7 @@ class PartialProgramLayer:
new_grad_name = var.name + suffix + "@GRAD" new_grad_name = var.name + suffix + "@GRAD"
finded_ops = list( finded_ops = list(
filter( filter(
lambda x: any([ lambda x: x[0] >= start_idx and any([
out_arg == var_grad_name out_arg == var_grad_name
for out_arg in x[1].output_arg_names for out_arg in x[1].output_arg_names
]), enumerate(target_program.block(0).ops))) ]), enumerate(target_program.block(0).ops)))
...@@ -367,7 +368,10 @@ class PartialProgramLayer: ...@@ -367,7 +368,10 @@ class PartialProgramLayer:
if targets and self._params: if targets and self._params:
backward.gradients(targets=targets, inputs=[]) backward.gradients(targets=targets, inputs=[])
self.prepare_gradient_aggregation(main_program, program) start_idx = len(
main_program.block(0).ops) + 2 * len(self._outputs.tolist())
self.prepare_gradient_aggregation(start_idx, main_program, program)
return program return program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册