未验证 提交 baef92e2 编写于 作者: S Stas Bekman 提交者: GitHub

[save_16bit_model] add missing prologue (#1741)

* [save_16bit_model] add missing prologue

* fix

* fix

* adjust the API rename

* add test

* style
Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 5fe5b38e
......@@ -21,6 +21,8 @@ from torch.distributed.distributed_c10d import _get_global_rank
from typing import Callable, Dict, Optional, Union, Iterable
import deepspeed
from deepspeed.runtime.utils import see_memory_usage, get_ma_status, DummyOptim
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
......@@ -3061,8 +3063,6 @@ class DeepSpeedEngine(Module):
a consolidated fp16 ``state_dict`` on cpu on rank 0, ``None`` on other ranks
"""
import deepspeed
if not self.zero_optimization_partition_weights():
raise ValueError("this function requires ZeRO-3 mode")
......@@ -3105,10 +3105,15 @@ class DeepSpeedEngine(Module):
if child is not None:
get_layer_state_dict(child, prefix + name + ".")
# Prepare for checkpoint save by ensuring all parameters are partitioned
self.optimizer.checkpoint_event_prologue()
see_memory_usage("before get_layer_state_dict", force=False)
get_layer_state_dict(self.module, prefix="")
see_memory_usage("after get_layer_state_dict", force=False)
self.optimizer.checkpoint_event_epilogue()
return state_dict
def save_fp16_model(self, save_dir, save_filename="pytorch_model.bin"):
......
......@@ -1345,3 +1345,54 @@ def test_load_immediate_save(tmpdir, zero_stage):
ds_model.save_checkpoint(tmpdir)
_test_load_immediate_save(args, model, tmpdir)
@pytest.mark.parametrize('zero_stage', [0, 1, 2, 3])
def test_save_before_accum_grad_is_done(tmpdir, zero_stage):
config_dict = {
"train_batch_size": 4,
"optimizer": {
"type": 'Adam'
},
"fp16": {
"enabled": True,
"initial_scale_power": 8
},
"zero_optimization": {
"stage": zero_stage,
"stage3_gather_fp16_weights_on_model_save": True,
},
"gradient_accumulation_steps": 2,
"train_micro_batch_size_per_gpu": 1,
"train_batch_size": 2,
}
hidden_dim = 10
model = SimpleModel(hidden_dim)
args = args_from_dict(tmpdir, config_dict)
@distributed_test(world_size=[1])
def _test_save_before_accum_grad_is_done(args, model, tmpdir):
# This test reproduces a bug where one tries to retrieve a 16bit model before grad_accum
# cycle was completed.
# So we config grad_accum=2 and step only once and save_16bit_model
ds_model = create_deepspeed_model(args=args, model=model, base_optimizer=None)
data_loader = random_dataloader(model=ds_model,
total_samples=2,
hidden_dim=hidden_dim,
device=ds_model.device,
dtype=torch.half)
batch = next(iter(data_loader))
loss = ds_model(batch[0], batch[1])
ds_model.backward(loss)
ds_model.step()
# we stepped only once, and now save 16bit model before gradient_accumulation_steps=2 is complete
ds_model.save_16bit_model(tmpdir, "model.pt")
# let's test just as well that we can save the checkpoint too
ds_model.save_checkpoint(tmpdir)
_test_save_before_accum_grad_is_done(args, model, tmpdir)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册