diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 892394e049ad3f77f67b8193e8193edac5b9bf96..6fd84a6aa0730718f4b96bf2b6639ff1378c99f6 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1089,6 +1089,8 @@ class DeepSpeedEngine(Module): logger.warning( "**** You are using ZeRO with an untested optimizer, proceed with caution *****" ) + # This optimizer in engine is ZeRO optimizer of stage1_2 or stage3 based on the 'stage' config, + # while ZeRO optimizer itself wraps the original optimizer. self.optimizer = self._configure_zero_optimizer(basic_optimizer) elif self.amp_enabled(): assert not (self.fp16_enabled() or self.bfloat16_enabled()), "Cannot enable both amp with (legacy) fp16 or bfloat16 mode" diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index c6a2e26aac8fdc4c4fc94b4fb4f627046959437c..20f92956a3a5372aca2c5026b9f4d79e56ff3e96 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -346,7 +346,9 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): assert (partitioned_data.data_ptr() % (2 * self.nccl_start_alignment_factor) == 0) - # a partition of the fp32 master weights that will be updated by this process + # A partition of the fp32 master weights that will be updated by this process. + # Note that the params in single_partition_of_fp32_groups is cloned and detached + # from the origin params of the model. if not fp16_master_weights_and_gradients: self.single_partition_of_fp32_groups.append( self.parallel_partitioned_bit16_groups[i][partition_id].to( @@ -356,7 +358,9 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): self.parallel_partitioned_bit16_groups[i][partition_id].to( self.device).clone().half().detach()) - # modify optimizer of have flat master weight + # Set local optimizer to have flat params of its own partition. + # After this, the local optimizer will only contain its own partition of params. + # In that case, the local optimizer only saves the states(momentum, variance, etc.) related to its partition's params(zero stage1). self.single_partition_of_fp32_groups[ i].requires_grad = True # keep this in case internal optimizer uses it param_group['params'] = [self.single_partition_of_fp32_groups[i]] @@ -1426,7 +1430,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): partitions = [] dp = dist.get_world_size(group=self.real_dp_process_group[group_id]) - dp_id = dist.get_rank(group=self.real_dp_process_group[group_id]) + # dp_id = dist.get_rank(group=self.real_dp_process_group[group_id]) total_num_elements = tensor.numel() @@ -1691,7 +1695,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): self.start_timers([OPTIMIZER_GRADIENTS]) norm_groups = [] single_partition_grad_groups = [] - skip = False + # skip = False for i, group in enumerate(self.bit16_groups): partition_id = dist.get_rank(group=self.real_dp_process_group[i]) if self.cpu_offload: @@ -1704,7 +1708,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): self.get_grad_norm_direct(self.averaged_gradients[i], self.params_in_partition[i])) - # free gradients for all the parameters that are not updated by this process + # free gradients for all the parameters that are not updated by this process(ZeRO stage2) self.free_grad_in_param_list(self.params_not_in_partition[i]) # create a flat gradients for parameters updated by this process @@ -1723,7 +1727,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): single_grad_partition.numel(), self.partition_size[i], i, partition_id) self.single_partition_of_fp32_groups[i].grad = single_grad_partition - # release all the gradient since we have already created a necessary copy in dp_grad_partition + # release all the gradient since we have already created a necessary copy in dp_grad_partition(ZeRO stage2) self.free_grad_in_param_list(self.params_in_partition[i]) self.averaged_gradients[i] = None @@ -1752,6 +1756,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): self.optimizer.step(fp16_param_groups=bit16_param_groups) else: self.optimizer.step() + # after step(), single_partition_of_fp32_groups has the local optimizer's own partition of updated params for bit16_partitions, fp32_partition in zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups): bit16_partitions[partition_id].data.copy_(fp32_partition.data) else: @@ -1772,7 +1777,8 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): self.reset_cpu_buffers() self.start_timers([OPTIMIZER_ALLGATHER]) - # gather the updated weights from everyone + # Gather the updated weights from everyone. + # Then all partitions of the model parameters are updated and ready for next round forward. all_gather_dp_groups( partitioned_param_groups=self.parallel_partitioned_bit16_groups, dp_process_group=self.real_dp_process_group,