未验证 提交 c1af73f7 编写于 作者: S Siddharth Singh 提交者: GitHub

Improving memory utilization of Z2+MoE (#2079)

* Shards expert parameter groups
* Do upscaling, optimizer and deletion of fp32 grads one-by-one on each parameter group in zero-2
Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 b0523787
......@@ -59,8 +59,9 @@ def split_params_grads_into_shared_and_expert_params(
return shared_grads, expert_grads
def split_params_into_different_moe_groups_for_optimizer(
param_groups: Tuple[Dict]) -> Tuple[Dict]:
def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dict],
max_group_size=178956971
) -> Tuple[Dict]:
"""Split parameters into different MoE groups for optimizer
Args:
......@@ -112,8 +113,32 @@ def split_params_into_different_moe_groups_for_optimizer(
param_group['params'] = new_params
# Flatten the moe groups
for k, v in group_moe.items():
for k1, v1 in v.items():
param_groups.append(v1)
if max_group_size is not None:
for k, v in group_moe.items():
for k1, v1 in v.items():
cur_group = []
all_groups = []
size_of_cur_group = 0
for param in v1['params']:
if size_of_cur_group + param.numel() <= max_group_size:
cur_group.append(param)
size_of_cur_group += param.numel()
else:
all_groups.append(cur_group)
cur_group = [param]
size_of_cur_group = param.numel()
if cur_group:
all_groups.append(cur_group)
for group in all_groups:
new_dict = {}
for key, val in v1.items():
if key != 'params':
new_dict[key] = val
new_dict['params'] = group
param_groups.append(new_dict)
else:
for k, v in group_moe.items():
for k1, v1 in v.items():
param_groups.append(v1)
return tuple(param_groups)
......@@ -1653,6 +1653,44 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self.custom_loss_scaler = True
self.external_loss_scale = loss_scale
def scaled_global_norm(self, norm_type=2):
assert norm_type == 2, "only L2 norm supported"
norm_groups = []
for i, group in enumerate(self.bit16_groups):
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
if self.cpu_offload:
norm_groups.append(
self.complete_grad_norm_calculation_for_cpu_offload(
self.params_in_partition[i]))
single_grad_partition = self.single_partition_of_fp32_groups[i].grad
else:
norm_groups.append(
self.get_grad_norm_direct(self.averaged_gradients[i],
self.params_in_partition[i]))
if self.has_moe_layers:
self._average_expert_grad_norms(norm_groups)
# note that the get_global_norm function only supports l2 norm
return get_global_norm(norm_list=norm_groups)
def get_bit16_param_group(self, group_no):
bit16_partitions = self.parallel_partitioned_bit16_groups[group_no]
partition_id = dist.get_rank(group=self.real_dp_process_group[group_no])
return [
bit16_partitions[dist.get_rank(group=self.real_dp_process_group[group_no])]
]
def _optimizer_step(self, group_no):
original_param_groups = self.optimizer.param_groups
self.optimizer.param_groups = [original_param_groups[group_no]]
from deepspeed.ops.adam import DeepSpeedCPUAdam
if type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half:
self.optimizer.step(fp16_param_groups=[self.get_bit16_param_group(group_no)])
else:
self.optimizer.step()
self.optimizer.param_groups = original_param_groups
def step(self, closure=None):
"""
Not supporting closure.
......@@ -1671,7 +1709,6 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
prev_scale = self.loss_scale
self._update_scale(self.overflow)
if self.overflow:
if dist.get_rank() == 0:
logger.info(
"[deepspeed] OVERFLOW! Rank {} Skipping step. Attempted loss scale: {}, "
......@@ -1692,22 +1729,33 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self.stop_timers(timer_names)
return
self.start_timers([OPTIMIZER_GRADIENTS])
norm_groups = []
single_partition_grad_groups = []
# skip = False
# Step 1:- Calculate gradient norm using fp-16 grads
see_memory_usage('Before norm calculation')
scaled_global_grad_norm = self.scaled_global_norm()
self._global_grad_norm = scaled_global_grad_norm / self.loss_scale
see_memory_usage('After norm before optimizer')
# Step 2:- run optimizer and upscaling simultaneously
for i, group in enumerate(self.bit16_groups):
self.start_timers([OPTIMIZER_GRADIENTS])
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
if self.cpu_offload:
norm_groups.append(
self.complete_grad_norm_calculation_for_cpu_offload(
self.params_in_partition[i]))
single_grad_partition = self.single_partition_of_fp32_groups[i].grad
else:
norm_groups.append(
self.get_grad_norm_direct(self.averaged_gradients[i],
self.params_in_partition[i]))
self.unscale_and_clip_grads([single_grad_partition],
scaled_global_grad_norm)
self.stop_timers([OPTIMIZER_GRADIENTS])
self.start_timers([OPTIMIZER_STEP])
self._optimizer_step(i)
from deepspeed.ops.adam import DeepSpeedCPUAdam
if not (type(self.optimizer) == DeepSpeedCPUAdam
and self.dtype == torch.half):
bit16_partitions = self.parallel_partitioned_bit16_groups[i]
fp32_partition = self.single_partition_of_fp32_groups[i]
bit16_partitions[partition_id].data.copy_(fp32_partition.data)
self.stop_timers([OPTIMIZER_STEP])
else:
# free gradients for all the parameters that are not updated by this process(ZeRO stage2)
self.free_grad_in_param_list(self.params_not_in_partition[i])
......@@ -1732,53 +1780,22 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self.averaged_gradients[i] = None
single_partition_grad_groups.append(single_grad_partition)
if self.has_moe_layers:
self._average_expert_grad_norms(norm_groups)
scaled_global_grad_norm = get_global_norm(norm_list=norm_groups)
self.unscale_and_clip_grads(single_partition_grad_groups,
scaled_global_grad_norm)
# Stash unscaled gradient norm
self._global_grad_norm = scaled_global_grad_norm / self.loss_scale
self.stop_timers([OPTIMIZER_GRADIENTS])
self.start_timers([OPTIMIZER_STEP])
if self.deepspeed_adam_offload:
from deepspeed.ops.adam import DeepSpeedCPUAdam
if type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half:
bit16_param_groups = [
[
bit16_partitions[dist.get_rank(
group=self.real_dp_process_group[group_id])]
] for group_id,
bit16_partitions in enumerate(self.parallel_partitioned_bit16_groups)
]
self.optimizer.step(fp16_param_groups=bit16_param_groups)
else:
self.optimizer.step()
for group_id, (bit16_partitions, fp32_partition) in enumerate(zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups)):
partition_id = dist.get_rank(
group=self.real_dp_process_group[group_id])
bit16_partitions[partition_id].data.copy_(fp32_partition.data)
else:
self.optimizer.step()
# get rid of the fp32 gradients. Not needed anymore
if not self.cpu_offload:
for group in self.single_partition_of_fp32_groups:
group.grad = None # in step
for group_id, (bit16_partitions, fp32_partition) in enumerate(zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups)):
partition_id = dist.get_rank(group=self.real_dp_process_group[group_id])
self.unscale_and_clip_grads([single_grad_partition],
scaled_global_grad_norm)
self.stop_timers([OPTIMIZER_GRADIENTS])
# Step 3:- run the optimizer if no offloading
self.start_timers([OPTIMIZER_STEP])
self._optimizer_step(i)
# Step 4:- get rid of the fp32 gradients. Not needed anymore
self.single_partition_of_fp32_groups[i].grad = None
del single_grad_partition
bit16_partitions = self.parallel_partitioned_bit16_groups[i]
fp32_partition = self.single_partition_of_fp32_groups[i]
bit16_partitions[partition_id].data.copy_(fp32_partition.data)
self.stop_timers([OPTIMIZER_STEP])
self.stop_timers([OPTIMIZER_STEP])
see_memory_usage('After optimizer before all-gather')
if self.cpu_offload:
self.reset_cpu_buffers()
......@@ -1794,7 +1811,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self.stop_timers([OPTIMIZER_ALLGATHER])
# TODO: we probably don't need this? just to be safe
for i in range(len(norm_groups)):
for i in range(len(self.bit16_groups)):
self._update_model_bit16_weights(i)
self.log_timers(timer_names)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册