未验证 提交 9305916d 编写于 作者: kisseternity's avatar kisseternity 提交者: GitHub

Comments for better understanding of zero stage1_2 (#2027)

Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 9fc4e5f1
...@@ -1089,6 +1089,8 @@ class DeepSpeedEngine(Module): ...@@ -1089,6 +1089,8 @@ class DeepSpeedEngine(Module):
logger.warning( logger.warning(
"**** You are using ZeRO with an untested optimizer, proceed with caution *****" "**** You are using ZeRO with an untested optimizer, proceed with caution *****"
) )
# This optimizer in engine is ZeRO optimizer of stage1_2 or stage3 based on the 'stage' config,
# while ZeRO optimizer itself wraps the original optimizer.
self.optimizer = self._configure_zero_optimizer(basic_optimizer) self.optimizer = self._configure_zero_optimizer(basic_optimizer)
elif self.amp_enabled(): elif self.amp_enabled():
assert not (self.fp16_enabled() or self.bfloat16_enabled()), "Cannot enable both amp with (legacy) fp16 or bfloat16 mode" assert not (self.fp16_enabled() or self.bfloat16_enabled()), "Cannot enable both amp with (legacy) fp16 or bfloat16 mode"
......
...@@ -346,7 +346,9 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -346,7 +346,9 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
assert (partitioned_data.data_ptr() % assert (partitioned_data.data_ptr() %
(2 * self.nccl_start_alignment_factor) == 0) (2 * self.nccl_start_alignment_factor) == 0)
# a partition of the fp32 master weights that will be updated by this process # A partition of the fp32 master weights that will be updated by this process.
# Note that the params in single_partition_of_fp32_groups is cloned and detached
# from the origin params of the model.
if not fp16_master_weights_and_gradients: if not fp16_master_weights_and_gradients:
self.single_partition_of_fp32_groups.append( self.single_partition_of_fp32_groups.append(
self.parallel_partitioned_bit16_groups[i][partition_id].to( self.parallel_partitioned_bit16_groups[i][partition_id].to(
...@@ -356,7 +358,9 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -356,7 +358,9 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self.parallel_partitioned_bit16_groups[i][partition_id].to( self.parallel_partitioned_bit16_groups[i][partition_id].to(
self.device).clone().half().detach()) self.device).clone().half().detach())
# modify optimizer of have flat master weight # Set local optimizer to have flat params of its own partition.
# After this, the local optimizer will only contain its own partition of params.
# In that case, the local optimizer only saves the states(momentum, variance, etc.) related to its partition's params(zero stage1).
self.single_partition_of_fp32_groups[ self.single_partition_of_fp32_groups[
i].requires_grad = True # keep this in case internal optimizer uses it i].requires_grad = True # keep this in case internal optimizer uses it
param_group['params'] = [self.single_partition_of_fp32_groups[i]] param_group['params'] = [self.single_partition_of_fp32_groups[i]]
...@@ -1426,7 +1430,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1426,7 +1430,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
partitions = [] partitions = []
dp = dist.get_world_size(group=self.real_dp_process_group[group_id]) dp = dist.get_world_size(group=self.real_dp_process_group[group_id])
dp_id = dist.get_rank(group=self.real_dp_process_group[group_id]) # dp_id = dist.get_rank(group=self.real_dp_process_group[group_id])
total_num_elements = tensor.numel() total_num_elements = tensor.numel()
...@@ -1691,7 +1695,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1691,7 +1695,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self.start_timers([OPTIMIZER_GRADIENTS]) self.start_timers([OPTIMIZER_GRADIENTS])
norm_groups = [] norm_groups = []
single_partition_grad_groups = [] single_partition_grad_groups = []
skip = False # skip = False
for i, group in enumerate(self.bit16_groups): for i, group in enumerate(self.bit16_groups):
partition_id = dist.get_rank(group=self.real_dp_process_group[i]) partition_id = dist.get_rank(group=self.real_dp_process_group[i])
if self.cpu_offload: if self.cpu_offload:
...@@ -1704,7 +1708,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1704,7 +1708,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self.get_grad_norm_direct(self.averaged_gradients[i], self.get_grad_norm_direct(self.averaged_gradients[i],
self.params_in_partition[i])) self.params_in_partition[i]))
# free gradients for all the parameters that are not updated by this process # free gradients for all the parameters that are not updated by this process(ZeRO stage2)
self.free_grad_in_param_list(self.params_not_in_partition[i]) self.free_grad_in_param_list(self.params_not_in_partition[i])
# create a flat gradients for parameters updated by this process # create a flat gradients for parameters updated by this process
...@@ -1723,7 +1727,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1723,7 +1727,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
single_grad_partition.numel(), self.partition_size[i], i, partition_id) single_grad_partition.numel(), self.partition_size[i], i, partition_id)
self.single_partition_of_fp32_groups[i].grad = single_grad_partition self.single_partition_of_fp32_groups[i].grad = single_grad_partition
# release all the gradient since we have already created a necessary copy in dp_grad_partition # release all the gradient since we have already created a necessary copy in dp_grad_partition(ZeRO stage2)
self.free_grad_in_param_list(self.params_in_partition[i]) self.free_grad_in_param_list(self.params_in_partition[i])
self.averaged_gradients[i] = None self.averaged_gradients[i] = None
...@@ -1752,6 +1756,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1752,6 +1756,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self.optimizer.step(fp16_param_groups=bit16_param_groups) self.optimizer.step(fp16_param_groups=bit16_param_groups)
else: else:
self.optimizer.step() self.optimizer.step()
# after step(), single_partition_of_fp32_groups has the local optimizer's own partition of updated params
for bit16_partitions, fp32_partition in zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups): for bit16_partitions, fp32_partition in zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups):
bit16_partitions[partition_id].data.copy_(fp32_partition.data) bit16_partitions[partition_id].data.copy_(fp32_partition.data)
else: else:
...@@ -1772,7 +1777,8 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1772,7 +1777,8 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self.reset_cpu_buffers() self.reset_cpu_buffers()
self.start_timers([OPTIMIZER_ALLGATHER]) self.start_timers([OPTIMIZER_ALLGATHER])
# gather the updated weights from everyone # Gather the updated weights from everyone.
# Then all partitions of the model parameters are updated and ready for next round forward.
all_gather_dp_groups( all_gather_dp_groups(
partitioned_param_groups=self.parallel_partitioned_bit16_groups, partitioned_param_groups=self.parallel_partitioned_bit16_groups,
dp_process_group=self.real_dp_process_group, dp_process_group=self.real_dp_process_group,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册