未验证 提交 39825a90 编写于 作者: S ShijieZZZZ 提交者: GitHub

Fix redundant shared_params in zero_to_fp32.py (#3149)

* submit changes

* update format

* fix fomrat

* revert

* test

* add top

* treat z1 as z2

* fix shared

* remove old changes

---------
Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 085981bf
......@@ -125,7 +125,7 @@ def parse_model_states(files):
shared_params = []
for param in state_dict["module"]:
if param not in [*param_names, *buffer_names]:
for share_param in state_dict["module"]:
for share_param in [*param_names, *buffer_names]:
if (state_dict["module"][share_param].data_ptr() == state_dict["module"][param].data_ptr()
and share_param != param):
shared_params.append([param, share_param])
......@@ -340,7 +340,8 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zer
# recover shared parameters
for pair in zero_model_states[0].shared_params:
state_dict[pair[0]] = state_dict[pair[1]]
if pair[1] in state_dict:
state_dict[pair[0]] = state_dict[pair[1]]
return state_dict
......@@ -457,7 +458,8 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zer
# recover shared parameters
for pair in zero_model_states[0].shared_params:
state_dict[pair[0]] = state_dict[pair[1]]
if pair[1] in state_dict:
state_dict[pair[0]] = state_dict[pair[1]]
return state_dict
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册