diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 06cad6d0c364391d84aa9ea2662d03d4fbc5e07e..e496589fba0daf155f25d4018ac63697029ddad3 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -21,6 +21,8 @@ from torch.distributed.distributed_c10d import _get_global_rank from typing import Callable, Dict, Optional, Union, Iterable +import deepspeed + from deepspeed.runtime.utils import see_memory_usage, get_ma_status, DummyOptim from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus @@ -3061,8 +3063,6 @@ class DeepSpeedEngine(Module): a consolidated fp16 ``state_dict`` on cpu on rank 0, ``None`` on other ranks """ - import deepspeed - if not self.zero_optimization_partition_weights(): raise ValueError("this function requires ZeRO-3 mode") @@ -3105,10 +3105,15 @@ class DeepSpeedEngine(Module): if child is not None: get_layer_state_dict(child, prefix + name + ".") + # Prepare for checkpoint save by ensuring all parameters are partitioned + self.optimizer.checkpoint_event_prologue() + see_memory_usage("before get_layer_state_dict", force=False) get_layer_state_dict(self.module, prefix="") see_memory_usage("after get_layer_state_dict", force=False) + self.optimizer.checkpoint_event_epilogue() + return state_dict def save_fp16_model(self, save_dir, save_filename="pytorch_model.bin"): diff --git a/tests/unit/test_checkpointing.py b/tests/unit/test_checkpointing.py index 7e830e8f4f7eb295b465616a88227f3132e9e60f..251e90e9320ef1b5fe8ccbcaaa33b33fdfa37c5c 100755 --- a/tests/unit/test_checkpointing.py +++ b/tests/unit/test_checkpointing.py @@ -1345,3 +1345,54 @@ def test_load_immediate_save(tmpdir, zero_stage): ds_model.save_checkpoint(tmpdir) _test_load_immediate_save(args, model, tmpdir) + + +@pytest.mark.parametrize('zero_stage', [0, 1, 2, 3]) +def test_save_before_accum_grad_is_done(tmpdir, zero_stage): + config_dict = { + "train_batch_size": 4, + "optimizer": { + "type": 'Adam' + }, + "fp16": { + "enabled": True, + "initial_scale_power": 8 + }, + "zero_optimization": { + "stage": zero_stage, + "stage3_gather_fp16_weights_on_model_save": True, + }, + "gradient_accumulation_steps": 2, + "train_micro_batch_size_per_gpu": 1, + "train_batch_size": 2, + } + hidden_dim = 10 + model = SimpleModel(hidden_dim) + args = args_from_dict(tmpdir, config_dict) + + @distributed_test(world_size=[1]) + def _test_save_before_accum_grad_is_done(args, model, tmpdir): + + # This test reproduces a bug where one tries to retrieve a 16bit model before grad_accum + # cycle was completed. + # So we config grad_accum=2 and step only once and save_16bit_model + ds_model = create_deepspeed_model(args=args, model=model, base_optimizer=None) + + data_loader = random_dataloader(model=ds_model, + total_samples=2, + hidden_dim=hidden_dim, + device=ds_model.device, + dtype=torch.half) + + batch = next(iter(data_loader)) + loss = ds_model(batch[0], batch[1]) + ds_model.backward(loss) + ds_model.step() + + # we stepped only once, and now save 16bit model before gradient_accumulation_steps=2 is complete + ds_model.save_16bit_model(tmpdir, "model.pt") + + # let's test just as well that we can save the checkpoint too + ds_model.save_checkpoint(tmpdir) + + _test_save_before_accum_grad_is_done(args, model, tmpdir)