未验证 提交 30d97705 编写于 作者: S ShijieZZZZ 提交者: GitHub

Recover shared parameters (#3033)

* submit changes

* update format

* fix fomrat

* revert

* test

* add top

* treat z1 as z2

* revert

---------
Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 4b358333
......@@ -84,9 +84,26 @@ def parse_model_state(file):
buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
param_shapes = state_dict[PARAM_SHAPES]
# collect parameters that are included in param_shapes
param_names = []
for s in param_shapes:
for name in s.keys():
param_names.append(name)
# record shared parameters so that they can be recovered based on partners
# this is because such parameters holding reference only are not saved by optimizer
shared_params = []
for param in state_dict["module"]:
if param not in [*param_names, *buffer_names]:
for share_param in state_dict["module"]:
if (state_dict["module"][share_param].data_ptr() == state_dict["module"][param].data_ptr()
and share_param != param):
shared_params.append([param, share_param])
break
ds_version = state_dict.get(DS_VERSION, None)
return buffers, param_shapes, ds_version
return buffers, param_shapes, shared_params, ds_version
def parse_optim_states(files, ds_checkpoint_dir):
......@@ -153,16 +170,18 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
model_file = get_model_state_file(ds_checkpoint_dir, zero_stage)
buffers, param_shapes, ds_version = parse_model_state(model_file)
buffers, param_shapes, shared_params, ds_version = parse_model_state(model_file)
print(f'Parsing checkpoint created by deepspeed=={ds_version}')
if zero_stage == 2:
return _get_fp32_state_dict_from_zero2_checkpoint(world_size, param_shapes, fp32_flat_groups, buffers)
return _get_fp32_state_dict_from_zero2_checkpoint(world_size, param_shapes, fp32_flat_groups, buffers,
shared_params)
elif zero_stage == 3:
return _get_fp32_state_dict_from_zero3_checkpoint(world_size, param_shapes, fp32_flat_groups, buffers)
return _get_fp32_state_dict_from_zero3_checkpoint(world_size, param_shapes, fp32_flat_groups, buffers,
shared_params)
def _get_fp32_state_dict_from_zero2_checkpoint(world_size, param_shapes, fp32_flat_groups, buffers):
def _get_fp32_state_dict_from_zero2_checkpoint(world_size, param_shapes, fp32_flat_groups, buffers, shared_params):
# Reconstruction protocol:
#
......@@ -238,6 +257,10 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, param_shapes, fp32_fl
if offset != avail_numel:
raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
# recover shared parameters
for pair in shared_params:
state_dict[pair[0]] = state_dict[pair[1]]
print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
return state_dict
......@@ -250,7 +273,7 @@ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
return partitioned_numel, padding_numel
def _get_fp32_state_dict_from_zero3_checkpoint(world_size, param_shapes, fp32_flat_groups, buffers):
def _get_fp32_state_dict_from_zero3_checkpoint(world_size, param_shapes, fp32_flat_groups, buffers, shared_params):
# Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
# param, re-consolidating each param, while dealing with padding if any
......@@ -307,6 +330,10 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, param_shapes, fp32_fl
if offset != avail_numel:
raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
# recover shared parameters
for pair in shared_params:
state_dict[pair[0]] = state_dict[pair[1]]
print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
return state_dict
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册