diff --git a/python/paddle/distributed/fleet/recompute/recompute.py b/python/paddle/distributed/fleet/recompute/recompute.py index f20bbdb3834c822abee6c9bc8549d804b608edd5..70a38ff3b9aa663f8b6b21c85ba7e1677310f109 100755 --- a/python/paddle/distributed/fleet/recompute/recompute.py +++ b/python/paddle/distributed/fleet/recompute/recompute.py @@ -18,7 +18,6 @@ import weakref import paddle from paddle import framework from paddle.autograd import PyLayer -from paddle.autograd.py_layer import LegacyPyLayer from paddle.distributed.fleet.meta_parallel.parallel_layers.random import ( get_rng_state_tracker, ) @@ -67,147 +66,6 @@ def swith_rng_state_tracker(rng_state, tracker): get_rng_state_tracker().set_states_tracker(orig_rng_tracker) -class LegacyRecomputeFunction(LegacyPyLayer): - @staticmethod - def forward(ctx, run_function, preserve_rng_state, *args): - # store for recomputing - ctx.run_function = run_function - ctx.preserve_rng_state = preserve_rng_state - - # NOTE the number of outputs of backward() should be equal to the number of tensors in forward()'s input - # the order of tensors in backward()'s output should be the same as tensors in forward()'s input - # None tensor inputs will be filtered in backward inputs. - - # save input for backward - ctx.inputs = [] - ctx.tensor_indices = [] - tensor_inputs = [] - for i, arg in enumerate(args): - if paddle.is_tensor(arg): - tensor_inputs.append(arg) - ctx.tensor_indices.append(i) - ctx.inputs.append(None) - else: - ctx.inputs.append(arg) - ctx.save_for_backward(*tensor_inputs) - - # NOTE recompute with restore RNG only support one senario where one process for one cuda gpu. - # one process with multiple gpu and mix-gpu-cpu senarios are not support - if ctx.preserve_rng_state: - ctx.fw_rng_state = paddle.get_rng_state() - ctx.fwd_rng_state_tracker = ( - get_rng_state_tracker().get_states_tracker() - ) - - # TODO support AMP - tracer = framework._dygraph_tracer() - ctx.is_fw_autocast = ( - False if tracer._amp_level == core.AmpLevel.O0 else True - ) - if tracer._amp_level == core.AmpLevel.O2: - ctx.amp_level = 'O2' - elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0): - ctx.amp_level = 'O1' - else: - raise ValueError( - "unsupported amp level: {}".format(tracer._amp_level) - ) - - if tracer._amp_dtype == 'float16': - ctx.amp_dtype = 'float16' - elif tracer._amp_dtype in ('bfloat16', 'float32'): - ctx.amp_dtype = 'bfloat16' - else: - raise ValueError( - "unsupported amp dtype: {}".format(tracer._amp_dtype) - ) - - ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list() - - with paddle.no_grad(): - outputs = run_function(*args) - return outputs - - @staticmethod - def backward(ctx, *args): - with paddle.fluid.dygraph.guard(): - # TODO need to check the recompute calling is vaild or not - - # Restore inputs - inputs = list(ctx.inputs) - tensor_indices = ctx.tensor_indices - tensors = ctx.saved_tensor() - for i, idx in enumerate(tensor_indices): - inputs[idx] = tensors[i] - - # paddle.enable_grad() - tracer = framework._dygraph_tracer() - tracer._has_grad = True - - # NOTE support AMP - # need restore auto_cast state as well as w/b list - if ctx.preserve_rng_state: - with swith_rng_state_tracker( - ctx.fw_rng_state, ctx.fwd_rng_state_tracker - ): - with paddle.amp.auto_cast( - enable=ctx.is_fw_autocast, - custom_white_list=ctx.amp_white_list, - custom_black_list=ctx.amp_black_list, - level=ctx.amp_level, - dtype=ctx.amp_dtype, - ): - detached_inputs = detach_variable(tuple(inputs)) - outputs = ctx.run_function(*detached_inputs) - else: - with paddle.amp.auto_cast( - enable=ctx.is_fw_autocast, - custom_white_list=ctx.amp_white_list, - custom_black_list=ctx.amp_black_list, - level=ctx.amp_level, - dtype=ctx.amp_dtype, - ): - detached_inputs = detach_variable(tuple(inputs)) - outputs = ctx.run_function(*detached_inputs) - - if isinstance(outputs, core.VarBase): - outputs = (outputs,) - assert len(outputs) == len(args) - - # run backward() with only tensor that requires grad - forward_outputs_with_grad = [] - # NOTE In Transformer-like network, if user put the attention mask into the recompute segment output, - # pylayer will force the stop_gradient of attention mask to be False, which will make the number of - # tensor that need grad does not match. - # the following backward_inputs_with_grad is used to avoid this case. - backward_inputs_with_grad = [] - for i in range(len(outputs)): - if ( - isinstance(outputs[i], core.VarBase) - and not outputs[i].stop_gradient - ): - forward_outputs_with_grad.append(outputs[i]) - backward_inputs_with_grad.append(args[i]) - - if len(forward_outputs_with_grad) == 0: - raise RuntimeError( - "none of output has requires_grad=True, this recompute() is not necessary" - ) - - # actually backward - with paddle.amp.auto_cast(enable=False): - paddle.autograd.backward( - forward_outputs_with_grad, backward_inputs_with_grad - ) - - grads = list( - inp._grad_ivar() - for inp in detached_inputs - if isinstance(inp, core.VarBase) - ) - return grads - - class RecomputeFunction(PyLayer): @staticmethod def forward(ctx, run_function, preserve_rng_state, *args, **kwargs):