diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index d24e99af488e4ee2e875376a76f220a6107c0443..e5408db473d851ad3fa558fd92bcbf0f48e86e2f 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2039,34 +2039,38 @@ class DeepSpeedEngine(Module): raise ValueError("this function requires ZeRO-3 mode") state_dict = OrderedDict() if torch.distributed.get_rank() == 0 else None - shared_weights = {} + shared_params = {} def get_layer_state_dict(module, prefix=""): # gather one layer at a time to be memory-efficient + # must use modifier_rank=0 to release GPU memory after each layer gathered + #see_memory_usage("before GatheredParameters", force=True) with deepspeed.zero.GatheredParameters(list( module.parameters(recurse=False)), - modifier_rank=None): + modifier_rank=0): if torch.distributed.get_rank() == 0: + # handle params for name, param in module.named_parameters(recurse=False): if param is None: continue key = prefix + name - # for shared weights we want to make sure not to unshare them when copying to cpu - data_ptr_id = param.storage().data_ptr() - if data_ptr_id in shared_weights: + # can't rely on param.data_ptr() as it will be reused as weights gets + # gathered and reduced, but param.ds_id is unique across all zero weights + # (and shared params will have the same param.ds_id) + if param.ds_id in shared_params: # shared weights - # print(f"`{key}` is shared with `{shared_weights[data_ptr_id]}`") - state_dict[key] = state_dict[shared_weights[data_ptr_id]] + #print(f"`{key}` is shared with `{shared_params[param.ds_id]}`") + state_dict[key] = state_dict[shared_params[param.ds_id]] else: state_dict[key] = param.detach().cpu() - shared_weights[data_ptr_id] = key - #print(f"param {name} {param.shape}") - #print(f"param {key} {param.shape} {state_dict[key].storage().data_ptr()}") + shared_params[param.ds_id] = key + #print(f"param {param.ds_id} {param.shape} {key} ") # now buffers - not sure if need to take care of potentially shared weights here for name, buf in module.named_buffers(recurse=False): if buf is not None and name not in module._non_persistent_buffers_set: state_dict[prefix + name] = buf.detach().cpu() + #see_memory_usage("after GatheredParameters", force=True) for name, child in module.named_children(): if child is not None: