diff --git a/doc/design/backward.md b/doc/design/backward.md index 35f03692bb052e0a04db18d28f6f8d901215a553..20fda7a98f514a3f1c1c2d0ba7447ec954b21d5a 100644 --- a/doc/design/backward.md +++ b/doc/design/backward.md @@ -106,9 +106,11 @@ See function `_addup_repetitive_outputs_` in `backward.py` for implementation de In our framework, variables can be marked as *no_gradient*, it means that the gradient of this variable is unnecessary and can be considered as zero in model training. Apparently, when all the outputs of some `grad_op` are marked as *no_gradient*, the `grad_op` itself can be skipped in backward pass. -But these unnecessary gradients still need to be creating and initialized by something, otherwise following `grad_op`s who take these gradients as inputs take the risk of using uninitialized memory. In our code, we employ `fill_zeros_like_op` to initialize them as all zeros. +Another situation is all the gradient inputs of some `grad_op` are marked as *no_gradient*, which means all of them can be considered as zeros. For `grad_op`s are in essence the propagation of gradients, all the outputs are definitely zeros when all gradient inputs are zeros. Therefore the `grad_op` can also be skipped. -This features are implemented in function `_remove_no_grad_branch_`. It checks new created `grad_op`s one-by-one, removes whose outputs are all in `no_grad_set` or inserts `fill_zeros_like_op` when its necessary. We can get the `no_grad_set` from the `_append_backward_ops_` argument `no_grad_dict` or generate it on the fly by scanning all variables' `no_gradient` attribute(True or False). +It should be noted that all these zero gradients still need to be creating and initialized by something, otherwise following `grad_op`s who take these gradients as inputs take the risk of using uninitialized memory. In our code, we employ `fill_zeros_like_op` to initialize them as all zeros. + +This features are implemented in function `_remove_no_grad_branch_`. It checks new created `grad_op`s one-by-one, removes who can be skipped and inserts `fill_zeros_like_op` when its necessary. We can get the `no_grad_set` from the `_append_backward_ops_` argument `no_grad_dict` or generate it on the fly by scanning all variables' `no_gradient` attribute(True or False). ### Creating Backward Variables diff --git a/python/paddle/v2/fluid/backward.py b/python/paddle/v2/fluid/backward.py index f11c83f59c930784ca355acc3193c3c352db10e5..ac60bf543600008fd5339c1a378951374afc4ad6 100644 --- a/python/paddle/v2/fluid/backward.py +++ b/python/paddle/v2/fluid/backward.py @@ -57,6 +57,8 @@ def _all_in_set_(cands, s): """ Test if all elements of 'cands' are in set 's' """ + if len(cands) == 0: + return False for c in cands: if not c in s: return False @@ -136,12 +138,23 @@ def _remove_no_grad_branch_(op_descs, no_grad_set): Remove unnecessary grad ops A grad op can be removed in two cases: 1. all outputs of the grad op are in 'no_grad_set' - 2. (TODO) all grad inputs of the grad op are in 'no_grad_set' + 2. all grad inputs of the grad op are in 'no_grad_set' """ + + def _op_can_be_removed_(op_desc, no_grad_set): + out_arg_names = op_desc.output_arg_names() + if len(out_arg_names) == 0 or _all_in_set_(out_arg_names, no_grad_set): + return True + if _all_in_set_( + filter(lambda name: name.find(core.grad_var_suffix()) != -1, + op_desc.input_arg_names()), no_grad_set): + no_grad_set.union(out_arg_names) + return True + return False + # Remove ops whose outputs are all in no_grad_dict op_descs = filter( - lambda op_desc: not _all_in_set_(op_desc.output_arg_names(), no_grad_set), - op_descs) + lambda op_desc: not _op_can_be_removed_(op_desc, no_grad_set), op_descs) # Insert fill_zeros_like_op to_insert = [] for idx, op_desc in enumerate(op_descs): @@ -284,7 +297,9 @@ def append_backward(loss, parameter_list=None, no_grad_set=None): block_no_grad_set.add(_append_grad_suffix_(var.name)) no_grad_dict[block.idx] = block_no_grad_set elif isinstance(no_grad_set, set): - no_grad_dict = {0: no_grad_set} + no_grad_dict = { + 0: set([_append_grad_suffix_(name) for name in no_grad_set]) + } else: raise ValueError("'no_grad_set' should be a set or None.")