未验证 提交 336dd089 编写于 作者: O Olatunji Ruwase 提交者: GitHub

Use clone to avoid checkpoint bloat (#1326)

上级 85acf14c
......@@ -559,7 +559,18 @@ class PipelineModule(nn.Module):
model_ckpt_path = self.ckpt_layer_path(save_dir, idx)
if not hasattr(layer, 'state_dict'):
continue
torch.save(layer.state_dict(), model_ckpt_path)
# We pass cloned tensors to torch.save() to avoid checkpoint bloat which occurs because torch.save()
# saves the underlying storage rather than the slice of the storage corresponding to individual tensors.
# This is a problem in DeepSpeed because we often allocate tensors using slices of large flattened buffers.
# Tensor cloning helps to avoid this problem because the storage of cloned tensors are closer to the true size.
# It is expected that the garbage collector will reclaim the cloned tensor storage to avoid memory bloat.
# See https://pytorch.org/docs/stable/notes/serialization.html#preserve-storage-sharing
orig_state_dict = layer.state_dict()
final_state_dict = type(orig_state_dict)(
{k: v.clone()
for k,
v in orig_state_dict.items()})
torch.save(final_state_dict, model_ckpt_path)
def load_state_dir(self, load_dir, strict=True):
for idx, layer in enumerate(self.forward_funcs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册