未验证 提交 dce054db 编写于 作者: J Jeff Rasley 提交者: GitHub

backwards compatability w. v020 ckpts, fix issue with zero-1 ckpts (#543)

上级 9de21b72
......@@ -42,6 +42,8 @@ sys.modules['deepspeed.pt'] = deepspeed.pt
sys.modules['deepspeed.pt.deepspeed_utils'] = deepspeed.runtime.utils
setattr(deepspeed.pt, 'deepspeed_config', deepspeed.runtime.config)
sys.modules['deepspeed.pt.deepspeed_config'] = deepspeed.runtime.config
setattr(deepspeed.pt, 'loss_scaler', deepspeed.runtime.fp16.loss_scaler)
sys.modules['deepspeed.pt.loss_scaler'] = deepspeed.runtime.fp16.loss_scaler
def initialize(args,
......
......@@ -194,7 +194,6 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
# max elems per param group
self.max_elems_per_comm = []
self.legacy_max_elements_per_comm = max_elements_per_comm
# loop to deal with groups
for i, param_group in enumerate(self.optimizer.param_groups):
......@@ -859,7 +858,6 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
state_dict['zero_stage'] = ZERO_OPTIMIZATION_OPTIMIZER_STATES
state_dict['partition_count'] = self.partition_count
state_dict['num_comm_intervals_per_group'] = self.num_comm_intervals_per_group
state_dict['max_elems_per_comm'] = self.max_elems_per_comm
# Remove paddings for DP alignment to enable loading for other alignment values
fp32_groups_without_padding = self._get_groups_without_padding(
......@@ -933,10 +931,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
sd['local_sub_partitions_of_fp32_groups'][group_idx]
for sd in all_state_dict
]
if 'max_elems_per_comm' in all_state_dict[0]:
max_elems_per_comm = all_state_dict[0]['max_elems_per_comm'][group_idx]
else:
max_elems_per_comm = self.legacy_max_elements_per_comm
max_elems_per_comm = self.max_elems_per_comm[group_idx]
sub_partition_weights = self._retrieve_group_sub_partition_weights(
all_partition_fp32_weights,
......@@ -1009,10 +1004,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
all_partition_group_states = [
sd['base_optimizer_state'][group_idx] for sd in state_dict_list
]
if 'max_elems_per_comm' in state_dict_list[0]:
max_elems_per_comm = state_dict_list[0]['max_elems_per_comm'][group_idx]
else:
max_elems_per_comm = self.legacy_max_elements_per_comm
max_elems_per_comm = self.max_elems_per_comm[group_idx]
group_optimizer_states = self._retrieve_group_optimizer_states(
all_partition_group_states,
max_elems_per_comm)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册