From 39825a9092b997a2fdcf129e91224183dda93cbc Mon Sep 17 00:00:00 2001 From: ShijieZZZZ <116392778+ShijieZZZZ@users.noreply.github.com> Date: Tue, 25 Apr 2023 06:15:18 -0700 Subject: [PATCH] Fix redundant shared_params in zero_to_fp32.py (#3149) * submit changes * update format * fix fomrat * revert * test * add top * treat z1 as z2 * fix shared * remove old changes --------- Co-authored-by: Olatunji Ruwase --- deepspeed/utils/zero_to_fp32.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py index 44cd87b8..aa9cb530 100755 --- a/deepspeed/utils/zero_to_fp32.py +++ b/deepspeed/utils/zero_to_fp32.py @@ -125,7 +125,7 @@ def parse_model_states(files): shared_params = [] for param in state_dict["module"]: if param not in [*param_names, *buffer_names]: - for share_param in state_dict["module"]: + for share_param in [*param_names, *buffer_names]: if (state_dict["module"][share_param].data_ptr() == state_dict["module"][param].data_ptr() and share_param != param): shared_params.append([param, share_param]) @@ -340,7 +340,8 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zer # recover shared parameters for pair in zero_model_states[0].shared_params: - state_dict[pair[0]] = state_dict[pair[1]] + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] return state_dict @@ -457,7 +458,8 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zer # recover shared parameters for pair in zero_model_states[0].shared_params: - state_dict[pair[0]] = state_dict[pair[1]] + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] return state_dict -- GitLab