From 75d247b72c928e1ee68bfb934184bf9a4596df57 Mon Sep 17 00:00:00 2001 From: WangXi Date: Mon, 5 Jul 2021 16:08:26 +0800 Subject: [PATCH] optimize grad add device (#33946) --- python/paddle/fluid/backward.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 9ce5f851846..0b3efefd28e 100755 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -474,11 +474,13 @@ def _addup_repetitive_outputs_(op_descs, block_idx): continue if len(renamed_vars[var_name]) > 1: if len(renamed_vars[var_name]) > _MAX_ADD_NUM_: - _accumulate_gradients_by_sum_op_( - var_name, renamed_vars, pending_sum_ops, idx, op_device) + _accumulate_gradients_by_sum_op_(var_name, renamed_vars, + pending_sum_ops, idx, + var_device[var_name]) else: - _accumulate_gradients_by_add_ops_( - var_name, renamed_vars, pending_sum_ops, idx, op_device) + _accumulate_gradients_by_add_ops_(var_name, renamed_vars, + pending_sum_ops, idx, + var_device[var_name]) for param_idx, param_name in enumerate(op_desc.output_names()): arg_names = op_desc.output(param_name) @@ -529,7 +531,7 @@ def _addup_repetitive_outputs_(op_descs, block_idx): arg_names[arg_idx] = new_name op_desc.set_output(param_name, arg_names) renamed_vars[var_name].append(new_name) - # record the latest device, for shared param + # record the latest device var_device[var_name] = op_device for var_name, inputs in six.iteritems(renamed_vars): -- GitLab