未验证 提交 81039463 编写于 作者: S Stas Bekman 提交者: GitHub

[zero3] release tmp memory when consolidating fp16 weights take2 (#1223)

* release tmp memory when consolidating fp16 weights take2

* cleanup
Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 54bed327
......@@ -2039,34 +2039,38 @@ class DeepSpeedEngine(Module):
raise ValueError("this function requires ZeRO-3 mode")
state_dict = OrderedDict() if torch.distributed.get_rank() == 0 else None
shared_weights = {}
shared_params = {}
def get_layer_state_dict(module, prefix=""):
# gather one layer at a time to be memory-efficient
# must use modifier_rank=0 to release GPU memory after each layer gathered
#see_memory_usage("before GatheredParameters", force=True)
with deepspeed.zero.GatheredParameters(list(
module.parameters(recurse=False)),
modifier_rank=None):
modifier_rank=0):
if torch.distributed.get_rank() == 0:
# handle params
for name, param in module.named_parameters(recurse=False):
if param is None:
continue
key = prefix + name
# for shared weights we want to make sure not to unshare them when copying to cpu
data_ptr_id = param.storage().data_ptr()
if data_ptr_id in shared_weights:
# can't rely on param.data_ptr() as it will be reused as weights gets
# gathered and reduced, but param.ds_id is unique across all zero weights
# (and shared params will have the same param.ds_id)
if param.ds_id in shared_params:
# shared weights
# print(f"`{key}` is shared with `{shared_weights[data_ptr_id]}`")
state_dict[key] = state_dict[shared_weights[data_ptr_id]]
#print(f"`{key}` is shared with `{shared_params[param.ds_id]}`")
state_dict[key] = state_dict[shared_params[param.ds_id]]
else:
state_dict[key] = param.detach().cpu()
shared_weights[data_ptr_id] = key
#print(f"param {name} {param.shape}")
#print(f"param {key} {param.shape} {state_dict[key].storage().data_ptr()}")
shared_params[param.ds_id] = key
#print(f"param {param.ds_id} {param.shape} {key} ")
# now buffers - not sure if need to take care of potentially shared weights here
for name, buf in module.named_buffers(recurse=False):
if buf is not None and name not in module._non_persistent_buffers_set:
state_dict[prefix + name] = buf.detach().cpu()
#see_memory_usage("after GatheredParameters", force=True)
for name, child in module.named_children():
if child is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册