未验证 提交 9901f696 编写于 作者: M mapingshuo 提交者: GitHub

Forward recompute3 (#19913)

* add recompute based checkpoints methods for large batch training
test=develop

* add append_backward_with_forward_recomputation
test=develop

* refine optimizer
test=develop

* update backward and optimizer
test=develop

* make Variable usable
test=develop

* add recompute code

* refine optimizer
test=develop

* refine addup _append_backward_ops_with_checkpoints_
1) for recompute part, just cache the grad_op_desc without appending to block
2) before appending grad_op_desc to backward part, addup_repetitive_vars, remove unused branch
test=develop

* make method private

* add recompute strategy into DistributedStrategy
test=develop

* checkpoint version3
test=develop

* remove some print information
test=develop

* remove unused sumop
test=develop

* try to fix recompute with graph building modules

* add input names to vars should be held

* add memory debug tool

* backup backward

* Fix bugs

* add backward desc for op not in any segments

* add exception info for sub_block

test=develop

* modify code style

test=develop

* modify code style

test=develop

* remove print functions

test=develop

* add API spec

test=develop
test=document_preview

* make Recompute a child class of Optimizer

test=develop
test=document_preview

* add API spec

test=develop
test=document_preview

* modify API spec

test=develop
test=document_preview

* add document for Recompute

test=develop
test=document_preview

* change API doc of Rcompute

test=develop
test=document_preview

* code cleaning

test=develop
test=document_preview

* modify API spec

* fix bugs when segments hold no element

* add testcase for Recompute Optimizer

test=develop
test=document_preview

* add test for apply_gradient, and code cleaning

test=develop
test=document_preview

* add test case for load function

* enable CI

test=develop
test=document

* add test case

test=develop
test=document_preview

* add sample code for 4 function of recompute optimizer

test=develop
test=document_preview
上级 d7251a8e
......@@ -1012,7 +1012,15 @@ paddle.fluid.optimizer.PipelineOptimizer.minimize (ArgSpec(args=['self', 'loss',
paddle.fluid.optimizer.LookaheadOptimizer ('paddle.fluid.optimizer.LookaheadOptimizer', ('document', 'c291cadfa7452c7bf58b9e2f900a3511'))
paddle.fluid.optimizer.LookaheadOptimizer.__init__ (ArgSpec(args=['self', 'inner_optimizer', 'alpha', 'k'], varargs=None, keywords=None, defaults=(0.5, 5)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.optimizer.LookaheadOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program'], varargs=None, keywords=None, defaults=(None,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.backward.append_backward (ArgSpec(args=['loss', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '52488008103886c793843a3828bacd5e'))
paddle.fluid.optimizer.RecomputeOptimizer ('paddle.fluid.optimizer.RecomputeOptimizer', ('document', '05769ba1182270f808f85488a50c8caa'))
paddle.fluid.optimizer.RecomputeOptimizer.__init__ (ArgSpec(args=['self', 'optimizer'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.optimizer.RecomputeOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '7838e157ec5ff4f835f814adf3a2b9cc'))
paddle.fluid.optimizer.RecomputeOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'ec8dfa14fcd958d7c196f3d1a0ce6fa7'))
paddle.fluid.optimizer.RecomputeOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks', 'checkpoints'], varargs=None, keywords=None, defaults=(None, None, None, None, None)), ('document', 'a26b3dbb0f63ee81d847d92e9fb942dc'))
paddle.fluid.optimizer.RecomputeOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.optimizer.RecomputeOptimizer.load (ArgSpec(args=['self', 'stat_dict'], varargs=None, keywords=None, defaults=None), ('document', '7b2b8ae72011bc4decb67e97623f2c56'))
paddle.fluid.optimizer.RecomputeOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'grad_clip'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.backward.append_backward (ArgSpec(args=['loss', 'parameter_list', 'no_grad_set', 'callbacks', 'checkpoints'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', '52488008103886c793843a3828bacd5e'))
paddle.fluid.backward.gradients (ArgSpec(args=['targets', 'inputs', 'target_gradients', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None)), ('document', 'e2097e1e0ed84ae44951437bfe269a1b'))
paddle.fluid.regularizer.L1DecayRegularizer ('paddle.fluid.regularizer.L1DecayRegularizer', ('document', '34603757e70974d2fcc730643b382925'))
paddle.fluid.regularizer.L1DecayRegularizer.__init__ (ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
......
......@@ -22,7 +22,161 @@ import six
from .. import compat as cpt
from . import unique_name
__all__ = ['append_backward', 'gradients']
__all__ = [
'append_backward',
'gradients',
]
class ProgramStats(object):
def __init__(self, block, ops):
self.block = block
self.ops = ops
self.op_deps = {} # op-> in_ops, out_ops
self.var_op_deps = {} # var as input op, var as output op
def get_input_nodes(self):
input_names = []
for name in self.var_op_deps:
if len(self.var_op_deps[name]["var_as_output_ops"]) <= 0 and \
len(self.var_op_deps[name]["var_as_input_ops"]) > 0:
if self.block.var(name).persistable:
continue
input_names.append(name)
for op in self.ops:
if op.desc.type() == "read":
input_names.extend(op.desc.output_arg_names())
return input_names
def get_reserved_vars(self):
var_name = []
for op in self.ops:
if op.desc.type() == "dropout":
var_name.extend(op.desc.output_arg_names())
return var_name
def get_out_of_subgraph_vars(self, begin_op_idx, end_op_idx):
var_name = []
for i in range(begin_op_idx, end_op_idx, 1):
for name in self.ops[i].desc.output_arg_names():
if name in self.var_op_deps:
for idx in self.var_op_deps[name]["var_as_input_ops"]:
if idx >= end_op_idx:
var_name.append(name)
return var_name
def is_subgraph(self, var_group1, var_group2):
# should traverse from var_group1 to var_group2
# max op idx in var_group2
# min op idx in var_group1
min_op_idx = len(self.ops)
max_op_idx = -1
for name in var_group1:
if name not in self.var_op_deps:
return False, min_op_idx, max_op_idx
for name in var_group2:
if name not in self.var_op_deps:
return False, min_op_idx, max_op_idx
for name in var_group1:
op_idx = self.var_op_deps[name]["var_as_input_ops"]
for idx in op_idx:
min_op_idx = min(min_op_idx, idx)
for name in var_group2:
op_idx = self.var_op_deps[name]["var_as_output_ops"]
for idx in op_idx:
max_op_idx = max(max_op_idx, idx)
if min_op_idx >= max_op_idx:
return False, min_op_idx, max_op_idx
return True, min_op_idx, max_op_idx
def build_stats(self):
for i, op in enumerate(self.ops):
self.op_deps[i] = {"in_ops": [], "out_ops": []}
for j, name in enumerate(op.desc.input_arg_names()):
if name in self.var_op_deps:
self.op_deps[i]["in_ops"].extend(self.var_op_deps[name][
"var_as_output_ops"])
for j, name in enumerate(op.desc.input_arg_names()):
if name in self.var_op_deps:
self.var_op_deps[name]["var_as_input_ops"].extend([i])
else:
self.var_op_deps[name] = {}
self.var_op_deps[name]["var_as_input_ops"] = [i]
self.var_op_deps[name]["var_as_output_ops"] = []
for j, name in enumerate(op.desc.output_arg_names()):
if name in self.var_op_deps:
self.var_op_deps[name]["var_as_output_ops"].extend([i])
else:
self.var_op_deps[name] = {}
self.var_op_deps[name]["var_as_input_ops"] = []
self.var_op_deps[name]["var_as_output_ops"] = [i]
for op_idx in self.op_deps[i]["in_ops"]:
self.op_deps[op_idx]["out_ops"].extend([i])
def _pretty_op_desc_(op_desc, prefix):
out_s = "%s\tname:[%s]\n%s \tinputs:[%s]\n%s \toutputs:[%s]" % \
(prefix + "_op", str(op_desc.type()), prefix + "_input", " ".join(op_desc.input_arg_names()),
prefix + "_output", " ".join(op_desc.output_arg_names()))
return out_s
def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars):
if len(descs) == 0:
return []
result_descs = []
op_role_attr_name = \
core.op_proto_and_checker_maker.kOpRoleAttrName()
backward = core.op_proto_and_checker_maker.OpRole.Backward
for desc in descs:
if isinstance(desc, framework.Operator):
desc = desc.desc
if isinstance(desc, tuple):
desc = desc[0]
is_needed = False
for name in desc.output_arg_names():
if main_block.has_var(name) and main_block.var(name).persistable:
continue
if name not in in_memory_vars:
is_needed = True
if is_needed:
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(desc)
new_op_desc._set_attr(op_role_attr_name, backward)
result_descs.append(new_op_desc)
return result_descs
def _add_descs_to_block(descs, block):
if len(descs) == 0:
return []
result_descs = []
op_role_attr_name = \
core.op_proto_and_checker_maker.kOpRoleAttrName()
backward = core.op_proto_and_checker_maker.OpRole.Backward
for desc in descs:
if isinstance(desc, framework.Operator):
desc = desc.desc
if isinstance(desc, tuple):
desc = desc[0]
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(desc)
new_op_desc._set_attr(op_role_attr_name, backward)
result_descs.append(new_op_desc)
return result_descs
def _find_loss_op_(loss):
for op in reversed(loss.block.ops):
assert isinstance(op, framework.Operator)
if len(op.output_arg_names) == 1 and op.output_arg_names[
0] == loss.name:
loss.op = op
break
if loss.op is None:
raise ValueError("loss.op is None. Should not happend")
def _rename_arg_(op_descs, old_name, new_name, begin_idx=None, end_idx=None):
......@@ -74,6 +228,20 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
return op_desc
def _create_loss_op_desc_(loss):
op_desc = _create_op_desc_(
"fill_constant", {}, {"Out": [_append_grad_suffix_(loss.name)]}, {
"shape": [1],
"value": 1.0,
"dtype": loss.dtype,
"force_cpu": False,
core.op_proto_and_checker_maker.kOpRoleAttrName():
int(core.op_proto_and_checker_maker.OpRole.Backward) |
int(core.op_proto_and_checker_maker.OpRole.Loss),
})
return op_desc
def _infer_var_data_type_(grad_var_name, block):
"""
Infer the data type of given grad variable
......@@ -115,7 +283,7 @@ def _some_in_set_(cands, s):
def _strip_grad_suffix_(name):
"""
Strip the grad suffix from the given varibale name
Strip the grad suffix from the given variable name
e.g. x@GRAD ==> x
y@GRAD@RENAME@1 ==> y
"""
......@@ -145,6 +313,8 @@ def _addup_repetitive_outputs_(op_descs):
renamed_var_start_idx = collections.defaultdict(list)
for idx, op_desc in enumerate(op_descs):
for var_name in op_desc.input_arg_names():
if "@GRAD" not in var_name:
continue
if len(renamed_vars[var_name]) > 1:
pending_sum_ops.append((_create_op_desc_(
"sum", {"X": renamed_vars[var_name]}, {"Out": [var_name]},
......@@ -153,6 +323,10 @@ def _addup_repetitive_outputs_(op_descs):
for param_idx, param_name in enumerate(op_desc.output_names()):
arg_names = op_desc.output(param_name)
for arg_idx, var_name in enumerate(arg_names):
if "@GRAD" not in var_name:
continue
#if "@RENAME@" in var_name:
# continue
if var_name == core.empty_var_name(
) or var_name in op_desc.input_arg_names():
# empty variable or inplace op
......@@ -237,8 +411,11 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
to_insert = []
for idx, op_desc in enumerate(op_descs):
for arg in op_desc.input_arg_names():
# arg is a gradient var name and arg should not have gradient
if core.grad_var_suffix() in arg and arg in no_grad_set:
x_in = _strip_grad_suffix_(arg)
# the reason should be: arg can be input of another grad op
# and the op is a not-to-remove op
to_insert.append((_create_op_desc_(
"fill_zeros_like", {"X": [x_in]}, {"Out": [arg]}, {}), idx))
......@@ -375,6 +552,170 @@ def serialize_op_decs(op_desc):
return proto.__str__()
def _append_backward_ops_with_checkpoints_(
block, ops, target_block, no_grad_dict, grad_to_var, checkpoints):
checkpoints_name = [x.name for x in checkpoints]
"""
Create grad ops with forward ops, and insert them into given block
Args:
block(Block): the block where forward ops are
ops(Op): the forward operators whose forward recomputation backward ops need to be added
target_block(Block): the block which is going to hold new generated grad ops
no_grad_dict(dict):
key(int) block index
val(str): corresponding forward variable name
checkpoints: variables that a user defined as checkpoint for forward recomputation
Algorithms:
1) go through all forward ops and induct all checkpoint vars
a. input variables can be deduced from forward program
b. input variables are checkpoints
c. variables that are used across segments will be held in memory
2) find ops between checkpoints, i.e. recompute_segments
3) go through each recompute_segments, add backward ops with forward recomputation
a. add ops in current recompute_segment as forward recomputation ops
b. rename all non-checkpoint variables in recomputation ops
c. add sum_op to merge gradient if needed
d. add backward ops of current recomputation ops
4) remove no grad branch as it is in _remove_no_grad_branch_
5) Note1: all appended ops' OpRole are Backward
6) Note2: variables that are used across segments will be held in memory
7) Note3: all variables with new name should be returned so that _append_backward_vars_ can be called
8) Note4: current forward recomputation backpropagation does not handle programs with subblock
"""
local_block = block.program._create_block()
buffer_block = block.program._create_block()
program_stat = ProgramStats(block, ops)
program_stat.build_stats()
segments = []
if len(checkpoints) == 1:
# only one checkpoint
max_op_idx = -1
var_group = [checkpoints_name[0]]
for name in var_group:
if name not in program_stat.var_op_deps:
break
op_idx = program_stat.var_op_deps[name]["var_as_output_ops"]
for idx in op_idx:
max_op_idx = max(max_op_idx, idx)
if max_op_idx > 0:
segments.append([0, max_op_idx + 1])
else:
start_idx = 0
while True:
if start_idx >= len(checkpoints_name) - 1:
break
flag, min_idx, max_idx = program_stat.is_subgraph(
[checkpoints_name[start_idx]],
[checkpoints_name[start_idx + 1]])
if flag:
segments.append([min_idx, max_idx + 1])
start_idx += 1
checkpoints_name = list(set(checkpoints_name))
if segments != [] and segments[0][0] != 0:
recompute_segments = [[0, segments[0][0]]] + segments
else:
recompute_segments = segments
vars_should_be_hold = []
for segment in recompute_segments:
vars_should_be_hold.extend(
program_stat.get_out_of_subgraph_vars(segment[0], segment[1]))
vars_should_be_hold.extend(program_stat.get_reserved_vars())
vars_should_be_hold.extend(program_stat.get_input_nodes())
vars_should_be_hold = list(set(vars_should_be_hold))
# find variables that can not be deleted
grad_should_be_hold = [x + "@GRAD" for x in vars_should_be_hold]
vars_should_be_hold.extend(grad_should_be_hold)
grad_op_descs = []
var_name_dict = {}
vars_in_memory = vars_should_be_hold + checkpoints_name
max_calculated_op_position = len(ops)
if recompute_segments == []:
gap_ops = ops[0:max_calculated_op_position]
for op in reversed(gap_ops):
if op.has_attr("sub_block"):
raise Exception("Recompute don't support ops with sub_block"
"invoke op: %s" %
_pretty_op_desc_(op.desc, "with_sub_block"))
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(no_grad_dict[block.idx]), [])
added_descs = _add_descs_to_block(grad_op_desc, local_block)
grad_op_descs.extend(added_descs)
grad_to_var.update(op_grad_to_var)
for i, segment in enumerate(recompute_segments[::-1]):
# add grad op for ops not in any segments
gap_ops = ops[segment[1]:max_calculated_op_position]
max_calculated_op_position = segment[0]
for op in reversed(gap_ops):
if op.has_attr("sub_block"):
raise Exception("Recompute don't support ops with sub_block"
"invoke op: %s" %
_pretty_op_desc_(op.desc, "with_sub_block"))
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(no_grad_dict[block.idx]), [])
added_descs = _add_descs_to_block(grad_op_desc, local_block)
grad_op_descs.extend(added_descs)
grad_to_var.update(op_grad_to_var)
ff_ops = ops[segment[0]:segment[1]]
var_suffix = ".subprog_%d" % i
for op in ff_ops:
if op.has_attr("sub_block"):
raise Exception("Recompute don't support ops with sub_block"
"invoke op: %s" %
_pretty_op_desc_(op.desc, "with_sub_block"))
input_and_output_names = []
input_and_output_names.extend(op.desc.input_arg_names())
input_and_output_names.extend(op.desc.output_arg_names())
for name in input_and_output_names:
if block.var(name).persistable or name in checkpoints_name:
continue
if name in vars_should_be_hold:
continue
if name not in var_name_dict:
var_name_dict[name] = name + var_suffix
buffer_descs = _add_needed_descs_to_block(ff_ops, buffer_block, block,
vars_in_memory)
added_descs = _add_descs_to_block(ff_ops, local_block)
# rename variable names in added_descs
for key in var_name_dict:
_rename_arg_(buffer_descs, key, var_name_dict[key])
# added_descs should be in grad_op_descs because it is backward op desc
grad_op_descs.extend(buffer_descs)
#for op_desc in reversed(buffer_descs):
for op_desc in reversed(added_descs):
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op_desc, cpt.to_text(no_grad_dict[block.idx]), [])
for key in var_name_dict:
_rename_arg_(grad_op_desc, key, var_name_dict[key])
grad_op_descs.extend(grad_op_desc)
grad_to_var.update(op_grad_to_var)
grad_op_descs = _addup_repetitive_outputs_(grad_op_descs)
grad_op_descs = _remove_no_grad_branch_(grad_op_descs,
no_grad_dict[block.idx])
added_descs = _add_descs_to_block(grad_op_descs, target_block)
return program_stat, checkpoints_name, vars_should_be_hold, recompute_segments
def _append_backward_ops_(block,
ops,
target_block,
......@@ -459,12 +800,19 @@ def _append_backward_ops_(block,
grad_op_descs.extend(grad_op_desc)
grad_to_var.update(op_grad_to_var)
# add grad_op_desc by reversed ops
# sum parameter's gradients' var given multiple var gradient
grad_op_descs = _addup_repetitive_outputs_(grad_op_descs)
# if all outputs of the grad op are in no_grad_set, then just remove and fill zero
# if all inputs of the grad op are in no_grad_set, just remove this op
grad_op_descs = _remove_no_grad_branch_(grad_op_descs,
no_grad_dict[block.idx])
# remove some backward ops
not_need_ops = _find_not_need_ops(grad_op_descs, ops, input_grad_names_set)
grad_op_descs = [
op_desc for op_desc in grad_op_descs if op_desc not in not_need_ops
]
......@@ -530,6 +878,8 @@ def _rename_grad_(block, start_op_idx, grad_to_var, target_grad_map):
op_desc._rename_input(name, var_map[name])
for name in op_desc.output_arg_names():
if "@GRAD" not in name:
continue
if block.desc.find_var(name.encode("ascii")):
new_name = unique_name.generate(name)
op_desc._rename_output(name, new_name)
......@@ -555,8 +905,11 @@ def _get_stop_gradients_(program):
return no_grad_dict
def append_backward(loss, parameter_list=None, no_grad_set=None,
callbacks=None):
def append_backward(loss,
parameter_list=None,
no_grad_set=None,
callbacks=None,
checkpoints=None):
"""
Append backward part to main_program.
......@@ -629,14 +982,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
if loss.op is None:
# the loss is from a cloned program. Find loss op manually.
for op in reversed(loss.block.ops):
assert isinstance(op, framework.Operator)
if len(op.output_arg_names) == 1 and op.output_arg_names[
0] == loss.name:
loss.op = op
break
if loss.op is None:
raise ValueError("loss.op is None. Should not happend")
_find_loss_op_(loss)
loss.op._set_attr(core.op_proto_and_checker_maker.kOpRoleAttrName(),
int(core.op_proto_and_checker_maker.OpRole.Forward) |
......@@ -661,19 +1007,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
current_block_idx = program.current_block_idx
grad_to_var = dict()
op_desc = _create_op_desc_(
"fill_constant",
{},
{"Out": [_append_grad_suffix_(loss.name)]},
{
"shape": [1], # TODO(panyx0718): This can be loss.shape.
"value": 1.0,
"dtype": loss.dtype,
"force_cpu": False,
core.op_proto_and_checker_maker.kOpRoleAttrName():
int(core.op_proto_and_checker_maker.OpRole.Backward) |
int(core.op_proto_and_checker_maker.OpRole.Loss),
})
op_desc = _create_loss_op_desc_(loss)
root_block.desc.append_op().copy_from(op_desc)
block_no_grad_set = set(map(_strip_grad_suffix_, no_grad_dict[0]))
......@@ -689,6 +1023,21 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
if program._appending_grad_times > 1:
input_grad_names_set = set([_append_grad_suffix_(loss.name)])
if checkpoints != None and \
isinstance(checkpoints, list) and \
len(checkpoints) > 0:
program_stat, checkpoint_names, \
vars_should_be_hold, \
recompute_segments = \
_append_backward_ops_with_checkpoints_(
root_block,
op_path,
root_block,
no_grad_dict,
grad_to_var,
checkpoints)
else:
_append_backward_ops_(
root_block,
op_path,
......
......@@ -105,6 +105,8 @@ class DistributedStrategy(fluid.BuildStrategy):
self.mode = "nccl2" # or collective
self.collective_mode = None # local_sgd or grad_allreduce
self.nccl_comm_num = 1
self.forward_recompute = False
self.recompute_checkpoints = []
self.exec_strategy = fluid.ExecutionStrategy()
......@@ -150,6 +152,11 @@ class CollectiveOptimizer(DistributedOptimizer):
def __init__(self, optimizer, strategy=DistributedStrategy()):
super(CollectiveOptimizer, self).__init__(optimizer, strategy)
if strategy.forward_recompute:
self.forward_recompute = True
self.recompute_checkpoints = strategy.recompute_checkpoints
else:
self.forward_recompute = False
self.print_config = False
def backward(self,
......@@ -347,6 +354,13 @@ class CollectiveOptimizer(DistributedOptimizer):
self._check_collective_mode(main_program, self._optimizer,
self._strategy)
if self.forward_recompute:
assert (isinstance(self.recompute_checkpoints, list) and
len(self.recompute_checkpoints) > 0)
self._optimizer = \
fluid.optimizer.RecomputeOptimizer(self._optimizer)
self._optimizer._set_checkpoints(self.recompute_checkpoints)
optimize_ops, param_grads = self._optimizer.minimize(
loss,
startup_program=startup_program,
......
......@@ -36,6 +36,7 @@ from paddle.fluid import core
from paddle.fluid.layers import tensor
from functools import reduce
from .wrapped_decorator import signature_safe_contextmanager
from .. import compat as cpt
__all__ = [
'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'Ftrl',
......@@ -43,7 +44,8 @@ __all__ = [
'AdamaxOptimizer', 'DecayedAdagradOptimizer', 'RMSPropOptimizer',
'FtrlOptimizer', 'Adadelta', 'ModelAverage', 'LarsMomentum',
'LarsMomentumOptimizer', 'DGCMomentumOptimizer', 'LambOptimizer',
'ExponentialMovingAverage', 'PipelineOptimizer', 'LookaheadOptimizer'
'ExponentialMovingAverage', 'PipelineOptimizer', 'LookaheadOptimizer',
'RecomputeOptimizer'
]
......@@ -2977,6 +2979,298 @@ class PipelineOptimizer(object):
}
class RecomputeOptimizer(Optimizer):
"""
Recompute Optimizer Wrapper
Normally, a training step contains three sub-steps: first, run forward
Operators to calculate the loss; second, run backward Operators to
calculate gradient of the parameters; third, apply optimization method
to update the value of the parameters.
In the forward computation process, all variables that are needed by
backward computation process will be kept in memory, which occupy a great
amount of memory when the network becomes very deep.
Recompute split the network to k segments. In each segment, It will
recompute the forward Operators, before running backward operators. It is
very helpful for saving memory.
The Variables that separate a network to segments are called as checkpoints,
and users should set it manually. The usage is very simple:
Args:
optimizer (Optimizer): The optimizer that is applied to parameters.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
def gen_data():
return {"x": np.random.random(size=(32, 32)).astype('float32'),
"y": np.random.randint(2, size=(32, 1)).astype('int64')}
def mlp(input_x, input_y, hid_dim=128, label_dim=2):
print(input_x)
fc_1 = fluid.layers.fc(input=input_x, size=hid_dim)
prediction = fluid.layers.fc(input=[fc_1], size=label_dim, act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=input_y)
sum_cost = fluid.layers.reduce_mean(cost)
return sum_cost, fc_1, prediction
input_x = fluid.layers.data(name="x", shape=[32], dtype='float32')
input_y = fluid.layers.data(name="y", shape=[1], dtype='int64')
cost, fc_1, pred = mlp(input_x, input_y)
sgd = fluid.optimizer.Adam(learning_rate=0.01)
sgd = fluid.optimizer.RecomputeOptimizer(sgd)
sgd._set_checkpoints([fc_1, pred])
sgd.minimize(cost)
print("Finished optimize")
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
step = 10
for i in range(step):
cost_val = exe.run(feed=gen_data(),
program=fluid.default_main_program(),
fetch_list=[cost.name])
print("step=%d cost=%f" % (i, cost_val[0]))
"""
def __init__(self, optimizer):
self._optimizer = optimizer
self._checkpoints = None
def _set_checkpoints(self, checkpoints):
self._checkpoints = checkpoints
def load(self, stat_dict):
"""
load function is not supported by Recompute Optimizer for now.
:return: None
Args:
stat_dict: the dict load by load_persistable method
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle.compat as cpt
def mlp(input_x, input_y, hid_dim=128, label_dim=2):
fc_1 = fluid.layers.fc(input=input_x, size=hid_dim)
prediction = fluid.layers.fc(input=[fc_1], size=label_dim, act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=input_y)
sum_cost = fluid.layers.reduce_mean(cost)
return sum_cost, fc_1, prediction
input_x = fluid.layers.data(name="x", shape=[32], dtype='float32')
input_y = fluid.layers.data(name="y", shape=[1], dtype='int64')
cost, fc_1, pred = mlp(input_x, input_y)
print("Finished FF")
sgd = fluid.optimizer.Adam(learning_rate=0.01)
sgd = fluid.optimizer.RecomputeOptimizer(sgd)
sgd._set_checkpoints([fc_1, pred])
try:
stat_dict = {}
sgd.load(stat_dict)
except NotImplementedError as e:
print(cpt.get_exception_message(e))
"""
raise NotImplementedError(
"load function is not supported by Recompute Optimizer for now")
def apply_gradients(self, params_grads):
"""
call apply_gradients function of self._optimizer.
Args:
params_grads (list): list of (param, grad) pair to do optimization.
Returns:
list: A list of operators appended to the current program.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle.fluid.framework as framework
def mlp(input_x, input_y, hid_dim=128, label_dim=2):
fc_1 = fluid.layers.fc(input=input_x, size=hid_dim)
prediction = fluid.layers.fc(input=[fc_1], size=label_dim, act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=input_y)
sum_cost = fluid.layers.reduce_mean(cost)
return sum_cost, fc_1, prediction
input_x = fluid.layers.data(name="x", shape=[32], dtype='float32')
input_y = fluid.layers.data(name="y", shape=[1], dtype='int64')
cost, fc_1, pred = mlp(input_x, input_y)
print("Finished FF")
sgd = fluid.optimizer.Adam(learning_rate=0.01)
sgd = fluid.optimizer.RecomputeOptimizer(sgd)
params_grads = sgd.backward(
cost,
startup_program=None,
parameter_list=None,
no_grad_set=None,
checkpoints=[fc_1, pred])
program = cost.block.program
with framework.program_guard(program, None):
optimize_ops = sgd.apply_gradients(params_grads)
print("Finished apply gradients")
"""
return self._optimizer.apply_gradients(params_grads=params_grads)
def backward(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None,
callbacks=None,
checkpoints=None):
"""
call append_backward with checkpoints.
Args:
loss (Variable): loss variable to run optimizations.
startup_program (Program): startup_program for initializing parameters
in `parameter_list`.
parameter_list (list): list of Variables to update.
no_grad_set (set|None): set of Variables should be ignored.
callbacks (list|None): list of callables to run when appending backward
operator for one parameter.
checkpoints (list): list of Variables as checkpoints
Examples:
.. code-block:: python
import paddle.fluid as fluid
def mlp(input_x, input_y, hid_dim=128, label_dim=2):
fc_1 = fluid.layers.fc(input=input_x, size=hid_dim)
prediction = fluid.layers.fc(input=[fc_1], size=label_dim, act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=input_y)
sum_cost = fluid.layers.reduce_mean(cost)
return sum_cost, fc_1, prediction
input_x = fluid.layers.data(name="x", shape=[32], dtype='float32')
input_y = fluid.layers.data(name="y", shape=[1], dtype='int64')
cost, fc_1, pred = mlp(input_x, input_y)
print("Finished FF")
sgd = fluid.optimizer.Adam(learning_rate=0.01)
sgd = fluid.optimizer.RecomputeOptimizer(sgd)
params_grads = sgd.backward(
cost,
startup_program=None,
parameter_list=None,
no_grad_set=None,
checkpoints=[fc_1, pred])
print("Finished backward")
"""
if framework.in_dygraph_mode():
raise NotImplementedError(
"DyGraph current does not support recompute")
self._dtype = loss.dtype
program = loss.block.program
with program_guard(program, startup_program):
params_grads = append_backward(
loss,
parameter_list,
no_grad_set,
checkpoints=self._checkpoints)
return params_grads
def apply_optimize(self, loss, startup_program, params_grads):
"""
call the apply_optimize function of self._optimizer
Args:
loss (Variable): loss variable to run optimizations.
startup_program (Program): startup_program for initializing parameters
in `parameter_list`.
params_grads (list): list of (param, grad) pair to do optimization.
Examples:
.. code-block:: python
import paddle.fluid as fluid
def mlp(input_x, input_y, hid_dim=128, label_dim=2):
fc_1 = fluid.layers.fc(input=input_x, size=hid_dim)
prediction = fluid.layers.fc(input=[fc_1], size=label_dim, act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=input_y)
sum_cost = fluid.layers.reduce_mean(cost)
return sum_cost, fc_1, prediction
input_x = fluid.layers.data(name="x", shape=[32], dtype='float32')
input_y = fluid.layers.data(name="y", shape=[1], dtype='int64')
cost, fc_1, pred = mlp(input_x, input_y)
print("Finished FF")
sgd = fluid.optimizer.Adam(learning_rate=0.01)
sgd = fluid.optimizer.RecomputeOptimizer(sgd)
params_grads = sgd.backward(
cost,
startup_program=None,
parameter_list=None,
no_grad_set=None,
checkpoints=[fc_1, pred])
optimize_ops = sgd.apply_optimize(
cost, startup_program=None, params_grads=params_grads)
print("Finished apply_optimize")
"""
return self._optimizer.apply_optimize(
loss, startup_program=startup_program, params_grads=params_grads)
def minimize(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None,
grad_clip=None):
assert (isinstance(loss, Variable)), "The loss should be an Variable."
assert (self._checkpoints is not None
), "You should call _set_checkpoints first"
if framework.in_dygraph_mode():
raise NotImplementedError(
"DyGraph current does not support recompute")
params_grads = self.backward(
loss,
startup_program=startup_program,
parameter_list=parameter_list,
no_grad_set=no_grad_set,
checkpoints=self._checkpoints)
if grad_clip:
# TODO(guru4elephant): should add grad_clip for static graph
pass
optimize_ops = self.apply_optimize(
loss, startup_program=startup_program, params_grads=params_grads)
return optimize_ops, params_grads
class LookaheadOptimizer(object):
"""
This implements the Lookahead optimizer of the
......
......@@ -18,6 +18,7 @@ import unittest
import paddle.fluid.framework as framework
import paddle.fluid.optimizer as optimizer
import paddle.compat as cpt
from paddle.fluid.backward import append_backward
......@@ -571,5 +572,154 @@ class TestLookaheadOptimizer(unittest.TestCase):
self.assertEqual([op.type for op in opts], ["scale", "sgd"])
class TestRecomputeOptimizer(unittest.TestCase):
def net(self):
program = framework.Program()
block = program.global_block()
mul_x = block.create_parameter(
dtype="float32", shape=[5, 10], lod_level=0, name="mul.x")
mul_y = block.create_var(
dtype="float32", shape=[10, 8], lod_level=0, name="mul.y")
mul_out = block.create_var(
dtype="float32", shape=[5, 8], lod_level=0, name="mul.out")
b1 = block.create_parameter(
dtype="float32", shape=[5, 8], lod_level=0, name="b1")
b1_out = block.create_var(
dtype="float32", shape=[5, 8], lod_level=0, name="b1_out")
b2 = block.create_parameter(
dtype="float32", shape=[5, 8], lod_level=0, name="b2")
b2_out = block.create_var(
dtype="float32", shape=[5, 8], lod_level=0, name="b2_out")
mean_out = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="mean.out")
block.append_op(
type="mul",
inputs={"X": mul_x,
"Y": mul_y},
outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1})
block.append_op(
type="elementwise_add",
inputs={"X": mul_out,
"Y": b1},
outputs={"Out": b1_out})
block.append_op(
type="elementwise_add",
inputs={"X": b1_out,
"Y": b2},
outputs={"Out": b2_out})
block.append_op(
type="mean", inputs={"X": b2_out}, outputs={"Out": mean_out})
return mul_out, b1_out, b2_out, mean_out
def test_no_checkpoint(self):
mul_out, b1_out, b2_out, mean_out = self.net()
self.assertEqual(len(mean_out.block.ops), 4)
self.assertEqual([op.type for op in mean_out.block.ops],
["mul", "elementwise_add", "elementwise_add", "mean"])
sgd_optimizer = optimizer.SGD(learning_rate=1.0)
recompute_optimizer = optimizer.RecomputeOptimizer(sgd_optimizer)
recompute_optimizer._set_checkpoints([])
opts, params_grads = recompute_optimizer.minimize(mean_out)
self.assertEqual(len(mean_out.block.ops), 12)
self.assertEqual([op.type for op in mean_out.block.ops], [
"mul", "elementwise_add", "elementwise_add", "mean",
"fill_constant", "mean_grad", "elementwise_add_grad",
"elementwise_add_grad", "mul_grad", "sgd", "sgd", "sgd"
])
def test_one_checkpoint(self):
mul_out, b1_out, b2_out, mean_out = self.net()
self.assertEqual(len(mean_out.block.ops), 4)
self.assertEqual([op.type for op in mean_out.block.ops],
["mul", "elementwise_add", "elementwise_add", "mean"])
sgd_optimizer = optimizer.SGD(learning_rate=1.0)
recompute_optimizer = optimizer.RecomputeOptimizer(sgd_optimizer)
recompute_optimizer._set_checkpoints([b1_out])
opts, params_grads = recompute_optimizer.minimize(mean_out)
self.assertEqual(len(mean_out.block.ops), 13)
self.assertEqual([op.type for op in mean_out.block.ops], [
"mul", "elementwise_add", "elementwise_add", "mean",
"fill_constant", "mean_grad", "elementwise_add_grad", "mul",
"elementwise_add_grad", "mul_grad", "sgd", "sgd", "sgd"
])
def test_multi_checkpoint(self):
mul_out, b1_out, b2_out, mean_out = self.net()
self.assertEqual(len(mean_out.block.ops), 4)
self.assertEqual([op.type for op in mean_out.block.ops],
["mul", "elementwise_add", "elementwise_add", "mean"])
sgd_optimizer = optimizer.SGD(learning_rate=1.0)
recompute_optimizer = optimizer.RecomputeOptimizer(sgd_optimizer)
recompute_optimizer._set_checkpoints([mul_out, b2_out])
opts, params_grads = recompute_optimizer.minimize(mean_out)
self.assertEqual(len(mean_out.block.ops), 13)
self.assertEqual([op.type for op in mean_out.block.ops], [
"mul", "elementwise_add", "elementwise_add", "mean",
"fill_constant", "mean_grad", "elementwise_add",
"elementwise_add_grad", "elementwise_add_grad", "mul_grad", "sgd",
"sgd", "sgd"
])
def test_adjacent_checkpoint(self):
mul_out, b1_out, b2_out, mean_out = self.net()
self.assertEqual(len(mean_out.block.ops), 4)
self.assertEqual([op.type for op in mean_out.block.ops],
["mul", "elementwise_add", "elementwise_add", "mean"])
sgd_optimizer = optimizer.SGD(learning_rate=1.0)
recompute_optimizer = optimizer.RecomputeOptimizer(sgd_optimizer)
recompute_optimizer._set_checkpoints([mul_out, b1_out])
opts, params_grads = recompute_optimizer.minimize(mean_out)
self.assertEqual(len(mean_out.block.ops), 12)
self.assertEqual([op.type for op in mean_out.block.ops], [
"mul", "elementwise_add", "elementwise_add", "mean",
"fill_constant", "mean_grad", "elementwise_add_grad",
"elementwise_add_grad", "mul_grad", "sgd", "sgd", "sgd"
])
def test_apply_gradients(self):
mul_out, b1_out, b2_out, mean_out = self.net()
sgd_optimizer = optimizer.SGD(learning_rate=1.0)
recompute_optimizer = optimizer.RecomputeOptimizer(sgd_optimizer)
recompute_optimizer._set_checkpoints([b1_out])
# apply backward
params_grads = recompute_optimizer.backward(
mean_out,
startup_program=None,
parameter_list=None,
no_grad_set=None,
checkpoints=[b1_out])
# apply gradient
program = mean_out.block.program
with framework.program_guard(program, None):
optimize_ops = recompute_optimizer.apply_gradients(params_grads)
self.assertEqual(len(mean_out.block.ops), 13)
self.assertEqual([op.type for op in mean_out.block.ops], [
"mul", "elementwise_add", "elementwise_add", "mean",
"fill_constant", "mean_grad", "elementwise_add_grad", "mul",
"elementwise_add_grad", "mul_grad", "sgd", "sgd", "sgd"
])
def test_load(self):
mul_out, b1_out, b2_out, mean_out = self.net()
sgd_optimizer = optimizer.SGD(learning_rate=1.0)
recompute_optimizer = optimizer.RecomputeOptimizer(sgd_optimizer)
recompute_optimizer._set_checkpoints([b1_out])
try:
stat_dict = {}
recompute_optimizer.load(stat_dict)
except NotImplementedError as e:
self.assertEqual(
"load function is not supported by Recompute Optimizer for now",
cpt.get_exception_message(e))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册