// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // The file has been adapted from the two files: // https://github.com/laekov/fastmoe/blob/master/cuda/balancing.cu // https://github.com/laekov/fastmoe/blob/master/cuda/balancing.cuh // Git commit hash: 295a615aacce7e54a37e7935274ba15e901c78e4 // We retain the following license from the original files: // Copyright 2021, Jiaao He. All rights reserved. // Licensed under the Apache License, Version 2.0 (the "License"). #include "paddle/fluid/operators/prune_gate_by_capacity_op.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" namespace ops = paddle::operators; namespace plat = paddle::platform; DECLARE_bool(avoid_op_randomness); namespace paddle { namespace operators { using LoDTensor = framework::LoDTensor; static constexpr int kNumCUDAThreads = 512; static constexpr int kNumMaxinumNumBlocks = 4096; static inline int NumBlocks(const int N) { return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, kNumMaxinumNumBlocks); } template __global__ void prune_gate_by_capacity_kernel(const T1* gate_idx_data, T1* new_gate_idx_data, T2* expert_count_data, const int64_t batch_size) { CUDA_KERNEL_LOOP(i, batch_size) { auto orig_cap = platform::CudaAtomicAdd(expert_count_data + gate_idx_data[i], -1); if (orig_cap <= 0) { new_gate_idx_data[i] = -1; } else { new_gate_idx_data[i] = gate_idx_data[i]; } } } template class PruneGateByCapacityFunctor { public: PruneGateByCapacityFunctor(const framework::ExecutionContext& context, const framework::LoDTensor* gate_idx, framework::LoDTensor* expert_count_out, T1* new_gate_idx_data) : context_(context), gate_idx_(gate_idx), expert_count_out_(expert_count_out), new_gate_idx_data_(new_gate_idx_data) {} template void apply() { auto batch_size = gate_idx_->numel(); auto* gate_idx_data = gate_idx_->data(); auto& dev_ctx = context_.template device_context(); auto* expert_count_out_data = expert_count_out_->data(); int blocks = NumBlocks(batch_size); int threads = kNumCUDAThreads; prune_gate_by_capacity_kernel<<>>( gate_idx_data, new_gate_idx_data_, expert_count_out_data, batch_size); } private: const framework::ExecutionContext context_; const framework::LoDTensor* gate_idx_; framework::LoDTensor* expert_count_out_; T1* new_gate_idx_data_; }; template static void VisitDataType(paddle::experimental::DataType type, Visitor visitor) { if (type == paddle::experimental::DataType::INT64) { visitor.template apply(); } else { PADDLE_THROW(platform::errors::InvalidArgument( "The recieved values gate_id type %s can not meet input requirements. " "Because the given gate_id data type of operators must be " "int64. Please input appropriate gate_id again! ", "framework::DataTypeToString(type)")); } } template class PruneGateByCapacityCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* gate_idx = context.Input("GateIdx"); auto* expert_count = context.Input("ExpertCount"); // auto* expert_count_out = context.Output("ExpertCountOut"); auto* new_gate_idx = context.Output("NewGateIdx"); auto* new_gate_idx_data = new_gate_idx->mutable_data(context.GetPlace()); framework::LoDTensor expert_count_out; framework::TensorCopy(*expert_count, context.GetPlace(), &expert_count_out); PruneGateByCapacityFunctor functor( context, gate_idx, &expert_count_out, new_gate_idx_data); VisitDataType(expert_count->type(), functor); } }; } // namespace operators } // namespace paddle REGISTER_OP_CUDA_KERNEL( prune_gate_by_capacity, ops::PruneGateByCapacityCUDAKernel);