From 9901f69677e9b85f1a5b8f6ac97ea1f3e2887375 Mon Sep 17 00:00:00 2001 From: mapingshuo Date: Mon, 23 Sep 2019 19:21:43 +0800 Subject: [PATCH] Forward recompute3 (#19913) * add recompute based checkpoints methods for large batch training test=develop * add append_backward_with_forward_recomputation test=develop * refine optimizer test=develop * update backward and optimizer test=develop * make Variable usable test=develop * add recompute code * refine optimizer test=develop * refine addup _append_backward_ops_with_checkpoints_ 1) for recompute part, just cache the grad_op_desc without appending to block 2) before appending grad_op_desc to backward part, addup_repetitive_vars, remove unused branch test=develop * make method private * add recompute strategy into DistributedStrategy test=develop * checkpoint version3 test=develop * remove some print information test=develop * remove unused sumop test=develop * try to fix recompute with graph building modules * add input names to vars should be held * add memory debug tool * backup backward * Fix bugs * add backward desc for op not in any segments * add exception info for sub_block test=develop * modify code style test=develop * modify code style test=develop * remove print functions test=develop * add API spec test=develop test=document_preview * make Recompute a child class of Optimizer test=develop test=document_preview * add API spec test=develop test=document_preview * modify API spec test=develop test=document_preview * add document for Recompute test=develop test=document_preview * change API doc of Rcompute test=develop test=document_preview * code cleaning test=develop test=document_preview * modify API spec * fix bugs when segments hold no element * add testcase for Recompute Optimizer test=develop test=document_preview * add test for apply_gradient, and code cleaning test=develop test=document_preview * add test case for load function * enable CI test=develop test=document * add test case test=develop test=document_preview * add sample code for 4 function of recompute optimizer test=develop test=document_preview --- paddle/fluid/API.spec | 10 +- python/paddle/fluid/backward.py | 415 ++++++++++++++++-- .../incubate/fleet/collective/__init__.py | 14 + python/paddle/fluid/optimizer.py | 296 ++++++++++++- .../fluid/tests/unittests/test_optimizer.py | 150 +++++++ 5 files changed, 850 insertions(+), 35 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 75c68a8e718..ae79dc30e87 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -1012,7 +1012,15 @@ paddle.fluid.optimizer.PipelineOptimizer.minimize (ArgSpec(args=['self', 'loss', paddle.fluid.optimizer.LookaheadOptimizer ('paddle.fluid.optimizer.LookaheadOptimizer', ('document', 'c291cadfa7452c7bf58b9e2f900a3511')) paddle.fluid.optimizer.LookaheadOptimizer.__init__ (ArgSpec(args=['self', 'inner_optimizer', 'alpha', 'k'], varargs=None, keywords=None, defaults=(0.5, 5)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.LookaheadOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program'], varargs=None, keywords=None, defaults=(None,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) -paddle.fluid.backward.append_backward (ArgSpec(args=['loss', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '52488008103886c793843a3828bacd5e')) +paddle.fluid.optimizer.RecomputeOptimizer ('paddle.fluid.optimizer.RecomputeOptimizer', ('document', '05769ba1182270f808f85488a50c8caa')) +paddle.fluid.optimizer.RecomputeOptimizer.__init__ (ArgSpec(args=['self', 'optimizer'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) +paddle.fluid.optimizer.RecomputeOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '7838e157ec5ff4f835f814adf3a2b9cc')) +paddle.fluid.optimizer.RecomputeOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'ec8dfa14fcd958d7c196f3d1a0ce6fa7')) +paddle.fluid.optimizer.RecomputeOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks', 'checkpoints'], varargs=None, keywords=None, defaults=(None, None, None, None, None)), ('document', 'a26b3dbb0f63ee81d847d92e9fb942dc')) +paddle.fluid.optimizer.RecomputeOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) +paddle.fluid.optimizer.RecomputeOptimizer.load (ArgSpec(args=['self', 'stat_dict'], varargs=None, keywords=None, defaults=None), ('document', '7b2b8ae72011bc4decb67e97623f2c56')) +paddle.fluid.optimizer.RecomputeOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'grad_clip'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) +paddle.fluid.backward.append_backward (ArgSpec(args=['loss', 'parameter_list', 'no_grad_set', 'callbacks', 'checkpoints'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', '52488008103886c793843a3828bacd5e')) paddle.fluid.backward.gradients (ArgSpec(args=['targets', 'inputs', 'target_gradients', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None)), ('document', 'e2097e1e0ed84ae44951437bfe269a1b')) paddle.fluid.regularizer.L1DecayRegularizer ('paddle.fluid.regularizer.L1DecayRegularizer', ('document', '34603757e70974d2fcc730643b382925')) paddle.fluid.regularizer.L1DecayRegularizer.__init__ (ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 6ae36555d77..07d7c9d19df 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -22,7 +22,161 @@ import six from .. import compat as cpt from . import unique_name -__all__ = ['append_backward', 'gradients'] +__all__ = [ + 'append_backward', + 'gradients', +] + + +class ProgramStats(object): + def __init__(self, block, ops): + self.block = block + self.ops = ops + self.op_deps = {} # op-> in_ops, out_ops + self.var_op_deps = {} # var as input op, var as output op + + def get_input_nodes(self): + input_names = [] + for name in self.var_op_deps: + if len(self.var_op_deps[name]["var_as_output_ops"]) <= 0 and \ + len(self.var_op_deps[name]["var_as_input_ops"]) > 0: + if self.block.var(name).persistable: + continue + input_names.append(name) + for op in self.ops: + if op.desc.type() == "read": + input_names.extend(op.desc.output_arg_names()) + return input_names + + def get_reserved_vars(self): + var_name = [] + for op in self.ops: + if op.desc.type() == "dropout": + var_name.extend(op.desc.output_arg_names()) + return var_name + + def get_out_of_subgraph_vars(self, begin_op_idx, end_op_idx): + var_name = [] + for i in range(begin_op_idx, end_op_idx, 1): + for name in self.ops[i].desc.output_arg_names(): + if name in self.var_op_deps: + for idx in self.var_op_deps[name]["var_as_input_ops"]: + if idx >= end_op_idx: + var_name.append(name) + return var_name + + def is_subgraph(self, var_group1, var_group2): + # should traverse from var_group1 to var_group2 + # max op idx in var_group2 + # min op idx in var_group1 + min_op_idx = len(self.ops) + max_op_idx = -1 + for name in var_group1: + if name not in self.var_op_deps: + return False, min_op_idx, max_op_idx + for name in var_group2: + if name not in self.var_op_deps: + return False, min_op_idx, max_op_idx + for name in var_group1: + op_idx = self.var_op_deps[name]["var_as_input_ops"] + for idx in op_idx: + min_op_idx = min(min_op_idx, idx) + for name in var_group2: + op_idx = self.var_op_deps[name]["var_as_output_ops"] + for idx in op_idx: + max_op_idx = max(max_op_idx, idx) + if min_op_idx >= max_op_idx: + return False, min_op_idx, max_op_idx + return True, min_op_idx, max_op_idx + + def build_stats(self): + for i, op in enumerate(self.ops): + self.op_deps[i] = {"in_ops": [], "out_ops": []} + for j, name in enumerate(op.desc.input_arg_names()): + if name in self.var_op_deps: + self.op_deps[i]["in_ops"].extend(self.var_op_deps[name][ + "var_as_output_ops"]) + for j, name in enumerate(op.desc.input_arg_names()): + if name in self.var_op_deps: + self.var_op_deps[name]["var_as_input_ops"].extend([i]) + else: + self.var_op_deps[name] = {} + self.var_op_deps[name]["var_as_input_ops"] = [i] + self.var_op_deps[name]["var_as_output_ops"] = [] + + for j, name in enumerate(op.desc.output_arg_names()): + if name in self.var_op_deps: + self.var_op_deps[name]["var_as_output_ops"].extend([i]) + else: + self.var_op_deps[name] = {} + self.var_op_deps[name]["var_as_input_ops"] = [] + self.var_op_deps[name]["var_as_output_ops"] = [i] + + for op_idx in self.op_deps[i]["in_ops"]: + self.op_deps[op_idx]["out_ops"].extend([i]) + + +def _pretty_op_desc_(op_desc, prefix): + out_s = "%s\tname:[%s]\n%s \tinputs:[%s]\n%s \toutputs:[%s]" % \ + (prefix + "_op", str(op_desc.type()), prefix + "_input", " ".join(op_desc.input_arg_names()), + prefix + "_output", " ".join(op_desc.output_arg_names())) + return out_s + + +def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars): + if len(descs) == 0: + return [] + result_descs = [] + op_role_attr_name = \ + core.op_proto_and_checker_maker.kOpRoleAttrName() + backward = core.op_proto_and_checker_maker.OpRole.Backward + for desc in descs: + if isinstance(desc, framework.Operator): + desc = desc.desc + if isinstance(desc, tuple): + desc = desc[0] + is_needed = False + for name in desc.output_arg_names(): + if main_block.has_var(name) and main_block.var(name).persistable: + continue + if name not in in_memory_vars: + is_needed = True + if is_needed: + new_op_desc = block.desc.append_op() + new_op_desc.copy_from(desc) + new_op_desc._set_attr(op_role_attr_name, backward) + result_descs.append(new_op_desc) + return result_descs + + +def _add_descs_to_block(descs, block): + if len(descs) == 0: + return [] + result_descs = [] + op_role_attr_name = \ + core.op_proto_and_checker_maker.kOpRoleAttrName() + backward = core.op_proto_and_checker_maker.OpRole.Backward + for desc in descs: + if isinstance(desc, framework.Operator): + desc = desc.desc + if isinstance(desc, tuple): + desc = desc[0] + new_op_desc = block.desc.append_op() + new_op_desc.copy_from(desc) + new_op_desc._set_attr(op_role_attr_name, backward) + result_descs.append(new_op_desc) + return result_descs + + +def _find_loss_op_(loss): + for op in reversed(loss.block.ops): + assert isinstance(op, framework.Operator) + if len(op.output_arg_names) == 1 and op.output_arg_names[ + 0] == loss.name: + loss.op = op + break + if loss.op is None: + raise ValueError("loss.op is None. Should not happend") def _rename_arg_(op_descs, old_name, new_name, begin_idx=None, end_idx=None): @@ -74,6 +228,20 @@ def _create_op_desc_(op_type, inputs, outputs, attrs): return op_desc +def _create_loss_op_desc_(loss): + op_desc = _create_op_desc_( + "fill_constant", {}, {"Out": [_append_grad_suffix_(loss.name)]}, { + "shape": [1], + "value": 1.0, + "dtype": loss.dtype, + "force_cpu": False, + core.op_proto_and_checker_maker.kOpRoleAttrName(): + int(core.op_proto_and_checker_maker.OpRole.Backward) | + int(core.op_proto_and_checker_maker.OpRole.Loss), + }) + return op_desc + + def _infer_var_data_type_(grad_var_name, block): """ Infer the data type of given grad variable @@ -115,7 +283,7 @@ def _some_in_set_(cands, s): def _strip_grad_suffix_(name): """ - Strip the grad suffix from the given varibale name + Strip the grad suffix from the given variable name e.g. x@GRAD ==> x y@GRAD@RENAME@1 ==> y """ @@ -145,6 +313,8 @@ def _addup_repetitive_outputs_(op_descs): renamed_var_start_idx = collections.defaultdict(list) for idx, op_desc in enumerate(op_descs): for var_name in op_desc.input_arg_names(): + if "@GRAD" not in var_name: + continue if len(renamed_vars[var_name]) > 1: pending_sum_ops.append((_create_op_desc_( "sum", {"X": renamed_vars[var_name]}, {"Out": [var_name]}, @@ -153,6 +323,10 @@ def _addup_repetitive_outputs_(op_descs): for param_idx, param_name in enumerate(op_desc.output_names()): arg_names = op_desc.output(param_name) for arg_idx, var_name in enumerate(arg_names): + if "@GRAD" not in var_name: + continue + #if "@RENAME@" in var_name: + # continue if var_name == core.empty_var_name( ) or var_name in op_desc.input_arg_names(): # empty variable or inplace op @@ -237,8 +411,11 @@ def _remove_no_grad_branch_(op_descs, no_grad_set): to_insert = [] for idx, op_desc in enumerate(op_descs): for arg in op_desc.input_arg_names(): + # arg is a gradient var name and arg should not have gradient if core.grad_var_suffix() in arg and arg in no_grad_set: x_in = _strip_grad_suffix_(arg) + # the reason should be: arg can be input of another grad op + # and the op is a not-to-remove op to_insert.append((_create_op_desc_( "fill_zeros_like", {"X": [x_in]}, {"Out": [arg]}, {}), idx)) @@ -375,6 +552,170 @@ def serialize_op_decs(op_desc): return proto.__str__() +def _append_backward_ops_with_checkpoints_( + block, ops, target_block, no_grad_dict, grad_to_var, checkpoints): + + checkpoints_name = [x.name for x in checkpoints] + """ + Create grad ops with forward ops, and insert them into given block + + Args: + block(Block): the block where forward ops are + ops(Op): the forward operators whose forward recomputation backward ops need to be added + target_block(Block): the block which is going to hold new generated grad ops + no_grad_dict(dict): + key(int) block index + val(str): corresponding forward variable name + checkpoints: variables that a user defined as checkpoint for forward recomputation + + Algorithms: + 1) go through all forward ops and induct all checkpoint vars + a. input variables can be deduced from forward program + b. input variables are checkpoints + c. variables that are used across segments will be held in memory + 2) find ops between checkpoints, i.e. recompute_segments + 3) go through each recompute_segments, add backward ops with forward recomputation + a. add ops in current recompute_segment as forward recomputation ops + b. rename all non-checkpoint variables in recomputation ops + c. add sum_op to merge gradient if needed + d. add backward ops of current recomputation ops + 4) remove no grad branch as it is in _remove_no_grad_branch_ + 5) Note1: all appended ops' OpRole are Backward + 6) Note2: variables that are used across segments will be held in memory + 7) Note3: all variables with new name should be returned so that _append_backward_vars_ can be called + 8) Note4: current forward recomputation backpropagation does not handle programs with subblock + """ + local_block = block.program._create_block() + buffer_block = block.program._create_block() + + program_stat = ProgramStats(block, ops) + program_stat.build_stats() + segments = [] + + if len(checkpoints) == 1: + # only one checkpoint + max_op_idx = -1 + var_group = [checkpoints_name[0]] + for name in var_group: + if name not in program_stat.var_op_deps: + break + op_idx = program_stat.var_op_deps[name]["var_as_output_ops"] + for idx in op_idx: + max_op_idx = max(max_op_idx, idx) + if max_op_idx > 0: + segments.append([0, max_op_idx + 1]) + else: + start_idx = 0 + while True: + if start_idx >= len(checkpoints_name) - 1: + break + flag, min_idx, max_idx = program_stat.is_subgraph( + [checkpoints_name[start_idx]], + [checkpoints_name[start_idx + 1]]) + if flag: + segments.append([min_idx, max_idx + 1]) + start_idx += 1 + + checkpoints_name = list(set(checkpoints_name)) + + if segments != [] and segments[0][0] != 0: + recompute_segments = [[0, segments[0][0]]] + segments + else: + recompute_segments = segments + vars_should_be_hold = [] + for segment in recompute_segments: + vars_should_be_hold.extend( + program_stat.get_out_of_subgraph_vars(segment[0], segment[1])) + vars_should_be_hold.extend(program_stat.get_reserved_vars()) + vars_should_be_hold.extend(program_stat.get_input_nodes()) + vars_should_be_hold = list(set(vars_should_be_hold)) + + # find variables that can not be deleted + grad_should_be_hold = [x + "@GRAD" for x in vars_should_be_hold] + vars_should_be_hold.extend(grad_should_be_hold) + + grad_op_descs = [] + var_name_dict = {} + + vars_in_memory = vars_should_be_hold + checkpoints_name + + max_calculated_op_position = len(ops) + if recompute_segments == []: + gap_ops = ops[0:max_calculated_op_position] + for op in reversed(gap_ops): + if op.has_attr("sub_block"): + raise Exception("Recompute don't support ops with sub_block" + "invoke op: %s" % + _pretty_op_desc_(op.desc, "with_sub_block")) + grad_op_desc, op_grad_to_var = core.get_grad_op_desc( + op.desc, cpt.to_text(no_grad_dict[block.idx]), []) + added_descs = _add_descs_to_block(grad_op_desc, local_block) + grad_op_descs.extend(added_descs) + grad_to_var.update(op_grad_to_var) + + for i, segment in enumerate(recompute_segments[::-1]): + # add grad op for ops not in any segments + gap_ops = ops[segment[1]:max_calculated_op_position] + max_calculated_op_position = segment[0] + for op in reversed(gap_ops): + if op.has_attr("sub_block"): + raise Exception("Recompute don't support ops with sub_block" + "invoke op: %s" % + _pretty_op_desc_(op.desc, "with_sub_block")) + grad_op_desc, op_grad_to_var = core.get_grad_op_desc( + op.desc, cpt.to_text(no_grad_dict[block.idx]), []) + added_descs = _add_descs_to_block(grad_op_desc, local_block) + grad_op_descs.extend(added_descs) + grad_to_var.update(op_grad_to_var) + + ff_ops = ops[segment[0]:segment[1]] + var_suffix = ".subprog_%d" % i + + for op in ff_ops: + if op.has_attr("sub_block"): + raise Exception("Recompute don't support ops with sub_block" + "invoke op: %s" % + _pretty_op_desc_(op.desc, "with_sub_block")) + input_and_output_names = [] + input_and_output_names.extend(op.desc.input_arg_names()) + input_and_output_names.extend(op.desc.output_arg_names()) + for name in input_and_output_names: + if block.var(name).persistable or name in checkpoints_name: + continue + if name in vars_should_be_hold: + continue + if name not in var_name_dict: + var_name_dict[name] = name + var_suffix + buffer_descs = _add_needed_descs_to_block(ff_ops, buffer_block, block, + vars_in_memory) + added_descs = _add_descs_to_block(ff_ops, local_block) + + # rename variable names in added_descs + for key in var_name_dict: + _rename_arg_(buffer_descs, key, var_name_dict[key]) + + # added_descs should be in grad_op_descs because it is backward op desc + grad_op_descs.extend(buffer_descs) + + #for op_desc in reversed(buffer_descs): + for op_desc in reversed(added_descs): + + grad_op_desc, op_grad_to_var = core.get_grad_op_desc( + op_desc, cpt.to_text(no_grad_dict[block.idx]), []) + + for key in var_name_dict: + _rename_arg_(grad_op_desc, key, var_name_dict[key]) + + grad_op_descs.extend(grad_op_desc) + grad_to_var.update(op_grad_to_var) + + grad_op_descs = _addup_repetitive_outputs_(grad_op_descs) + grad_op_descs = _remove_no_grad_branch_(grad_op_descs, + no_grad_dict[block.idx]) + added_descs = _add_descs_to_block(grad_op_descs, target_block) + return program_stat, checkpoints_name, vars_should_be_hold, recompute_segments + + def _append_backward_ops_(block, ops, target_block, @@ -459,12 +800,19 @@ def _append_backward_ops_(block, grad_op_descs.extend(grad_op_desc) grad_to_var.update(op_grad_to_var) + # add grad_op_desc by reversed ops + + # sum parameter's gradients' var given multiple var gradient grad_op_descs = _addup_repetitive_outputs_(grad_op_descs) + # if all outputs of the grad op are in no_grad_set, then just remove and fill zero + # if all inputs of the grad op are in no_grad_set, just remove this op grad_op_descs = _remove_no_grad_branch_(grad_op_descs, no_grad_dict[block.idx]) + # remove some backward ops not_need_ops = _find_not_need_ops(grad_op_descs, ops, input_grad_names_set) + grad_op_descs = [ op_desc for op_desc in grad_op_descs if op_desc not in not_need_ops ] @@ -530,6 +878,8 @@ def _rename_grad_(block, start_op_idx, grad_to_var, target_grad_map): op_desc._rename_input(name, var_map[name]) for name in op_desc.output_arg_names(): + if "@GRAD" not in name: + continue if block.desc.find_var(name.encode("ascii")): new_name = unique_name.generate(name) op_desc._rename_output(name, new_name) @@ -555,8 +905,11 @@ def _get_stop_gradients_(program): return no_grad_dict -def append_backward(loss, parameter_list=None, no_grad_set=None, - callbacks=None): +def append_backward(loss, + parameter_list=None, + no_grad_set=None, + callbacks=None, + checkpoints=None): """ Append backward part to main_program. @@ -629,14 +982,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, if loss.op is None: # the loss is from a cloned program. Find loss op manually. - for op in reversed(loss.block.ops): - assert isinstance(op, framework.Operator) - if len(op.output_arg_names) == 1 and op.output_arg_names[ - 0] == loss.name: - loss.op = op - break - if loss.op is None: - raise ValueError("loss.op is None. Should not happend") + _find_loss_op_(loss) loss.op._set_attr(core.op_proto_and_checker_maker.kOpRoleAttrName(), int(core.op_proto_and_checker_maker.OpRole.Forward) | @@ -661,19 +1007,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, current_block_idx = program.current_block_idx grad_to_var = dict() - op_desc = _create_op_desc_( - "fill_constant", - {}, - {"Out": [_append_grad_suffix_(loss.name)]}, - { - "shape": [1], # TODO(panyx0718): This can be loss.shape. - "value": 1.0, - "dtype": loss.dtype, - "force_cpu": False, - core.op_proto_and_checker_maker.kOpRoleAttrName(): - int(core.op_proto_and_checker_maker.OpRole.Backward) | - int(core.op_proto_and_checker_maker.OpRole.Loss), - }) + op_desc = _create_loss_op_desc_(loss) root_block.desc.append_op().copy_from(op_desc) block_no_grad_set = set(map(_strip_grad_suffix_, no_grad_dict[0])) @@ -689,14 +1023,29 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, if program._appending_grad_times > 1: input_grad_names_set = set([_append_grad_suffix_(loss.name)]) - _append_backward_ops_( - root_block, - op_path, - root_block, - no_grad_dict, - grad_to_var, - callbacks, - input_grad_names_set=input_grad_names_set) + + if checkpoints != None and \ + isinstance(checkpoints, list) and \ + len(checkpoints) > 0: + program_stat, checkpoint_names, \ + vars_should_be_hold, \ + recompute_segments = \ + _append_backward_ops_with_checkpoints_( + root_block, + op_path, + root_block, + no_grad_dict, + grad_to_var, + checkpoints) + else: + _append_backward_ops_( + root_block, + op_path, + root_block, + no_grad_dict, + grad_to_var, + callbacks, + input_grad_names_set=input_grad_names_set) # Because calc_gradient may be called multiple times, # we need rename the internal gradient variables so that they have diff --git a/python/paddle/fluid/incubate/fleet/collective/__init__.py b/python/paddle/fluid/incubate/fleet/collective/__init__.py index 32a45c2dc9e..4f939deac66 100644 --- a/python/paddle/fluid/incubate/fleet/collective/__init__.py +++ b/python/paddle/fluid/incubate/fleet/collective/__init__.py @@ -105,6 +105,8 @@ class DistributedStrategy(fluid.BuildStrategy): self.mode = "nccl2" # or collective self.collective_mode = None # local_sgd or grad_allreduce self.nccl_comm_num = 1 + self.forward_recompute = False + self.recompute_checkpoints = [] self.exec_strategy = fluid.ExecutionStrategy() @@ -150,6 +152,11 @@ class CollectiveOptimizer(DistributedOptimizer): def __init__(self, optimizer, strategy=DistributedStrategy()): super(CollectiveOptimizer, self).__init__(optimizer, strategy) + if strategy.forward_recompute: + self.forward_recompute = True + self.recompute_checkpoints = strategy.recompute_checkpoints + else: + self.forward_recompute = False self.print_config = False def backward(self, @@ -347,6 +354,13 @@ class CollectiveOptimizer(DistributedOptimizer): self._check_collective_mode(main_program, self._optimizer, self._strategy) + if self.forward_recompute: + assert (isinstance(self.recompute_checkpoints, list) and + len(self.recompute_checkpoints) > 0) + self._optimizer = \ + fluid.optimizer.RecomputeOptimizer(self._optimizer) + self._optimizer._set_checkpoints(self.recompute_checkpoints) + optimize_ops, param_grads = self._optimizer.minimize( loss, startup_program=startup_program, diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index e33f5c13be5..e0d68eb2d43 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -36,6 +36,7 @@ from paddle.fluid import core from paddle.fluid.layers import tensor from functools import reduce from .wrapped_decorator import signature_safe_contextmanager +from .. import compat as cpt __all__ = [ 'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'Ftrl', @@ -43,7 +44,8 @@ __all__ = [ 'AdamaxOptimizer', 'DecayedAdagradOptimizer', 'RMSPropOptimizer', 'FtrlOptimizer', 'Adadelta', 'ModelAverage', 'LarsMomentum', 'LarsMomentumOptimizer', 'DGCMomentumOptimizer', 'LambOptimizer', - 'ExponentialMovingAverage', 'PipelineOptimizer', 'LookaheadOptimizer' + 'ExponentialMovingAverage', 'PipelineOptimizer', 'LookaheadOptimizer', + 'RecomputeOptimizer' ] @@ -2977,6 +2979,298 @@ class PipelineOptimizer(object): } +class RecomputeOptimizer(Optimizer): + """ + Recompute Optimizer Wrapper + + Normally, a training step contains three sub-steps: first, run forward + Operators to calculate the loss; second, run backward Operators to + calculate gradient of the parameters; third, apply optimization method + to update the value of the parameters. + + In the forward computation process, all variables that are needed by + backward computation process will be kept in memory, which occupy a great + amount of memory when the network becomes very deep. + + Recompute split the network to k segments. In each segment, It will + recompute the forward Operators, before running backward operators. It is + very helpful for saving memory. + + The Variables that separate a network to segments are called as checkpoints, + and users should set it manually. The usage is very simple: + + Args: + optimizer (Optimizer): The optimizer that is applied to parameters. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + import numpy as np + def gen_data(): + return {"x": np.random.random(size=(32, 32)).astype('float32'), + "y": np.random.randint(2, size=(32, 1)).astype('int64')} + def mlp(input_x, input_y, hid_dim=128, label_dim=2): + print(input_x) + fc_1 = fluid.layers.fc(input=input_x, size=hid_dim) + prediction = fluid.layers.fc(input=[fc_1], size=label_dim, act='softmax') + cost = fluid.layers.cross_entropy(input=prediction, label=input_y) + sum_cost = fluid.layers.reduce_mean(cost) + return sum_cost, fc_1, prediction + input_x = fluid.layers.data(name="x", shape=[32], dtype='float32') + input_y = fluid.layers.data(name="y", shape=[1], dtype='int64') + cost, fc_1, pred = mlp(input_x, input_y) + + sgd = fluid.optimizer.Adam(learning_rate=0.01) + sgd = fluid.optimizer.RecomputeOptimizer(sgd) + sgd._set_checkpoints([fc_1, pred]) + sgd.minimize(cost) + + print("Finished optimize") + place = fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + step = 10 + + for i in range(step): + cost_val = exe.run(feed=gen_data(), + program=fluid.default_main_program(), + fetch_list=[cost.name]) + print("step=%d cost=%f" % (i, cost_val[0])) + + """ + + def __init__(self, optimizer): + self._optimizer = optimizer + self._checkpoints = None + + def _set_checkpoints(self, checkpoints): + self._checkpoints = checkpoints + + def load(self, stat_dict): + """ + load function is not supported by Recompute Optimizer for now. + :return: None + + Args: + stat_dict: the dict load by load_persistable method + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + import paddle.compat as cpt + + def mlp(input_x, input_y, hid_dim=128, label_dim=2): + fc_1 = fluid.layers.fc(input=input_x, size=hid_dim) + prediction = fluid.layers.fc(input=[fc_1], size=label_dim, act='softmax') + cost = fluid.layers.cross_entropy(input=prediction, label=input_y) + sum_cost = fluid.layers.reduce_mean(cost) + return sum_cost, fc_1, prediction + + input_x = fluid.layers.data(name="x", shape=[32], dtype='float32') + input_y = fluid.layers.data(name="y", shape=[1], dtype='int64') + cost, fc_1, pred = mlp(input_x, input_y) + print("Finished FF") + + sgd = fluid.optimizer.Adam(learning_rate=0.01) + sgd = fluid.optimizer.RecomputeOptimizer(sgd) + sgd._set_checkpoints([fc_1, pred]) + try: + stat_dict = {} + sgd.load(stat_dict) + except NotImplementedError as e: + print(cpt.get_exception_message(e)) + """ + raise NotImplementedError( + "load function is not supported by Recompute Optimizer for now") + + def apply_gradients(self, params_grads): + """ + call apply_gradients function of self._optimizer. + + Args: + params_grads (list): list of (param, grad) pair to do optimization. + + Returns: + list: A list of operators appended to the current program. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + import paddle.fluid.framework as framework + + def mlp(input_x, input_y, hid_dim=128, label_dim=2): + fc_1 = fluid.layers.fc(input=input_x, size=hid_dim) + prediction = fluid.layers.fc(input=[fc_1], size=label_dim, act='softmax') + cost = fluid.layers.cross_entropy(input=prediction, label=input_y) + sum_cost = fluid.layers.reduce_mean(cost) + return sum_cost, fc_1, prediction + + + input_x = fluid.layers.data(name="x", shape=[32], dtype='float32') + input_y = fluid.layers.data(name="y", shape=[1], dtype='int64') + cost, fc_1, pred = mlp(input_x, input_y) + print("Finished FF") + + sgd = fluid.optimizer.Adam(learning_rate=0.01) + sgd = fluid.optimizer.RecomputeOptimizer(sgd) + params_grads = sgd.backward( + cost, + startup_program=None, + parameter_list=None, + no_grad_set=None, + checkpoints=[fc_1, pred]) + + program = cost.block.program + with framework.program_guard(program, None): + optimize_ops = sgd.apply_gradients(params_grads) + + print("Finished apply gradients") + """ + + return self._optimizer.apply_gradients(params_grads=params_grads) + + def backward(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None, + callbacks=None, + checkpoints=None): + """ + call append_backward with checkpoints. + + Args: + loss (Variable): loss variable to run optimizations. + startup_program (Program): startup_program for initializing parameters + in `parameter_list`. + parameter_list (list): list of Variables to update. + no_grad_set (set|None): set of Variables should be ignored. + callbacks (list|None): list of callables to run when appending backward + operator for one parameter. + checkpoints (list): list of Variables as checkpoints + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + + def mlp(input_x, input_y, hid_dim=128, label_dim=2): + fc_1 = fluid.layers.fc(input=input_x, size=hid_dim) + prediction = fluid.layers.fc(input=[fc_1], size=label_dim, act='softmax') + cost = fluid.layers.cross_entropy(input=prediction, label=input_y) + sum_cost = fluid.layers.reduce_mean(cost) + return sum_cost, fc_1, prediction + + + input_x = fluid.layers.data(name="x", shape=[32], dtype='float32') + input_y = fluid.layers.data(name="y", shape=[1], dtype='int64') + cost, fc_1, pred = mlp(input_x, input_y) + print("Finished FF") + + sgd = fluid.optimizer.Adam(learning_rate=0.01) + sgd = fluid.optimizer.RecomputeOptimizer(sgd) + params_grads = sgd.backward( + cost, + startup_program=None, + parameter_list=None, + no_grad_set=None, + checkpoints=[fc_1, pred]) + print("Finished backward") + """ + + if framework.in_dygraph_mode(): + raise NotImplementedError( + "DyGraph current does not support recompute") + + self._dtype = loss.dtype + program = loss.block.program + with program_guard(program, startup_program): + params_grads = append_backward( + loss, + parameter_list, + no_grad_set, + checkpoints=self._checkpoints) + return params_grads + + def apply_optimize(self, loss, startup_program, params_grads): + """ + call the apply_optimize function of self._optimizer + + Args: + loss (Variable): loss variable to run optimizations. + startup_program (Program): startup_program for initializing parameters + in `parameter_list`. + params_grads (list): list of (param, grad) pair to do optimization. + + Examples: + .. code-block:: python + import paddle.fluid as fluid + + def mlp(input_x, input_y, hid_dim=128, label_dim=2): + fc_1 = fluid.layers.fc(input=input_x, size=hid_dim) + prediction = fluid.layers.fc(input=[fc_1], size=label_dim, act='softmax') + cost = fluid.layers.cross_entropy(input=prediction, label=input_y) + sum_cost = fluid.layers.reduce_mean(cost) + return sum_cost, fc_1, prediction + + + input_x = fluid.layers.data(name="x", shape=[32], dtype='float32') + input_y = fluid.layers.data(name="y", shape=[1], dtype='int64') + cost, fc_1, pred = mlp(input_x, input_y) + print("Finished FF") + + sgd = fluid.optimizer.Adam(learning_rate=0.01) + sgd = fluid.optimizer.RecomputeOptimizer(sgd) + params_grads = sgd.backward( + cost, + startup_program=None, + parameter_list=None, + no_grad_set=None, + checkpoints=[fc_1, pred]) + + optimize_ops = sgd.apply_optimize( + cost, startup_program=None, params_grads=params_grads) + + print("Finished apply_optimize") + """ + + return self._optimizer.apply_optimize( + loss, startup_program=startup_program, params_grads=params_grads) + + def minimize(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None, + grad_clip=None): + + assert (isinstance(loss, Variable)), "The loss should be an Variable." + assert (self._checkpoints is not None + ), "You should call _set_checkpoints first" + if framework.in_dygraph_mode(): + raise NotImplementedError( + "DyGraph current does not support recompute") + + params_grads = self.backward( + loss, + startup_program=startup_program, + parameter_list=parameter_list, + no_grad_set=no_grad_set, + checkpoints=self._checkpoints) + + if grad_clip: + # TODO(guru4elephant): should add grad_clip for static graph + pass + + optimize_ops = self.apply_optimize( + loss, startup_program=startup_program, params_grads=params_grads) + + return optimize_ops, params_grads + + class LookaheadOptimizer(object): """ This implements the Lookahead optimizer of the diff --git a/python/paddle/fluid/tests/unittests/test_optimizer.py b/python/paddle/fluid/tests/unittests/test_optimizer.py index fefee65c979..9761698e991 100644 --- a/python/paddle/fluid/tests/unittests/test_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_optimizer.py @@ -18,6 +18,7 @@ import unittest import paddle.fluid.framework as framework import paddle.fluid.optimizer as optimizer +import paddle.compat as cpt from paddle.fluid.backward import append_backward @@ -571,5 +572,154 @@ class TestLookaheadOptimizer(unittest.TestCase): self.assertEqual([op.type for op in opts], ["scale", "sgd"]) +class TestRecomputeOptimizer(unittest.TestCase): + def net(self): + program = framework.Program() + block = program.global_block() + mul_x = block.create_parameter( + dtype="float32", shape=[5, 10], lod_level=0, name="mul.x") + mul_y = block.create_var( + dtype="float32", shape=[10, 8], lod_level=0, name="mul.y") + mul_out = block.create_var( + dtype="float32", shape=[5, 8], lod_level=0, name="mul.out") + b1 = block.create_parameter( + dtype="float32", shape=[5, 8], lod_level=0, name="b1") + b1_out = block.create_var( + dtype="float32", shape=[5, 8], lod_level=0, name="b1_out") + b2 = block.create_parameter( + dtype="float32", shape=[5, 8], lod_level=0, name="b2") + b2_out = block.create_var( + dtype="float32", shape=[5, 8], lod_level=0, name="b2_out") + mean_out = block.create_var( + dtype="float32", shape=[1], lod_level=0, name="mean.out") + block.append_op( + type="mul", + inputs={"X": mul_x, + "Y": mul_y}, + outputs={"Out": mul_out}, + attrs={"x_num_col_dims": 1}) + block.append_op( + type="elementwise_add", + inputs={"X": mul_out, + "Y": b1}, + outputs={"Out": b1_out}) + block.append_op( + type="elementwise_add", + inputs={"X": b1_out, + "Y": b2}, + outputs={"Out": b2_out}) + block.append_op( + type="mean", inputs={"X": b2_out}, outputs={"Out": mean_out}) + + return mul_out, b1_out, b2_out, mean_out + + def test_no_checkpoint(self): + mul_out, b1_out, b2_out, mean_out = self.net() + self.assertEqual(len(mean_out.block.ops), 4) + self.assertEqual([op.type for op in mean_out.block.ops], + ["mul", "elementwise_add", "elementwise_add", "mean"]) + sgd_optimizer = optimizer.SGD(learning_rate=1.0) + recompute_optimizer = optimizer.RecomputeOptimizer(sgd_optimizer) + recompute_optimizer._set_checkpoints([]) + opts, params_grads = recompute_optimizer.minimize(mean_out) + + self.assertEqual(len(mean_out.block.ops), 12) + self.assertEqual([op.type for op in mean_out.block.ops], [ + "mul", "elementwise_add", "elementwise_add", "mean", + "fill_constant", "mean_grad", "elementwise_add_grad", + "elementwise_add_grad", "mul_grad", "sgd", "sgd", "sgd" + ]) + + def test_one_checkpoint(self): + mul_out, b1_out, b2_out, mean_out = self.net() + self.assertEqual(len(mean_out.block.ops), 4) + self.assertEqual([op.type for op in mean_out.block.ops], + ["mul", "elementwise_add", "elementwise_add", "mean"]) + sgd_optimizer = optimizer.SGD(learning_rate=1.0) + recompute_optimizer = optimizer.RecomputeOptimizer(sgd_optimizer) + recompute_optimizer._set_checkpoints([b1_out]) + opts, params_grads = recompute_optimizer.minimize(mean_out) + + self.assertEqual(len(mean_out.block.ops), 13) + self.assertEqual([op.type for op in mean_out.block.ops], [ + "mul", "elementwise_add", "elementwise_add", "mean", + "fill_constant", "mean_grad", "elementwise_add_grad", "mul", + "elementwise_add_grad", "mul_grad", "sgd", "sgd", "sgd" + ]) + + def test_multi_checkpoint(self): + mul_out, b1_out, b2_out, mean_out = self.net() + self.assertEqual(len(mean_out.block.ops), 4) + self.assertEqual([op.type for op in mean_out.block.ops], + ["mul", "elementwise_add", "elementwise_add", "mean"]) + sgd_optimizer = optimizer.SGD(learning_rate=1.0) + recompute_optimizer = optimizer.RecomputeOptimizer(sgd_optimizer) + recompute_optimizer._set_checkpoints([mul_out, b2_out]) + opts, params_grads = recompute_optimizer.minimize(mean_out) + + self.assertEqual(len(mean_out.block.ops), 13) + self.assertEqual([op.type for op in mean_out.block.ops], [ + "mul", "elementwise_add", "elementwise_add", "mean", + "fill_constant", "mean_grad", "elementwise_add", + "elementwise_add_grad", "elementwise_add_grad", "mul_grad", "sgd", + "sgd", "sgd" + ]) + + def test_adjacent_checkpoint(self): + mul_out, b1_out, b2_out, mean_out = self.net() + self.assertEqual(len(mean_out.block.ops), 4) + self.assertEqual([op.type for op in mean_out.block.ops], + ["mul", "elementwise_add", "elementwise_add", "mean"]) + sgd_optimizer = optimizer.SGD(learning_rate=1.0) + recompute_optimizer = optimizer.RecomputeOptimizer(sgd_optimizer) + recompute_optimizer._set_checkpoints([mul_out, b1_out]) + opts, params_grads = recompute_optimizer.minimize(mean_out) + + self.assertEqual(len(mean_out.block.ops), 12) + self.assertEqual([op.type for op in mean_out.block.ops], [ + "mul", "elementwise_add", "elementwise_add", "mean", + "fill_constant", "mean_grad", "elementwise_add_grad", + "elementwise_add_grad", "mul_grad", "sgd", "sgd", "sgd" + ]) + + def test_apply_gradients(self): + mul_out, b1_out, b2_out, mean_out = self.net() + sgd_optimizer = optimizer.SGD(learning_rate=1.0) + recompute_optimizer = optimizer.RecomputeOptimizer(sgd_optimizer) + recompute_optimizer._set_checkpoints([b1_out]) + # apply backward + params_grads = recompute_optimizer.backward( + mean_out, + startup_program=None, + parameter_list=None, + no_grad_set=None, + checkpoints=[b1_out]) + + # apply gradient + program = mean_out.block.program + with framework.program_guard(program, None): + optimize_ops = recompute_optimizer.apply_gradients(params_grads) + + self.assertEqual(len(mean_out.block.ops), 13) + self.assertEqual([op.type for op in mean_out.block.ops], [ + "mul", "elementwise_add", "elementwise_add", "mean", + "fill_constant", "mean_grad", "elementwise_add_grad", "mul", + "elementwise_add_grad", "mul_grad", "sgd", "sgd", "sgd" + ]) + + def test_load(self): + mul_out, b1_out, b2_out, mean_out = self.net() + sgd_optimizer = optimizer.SGD(learning_rate=1.0) + recompute_optimizer = optimizer.RecomputeOptimizer(sgd_optimizer) + recompute_optimizer._set_checkpoints([b1_out]) + try: + stat_dict = {} + recompute_optimizer.load(stat_dict) + except NotImplementedError as e: + self.assertEqual( + "load function is not supported by Recompute Optimizer for now", + cpt.get_exception_message(e)) + + if __name__ == '__main__': unittest.main() -- GitLab