diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 9ce5f851846e8f6798c5fb009be60f2fb3119b11..0b3efefd28edc79941a5068f38ffbe7addd31519 100755 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -474,11 +474,13 @@ def _addup_repetitive_outputs_(op_descs, block_idx): continue if len(renamed_vars[var_name]) > 1: if len(renamed_vars[var_name]) > _MAX_ADD_NUM_: - _accumulate_gradients_by_sum_op_( - var_name, renamed_vars, pending_sum_ops, idx, op_device) + _accumulate_gradients_by_sum_op_(var_name, renamed_vars, + pending_sum_ops, idx, + var_device[var_name]) else: - _accumulate_gradients_by_add_ops_( - var_name, renamed_vars, pending_sum_ops, idx, op_device) + _accumulate_gradients_by_add_ops_(var_name, renamed_vars, + pending_sum_ops, idx, + var_device[var_name]) for param_idx, param_name in enumerate(op_desc.output_names()): arg_names = op_desc.output(param_name) @@ -529,7 +531,7 @@ def _addup_repetitive_outputs_(op_descs, block_idx): arg_names[arg_idx] = new_name op_desc.set_output(param_name, arg_names) renamed_vars[var_name].append(new_name) - # record the latest device, for shared param + # record the latest device var_device[var_name] = op_device for var_name, inputs in six.iteritems(renamed_vars):