未验证 提交 5915d3aa 编写于 作者: Y yaoxuefeng 提交者: GitHub

fix add grad bug test=develop (#22924) (#23024)

上级 e3b28d5b
...@@ -189,6 +189,12 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -189,6 +189,12 @@ class DistributedAdam(DistributedOptimizerImplBase):
sparse_table_index = 0 sparse_table_index = 0
for loss in losses: for loss in losses:
prog_id = str(id(loss.block.program)) prog_id = str(id(loss.block.program))
# param_grads of program
params_grads = sorted(
fluid.backward.append_backward(loss, parameter_list,
no_grad_set),
key=lambda x: x[0].name)
if prog_id not in program_id_set: if prog_id not in program_id_set:
program_id_set.add(prog_id) program_id_set.add(prog_id)
sparse_table = self._find_multi_distributed_lookup_table([loss]) sparse_table = self._find_multi_distributed_lookup_table([loss])
...@@ -215,11 +221,6 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -215,11 +221,6 @@ class DistributedAdam(DistributedOptimizerImplBase):
loss.block.program, sparse_table) loss.block.program, sparse_table)
prog_id_to_sparse_grads[prog_id] = grads_dict prog_id_to_sparse_grads[prog_id] = grads_dict
# param_grads of program
params_grads = sorted(
fluid.backward.append_backward(loss, parameter_list,
no_grad_set),
key=lambda x: x[0].name)
if prog_id not in prog_id_to_param_grads: if prog_id not in prog_id_to_param_grads:
prog_id_to_param_grads[prog_id] = [] prog_id_to_param_grads[prog_id] = []
prog_id_to_param_grads[prog_id].append(params_grads) prog_id_to_param_grads[prog_id].append(params_grads)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册