# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # The file has been adapted from the file: # https://github.com/laekov/fastmoe/blob/master/fmoe/layers.py # Git commit hash: 295a615aacce7e54a37e7935274ba15e901c78e4 # We retain the following license from the original files: # Copyright 2021, Jiaao He. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"). import collections import math import numpy as np import paddle import paddle.nn as nn import paddle.nn.functional as F from paddle.distributed.utils import global_scatter, global_gather from paddle.distributed import alltoall, all_gather from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.distributed import fleet from paddle.autograd import PyLayer from .gate import NaiveGate, GShardGate, SwitchGate, BaseGate from .utils import count_by_gate from paddle import fluid from paddle.fluid.framework import in_dygraph_mode from paddle.incubate.distributed.fleet import recompute_hybrid def _local_scatter(inp, pos): if pos.shape != [0]: inp_buf = paddle.index_select(inp, pos, 0) else: inp_buf = paddle.empty([0, inp.shape[1]], dtype=inp.dtype) return inp_buf def _local_gather(inp, pos, out_batch_size, maybe_overlap=True): if pos.shape != [0]: origin_dtype = inp.dtype inp = paddle.cast(inp, dtype="float32") inp_buf = paddle.scatter(paddle.zeros( shape=[out_batch_size, inp.shape[-1]], dtype="float32"), pos, inp, overwrite=True) inp_buf = paddle.cast(inp_buf, dtype=origin_dtype) else: inp_buf = paddle.zeros([out_batch_size, inp.shape[-1]], dtype=inp.dtype) return inp_buf def _all_gather(tensor, group=None, use_calc_stream=True): if group is not None and not group.is_member(): return if in_dygraph_mode(): group = paddle.distributed.collective._get_default_group( ) if group is None else group tensor_shape = list(tensor.shape) tensor_shape[0] *= group.nranks out = paddle.empty(tensor_shape, tensor.dtype) task = group.process_group.all_gather(tensor, out) task.wait() return out else: ring_id = 0 if group is None else group.id nranks = paddle.distributed.collective._get_global_group( ).nranks if group is None else group.nranks return paddle._legacy_C_ops.c_allgather(tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'nranks', nranks) class MoEScatter(PyLayer): r""" Scatter input samples from [batch x sequences] to contiguous alone experts. If `world_size` is greater than 1, the samples will first be locally scattered, and then exchanged across workers. """ @staticmethod def forward(ctx, inp, pos, local_expert_count, global_expert_count, fwd_batch_size, world_size, group=None): local_input_buf = _local_scatter(inp, pos) if world_size > 1: global_input_buf = global_scatter(local_input_buf, local_expert_count, global_expert_count, group=group) else: global_input_buf = local_input_buf ctx.moe_args = inp.shape[0], world_size, group variables = (pos, local_expert_count, global_expert_count) ctx.save_for_backward(*variables) return global_input_buf @staticmethod def backward(ctx, grad): (pos, local_expert_count, global_expert_count) = ctx.saved_tensor() (inp_batch_size, world_size, group) = ctx.moe_args if world_size > 1: local_grad_in = global_gather(grad, local_expert_count, global_expert_count, group=group) else: local_grad_in = grad grad_in = _local_gather(local_grad_in, pos, inp_batch_size) return grad_in, None, None, None class MoEGather(PyLayer): r""" Gather output samples from contiguous alone experts back to [batch x sequences]. Works symmetrically with MoEScatter. """ @staticmethod def forward(ctx, global_output_buf, pos, local_expert_count, global_expert_count, local_batch_size, world_size, group=None): if world_size > 1: local_output_buf = global_gather(global_output_buf, local_expert_count, global_expert_count, group=group) else: local_output_buf = global_output_buf output = _local_gather(local_output_buf, pos, local_batch_size, maybe_overlap=False) ctx.moe_args = (global_output_buf.shape[0], world_size, group) variables = (pos, local_expert_count, global_expert_count) ctx.save_for_backward(*variables) return output @staticmethod def backward(ctx, grad_out): pos, local_expert_count, global_expert_count = ctx.saved_tensor() fwd_batch_size, world_size, group = ctx.moe_args grad_out_buf = _local_scatter(grad_out, pos) if world_size > 1: global_grad_out_buf = global_scatter(grad_out_buf, local_expert_count, global_expert_count, group=group) else: global_grad_out_buf = grad_out_buf return global_grad_out_buf, None, None, None class AllGather(PyLayer): r""" A wrapper for the All-Gather function to support auto-differentiation. """ @staticmethod def forward(ctx, inp, rank, world_size, group): tensor_list = [] paddle.distributed.all_gather(tensor_list, inp, group=group) output = paddle.concat(tensor_list, axis=0) ctx.args = rank, inp.shape[0] return output @staticmethod def backward(ctx, grad_out): rank, dim0 = ctx.args return paddle.slice(grad_out, axes=[0], starts=[rank * dim0], ends=[(rank + 1) * dim0]) class Slice(PyLayer): r""" A wrapper for the Slice function to support auto-differentiation. """ @staticmethod def forward(ctx, inp, rank, world_size, group): B = inp.shape[0] local_batch_size = B // world_size batch_start = local_batch_size * rank batch_end = min(batch_start + local_batch_size, B) inp = paddle.slice(inp, axes=[0], starts=[batch_start], ends=[batch_end]) ctx.args = world_size, group return inp @staticmethod def backward(ctx, grad_out): world_size, group = ctx.args return _all_gather(grad_out, group=group) def prepare_forward(gate, num_expert, world_size, moe_group): pos, local_expert_count, global_expert_count = count_by_gate( gate, num_expert, world_size, group=moe_group) with paddle.no_grad(): fwd_expert_count = global_expert_count.reshape_( [world_size, num_expert]).sum(axis=0) fwd_batch_size = int(fwd_expert_count.sum().item()) return ( pos, local_expert_count, global_expert_count, fwd_expert_count, fwd_batch_size, ) class MoELayer(nn.Layer): """MoE Layer Args: d_model: (int) model dimention experts: (nn.LayerList) expert networks list gate: (dict|NaiveGate|SwitchGate|NaiveGate): if gate is a dict: gate is a gate network config, containing 2 keys: `type`(str) value can be: "naive", "gshard", "switch" or None, default is "gshard" `top_k`(int) default value is 2 else gate is an instance of NaiveGate|SwitchGate|NaiveGate: moe_group: moe group for experts communication mp_group: mp group for mp commutication recompute_interval(int, optional): whether to use recompute, default 0, means to disable recompute. recompute_ctx(dict, optional): the context for recompute, if recompute_interval > 1, recompute_ctx must be given. Examples: .. code-block:: python from paddle.nn import layer, LayerList from paddle.distributed.moe import MoElayer from paddle.distributed.collective import Group from paddle.distributed import fleet moe_group = Group(fleet.worker_index(), fleet.worker_num(), 0, list(range(fleet.worker_num()))) mp_group = None num_experts=8 dim_feedforward=512 d_model=8 top_k=2 class ExpertLayer(Layer): def __init__(self, d_model, d_hidden, name=None,rank=0, windex = 0, num_expert=1): super(ExpertLayer, self).__init__() self.htoh4 = nn.Linear(d_model, d_hidden) self.h4toh = nn.Linear(d_hidden, d_model) def forward(self, x): x = self.htoh4(x) x = self.h4toh(x) return x gate_config = { "type": "gshard", "top_k": top_k, } experts_list = LayerList() for expi in range(num_experts): exp_layer = ExpertLayer(d_model, dim_feedforward // top_k, windex=expi, num_expert=num_experts) experts_list.append(exp_layer) moeLayer = MoELayer(d_model = d_model, experts=experts_list, gate=gate_config, moe_group=moe_group, mp_group=mp_group, recompute_interval=0) """ def __init__(self, d_model, experts, gate=None, moe_group=None, mp_group=None, recompute_interval=0, recompute_ctx=None): super(MoELayer, self).__init__() self.recompute_ctx = recompute_ctx if gate is None: gate = dict() assert isinstance(gate, (dict, BaseGate)), \ "gate config' type must be dict or an instance of BaseGate" # only support mp/dp self.group = moe_group self.world_size = 1 if self.group is not None: self.world_size = self.group.nranks self.num_expert = len(experts) self.recompute_interval = recompute_interval assert experts is not None self.experts = experts self.mp_group = mp_group self.d_model = d_model if isinstance(gate, dict): self.top_k = gate.get("top_k", 2) gate = gate.get("type", "gshard") if gate == "naive" or gate is None: gate = NaiveGate(self.d_model, num_expert=len(experts), world_size=self.world_size, topk=self.top_k) elif gate == "gshard": gate = GShardGate(self.d_model, num_expert=len(experts), world_size=self.world_size, topk=self.top_k, group=self.group) elif gate == "switch": gate = SwitchGate(self.d_model, num_expert=len(experts), world_size=self.world_size, topk=self.top_k, group=self.group) else: assert False, "We only support naive gate, \ gshard gate and switch gate, \ but you choose {} gate.".format(str(gate)) elif isinstance(gate, NaiveGate): self.top_k = gate.top_k elif isinstance(gate, BaseGate): raise TypeError("Unimplemented gate type: ", type(gate)) else: raise TypeError("gate's type must be either dict or moe.BaseGate") self.gate = gate def forward(self, inp): # inp shape: b * s * m assert len(inp.shape) == 3 origin_shape = inp.shape inp = inp.reshape_([-1, origin_shape[2]]) mp_rank = 0 mp_size = 1 if self.mp_group is not None: mp_rank = self.mp_group.rank mp_size = self.mp_group.nranks if mp_size > 1: inp = Slice.apply(inp, mp_rank, mp_size, self.mp_group) value, gate = self.gate(inp) ( pos, local_expert_count, global_expert_count, fwd_expert_count, fwd_batch_size, ) = prepare_forward(gate, self.num_expert, self.world_size, self.group) topk = 1 if len(gate.shape) == 2: topk = gate.shape[1] if pos.shape != [0]: temp_pos = pos // topk else: temp_pos = pos assert topk == self.top_k x = MoEScatter.apply(inp, temp_pos, local_expert_count, global_expert_count, fwd_batch_size, self.world_size, self.group) d_model = self.d_model def experts_fwd(x, fwd_expert_count, experts): if x.shape[0] == 0: return x y = [] last_index = 0 assert isinstance(fwd_expert_count, np.ndarray) assert len(experts) == len(fwd_expert_count) for idx, expert_count in enumerate(fwd_expert_count): if expert_count <= 0: continue y.append(experts[idx](x[last_index:expert_count + last_index])) last_index = expert_count + last_index return paddle.concat(y, axis=0) if self.recompute_interval <= 0 or x.shape[0] == 0: x = experts_fwd(x, fwd_expert_count.numpy(), self.experts) else: x = recompute_hybrid(self.recompute_ctx, experts_fwd, x, fwd_expert_count.numpy(), self.experts) out_batch_size = inp.shape[0] if len(gate.shape) == 2: out_batch_size *= gate.shape[1] x = MoEGather.apply(x, pos, local_expert_count, global_expert_count, out_batch_size, self.world_size, self.group) x = x.reshape([-1, self.top_k, d_model]) value = value.reshape([x.shape[0], 1, self.top_k]) x = paddle.bmm(value, x).reshape([-1, d_model]) if mp_size > 1: x = AllGather.apply(x, mp_rank, mp_size, self.mp_group) x = paddle.reshape_(x, origin_shape) return x