未验证 提交 97207c8c 编写于 作者: O Olatunji Ruwase 提交者: GitHub

ZeRO2-Offload: Disable copy overlapping (#1219)

* Disable copy stream

* Format fixes

* Remove debug codes

* Remove debug codes

* Fix indent
上级 d1a7a55e
......@@ -20,8 +20,8 @@ from deepspeed.ops.op_builder import UtilsBuilder
from deepspeed.utils import logger
from deepspeed.git_version_info import version
#Toggle this to true to enable correctness test
#with gradient partitioning and without
# Toggle this to true to enable correctness test
# with gradient partitioning and without
pg_correctness_test = False
......@@ -171,29 +171,29 @@ class FP16_DeepSpeedZeroOptimizer(object):
self.fp16_groups = []
self.fp16_groups_flat = []
#param partitioned by data parallel degree
#this will contain a list of equal sized tensors
#each of which will be updated by a different process
# param partitioned by data parallel degree
# this will contain a list of equal sized tensors
# each of which will be updated by a different process
self.parallel_partitioned_fp16_groups = []
#a single 32-bit partition of the parallel partitioned parameters
#that this process will update
# a single 32-bit partition of the parallel partitioned parameters
# that this process will update
self.single_partition_of_fp32_groups = []
#param partition info
# param partition info
#These are the parameters in each group that will not be updated by this process directly
# These are the parameters in each group that will not be updated by this process directly
self.params_not_in_partition = []
#These are the parameters that will be updated by this process directly
# These are the parameters that will be updated by this process directly
self.params_in_partition = []
#Offset from the first paramter in the the self.params_in_partition
#the parameter boundaries may not align with partition boundaries
#so we need to keep track of the offset
# Offset from the first paramter in the the self.params_in_partition
# the parameter boundaries may not align with partition boundaries
# so we need to keep track of the offset
self.first_offset = []
#number of elements per partition in each group
# number of elements per partition in each group
self.partition_size = []
partition_id = dist.get_rank(group=self.dp_process_group)
......@@ -220,11 +220,11 @@ class FP16_DeepSpeedZeroOptimizer(object):
padding = 0
self.groups_padding.append(padding)
#not sure why apex was cloning the weights before flattening
#removing cloning here
# not sure why apex was cloning the weights before flattening
# removing cloning here
see_memory_usage(f"Before moving param group {i} to CPU")
#move all the parameters to cpu to free up GPU space for creating flat buffer
# move all the parameters to cpu to free up GPU space for creating flat buffer
move_to_cpu(self.fp16_groups[i])
see_memory_usage(f"After moving param group {i} to CPU")
......@@ -239,7 +239,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
self.round_robin_fp16_groups.append(round_robin_tensors)
self.round_robin_fp6_indices.append(round_robin_indices)
#create flat buffer in CPU and move to GPU
# create flat buffer in CPU and move to GPU
self.fp16_groups_flat.append(
self.flatten_dense_tensors_aligned(
self.round_robin_fp16_groups[i],
......@@ -254,8 +254,8 @@ class FP16_DeepSpeedZeroOptimizer(object):
# set model fp16 weight to slices of flattened buffer
self._update_model_fp16_weights(i)
#divide the flat weights into near equal partition equal to the data parallel degree
#each process will compute on a different part of the partition
# divide the flat weights into near equal partition equal to the data parallel degree
# each process will compute on a different part of the partition
data_parallel_partitions = self.get_data_parallel_partitions(
self.fp16_groups_flat[i])
self.parallel_partitioned_fp16_groups.append(data_parallel_partitions)
......@@ -288,13 +288,12 @@ class FP16_DeepSpeedZeroOptimizer(object):
self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False)
self.reduction_stream = torch.cuda.Stream()
self.cpu_computation_stream = torch.cuda.Stream()
self.migration_stream = torch.cuda.Stream()
self.copy_grad_stream = torch.cuda.Stream()
self.callback_queued = False
self.param_dict = {}
#map between param_id and bool to specify if a param is in this partition
# map between param_id and bool to specify if a param is in this partition
self.is_param_in_current_partition = {}
# CPU-Offload requires contiguous gradients
......@@ -306,7 +305,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
self._release_ipg_buffers()
self.previous_reduced_grads = None
#simplified param id
# simplified param id
self.param_id = {}
largest_param_numel = 0
......@@ -349,40 +348,40 @@ class FP16_DeepSpeedZeroOptimizer(object):
self.first_offset[i],
self.partition_size[i])
#mapping from parameter to partition that it belongs to
# mapping from parameter to partition that it belongs to
self.param_to_partition_ids = {}
#stores if a partition has been reduced in this step
# stores if a partition has been reduced in this step
self.is_partition_reduced = {}
#number of grads in partition that still need to be computed
# number of grads in partition that still need to be computed
self.remaining_grads_in_partition = {}
#total number of grads in partition
# total number of grads in partition
self.total_grads_in_partition = {}
#stores if a grad in a partition has been computed or not
# stores if a grad in a partition has been computed or not
self.is_grad_computed = {}
#stores the offset at which a parameter gradient needs to be inserted in a partition
# stores the offset at which a parameter gradient needs to be inserted in a partition
self.grad_partition_insertion_offset = {}
#the offset in the gradient at which it must be inserted at the beginning of the partition
# the offset in the gradient at which it must be inserted at the beginning of the partition
self.grad_start_offset = {}
#will store the averaged gradients required by this partition
# will store the averaged gradients required by this partition
self.averaged_gradients = {}
# store index of first parameter in each partition
self.first_param_index_in_partition = {}
#initializes all data structures for implementing gradient partitioning
# initializes all data structures for implementing gradient partitioning
self.initialize_gradient_partitioning_data_structures()
#resets the data structure value for the next backward propagation
# resets the data structure value for the next backward propagation
self.reset_partition_gradient_structures()
#creates backward hooks for gradient partitioning
# creates backward hooks for gradient partitioning
if self.partition_gradients or self.overlap_comm:
self.create_reduce_and_remove_grad_hooks()
......@@ -535,7 +534,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
self.reduce_ipg_grads()
self.report_ipg_memory_usage(f"In ipg_epilogue after reduce_ipg_grads", 0)
#if dist.get_rank() == 0:
# if dist.get_rank() == 0:
# logger.info("Params already reduced %s", self.params_already_reduced)
for i in range(len(self.params_already_reduced)):
self.params_already_reduced[i] = False
......@@ -564,7 +563,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
device=torch.cuda.current_device(),
return_tensor_list=True)
for accumulated_grad, new_avg_grad in zip(self.averaged_gradients[i],avg_new):
for accumulated_grad, new_avg_grad in zip(self.averaged_gradients[i], avg_new):
accumulated_grad.add_(new_avg_grad)
self._release_ipg_buffers()
......@@ -629,7 +628,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
elif start_index > current_index and start_index < (current_index +
param_size):
assert (first_offset==0), "This can happen either zero or only once as this must be the first tensor in the partition"
assert (first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition"
first_offset = start_index - current_index
set_key_value_list(self.param_to_partition_ids[i],
......@@ -710,7 +709,6 @@ class FP16_DeepSpeedZeroOptimizer(object):
param.numel())
param_id = self.get_param_id(param)
assert self.params_already_reduced[param_id] == False, \
f"The parameter {param_id} has already been reduced. \
Gradient computed twice for this partition. \
......@@ -720,7 +718,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
self.extra_large_param_to_reduce = param
elif self.contiguous_gradients:
#keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening
# keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening
new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow(
0,
self.elements_in_ipg_bucket,
......@@ -844,13 +842,13 @@ class FP16_DeepSpeedZeroOptimizer(object):
num_elements = tensor.numel()
tensor_offset = 0
#we need to offset to get to the right element
# we need to offset to get to the right element
if i == 0 and first_offset > 0:
tensor_offset = first_offset
num_elements = num_elements - tensor_offset
param_start_offset = first_offset
#we dont need all elements of the tensor
# we dont need all elements of the tensor
if num_elements > (partition_size - current_offset):
num_elements = partition_size - current_offset
......@@ -869,7 +867,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
def async_accumulate_grad_in_cpu(self, param):
param_id = self.get_param_id(param)
#copy to a preexisiting buffer to avoid memory allocation penalty
# copy to a preexisiting buffer to avoid memory allocation penalty
dest_buffer = self.temp_grad_buffer_for_cpu_offload.view(-1).narrow(
0,
0,
......@@ -887,7 +885,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
def async_accumulate_grad_in_cpu_via_gpu(self, param):
param_id = self.get_param_id(param)
#copy to a preexisiting buffer to avoid memory allocation penalty
# copy to a preexisiting buffer to avoid memory allocation penalty
dest_buffer = self.temp_grad_buffer_for_gpu_offload.view(-1).narrow(
0,
0,
......@@ -904,7 +902,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
non_blocking=True)
param.grad.data.view(-1).add_(dest_buffer)
#at the boundary we will send 32bit directly
# at the boundary we will send 32bit directly
if not self.is_gradient_accumulation_boundary:
self.accumulated_grads_in_cpu[param_id].data.copy_(param.grad.data.view(-1),
non_blocking=True)
......@@ -1039,7 +1037,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
device=torch.cuda.current_device())
see_memory_usage(f"after copying {total_size} gradients into partition")
#The allreduce buffer will be rewritted. Copy the gradients in partition to a new buffer
# The allreduce buffer will be rewritted. Copy the gradients in partition to a new buffer
new_grad_tensor = self.grads_in_partition.view(-1).narrow(
0,
self.grads_in_partition_offset,
......@@ -1050,18 +1048,12 @@ class FP16_DeepSpeedZeroOptimizer(object):
self.grads_in_partition_offset += param.numel()
def reduce_ipg_grads(self):
if self.overlap_comm:
stream = self.reduction_stream
elif self.cpu_offload:
stream = self.copy_grad_stream
else:
stream = torch.cuda.current_stream()
if self.contiguous_gradients:
if self.extra_large_param_to_reduce is not None:
assert len(self.params_in_ipg_bucket) == 1, "more than 1 param in ipg bucket, this shouldn't happen"
_, _, param_id = self.params_in_ipg_bucket[0]
assert self.get_param_id(self.extra_large_param_to_reduce) == param_id, "param in ipg bucket does not match extra-large param"
assert self.get_param_id(
self.extra_large_param_to_reduce) == param_id, "param in ipg bucket does not match extra-large param"
self.average_tensor(self.extra_large_param_to_reduce.grad.view(-1))
self.extra_large_param_to_reduce = None
else:
......@@ -1072,6 +1064,16 @@ class FP16_DeepSpeedZeroOptimizer(object):
self.grads_in_ipg_bucket,
elements_per_buffer=self.elements_in_ipg_bucket)
if self.overlap_comm:
stream = self.reduction_stream
elif self.cpu_offload:
# TODO: copy_grad_stream is disabled because of race with reduce. This hurts perf and should be fixed.
# torch.cuda.synchronize()
# stream = self.copy_grad_stream
stream = torch.cuda.current_stream()
else:
stream = torch.cuda.current_stream()
with torch.cuda.stream(stream):
for _, param, param_id in self.params_in_ipg_bucket:
......@@ -1081,7 +1083,6 @@ class FP16_DeepSpeedZeroOptimizer(object):
Multiple gradient reduction is currently not supported"
self.params_already_reduced[param_id] = True
if not self.is_param_in_current_partition[param_id]:
if self.overlap_comm and self.contiguous_gradients is False:
# Clear grads of other partitions during the next reduction
......@@ -1204,7 +1205,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
param.grad = None
self.previous_reduced_grads = None
#if rank is specified do a reduction instead of an allreduce
# if rank is specified do a reduction instead of an allreduce
def allreduce_and_copy(self, small_bucket, rank=None, log=None):
if self.overlap_comm:
torch.cuda.synchronize()
......@@ -1233,10 +1234,12 @@ class FP16_DeepSpeedZeroOptimizer(object):
if numel > numel_per_bucket:
self.allreduce_and_copy(small_bucket, rank=rank, log=None)
small_bucket = []
if len(small_bucket) > 0:
self.allreduce_and_copy(small_bucket, rank=rank, log=log)
#allows using reduction of gradients instead of using all_reduce
# allows using reduction of gradients instead of using all_reduce
def buffered_reduce_fallback(self,
rank,
grads,
......@@ -1254,8 +1257,8 @@ class FP16_DeepSpeedZeroOptimizer(object):
#############################################################################
#############################################################################
#views the tensor as multiple partitions and returns
#those partitions
# views the tensor as multiple partitions and returns
# those partitions
def get_data_parallel_partitions(self, tensor):
partitions = []
......@@ -1297,7 +1300,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
tensor_size):
params_in_partition.append(tensor)
assert (first_offset==0), "This can happen either zero or only once as this must be the first tensor in the partition"
assert (first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition"
first_offset = start_index - current_index
else:
......@@ -1363,7 +1366,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
total_norm = total_norm_cuda[0].item()
else:
total_norm = 0.0
#if dist.get_rank() == 0:
# if dist.get_rank() == 0:
# logger.info(f"Total Norm begining {total_norm}")
for g, p in zip(gradients, params):
if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
......@@ -1387,9 +1390,9 @@ class FP16_DeepSpeedZeroOptimizer(object):
return total_norm
#creates a flat fused tensor from the tensor list starting at the first_offset
#in the first tensor of the list. If there are not enough elements in the tensor
#list then the flat tensor will be padded with zeros
# creates a flat fused tensor from the tensor list starting at the first_offset
# in the first tensor of the list. If there are not enough elements in the tensor
# list then the flat tensor will be padded with zeros
def get_flat_partition(self,
tensor_list,
first_offset,
......@@ -1407,17 +1410,17 @@ class FP16_DeepSpeedZeroOptimizer(object):
num_elements = tensor.numel()
tensor_offset = 0
#we need to offset to get to the right element
# we need to offset to get to the right element
if i == 0 and first_offset > 0:
tensor_offset = first_offset
num_elements = num_elements - tensor_offset
#we dont need all elements of the tensor
# we dont need all elements of the tensor
if num_elements > (partition_size - current_size):
num_elements = partition_size - current_size
#we need a narrow view of the tensor based on the tensor offset and number of elements that
#we need from this tensor
# we need a narrow view of the tensor based on the tensor offset and number of elements that
# we need from this tensor
if tensor_offset > 0 or num_elements < tensor.numel():
flat_tensor_list.append(tensor.contiguous().view(-1).narrow(
0,
......@@ -1428,7 +1431,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
current_size = current_size + num_elements
#this means its the last partition and does not align with the dp boundary. We need to pad before flattening
# this means its the last partition and does not align with the dp boundary. We need to pad before flattening
if current_size < partition_size:
flat_tensor_list.append(
torch.zeros(int(partition_size - current_size),
......@@ -1474,9 +1477,6 @@ class FP16_DeepSpeedZeroOptimizer(object):
"""
self.micro_step_id = -1
if self.cpu_offload:
torch.cuda.current_stream().wait_stream(self.migration_stream)
see_memory_usage(f"In step before checking overflow")
# First compute norm for all group so we know if there is overflow
......@@ -1524,10 +1524,10 @@ class FP16_DeepSpeedZeroOptimizer(object):
self.get_grad_norm_direct(self.averaged_gradients[i],
self.params_in_partition[i]))
#free gradients for all the prameters that are not updated by this process
# free gradients for all the prameters that are not updated by this process
self.free_grad_in_param_list(self.params_not_in_partition[i])
#create a flat gradients for parameters updated by this process
# create a flat gradients for parameters updated by this process
# If we are last partition, ensure we have same size grads and partition size, if not pad with zero tensors
if partition_id == dist.get_world_size(group=self.dp_process_group) - 1:
single_grad_partition = self.flatten_dense_tensors_aligned(
......@@ -1538,10 +1538,11 @@ class FP16_DeepSpeedZeroOptimizer(object):
single_grad_partition = self.flatten(self.averaged_gradients[i]).to(
self.single_partition_of_fp32_groups[i].dtype)
assert single_grad_partition.numel() == self.partition_size[i], \
"averaged gradients have different number of elements that partition size {} {} {} {}".format(single_grad_partition.numel(), self.partition_size[i], i, partition_id)
"averaged gradients have different number of elements that partition size {} {} {} {}".format(
single_grad_partition.numel(), self.partition_size[i], i, partition_id)
self.single_partition_of_fp32_groups[i].grad = single_grad_partition
#release all the gradient since we have already created a necessary copy in dp_grad_partition
# release all the gradient since we have already created a necessary copy in dp_grad_partition
self.free_grad_in_param_list(self.params_in_partition[i])
self.averaged_gradients[i] = None
......@@ -1567,7 +1568,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
else:
self.optimizer.step()
#get rid of the fp32 gradients. Not needed anymore
# get rid of the fp32 gradients. Not needed anymore
if not self.cpu_offload:
for group in self.single_partition_of_fp32_groups:
group.grad = None
......@@ -1581,10 +1582,10 @@ class FP16_DeepSpeedZeroOptimizer(object):
self.reset_cpu_buffers()
self.start_timers([OPTIMIZER_ALLGATHER])
#gather the updated weights from everyone
# gather the updated weights from everyone
for group_id, partitioned_params in enumerate(self.parallel_partitioned_fp16_groups):
#Sequential AllGather Best of both worlds
# Sequential AllGather Best of both worlds
dp_world_size = dist.get_world_size(group=self.dp_process_group)
num_shards = max(
1,
......@@ -1721,8 +1722,6 @@ class FP16_DeepSpeedZeroOptimizer(object):
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
"""
self.micro_step_id += 1
if self.cpu_offload:
torch.cuda.current_stream().wait_stream(self.migration_stream)
if self.contiguous_gradients:
self.ipg_buffer = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册