gshard_gate.py 2.3 KB
Newer Older
R
Roc 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import paddle
import paddle.nn.functional as F
import numpy as np
from .naive_gate import NaiveGate
from ..utils import limit_by_capacity


class GShardGate(NaiveGate):
    def __init__(self,
                 d_model,
                 num_expert,
                 world_size,
                 topk=2,
                 capacity=(1.2, 2.4),
                 random_routing=True,
                 group=None):
        assert topk == 2, "topk should be 2 in gshard"
        super().__init__(d_model, num_expert, world_size)
        self.capacity = capacity
        self.random_routing = random_routing
        self.group = group

    def forward(self, x):
        topk_val, topk_idx, gate_score = super().forward(
            x, return_all_scores=True)
        s = gate_score.shape[0]
        top1_idx = topk_idx.flatten()
        c_e = paddle.scatter(
            paddle.zeros(shape=[self.tot_expert]),
            top1_idx,
            paddle.ones_like(
                top1_idx, dtype="float32"),
            overwrite=False) / s
        m_e = paddle.mean(F.softmax(gate_score, axis=1), axis=0)
        loss = paddle.mean(c_e * m_e) * (self.num_expert**2)
        self.set_loss(loss)

        cap_rate = self.capacity[0 if self.training else 1]
        capacity = math.ceil(cap_rate * x.shape[0])
        _new_lec, _new_gec, topk_idx = limit_by_capacity(
            topk_idx,
            self.num_expert,
            self.world_size,
            capacity,
            group=self.group)

        if self.random_routing:
            rand_routing_prob = paddle.rand(
                shape=[gate_score.shape[0]], dtype="float32")
            topk_idx = paddle.distributed.utils.random_routing(
                topk_idx, topk_val, rand_routing_prob)
        return topk_val, topk_idx