未验证 提交 936117b5 编写于 作者: S Samyam Rajbhandari 提交者: GitHub

Enhancement: Ability to load checkpoint without loading the optimizer… (#128)

* Enhancement: Ability to load checkpoint without loading the optimizer states. Unittest testing saving and loading checkpoint with fused, unfused and zero optimizer. The unitest takes about 165s
上级 1c0b326e
......@@ -933,26 +933,28 @@ class DeepSpeedLight(Module):
if not os.path.exists(dirname):
os.makedirs(dirname)
def load_checkpoint(self, load_dir, tag):
def load_checkpoint(self, load_dir, tag, load_optimizer_states=True):
r"""Load training checkpoint
Arguments:
load_dir: Required. Directory to load the checkpoint from
tag: Required. Checkpoint tag used as a unique identifier for the checkpoint. Ex. Global Step.
load_optimizer_states: Optional. Boolean to load the training optimizer states from Checkpoint. Ex. ADAM's momentum and variance
Return:
load_path: Path of the loaded checkpoint. None if loading the checkpoint failed
client_state: State dictionary used for loading required training states in the client code.
"""
load_path, client_states = self._load_checkpoint(load_dir, tag)
load_path, client_states = self._load_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states)
if self.zero_optimization() and load_path is not None:
self._load_zero_checkpoint(load_dir, tag)
self._load_zero_checkpoint(load_dir,
tag,
load_optimizer_states=load_optimizer_states)
return load_path, client_states
def _load_checkpoint(self, load_dir, tag):
def _load_checkpoint(self, load_dir, tag, load_optimizer_states=True):
load_path = self._get_ckpt_name(load_dir, tag)
......@@ -967,7 +969,8 @@ class DeepSpeedLight(Module):
self.load_module_state_dict(checkpoint['module'])
if not self.zero_optimization():
self.optimizer.load_state_dict(checkpoint['optimizer'])
self.optimizer.load_state_dict(checkpoint['optimizer'],
load_optimizer_states=load_optimizer_states)
if self.lr_scheduler is not None:
self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
......@@ -990,7 +993,7 @@ class DeepSpeedLight(Module):
return load_path, client_state
def _load_zero_checkpoint(self, load_dir, tag):
def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True):
zero_checkpoint_name = self._get_zero_ckpt_name(load_dir, tag)
if not os.path.exists(zero_checkpoint_name):
......@@ -1000,7 +1003,8 @@ class DeepSpeedLight(Module):
return None
zero_sd = torch.load(zero_checkpoint_name, map_location='cpu')
self.optimizer.load_state_dict(zero_sd['optimizer_state_dict'])
self.optimizer.load_state_dict(zero_sd['optimizer_state_dict'],
load_optimizer_states=load_optimizer_states)
logging.info('loading zero checkpoint {}'.format(zero_checkpoint_name))
def save_checkpoint(self, save_dir, tag, client_state={}):
......
......@@ -507,7 +507,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
'single_partition_of_fp32_groups'] = self.single_partition_of_fp32_groups
return state_dict
def load_state_dict(self, state_dict):
def load_state_dict(self, state_dict, load_optimizer_states=True):
"""
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
......@@ -527,7 +527,8 @@ class FP16_DeepSpeedZeroOptimizer(object):
self.loss_scaler = state_dict['loss_scaler']
self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
self.overflow = state_dict['overflow']
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
if load_optimizer_states:
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
for current, saved in zip(self.single_partition_of_fp32_groups, state_dict['single_partition_of_fp32_groups']):
current.data.copy_(saved.data)
......@@ -294,7 +294,7 @@ class FP16_Optimizer(object):
state_dict['clip_grad'] = self.clip_grad
return state_dict
def load_state_dict(self, state_dict):
def load_state_dict(self, state_dict, load_optimizer_states=True):
"""
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
......@@ -318,7 +318,8 @@ class FP16_Optimizer(object):
self.last_overflow_iter = state_dict['last_overflow_iter']
self.scale_factor = state_dict['scale_factor']
self.scale_window = state_dict['scale_window']
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
if load_optimizer_states:
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
self.clip_grad = state_dict['clip_grad']
# At this point, the optimizer's references to the model's fp32 parameters are up to date.
# The optimizer's hyperparameters and internal buffers are also up to date.
......
......@@ -281,7 +281,7 @@ class FP16_UnfusedOptimizer(object):
state_dict['fp32_groups'] = self.fp32_groups
return state_dict
def load_state_dict(self, state_dict):
def load_state_dict(self, state_dict, load_optimizer_states=True):
"""
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
......@@ -305,7 +305,9 @@ class FP16_UnfusedOptimizer(object):
self.last_overflow_iter = state_dict['last_overflow_iter']
self.scale_factor = state_dict['scale_factor']
self.scale_window = state_dict['scale_window']
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
if load_optimizer_states:
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
# At this point, the optimizer's references to the model's fp32 parameters are up to date.
# The optimizer's hyperparameters and internal buffers are also up to date.
# However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
......
文件模式从 100644 更改为 100755
import torch
import deepspeed
from deepspeed.pt.deepspeed_zero_optimizer import FP16_DeepSpeedZeroOptimizer
from deepspeed.pt.fp16_optimizer import FP16_Optimizer
from deepspeed.pt.fp16_unfused_optimizer import FP16_UnfusedOptimizer
import argparse
import pytest
import json
import os
from common import distributed_test
from simple_model import SimpleModel, random_dataloader, args_from_dict
def compare_model_states(saved_model, loaded_model):
for p0, p1 in zip(saved_model.module.parameters(), loaded_model.module.parameters()):
assert torch.allclose(p0,p1,atol=1e-07), f"FP16 model state {p0} is not equal to {p1}"
if isinstance(saved_model.optimizer, FP16_DeepSpeedZeroOptimizer):
for p0, p1 in zip(saved_model.optimizer.single_partition_of_fp32_groups, loaded_model.optimizer.single_partition_of_fp32_groups):
assert torch.allclose(p0,p1,atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"
elif isinstance(saved_model.optimizer, FP16_Optimizer):
for p0, p1 in zip(saved_model.optimizer.fp32_groups_flat, loaded_model.optimizer.fp32_groups_flat):
assert torch.allclose(p0,p1,atol=1e-07), f"FP32 model states {p0} is not equal to {p1}"
elif isinstance(saved_model.optimizer, FP16_UnfusedOptimizer):
for params0, params1 in zip(saved_model.optimizer.fp32_groups, loaded_model.optimizer.fp32_groups):
for p0, p1 in zip(params0, params1):
assert torch.allclose(p0,p1,atol=1e-07), f"FP32 model states {p0} is not equal to {p1}"
else:
assert False, 'Unexpected Optimizer Type'
def compare_optimizer_states(saved_model, loaded_model, hidden_dim):
compare_model_states(saved_model, loaded_model)
for state0, state1 in zip(saved_model.optimizer.optimizer.state.values(),
loaded_model.optimizer.optimizer.state.values()):
for s0, s1 in zip(state0.values(), state1.values()):
if isinstance(s0, torch.Tensor) and isinstance(s1, torch.Tensor):
assert torch.equal(s0, s1)
else:
assert s0 == s1
def checkpoint_correctness_verification(args,
model,
hidden_dim,
load_optimizer_states=True):
ds_model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=ds_model,
total_samples=50,
hidden_dim=hidden_dim,
device=ds_model.device)
for n, batch in enumerate(data_loader):
loss = ds_model(batch[0], batch[1])
ds_model.backward(loss)
ds_model.step()
trained_model = ds_model
save_folder = 'saved_checkpoint'
save_tag = '1'
trained_model.save_checkpoint(save_folder, save_tag)
loaded_model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
loaded_model.load_checkpoint(save_folder,
save_tag,
load_optimizer_states=load_optimizer_states)
if load_optimizer_states:
compare_optimizer_states(trained_model, loaded_model, hidden_dim)
else:
compare_model_states(trained_model, loaded_model)
def test_checkpoint_unfused_optimizer(tmpdir):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Lamb",
"params": {
"lr": 0.00015,
"max_grad_norm": 1.0
}
},
"fp16": {
"enabled": True
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
@distributed_test(world_size=[2])
def _test_checkpoint_unfused_optimizer(args,
model,
hidden_dim,
load_optimizer_states):
checkpoint_correctness_verification(args,
model,
hidden_dim,
load_optimizer_states=load_optimizer_states)
_test_checkpoint_unfused_optimizer(args=args,
model=model,
hidden_dim=hidden_dim,
load_optimizer_states=True)
_test_checkpoint_unfused_optimizer(args=args,
model=model,
hidden_dim=hidden_dim,
load_optimizer_states=False)
def test_checkpoint_fused_optimizer(tmpdir):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015,
"betas": [0.8,
0.999],
"eps": 1e-8,
"weight_decay": 3e-7
}
},
"fp16": {
"enabled": True
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
@distributed_test(world_size=[2])
def _test_checkpoint_fused_optimizer(args, model, hidden_dim, load_optimizer_states):
checkpoint_correctness_verification(args,
model,
hidden_dim,
load_optimizer_states=load_optimizer_states)
_test_checkpoint_fused_optimizer(args=args,
model=model,
hidden_dim=hidden_dim,
load_optimizer_states=True)
_test_checkpoint_fused_optimizer(args=args,
model=model,
hidden_dim=hidden_dim,
load_optimizer_states=False)
def test_checkpoint_zero_optimizer(tmpdir):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015,
"betas": [0.8,
0.999],
"eps": 1e-8,
"weight_decay": 3e-7
}
},
"fp16": {
"enabled": True
},
"zero_optimization": True
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
@distributed_test(world_size=[2])
def _test_checkpoint_zero_optimizer(args, model, hidden_dim, load_optimizer_states):
checkpoint_correctness_verification(args,
model,
hidden_dim,
load_optimizer_states=load_optimizer_states)
_test_checkpoint_zero_optimizer(args=args,
model=model,
hidden_dim=hidden_dim,
load_optimizer_states=True)
_test_checkpoint_zero_optimizer(args=args,
model=model,
hidden_dim=hidden_dim,
load_optimizer_states=False)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册