未验证 提交 56635d5b 编写于 作者: A Ammar Ahmad Awan 提交者: GitHub

enable/disable moe token dropping. (#1492)

* Add a flag to enable/disable token dropping in moe/top-1 gating.

* fix syntax and formatting.
上级 99bd592d
......@@ -24,7 +24,8 @@ class MoE(torch.nn.Module):
capacity_factor=1.,
eval_capacity_factor=1.,
min_capacity=4,
noisy_gate_policy: typing.Optional[str] = None):
noisy_gate_policy: typing.Optional[str] = None,
drop_tokens: bool = True):
"""Initialize an MoE layer.
Arguments:
......@@ -66,7 +67,8 @@ class MoE(torch.nn.Module):
capacity_factor,
eval_capacity_factor,
min_capacity,
noisy_gate_policy),
noisy_gate_policy,
drop_tokens),
experts,
num_local_experts,
group=groups.get_expert_parallel_group())
......
......@@ -136,9 +136,10 @@ def top1gating(logits: torch.Tensor,
capacity_factor: float,
min_capacity: int,
used_token: torch.Tensor = None,
noisy_gate_policy: Optional[str] = None) -> Tuple[Tensor,
Tensor,
Tensor]:
noisy_gate_policy: Optional[str] = None,
drop_tokens: bool = True) -> Tuple[Tensor,
Tensor,
Tensor]:
"""Implements Top1Gating on logits."""
if noisy_gate_policy == 'RSample':
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
......@@ -167,6 +168,12 @@ def top1gating(logits: torch.Tensor,
# gating decisions
exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')
# if we don't want to drop any tokens
if not drop_tokens:
new_capacity = torch.max(exp_counts).to(logits.device)
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.group.WORLD)
capacity = new_capacity
# Compute l_aux
me = torch.mean(gates, dim=0)
ce = torch.mean(mask1.float(), dim=0)
......@@ -306,7 +313,8 @@ class TopKGate(torch.nn.Module):
capacity_factor: float = 1.0,
eval_capacity_factor: float = 1.0,
min_capacity: int = 4,
noisy_gate_policy: Optional[str] = None) -> None:
noisy_gate_policy: Optional[str] = None,
drop_tokens: bool = True) -> None:
super().__init__()
# Only top-1 and top-2 are supported at the moment.
......@@ -321,6 +329,7 @@ class TopKGate(torch.nn.Module):
self.timers = SynchronizedWallClockTimer()
self.wall_clock_breakdown = False
self.gate_time = 0.0
self.drop_tokens = drop_tokens
def forward(
self,
......@@ -347,7 +356,8 @@ class TopKGate(torch.nn.Module):
self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity,
used_token,
self.noisy_gate_policy if self.training else None)
self.noisy_gate_policy if self.training else None,
self.drop_tokens)
else:
gate_output = top2gating(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册