diff --git a/paddle/fluid/operators/limit_by_capacity_op.cc b/paddle/fluid/operators/limit_by_capacity_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..2b6e088deddc4ce693e39220789159e66e1fea96 --- /dev/null +++ b/paddle/fluid/operators/limit_by_capacity_op.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/limit_by_capacity_op.h" + +namespace paddle { +namespace operators { + +class LimitByCapacityOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("expert_count"), "Input", "expert_count", + "LimitByCapacity"); + OP_INOUT_CHECK(ctx->HasInput("capacity"), "Input", "capacity", + "LimitByCapacity"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "LimitByCapacity"); + + ctx->ShareDim("expert_count", "Out"); + ctx->ShareLoD("expert_count", "Out"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + // the dtype of the expert_count and capacity should be same as int64 + auto expert_count_dtype = + OperatorWithKernel::IndicateVarDataType(ctx, "expert_count"); + auto capacity_dtype = + OperatorWithKernel::IndicateVarDataType(ctx, "capacity"); + + PADDLE_ENFORCE_EQ( + expert_count_dtype, capacity_dtype, + platform::errors::InvalidArgument( + "The dtype of the expert_count and capacity should be same")); + + PADDLE_ENFORCE_EQ( + expert_count_dtype, framework::proto::VarType::INT64, + platform::errors::InvalidArgument("The dtype of the expert_count and " + "capacity should be same as int64")); + return framework::OpKernelType(expert_count_dtype, ctx.GetPlace()); + } +}; + +class LimitByCapacityOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("expert_count", "(Tensor) The input expert count tensor."); + AddInput("capacity", "(Tensor) The input capacity."); + AddOutput("Out", + "(Tensor) The output tensor expert count limit by capacity."); + AddAttr("n_worker", "(int), The number of works."); + AddComment( + R"DOC(limit_by_capacity Operator.limit expert count by capacity.)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CPU_KERNEL(limit_by_capacity, ops::LimitByCapacityOpCPUKernel, + ops::LimitByCapacityOpCPUKernel); + +REGISTER_OP_WITHOUT_GRADIENT(limit_by_capacity, ops::LimitByCapacityOp, + ops::LimitByCapacityOpMaker); diff --git a/paddle/fluid/operators/limit_by_capacity_op.cu b/paddle/fluid/operators/limit_by_capacity_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..ebc6d1a927c57d61e12a3e3aa0f9b699bbbc5920 --- /dev/null +++ b/paddle/fluid/operators/limit_by_capacity_op.cu @@ -0,0 +1,83 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/limit_by_capacity_op.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace operators { + +#define CEIL(_x_, _y_) (((_x_)-1) / (_y_) + 1) + +using LoDTensor = framework::LoDTensor; +using Tensor = framework::Tensor; + +template +__global__ void limit_by_capacity_impl(const T* expc, T* cap, T* out, + const int n_expert, const int n_worker) { + int eid = blockIdx.y; + int wid = blockIdx.x * blockDim.x + threadIdx.x; + if (wid < n_worker) { + auto proposal = expc[wid * n_expert + eid]; + // int cap_left = atomicSub(cap + eid, proposal); + auto cap_left = paddle::platform::CudaAtomicAdd(cap + eid, proposal * (-1)); + if (cap_left >= proposal) { + out[wid * n_expert + eid] = proposal; + } else if (cap_left >= 0) { + out[wid * n_expert + eid] = cap_left; + } else { + out[wid * n_expert + eid] = 0; + } + } +} + +template +class LimitByCapacityOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto expert_count = context.Input("expert_count"); + auto capacity = context.Input("capacity"); + auto n_worker = context.Attr("n_worker"); + auto out = context.Output("Out"); + + auto n_expert = expert_count->numel() / n_worker; + // std::cout << "n_expert" << n_expert << std::endl; + const auto place = context.GetPlace(); + const auto& dev_ctx = + context.template device_context(); + + dim3 grid_dim(CEIL(n_worker, 1024), n_expert); + dim3 block_dim(1024); + auto out_data = out->mutable_data(place); + const T* ec_data = expert_count->data(); + + framework::Tensor capacity_copy; + framework::TensorCopy(*capacity, place, dev_ctx, &capacity_copy); + T* cap_data = capacity_copy.mutable_data(place); + + limit_by_capacity_impl<<>>( + ec_data, cap_data, out_data, n_expert, n_worker); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL(limit_by_capacity, + ops::LimitByCapacityOpCUDAKernel); diff --git a/paddle/fluid/operators/limit_by_capacity_op.h b/paddle/fluid/operators/limit_by_capacity_op.h new file mode 100644 index 0000000000000000000000000000000000000000..c76d298f4298216b74b3c580e7b6dcca72480d52 --- /dev/null +++ b/paddle/fluid/operators/limit_by_capacity_op.h @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" + +#if defined(PADDLE_WITH_GLOO) +#include "paddle/fluid/framework/fleet/gloo_wrapper.h" +#endif + +namespace paddle { +namespace operators { + +template +class LimitByCapacityOpCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_THROW(platform::errors::Unavailable( + "Do not support limit by capacity op for cpu kernel now.")); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/prune_gate_by_capacity_op.cc b/paddle/fluid/operators/prune_gate_by_capacity_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..091b33884bfbaadedf70ed98f933bb6485f33dd6 --- /dev/null +++ b/paddle/fluid/operators/prune_gate_by_capacity_op.cc @@ -0,0 +1,123 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/prune_gate_by_capacity_op.h" + +namespace paddle { +namespace operators { + +class PruneGateByCapacityOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("GateIdx"), "Input", "GateIdx", + "prun_gate_by_capacity"); + OP_INOUT_CHECK(ctx->HasInput("ExpertCount"), "Input", "ExpertCount", + "prun_gate_by_capacity"); + + OP_INOUT_CHECK(ctx->HasOutput("NewGateIdx"), "Output", "NewGateIdx", + "prun_gate_by_capacity"); + // OP_INOUT_CHECK(ctx->HasOutput("ExpertCountOut"), "Output", + // "ExpertCountOut", + // "prun_gate_by_capacity"); + // auto gate_idx_dims = ctx->GetInputDim("GateIdx"); + auto expert_count_dims = ctx->GetInputDim("ExpertCount"); + + int64_t n_expert = ctx->Attrs().Get("n_expert"); + int64_t n_worker = ctx->Attrs().Get("n_worker"); + + int64_t expert_count_num_ele = 1; + for (int64_t i = 0; i < expert_count_dims.size(); i++) { + expert_count_num_ele *= expert_count_dims[i]; + } + + PADDLE_ENFORCE_EQ( + expert_count_num_ele, n_expert * n_worker, + platform::errors::Unavailable( + "The number of elements for expert_count is ( %ld ) incorrect. " + "Because the number of expert_count must equal the " + "product of n_worker ( %ld ) and n_expert ( %ld ). " + "Please input appropriate expert_count again!", + expert_count_num_ele, n_worker, n_expert)); + + auto gate_idx_in_dims = ctx->GetInputDim("GateIdx"); + // auto expert_count_in_dims = ctx->GetInputDim("ExpertCount"); + ctx->SetOutputDim("NewGateIdx", gate_idx_in_dims); + // ctx->SetOutputDim("ExpertCountOut", expert_count_in_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto gate_idx_data_type = + OperatorWithKernel::IndicateVarDataType(ctx, "GateIdx"); + auto expert_count_data_type = + OperatorWithKernel::IndicateVarDataType(ctx, "ExpertCount"); + PADDLE_ENFORCE_EQ( + gate_idx_data_type, expert_count_data_type, + platform::errors::InvalidArgument( + "The dtype of the gate_idx and expert_count should be same")); + PADDLE_ENFORCE_EQ(gate_idx_data_type, framework::proto::VarType::INT64, + platform::errors::InvalidArgument( + "The dtype of the gate_idx and expert_count should " + "be same as int64")); + return framework::OpKernelType(gate_idx_data_type, ctx.device_context()); + } +}; + +class PruneGateByCapacityOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("GateIdx", + "(Tensor), The gate_id sequence corresponding to the input data."); + AddInput("ExpertCount", + "(Tensor), The quantity value counted on the gate_id sequence of " + "the input data."); + AddAttr("n_expert", "The number of Experts on each worker") + .SetDefault(0); + AddAttr("n_worker", "The number of workers on the trainer") + .SetDefault(0); + + AddOutput("NewGateIdx", + "(Tensor), The gate_id sequence corresponding to the new input " + "data after passing through prune."); + // AddOutput( + // "ExpertCountOut", + // "(Tensor), The copy quantity value counted on the gate_id sequence of + // " + // "the input data."); + + AddComment(R"DOC( +prune_gate_by_capacity Operator. + +This operator is used to prune gate by capacity(CUDA). + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_WITHOUT_GRADIENT(prune_gate_by_capacity, ops::PruneGateByCapacityOp, + ops::PruneGateByCapacityOpMaker); + +REGISTER_OP_CPU_KERNEL( + prune_gate_by_capacity, + ops::PruneGateByCapacityCPUKernel, + ops::PruneGateByCapacityCPUKernel); diff --git a/paddle/fluid/operators/prune_gate_by_capacity_op.cu b/paddle/fluid/operators/prune_gate_by_capacity_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..953847512bc1a775e9475d9419b475ebeaf5e569 --- /dev/null +++ b/paddle/fluid/operators/prune_gate_by_capacity_op.cu @@ -0,0 +1,123 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/prune_gate_by_capacity_op.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +DECLARE_bool(avoid_op_randomness); + +namespace paddle { +namespace operators { +using LoDTensor = framework::LoDTensor; + +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaxinumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaxinumNumBlocks); +} + +template +__global__ void prune_gate_by_capacity_kernel(const T1* gate_idx_data, + T1* new_gate_idx_data, + T2* expert_count_data, + const int64_t batch_size) { + CUDA_KERNEL_LOOP(i, batch_size) { + auto orig_cap = + platform::CudaAtomicAdd(expert_count_data + gate_idx_data[i], -1); + if (orig_cap <= 0) { + new_gate_idx_data[i] = -1; + } else { + new_gate_idx_data[i] = gate_idx_data[i]; + } + } +} + +template +class PruneGateByCapacityFunctor { + public: + PruneGateByCapacityFunctor(const framework::ExecutionContext& context, + const framework::LoDTensor* gate_idx, + framework::LoDTensor* expert_count_out, + T1* new_gate_idx_data) + : context_(context), + gate_idx_(gate_idx), + expert_count_out_(expert_count_out), + new_gate_idx_data_(new_gate_idx_data) {} + + template + void apply() { + auto batch_size = gate_idx_->numel(); + auto* gate_idx_data = gate_idx_->data(); + + auto& dev_ctx = context_.template device_context(); + auto* expert_count_out_data = expert_count_out_->data(); + + int blocks = NumBlocks(batch_size); + int threads = kNumCUDAThreads; + + prune_gate_by_capacity_kernel<<>>( + gate_idx_data, new_gate_idx_data_, expert_count_out_data, batch_size); + } + + private: + const framework::ExecutionContext context_; + const framework::LoDTensor* gate_idx_; + framework::LoDTensor* expert_count_out_; + T1* new_gate_idx_data_; +}; + +template +static void VisitDataType(paddle::experimental::DataType type, + Visitor visitor) { + if (type == paddle::experimental::DataType::INT64) { + visitor.template apply(); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "The recieved values gate_id type %s can not meet input requirements. " + "Because the given gate_id data type of operators must be " + "int64. Please input appropriate gate_id again! ", + "framework::DataTypeToString(type)")); + } +} + +template +class PruneGateByCapacityCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* gate_idx = context.Input("GateIdx"); + auto* expert_count = context.Input("ExpertCount"); + // auto* expert_count_out = context.Output("ExpertCountOut"); + auto* new_gate_idx = context.Output("NewGateIdx"); + auto* new_gate_idx_data = new_gate_idx->mutable_data(context.GetPlace()); + + framework::LoDTensor expert_count_out; + framework::TensorCopy(*expert_count, context.GetPlace(), &expert_count_out); + PruneGateByCapacityFunctor functor( + context, gate_idx, &expert_count_out, new_gate_idx_data); + VisitDataType(expert_count->type(), functor); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_CUDA_KERNEL( + prune_gate_by_capacity, + ops::PruneGateByCapacityCUDAKernel); diff --git a/paddle/fluid/operators/prune_gate_by_capacity_op.h b/paddle/fluid/operators/prune_gate_by_capacity_op.h new file mode 100644 index 0000000000000000000000000000000000000000..d7a00bd40d786f669f2d8d0cca68938b7285ac5f --- /dev/null +++ b/paddle/fluid/operators/prune_gate_by_capacity_op.h @@ -0,0 +1,33 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace operators { + +template +class PruneGateByCapacityCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "prune_gate_by_capacity is not supported on CPU.")); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/random_routing_op.cc b/paddle/fluid/operators/random_routing_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..5e02ea39745b324be611b13b8f7c42080e49ce9b --- /dev/null +++ b/paddle/fluid/operators/random_routing_op.cc @@ -0,0 +1,96 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/random_routing_op.h" + +namespace paddle { +namespace operators { + +class RandomRoutingOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Prob"), "Input", "Porb", "RandomRouting"); + OP_INOUT_CHECK(ctx->HasInput("TopK_Value"), "Input", "TopKValue", + "RandomRouting"); + OP_INOUT_CHECK(ctx->HasInput("TopK_Idx"), "Input", "TopKIdx", + "RandomRouting"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "RandomRouting"); + + // check dims + + auto topk_val_dims = ctx->GetInputDim("TopK_Value"); + auto prob_dims = ctx->GetInputDim("Prob"); + auto topk_idx_dims = ctx->GetInputDim("TopK_Idx"); + + PADDLE_ENFORCE_EQ(prob_dims[0], topk_val_dims[0], + platform::errors::InvalidArgument( + "Output(Out) of ScatterNdAddOp should not be null.")); + + PADDLE_ENFORCE_EQ(topk_idx_dims[1], topk_val_dims[1], + platform::errors::InvalidArgument( + "Output(Out) of ScatterNdAddOp should not be null.")); + + PADDLE_ENFORCE_EQ(topk_idx_dims[0], topk_val_dims[0], + platform::errors::InvalidArgument( + "Output(Out) of ScatterNdAddOp should not be null.")); + + ctx->SetOutputDim("Out", topk_idx_dims); + ctx->ShareLoD("TopK_Idx", /*->*/ "Out"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + // the dtype of the gate_idx should be same as int64 + const auto topk_idx_dtype = + OperatorWithKernel::IndicateVarDataType(ctx, "TopK_Idx"); + PADDLE_ENFORCE_EQ(topk_idx_dtype, framework::proto::VarType::INT64, + platform::errors::InvalidArgument( + "The dtype of the topk_idx_dtype should be int64")); + + const auto& topk_value_type = + OperatorWithKernel::IndicateVarDataType(ctx, "TopK_Value"); + return framework::OpKernelType(topk_value_type, ctx.GetPlace()); + } +}; + +class RandomRoutingOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Prob", "(Tensor) The input Prob index tensor."); + AddInput("TopK_Value", "(Tensor) The input TopK_Value index tensor."); + AddInput("TopK_Idx", "(Tensor) The input TopK_Idx index tensor."); + AddOutput("Out", "(Tensor) The output random routing tensor."); + AddComment(R"DOC(expert_count Operator random routing.)DOC"); + } +}; + +DECLARE_INPLACE_OP_INFERER(RandomRoutingInplaceInferer, {"TopK_Idx", "Out"}); + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OPERATOR( + random_routing, ops::RandomRoutingOp, ops::RandomRoutingOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker, + ops::RandomRoutingInplaceInferer) + +REGISTER_OP_CPU_KERNEL(random_routing, ops::RandomRoutingOpCPUKernel, + ops::RandomRoutingOpCPUKernel); diff --git a/paddle/fluid/operators/random_routing_op.cu b/paddle/fluid/operators/random_routing_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..fec65518a9d4851128e1ceb74b415971a526dda2 --- /dev/null +++ b/paddle/fluid/operators/random_routing_op.cu @@ -0,0 +1,88 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/random_routing_op.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace operators { + +#define CEIL(_x_, _y_) (((_x_)-1) / (_y_) + 1) +#define PERTHREAD_EXPERTS 256 +#define WARP_SIZE 32 + +const int CUDA_NUM_THREADS = 512; +static inline int GET_BLOCKS(const int N) { + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} + +using LoDTensor = framework::LoDTensor; +using Tensor = framework::Tensor; + +template +__global__ void random_routing_kernel(int64_t* data, const int64_t length, + const size_t N, const size_t D, + const T* prob, const int64_t* topk_idx, + const T* topk_value) { + CUDA_KERNEL_LOOP(idx, length) { + size_t row = idx / D; + size_t col = idx % D; + if (col != 1) return; + if (static_cast(2) * topk_value[idx] < prob[row]) { + data[idx] = static_cast(-1); + } + } +} + +template +class RandomRoutingOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto topk_idx = context.Input("TopK_Idx"); + auto topk_value = context.Input("TopK_Value"); + auto prob = context.Input("Prob"); + auto out = context.Output("Out"); + + auto place = context.GetPlace(); + const auto& dev_ctx = + context.template device_context(); + framework::TensorCopy(*topk_idx, place, out); + + size_t N = topk_idx->dims()[0]; + size_t D = topk_idx->dims()[1]; + + int64_t num_idx = topk_idx->numel(); + + auto prob_data = prob->data(); + auto topk_value_data = topk_value->data(); + auto topk_idx_data = topk_idx->data(); + auto out_data = out->data(); + + random_routing_kernel< + T><<>>( + out_data, num_idx, N, D, prob_data, topk_idx_data, topk_value_data); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL(random_routing, ops::RandomRoutingOpCUDAKernel, + ops::RandomRoutingOpCUDAKernel, + ops::RandomRoutingOpCUDAKernel); diff --git a/paddle/fluid/operators/random_routing_op.h b/paddle/fluid/operators/random_routing_op.h new file mode 100644 index 0000000000000000000000000000000000000000..c4e0ffaa78434466d6c66a868c704f93998a6d10 --- /dev/null +++ b/paddle/fluid/operators/random_routing_op.h @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" + +#if defined(PADDLE_WITH_GLOO) +#include "paddle/fluid/framework/fleet/gloo_wrapper.h" +#endif + +namespace paddle { +namespace operators { + +template +class RandomRoutingOpCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_THROW(platform::errors::Unavailable( + "Do not support expert count op for cpu kernel now.")); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/distributed/models/moe/utils.py b/python/paddle/distributed/models/moe/utils.py index ffc4a1c637c17bf4be1e4fe228c5b1c6f40bea73..6fb6a5ca32b3c3d556262f2c744fa5c7b557b8d0 100644 --- a/python/paddle/distributed/models/moe/utils.py +++ b/python/paddle/distributed/models/moe/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ from paddle.fluid import core from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.framework import _non_static_mode +from paddle.fluid.data_feeder import check_variable_and_dtype def _number_count(numbers, upper_range): @@ -41,7 +42,6 @@ def _number_count(numbers, upper_range): """ if _non_static_mode(): return core.ops.number_count(numbers, 'upper_range', upper_range) - else: op_type = 'number_count' @@ -103,3 +103,113 @@ def _assign_pos(x, cum_count): }, outputs={'Out': [out]}) return out + + +def _random_routing(topk_idx, topk_value, prob, topk=2): + r""" + random routing topk gate idx + ``` + out = topk_idx + for i in len(topk_idx): + if topk * value[i][topk-1] < prob[i]: + out[i][topk-1] = -1 + ``` + Args: + topk_idx: gate idx, shape=(N, topk) + topk_value: values, shape = topk_idx.shape + prob: random prob, shape=(topk_idx.shape[0],) + """ + if topk == 2: + if _non_static_mode(): + return core.ops.random_routing(prob, topk_value, topk_idx) + else: + raise RuntimeError("Not supporting static mode now") + else: + raise RuntimeError("only topk=2 is supported now") + + +def _limit_by_capacity(expert_count, capacity, n_worker): + """ + limit the expert count by capacity. + Args: + expert_count (Tensor): Tensor. The input expert count whose data type should be int32 or int64. + capacity (Tensor): Tensor. The input capacity whose data type should be int32 or int64 and the elements of capacity should be the same with expert_count.numel()/n_work. + n_work (int): The number of the works. + Returns: + out (Tensor): The output expert count limit by capacity. + Examples: + .. code-block:: python + # required: distributed + import paddle + expert_count = [1, 2, 2, 8, 3, 6] + capacity = [5, 5, 5] + n_work = 2 + expert_count = paddle.to_tensor(expert_count, dtype="int32") + capacity = paddle.to_tensor(capacity, dtype="int32") + out = paddle.distributed.utils.limit_by_capacity(expert_count, capacity, n_work) + print(out) # the result: [1, 2, 2, 4, 3, 3] + """ + if _non_static_mode(): + return core.ops.limit_by_capacity(expert_count, capacity, 'n_worker', + n_worker) + else: + op_type = 'limit_by_capacity' + + helper = LayerHelper(op_type, **locals()) + out = helper.create_variable_for_type_inference( + dtype=expert_count.dtype) + + helper.append_op( + type=op_type, + inputs={'expert_count': expert_count, + 'capacity': capacity}, + outputs={'Out': out}, + attrs={'n_worker': n_worker}) + return out + + +def _prune_gate_by_capacity(gate_idx, expert_count, n_expert, n_worker): + """ + prune gate by capacity(only support CUDA) + + Args: + gate_idx (Tensor): Represents the gate_id sequence corresponding to the input data with type int32, int64. + expert_count (Tensor): The quantity value counted on the gate_id sequence of the input data with type int32, int64. + n_worker(int,optional): The number of workers on the trainer with type int64. + + Returns: + new_gate_idx (Tensor): The gate_id sequence corresponding to the new input data after passing through prune. + + Examples: + .. code-block:: python + + import paddle + gate_idx = paddle.to_tensor([1, 3, 3, 3, 3, 2, 1, 1], dtype='int32') + expert_count = paddle.to_tensor([0, 3, 1, 3, 0, 0, 0, 0], dtype='int32') + n_worker = 1 + new_gate_id = paddle.distributed.utils.prune_gate_by_capacity(gate_idx, expert_count, n_expert, n_worker) + print(new_gate_id) + # Tensor(shape=[8], dtype=int32, place=CUDAPlace(0), stop_gradient=True, + [1, 3, 3, 3, -1, 2, 1, 1]) + """ + + if _non_static_mode(): + return core.ops.prune_gate_by_capacity( + gate_idx, expert_count, "n_expert", n_expert, "n_worker", n_worker) + check_variable_and_dtype(gate_idx, 'GateIdx', ['int32', 'int64'], + 'paddle.distributed.utils.prune_gate_by_capacity') + check_variable_and_dtype(expert_count, 'ExpertCount', ['int32', 'int64'], + 'paddle.distributed.utils.prune_gate_by_capacity') + + helper = LayerHelper('prune_gate_by_capacity', **locals()) + new_gate_idx = helper.create_variable_for_type_inference( + dtype=gate_idx.dtype) + helper.append_op( + type='prune_gate_by_capacity', + inputs={'GateIdx': gate_idx, + "ExpertCount": expert_count}, + outputs={'NewGateIdx': new_gate_idx}, + attrs={"n_expert": n_expert, + "n_worker": n_worker}) + + return new_gate_idx diff --git a/python/paddle/fluid/tests/unittests/test_limit_by_capacity_op.py b/python/paddle/fluid/tests/unittests/test_limit_by_capacity_op.py new file mode 100644 index 0000000000000000000000000000000000000000..e5ec67d41f7efa9835d8c8ccc19a03357e18878f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_limit_by_capacity_op.py @@ -0,0 +1,99 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import numpy as np +from paddle.distributed.models.moe import utils +from paddle.fluid import core + + +def limit_by_capacity(expert_count, _capacity, n_worker): + capacity = np.copy(_capacity) + old_shape = expert_count.shape + expert_count = np.reshape(expert_count, (n_worker, len(capacity))) + output = np.zeros_like(expert_count) + for wid in range(len(expert_count)): + for eid in range(len(expert_count[wid])): + last_cap = capacity[eid] + if last_cap >= 0: + capacity[eid] -= expert_count[wid][eid] + if last_cap >= expert_count[wid][eid]: + output[wid][eid] = expert_count[wid][eid] + elif last_cap >= 0: + output[wid][eid] = last_cap + return output.reshape(old_shape) + + +def all_close(exp, out, n_worker): + exp = exp.reshape(n_worker, -1) + out = out.reshape(n_worker, -1) + return np.allclose(exp.sum(0), out.sum(0)) + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestLimitByCapacityInt64API(unittest.TestCase): + def init_test_case(self): + self.expert_count = np.random.randint( + 0, 1000, size=(len(self.capacity) * self.n_worker)) + self.out = limit_by_capacity(self.expert_count, self.capacity, + self.n_worker) + self.expert_count = self.expert_count.astype("int64") + self.capacity = self.capacity.astype("int64") + self.place = paddle.CUDAPlace(0) + + def setUp(self): + self.capacity = np.array([100, 12000, 1200, 800, 4700, 10000, 57, 99]) + self.n_worker = 1024 * 8 + self.init_test_case() + + def test_static_api(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + capacity = paddle.static.data( + 'capacity', shape=self.capacity.shape, dtype="int64") + expert_count_tensor = paddle.static.data( + 'ExpertCount', shape=self.expert_count.shape, dtype="int64") + out = utils._limit_by_capacity(expert_count_tensor, capacity, + self.n_worker) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={ + 'capacity': self.capacity, + 'ExpertCount': self.expert_count, + }, + fetch_list=out) + + assert all_close(self.out, res[0], self.n_worker) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + capacity = paddle.to_tensor(self.capacity) + expert_count_tensor = paddle.to_tensor(self.expert_count) + out = utils._limit_by_capacity(expert_count_tensor, capacity, + self.n_worker) + assert all_close(self.out, out.numpy(), self.n_worker) + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestLimitByCapacityInt64API_SmallWorker(TestLimitByCapacityInt64API): + def setUp(self): + self.capacity = np.array([100, 12000, 1200, 0, 4700, 1000, 57, 200]) + self.n_worker = 1 + self.init_test_case() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_prune_gate_by_capacity_op.py b/python/paddle/fluid/tests/unittests/test_prune_gate_by_capacity_op.py new file mode 100644 index 0000000000000000000000000000000000000000..d9d110f45ff79cd654a8a812a219bf4f40f93e61 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_prune_gate_by_capacity_op.py @@ -0,0 +1,125 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import numpy as np +from paddle.distributed.models.moe import utils +from paddle.fluid import core + + +def count(x, upper_num): + res = np.zeros((upper_num, )).astype(int) + for i in x.reshape(-1): + if i >= 0 and i < len(res): + res[i] += 1 + return res + + +def limit_by_capacity(expert_count, _capacity, n_worker): + capacity = np.copy(_capacity) + old_shape = expert_count.shape + expert_count = np.reshape(expert_count, (n_worker, len(capacity))) + output = np.zeros_like(expert_count) + for wid in range(len(expert_count)): + for eid in range(len(expert_count[wid])): + last_cap = capacity[eid] + if last_cap >= 0: + capacity[eid] -= expert_count[wid][eid] + if last_cap >= expert_count[wid][eid]: + output[wid][eid] = expert_count[wid][eid] + elif last_cap >= 0: + output[wid][eid] = last_cap + return output.reshape(old_shape) + + +def prune_gate_by_capacity(gate_idx, expert_count, n_expert, n_worker): + new_gate_idx = np.copy(gate_idx) + expert_count = np.copy(expert_count) + for i in range(len(gate_idx)): + idx = gate_idx[i] + last_cap = expert_count[idx] + if last_cap > 0: + expert_count[idx] -= 1 + else: + new_gate_idx[i] = -1 + return new_gate_idx + + +def assert_allclose(output, expected, n_expert): + c1 = count(output, n_expert) + c2 = count(expected, n_expert) + assert np.allclose(c1, c2) + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestPruneGateByCapacityAPI1(unittest.TestCase): + def init_test_case(self): + self.gate_idx = np.random.randint( + 0, self.n_expert, size=(200, )).astype(self.dtype) + expert_count = count(self.gate_idx, self.n_expert * self.n_worker) + capacity = np.random.randint(10, 200, size=(self.n_expert, )) + self.expert_count = limit_by_capacity(expert_count, capacity, + self.n_worker).astype(self.dtype) + self.out = prune_gate_by_capacity(self.gate_idx, self.expert_count, + self.n_expert, + self.n_worker).astype(self.dtype) + self.place = paddle.CUDAPlace(0) + + def setUp(self): + self.n_expert = 24 + self.n_worker = 2 + self.dtype = "int64" + self.init_test_case() + + def test_static_api(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + gate_idx_tensor = paddle.static.data( + 'GateIdx', shape=self.gate_idx.shape, dtype="int64") + expert_count_tensor = paddle.static.data( + 'ExpertCount', shape=self.expert_count.shape, dtype="int64") + out = utils._prune_gate_by_capacity(gate_idx_tensor, + expert_count_tensor, + self.n_expert, self.n_worker) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={ + 'GateIdx': self.gate_idx, + 'ExpertCount': self.expert_count, + }, + fetch_list=out) + assert_allclose(res[0], self.out, self.n_expert) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + gate_idx_tensor = paddle.to_tensor(self.gate_idx) + expert_count_tensor = paddle.to_tensor(self.expert_count) + out = utils._prune_gate_by_capacity( + gate_idx_tensor, expert_count_tensor, self.n_expert, self.n_worker) + assert_allclose(out.numpy(), self.out, self.n_expert) + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestPruneGateByCapacityAPI2(TestPruneGateByCapacityAPI1): + def setUp(self): + self.n_expert = 12 + self.n_worker = 1 + self.dtype = "int64" + self.init_test_case() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_random_routing_op.py b/python/paddle/fluid/tests/unittests/test_random_routing_op.py new file mode 100644 index 0000000000000000000000000000000000000000..dc8f6f5fcec153d19ebfac3d3d72df86fcacbc94 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_random_routing_op.py @@ -0,0 +1,77 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import op_test +import numpy as np +import unittest +import paddle +import paddle.fluid.core as core +from paddle.fluid.op import Operator +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard +from paddle.fluid.backward import append_backward +from paddle.distributed.models.moe import utils + + +def random_routing(topk_idx, topk_value, prob, topk=2): + if topk == 2: + new_topk_idx = np.copy(topk_idx) + for i in range(len(topk_idx)): + val = topk_value[i][1] + if val * 2 < prob[i]: + new_topk_idx[i][1] = -1 + return new_topk_idx + else: + raise RuntimeError("only topk=2 is supported now") + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestNumberCountAPIFp32(unittest.TestCase): + def setUp(self): + self.dtype = "float32" + self.init() + + def init(self): + self.upper_range = 8 + self.x = np.random.randint( + -1, self.upper_range, size=(200, 2)).astype('int64') + self.prob = np.random.random((self.x.shape[0], )).astype(self.dtype) + self.topk_value = np.random.random(self.x.shape).astype(self.dtype) + self.out = random_routing(self.x, self.topk_value, + self.prob).astype(self.dtype) + self.place = paddle.CUDAPlace(0) + + def test_api_dygraph(self): + paddle.disable_static() + x = paddle.to_tensor(self.x) + value = paddle.to_tensor(self.topk_value) + prob = paddle.to_tensor(self.prob) + out = utils._random_routing(x, value, prob) + assert np.allclose(out.numpy(), self.out) + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestNumberCountAPIFp16(TestNumberCountAPIFp32): + def setUp(self): + self.dtype = "float16" + self.init() + + +if __name__ == '__main__': + paddle.enable_static() + unittest.main() diff --git a/python/paddle/incubate/distributed/models/moe/__init__.py b/python/paddle/incubate/distributed/models/moe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e1663029ef1f844676ce9484f724dc253d625386 --- /dev/null +++ b/python/paddle/incubate/distributed/models/moe/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/paddle/incubate/distributed/models/moe/gate/__init__.py b/python/paddle/incubate/distributed/models/moe/gate/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4dd9205f7c144da51fb66102b7d75da13f11f659 --- /dev/null +++ b/python/paddle/incubate/distributed/models/moe/gate/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .gshard_gate import GShardGate +from .switch_gate import SwitchGate +from .naive_gate import NaiveGate +from .base_gate import BaseGate diff --git a/python/paddle/incubate/distributed/models/moe/gate/base_gate.py b/python/paddle/incubate/distributed/models/moe/gate/base_gate.py new file mode 100644 index 0000000000000000000000000000000000000000..046051f6b6adbd8358c09bea9f7eb72eac88544d --- /dev/null +++ b/python/paddle/incubate/distributed/models/moe/gate/base_gate.py @@ -0,0 +1,36 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.nn as nn + + +class BaseGate(nn.Layer): + def __init__(self, num_expert, world_size): + super().__init__() + self.world_size = world_size + self.num_expert = num_expert + self.tot_expert = world_size * num_expert + self.loss = None + + def forward(self, x): + raise NotImplementedError("Please implement the forward function.") + + def set_loss(self, loss): + self.loss = loss + + def get_loss(self, clear=True): + loss = self.loss + if clear: + self.loss = None + return loss diff --git a/python/paddle/incubate/distributed/models/moe/gate/gshard_gate.py b/python/paddle/incubate/distributed/models/moe/gate/gshard_gate.py new file mode 100644 index 0000000000000000000000000000000000000000..ea441263790aba02bf495e1e7e432cc022d7db1d --- /dev/null +++ b/python/paddle/incubate/distributed/models/moe/gate/gshard_gate.py @@ -0,0 +1,67 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import paddle +import paddle.nn.functional as F +import numpy as np +from .naive_gate import NaiveGate +from ..utils import limit_by_capacity + + +class GShardGate(NaiveGate): + def __init__(self, + d_model, + num_expert, + world_size, + topk=2, + capacity=(1.2, 2.4), + random_routing=True, + group=None): + assert topk == 2, "topk should be 2 in gshard" + super().__init__(d_model, num_expert, world_size) + self.capacity = capacity + self.random_routing = random_routing + self.group = group + + def forward(self, x): + topk_val, topk_idx, gate_score = super().forward( + x, return_all_scores=True) + s = gate_score.shape[0] + top1_idx = topk_idx.flatten() + c_e = paddle.scatter( + paddle.zeros(shape=[self.tot_expert]), + top1_idx, + paddle.ones_like( + top1_idx, dtype="float32"), + overwrite=False) / s + m_e = paddle.mean(F.softmax(gate_score, axis=1), axis=0) + loss = paddle.mean(c_e * m_e) * (self.num_expert**2) + self.set_loss(loss) + + cap_rate = self.capacity[0 if self.training else 1] + capacity = math.ceil(cap_rate * x.shape[0]) + _new_lec, _new_gec, topk_idx = limit_by_capacity( + topk_idx, + self.num_expert, + self.world_size, + capacity, + group=self.group) + + if self.random_routing: + rand_routing_prob = paddle.rand( + shape=[gate_score.shape[0]], dtype="float32") + topk_idx = paddle.distributed.utils.random_routing( + topk_idx, topk_val, rand_routing_prob) + return topk_val, topk_idx diff --git a/python/paddle/incubate/distributed/models/moe/gate/naive_gate.py b/python/paddle/incubate/distributed/models/moe/gate/naive_gate.py new file mode 100644 index 0000000000000000000000000000000000000000..ac6164ceace8b153326608df8dfb2525d110fb5a --- /dev/null +++ b/python/paddle/incubate/distributed/models/moe/gate/naive_gate.py @@ -0,0 +1,37 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base_gate import BaseGate + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class NaiveGate(BaseGate): + def __init__(self, d_model, num_expert, world_size, topk=2): + super().__init__(num_expert, world_size) + self.gate = nn.Linear(d_model, self.tot_expert) + self.gate.weight.name = "gate_" + self.gate.weight.name + self.gate.bias.name = "gate_" + self.gate.bias.name + self.top_k = topk + + def forward(self, inp, return_all_scores=False): + gate = self.gate(inp) + gate_top_k_val, gate_top_k_idx = paddle.topk( + gate, k=self.top_k, axis=-1, largest=True, sorted=False) + + if return_all_scores: + return gate_top_k_val, gate_top_k_idx, gate + return gate_top_k_val, gate_top_k_idx diff --git a/python/paddle/incubate/distributed/models/moe/gate/switch_gate.py b/python/paddle/incubate/distributed/models/moe/gate/switch_gate.py new file mode 100644 index 0000000000000000000000000000000000000000..94347ea15eb0bee4f897b66f94a155f2918ed2ed --- /dev/null +++ b/python/paddle/incubate/distributed/models/moe/gate/switch_gate.py @@ -0,0 +1,69 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from .naive_gate import NaiveGate +from ..utils import limit_by_capacity + + +class SwitchGate(NaiveGate): + def __init__(self, + d_model, + num_expert, + world_size, + topk=1, + switch_eps=.1, + capacity=(1.2, 2.4), + group=None): + assert topk == 1, "topk should be 1 in switch" + super().__init__(d_model, num_expert, world_size, topk=1) + self.switch_eps = switch_eps + self.capacity = capacity + self.group = group + + def forward(self, inp): + score = self.gate(inp) + + if self.training: + noise = paddle.rand(shape=score.shape) + noise = noise * 2 * self.switch_eps + 1.0 - self.switch_eps + score += noise + + score = F.softmax(score, axis=-1) + top1_score, top1_idx = paddle.topk(score, k=1, axis=-1, largest=True) + + cap_rate = self.capacity[0 if self.training else 1] + capacity = math.ceil(cap_rate * inp.shape[0]) + _new_lec, _new_gec, top1_idx = limit_by_capacity( + top1_idx, + self.num_expert, + self.world_size, + capacity, + group=self.group) + valid_idx = top1_idx[top1_idx > -1] + valid_idx_tmp = paddle.reshape(valid_idx, shape=[len(valid_idx), 1]) + fraction_expert = paddle.scatter_nd_add( + x=paddle.zeros(shape=[self.tot_expert]), + index=valid_idx_tmp, + updates=paddle.ones_like( + valid_idx, dtype=paddle.float32).reshape( + shape=[len(valid_idx)]), ) / valid_idx.numel() + prob_expert = score.sum(axis=0) / valid_idx.numel() + loss = (fraction_expert * prob_expert).sum() * self.tot_expert + self.set_loss(loss) + + return top1_score, top1_idx diff --git a/python/paddle/incubate/distributed/models/moe/grad_clip.py b/python/paddle/incubate/distributed/models/moe/grad_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..cde5455d271683c4a6867e0a4ac4a0472b24b2df --- /dev/null +++ b/python/paddle/incubate/distributed/models/moe/grad_clip.py @@ -0,0 +1,217 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +from paddle.fluid.clip import ClipGradBase, _squared_l2_norm +from paddle.fluid.dygraph import base as imperative_base +from paddle.fluid import core, layers, framework +from paddle.distributed import collective +import six +import warnings +import copy + + +class ClipGradForMOEByGlobalNorm(ClipGradBase): + r""" + The Algrithm is the same as paddle.fluid.clip.ClipGradByGlobalNorm + Given a list of Tensor :math:`t\_list` , calculate the global norm for the elements of all tensors in + :math:`t\_list` , and limit it to ``clip_norm`` . + + - If the global norm is greater than ``clip_norm`` , all elements of :math:`t\_list` will be compressed by a ratio. + + - If the global norm is less than or equal to ``clip_norm`` , nothing will be done. + + The list of Tensor :math:`t\_list` is not passed from this class, but the gradients of all parameters set in ``optimizer``. + If ``need_clip`` of specific param is ``False`` in its ``ParamAttr``, then the gradients of this param will not be clipped. + + Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer`` + (for example: :ref:`api_paddle_optimizer_SGD`). + + The clipping formula is: + + .. math:: + + t\_list[i] = t\_list[i] * \frac{clip\_norm}{\max(global\_norm, clip\_norm)} + + where: + + .. math:: + + global\_norm = \sqrt{\sum_{i=0}^{N-1}(l2norm(t\_list[i]))^2} + + Note: + ``need_clip`` of ``ClipGradyGlobalNorm`` HAS BEEN DEPRECATED since 2.0. + Please use ``need_clip`` in ``ParamAttr`` to speficiy the clip scope. + + Args: + clip_norm (float): The maximum norm value. + is_expert_param_func (function): a function to decide whether a param should be put into moe_params_grads + moe_group (Group): group for moe experts communication. + group_name (str, optional): The group name for this clip. Default value is ``default_moe_group``. + + Examples: + .. code-block:: python + + import paddle + + x = paddle.uniform([10, 10], min=-1.0, max=1.0, dtype='float32') + linear = paddle.nn.Linear(in_features=10, out_features=10, + weight_attr=paddle.ParamAttr(need_clip=True), + bias_attr=paddle.ParamAttr(need_clip=False)) + out = linear(x) + loss = paddle.mean(out) + loss.backward() + + is_expert_func = lambda param: "expert_" in param.name + clip = paddle.nn.ClipGradForMOEByGlobalNorm(clip_norm=1.0,is_expert_func, None) + sdg = paddle.optimizer.SGD(learning_rate=0.1, parameters=linear.parameters(), grad_clip=clip) + sdg.step() + """ + + def __init__(self, + clip_norm, + is_expert_param_func=None, + moe_group=None, + group_name="default_moe_group"): + super(ClipGradForMOEByGlobalNorm, self).__init__() + self.clip_norm = float(clip_norm) + self.group_name = group_name + self.moe_group = moe_group + if moe_group is not None and moe_group.nranks > 1: + assert is_expert_param_func is not None, \ + "When moe group size > 1, a function for selecting expert params must be specified." + self.is_expert_param_func = is_expert_param_func + + def __str__(self): + return "Gradient Clip By GlobalNorm, global_norm=%f" % (self.clip_norm) + + @staticmethod + def get_l2_norm_pow(params_grads, sum_dtype=None): + sum_square_list = [] + sum_square_list_fp16 = [] + sum_square_list_fp32 = [] + for p, g in params_grads: + if g is None: + continue + if getattr(p, 'need_clip', True) is False: + continue + merge_grad = g + if g.type == core.VarDesc.VarType.SELECTED_ROWS: + merge_grad = layers.merge_selected_rows(g) + merge_grad = layers.get_tensor_from_selected_rows(merge_grad) + sum_square = _squared_l2_norm(merge_grad) + if sum_square.dtype == core.VarDesc.VarType.FP16: + sum_square_list_fp16.append(sum_square) + elif sum_square.dtype == core.VarDesc.VarType.FP32: + sum_square_list_fp32.append(sum_square) + else: + sum_square_list.append(sum_square) + + # all parameters have been filterd out + if len(sum_square_list) + len(sum_square_list_fp16) + len( + sum_square_list_fp32) == 0: + return None, None + assert sum_dtype in ["float64", "float32", None], \ + "sum's type must be float64/ float32 / None" + if sum_dtype != "float64": + sum_dtype = 'float64' if len(sum_square_list) > 0 else "float32" + + global_norm_var = [] + if len(sum_square_list_fp16) > 0: + global_norm_var_fp16 = layers.concat(sum_square_list_fp16) + global_norm_var_fp16 = layers.reduce_sum(global_norm_var_fp16) + global_norm_var.append(global_norm_var_fp16.astype(sum_dtype)) + if len(sum_square_list_fp32) > 0: + global_norm_var_fp32 = layers.concat(sum_square_list_fp32) + global_norm_var_fp32 = layers.reduce_sum(global_norm_var_fp32) + if sum_dtype == 'float32': + global_norm_var.append(global_norm_var_fp32) + else: + global_norm_var.append(global_norm_var_fp32.astype(sum_dtype)) + if len(sum_square_list) > 0: + global_norm_var_fp64 = layers.concat(sum_square_list) + global_norm_var_fp64 = layers.reduce_sum(global_norm_var_fp64) + global_norm_var.append(global_norm_var_fp64) + global_norm_var = layers.concat(global_norm_var) + global_norm_var = layers.reduce_sum(global_norm_var) + return global_norm_var, sum_dtype + + @imperative_base.no_grad + def _dygraph_clip(self, params_grads): + normal_params_grads = [] + moe_params_grads = [] + + # seperate moe params from normal params + if self.moe_group is not None and self.moe_group.nranks > 1: + for p, g in params_grads: + if self.is_expert_param_func(p): + moe_params_grads.append((p, g)) + else: + normal_params_grads.append((p, g)) + else: + normal_params_grads = params_grads + + # why to return sum_dtype? + # we will call `get_l2_norm_pow` twice and the precisions may be different. + # For convenience and simplification, we use sum_dtype directly instead of global_norm_var_normal.dtype + global_norm_var_normal, sum_dtype \ + = self.get_l2_norm_pow(normal_params_grads) + global_norm_var_moe = None + if len(moe_params_grads) > 0: + global_norm_var_moe, _ \ + = self.get_l2_norm_pow(moe_params_grads, sum_dtype) + if global_norm_var_moe is not None: + collective.all_reduce( + global_norm_var_moe, + op=collective.ReduceOp.SUM, + group=self.moe_group) + + if global_norm_var_normal is None and global_norm_var_moe is None: + return params_grads + elif global_norm_var_normal is None: + global_norm_var = global_norm_var_moe + elif global_norm_var_moe is None: + global_norm_var = global_norm_var_normal + else: + if global_norm_var_normal.dtype != global_norm_var_moe.dtype: + # compared with normal norm, moe norm is the later one, + # so its precision is no lower than normal norm + global_norm_var_normal = \ + global_norm_var_normal.astype(global_norm_var_moe.dtype) + global_norm_var = global_norm_var_normal + global_norm_var_moe + + params_and_grads = [] + global_norm_var = layers.sqrt(global_norm_var) + max_global_norm = layers.fill_constant( + shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm) + clip_var = layers.elementwise_div( + x=max_global_norm, + y=layers.elementwise_max( + x=global_norm_var, y=max_global_norm)) + for p, g in params_grads: + if g is None: + continue + if getattr(p, 'need_clip', True) is False: + params_and_grads.append((p, g)) + continue + # TODO(wangxi): use inplace elementwise_mul + clip_input = (clip_var.astype('float16') + if g.dtype == core.VarDesc.VarType.FP16 else clip_var) + new_grad = layers.elementwise_mul(x=g, y=clip_input) + params_and_grads.append((p, new_grad)) + return params_and_grads + + +ClipGradByGlobalNorm = ClipGradForMOEByGlobalNorm diff --git a/python/paddle/incubate/distributed/models/moe/moe_layer.py b/python/paddle/incubate/distributed/models/moe/moe_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..6ee2a30589cddfdd8c6e67891c6eed671f0a2cb8 --- /dev/null +++ b/python/paddle/incubate/distributed/models/moe/moe_layer.py @@ -0,0 +1,431 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import math + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.distributed.utils import global_scatter, global_gather +from paddle.distributed import alltoall, all_gather + +from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker +from paddle.distributed import fleet +from paddle.autograd import PyLayer +from .gate import NaiveGate, GShardGate, SwitchGate, BaseGate +from .utils import count_by_gate +from paddle.distributed.fleet.meta_parallel.pp_utils.utils import _hp_recompute +from paddle import fluid + +__all__ = ["MoeLayer"] + + +def _local_scatter(inp, pos): + if pos.shape != [0]: + inp_buf = paddle.index_select(inp, pos, 0) + else: + inp_buf = paddle.empty([0, inp.shape[1]], dtype=inp.dtype) + return inp_buf + + +def _local_gather(inp, pos, out_batch_size, maybe_overlap=True): + if pos.shape != [0]: + origin_dtype = inp.dtype + inp = paddle.cast(inp, dtype="float32") + inp_buf = paddle.scatter( + paddle.zeros( + shape=[out_batch_size, inp.shape[-1]], dtype="float32"), + pos, + inp, + overwrite=True) + inp_buf = paddle.cast(inp_buf, dtype=origin_dtype) + else: + inp_buf = paddle.zeros([out_batch_size, inp.shape[-1]], dtype=inp.dtype) + return inp_buf + + +def _all_gather(tensor, group=None, use_calc_stream=True): + """ + The main difference with paddle.distributed.all_gather: + no need to pass in tensor_list, the returned tensor is spliced + """ + if group is not None and not group.is_member(): + return + ring_id = 0 if group is None else group.id + nranks = paddle.distributed.collective._get_global_group( + ).nranks if group is None else group.nranks + return paddle._C_ops.c_allgather(tensor, 'use_calc_stream', use_calc_stream, + 'ring_id', ring_id, 'nranks', nranks) + + +class MOEScatter(PyLayer): + r""" + Scatter input samples from [batch x sequences] to contiguous alone experts. + If `world_size` is greater than 1, the samples will first be locally + scattered, and then exchanged across workers. + """ + + @staticmethod + def forward(ctx, + inp, + pos, + local_expert_count, + global_expert_count, + fwd_batch_size, + world_size, + group=None): + local_input_buf = _local_scatter(inp, pos) + if world_size > 1: + global_input_buf = global_scatter( + local_input_buf, + local_expert_count, + global_expert_count, + group=group) + else: + global_input_buf = local_input_buf + + ctx.moe_args = inp.shape[0], world_size, group + + variables = (pos, local_expert_count, global_expert_count) + ctx.save_for_backward(*variables) + return global_input_buf + + @staticmethod + def backward(ctx, grad): + (pos, local_expert_count, global_expert_count) = ctx.saved_tensor() + (inp_batch_size, world_size, group) = ctx.moe_args + + if world_size > 1: + local_grad_in = global_gather( + grad, local_expert_count, global_expert_count, group=group) + else: + local_grad_in = grad + grad_in = _local_gather(local_grad_in, pos, inp_batch_size) + return grad_in, None, None, None + + +class MOEGather(PyLayer): + r""" + Gather output samples from contiguous alone experts back to [batch x + sequences]. Works symmetrically with MOEScatter. + """ + + @staticmethod + def forward(ctx, + global_output_buf, + pos, + local_expert_count, + global_expert_count, + local_batch_size, + world_size, + group=None): + if world_size > 1: + local_output_buf = global_gather( + global_output_buf, + local_expert_count, + global_expert_count, + group=group) + else: + local_output_buf = global_output_buf + output = _local_gather( + local_output_buf, pos, local_batch_size, maybe_overlap=False) + + ctx.moe_args = (global_output_buf.shape[0], world_size, group) + variables = (pos, local_expert_count, global_expert_count) + ctx.save_for_backward(*variables) + return output + + @staticmethod + def backward(ctx, grad_out): + pos, local_expert_count, global_expert_count = ctx.saved_tensor() + fwd_batch_size, world_size, group = ctx.moe_args + grad_out_buf = _local_scatter(grad_out, pos) + if world_size > 1: + global_grad_out_buf = global_scatter( + grad_out_buf, + local_expert_count, + global_expert_count, + group=group) + else: + global_grad_out_buf = grad_out_buf + return global_grad_out_buf, None, None, None + + +class AllGather(PyLayer): + r""" + A wrapper for the All-Gather function to support auto-differentiation. + """ + + @staticmethod + def forward(ctx, inp, rank, world_size, group): + tensor_list = [] + paddle.distributed.all_gather(tensor_list, inp, group=group) + output = paddle.concat(tensor_list, axis=0) + ctx.args = rank, inp.shape[0] + return output + + @staticmethod + def backward(ctx, grad_out): + rank, dim0 = ctx.args + return paddle.slice( + grad_out, axes=[0], starts=[rank * dim0], ends=[(rank + 1) * dim0]) + + +class Slice(PyLayer): + r""" + A wrapper for the Slice function to support auto-differentiation. + """ + + @staticmethod + def forward(ctx, inp, rank, world_size, group): + B = inp.shape[0] + local_batch_size = B // world_size + batch_start = local_batch_size * rank + batch_end = min(batch_start + local_batch_size, B) + inp = paddle.slice( + inp, axes=[0], starts=[batch_start], ends=[batch_end]) + ctx.args = world_size, group + return inp + + @staticmethod + def backward(ctx, grad_out): + world_size, group = ctx.args + # tensor_list = [] + # paddle.distributed.all_gather(tensor_list, grad_out, group=group) + # grad_out = paddle.concat(tensor_list, axis=0) + return _all_gather(grad_out, group=group) + # return grad_out + + +def prepare_forward(gate, num_expert, world_size, moe_group): + pos, local_expert_count, global_expert_count = count_by_gate( + gate, num_expert, world_size, group=moe_group) + with paddle.no_grad(): + fwd_expert_count = global_expert_count.reshape_( + [world_size, num_expert]).sum(axis=0) + fwd_batch_size = int(fwd_expert_count.sum().item()) + return ( + pos, + local_expert_count, + global_expert_count, + fwd_expert_count, + fwd_batch_size, ) + + +class MoeLayer(nn.Layer): + """Moe Layer + Args: + d_model: (int) model dimention + experts: (nn.LayerList) expert networks list + gate: (dict|NaiveGate|SwitchGate|NaiveGate): + if gate is a dict: + gate is a gate network config, containing 2 keys: + `type`(str) value can be: "naive", "gshard", "switch" or None, default is "gshard" + `top_k`(int) default value is 2 + else gate is an instance of NaiveGate|SwitchGate|NaiveGate: + + moe_group: moe group for experts communication + mp_group: mp group for mp commutication + kwargs: other parameters + Examples: + .. code-block:: python + from paddle.nn import layer, LayerList + from paddle.distributed.moe import Moelayer + from paddle.distributed.collective import Group + from paddle.distributed import fleet + + moe_group = Group(fleet.worker_index(), + fleet.worker_num(), + 0, + list(range(fleet.worker_num()))) + mp_group = None + + num_experts=8 + dim_feedforward=512 + d_model=8 + top_k=2 + + class ExpertLayer(Layer): + def __init__(self, d_model, d_hidden, name=None,rank=0, windex = 0, num_expert=1): + super(ExpertLayer, self).__init__() + self.htoh4 = nn.Linear(d_model, d_hidden) + self.h4toh = nn.Linear(d_hidden, d_model) + + def forward(self, x): + x = self.htoh4(x) + x = self.h4toh(x) + return x + + gate_config = { + "type": "gshard", + "top_k": top_k, + } + + experts_list = LayerList() + for expi in range(num_experts): + exp_layer = ExpertLayer(d_model, dim_feedforward // top_k, windex=expi, num_expert=num_experts) + experts_list.append(exp_layer) + + moeLayer = MoeLayer(d_model = d_model, + experts=experts_list, + gate=gate_config, + moe_group=moe_group, + mp_group=mp_group, + recompute_interval=0) + + """ + + def __init__(self, + d_model, + experts, + gate=None, + moe_group=None, + mp_group=None, + **kwargs): + super(MoeLayer, self).__init__() + + recompute_interval = kwargs.get("recompute_interval", 0) + + if gate is None: + gate = dict() + + assert isinstance(gate, (dict, BaseGate)), \ + "gate config' type must be dict or an instance of BaseGate" + # only support mp/dp + self.group = moe_group + + self.world_size = 1 + if self.group is not None: + self.world_size = self.group.nranks + self.num_expert = len(experts) + self.recompute_interval = recompute_interval + assert experts is not None + self.experts = experts + + self.mp_group = mp_group + self.d_model = d_model + if isinstance(gate, dict): + self.top_k = gate.get("top_k", 2) + gate = gate.get("type", "gshard") + if gate == "naive" or gate is None: + gate = NaiveGate( + self.d_model, + num_expert=len(experts), + world_size=self.world_size, + topk=self.top_k) + elif gate == "gshard": + gate = GShardGate( + self.d_model, + num_expert=len(experts), + world_size=self.world_size, + topk=self.top_k, + group=self.group) + elif gate == "switch": + gate = SwitchGate( + self.d_model, + num_expert=len(experts), + world_size=self.world_size, + topk=self.top_k, + group=self.group) + else: + assert False, "We only support naive gate, \ + gshard gate and switch gate, \ + but you choose {} gate.".format(str(gate)) + elif isinstance(gate, NaiveGate): + self.top_k = gate.top_k + elif isinstance(gate, BaseGate): + raise TypeError("Unimplemented gate type: ", type(gate)) + else: + raise TypeError("gate's type must be either dict or moe.BaseGate") + self.gate = gate + + def forward(self, inp): + # inp shape: b * s * m + assert len(inp.shape) == 3 + origin_shape = inp.shape + inp = inp.reshape_([-1, origin_shape[2]]) + + mp_rank = 0 + mp_size = 1 + if self.mp_group is not None: + mp_rank = self.mp_group.rank + mp_size = self.mp_group.nranks + if mp_size > 1: + inp = Slice.apply(inp, mp_rank, mp_size, self.mp_group) + value, gate = self.gate(inp) + + ( + pos, + local_expert_count, + global_expert_count, + fwd_expert_count, + fwd_batch_size, ) = prepare_forward(gate, self.num_expert, + self.world_size, self.group) + + topk = 1 + if len(gate.shape) == 2: + topk = gate.shape[1] + + if pos.shape != [0]: + temp_pos = pos // topk + else: + temp_pos = pos + assert topk == self.top_k + + x = MOEScatter.apply(inp, temp_pos, local_expert_count, + global_expert_count, fwd_batch_size, + self.world_size, self.group) + + d_model = self.d_model + + def experts_fwd(x, fwd_expert_count, experts): + + if x.shape[0] == 0: + return paddle.empty(x.shape, x.dtype) + y = [] + last_index = 0 + assert isinstance(fwd_expert_count, np.ndarray) + assert len(experts) == len(fwd_expert_count) + for idx, expert_count in enumerate(fwd_expert_count): + if expert_count <= 0: + continue + y.append(experts[idx](x[last_index:expert_count + last_index])) + last_index = expert_count + last_index + return paddle.concat(y, axis=0) + + if self.recompute_interval <= 0: + x = experts_fwd(x, fwd_expert_count.numpy(), self.experts) + else: + x = _hp_recompute(experts_fwd, x, + fwd_expert_count.numpy(), self.experts) + + out_batch_size = inp.shape[0] + if len(gate.shape) == 2: + out_batch_size *= gate.shape[1] + + x = MOEGather.apply(x, pos, local_expert_count, global_expert_count, + out_batch_size, self.world_size, self.group) + + x = x.reshape([-1, self.top_k, d_model]) + value = value.reshape([x.shape[0], 1, self.top_k]) + x = paddle.bmm(value, x).reshape([-1, d_model]) + + if mp_size > 1: + x = AllGather.apply(x, mp_rank, mp_size, self.mp_group) + + x = paddle.reshape_(x, origin_shape) + + return x diff --git a/python/paddle/incubate/distributed/models/moe/utils.py b/python/paddle/incubate/distributed/models/moe/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..99e31a16273bf7ef939d724c00d35e7fb647aada --- /dev/null +++ b/python/paddle/incubate/distributed/models/moe/utils.py @@ -0,0 +1,59 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from paddle.distributed.models.moe.utils import * + + +def _alltoall(in_tensor_list, group=None, use_calc_stream=True): + if group is not None and not group.is_member(): + return + ring_id = 0 if group is None else group.id + nranks = len(in_tensor_list) + return paddle._C_ops.alltoall(in_tensor_list, 'use_calc_stream', + use_calc_stream, 'ring_id', ring_id) + + +def count_by_gate(gate, num_expert, world_size, require_pos=True, group=None): + total_expert_count = num_expert * world_size + with paddle.no_grad(): + local_expert_count = _number_count(gate, total_expert_count) + + if world_size > 1: + global_expert_count = _alltoall(local_expert_count, group=group) + else: + global_expert_count = local_expert_count + if not require_pos: + pos = None + else: + lec_cum = paddle.cumsum(local_expert_count, axis=0) + pos = _assign_pos(gate, lec_cum) + return pos, local_expert_count, global_expert_count + + +def limit_by_capacity(topk_idx, num_expert, world_size, capacity, group=None): + with paddle.no_grad(): + capacity = paddle.ones( + shape=[num_expert], dtype=paddle.int64) * capacity + pos, lec, gec = count_by_gate( + topk_idx, num_expert, world_size, require_pos=False, group=group) + new_gec = _limit_by_capacity(gec, capacity, world_size) + if world_size > 1: + assert group.nranks == world_size + new_lec = _alltoall(new_gec, group=group) + else: + new_lec = new_gec + + topk_idx = _prune_gate_by_capacity(topk_idx, new_lec, num_expert, + world_size) + + return new_lec, new_gec, topk_idx