From 56635d5b6c38a986663f2d92b30fe6e680930cb4 Mon Sep 17 00:00:00 2001 From: Ammar Ahmad Awan Date: Thu, 28 Oct 2021 10:08:00 -0700 Subject: [PATCH] enable/disable moe token dropping. (#1492) * Add a flag to enable/disable token dropping in moe/top-1 gating. * fix syntax and formatting. --- deepspeed/moe/layer.py | 6 ++++-- deepspeed/moe/sharded_moe.py | 20 +++++++++++++++----- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/deepspeed/moe/layer.py b/deepspeed/moe/layer.py index dd1fb462..dc0a4582 100644 --- a/deepspeed/moe/layer.py +++ b/deepspeed/moe/layer.py @@ -24,7 +24,8 @@ class MoE(torch.nn.Module): capacity_factor=1., eval_capacity_factor=1., min_capacity=4, - noisy_gate_policy: typing.Optional[str] = None): + noisy_gate_policy: typing.Optional[str] = None, + drop_tokens: bool = True): """Initialize an MoE layer. Arguments: @@ -66,7 +67,8 @@ class MoE(torch.nn.Module): capacity_factor, eval_capacity_factor, min_capacity, - noisy_gate_policy), + noisy_gate_policy, + drop_tokens), experts, num_local_experts, group=groups.get_expert_parallel_group()) diff --git a/deepspeed/moe/sharded_moe.py b/deepspeed/moe/sharded_moe.py index 08cbc407..5d919621 100644 --- a/deepspeed/moe/sharded_moe.py +++ b/deepspeed/moe/sharded_moe.py @@ -136,9 +136,10 @@ def top1gating(logits: torch.Tensor, capacity_factor: float, min_capacity: int, used_token: torch.Tensor = None, - noisy_gate_policy: Optional[str] = None) -> Tuple[Tensor, - Tensor, - Tensor]: + noisy_gate_policy: Optional[str] = None, + drop_tokens: bool = True) -> Tuple[Tensor, + Tensor, + Tensor]: """Implements Top1Gating on logits.""" if noisy_gate_policy == 'RSample': logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device) @@ -167,6 +168,12 @@ def top1gating(logits: torch.Tensor, # gating decisions exp_counts = torch.sum(mask1, dim=0).detach().to('cpu') + # if we don't want to drop any tokens + if not drop_tokens: + new_capacity = torch.max(exp_counts).to(logits.device) + dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.group.WORLD) + capacity = new_capacity + # Compute l_aux me = torch.mean(gates, dim=0) ce = torch.mean(mask1.float(), dim=0) @@ -306,7 +313,8 @@ class TopKGate(torch.nn.Module): capacity_factor: float = 1.0, eval_capacity_factor: float = 1.0, min_capacity: int = 4, - noisy_gate_policy: Optional[str] = None) -> None: + noisy_gate_policy: Optional[str] = None, + drop_tokens: bool = True) -> None: super().__init__() # Only top-1 and top-2 are supported at the moment. @@ -321,6 +329,7 @@ class TopKGate(torch.nn.Module): self.timers = SynchronizedWallClockTimer() self.wall_clock_breakdown = False self.gate_time = 0.0 + self.drop_tokens = drop_tokens def forward( self, @@ -347,7 +356,8 @@ class TopKGate(torch.nn.Module): self.capacity_factor if self.training else self.eval_capacity_factor, self.min_capacity, used_token, - self.noisy_gate_policy if self.training else None) + self.noisy_gate_policy if self.training else None, + self.drop_tokens) else: gate_output = top2gating( -- GitLab