未验证 提交 42c1e916 编写于 作者: H Hugh Pu 提交者: GitHub

feat(activation_checkpointing): add `non_reentrant_checkpoint` to support...

feat(activation_checkpointing): add `non_reentrant_checkpoint` to support inputs require no grad (#4118)

* feat: add `non_reentrant_checkpoint`

* feat: add missing output postprocess and change the hook to record leaf forward tensor refs

* fix: make the multi_grad_hook registered after graph construction

* fix: backward compatibility for multi_tensor_hook

* fix: nonlocal reference error of deepspeed_saved_tensors

* fix: reduce repeating hook registration

* test: add test for `activation_checkpointing.checkpointing.non_reentrant_checkpoint`

* Pass correct node size for ZeRO++ (#4085)

* Pass correct node size

* formatting

---------
Co-authored-by: NConnor Holmes <development@cmikeh2.me>
Co-authored-by: NMichael Wyatt <michaelwyatt@microsoft.com>

* add deepspeed chat arxiv report (#4110)

* add deepspeed chat arxiv report

* add zeroquant v2 and fp

* add selective enhencement

* add ignore for 'Youn' in spell checker

---------
Co-authored-by: Nyaozhewei <zheweiy@berkeley.edu>
Co-authored-by: NMichael Wyatt <michaelwyatt@microsoft.com>

* style: change flake8 detected style missmatch

* test: hack to clone the `test_activation_checkpointing` module for reuse and add regression tests

* doc: explain the introduction of `non_reentrant_checkpoint`

* doc: explain the test of `non_reentrant_checkpoint`

---------
Co-authored-by: NConnor Holmes <connorholmes@microsoft.com>
Co-authored-by: NConnor Holmes <development@cmikeh2.me>
Co-authored-by: NMichael Wyatt <michaelwyatt@microsoft.com>
Co-authored-by: NConglong Li <conglong.li@gmail.com>
Co-authored-by: Nyaozhewei <zheweiy@berkeley.edu>
Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 c69bd1f7
......@@ -18,6 +18,7 @@ import copy
import torch
import contextlib
from deepspeed import comm as dist
import weakref
import mmap
from torch import _C
......@@ -705,6 +706,271 @@ class CheckpointFunction(torch.autograd.Function):
return tuple(ret_list)
def non_reentrant_checkpoint(function, *args):
"""This function is union of `torch.utils.checkpoint._checkpoint_without_reentrant` and `CheckpointFunction` in this module
This function is aim to solve the back probagation error raised from all input requires no grad.
* has already been implemented in pytorch for a while, the solution is stable at most time except for jit module mode.
* can help to solve the issue which is hacked by `deepspeed.runtime.pipe.module.PipelineModule._is_checkpointable`
Main modifications compared to the implementation of torch:
1. adapt to the signature of `checkpoint` function in this module
2. solve the non-deterministic by random state management consistent with deepspeed `CheckpointFunction`
3. when there is partition or cpu checkpointing, gather them in the unpack_hook during back probagation
4. make all after backward blocks in the hook which will executed after all leaf nodes backward execution.
5. above 4. is inspired by `torch.autograd.graph.register_multi_grad_hook`, which is only implemented after 2.0.0
"""
global mpu, timers, SYNCHRONIZE, PROFILE_TIME
deepspeed_saved_tensors = None
non_tensor_args = None
tensor_flags = None
def save_args_for_backward(*all_args):
"""keep this function to reduce the modification from original implementation"""
nonlocal deepspeed_saved_tensors, non_tensor_args, tensor_flags
tensor_args, non_tensor_args, tensor_flags = extract_tensors(all_objects=all_args)
deepspeed_saved_tensors = tensor_args
non_tensor_args = non_tensor_args
tensor_flags = tensor_flags
if SYNCHRONIZE:
get_accelerator().synchronize()
if timers is None and PROFILE_TIME:
timers = Timers()
if PROFILE_TIME:
timers(FORWARD_GLOBAL_TIMER).start()
global num_layers
global mp_rank, mp_size, mp_group
global contiguous_data_buffers, contiguous_size_buffers
global data_offsets, size_offsets
if mp_rank is None:
if mpu is not None:
if hasattr(mpu, 'get_tensor_model_parallel_rank'):
mp_rank = mpu.get_tensor_model_parallel_rank()
mp_size = mpu.get_tensor_model_parallel_world_size()
mp_group = mpu.get_tensor_model_parallel_group()
else:
mp_rank = mpu.get_model_parallel_rank()
mp_size = mpu.get_model_parallel_world_size()
mp_group = mpu.get_model_parallel_group()
else:
mp_rank = 0
mp_size = 1
mp_group = None
global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset
if cuda_device is None:
see_memory_usage("First Forward Beginning", force=False)
if dist.get_rank() == 0:
logger.info(f"Activation Checkpointing Information")
logger.info(f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {CPU_CHECKPOINT}")
logger.info(
f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers")
logger.info(f"----Synchronization {SYNCHRONIZE}")
logger.info(f"----Profiling time in checkpointing {PROFILE_TIME}")
cuda_device = get_accelerator().current_device_name()
transport_stream = get_accelerator().Stream(device=cuda_device)
if PARTITION_ACTIVATIONS:
inputs = partition_activations(args, CPU_CHECKPOINT, CONTIGUOUS_CHECKPOINTING)
elif CPU_CHECKPOINT:
inputs = copy_to_device(args, device=torch.device('cpu'), criterion_func=is_activation_to_checkpoint)
# just in case something funky is happening such as reuse of inputs
inputs_cuda = copy_to_device(args, device=cuda_device, criterion_func=is_activation_to_checkpoint)
# Copy the rng states.
fwd_cpu_rng_state = torch.get_rng_state()
fwd_cuda_rng_state = get_accelerator().get_rng_state()
fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
if PARTITION_ACTIVATIONS:
new_args = get_partitioned_activations_for_backward(args, inputs, CONTIGUOUS_CHECKPOINTING)
assert len(new_args) % 2 == 0, f'save_for_backward called with odd number of args, {len(new_args)}'
save_args_for_backward(*new_args)
elif CPU_CHECKPOINT:
new_args = get_cpu_activations_for_backward(args, inputs)
save_args_for_backward(*new_args)
else:
save_args_for_backward(*args)
class Holder():
"""the place holder object used as activations to save memory"""
pass
# weakref seems utilized to discover the tensor deletion before a whole
# forward backward pair loop finished
storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
weak_holder_list = []
leaf_tensors = []
backward_visited_leaf_nodes = 0
def checkpoint_pack(tensor_from_forward):
"""used to record the activation order in the `weak_holder_list`
the activation order in holder list is consistent between the first forward and recomputing forward.
* the jit compiled forward will break the order consistency *
"""
res = Holder()
weak_holder_list.append(weakref.ref(res))
# if this is a leaf tensor, save it for backward progression trace
# leaf tensor used to be input or parameters, which is not activations and
# has no memory overhead
if tensor_from_forward.requires_grad and tensor_from_forward.is_leaf:
leaf_tensors.append(tensor_from_forward)
return res
def checkpoint_unpack(holder_from_backward):
"""retrieve the activations from recompute"""
nonlocal deepspeed_saved_tensors, non_tensor_args, tensor_flags
# if this is the first step of backward probagation, recompute the graph and save
# all the activations with the same order as `checkpoint_pack` does
if len(storage) == 0:
unpack_counter = 0
def replay_pack(tensor_from_replay):
"""save recompute activations"""
nonlocal unpack_counter
unpack_counter += 1
if weak_holder_list[unpack_counter - 1]() is None:
return
detached_activations = tensor_from_replay.detach()
storage[weak_holder_list[unpack_counter - 1]()] = detached_activations
return
def replay_unpack(none_value):
"""recompute graph need not to backward"""
raise RuntimeError("You are calling backwards on a tensor that is never exposed.")
global timers
see_memory_usage("In backward", force=False)
# removing pointers to the contiguous buffer memory
# so that they can be garbage collected once the checkpoints
# have been used
if SYNCHRONIZE:
get_accelerator().synchronize()
if PROFILE_TIME:
timers('backward').start()
if CONTIGUOUS_CHECKPOINTING:
global data_offsets, size_offsets
global contiguous_data_buffers, contiguous_size_buffers
for buffers in contiguous_data_buffers:
buffers = []
# frees up all the pointers to the checkpoints except for the ones
# stored by save for backward
contiguous_data_buffers = []
contiguous_size_buffers = []
data_offsets = []
size_offsets = []
see_memory_usage("In backward checkpointing code", force=False)
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad(), "
"please use .backward() if possible")
global cuda_device, transport_stream, PARTITION_ACTIVATIONS
# gather inputs which is partitioned or checkpointed before first forward
if PARTITION_ACTIVATIONS:
# with get_accelerator().stream(transport_stream):
inputs = gather_partitioned_activations(deepspeed_saved_tensors,
device=cuda_device if CPU_CHECKPOINT else None)
detached_inputs = detach_variable(inputs)
elif CPU_CHECKPOINT:
inputs = move_to_device(deepspeed_saved_tensors, cuda_device, is_activation_to_checkpoint)
detached_inputs = detach_variable(inputs)
else:
inputs = deepspeed_saved_tensors
detached_inputs = detach_variable(inputs)
# Add non tensor input args
detached_inputs = merge_tensors(tensor_objects=detached_inputs,
non_tensor_objects=non_tensor_args,
tensor_flags=tensor_flags)
# Store the current states.
bwd_cpu_rng_state = torch.get_rng_state()
bwd_cuda_rng_state = get_accelerator().get_rng_state()
bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
# Set the states to what it used to be before the forward pass.
torch.set_rng_state(fwd_cpu_rng_state)
_set_cuda_rng_state(fwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(fwd_cuda_rng_state_tracker)
see_memory_usage("In backward checkpointing code before forward", force=False)
with torch.enable_grad(), torch.autograd.graph.saved_tensors_hooks(replay_pack, replay_unpack):
_unused = function(*detached_inputs)
see_memory_usage("In backward checkpointing code after forward", force=False)
# Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_cpu_rng_state)
_set_cuda_rng_state(bwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)
deepspeed_saved_tensors = None
non_tensor_args = None
tensor_flags = None
if holder_from_backward not in storage:
raise RuntimeError("Attempt to retrieve a tensor saved by autograd multiple times without checkpoint"
" recomputation being triggered in between, this is not currently supported.")
return storage[holder_from_backward]
def after_backward_hook(_nonuse_grads):
"""the hook registered to all leaf tensors"""
nonlocal leaf_tensors, backward_visited_leaf_nodes
backward_visited_leaf_nodes += 1
if backward_visited_leaf_nodes == len(leaf_tensors):
see_memory_usage("After backward checkpointing code after backward", force=False)
if PROFILE_TIME:
timers('backward').stop()
timers.log(['backward'])
if SYNCHRONIZE:
get_accelerator().synchronize()
with torch.autograd.graph.saved_tensors_hooks(checkpoint_pack, checkpoint_unpack):
outputs = function(*inputs_cuda)
for leaf_tensor in leaf_tensors:
leaf_tensor.register_hook(after_backward_hook)
see_memory_usage("After running forward on the layer", force=False)
if PROFILE_TIME:
timers(FORWARD_GLOBAL_TIMER).stop()
timers.log([FORWARD_GLOBAL_TIMER])
if SYNCHRONIZE:
get_accelerator().synchronize()
all_outputs = []
if torch.is_tensor(outputs):
all_outputs += [outputs]
else:
all_outputs += outputs
if len(all_outputs) == 1:
return all_outputs[0]
else:
return tuple(all_outputs)
def checkpoint(function, *args):
"""Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint. """
......
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# TODO: add tests with model parallelism for activation partitioning and other features.
import sys
import torch
import pytest
from importlib import util
from deepspeed.runtime.activation_checkpointing.checkpointing import non_reentrant_checkpoint
from unit.common import DistributedTest
# the hack to clone the module `test_activation_checkpointing` and inject
# `non_reentrant_checkpoint` as the `ckpt` of the origin test module
ORG_SPEC = util.find_spec('test_activation_checkpointing')
test_act_ckpt = util.module_from_spec(ORG_SPEC)
ORG_SPEC.loader.exec_module(test_act_ckpt)
sys.modules['test_act_ckpt'] = test_act_ckpt
test_act_ckpt.ckpt = non_reentrant_checkpoint
HIDDEN_DIM = test_act_ckpt.HIDDEN_DIM
MaskedLinear = test_act_ckpt.MaskedLinear
MaskedLinearSeq = test_act_ckpt.MaskedLinearSeq
MaskedLinearSeqDup = test_act_ckpt.MaskedLinearSeqDup
DropMaskLinear = test_act_ckpt.DropMaskLinear
LinearNonTensorInput = test_act_ckpt.LinearNonTensorInput
LinearNonTensorOutput = test_act_ckpt.LinearNonTensorOutput
_test_activation_checkpoint = test_act_ckpt._test_activation_checkpoint
_mixed_mask = test_act_ckpt._mixed_mask
_bool_to_float = test_act_ckpt._bool_to_float
_test_activation_checkpoint_ordering = test_act_ckpt._test_activation_checkpoint_ordering
class TestActivationCheckpointWithGrad(test_act_ckpt.TestActivationCheckpoint):
"""test `non_reentrant_checkpoint` can still checkpoint activations for inputs with grad"""
pass
class TestCheckpointNonTensorWithGrad(test_act_ckpt.TestCheckpointNonTensor):
"""test `non_reentrant_checkpoint` can still checkpoint activations for inputs with grad"""
pass
class TestCheckpointNonTensorOutputOrderingWithGrad(test_act_ckpt.TestCheckpointNonTensorOutputOrdering):
"""test `non_reentrant_checkpoint` can still checkpoint activations for inputs with grad"""
pass
# below classes are used to test the graph with inputs have no grad and parameters has grad, namely partial graph?
@pytest.mark.parametrize('mask', [
_mixed_mask(),
_bool_to_float(_mixed_mask()),
])
class TestActivationCheckpointWithoutGrad(DistributedTest):
"""test all input tensors without grad"""
world_size = 1
def test_ckpt_inputs1_outputs1(self, mask):
module = torch.nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
inputs = torch.rand(HIDDEN_DIM)
_test_activation_checkpoint(module, inputs)
def test_ckpt_inputs2_outputs1(self, mask):
module = MaskedLinear(HIDDEN_DIM, HIDDEN_DIM)
inputs = torch.rand(HIDDEN_DIM)
_test_activation_checkpoint(module, inputs, mask)
def test_ckpt_inputs2_outputs2(self, mask):
module = MaskedLinearSeq(HIDDEN_DIM, HIDDEN_DIM)
inputs = torch.rand(HIDDEN_DIM)
_test_activation_checkpoint(module, inputs, mask)
def test_ckpt_inputs2_outputs3(self, mask):
module = MaskedLinearSeqDup(HIDDEN_DIM, HIDDEN_DIM)
inputs = torch.rand(HIDDEN_DIM)
_test_activation_checkpoint(module, inputs, mask)
def test_ckpt_arg_none(self, mask):
module = DropMaskLinear(HIDDEN_DIM, HIDDEN_DIM)
inputs = (torch.rand(HIDDEN_DIM), None)
_test_activation_checkpoint(module, *inputs)
@pytest.mark.parametrize('non_tensor', [None, 2, True, (None, 2.5), (None, True, torch.randn(HIDDEN_DIM))])
class TestCheckpointNonTensorWithoutGrad(DistributedTest):
"""test all input tensors without grad"""
world_size = 1
def test_ckpt_non_tensor_input(self, non_tensor):
module = LinearNonTensorInput(HIDDEN_DIM, HIDDEN_DIM)
inputs = torch.rand(HIDDEN_DIM)
_test_activation_checkpoint(module, inputs, non_tensor)
def test_ckpt_non_tensor_output(self, non_tensor):
module = LinearNonTensorOutput(non_tensor)
inputs = torch.rand(HIDDEN_DIM)
_test_activation_checkpoint(module, inputs)
@pytest.mark.parametrize('non_tensor_output', [
None, (torch.randn(HIDDEN_DIM), 2.5), (None, torch.randn(HIDDEN_DIM), True), (None, True, torch.randn(HIDDEN_DIM))
])
class TestCheckpointNonTensorOutputOrderingWithoutGrad(DistributedTest):
"""test all input tensors without grad"""
world_size = 1
def test_ckpt_non_tensor_output_ordering(self, non_tensor_output):
module = LinearNonTensorOutput(non_tensor_output)
inputs = torch.rand(HIDDEN_DIM)
# First return is a tensor
ordering = [True]
if type(non_tensor_output) in [list, tuple]:
ordering += [torch.is_tensor(t) for t in non_tensor_output]
else:
ordering += [torch.is_tensor(non_tensor_output)]
_test_activation_checkpoint_ordering(module, ordering, inputs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册