未验证 提交 7392e459 编写于 作者: S Stas Bekman 提交者: GitHub

[zero2] zero_param_shapes: switch to round_robin_fp16_groups (#1240)

* zero_param_shapes: switch to round_robin_fp16_groups

* add test

* old torch workaround
上级 32e85eda
......@@ -1978,7 +1978,15 @@ class DeepSpeedEngine(Module):
param_shapes = OrderedDict()
cnt = 0
numel = 0
for fp16_group in self.optimizer.fp16_groups:
# zero2 started using a round_robin_fp16_groups which is a shuffled version of fp16_groups -
# if we don't use it, we get parameters ordered incorrectly
if hasattr(self.optimizer, "round_robin_fp16_groups"):
fp16_groups = self.optimizer.round_robin_fp16_groups
else:
fp16_groups = self.optimizer.fp16_groups
for fp16_group in fp16_groups:
for param in fp16_group:
cnt += 1
numel += param.ds_numel if hasattr(param, "ds_numel") else param.numel()
......
......@@ -422,6 +422,10 @@ class FP16_DeepSpeedZeroOptimizer(object):
param.data = self.round_robin_fp16_groups[group_index][new_index].data
def _round_robin_reorder(self, tensor_list, num_partitions):
# disable round robin if need to debug something
#return tensor_list, list(range(len(tensor_list)))
partition_tensors = {}
for i, tensor in enumerate(tensor_list):
......
......@@ -3,11 +3,13 @@ import pytest
import json
import argparse
import os
import torch.distributed as dist
from common import distributed_test
from simple_model import SimpleModel, random_dataloader, args_from_dict
import deepspeed
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
def run_unbalanced_gradients(model, data_loader):
......@@ -128,3 +130,104 @@ def test_zero3_repeat_forward_loop(tmpdir, zero_stage):
model.step()
_test_zero3_repeat_forward_loop(args=args, model=model, hidden_dim=hidden_dim)
# testing the fix https://github.com/microsoft/DeepSpeed/pull/1227
@pytest.mark.parametrize('zero_stage', [2, 3])
def test_zero_to_fp32(tmpdir, zero_stage):
# TODO:
# - need to test with multiple param groups
# force all params to be partitioned by forcing threshold=0
config_dict = {
"train_micro_batch_size_per_gpu": 2,
"gradient_accumulation_steps": 2,
"steps_per_print": 1,
"zero_optimization": {
"stage": zero_stage,
"stage3_param_persistence_threshold": 0
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-3
}
},
"fp16": {
"enabled": True,
"initial_scale_power": 8
}
}
@distributed_test(world_size=[2])
def _test_zero_to_fp32():
class MyModel(torch.nn.Module):
def __init__(self, hidden_dim, n_layers):
super().__init__()
self.ll = torch.nn.ModuleList(
torch.nn.Linear(hidden_dim,
hidden_dim) for i in range(n_layers))
self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
def forward(self, x, y):
hidden = x
for l in self.ll:
hidden = l(hidden)
return self.cross_entropy_loss(hidden, y)
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 2
world_size = dist.get_world_size()
# we want at least 2x layers as there are gpus to trigger round_robin_fp16_groups reshuffle in zero2
n_layers = world_size * 2
model = MyModel(hidden_dim=hidden_dim, n_layers=n_layers)
model, _, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=16,
hidden_dim=hidden_dim,
device=model.device)
for i, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
model.save_checkpoint(tmpdir)
# make sure all sides saved it
dist.barrier()
def dump_state_dict(model):
if dist.get_rank() != 0:
return
for name, param in model.named_parameters():
print(f"{name} {param}")
if zero_stage == 3:
with deepspeed.zero.GatheredParameters(list(
model.module.parameters(recurse=True)),
modifier_rank=None):
pass # this forces gathering the model
#dump_state_dict(model)
orig_state_dict = {}
for name, param in model.module.named_parameters():
orig_state_dict[name] = param.detach().cpu()
print(orig_state_dict)
fp32_model = load_state_dict_from_zero_checkpoint(model.module, tmpdir)
#dump_state_dict(fp32_model)
fp32_state_dict = fp32_model.state_dict()
for name in orig_state_dict.keys():
# float() workaround for torch<1.6
assert torch.allclose(orig_state_dict[name].float(),
fp32_state_dict[name].float())
_test_zero_to_fp32()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册