From e40558ded241edefed410aa44dbf490ab60f05f6 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Fri, 21 Jan 2022 06:32:48 -0800 Subject: [PATCH] Fix checkpoint api (#1714) --- deepspeed/runtime/engine.py | 2 +- deepspeed/runtime/zero/stage3.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 7d22e103..6c1aadee 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2460,7 +2460,7 @@ class DeepSpeedEngine(Module): tag, load_optimizer_states=load_optimizer_states) if not success: - self.optimizer._restore_from_fp16_weights() + self.optimizer._restore_from_bit16_weights() return load_path, client_states diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index cdec9bda..359bb273 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -2971,13 +2971,13 @@ class DeepSpeedZeroOptimizer_Stage3(object): current.data.copy_(saved.data) # Restore base optimizer fp32 weights from ZeRO fp16 weights - def _restore_from_fp16_weights(self): + def _restore_from_bit16_weights(self): for fp16_partitions, fp32_partition in zip(self.fp16_partitioned_groups_flat, self.fp32_partitioned_groups_flat): fp32_partition.data.copy_(fp16_partitions.data) # Refresh the fp32 master params from the fp16 copies. def refresh_fp32_params(self): - self._restore_from_fp16_weights() + self._restore_from_bit16_weights() # Extract flattened partition for current rank from all partitions def _get_flattened_partition(self, all_partition_states): -- GitLab