未验证 提交 e40558de 编写于 作者: O Olatunji Ruwase 提交者: GitHub

Fix checkpoint api (#1714)

上级 4912e0ad
...@@ -2460,7 +2460,7 @@ class DeepSpeedEngine(Module): ...@@ -2460,7 +2460,7 @@ class DeepSpeedEngine(Module):
tag, tag,
load_optimizer_states=load_optimizer_states) load_optimizer_states=load_optimizer_states)
if not success: if not success:
self.optimizer._restore_from_fp16_weights() self.optimizer._restore_from_bit16_weights()
return load_path, client_states return load_path, client_states
......
...@@ -2971,13 +2971,13 @@ class DeepSpeedZeroOptimizer_Stage3(object): ...@@ -2971,13 +2971,13 @@ class DeepSpeedZeroOptimizer_Stage3(object):
current.data.copy_(saved.data) current.data.copy_(saved.data)
# Restore base optimizer fp32 weights from ZeRO fp16 weights # Restore base optimizer fp32 weights from ZeRO fp16 weights
def _restore_from_fp16_weights(self): def _restore_from_bit16_weights(self):
for fp16_partitions, fp32_partition in zip(self.fp16_partitioned_groups_flat, self.fp32_partitioned_groups_flat): for fp16_partitions, fp32_partition in zip(self.fp16_partitioned_groups_flat, self.fp32_partitioned_groups_flat):
fp32_partition.data.copy_(fp16_partitions.data) fp32_partition.data.copy_(fp16_partitions.data)
# Refresh the fp32 master params from the fp16 copies. # Refresh the fp32 master params from the fp16 copies.
def refresh_fp32_params(self): def refresh_fp32_params(self):
self._restore_from_fp16_weights() self._restore_from_bit16_weights()
# Extract flattened partition for current rank from all partitions # Extract flattened partition for current rank from all partitions
def _get_flattened_partition(self, all_partition_states): def _get_flattened_partition(self, all_partition_states):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册