diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py index 44cd87b890889937e4444f970f3329a14268a6e6..aa9cb530d44023c6d348e75790dc3e003601bd18 100755 --- a/deepspeed/utils/zero_to_fp32.py +++ b/deepspeed/utils/zero_to_fp32.py @@ -125,7 +125,7 @@ def parse_model_states(files): shared_params = [] for param in state_dict["module"]: if param not in [*param_names, *buffer_names]: - for share_param in state_dict["module"]: + for share_param in [*param_names, *buffer_names]: if (state_dict["module"][share_param].data_ptr() == state_dict["module"][param].data_ptr() and share_param != param): shared_params.append([param, share_param]) @@ -340,7 +340,8 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zer # recover shared parameters for pair in zero_model_states[0].shared_params: - state_dict[pair[0]] = state_dict[pair[1]] + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] return state_dict @@ -457,7 +458,8 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zer # recover shared parameters for pair in zero_model_states[0].shared_params: - state_dict[pair[0]] = state_dict[pair[1]] + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] return state_dict