未验证 提交 9305916d 编写于 作者: kisseternity's avatar kisseternity 提交者: GitHub

Comments for better understanding of zero stage1_2 (#2027)

Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 9fc4e5f1
......@@ -1089,6 +1089,8 @@ class DeepSpeedEngine(Module):
logger.warning(
"**** You are using ZeRO with an untested optimizer, proceed with caution *****"
)
# This optimizer in engine is ZeRO optimizer of stage1_2 or stage3 based on the 'stage' config,
# while ZeRO optimizer itself wraps the original optimizer.
self.optimizer = self._configure_zero_optimizer(basic_optimizer)
elif self.amp_enabled():
assert not (self.fp16_enabled() or self.bfloat16_enabled()), "Cannot enable both amp with (legacy) fp16 or bfloat16 mode"
......
......@@ -346,7 +346,9 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
assert (partitioned_data.data_ptr() %
(2 * self.nccl_start_alignment_factor) == 0)
# a partition of the fp32 master weights that will be updated by this process
# A partition of the fp32 master weights that will be updated by this process.
# Note that the params in single_partition_of_fp32_groups is cloned and detached
# from the origin params of the model.
if not fp16_master_weights_and_gradients:
self.single_partition_of_fp32_groups.append(
self.parallel_partitioned_bit16_groups[i][partition_id].to(
......@@ -356,7 +358,9 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self.parallel_partitioned_bit16_groups[i][partition_id].to(
self.device).clone().half().detach())
# modify optimizer of have flat master weight
# Set local optimizer to have flat params of its own partition.
# After this, the local optimizer will only contain its own partition of params.
# In that case, the local optimizer only saves the states(momentum, variance, etc.) related to its partition's params(zero stage1).
self.single_partition_of_fp32_groups[
i].requires_grad = True # keep this in case internal optimizer uses it
param_group['params'] = [self.single_partition_of_fp32_groups[i]]
......@@ -1426,7 +1430,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
partitions = []
dp = dist.get_world_size(group=self.real_dp_process_group[group_id])
dp_id = dist.get_rank(group=self.real_dp_process_group[group_id])
# dp_id = dist.get_rank(group=self.real_dp_process_group[group_id])
total_num_elements = tensor.numel()
......@@ -1691,7 +1695,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self.start_timers([OPTIMIZER_GRADIENTS])
norm_groups = []
single_partition_grad_groups = []
skip = False
# skip = False
for i, group in enumerate(self.bit16_groups):
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
if self.cpu_offload:
......@@ -1704,7 +1708,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self.get_grad_norm_direct(self.averaged_gradients[i],
self.params_in_partition[i]))
# free gradients for all the parameters that are not updated by this process
# free gradients for all the parameters that are not updated by this process(ZeRO stage2)
self.free_grad_in_param_list(self.params_not_in_partition[i])
# create a flat gradients for parameters updated by this process
......@@ -1723,7 +1727,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
single_grad_partition.numel(), self.partition_size[i], i, partition_id)
self.single_partition_of_fp32_groups[i].grad = single_grad_partition
# release all the gradient since we have already created a necessary copy in dp_grad_partition
# release all the gradient since we have already created a necessary copy in dp_grad_partition(ZeRO stage2)
self.free_grad_in_param_list(self.params_in_partition[i])
self.averaged_gradients[i] = None
......@@ -1752,6 +1756,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self.optimizer.step(fp16_param_groups=bit16_param_groups)
else:
self.optimizer.step()
# after step(), single_partition_of_fp32_groups has the local optimizer's own partition of updated params
for bit16_partitions, fp32_partition in zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups):
bit16_partitions[partition_id].data.copy_(fp32_partition.data)
else:
......@@ -1772,7 +1777,8 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self.reset_cpu_buffers()
self.start_timers([OPTIMIZER_ALLGATHER])
# gather the updated weights from everyone
# Gather the updated weights from everyone.
# Then all partitions of the model parameters are updated and ready for next round forward.
all_gather_dp_groups(
partitioned_param_groups=self.parallel_partitioned_bit16_groups,
dp_process_group=self.real_dp_process_group,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册