未验证 提交 77ebf760 编写于 作者: S Stas Bekman 提交者: GitHub

[zero_to_fp32] fix shared param recovery (#3407)

* [zero_to_fp32] fix shared param recovery

* cleanup

* cleanup

* better naming

* not all params have ds_id it seems

---------
Co-authored-by: NLogan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 dcb4a7d6
......@@ -3068,6 +3068,7 @@ class DeepSpeedEngine(Module):
param_shapes=self._get_zero_param_shapes() if self.optimizer and zero_optimizer_state else None,
frozen_param_shapes=self._get_zero_frozen_param_attributes(self._get_param_shape_func)
if save_frozen_param else None,
shared_params=self._get_shared_params() if self.optimizer and zero_optimizer_state else None,
frozen_param_fragments=self._get_zero_frozen_param_attributes(self._get_param_fragment_func)
if save_frozen_param else None,
lr_scheduler=self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
......@@ -3170,6 +3171,40 @@ class DeepSpeedEngine(Module):
return param_group_shapes
def _get_shared_params(self):
"""
Returns a dict of shared params, which can later be used to reconstruct the original state dict,
e.g. in `zero_to_fp32`. Each dict entry is a pair of param names, where the key is the name
of the variable that isn't stored and the value is the actual param holding data.
"""
shared_ds_ids = {}
shared_params_by_full_name = {}
def get_layer_state_dict(module, prefix=""):
# handle params
for name, param in module.named_parameters(recurse=False):
if param is None or not hasattr(param, "ds_id"):
continue
key = prefix + name
# can't rely on param.data_ptr() as it will be reused as weights gets
# gathered and reduced, but param.ds_id is unique across all zero weights
# (and shared params will have the same param.ds_id)
if param.ds_id in shared_ds_ids:
# shared weights
#print(f"`{key}` is shared with `{shared_ds_ids[param.ds_id]}`")
shared_params_by_full_name[key] = shared_ds_ids[param.ds_id]
else:
shared_ds_ids[param.ds_id] = key
for name, child in module.named_children():
if child is not None:
get_layer_state_dict(child, prefix + name + ".")
if dist.get_rank() == 0:
get_layer_state_dict(self.module, prefix="")
return shared_params_by_full_name
def _copy_recovery_script(self, save_path):
base_dir = os.path.dirname(os.path.dirname(__file__))
script = "zero_to_fp32.py"
......
......@@ -120,16 +120,8 @@ def parse_model_states(files):
print(f"Found frozen_param_shapes: {frozen_param_shapes}")
param_names += list(frozen_param_shapes.keys())
# record shared parameters so that they can be recovered based on partners
# this is because such parameters holding reference only are not saved by optimizer
shared_params = []
for param in state_dict["module"]:
if param not in [*param_names, *buffer_names]:
for share_param in [*param_names, *buffer_names]:
if (state_dict["module"][share_param].data_ptr() == state_dict["module"][param].data_ptr()
and share_param != param):
shared_params.append([param, share_param])
break
# handle shared params
shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
ds_version = state_dict.get(DS_VERSION, None)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册