未验证 提交 c88af214 编写于 作者: Z Zhen Zhang 提交者: GitHub

[MiCS] [Fix] saving and loading model checkpoint logic for MiCS sharding (#3440)

* fix mics save checkpoint hanging

* MiCS load_checkpoint

* copyright

* fix for torch-1.9.0

all_reduce_coalesced api does not support nccl backend

* Naming alignment

* adding more test conditions for mics shard size

* test with different shard sizes

* adding assertion for better error msg

---------
Co-authored-by: NZhen Zhang <zhzhn@amazon.com>
上级 f483c034
......@@ -20,13 +20,24 @@ def is_torch_two():
return False
def torch_ver_ge_1_13():
if is_torch_two():
return True
else:
TORCH_MAJOR = int(torch.__version__.split('.')[0])
assert TORCH_MAJOR == 1
TORCH_MINOR = int(torch.__version__.split('.')[1])
return TORCH_MINOR >= 13
def has_coalescing_manager():
has_c10d = hasattr(torch.distributed, 'distributed_c10d')
return has_c10d and hasattr(torch.distributed.distributed_c10d, '_coalescing_manager')
def has_all_reduce_coalesced():
return hasattr(torch.distributed, "all_reduce_coalesced")
return hasattr(torch.distributed, "all_reduce_coalesced") and torch_ver_ge_1_13()
def get_coalescing_manager(group, device, reqs):
......
......@@ -3057,11 +3057,11 @@ class DeepSpeedEngine(Module):
def _create_zero_checkpoint_files(self, save_dir, tag):
success = True
# zero checkpoint files are created sequentially
for rank in range(self.world_size):
for rank in range(dist.get_world_size(self.optimizer.dp_process_group)):
if rank == self.global_rank:
success = self._create_checkpoint_file(save_dir, tag, True)
dist.barrier()
dist.barrier(group=self.optimizer.dp_process_group)
return success
......
......@@ -393,15 +393,20 @@ class MiCS_Optimizer(DeepSpeedZeroOptimizer_Stage3):
gradient_accumulation_steps, elastic_checkpoint, aio_config)
first_param = next(module.parameters())
# overload the dp_process_group and partition_count
assert hasattr(first_param, "comm"), " ".join([
"Sharded parameters don't have the MiCS_CommGroups attached.",
"Might due to the use of deepspeed.zero.Init context for initializing the weights.",
"To use MiCS sharding, please use deepspeed.zero.MiCS_Init instead for initializing parameter."
])
self.dp_process_group = first_param.comm.param_shard_group
self.partition_count = first_param.comm.param_shard_size
def initialize_ds_offload(self, module, timers, ds_config, overlap_comm, prefetch_bucket_size, max_reuse_distance,
max_live_parameters, param_persistence_threshold, model_persistence_threshold,
offload_optimizer_config, mpu):
offload_param_config, mpu):
return MiCS_Offload(module, timers, ds_config, overlap_comm, prefetch_bucket_size, max_reuse_distance,
max_live_parameters, param_persistence_threshold, model_persistence_threshold,
offload_optimizer_config, mpu)
offload_param_config, mpu)
def partition_grads(self, params_to_release: List[Parameter], grad_partitions: List[Tensor]) -> None:
grad_buffers = super().partition_grads(params_to_release, grad_partitions)
......@@ -440,14 +445,13 @@ class MiCS_Optimizer(DeepSpeedZeroOptimizer_Stage3):
grad_buff.view(-1).copy_(aggregated_buffer.narrow(0, offset, grad_buff.numel()))
offset += grad_buff.numel()
# TODO: Support different/changing load/save DP degree.
def load_state_dict(self,
state_dict_list,
load_optimizer_states=True,
load_from_fp32_weights=False,
checkpoint_folder=None):
r""" Loading the MiCS checkpoints
TODO: move the implementation from zhen/merged_ds_master branch
r""" Loading the ZeRO-3/MiCS partitioned checkpoints
Because the self.dp_process_group is replaced with the communicator for
partition group we can call the load_state_dict logic from ZeRO-3.
"""
raise NotImplementedError("Not implemented for loading MiCS checkpoints")
super().load_state_dict(state_dict_list, load_optimizer_states, load_from_fp32_weights, checkpoint_folder)
#!/bin/bash
deepspeed test_mics_config.py --mics_shard_size=1
deepspeed test_mics_config.py --mics_shard_size=2
# for debugging the hierarchical params gathering
export NDEV_PER_NODE=2
deepspeed test_mics_config.py --mics_shard_size=4 --mics_hierarchical_params_gather
......@@ -5,6 +5,10 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Testing on a 8 GPUs node
NDEV_PER_NODE=2 torchrun --nnodes 1 --nproc-per-node 8 test_mics_config.py
"""
import os
import json
......@@ -39,7 +43,7 @@ def create_config_from_dict(tmpdir, config_dict):
def get_data_loader(model, total_samples, hidden_dim, device):
batch_size = model.train_micro_batch_size_per_gpu()
train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=torch.half)
train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=torch.float)
train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim)
train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
sampler = DistributedSampler(train_dataset)
......@@ -49,11 +53,17 @@ def get_data_loader(model, total_samples, hidden_dim, device):
def get_args(tmpdir, config_dict):
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument('--zero', type=int, default=3)
parser.add_argument('--local_rank', type=int)
parser.add_argument('--mics_shard_size', default=2, type=int)
parser.add_argument('--mics_hierarchical_params_gather', default=False, action='store_true')
args = parser.parse_args() #args=''
config_dict["zero_optimization"]["stage"] = args.zero
config_dict["zero_optimization"]["mics_shard_size"] = args.mics_shard_size
config_dict["zero_optimization"]["mics_hierarchical_params_gather"] = args.mics_hierarchical_params_gather
# print('config_dict["zero_optimization"]', config_dict["zero_optimization"])
config_path = create_config_from_dict(tmpdir, config_dict)
......@@ -80,7 +90,7 @@ config_dict = {
}
},
"fp16": {
"enabled": True,
"enabled": False,
"initial_scale_power": 15
},
"zero_optimization": {
......@@ -95,8 +105,8 @@ config_dict = {
args = get_args('/tmp/', config_dict)
hidden_dim = 32
# with deepspeed.zero.Init():
model = SimpleModel(hidden_dim, empty_grad=False)
with deepspeed.zero.MiCS_Init(config_dict_or_path=config_dict):
model = SimpleModel(hidden_dim, empty_grad=False)
# print('------> init model with deepspeed.zero.Init()')
model, _, _, _ = deepspeed.initialize(args=args,
......
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import deepspeed
from unit.common import DistributedTest
from unit.simple_model import *
from unit.checkpoint.common import *
import pytest
class TestMiCSCheckpoint(DistributedTest):
world_size = 4
def _toy_model_config(self, shard_size):
config_dict = {
"train_micro_batch_size_per_gpu": 2,
"steps_per_print": 1,
"optimizer": {
"type": 'Adam',
"params": {
"lr": 0.00015,
"betas": [0.8, 0.999],
"eps": 1e-8,
"weight_decay": 3e-7
}
},
"fp16": {
"enabled": True,
"initial_scale_power": 8
},
"wall_clock_breakdown": True,
"zero_optimization": {
"stage": 3,
"mics_shard_size": shard_size
}
}
hidden_dim = 10
with deepspeed.zero.MiCS_Init(config_dict_or_path=config_dict):
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
return config_dict, hidden_dim, models
@pytest.mark.parametrize('shard_size', [1, 2, 4])
def test_load_optimizer_state(self, tmpdir, shard_size):
config_dict, hidden_dim, models = self._toy_model_config(shard_size)
checkpoint_correctness_verification(config_dict, models, hidden_dim, tmpdir, load_optimizer_states=True)
@pytest.mark.parametrize('shard_size', [1, 2, 4])
def test_not_load_optimizer_state(self, tmpdir, shard_size):
config_dict, hidden_dim, models = self._toy_model_config(shard_size)
checkpoint_correctness_verification(config_dict, models, hidden_dim, tmpdir, load_optimizer_states=False)
@pytest.mark.parametrize('shard_size', [1, 2, 4])
def test_load_module_only(self, tmpdir, shard_size):
config_dict, hidden_dim, models = self._toy_model_config(shard_size)
checkpoint_correctness_verification(config_dict, models, hidden_dim, tmpdir, load_module_only=True)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册