From 2ded2ff0be6ef5f1e592af7d2f7e05c1151a8d44 Mon Sep 17 00:00:00 2001 From: Alexander Jipa Date: Fri, 30 Jun 2023 10:40:13 -0400 Subject: [PATCH] checking process_group before merging bucket ranges (#3521) (#3577) Co-authored-by: Alexander Jipa Co-authored-by: Olatunji Ruwase --- deepspeed/runtime/zero/stage_1_and_2.py | 6 +- tests/unit/moe/test_moe.py | 77 ++++++++++++++++++++++--- tests/unit/simple_model.py | 40 +++++++------ 3 files changed, 96 insertions(+), 27 deletions(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 06a776b8..37c0ef06 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -915,7 +915,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): rank_and_offsets = [] real_dp_process_group = [] curr_size = 0 - prev_id = -1 + prev_id, prev_process_group = -1, None process_group = self.dp_process_group # count = 0 @@ -958,14 +958,14 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): numel = partition_ids_w_offsets[idx + 1][1] - offset # Merge bucket ranges if they belong to the same rank - if partition_id == prev_id: + if partition_id == prev_id and process_group == prev_process_group: prev_pid, prev_size, prev_numel = rank_and_offsets[-1] rank_and_offsets[-1] = (prev_pid, prev_size, prev_numel + numel) else: rank_and_offsets.append((partition_id, curr_size, numel)) real_dp_process_group.append(process_group) curr_size += numel - prev_id = partition_id + prev_id, prev_process_group = partition_id, process_group if not self.ipg_bucket_has_moe_params: tensor.div_(dist.get_world_size(group=self.dp_process_group)) diff --git a/tests/unit/moe/test_moe.py b/tests/unit/moe/test_moe.py index 83894b29..afbcaf83 100644 --- a/tests/unit/moe/test_moe.py +++ b/tests/unit/moe/test_moe.py @@ -8,33 +8,96 @@ import deepspeed import pytest from unit.common import DistributedTest from unit.simple_model import SimplePRMoEModel, SimpleMoEModel, sequence_dataloader +from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer, is_moe_param from unit.util import required_torch_version @pytest.mark.parametrize("ep_size", [2, 4]) +@pytest.mark.parametrize("zero_stage", [0, 1, 2]) @pytest.mark.parametrize("use_residual", [True, False]) class TestMoE(DistributedTest): world_size = 4 - def test(self, ep_size, use_residual): + def test(self, ep_size, zero_stage, use_residual): if not required_torch_version(): pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly") - config_dict = {"train_batch_size": 8, "steps_per_print": 1, "fp16": {"enabled": True}} + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "fp16": { + "enabled": True + }, + "zero_optimization": { + "stage": zero_stage + } + } hidden_dim = 16 # E+D -- ep_size = 2 # E only -- ep_size = 4 model = SimpleMoEModel(hidden_dim, ep_size=ep_size, use_residual=use_residual) - optimizer = torch.optim.AdamW(params=model.parameters()) - model, _, _, _ = deepspeed.initialize(config=config_dict, - model=model, - optimizer=optimizer, - dist_init_required=False) + param_group = {'params': [p for p in model.parameters()], 'name': 'random-unique-name'} + params = split_params_into_different_moe_groups_for_optimizer(param_group) + optimizer = torch.optim.AdamW(params=params) + model, optimizer, _, _ = deepspeed.initialize(config=config_dict, + model=model, + optimizer=optimizer, + dist_init_required=False) #dist_init_required=False -- parameterize to True/False? data_loader = sequence_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) + def strict_average_tensor(tensor): + process_group = optimizer.dp_process_group + curr_size = 0 + pg_offsets = [] + for i, param, param_id in optimizer.params_in_ipg_bucket: + process_group = optimizer.dp_process_group + if optimizer.ipg_bucket_has_moe_params: + process_group = optimizer.expert_dp_process_group[param.group_name] if is_moe_param( + param) else optimizer.dp_process_group + partition_ids = optimizer.param_to_partition_ids[i][param_id] + # Get all partition ids + their offsets + partition_offsets = [] + for partition_id in partition_ids: + offset = optimizer.grad_start_offset[i][partition_id][param_id] + partition_offsets.append(offset) + partition_offsets.sort() + # Calculate rank and offsets for grad slices + for idx, offset in enumerate(partition_offsets): + # Calculate numel for grad slice depending on partition location + if idx == len(partition_offsets) - 1: + # Last partition_id uses its own offset + numel = param.numel() - offset + else: + # Set numel to next partition's offset + numel = partition_offsets[idx + 1] - offset + pg_offsets.append((curr_size, process_group)) + curr_size += numel + + def strict_narrow(dim, start, length): + lo, hi = 0, len(pg_offsets) - 1 + while lo < hi: + mi = lo + (hi - lo) // 2 + if pg_offsets[mi][0] >= start: + hi = mi + else: + lo = mi + 1 + curr_slice, reduce_process_group = lo, pg_offsets[lo][1] + while curr_slice < len(pg_offsets) and start + length > pg_offsets[curr_slice][0]: + assert reduce_process_group == pg_offsets[curr_slice][ + 1], "reduce process_group does not match the parameter's process_group" + curr_slice += 1 + return orig_narrow(dim, start, length) # real call + + orig_narrow, tensor.narrow = tensor.narrow, strict_narrow + type(optimizer).average_tensor(optimizer, tensor) # real call + tensor.narrow = orig_narrow + + if "average_tensor" in dir(optimizer): + optimizer.average_tensor = strict_average_tensor + for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) diff --git a/tests/unit/simple_model.py b/tests/unit/simple_model.py index 330f612e..9fd0cff8 100644 --- a/tests/unit/simple_model.py +++ b/tests/unit/simple_model.py @@ -71,27 +71,33 @@ class SimpleMoEModel(torch.nn.Module): def __init__(self, hidden_dim, num_experts=4, ep_size=1, use_residual=False): super(SimpleMoEModel, self).__init__() - self.linear = torch.nn.Linear(hidden_dim, hidden_dim) - expert = torch.nn.Linear(hidden_dim, hidden_dim) + self.linear1 = torch.nn.Linear(hidden_dim, hidden_dim) + expert = torch.nn.Sequential(torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.Linear(hidden_dim, hidden_dim)) # using two MoE layers to check implications of sharing a single storage - self.linear2 = MoE(hidden_size=hidden_dim, - expert=expert, - ep_size=ep_size, - use_residual=use_residual, - num_experts=num_experts, - k=1) - self.linear3 = MoE(hidden_size=hidden_dim, - expert=expert, - ep_size=ep_size, - use_residual=use_residual, - num_experts=num_experts, - k=1) + self.moe_1 = MoE(hidden_size=hidden_dim, + expert=expert, + ep_size=ep_size, + use_residual=use_residual, + num_experts=num_experts, + k=1) + # interleaving MoE modules with dense to create an opportunity + # for gradients to be merged in ZeRO stage 2 average_tensor reduce bucket + self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim) + self.moe_2 = MoE(hidden_size=hidden_dim, + expert=expert, + ep_size=ep_size, + use_residual=use_residual, + num_experts=num_experts, + k=1) + self.linear3 = torch.nn.Linear(hidden_dim, hidden_dim) self.cross_entropy_loss = torch.nn.CrossEntropyLoss() def forward(self, x, y): - hidden_dim = self.linear(x) - output, _, _ = self.linear2(hidden_dim) - output, _, _ = self.linear3(output) + hidden_dim = self.linear1(x) + output, _, _ = self.moe_1(hidden_dim) + output = self.linear2(output) + output, _, _ = self.moe_2(output) + output = self.linear3(output) hidden_dim = hidden_dim + output sentence_embed = hidden_dim.mean(1) return self.cross_entropy_loss(sentence_embed, y) -- GitLab