未验证 提交 c7027d9e 编写于 作者: X xiongkun 提交者: GitHub

fix bugs in prepare_gradient_aggregation (#45268)

上级 6fb34e74
......@@ -287,7 +287,8 @@ class PartialProgramLayer:
return main_program
def prepare_gradient_aggregation(self, main_program, target_program):
def prepare_gradient_aggregation(self, start_idx, main_program,
target_program):
"""
Why we need add gradient aggregation operation ?
In some cases, if non leaf nodes are used as output, gradient overwriting will occur, such as
......@@ -323,7 +324,7 @@ class PartialProgramLayer:
new_grad_name = var.name + suffix + "@GRAD"
finded_ops = list(
filter(
lambda x: any([
lambda x: x[0] >= start_idx and any([
out_arg == var_grad_name
for out_arg in x[1].output_arg_names
]), enumerate(target_program.block(0).ops)))
......@@ -367,7 +368,10 @@ class PartialProgramLayer:
if targets and self._params:
backward.gradients(targets=targets, inputs=[])
self.prepare_gradient_aggregation(main_program, program)
start_idx = len(
main_program.block(0).ops) + 2 * len(self._outputs.tolist())
self.prepare_gradient_aggregation(start_idx, main_program, program)
return program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册