未验证 提交 9d79cfd1 编写于 作者: O Olatunji Ruwase 提交者: GitHub

Respect memory pinning config (#4131)

* Respect memory pinning config

* Bug fix
上级 7a282db8
......@@ -1452,7 +1452,7 @@ class DeepSpeedEngine(Module):
expert_data_parallel_group=self.expert_data_parallel_group if self.has_moe_layers else None,
reduce_scatter=self.zero_reduce_scatter(),
overlap_comm=overlap_comm,
cpu_offload=self.zero_cpu_offload(),
offload_optimizer_config=self.zero_offload_optimizer(),
mpu=self.mpu,
postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor(),
......
......@@ -120,7 +120,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
expert_data_parallel_group=None,
reduce_scatter=True,
overlap_comm=False,
cpu_offload=False,
offload_optimizer_config=None,
mpu=None,
clip_grad=0.0,
gradient_accumulation_dtype=torch.float32,
......@@ -135,10 +135,17 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
fp16_master_weights_and_gradients=False,
elastic_checkpoint=False):
if offload_optimizer_config is not None and offload_optimizer_config.device != OffloadDeviceEnum.none:
self.cpu_offload = True
self.cpu_offload_pin_memory = offload_optimizer_config.pin_memory
else:
self.cpu_offload = False
self.cpu_offload_pin_memory = False
if dist.get_rank() == 0:
logger.info(f"Reduce bucket size {reduce_bucket_size}")
logger.info(f"Allgather bucket size {allgather_bucket_size}")
logger.info(f"CPU Offload: {cpu_offload}")
logger.info(f"CPU Offload: {self.cpu_offload}")
logger.info(f'Round robin gradient partitioning: {round_robin_gradients}')
# The fused optimizer does all the work. We need this layer for two reason:
# 1. maintain same user API from apex.fp16_utils
......@@ -153,7 +160,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
# - flat by groups, not keeping state. TODO: remove state explicitly?
# - master grad and unflat master weight never exist. TODO: a way to save out unflat master?
if not get_accelerator().is_available():
raise SystemError("Cannot use fp16 without accelerator.")
raise SystemError("Accelerator is not detected, cannot perform low precision training (e.g., fp16, bf16).")
self.optimizer = init_optimizer
# Use torch (un)flatten ops
......@@ -170,9 +177,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self.overlap_comm = overlap_comm
self.cpu_offload = cpu_offload
self.deepspeed_adam_offload = cpu_offload
self.deepspeed_adam_offload = self.cpu_offload
self.device = get_accelerator().current_device_name() if not self.cpu_offload else 'cpu'
......@@ -195,7 +200,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self.is_gradient_accumulation_boundary = True
# CPU-Offload requires contiguous gradients
self.contiguous_gradients = contiguous_gradients or cpu_offload
self.contiguous_gradients = contiguous_gradients or self.cpu_offload
self.has_moe_layers = has_moe_layers
if self.has_moe_layers:
......@@ -440,8 +445,12 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self.norm_for_param_grads = {}
self.local_overflow = False
self.grad_position = {}
self.temp_grad_buffer_for_cpu_offload = get_accelerator().pin_memory(
torch.zeros(largest_param_numel, device=self.device, dtype=self.dtype))
self.temp_grad_buffer_for_cpu_offload = torch.zeros(largest_param_numel,
device=self.device,
dtype=self.dtype)
if self.cpu_offload_pin_memory:
self.temp_grad_buffer_for_cpu_offload = get_accelerator().pin_memory(
self.temp_grad_buffer_for_cpu_offload)
self.temp_grad_buffer_for_gpu_offload = torch.zeros(largest_param_numel,
device=get_accelerator().current_device_name(),
dtype=self.dtype)
......@@ -631,7 +640,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
dtype=self.single_partition_of_fp32_groups[i].dtype,
device=self.device)
self.single_partition_of_fp32_groups[i].grad = get_accelerator().pin_memory(
single_grad_partition) if self.cpu_offload else single_grad_partition
single_grad_partition) if self.cpu_offload_pin_memory else single_grad_partition
# Initialize the optimizer states with the flattened fp32 partition.
# State initialization for the Adagrad optimizer occurs at construction as opposed to other optimizers
......@@ -1101,7 +1110,8 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
#buffer for storing gradients for this parameter in CPU
def buffer_to_accumulate_to_in_cpu():
if not self.fp16_master_weights_and_gradients:
return get_accelerator().pin_memory(torch.zeros(param.numel(), dtype=param.dtype, device=self.device))
buffer = torch.zeros(param.numel(), dtype=param.dtype, device=self.device)
return get_accelerator().pin_memory(buffer) if self.cpu_offload_pin_memory else buffer
else:
return self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(0, dest_offset, num_elements)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册