未验证 提交 607814fe 编写于 作者: O Olatunji Ruwase 提交者: GitHub

Fix bug in fp32 optimizer state loading (#289)

上级 7ccc9daf
......@@ -1140,8 +1140,12 @@ class DeepSpeedLight(Module):
self.load_module_state_dict(state_dict=checkpoint['module'],
strict=load_module_strict)
if not self.zero_optimization():
self.optimizer.load_state_dict(checkpoint['optimizer'],
load_optimizer_states=load_optimizer_states)
if self.fp16_enabled():
self.optimizer.load_state_dict(
checkpoint['optimizer'],
load_optimizer_states=load_optimizer_states)
else:
self.optimizer.load_state_dict(checkpoint['optimizer'])
if load_lr_scheduler_states and self.lr_scheduler is not None:
self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
......
......@@ -41,9 +41,9 @@ class SimpleOptimizer(torch.optim.Optimizer):
return loss
def random_dataloader(model, total_samples, hidden_dim, device):
def random_dataloader(model, total_samples, hidden_dim, device, dtype=torch.half):
batch_size = model.train_micro_batch_size_per_gpu()
train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=torch.half)
train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=dtype)
train_label = torch.empty(total_samples,
dtype=torch.long,
device=device).random_(hidden_dim)
......
......@@ -47,14 +47,18 @@ def compare_model_states(saved_model, loaded_model):
for params0, params1 in zip(saved_model.optimizer.fp32_groups, loaded_model.optimizer.fp32_groups):
for p0, p1 in zip(params0, params1):
assert torch.allclose(p0,p1,atol=1e-07), f"FP32 model states {p0} is not equal to {p1}"
elif isinstance(saved_model.optimizer, torch.optim.Optimizer):
pass
else:
assert False, 'Unexpected Optimizer Type'
assert False, f'Unexpected Optimizer Type: {saved_model.optimizer}'
def compare_optimizer_states(saved_model, loaded_model, hidden_dim, fp16=True):
saved_optimizer = saved_model.optimizer.optimizer if fp16 else saved_model.optimizer
loaded_optimizer = loaded_model.optimizer.optimizer if fp16 else loaded_model.optimizer
def compare_optimizer_states(saved_model, loaded_model, hidden_dim):
for state0, state1 in zip(saved_model.optimizer.optimizer.state.values(),
loaded_model.optimizer.optimizer.state.values()):
for state0, state1 in zip(saved_optimizer.state.values(),
loaded_optimizer.state.values()):
for s0, s1 in zip(state0.values(), state1.values()):
if isinstance(s0, torch.Tensor) and isinstance(s1, torch.Tensor):
assert torch.equal(s0, s1)
......@@ -90,15 +94,17 @@ def checkpoint_correctness_verification(args,
hidden_dim,
tmpdir,
load_optimizer_states=False,
load_lr_scheduler_states=False):
load_lr_scheduler_states=False,
fp16=True):
dtype = torch.half if fp16 else torch.float32
ds_model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=ds_model,
total_samples=50,
hidden_dim=hidden_dim,
device=ds_model.device)
device=ds_model.device,
dtype=dtype)
for n, batch in enumerate(data_loader):
loss = ds_model(batch[0], batch[1])
ds_model.backward(loss)
......@@ -123,7 +129,7 @@ def checkpoint_correctness_verification(args,
compare_model_states(trained_model, loaded_model)
if load_optimizer_states:
compare_optimizer_states(trained_model, loaded_model, hidden_dim)
compare_optimizer_states(trained_model, loaded_model, hidden_dim, fp16)
if load_lr_scheduler_states:
compare_lr_scheduler_states(trained_model, loaded_model)
......@@ -420,3 +426,34 @@ def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage):
hidden_dim=hidden_dim,
load_optimizer_states=False,
load_lr_scheduler_states=False)
def test_checkpoint_fp32_optimizer(tmpdir):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015,
"betas": [0.8,
0.999],
"eps": 1e-8,
"weight_decay": 3e-7
}
},
"fp16": {
"enabled": False
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
@distributed_test(world_size=[2])
def _test_checkpoint_fp32_optimizer(args, model, hidden_dim):
checkpoint_correctness_verification(args, model, hidden_dim, tmpdir, fp16=False)
_test_checkpoint_fp32_optimizer(args=args, model=model, hidden_dim=hidden_dim)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册