未验证 提交 e801e6d7 编写于 作者: A Alexander Jipa 提交者: GitHub

skipping redundant MoE optimizer state loading (#4120)

Co-authored-by: NAlexander Jipa <azzhipa@amazon.com>
上级 9894c06a
......@@ -2759,26 +2759,29 @@ class DeepSpeedEngine(Module):
self.loaded_checkpoint_dp_world_size = checkpoint['dp_world_size']
optim_checkpoint = None
if load_module_only:
deepspeed_states = ['module']
if self.optimizer is not None and self.fp16_enabled():
self.optimizer.refresh_fp32_params()
else:
if self.has_moe_layers:
largest_group_name = groups._get_max_expert_size_name()
expp_rank = groups._get_expert_parallel_rank(largest_group_name)
optim_load_path = self._get_optimizer_ckpt_name(load_dir, tag, expp_rank)
optim_checkpoint = self.checkpoint_engine.load(optim_load_path, map_location=torch.device('cpu'))
else:
optim_checkpoint = checkpoint
has_zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled()
if load_optimizer_states and self.optimizer is not None and not has_zero_optimizer_state:
if self.fp16_enabled():
if self.has_moe_layers:
largest_group_name = groups._get_max_expert_size_name()
expp_rank = groups._get_expert_parallel_rank(largest_group_name)
optim_load_path = self._get_optimizer_ckpt_name(load_dir, tag, expp_rank)
optim_checkpoint = self.checkpoint_engine.load(optim_load_path, map_location=torch.device('cpu'))
else:
optim_checkpoint = checkpoint
if self.fp16_enabled() or self.bfloat16_enabled():
self.optimizer.load_state_dict(optim_checkpoint['optimizer'],
load_optimizer_states=load_optimizer_states)
else:
self.optimizer.load_state_dict(optim_checkpoint['optimizer'])
optim_checkpoint = checkpoint
self.optimizer.load_state_dict(optim_checkpoint['optimizer'])
if load_lr_scheduler_states and self.lr_scheduler is not None:
self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
......@@ -2835,7 +2838,7 @@ class DeepSpeedEngine(Module):
client_state = {key: value for key, value in checkpoint.items() if not key in deepspeed_states}
if not load_optimizer_states and not load_module_only:
if optim_checkpoint is not None:
client_state['optimizer'] = optim_checkpoint['optimizer']
return load_path, client_state
......
......@@ -15,6 +15,7 @@ from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from unit.simple_model import *
from unittest.mock import MagicMock, patch
def compare_deepspeed_states(saved_model, loaded_model):
......@@ -209,11 +210,17 @@ def checkpoint_correctness_verification(config_dict,
loaded_model = create_deepspeed_model(config_dict=config_dict, model=models[1], base_optimizer=base_optimizers[1])
assert list(trained_model.parameters())[0].dtype == list(loaded_model.parameters())[0].dtype
loaded_model.load_checkpoint(save_folder,
tag=save_tag,
load_optimizer_states=load_optimizer_states,
load_lr_scheduler_states=load_lr_scheduler_states,
load_module_only=load_module_only)
context = patch.object(loaded_model, "_get_optimizer_ckpt_name",
wraps=loaded_model._get_optimizer_ckpt_name) if not load_optimizer_states else MagicMock()
with context as optim_load_state_dict_mock:
loaded_model.load_checkpoint(save_folder,
tag=save_tag,
load_optimizer_states=load_optimizer_states,
load_lr_scheduler_states=load_lr_scheduler_states,
load_module_only=load_module_only)
if not load_optimizer_states:
# should not attempt to get the file name to load it
optim_load_state_dict_mock.assert_not_called()
compare_model_states(trained_model,
loaded_model,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册