diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 7d22e103276db94d4f1789fc35c76c7305bf5251..6c1aadee3a982b08f2e07744c48786b1b005751e 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2460,7 +2460,7 @@ class DeepSpeedEngine(Module): tag, load_optimizer_states=load_optimizer_states) if not success: - self.optimizer._restore_from_fp16_weights() + self.optimizer._restore_from_bit16_weights() return load_path, client_states diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index cdec9bdaadee90adc70e09683c077b6e2b9d59fb..359bb273196c105065037d7132508823e7a7f144 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -2971,13 +2971,13 @@ class DeepSpeedZeroOptimizer_Stage3(object): current.data.copy_(saved.data) # Restore base optimizer fp32 weights from ZeRO fp16 weights - def _restore_from_fp16_weights(self): + def _restore_from_bit16_weights(self): for fp16_partitions, fp32_partition in zip(self.fp16_partitioned_groups_flat, self.fp32_partitioned_groups_flat): fp32_partition.data.copy_(fp16_partitions.data) # Refresh the fp32 master params from the fp16 copies. def refresh_fp32_params(self): - self._restore_from_fp16_weights() + self._restore_from_bit16_weights() # Extract flattened partition for current rank from all partitions def _get_flattened_partition(self, all_partition_states):