未验证 提交 85acf14c 编写于 作者: O Olatunji Ruwase 提交者: GitHub

Activation checkpointing improvements (#1254)

* Rename PA_TO_cpu

* Code cleanup

* Revert accidental change
上级 9645e7bc
......@@ -24,7 +24,7 @@ from torch.cuda import _lazy_call, device as device_ctx_manager
from deepspeed.runtime.config import DeepSpeedConfig
from deepspeed.utils import logger
from deepspeed.runtime.utils import move_to_device, see_memory_usage
from deepspeed.runtime.utils import copy_to_device, move_to_device, see_memory_usage
from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers
# DeepSpeed Checkpointing Enabled or Disabled
......@@ -50,7 +50,7 @@ timers = None
# optimization flags
PARTITION_ACTIVATIONS = False
PA_TO_CPU = False
CPU_CHECKPOINT = False
CONTIGUOUS_CHECKPOINTING = False
SYNCHRONIZE = False
PROFILE_TIME = False
......@@ -253,14 +253,20 @@ def get_partition_size(item):
return int(partition_size)
def get_full_inputs(tensors, device=None):
def gather_partitioned_activations(tensors, device=None):
global mp_rank, mp_size, mp_group
assert len(tensors) % 2 == 0, f'Expected even count of tensors, instead got {len(tensors)}'
inputs = []
num_args = int(len(tensors) / 2)
for i in range(num_args - 1):
for i in range(num_args):
item = tensors[2 * i]
size = tensors[2 * i + 1]
if not is_activation_to_checkpoint(item):
inputs.append(item)
continue
partition_size = item.numel()
tensor_size = partition_size * mp_size
if device is not None:
......@@ -281,7 +287,6 @@ def get_full_inputs(tensors, device=None):
item.data = input_tensor.data
inputs.append(item)
inputs.append(tensors[-2])
return tuple(inputs)
......@@ -324,7 +329,7 @@ def merge_tensors(tensor_objects, non_tensor_objects, tensor_flags):
real_tensor_flags = None
#remove the flags that are assigned to the size of the flattened tensors
# remove the flags that are assigned to the size of the flattened tensors
if PARTITION_ACTIVATIONS:
real_tensor_flags = []
previous_flag = False
......@@ -348,6 +353,132 @@ def merge_tensors(tensor_objects, non_tensor_objects, tensor_flags):
return tuple(merged_objects)
def is_activation_to_checkpoint(item):
"""
Is an activation to be checkpointed
"""
global mp_size
return torch.is_tensor(item) and item.is_floating_point() and item.numel() >= mp_size
def partition_activations(args, cpu_checkpoint, contiguous_checkpoint):
global contiguous_data_buffers, data_offsets
inputs = []
for i, item in enumerate(args):
if not is_activation_to_checkpoint(item):
inputs.append(item)
continue
partition_size = get_partition_size(item)
partition = item.detach().contiguous().view(-1).narrow(
0,
get_partition_start(item),
partition_size).clone()
buffer_device = torch.device('cpu') if cpu_checkpoint else partition.device
if contiguous_checkpoint:
if i >= len(contiguous_data_buffers):
tensor_list = [
torch.tensor(()).new_empty([partition_size],
dtype=partition.dtype,
device=buffer_device)
for i in range(num_layers)
]
contiguous_data_buffers.append(tensor_list)
data_offsets.append(0)
elif contiguous_data_buffers[i] is None:
tensor_list = [
torch.tensor(()).new_empty([partition_size],
dtype=partition.dtype,
device=buffer_device)
for i in range(num_layers)
]
contiguous_data_buffers[i] = tensor_list
data_offsets[i] = 0
# Because the 'new_empty' returns uninitialized pages,
# the pages need to be populated during the cudaMemcpy time
# which increases the data copy time. To avoid this, we
# pre-populate these pages by simply writing 0 ahead of
# the actual cudaMemcpy operation time. Due to the
# previously launched GPU kernels, there is a small
# window of time here for CPUs to populate pages asynchronously.
contiguous_data_buffers[i][data_offsets[i]].data[range(
0,
contiguous_data_buffers[i][data_offsets[i]].data.shape[0],
int(mmap.PAGESIZE /
contiguous_data_buffers[i][data_offsets[i]].data.element_size())
)] = 0
contiguous_partition = contiguous_data_buffers[i][
data_offsets[i]].data.copy_(partition.data)
data_offsets[i] = data_offsets[i] + 1
inputs.append(contiguous_partition)
else:
partition = partition.cpu() if CPU_CHECKPOINT else partition
inputs.append(partition)
return inputs
def get_partitioned_activations_for_backward(args, inputs, contiguous_checkpoint):
global contiguous_size_buffers, size_offsets
new_args = []
for i, (arg, inp) in enumerate(zip(args, inputs)):
size = torch.tensor(arg.size()) if torch.is_tensor(arg) else None
if not is_activation_to_checkpoint(arg):
new_args.append(arg)
new_args.append(size)
continue
arg.data = inp.data
new_args.append(arg)
if contiguous_checkpoint:
numel = size.numel()
if i >= len(contiguous_size_buffers):
tmp = torch.tensor(())
contiguous_size_buffers.append(
tmp.new_empty([numel * num_layers],
dtype=size.dtype,
device=size.device))
size_offsets.append(0)
elif contiguous_size_buffers[i] is None:
tmp = torch.tensor(())
contiguous_size_buffers[i] = tmp.new_empty([numel * num_layers],
dtype=size.dtype,
device=size.device)
size_offsets[i] = 0
contiguous_size = contiguous_size_buffers[i].narrow(
0,
size_offsets[i],
numel).data.copy_(size.data)
contiguous_size = contiguous_size.view_as(size)
size_offsets[i] = size_offsets[i] + numel
new_args.append(contiguous_size)
else:
new_args.append(size)
return new_args
def get_cpu_activations_for_backward(args, inputs):
new_args = []
for i, (arg, inp) in enumerate(zip(args, inputs)):
if not is_activation_to_checkpoint(arg):
new_args.append(arg)
continue
arg.data = inp.data
new_args.append(arg)
return new_args
class CheckpointFunction(torch.autograd.Function):
"""This function is adapted from torch.utils.checkpoint with
two main changes:
......@@ -399,7 +530,7 @@ class CheckpointFunction(torch.autograd.Function):
if dist.get_rank() == 0:
logger.info(f"Activation Checkpointing Information")
logger.info(
f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {PA_TO_CPU}"
f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {CPU_CHECKPOINT}"
)
logger.info(
f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers"
......@@ -411,69 +542,18 @@ class CheckpointFunction(torch.autograd.Function):
transport_stream = torch.cuda.Stream(device=cuda_device)
if PARTITION_ACTIVATIONS:
#inputs = [item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), get_partition_size(item)).clone() for item in args[:-1]]
# inputs.append(args[-1])
inputs = []
for i, item in enumerate(args[:-1]):
if not torch.is_tensor(item) or mp_size > item.numel():
inputs.append(item)
continue
partition_size = get_partition_size(item)
partition = item.detach().contiguous().view(-1).narrow(
0,
get_partition_start(item),
partition_size).clone()
if CONTIGUOUS_CHECKPOINTING:
buffer_device = torch.device(
'cpu') if PA_TO_CPU else partition.device
if i >= len(contiguous_data_buffers):
tensor_list = [
torch.tensor(()).new_empty([partition_size],
dtype=partition.dtype,
device=buffer_device)
for i in range(num_layers)
]
contiguous_data_buffers.append(tensor_list)
data_offsets.append(0)
elif contiguous_data_buffers[i] is None:
tensor_list = [
torch.tensor(()).new_empty([partition_size],
dtype=partition.dtype,
device=buffer_device)
for i in range(num_layers)
]
contiguous_data_buffers[i] = tensor_list
data_offsets[i] = 0
# Because the 'new_empty' returns uninitialized pages,
# the pages need to be populated during the cudaMemcpy time
# which increases the data copy time. To avoid this, we
# pre-populate these pages by simply writing 0 ahead of
# the actual cudaMemcpy operation time. Due to the
# previously launched GPU kernels, there is a small
# window of time here for CPUs to populate pages asynchronously.
contiguous_data_buffers[i][data_offsets[i]].data[range(
0,
contiguous_data_buffers[i][data_offsets[i]].data.shape[0],
int(mmap.PAGESIZE / contiguous_data_buffers[i][
data_offsets[i]].data.element_size()))] = 0
contiguous_partition = contiguous_data_buffers[i][
data_offsets[i]].data.copy_(partition.data)
data_offsets[i] = data_offsets[i] + 1
inputs.append(contiguous_partition)
else:
partition = partition.cpu() if PA_TO_CPU else partition
inputs.append(partition)
inputs.append(args[-1])
#just in case something funky is happening such as reuse of inputs
inputs_cuda = move_to_device(args, cuda_device)
inputs = partition_activations(args,
CPU_CHECKPOINT,
CONTIGUOUS_CHECKPOINTING)
elif CPU_CHECKPOINT:
inputs = copy_to_device(args,
device=torch.device('cpu'),
criterion_func=is_activation_to_checkpoint)
# just in case something funky is happening such as reuse of inputs
inputs_cuda = copy_to_device(args,
device=cuda_device,
criterion_func=is_activation_to_checkpoint)
# Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state()
......@@ -488,56 +568,15 @@ class CheckpointFunction(torch.autograd.Function):
see_memory_usage("After running forward on the layer", force=False)
del inputs_cuda
# with torch.cuda.stream(transport_stream):
# if PARTITION_ACTIVATIONS:
# new_args = []
# for arg, inp in zip(args,inputs):
# size= torch.tensor(arg.size())
# arg.data = inp.data
# new_args.append(arg)
# new_args.append(size)
# ctx.save_for_backward(*new_args)
if PARTITION_ACTIVATIONS:
new_args = []
for i, (arg, inp) in enumerate(zip(args, inputs)):
if not torch.is_tensor(arg):
new_args.append(arg)
continue
size = torch.tensor(arg.size())
arg.data = inp.data
new_args.append(arg)
if CONTIGUOUS_CHECKPOINTING:
numel = size.numel()
if i >= len(contiguous_size_buffers):
tmp = torch.tensor(())
contiguous_size_buffers.append(
tmp.new_empty([numel * num_layers],
dtype=size.dtype,
device=size.device))
size_offsets.append(0)
elif contiguous_size_buffers[i] is None:
tmp = torch.tensor(())
contiguous_size_buffers[i] = tmp.new_empty([numel * num_layers],
dtype=size.dtype,
device=size.device)
size_offsets[i] = 0
contiguous_size = contiguous_size_buffers[i].narrow(
0,
size_offsets[i],
numel).data.copy_(size.data)
contiguous_size = contiguous_size.view_as(size)
size_offsets[i] = size_offsets[i] + numel
new_args.append(contiguous_size)
else:
new_args.append(size)
# if dist.get_rank() == 0:
# logger.info(f"The stored tensor is {contiguous_size} and orginal one is {size} ")
new_args = get_partitioned_activations_for_backward(
args,
inputs,
CONTIGUOUS_CHECKPOINTING)
assert len(new_args) % 2 == 0, f'save_for_backward called with odd number of args, {len(new_args)}'
save_args_for_backward(*new_args)
elif CPU_CHECKPOINT:
new_args = get_cpu_activations_for_backward(args, inputs)
save_args_for_backward(*new_args)
else:
save_args_for_backward(*args)
......@@ -600,8 +639,14 @@ class CheckpointFunction(torch.autograd.Function):
if PARTITION_ACTIVATIONS:
# with torch.cuda.stream(transport_stream):
inputs = get_full_inputs(ctx.saved_tensors,
device=cuda_device if PA_TO_CPU else None)
inputs = gather_partitioned_activations(
ctx.saved_tensors,
device=cuda_device if CPU_CHECKPOINT else None)
detached_inputs = detach_variable(inputs)
elif CPU_CHECKPOINT:
inputs = move_to_device(ctx.saved_tensors,
cuda_device,
is_activation_to_checkpoint)
detached_inputs = detach_variable(inputs)
else:
inputs = ctx.saved_tensors
......@@ -727,7 +772,7 @@ def reset():
def _configure_using_config_file(config, mpu=None):
global num_layers, PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
PA_TO_CPU, SYNCHRONIZE, PROFILE_TIME
CPU_CHECKPOINT, SYNCHRONIZE, PROFILE_TIME
config = DeepSpeedConfig(config, mpu=mpu).activation_checkpointing_config
if dist.get_rank() == 0:
......@@ -735,7 +780,7 @@ def _configure_using_config_file(config, mpu=None):
PARTITION_ACTIVATIONS = config.partition_activations
CONTIGUOUS_CHECKPOINTING = config.contiguous_memory_optimization
num_layers = config.number_checkpoints
PA_TO_CPU = config.cpu_checkpointing
CPU_CHECKPOINT = config.cpu_checkpointing
SYNCHRONIZE = config.synchronize_checkpoint_boundary
PROFILE_TIME = config.profile
......@@ -745,12 +790,12 @@ def _configure_defaults():
global mpu, num_layers, deepspeed_checkpointing_enabled
global PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
PA_TO_CPU, SYNCHRONIZE, PROFILE_TIME
CPU_CHECKPOINT, SYNCHRONIZE, PROFILE_TIME
PARTITION_ACTIVATIONS = False
CONTIGUOUS_CHECKPOINTING = False
num_layers = False
PA_TO_CPU = False
CPU_CHECKPOINT = False
SYNCHRONIZE = False
PROFILE_TIME = False
deepspeed_checkpointing_enabled = True
......@@ -804,7 +849,7 @@ def configure(
global mpu, num_layers, deepspeed_checkpointing_enabled
global PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
PA_TO_CPU, SYNCHRONIZE, PROFILE_TIME
CPU_CHECKPOINT, SYNCHRONIZE, PROFILE_TIME
_configure_defaults()
......@@ -824,7 +869,7 @@ def configure(
num_layers = num_checkpoints
if checkpoint_in_cpu is not None:
PA_TO_CPU = checkpoint_in_cpu
CPU_CHECKPOINT = checkpoint_in_cpu
if synchronize is not None:
SYNCHRONIZE = synchronize
......@@ -832,8 +877,8 @@ def configure(
if profile is not None:
PROFILE_TIME = profile
if PA_TO_CPU or CONTIGUOUS_CHECKPOINTING:
assert PARTITION_ACTIVATIONS, "CPU Checkpointing/Contiguous Checkpointing is only availble with partitioned activations. Set partitioned activations to true in deepspeed config"
if CONTIGUOUS_CHECKPOINTING:
assert PARTITION_ACTIVATIONS, "Contiguous Checkpointing is only availble with partitioned activations. Set partitioned activations to true in deepspeed config"
if CONTIGUOUS_CHECKPOINTING:
assert num_layers is not None, "Must specify the number of layers with contiguous memory checkpointing"
......
......@@ -59,24 +59,52 @@ def set_random_seed(seed):
torch.manual_seed(seed)
def move_to_device(item, device):
def copy_to_device(item, device, criterion_func):
"""
Move tensor onto device. Works on individual tensors, and tensors contained/nested in lists, tuples, and dicts.
Return a copy of tensor on specified device.
Works on individual tensors, and tensors contained/nested in lists, tuples, and dicts.
Parameters:
item: tensor to move or (possibly nested) container of tensors to move.
item: tensor to copy or (possibly nested) container of tensors to copy.
device: target device
criterion_func: Function to restrict copy operation to items meet criterion
Returns:
None
"""
if torch.is_tensor(item):
if criterion_func(item):
return item.to(device)
elif isinstance(item, list):
return [move_to_device(v, device) for v in item]
return [copy_to_device(v, device, criterion_func) for v in item]
elif isinstance(item, tuple):
return tuple([copy_to_device(v, device, criterion_func) for v in item])
elif isinstance(item, dict):
return {k: copy_to_device(v, device, criterion_func) for k, v in item.items()}
else:
return item
def move_to_device(item, device, criterion_func):
"""
Move tensor on to specified device by changing the storage.
Works on individual tensors, and tensors contained/nested in lists, tuples, and dicts.
Parameters:
item: tensor to move or (possibly nested) container of tensors to move.
device: target device
criterion_func: Function to restrict move operation to items meet criterion
Returns:
None
"""
if criterion_func(item):
device_copy = item.to(device)
item.data = device_copy.data
return item
elif isinstance(item, list):
return [move_to_device(v, device, criterion_func) for v in item]
elif isinstance(item, tuple):
return tuple([move_to_device(v, device) for v in item])
return tuple([move_to_device(v, device, criterion_func) for v in item])
elif isinstance(item, dict):
return {k: move_to_device(v, device) for k, v in item.items()}
return {k: move_to_device(v, device, criterion_func) for k, v in item.items()}
else:
return item
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册