diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 708167a0273996fbb67eddec711ccff2aca5e759..9ce5f851846e8f6798c5fb009be60f2fb3119b11 100755 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -462,6 +462,7 @@ def _addup_repetitive_outputs_(op_descs, block_idx): var_rename_count = collections.defaultdict(int) renamed_vars = collections.defaultdict(list) renamed_var_start_idx = collections.defaultdict(list) + var_device = collections.defaultdict(str) for idx, op_desc in enumerate(op_descs): op_device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName( ) @@ -528,16 +529,19 @@ def _addup_repetitive_outputs_(op_descs, block_idx): arg_names[arg_idx] = new_name op_desc.set_output(param_name, arg_names) renamed_vars[var_name].append(new_name) + # record the latest device, for shared param + var_device[var_name] = op_device for var_name, inputs in six.iteritems(renamed_vars): if len(renamed_vars[var_name]) > 1: if len(renamed_vars[var_name]) > _MAX_ADD_NUM_: - _accumulate_gradients_by_sum_op_(var_name, renamed_vars, - pending_sum_ops, len(op_descs)) + _accumulate_gradients_by_sum_op_( + var_name, renamed_vars, pending_sum_ops, + len(op_descs), var_device[var_name]) else: - _accumulate_gradients_by_add_ops_(var_name, renamed_vars, - pending_sum_ops, - len(op_descs)) + _accumulate_gradients_by_add_ops_( + var_name, renamed_vars, pending_sum_ops, + len(op_descs), var_device[var_name]) # sum_op descs are sorted according to their insert position for key, value in collections.OrderedDict(