未验证 提交 e40558de 编写于 作者: O Olatunji Ruwase 提交者: GitHub

Fix checkpoint api (#1714)

上级 4912e0ad
......@@ -2460,7 +2460,7 @@ class DeepSpeedEngine(Module):
tag,
load_optimizer_states=load_optimizer_states)
if not success:
self.optimizer._restore_from_fp16_weights()
self.optimizer._restore_from_bit16_weights()
return load_path, client_states
......
......@@ -2971,13 +2971,13 @@ class DeepSpeedZeroOptimizer_Stage3(object):
current.data.copy_(saved.data)
# Restore base optimizer fp32 weights from ZeRO fp16 weights
def _restore_from_fp16_weights(self):
def _restore_from_bit16_weights(self):
for fp16_partitions, fp32_partition in zip(self.fp16_partitioned_groups_flat, self.fp32_partitioned_groups_flat):
fp32_partition.data.copy_(fp16_partitions.data)
# Refresh the fp32 master params from the fp16 copies.
def refresh_fp32_params(self):
self._restore_from_fp16_weights()
self._restore_from_bit16_weights()
# Extract flattened partition for current rank from all partitions
def _get_flattened_partition(self, all_partition_states):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册