From def2a87fbbc4640d9eb49d1009d77ddd452a20f7 Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 <49090790+xiaoxiaohehe001@users.noreply.github.com> Date: Thu, 22 Dec 2022 14:56:20 +0800 Subject: [PATCH] [Paddle Inference] Add moe phi kernel (#48703) --- cmake/third_party.cmake | 1 + .../tensorrt/dynamic_shape_infermeta.cc | 10 + paddle/fluid/operators/moe_op.cc | 64 ++ paddle/phi/infermeta/multiary.cc | 14 + paddle/phi/infermeta/multiary.h | 9 + paddle/phi/kernels/CMakeLists.txt | 6 + .../fusion/cutlass/default_moe_fc_traits.h | 206 ++++ .../cutlass/linear_combination_ft_gelu.h | 687 +++++++++++++ .../fusion/cutlass/moe_cutlass_kernel.h | 879 +++++++++++++++++ .../phi/kernels/fusion/cutlass/moe_kernel.cu | 911 ++++++++++++++++++ .../kernels/fusion/cutlass/moe_kernel_impl.h | 779 +++++++++++++++ paddle/phi/kernels/fusion/moe_kernel.h | 32 + .../fluid/tests/unittests/CMakeLists.txt | 2 + .../tests/unittests/test_fused_ec_moe_op.py | 176 ++++ python/paddle/incubate/nn/__init__.py | 2 + .../paddle/incubate/nn/functional/__init__.py | 2 + .../incubate/nn/functional/fused_ec_moe.py | 75 ++ .../paddle/incubate/nn/layer/fused_ec_moe.py | 101 ++ 18 files changed, 3956 insertions(+) create mode 100644 paddle/fluid/operators/moe_op.cc create mode 100644 paddle/phi/kernels/fusion/cutlass/default_moe_fc_traits.h create mode 100644 paddle/phi/kernels/fusion/cutlass/linear_combination_ft_gelu.h create mode 100644 paddle/phi/kernels/fusion/cutlass/moe_cutlass_kernel.h create mode 100644 paddle/phi/kernels/fusion/cutlass/moe_kernel.cu create mode 100644 paddle/phi/kernels/fusion/cutlass/moe_kernel_impl.h create mode 100644 paddle/phi/kernels/fusion/moe_kernel.h create mode 100644 python/paddle/fluid/tests/unittests/test_fused_ec_moe_op.py create mode 100644 python/paddle/incubate/nn/functional/fused_ec_moe.py create mode 100644 python/paddle/incubate/nn/layer/fused_ec_moe.py diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index 199a61fca7c..f2bfa77b0e4 100755 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -516,6 +516,7 @@ if(WITH_GPU if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.0) include(external/cutlass) # download, build, install cusparselt list(APPEND third_party_deps extern_cutlass) + set(WITH_CUTLASS ON) endif() endif() diff --git a/paddle/fluid/inference/tensorrt/dynamic_shape_infermeta.cc b/paddle/fluid/inference/tensorrt/dynamic_shape_infermeta.cc index 4c5944e7945..a959977dfe0 100644 --- a/paddle/fluid/inference/tensorrt/dynamic_shape_infermeta.cc +++ b/paddle/fluid/inference/tensorrt/dynamic_shape_infermeta.cc @@ -235,6 +235,15 @@ nvinfer1::DimsExprs UnchangedInferMeta( return inputs[0]; } +nvinfer1::DimsExprs MoeInferMeta( + int output_index, + const nvinfer1::DimsExprs* inputs, + int nb_inputs, + nvinfer1::IExprBuilder& expr_builder, // NOLINT + const framework::OpDesc& op_desc) { + return inputs[0]; +} + nvinfer1::DimsExprs Pad3dInferMeta( int output_index, const nvinfer1::DimsExprs* inputs, @@ -384,6 +393,7 @@ PD_REGISTER_DYNAMIC_INFER_META_FN(instance_norm, InstanceNormInferMeta); PD_REGISTER_DYNAMIC_INFER_META_FN(unfold, UnflodInferMeta); PD_REGISTER_DYNAMIC_INFER_META_FN(scatter_nd_add, ScatterNdAddInferMeta); PD_REGISTER_DYNAMIC_INFER_META_FN(inverse, UnchangedInferMeta); +PD_REGISTER_DYNAMIC_INFER_META_FN(moe, MoeInferMeta); PD_REGISTER_DYNAMIC_INFER_META_FN(pad3d, Pad3dInferMeta); PD_REGISTER_DYNAMIC_INFER_META_FN(grid_sampler, GridSamplerInferMeta); } // namespace tensorrt diff --git a/paddle/fluid/operators/moe_op.cc b/paddle/fluid/operators/moe_op.cc new file mode 100644 index 00000000000..6832beeaa8e --- /dev/null +++ b/paddle/fluid/operators/moe_op.cc @@ -0,0 +1,64 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/binary.h" + +namespace paddle { +namespace operators { + +class MoeOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; + +class MoeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The source input tensor of Moe op."); + AddInput("Gate", "(Tensor), The gating input tensor of Moe op."); + AddInput("Bmm0", "(Tensor), The bmm0 input tensor of Moe op."); + AddInput("Bias0", "(Tensor), The eltwise0 input tensor of Moe op."); + AddInput("Bmm1", "(Tensor), The bmm1 input tensor of Moe op."); + AddInput("Bias1", "(Tensor), The eltwise1 input tensor of Moe op."); + AddOutput("Out", "(Tensor), The output tensor of Moe op."); + AddAttr( + "act_type", + R"DOC(activation type, currently only support `gelu`, `relu`. Default value is: `gelu`. )DOC") + .SetDefault("gelu"); + AddComment( + R"DOC(FusedEcMoe kernel. For more details you can refer to `FusedEcMoE` python documents. )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +DECLARE_INFER_SHAPE_FUNCTOR(moe, + MoeInferShapeFunctor, + PD_INFER_META(phi::MoeInferMeta)); +REGISTER_OPERATOR(moe, ops::MoeOp, ops::MoeOpMaker, MoeInferShapeFunctor); diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index d713d59d2ce..6dcd938f72d 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -2931,6 +2931,20 @@ void YoloLossInferMeta(const MetaTensor& x, gt_match_mask->set_dtype(x.dtype()); } +void MoeInferMeta(const MetaTensor& x, + const MetaTensor& gate, + const MetaTensor& bmm0, + const MetaTensor& bias0, + const MetaTensor& bmm1, + const MetaTensor& bias1, + const std::string& act_type, + MetaTensor* out) { + out->set_dims(x.dims()); + out->share_lod(x); + out->set_dtype(x.dtype()); + out->set_layout(x.layout()); +} + } // namespace phi PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta); diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index a8fc5077475..2ab3c2538a8 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -523,4 +523,13 @@ void YoloLossInferMeta(const MetaTensor& x, MetaTensor* objectness_mask, MetaTensor* gt_match_mask); +void MoeInferMeta(const MetaTensor& x, + const MetaTensor& gate, + const MetaTensor& bmm0, + const MetaTensor& bias0, + const MetaTensor& bmm1, + const MetaTensor& bias1, + const std::string& act_type, + MetaTensor* out); + } // namespace phi diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index cbe7b25ea07..aab850d0c83 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -104,6 +104,12 @@ file( "strings/gpu/*.cu" "fusion/gpu/*.cu") +if(WITH_CUTLASS) + file(GLOB cutlass_cu "fusion/cutlass/default_moe_fc_traits.h" + "fusion/cutlass/linear_combination_ft_gelu.h" "fusion/cutlass/moe*") + list(APPEND kernel_cu ${cutlass_cu}) +endif() + if(WITH_MKLDNN) file( GLOB diff --git a/paddle/phi/kernels/fusion/cutlass/default_moe_fc_traits.h b/paddle/phi/kernels/fusion/cutlass/default_moe_fc_traits.h new file mode 100644 index 00000000000..8f3924887b6 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/default_moe_fc_traits.h @@ -0,0 +1,206 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/bfloat16.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" + +namespace cutlass { +namespace gemm { +namespace kernel { + +template +struct MoeArchTraits {}; + +template +struct MoeArchTraits { + static constexpr int Stages = 2; + using OperatorClass = cutlass::arch::OpClassSimt; + using AccType = float; + using LayoutB = cutlass::layout::RowMajor; + + static constexpr int ElementsPerAccessA = 1; + static constexpr int ElementsPerAccessB = 1; + static constexpr int ElementsPerAccessC = 1; + using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// ========================= Volta Traits =========================== +// Volta will always dequantize after the global memory load. +template +struct MoeArchTraits { + private: + static constexpr int ThreadblockK = 32; + + public: + static constexpr int Stages = 2; + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = cutlass::layout::RowMajor; + + static constexpr int ElementsPerAccessA = + 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = + 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessC = + 128 / cutlass::sizeof_bits::value; + using ThreadBlockShape = cutlass::gemm::GemmShape<32, 128, ThreadblockK>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, ThreadblockK>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +template +struct MoeArchTraits { + private: + static constexpr int ThreadblockK = 32; + + public: + static constexpr int Stages = 2; + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = cutlass::layout::RowMajor; + + static constexpr int ElementsPerAccessA = + 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = + 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessC = + 128 / cutlass::sizeof_bits::value; + using ThreadBlockShape = cutlass::gemm::GemmShape<32, 128, ThreadblockK>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, ThreadblockK>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// ======================= Turing Traits ============================== +// Turing will dequantize after LDSM + +// fp16 x fp16 specialization +template <> +struct MoeArchTraits { + static constexpr int Stages = 2; + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = cutlass::layout::RowMajor; + + static constexpr int ElementsPerAccessA = + 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = + 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessC = + 128 / cutlass::sizeof_bits::value; + using ThreadBlockShape = cutlass::gemm::GemmShape<32, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// bf16 x bf16 specialization +template <> +struct MoeArchTraits { + static constexpr int Stages = 2; + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = cutlass::layout::RowMajor; + + static constexpr int ElementsPerAccessA = + 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = + 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessC = + 128 / cutlass::sizeof_bits::value; + using ThreadBlockShape = cutlass::gemm::GemmShape<32, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +template <> +struct MoeArchTraits { + static constexpr int Stages = 3; + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = cutlass::layout::RowMajor; + + static constexpr int ElementsPerAccessA = 4; + static constexpr int ElementsPerAccessB = 4; + static constexpr int ElementsPerAccessC = 4; + using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 16>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +template <> +struct MoeArchTraits { + static constexpr int Stages = 3; + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = cutlass::layout::RowMajor; + + static constexpr int ElementsPerAccessA = + 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = + 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessC = + 128 / cutlass::sizeof_bits::value; + using ThreadBlockShape = cutlass::gemm::GemmShape<32, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +template <> +struct MoeArchTraits { + static constexpr int Stages = 3; + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = cutlass::layout::RowMajor; + + static constexpr int ElementsPerAccessA = + 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = + 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessC = + 128 / cutlass::sizeof_bits::value; + using ThreadBlockShape = cutlass::gemm::GemmShape<32, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/paddle/phi/kernels/fusion/cutlass/linear_combination_ft_gelu.h b/paddle/phi/kernels/fusion/cutlass/linear_combination_ft_gelu.h new file mode 100644 index 00000000000..0dac70dd8a6 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/linear_combination_ft_gelu.h @@ -0,0 +1,687 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing linear combination with a maximum operation used by + epilogues. +*/ + +#pragma once +#include +#include +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Single source of truth for whether to unroll for `LinearCombinationClamp()` +constexpr bool LinearCombinationFtGeluIsHeavy() { return false; } + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +__forceinline__ __device__ float copysignf_pos(float a, float b) { + float r; + r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); + return r; +} + +__inline__ __device__ float tanh_opt(float x) { +#if (__CUDA_ARCH__ >= 750) + float r; + asm("tanh.approx.f32 %0,%1; \n\t" : "=f"(r) : "f"(x)); + return r; +#else + const float exp_val = -1.f * fabs(2 * x); + return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); +#endif +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GELU operator implemented using the Taylor series approximation +template +struct FtGelu { + static const bool kIsHeavy = true; + CUTLASS_DEVICE + T operator()(T const &z) const { + T k0 = static_cast(0.7978845608028654); + T k1 = static_cast(0.044715); + + return T(cutlass::constants::half() * z * + (cutlass::constants::one() + + fast_tanh(k0 * z * (cutlass::constants::one() + k1 * z * z)))); + } +}; + +template <> +struct FtGelu { + static const bool kIsHeavy = true; + CUTLASS_DEVICE + float operator()(float const &z) const { + float k0 = static_cast(0.7978845608028654); + float k1 = static_cast(0.044715); + + return float( + z * + (cutlass::constants::one() + + tanh_opt(k0 * z * (cutlass::constants::one() + k1 * z * z)))); + } +}; + +template +struct FtGelu> { + static const bool kIsHeavy = true; + CUTLASS_DEVICE + Array operator()(Array const &z) const { + using T = half_t; + Array y; + + half_t k0 = half_t(0.7978845608028654); + half_t k1 = half_t(0.044715); + + multiply_add> fma; + multiplies> mul; + plus> add; + + fast_tanh_op> tanh; + + Array u = + mul(mul(k0, z), fma(mul(k1, z), z, cutlass::constants::one())); + + y = mul(mul(z, cutlass::constants::half()), + add(cutlass::constants::one(), tanh(u))); + + return y; + } +}; + +template +struct FtGelu> { + static const bool kIsHeavy = true; + CUTLASS_DEVICE + Array operator()(Array const &rhs) const { + Array y; + FtGelu gelu_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + y[i] = gelu_op(rhs[i]); + } + + return y; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator to an array of elements. +/// +/// D = alpha * accumulator + beta * source + uniform +/// +template < + typename ElementOutput_, ///< Data type used to load and store tensors + int Count, ///< Number of elements computed per operation + ///< Usually it is 128/sizeof_bits, + ///< but we use 64 or 32 sometimes when there are not enough + ///< data to store + typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type + typename ElementCompute_ = + ElementOutput_, ///< Data type used to compute linear combination + ScaleType::Kind Scale = + ScaleType::Default, ///< Control Alpha and Beta scaling + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest> +class LinearCombinationFtGelu { + public: + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static int const kCount = Count; + static const ScaleType::Kind kScale = Scale; + + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using FragmentCompute = Array; + using FragmentScaleBias = Array; + + static FloatRoundStyle const kRound = Round; + + static bool const kIsHeavy = detail::LinearCombinationFtGeluIsHeavy(); + + /// Host-constructable parameters structure + struct Params { + ElementCompute alpha; ///< scales accumulators + ElementCompute beta; ///< scales source tensor + ElementCompute threshold; ///< minimum value that is output + ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if + ///< not null, loads it from memory + ElementCompute const *beta_ptr; ///< pointer to source scalar - if not + ///< null, loads it from memory + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() + : alpha(ElementCompute(1)), + beta(ElementCompute(0)), + threshold(ElementCompute(0)), + alpha_ptr(nullptr), + beta_ptr(nullptr) {} + + CUTLASS_HOST_DEVICE + Params(ElementCompute alpha, + ElementCompute beta = ElementCompute(0), + ElementCompute threshold = ElementCompute(0)) + : alpha(alpha), + beta(beta), + threshold(threshold), + alpha_ptr(nullptr), + beta_ptr(nullptr) {} + + CUTLASS_HOST_DEVICE + Params(ElementCompute const *alpha_ptr, + ElementCompute const *beta_ptr = nullptr, + ElementCompute threshold = ElementCompute(0)) + : alpha(0), + beta(0), + threshold(threshold), + alpha_ptr(alpha_ptr), + beta_ptr(beta_ptr) {} + }; + + private: + // + // Data members + // + + ElementCompute alpha_; + ElementCompute beta_; + ElementCompute threshold_; + + public: + /// Constructs the function object, possibly loading from pointers in host + /// memory + CUTLASS_HOST_DEVICE + explicit LinearCombinationFtGelu(Params const ¶ms) { + alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); + beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); + threshold_ = params.threshold; + } + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + if (Scale == ScaleType::NoBetaScaling) return true; + + if (Scale == ScaleType::OnlyAlphaScaling) return false; + + if (Scale == ScaleType::Nothing) return false; + + return beta_ != ElementCompute(0); + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) { + if (k_partition) { + beta_ = ElementCompute(1); + } + + if (k_partition != k_partition_count - 1) { + // set to NaN to make ReLU no-op for all except last k partitions + int64_t allones = -1; + threshold_ = reinterpret_cast(allones); + } + } + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()(FragmentAccumulator const &accumulator, + FragmentOutput const &source) const { + // Convert source to interal compute numeric type + NumericArrayConverter + source_converter; + NumericArrayConverter + accumulator_converter; + + FragmentCompute converted_source = source_converter(source); + FragmentCompute converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + FragmentCompute intermediate; + + multiplies mul_add_source; + multiply_add mul_add_accumulator; + GELU ftgelu; + + if (Scale == ScaleType::NoBetaScaling) { + intermediate = converted_source; + intermediate = + mul_add_accumulator(alpha_, + converted_accumulator, + intermediate); // D = alpha * Accum + X + } else if (Scale == ScaleType::Nothing) { + intermediate = converted_accumulator; + } else { + intermediate = + mul_add_source(beta_, converted_source); // X = beta * C + uniform + intermediate = + mul_add_accumulator(alpha_, + converted_accumulator, + intermediate); // D = alpha * Accum + X + } + + // Compute threshold optionally + intermediate = ftgelu(intermediate); + + // Convert to destination numeric type + NumericArrayConverter + destination_converter; + + return destination_converter(intermediate); + } + + /// Computes linear scaling: D = alpha * accumulator + CUTLASS_HOST_DEVICE + FragmentOutput operator()(FragmentAccumulator const &accumulator) const { + // Convert source to interal compute numeric type + NumericArrayConverter + accumulator_converter; + + FragmentCompute converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + FragmentCompute intermediate; + + multiplies mul_accumulator; + GELU ftgelu; + + if (Scale == ScaleType::Nothing) { + intermediate = converted_accumulator; + } else { + intermediate = + mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum + } + + // Compute threshold optionally + intermediate = ftgelu(intermediate); + + // Convert to destination numeric type + NumericArrayConverter + destination_converter; + + return destination_converter(intermediate); + } + + /// Computes per-channel linear scaling and bias : D = scale * accumulator + + /// bias Scale and Bias are from input Fragment + CUTLASS_HOST_DEVICE + FragmentOutput operator()(FragmentAccumulator const &accumulator, + FragmentScaleBias const &scale, + FragmentScaleBias const &bias) const { + // Convert source to interal compute numeric type + NumericArrayConverter + accumulator_converter; + + FragmentCompute converted_accumulator = accumulator_converter(accumulator); + + // Perform per-channel scale and bias + FragmentCompute intermediate; + + multiply_add mul_add_accumulator; + + if (Scale == ScaleType::OnlyAlphaPerChannelScaling) + intermediate = mul_add_accumulator( + scale, converted_accumulator, bias); // D = scale * Accum + bias + else + intermediate = mul_add_accumulator( + alpha_, converted_accumulator, bias); // D = alpha * Accum + bias + + GELU ftgelu; + + // Compute threshold optionally + intermediate = ftgelu(intermediate); + + // Convert to destination numeric type + NumericArrayConverter + destination_converter; + + return destination_converter(intermediate); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Conditional guards to enable partial specialization for packed integers +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && \ + ((__CUDACC_VER_MAJOR__ > 10) || \ + ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2))) + +/// Applies a linear combination operator to an array of elements. +/// +/// D = alpha * accumulator + beta * source + uniform +/// +/// Special handling for int types + +template +class LinearCombinationFtGelu { + public: + using ElementOutput = ElementOutput_; + using ElementAccumulator = int; + using ElementCompute = float; + + static bool const kIsHeavy = detail::LinearCombinationFtGeluIsHeavy(); + + static int const kCount = Count; + static const ScaleType::Kind kScale = Scale; + + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using FragmentCompute = Array; + using FragmentScaleBias = Array; + + static FloatRoundStyle const kRound = Round; + + /// Host-constructable parameters structure + struct Params { + ElementCompute alpha; ///< scales accumulators + ElementCompute beta; ///< scales source tensor + ElementCompute threshold; ///< minimum value that is output + ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if + ///< not null, loads it from memory + ElementCompute const *beta_ptr; ///< pointer to source scalar - if not + ///< null, loads it from memory + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() + : alpha(ElementCompute(1)), + beta(ElementCompute(0)), + threshold(ElementCompute(0)), + alpha_ptr(nullptr), + beta_ptr(nullptr) {} + + CUTLASS_HOST_DEVICE + Params(ElementCompute alpha, + ElementCompute beta = ElementCompute(0), + ElementCompute threshold = ElementCompute(0)) + : alpha(alpha), + beta(beta), + threshold(threshold), + alpha_ptr(nullptr), + beta_ptr(nullptr) {} + + CUTLASS_HOST_DEVICE + Params(ElementCompute const *alpha_ptr, + ElementCompute const *beta_ptr = nullptr, + ElementCompute threshold = ElementCompute(0)) + : alpha(0), + beta(0), + threshold(threshold), + alpha_ptr(alpha_ptr), + beta_ptr(beta_ptr) {} + }; + + private: + // + // Data members + // + + ElementCompute alpha_; + ElementCompute beta_; + ElementCompute threshold_; + + public: + /// Constructs the function object, possibly loading from pointers in host + /// memory + CUTLASS_HOST_DEVICE + explicit LinearCombinationFtGelu(Params const ¶ms) { + alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); + beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); + threshold_ = params.threshold; + } + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + if (Scale == ScaleType::NoBetaScaling) return true; + + if (Scale == ScaleType::OnlyAlphaScaling) return false; + + if (Scale == ScaleType::Nothing) return false; + + return beta_ != ElementCompute(0); + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) { + if (k_partition) { + beta_ = ElementCompute(1); + } + + if (k_partition != k_partition_count - 1) { + // set to NaN to make ReLU no-op for all except last k partitions + int64_t allones = -1; + threshold_ = reinterpret_cast(allones); + } + } + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()(FragmentAccumulator const &accumulator, + FragmentOutput const &source) const { + // Convert source to interal compute numeric type + NumericArrayConverter + source_converter; + NumericArrayConverter + accumulator_converter; + + FragmentCompute converted_source = source_converter(source); + FragmentCompute converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + FragmentCompute intermediate; + + multiplies mul_add_source; + multiply_add mul_add_accumulator; + GELU ftgelu; + + if (Scale == ScaleType::NoBetaScaling) { + intermediate = converted_source; + intermediate = + mul_add_accumulator(alpha_, + converted_accumulator, + intermediate); // D = alpha * Accum + X + } else if (Scale == ScaleType::Nothing) { + intermediate = converted_accumulator; + } else { + intermediate = + mul_add_source(beta_, converted_source); // X = beta * C + uniform + intermediate = + mul_add_accumulator(alpha_, + converted_accumulator, + intermediate); // D = alpha * Accum + X + } + + // Compute threshold optionally + intermediate = ftgelu(intermediate); + + if (platform::numeric_limits::is_integer) { + // Convert floats back to INT + FragmentAccumulator scaled_accumulator; + + NumericArrayConverter + compute_converter; + + scaled_accumulator = compute_converter(intermediate); + + // Convert to destination numeric type + NumericArrayConverter + destination_converter; + + return destination_converter(scaled_accumulator); + } else { + NumericArrayConverter + destination_converter; + return destination_converter(intermediate); + } + } + + /// Computes linear scaling: D = alpha * accumulator + CUTLASS_HOST_DEVICE + FragmentOutput operator()(FragmentAccumulator const &accumulator) const { + // Convert source to interal compute numeric type + NumericArrayConverter + accumulator_converter; + + FragmentCompute converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + FragmentCompute intermediate; + + multiplies mul_accumulator; + GELU ftgelu; + + if (Scale == ScaleType::Nothing) { + intermediate = converted_accumulator; + } else { + intermediate = + mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum + } + + // Compute threshold optionally + intermediate = ftgelu(intermediate); + + if (platform::numeric_limits::is_integer) { + // Convert floats back to INT + FragmentAccumulator scaled_accumulator; + + NumericArrayConverter + compute_converter; + + scaled_accumulator = compute_converter(intermediate); + + // Convert to destination numeric type + NumericArrayConverter + destination_converter; + + return destination_converter(scaled_accumulator); + } else { + NumericArrayConverter + destination_converter; + return destination_converter(intermediate); + } + } + + /// Computes per-channel linear scaling and bias : D = scale * accumulator + + /// bias Scale and Bias are from input Fragment + CUTLASS_HOST_DEVICE + FragmentOutput operator()(FragmentAccumulator const &accumulator, + FragmentScaleBias const &scale, + FragmentScaleBias const &bias) const { + // Convert source to interal compute numeric type + NumericArrayConverter + accumulator_converter; + + FragmentCompute converted_accumulator = accumulator_converter(accumulator); + + // Perform per-channel scale and bias + FragmentCompute intermediate; + + multiply_add mul_add_accumulator; + + if (Scale == ScaleType::OnlyAlphaPerChannelScaling) + intermediate = mul_add_accumulator( + scale, converted_accumulator, bias); // D = scale * Accum + bias + else + intermediate = mul_add_accumulator( + alpha_, converted_accumulator, bias); // D = alpha * Accum + bias + + GELU ftgelu; + + // Compute threshold optionally + intermediate = ftgelu(intermediate); + + if (platform::numeric_limits::is_integer) { + // Convert floats back to INT + FragmentAccumulator scaled_accumulator; + + NumericArrayConverter + compute_converter; + + scaled_accumulator = compute_converter(intermediate); + + // Convert to destination numeric type + NumericArrayConverter + destination_converter; + + return destination_converter(scaled_accumulator); + } else { + NumericArrayConverter + destination_converter; + return destination_converter(intermediate); + } + } +}; + +#endif // Conditional guards to enable partial specialization for packed + // integers + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/moe_cutlass_kernel.h b/paddle/phi/kernels/fusion/cutlass/moe_cutlass_kernel.h new file mode 100644 index 00000000000..f037f4e01b1 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/moe_cutlass_kernel.h @@ -0,0 +1,879 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + *modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + *notice, this list of conditions and the following disclaimer in the + *documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its + *contributors may be used to endorse or promote products derived from this + *software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, + *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, + *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/complex.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/gemm/kernel/grouped_problem_visitor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/trace.h" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +/// Visitor class to abstract away the algorithm for iterating over tiles +template +struct BaseMoeProblemVisitor { + using ThreadblockShape = ThreadblockShape_; + + struct ProblemInfo { + static int32_t const kNoPrefetchEntry = -1; + int32_t problem_idx; + int32_t problem_start; + + CUTLASS_DEVICE + ProblemInfo() + : problem_idx(kNoPrefetchEntry), problem_start(kNoPrefetchEntry) {} + + CUTLASS_DEVICE + ProblemInfo(int32_t problem_idx_, int32_t problem_start_) + : problem_idx(problem_idx_), problem_start(problem_start_) {} + }; + + struct Params { + int64_t const *last_row_for_problem; + int64_t gemm_n; + int64_t gemm_k; + int32_t problem_count; + void const *workspace; + int32_t tile_count; + + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + Params() + : last_row_for_problem(nullptr), + gemm_n(0), + gemm_k(0), + problem_count(0), + workspace(nullptr), + tile_count(0) {} + + /// Ctor + CUTLASS_HOST_DEVICE + Params(int64_t const *last_row_for_problem, + int64_t gemm_n, + int64_t gemm_k, + int32_t problem_count, + void const *workspace = nullptr, + int32_t tile_count = 0) + : last_row_for_problem(last_row_for_problem), + gemm_n(gemm_n), + gemm_k(gemm_k), + problem_count(problem_count), + workspace(workspace), + tile_count(tile_count) {} + }; + + Params const ¶ms; + int32_t tile_idx; + int32_t problem_tile_start; + int32_t problem_idx; + + // + // Methods + // + CUTLASS_DEVICE + BaseMoeProblemVisitor(Params const ¶ms_, int32_t block_idx) + : params(params_), + tile_idx(block_idx), + problem_tile_start(0), + problem_idx(0) {} + + /// Get the grid shape + CUTLASS_HOST_DEVICE + static cutlass::gemm::GemmCoord grid_shape( + const cutlass::gemm::GemmCoord &problem) { + return cutlass::gemm::GemmCoord( + ((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), + ((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN), + 1); + } + + /// Gets the global tile index + CUTLASS_HOST_DEVICE + int32_t tile_index() const { return tile_idx; } + + /// Gets the index of the problem + CUTLASS_HOST_DEVICE + int32_t problem_index() const { return problem_idx; } + + CUTLASS_HOST_DEVICE + int32_t threadblock_idx() const { return tile_idx - problem_tile_start; } + + CUTLASS_DEVICE + void advance(int32_t grid_size) { tile_idx += grid_size; } + + CUTLASS_HOST_DEVICE + static void possibly_transpose_problem( + cutlass::gemm::GemmCoord &problem) { // NOLINT + ProblemSizeHelper::possibly_transpose_problem(problem); + } + + /// Returns the problem size for the current problem + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size() const { + return problem_size(problem_idx); + } + + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size(int idx) const { + const int64_t prev_problem_row = + idx == 0 ? 0 : params.last_row_for_problem[idx - 1]; + const int64_t current_problem_row = params.last_row_for_problem[idx]; + const int64_t gemm_m = current_problem_row - prev_problem_row; + GemmCoord problem(GemmCoord::Index(gemm_m), + GemmCoord::Index(params.gemm_n), + GemmCoord::Index(params.gemm_k)); + ProblemSizeHelper::possibly_transpose_problem(problem); + return problem; + } + + CUTLASS_HOST_DEVICE + static int32_t tile_count(const cutlass::gemm::GemmCoord &grid) { + return ProblemSizeHelper::tile_count(grid); + } + + static int32_t group_tile_count( + const cutlass::gemm::GemmCoord *host_problem_sizes_ptr, + int32_t problem_count) { + int32_t total_tiles = 0; + for (int32_t i = 0; i < problem_count; ++i) { + auto problem = host_problem_sizes_ptr[i]; + possibly_transpose_problem(problem); + auto grid = grid_shape(problem); + total_tiles += tile_count(grid); + } + + return total_tiles; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MoeProblemVisitor; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// ProblemVisitor that performs all scheduling on device +// +template +struct MoeProblemVisitor + : public BaseMoeProblemVisitor { + using Base = BaseMoeProblemVisitor; + using Params = typename Base::Params; + static int const kThreadCount = ThreadCount; + static bool const kRequiresPrecomputation = false; + static int const kThreadsPerWarp = 32; + + struct SharedStorage {}; + + // Final tile of the problem loaded by this thread. Each thread will hold + // a separate value. + int32_t problem_ending_tile; + + SharedStorage &shared_storage; + + // + // Methods + // + CUTLASS_DEVICE + MoeProblemVisitor(Params const ¶ms_, + SharedStorage &shared_storage_, // NOLINT + int32_t block_idx) + : Base(params_, block_idx), + problem_ending_tile(0), + shared_storage(shared_storage_) { + this->problem_idx = -1 * kThreadsPerWarp; + this->problem_tile_start = 0; + } + + CUTLASS_DEVICE + bool next_tile() { + // Check whether the tile to compute is within the range of the current + // problem. + int32_t problem_tile_end = __shfl_sync( + 0xffffffff, problem_ending_tile, this->problem_idx % kThreadsPerWarp); + if (this->tile_idx < problem_tile_end) { + return true; + } + + // Check whether the tile to compute is within the current group of problems + // fetched by the warp. The last tile for this group is the final tile of + // the problem held by the final thread in the warp. + int32_t group_tile_end = + __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1); + + // Keep the starting problem for this group in `problem_idx`. This is done + // to reduce register pressure. The starting problem for this group is + // simply the first problem in the group most recently fetched by the warp. + int32_t &group_problem_start = this->problem_idx; + group_problem_start = + (this->problem_idx / kThreadsPerWarp) * kThreadsPerWarp; + + // Keep the starting tile for this group in `problem_tile_start`. This is + // done to reduce register pressure. + int32_t &group_tile_start = this->problem_tile_start; + + // Each thread in the warp processes a separate problem to advance until + // reaching a problem whose starting tile is less less than tile_idx. + while (group_tile_end <= this->tile_idx) { + group_problem_start += kThreadsPerWarp; + if (group_problem_start > this->params.problem_count) { + return false; + } + + // Since `group_tile_start` is a reference to `this->problem_tile_start`, + // this also sets `this->problem_tile_start`. The fact that + // `this->problem_tile_start` is also set here is used later in + // `next_tile`. + group_tile_start = group_tile_end; + + int lane_idx = threadIdx.x % kThreadsPerWarp; + int32_t lane_problem = group_problem_start + lane_idx; + + // Compute the number of tiles in the problem assigned to each thread. + problem_ending_tile = 0; + if (lane_problem < this->params.problem_count) { + cutlass::gemm::GemmCoord problem = this->problem_size(lane_problem); + cutlass::gemm::GemmCoord grid = this->grid_shape(problem); + problem_ending_tile = this->tile_count(grid); + } + + // Compute a warp-wide inclusive prefix sum to compute the ending tile + // index of each thread's problem. + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kThreadsPerWarp; i <<= 1) { + int32_t val = __shfl_up_sync(0xffffffff, problem_ending_tile, i); + if (lane_idx >= i) { + problem_ending_tile += val; + } + } + + // The total tile count for this group is now in the final position of the + // prefix sum + int32_t tiles_in_group = + __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1); + + problem_ending_tile += group_tile_start; + group_tile_end += tiles_in_group; + } + + // The next problem to process is the first one that does not have ending + // tile position that is greater than or equal to tile index. + int32_t problem_idx_in_group = __popc( + __ballot_sync(0xffffffff, problem_ending_tile <= this->tile_idx)); + + this->problem_idx = group_problem_start + problem_idx_in_group; + + // The starting tile for this problem is the ending tile of the previous + // problem. In cases where `problem_idx_in_group` is the first problem in + // the group, we do not need to reset `problem_tile_start`, because it is + // set to the previous group's ending tile in the while loop above. + if (problem_idx_in_group > 0) { + this->problem_tile_start = __shfl_sync( + 0xffffffff, problem_ending_tile, problem_idx_in_group - 1); + } + + return true; + } + + static size_t get_workspace_size( + const cutlass::gemm::GemmCoord *host_problem_sizes_ptr, + int32_t problem_count, + int32_t block_count) { + return 0; + } + + static void host_precompute( + const cutlass::gemm::GemmCoord *host_problem_sizes_ptr, + int32_t problem_count, + int32_t block_count, + void *host_workspace_ptr) {} +}; + +/// Visitor class to abstract away the algorithm for iterating over tiles +template +struct GemmMoeProblemVisitor + : public MoeProblemVisitor, + ThreadblockShape, + GroupScheduleMode_, + PrefetchTileCount, + ThreadCount> { + static bool const kTransposed = Transposed; + + using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper; + using Base = MoeProblemVisitor; + using Params = typename Base::Params; + using SharedStorage = typename Base::SharedStorage; + + // + // Methods + // + CUTLASS_DEVICE + GemmMoeProblemVisitor(Params const ¶ms_, + SharedStorage &shared_storage_, // NOLINT + int32_t block_idx) + : Base(params_, shared_storage_, block_idx) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// This section exists to that we can use the same kernel code for regular gemm +// and dequantizing gemms. It will dispatch to the dequantizing gemm if the Mma +// type has an Iterator for scales in global. +template +using void_t = void; + +template +struct use_dq_gemm : platform::false_type {}; + +template +struct use_dq_gemm> + : platform::true_type {}; + +// SFINAE overload for dequantizing gemm +template < + typename Mma, + typename ElementScale, + typename platform::enable_if::value, bool>::type = true> +CUTLASS_DEVICE static void run_mma(Mma mma, + int gemm_k_iterations, + typename Mma::FragmentC &accum, // NOLINT + typename Mma::IteratorA iterator_A, + typename Mma::IteratorB iterator_B, + typename Mma::FragmentC const &src_accum, + ElementScale *weight_scale_ptr, + MatrixCoord scale_extent, + const int thread_idx, + MatrixCoord tb_offset_scale) { + typename Mma::IteratorScale iterator_scale( + Mma::IteratorScale::Layout(scale_extent.column()), + weight_scale_ptr, + scale_extent, + thread_idx, + tb_offset_scale); + + mma(gemm_k_iterations, + accum, + iterator_A, + iterator_B, + iterator_scale, + src_accum); +} + +// SFINAE overload for normal gemm. This completely ignores the scale parameters +template < + typename Mma, + typename ElementScale, + typename platform::enable_if::value, bool>::type = true> +CUTLASS_DEVICE static void run_mma(Mma mma, + int gemm_k_iterations, + typename Mma::FragmentC &accum, // NOLINT + typename Mma::IteratorA iterator_A, + typename Mma::IteratorB iterator_B, + typename Mma::FragmentC const &src_accum, + ElementScale *weight_scale_ptr, + MatrixCoord scale_extent, + const int thread_idx, + MatrixCoord tb_offset_scale) { + mma(gemm_k_iterations, accum, iterator_A, iterator_B, src_accum); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MoeFCGemm { + public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; + static bool const kTransposed = false; + + // Optional transpose + using MapArguments = + kernel::detail::MapArguments; + + // Public-facing type definitions related to operand element type, layout, and + // complex conjugate operation. Must interact with the 'kTransposed' notion. + static_assert(!kTransposed, "Transpose problem not supported"); + using ElementA = typename MapArguments::ElementA; + using LayoutA = typename MapArguments::LayoutA; + using ElementB = typename MapArguments::ElementB; + using LayoutB = typename MapArguments::LayoutB; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename MapArguments::LayoutC; + using ElementScale = ElementC; + + static ComplexTransform const kTransformA = MapArguments::kTransformA; + static ComplexTransform const kTransformB = MapArguments::kTransformB; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = MapArguments::kAlignmentA; + static int const kAlignmentB = MapArguments::kAlignmentB; + static int const kAlignmentC = + Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using ProblemVisitor = GemmMoeProblemVisitor; + + // + // Structures + // + + /// Argument structure + struct Arguments { + // + // Data members + // + + int problem_count; + int threadblock_count; + + typename EpilogueOutputOp::Params output_op; + + ElementA *ptr_A; + ElementB *ptr_B; + ElementScale *weight_scales; + ElementC *ptr_C; + ElementC *ptr_D; + + int64_t *total_rows_before_expert; + int64_t gemm_n; + int64_t gemm_k; + + // Only used by device-level operator + GemmCoord *host_problem_sizes; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() + : problem_count(0), + threadblock_count(0), + ptr_A(nullptr), + ptr_B(nullptr), + weight_scales(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + total_rows_before_expert(nullptr), + gemm_n(0), + gemm_k(0), + host_problem_sizes(nullptr) {} + + /// Ctor + CUTLASS_HOST_DEVICE + Arguments(int problem_count, + int threadblock_count, + typename EpilogueOutputOp::Params output_op, + const ElementA *ptr_A, + const ElementB *ptr_B, + const ElementScale *weight_scales, + const ElementC *ptr_C, + ElementC *ptr_D, + int64_t *total_rows_before_expert, + int64_t gemm_n, + int64_t gemm_k, + GemmCoord *host_problem_sizes = nullptr) + : problem_count(problem_count), + threadblock_count(threadblock_count), + output_op(output_op), + ptr_A(const_cast(ptr_A)), + ptr_B(const_cast(ptr_B)), + weight_scales(const_cast(weight_scales)), + ptr_C(const_cast(ptr_C)), + ptr_D(ptr_D), + total_rows_before_expert(total_rows_before_expert), + gemm_n(gemm_n), + gemm_k(gemm_k), + host_problem_sizes(nullptr) { + if (platform::is_same::value || + platform::is_same::value) { + assert(weight_scales); + } + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + typename ProblemVisitor::Params problem_visitor; + int threadblock_count; + + typename EpilogueOutputOp::Params output_op; + + ElementA *ptr_A; + ElementB *ptr_B; + ElementScale *weight_scales; + ElementC *ptr_C; + ElementC *ptr_D; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() + : ptr_A(nullptr), + ptr_B(nullptr), + weight_scales(nullptr), + ptr_C(nullptr), + ptr_D(nullptr) {} + + CUTLASS_HOST_DEVICE + Params(Arguments const &args, + void *workspace = nullptr, + int tile_count = 0) // NOLINT + : problem_visitor(args.total_rows_before_expert, + args.gemm_n, + args.gemm_k, + args.problem_count, + workspace, + tile_count), + threadblock_count(args.threadblock_count), + output_op(args.output_op), + ptr_A(args.ptr_A), + ptr_B(args.ptr_B), + weight_scales(args.weight_scales), + ptr_C(args.ptr_C), + ptr_D(args.ptr_D) {} + + CUTLASS_HOST_DEVICE + void update(Arguments const &args, + void *workspace = nullptr, + int tile_count = 0) { + problem_visitor = + typename ProblemVisitor::Params(args.total_rows_before_expert, + args.gemm_n, + args.gemm_k, + args.problem_count, + workspace, + tile_count); + threadblock_count = args.threadblock_count; + output_op = args.output_op; + ptr_A = args.ptr_A; + ptr_B = args.ptr_B; + weight_scales = args.weight_scales; + ptr_C = args.ptr_C; + ptr_D = args.ptr_D; + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename ProblemVisitor::SharedStorage problem_visitor; + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + public: + // + // Methods + // + + CUTLASS_DEVICE + MoeFCGemm() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const &problem_size) { + return Status::kSuccess; + } + + static Status can_implement(Arguments const &args) { + if (platform::is_same::value || + platform::is_same::value) { + if (args.weight_scales == nullptr) { + CUTLASS_TRACE_HOST( + "MoeFCGemm::can_implement() - weight scales are required for " + "uint8_t and uint4b_t"); + return Status::kInvalid; + } + } else if (args.weight_scales != nullptr) { + CUTLASS_TRACE_HOST( + "MoeFCGemm::can_implement() - weight scales are ignored for all " + "types except uint8_t and uint4b_t"); + return Status::kInvalid; + } + return Status::kSuccess; + } + + static size_t get_extra_workspace_size( + Arguments const &args, cutlass::gemm::GemmCoord const &grid_tiled_shape) { + return 0; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, + SharedStorage &shared_storage) { // NOLINT + // + // These types shadow the type-level definitions and support the ability to + // implement a 'transposed' GEMM that computes the transposed problems. + // + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + static constexpr int kInterleave = + Mma::IteratorB::Shape::kRow / Mma::SmemIteratorB::Shape::kRow; + static_assert(platform::is_same::value && + kInterleave == 1 || + platform::is_same::value && + kInterleave >= 1, + "B must be row major/col major OR col major interleaved."); + + // + // Problem visitor. + // + ProblemVisitor problem_visitor( + params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); + + const int64_t gemm_k = params.problem_visitor.gemm_k; + const int64_t gemm_n = params.problem_visitor.gemm_n; + int64_t bytes_per_expert_matrix = + (gemm_k * gemm_n / 8) * cutlass::sizeof_bits::value; + + // Outer 'persistent' loop to iterate over tiles + while (problem_visitor.next_tile()) { + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); + + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + + cutlass::gemm::GemmCoord threadblock_offset( + int(cta_idx / grid_shape.n()) * Mma::Shape::kM, // NOLINT + int(cta_idx % grid_shape.n()) * Mma::Shape::kN, // NOLINT + 0); + + // Load element pointers. Exchange pointers and strides if working on the + // transpose + const int64_t rows_to_jump = + problem_idx == 0 + ? 0 + : params.problem_visitor.last_row_for_problem[problem_idx - 1]; + ElementA *ptr_A = + reinterpret_cast(params.ptr_A) + rows_to_jump * gemm_k; + typename LayoutA::LongIndex ldm_A = gemm_k; + + char *byte_ptr_B = ((char *)params.ptr_B) + // NOLINT + problem_idx * bytes_per_expert_matrix; + ElementB *ptr_B = reinterpret_cast(byte_ptr_B); + typename LayoutB::LongIndex ldm_B = + platform::is_same::value + ? gemm_n + : gemm_k * kInterleave; + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_offset.m(), + 0, + }; + + cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave}; + + cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()}; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(LayoutA(ldm_A), + ptr_A, + {problem_size.m(), problem_size.k()}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B( + LayoutB(ldm_B), + ptr_B, + {problem_size.k() * kInterleave, problem_size.n() / kInterleave}, + thread_idx, + tb_offset_B); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Matrix multiply phase + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = + (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Wait for all threads to finish their epilogue phases from the previous + // tile. + __syncthreads(); + + // Compute threadblock-scoped matrix multiply-add + ElementScale *weight_scale_ptr = + params.weight_scales + problem_idx * problem_size.n(); + run_mma(mma, + gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + accumulators, + weight_scale_ptr, + {1, problem_size.n()}, + thread_idx, + tb_offset_scale); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + ElementC *ptr_C = + reinterpret_cast(params.ptr_C) + problem_idx * gemm_n; + ElementC *ptr_D = + reinterpret_cast(params.ptr_D) + rows_to_jump * gemm_n; + + LayoutC layout_C(0); + LayoutC layout_D(gemm_n); + + typename Epilogue::OutputTileIterator::Params params_C(layout_C); + typename Epilogue::OutputTileIterator::Params params_D(layout_D); + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C(params_C, + ptr_C, + problem_size.mn(), + thread_idx, + threadblock_offset.mn()); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D(params_D, + ptr_D, + problem_size.mn(), + thread_idx, + threadblock_offset.mn()); + + Epilogue epilogue( + shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // Next tile + problem_visitor.advance(gridDim.x); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/moe_kernel.cu b/paddle/phi/kernels/fusion/cutlass/moe_kernel.cu new file mode 100644 index 00000000000..dcdb30ed9a8 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/moe_kernel.cu @@ -0,0 +1,911 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/fusion/moe_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/fusion/cutlass/moe_kernel_impl.h" + +// Ignore CUTLASS warnings about type punning +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#pragma GCC diagnostic ignored "-Wunused-function" + +#include "cutlass/array.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "cutlass/numeric_conversion.h" +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/kernels/fusion/cutlass/default_moe_fc_traits.h" +#include "paddle/phi/kernels/fusion/cutlass/linear_combination_ft_gelu.h" +#include "paddle/phi/kernels/fusion/cutlass/moe_cutlass_kernel.h" +#pragma GCC diagnostic pop +namespace phi { + +namespace { +inline int getSMVersion() { + const int device = phi::backends::gpu::GetCurrentDeviceId(); + const phi::gpuDeviceProp prop = + phi::backends::gpu::GetDeviceProperties(device); + return prop.major * 10 + prop.minor; +} + +struct EpilogueOpBiasReLU {}; + +struct EpilogueOpBiasFtGelu {}; + +struct EpilogueOpBias {}; + +struct EpilogueOpNoBias {}; + +template +struct Epilogue {}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationRelu< + ElementType, + ElementsPerVectorAccess, + ElementAccumulator, + ElementAccumulator, + cutlass::epilogue::thread::ScaleType::NoBetaScaling>; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationFtGelu< + ElementType, + ElementsPerVectorAccess, + ElementAccumulator, + ElementAccumulator, + cutlass::epilogue::thread::ScaleType::NoBetaScaling>; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombination< + ElementType, + ElementsPerVectorAccess, + ElementAccumulator, + ElementAccumulator, + cutlass::epilogue::thread::ScaleType::NoBetaScaling>; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombination< + ElementType, + ElementsPerVectorAccess, + ElementAccumulator, + ElementAccumulator, + cutlass::epilogue::thread::ScaleType::Nothing>; +}; + +} // namespace + +namespace fusion { + +template +void InitExpertChoiceRouteKernelLauncher( + int* expert_for_source_row, + int* source_row, + int* expanded_source_row_to_expanded_dest_row, + int64_t* total_rows_before_expert, + T* attr_mask, + const int num_experts, + const int num_rows, + const int k, + const int batch_size, + cudaStream_t stream) { + const int threads = 128; + const int blocks = num_experts; + + initialize_expert_choice_route_kernel<<>>( + expert_for_source_row, + source_row, + expanded_source_row_to_expanded_dest_row, + total_rows_before_expert, + attr_mask, + num_rows, + k, + batch_size); +} + +#define SOFTMAX_KERNEL(ITEMS_PER_THREAD) \ + block.x /= ITEMS_PER_THREAD; \ + assert(block.x <= 1024); \ + if (is_half2) { \ + if (grid.x % 4 == 0) { \ + grid.x /= 4; \ + softmax_kernel_v5_half2<__half, ITEMS_PER_THREAD, 4> \ + <<>>(reinterpret_cast(buffer), \ + (const half*)attr_mask, \ + batch_size, \ + head_num, \ + seq_len_1, \ + seq_len_2, \ + (const half)scalar); \ + } else { \ + softmax_kernel_v4_half2<__half, ITEMS_PER_THREAD> \ + <<>>(reinterpret_cast(buffer), \ + (const half*)attr_mask, \ + batch_size, \ + head_num, \ + seq_len_1, \ + seq_len_2, \ + (const half)scalar); \ + } \ + } else { \ + softmax_kernel_v4 \ + <<>>(buffer, \ + buffer_src, \ + attr_mask, \ + batch_size, \ + head_num, \ + seq_len_1, \ + seq_len_2, \ + scalar); \ + } + +template +void invokeMaskedSoftMax(T* buffer, + const T* buffer_src, + const T* attr_mask, + const int batch_size, + const int seq_len_1, + const int seq_len_2, + const int head_num, + const T scalar, + cudaStream_t stream) { + // NOTE: attention scores shape (batch_size, head_num, seq_len_1, seq_len_2) + dim3 grid(seq_len_1, batch_size, head_num); + if (batch_size * head_num > 360) { + grid.x = ceil(static_cast(seq_len_1) / 32.0f); + } + + bool is_half2 = sizeof(T) == 2 && sizeof(T) == 2 && seq_len_2 % 2 == 0; + dim3 block((seq_len_2 / (is_half2 ? 2 : 1) + 31) / 32 * 32); + + if (block.x > 2048 && block.x <= 4096) { + SOFTMAX_KERNEL(4) + } else if (block.x > 1024) { + SOFTMAX_KERNEL(2) + } else if (block.x > 0) { + SOFTMAX_KERNEL(1) + } else { + PADDLE_ENFORCE_EQ(true, + false, + phi::errors::InvalidArgument( + "Softmax kernel only support columns in 0 - 4096. ")); + } +} + +template +void InvokeTransposeAxis01(T* out, + T* in, + const int dim0, + const int dim1, + const int dim2, + cudaStream_t stream) { + dim3 block(512); + dim3 grid(static_cast(ceil(dim0 * dim1 * dim2 / 512.))); + transposeAxis01<<>>(out, in, dim0, dim1, dim2); +} + +template +void InvokePadding(T* output1, + int* output2, + const T* input1, + const int* input2, + const int* input_lengths, + const int num_tokens, + const int batch_size, + const int max_seq_len, + const int num_experts, + cudaStream_t stream) { + assert(max_seq_len <= 1024); + dim3 block(max_seq_len); + dim3 grid(num_experts); + paddingKernel<<>>(output1, + output2, + input1, + input2, + input_lengths, + num_tokens, + batch_size, + max_seq_len, + num_experts); +} + +template +void InvokeGeneralTopKPairSort(T* out_keys, + int* out_values, + T* in_keys, + int* in_values, + const int m, + const int n, + cudaStream_t stream) { + assert(n <= 4096); + const int blocks = m; + + if (n == 128) { + general_topk_pair_sort + <<>>(out_keys, out_values, in_keys, in_values); + } + if (n == 256) { + general_topk_pair_sort + <<>>(out_keys, out_values, in_keys, in_values); + } + if (n == 1024) { + general_topk_pair_sort + <<>>(out_keys, out_values, in_keys, in_values); + } else if (n == 2048) { + general_topk_pair_sort + <<>>(out_keys, out_values, in_keys, in_values); + } else if (n == 4096) { + general_topk_pair_sort + <<>>(out_keys, out_values, in_keys, in_values); + } +} + +template +void InitMoeRoutingKernelLauncher( + const T* unpermuted_input, + T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, + const int num_experts, + const int num_rows, + const int active_rows, + const int cols, + const int k, + const int batch_size, + const int max_seq_len, + bool ec_route, + cudaStream_t stream) { + const int blocks = ec_route ? num_experts * k * batch_size : num_rows * k; + if (ec_route) { + constexpr int max_pack_size = 16 / sizeof(T); + const int threads = std::min(cols / max_pack_size, 1024); + if (cols % max_pack_size == 0) { + initialize_moe_routing_kernel + <<>>( + unpermuted_input, + permuted_output, + expanded_dest_row_to_expanded_source_row, + expanded_source_row_to_expanded_dest_row, + num_rows, + batch_size * k * num_experts, + cols, + k, + max_seq_len, + ec_route); + } else { + initialize_moe_routing_kernel<<>>( + unpermuted_input, + permuted_output, + expanded_dest_row_to_expanded_source_row, + expanded_source_row_to_expanded_dest_row, + num_rows, + batch_size * k * num_experts, + cols, + k, + max_seq_len, + ec_route); + } + } else { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "Currently only support `ec_route = True`. ")); + } +} + +template +void GenericMoeGemmKernelLauncher(const T* A, + const T* B, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + const int multi_processor_count, + cudaStream_t stream) { + static_assert(cutlass::platform::is_same::value || + cutlass::platform::is_same::value, + "Specialized for half, float"); + static_assert( + cutlass::platform::is_same::value || + cutlass::platform::is_same::value || + cutlass::platform::is_same::value, + "cutlass weight type only support float, half, uint8_t, uint4b_t"); + // The cutlass type for the input elements. This is needed to convert to + // cutlass::half_t if necessary. + using ElementType_ = typename cutlass::platform::conditional< + cutlass::platform::is_same::value, + cutlass::half_t, + T>::type; + using ElementType = ElementType_; + using CutlassWeightType_ = typename cutlass::platform::conditional< + cutlass::platform::is_same::value, + cutlass::half_t, + WeightType>::type; + using CutlassWeightType = CutlassWeightType_; + + // We need separate config for each architecture since we will target + // different tensorcore instructions. For float, we do not target TCs. + using MoeArchTraits = cutlass::gemm::kernel:: + MoeArchTraits; + using ElementAccumulator = typename MoeArchTraits::AccType; + using EpilogueOp = typename Epilogue::Op; + + // Finally, set up the kernel. + using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemmGrouped< + ElementType, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kNone, + MoeArchTraits::ElementsPerAccessA, + CutlassWeightType, + typename MoeArchTraits::LayoutB, + cutlass::ComplexTransform::kNone, + MoeArchTraits::ElementsPerAccessB, + ElementType, + cutlass::layout::RowMajor, + ElementAccumulator, + typename MoeArchTraits::OperatorClass, + arch, + typename MoeArchTraits::ThreadBlockShape, + typename MoeArchTraits::WarpShape, + typename MoeArchTraits::InstructionShape, + EpilogueOp, + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + MoeArchTraits::Stages, + cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, + typename MoeArchTraits::Operator>::GemmKernel; + + using GemmKernel = + cutlass::gemm::kernel::MoeFCGemm; + using GemmGrouped = cutlass::gemm::device::GemmGrouped; + + int occupancy = GemmGrouped::maximum_active_blocks(); + const int threadblock_count = multi_processor_count * occupancy; + if (occupancy == 0) { + PADDLE_THROW(paddle::platform::errors::Fatal( + "[MoE Runner] GPU lacks the shared memory resources to run GroupedGEMM " + "kernel")); + } + + typename EpilogueOp::Params epilogue_op(ElementAccumulator(1.f), + ElementAccumulator(1.f)); + typename GemmGrouped::Arguments args( + num_experts, + threadblock_count, + epilogue_op, + reinterpret_cast(A), + reinterpret_cast(B), + reinterpret_cast(weight_scales), + reinterpret_cast(biases), + reinterpret_cast(C), + total_rows_before_expert, + gemm_n, + gemm_k); + GemmGrouped gemm; + auto can_implement = gemm.can_implement(args); + if (can_implement != cutlass::Status::kSuccess) { + std::string err_msg = "MoEFC kernel will fail for params. Error: " + + std::string(cutlassGetStatusString(can_implement)); + PADDLE_THROW(paddle::platform::errors::Fatal("[MoE Runner] " + err_msg)); + } + auto init_status = gemm.initialize(args); + if (init_status != cutlass::Status::kSuccess) { + std::string err_msg = + "Failed to initialize cutlass variable batched gemm. Error: " + + std::string(cutlassGetStatusString(init_status)); + PADDLE_THROW(paddle::platform::errors::Fatal("[MoE Runner] " + err_msg)); + } + auto run_status = gemm.run(stream); + if (run_status != cutlass::Status::kSuccess) { + std::string err_msg = + "Failed to run cutlass variable batched gemm. Error: " + + std::string(cutlassGetStatusString(run_status)); + PADDLE_THROW(paddle::platform::errors::Fatal("[MoE Runner] " + err_msg)); + } +} + +template +void gemm_bias_act(const T* A, + const T* B, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + int sm, + int multi_processor_count, + const std::string& act_type, + cudaStream_t stream) { + if (act_type == "gelu") { + if (sm == 75) { + GenericMoeGemmKernelLauncher( + A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + gemm_n, + gemm_k, + num_experts, + multi_processor_count, + stream); + } else if (sm == 80 || sm == 86) { + GenericMoeGemmKernelLauncher( + A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + gemm_n, + gemm_k, + num_experts, + multi_processor_count, + stream); + } else { + GenericMoeGemmKernelLauncher( + A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + gemm_n, + gemm_k, + num_experts, + multi_processor_count, + stream); + } + } else { + // act type is relu. + if (sm == 75) { + GenericMoeGemmKernelLauncher(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + gemm_n, + gemm_k, + num_experts, + multi_processor_count, + stream); + } else if (sm == 80 || sm == 86) { + GenericMoeGemmKernelLauncher(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + gemm_n, + gemm_k, + num_experts, + multi_processor_count, + stream); + } else { + GenericMoeGemmKernelLauncher(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + gemm_n, + gemm_k, + num_experts, + multi_processor_count, + stream); + } + } +} + +template +void gemm(const T* A, + const T* B, + const T* weight_scales, + T* C, + int64_t* total_rows_before_expert, + const int gemm_n, + const int gemm_k, + const int num_experts, + int sm, + int multi_processor_count, + cudaStream_t stream) { + if (sm == 75) { + GenericMoeGemmKernelLauncher( + A, + B, + weight_scales, + nullptr, + C, + total_rows_before_expert, + gemm_n, + gemm_k, + num_experts, + multi_processor_count, + stream); + } else if (sm == 80 || sm == 86) { + GenericMoeGemmKernelLauncher( + A, + B, + weight_scales, + nullptr, + C, + total_rows_before_expert, + gemm_n, + gemm_k, + num_experts, + multi_processor_count, + stream); + } else { + GenericMoeGemmKernelLauncher( + A, + B, + weight_scales, + nullptr, + C, + total_rows_before_expert, + gemm_n, + gemm_k, + num_experts, + multi_processor_count, + stream); + } +} + +template +void finalize_moe_routing_kernelLauncher( + const T* expanded_permuted_rows, + T* reduced_unpermuted_output, + const T* skip, + const T* bias, + const T* scales, + const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, + const int num_experts, + const int num_rows, + const int cols, + const int k, + bool ec_route, + cudaStream_t stream) { + const int blocks = num_rows; + const int threads = std::min(cols, 1024); + { + finalize_moe_routing_kernel<<>>( + expanded_permuted_rows, + reduced_unpermuted_output, + skip, + bias, + scales, + expanded_source_row_to_expanded_dest_row, + expert_for_source_row, + cols, + num_experts, + ec_route); + } +} + +template +void MoeKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& gate, + const DenseTensor& bmm0, + const DenseTensor& bias0, + const DenseTensor& bmm1, + const DenseTensor& bias1, + const std::string& act_type, + DenseTensor* output) { + const T* input_activations = x.data(); + T* gating_output = const_cast(gate.data()); + const T* fc1_expert_weights = bmm0.data(); + const T* fc1_expert_biases = bias0.data(); + const T* fc2_expert_weights = bmm1.data(); + const T* fc2_expert_biases = bias1.data(); + // int moe_act = static_cast(act); + T* output_ = ctx.template Alloc(output); + auto stream = ctx.stream(); + + auto input_dims = x.dims(); + auto bmm0_dims = bmm0.dims(); + const bool IS_FP16 = std::is_same::value; + + const int num_rows = input_dims[0] * input_dims[1]; + const int hidden_size = input_dims[2]; + const int inter_size = bmm0_dims[2]; + const int num_experts = bmm0_dims[0]; + const int k = input_dims[1] / 16; + const int batch_size = input_dims[0]; + const int max_seq_len = 128; + int64_t bytes = getWorkspaceSize(num_rows, + hidden_size, + inter_size, + num_experts, + k, + batch_size, + max_seq_len); + + // Pointers + int* source_rows; + int* padded_source_rows; + int* permuted_rows; + int* permuted_experts; + char* sorter_ws_; + T* permuted_data; + T* padded_expert_scales; + int64_t* total_rows_before_expert; + T* sorted_softmax_output; + T* attr_mask; + T* fc1_result; + + phi::DenseTensor ws_ptr_tensor = phi::Empty(ctx, {bytes}); + int8_t* ws_ptr = ws_ptr_tensor.data(); + + const int buf_size = AlignTo16(num_experts * batch_size * k * hidden_size); + const int padded_experts = AlignTo16(num_experts); + const int num_moe_inputs = AlignTo16(num_experts * num_rows); + // padded_num_moe_inputs for topk sort + int padded_num_moe_inputs = num_experts * batch_size * max_seq_len; + + source_rows = reinterpret_cast(ws_ptr); + padded_source_rows = source_rows + num_moe_inputs; + permuted_rows = padded_source_rows + padded_num_moe_inputs; + permuted_experts = permuted_rows + padded_num_moe_inputs; + permuted_data = reinterpret_cast(permuted_experts + num_experts * k); + padded_expert_scales = reinterpret_cast(permuted_data + buf_size); + total_rows_before_expert = + reinterpret_cast(padded_expert_scales + padded_num_moe_inputs); + sorted_softmax_output = + reinterpret_cast(total_rows_before_expert + padded_experts); + attr_mask = + reinterpret_cast(sorted_softmax_output + padded_num_moe_inputs); + fc1_result = reinterpret_cast(attr_mask + num_moe_inputs); + + phi::DenseTensor expert_for_source_row_tensor = + phi::Empty(ctx, {num_experts, num_rows}); + int* expert_for_source_row = expert_for_source_row_tensor.data(); + phi::DenseTensor expanded_source_row_to_expanded_dest_row_tensor = + phi::Empty(ctx, {num_experts, num_rows}); + int* expanded_source_row_to_expanded_dest_row = + expanded_source_row_to_expanded_dest_row_tensor.data(); + phi::DenseTensor expert_scales_tensor = + phi::Empty(ctx, {num_experts, num_rows}); + T* expert_scales = expert_scales_tensor.data(); + phi::DenseTensor fc2_output_tensor = + phi::Empty(ctx, {num_experts * batch_size * k, hidden_size}); + T* fc2_result = fc2_output_tensor.data(); + phi::DenseTensor input_lengths_tensor = phi::Empty(ctx, {batch_size}); + int* input_lengths = input_lengths_tensor.data(); + funcs::SetConstant set_len; + set_len(ctx, &input_lengths_tensor, static_cast(max_seq_len)); + + int sm = getSMVersion(); + int multi_processor_count = phi::backends::gpu::GetGPUMultiProcessors( + phi::backends::gpu::GetCurrentDeviceId()); + + InitExpertChoiceRouteKernelLauncher( + expert_for_source_row, + source_rows, + expanded_source_row_to_expanded_dest_row, + total_rows_before_expert, + attr_mask, + num_experts, + num_rows, + k, + batch_size, + ctx.stream()); + T scalar = (T)1.0f; + if (IS_FP16) { + invokeMaskedSoftMax<__half>(reinterpret_cast<__half*>(gating_output), + reinterpret_cast(gating_output), + reinterpret_cast(attr_mask), + /*batch_size=*/num_rows, + /*seq_len_1=*/1, + /*seq_len_2=*/num_experts, + /*head_num=*/1, + *reinterpret_cast(&scalar), + ctx.stream()); + } else { + invokeMaskedSoftMax(reinterpret_cast(gating_output), + reinterpret_cast(gating_output), + reinterpret_cast(attr_mask), + /*batch_size=*/num_rows, + /*seq_len_1=*/1, + /*seq_len_2=*/num_experts, + /*head_num=*/1, + *reinterpret_cast(&scalar), + ctx.stream()); + } + InvokeTransposeAxis01( + expert_scales, gating_output, num_rows, num_experts, 1, ctx.stream()); + + int padded_max_seq_len = max_seq_len <= 128 ? 128 : 256; + InvokePadding(padded_expert_scales, + padded_source_rows, + expert_scales, + source_rows, + input_lengths, + num_rows, + batch_size, + padded_max_seq_len, + num_experts, + ctx.stream()); + if (IS_FP16) { + InvokeGeneralTopKPairSort<__half>( + reinterpret_cast<__half*>(sorted_softmax_output), + permuted_rows, + reinterpret_cast<__half*>(padded_expert_scales), + padded_source_rows, + num_experts * batch_size, + padded_max_seq_len, + ctx.stream()); + } else { + InvokeGeneralTopKPairSort( + reinterpret_cast(sorted_softmax_output), + permuted_rows, + reinterpret_cast(padded_expert_scales), + padded_source_rows, + num_experts * batch_size, + padded_max_seq_len, + ctx.stream()); + } + InitMoeRoutingKernelLauncher(input_activations, + permuted_data, + permuted_rows, + expanded_source_row_to_expanded_dest_row, + num_experts, + num_rows, + num_rows, + hidden_size, + k, + batch_size, + max_seq_len, + true, + ctx.stream()); + + const T* fc1_scales = nullptr; + const T* fc2_scales = nullptr; + if (IS_FP16) { + gemm_bias_act(reinterpret_cast(permuted_data), + reinterpret_cast(fc1_expert_weights), + reinterpret_cast(fc1_scales), + reinterpret_cast(fc1_expert_biases), + reinterpret_cast<__half*>(fc1_result), + total_rows_before_expert, + inter_size, + hidden_size, + num_experts, + sm, + multi_processor_count, + act_type, + ctx.stream()); + gemm(reinterpret_cast(fc1_result), + reinterpret_cast(fc2_expert_weights), + reinterpret_cast(fc2_scales), + reinterpret_cast<__half*>(fc2_result), + total_rows_before_expert, + hidden_size, + inter_size, + num_experts, + sm, + multi_processor_count, + ctx.stream()); + } else { + gemm_bias_act(reinterpret_cast(permuted_data), + reinterpret_cast(fc1_expert_weights), + reinterpret_cast(fc1_scales), + reinterpret_cast(fc1_expert_biases), + reinterpret_cast(fc1_result), + total_rows_before_expert, + inter_size, + hidden_size, + num_experts, + sm, + multi_processor_count, + act_type, + ctx.stream()); + gemm(reinterpret_cast(fc1_result), + reinterpret_cast(fc2_expert_weights), + reinterpret_cast(fc2_scales), + reinterpret_cast(fc2_result), + total_rows_before_expert, + hidden_size, + inter_size, + num_experts, + sm, + multi_processor_count, + ctx.stream()); + } + + finalize_moe_routing_kernelLauncher(fc2_result, + output_, + input_activations, + fc2_expert_biases, + expert_scales, + expanded_source_row_to_expanded_dest_row, + expert_for_source_row, + num_experts, + num_rows, + hidden_size, + k, + true, + ctx.stream()); +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL( + moe, GPU, ALL_LAYOUT, phi::fusion::MoeKernel, float, phi::dtype::float16) {} diff --git a/paddle/phi/kernels/fusion/cutlass/moe_kernel_impl.h b/paddle/phi/kernels/fusion/cutlass/moe_kernel_impl.h new file mode 100644 index 00000000000..2d8fbd68fda --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/moe_kernel_impl.h @@ -0,0 +1,779 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "cub/cub.cuh" +#include "paddle/phi/kernels/funcs/math_cuda_utils.h" + +namespace phi { + +static const float HALF_FLT_MAX = 65504.F; +static const float HALF_FLT_MIN = -65504.F; +static inline size_t AlignTo16(const size_t& input) { + static constexpr int ALIGNMENT = 16; + return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT); +} + +/* +WarpReduce multi values. +TODO(zhengzekang): Add blocksize templates to reduce shared memory usage. +*/ +template +__inline__ __device__ T warpReduceSumV2(T* val) { +#pragma unroll + for (int i = 0; i < NUM; i++) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32); + } + return (T)(0.0f); +} + +template +__inline__ __device__ T blockReduceSumV2(T* val) { + static __shared__ T shared[NUM][33]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduceSumV2(val); + + if (lane == 0) { +#pragma unroll + for (int i = 0; i < NUM; i++) { + shared[i][wid] = val[i]; + } + } + + __syncthreads(); + + bool is_mask = threadIdx.x < (blockDim.x / 32.f); +#pragma unroll + for (int i = 0; i < NUM; i++) { + val[i] = is_mask ? shared[i][lane] : (T)(0.0f); + } + warpReduceSumV2(val); + return (T)0.0f; +} + +template +__inline__ __device__ T warpReduceMaxV2(T* val) { +#pragma unroll + for (int i = 0; i < NUM; i++) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32)); + } + return (T)(0.0f); +} + +template +__inline__ __device__ T blockReduceMaxV2(T* val) { + static __shared__ T shared[32][NUM]; + int lane = threadIdx.x & 0x1f; // in-warp idx + int wid = threadIdx.x >> 5; // warp idx + + warpReduceMaxV2(val); // get maxx in each warp + + if (lane == 0) { // record in-warp maxx by warp Idx +#pragma unroll + for (int i = 0; i < NUM; i++) { + shared[wid][i] = val[i]; + } + } + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + bool is_mask = threadIdx.x < (blockDim.x / 32.f); +#pragma unroll + for (int i = 0; i < NUM; i++) { + val[i] = is_mask ? shared[lane][i] : (T)-1e20f; + } + warpReduceMaxV2(val); + + return (T)0.0f; +} + +class CubKeyValueSorter { + public: + CubKeyValueSorter(); + + explicit CubKeyValueSorter(const int num_experts); + + void update_num_experts(const int num_experts); + + size_t getWorkspaceSize(const size_t num_key_value_pairs, + bool descending = false); + + template + void run(void* workspace, + const size_t workspace_size, + const KeyT* keys_in, + KeyT* keys_out, + const int* values_in, + int* values_out, + const size_t num_key_value_pairs, + bool descending, + cudaStream_t stream); + + private: + size_t num_key_value_pairs_; + int num_experts_; + int num_bits_; +}; + +// ===== CUB Sorting things ===== +CubKeyValueSorter::CubKeyValueSorter() + : num_experts_(0), num_bits_(sizeof(int) * 8) {} + +CubKeyValueSorter::CubKeyValueSorter(const int num_experts) + : num_experts_(num_experts), + num_bits_(static_cast(log2(num_experts)) + 1) {} + +void CubKeyValueSorter::update_num_experts(const int num_experts) { + num_experts_ = num_experts; + num_bits_ = static_cast(log2(num_experts)) + 1; +} + +size_t CubKeyValueSorter::getWorkspaceSize(const size_t num_key_value_pairs, + bool descending) { + num_key_value_pairs_ = num_key_value_pairs; + size_t required_storage = 0; + int* null_int = nullptr; + if (descending) { + cub::DeviceRadixSort::SortPairsDescending(NULL, + required_storage, + null_int, + null_int, + null_int, + null_int, + num_key_value_pairs, + 0, + 32); + } else { + cub::DeviceRadixSort::SortPairs(NULL, + required_storage, + null_int, + null_int, + null_int, + null_int, + num_key_value_pairs, + 0, + num_bits_); + } + return required_storage; +} + +template +void CubKeyValueSorter::run(void* workspace, + const size_t workspace_size, + const KeyT* keys_in, + KeyT* keys_out, + const int* values_in, + int* values_out, + const size_t num_key_value_pairs, + bool descending, + cudaStream_t stream) { + size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs); + size_t actual_ws_size = workspace_size; + + if (expected_ws_size > workspace_size) { + std::stringstream err_ss; + err_ss << "[Error][CubKeyValueSorter::run]\n"; + err_ss + << "Error. The allocated workspace is too small to run this problem.\n"; + err_ss << "Expected workspace size of at least " << expected_ws_size + << " but got problem size " << workspace_size << "\n"; + throw std::runtime_error(err_ss.str()); + } + if (descending) { + cub::DeviceRadixSort::SortPairsDescending(workspace, + actual_ws_size, + keys_in, + keys_out, + values_in, + values_out, + num_key_value_pairs, + 0, + 32, + stream); + } else { + cub::DeviceRadixSort::SortPairs(workspace, + actual_ws_size, + keys_in, + keys_out, + values_in, + values_out, + num_key_value_pairs, + 0, + num_bits_, + stream); + } +} + +template <> +void CubKeyValueSorter::run(void* workspace, + const size_t workspace_size, + const __nv_bfloat16* keys_in, + __nv_bfloat16* keys_out, + const int* values_in, + int* values_out, + const size_t num_key_value_pairs, + bool descending, + cudaStream_t stream) {} + +CubKeyValueSorter sorter_; + +// -------- getWorkspaceSize -------- // +template +size_t getWorkspaceSize(const int num_rows, + const int hidden_size, + const int inter_size, + const int num_experts, + const int k, + const int batch_size, + const int max_seq_len) { + const int buf_size = AlignTo16(num_experts * batch_size * k * hidden_size); + const int interbuf_size = + AlignTo16(num_experts * batch_size * k * inter_size); + const int padded_experts = AlignTo16(num_experts); + const int num_moe_inputs = AlignTo16(num_experts * num_rows); + int padded_num_moe_inputs = num_experts * batch_size * max_seq_len; + + size_t total_ws_bytes = sizeof(int) * num_moe_inputs; // source_rows_ + total_ws_bytes += sizeof(int) * padded_num_moe_inputs; // padded_source_rows_ + total_ws_bytes += sizeof(T) * padded_num_moe_inputs; // padded_expert_scales_ + total_ws_bytes += sizeof(int) * padded_num_moe_inputs; // permuted_rows_ + total_ws_bytes += sizeof(int) * num_experts * k; // permuted_experts_ + total_ws_bytes += buf_size * sizeof(T); // permuted_data_ + total_ws_bytes += + padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_ + + total_ws_bytes += sizeof(T) * num_moe_inputs; // attr_mask: [e, n] + total_ws_bytes += sizeof(T) * padded_num_moe_inputs; // sorted_softmax_output + + const int bytes_for_fc1_result = interbuf_size * sizeof(T); + const int sorter_ws_size_bytes = + AlignTo16(sorter_.getWorkspaceSize(num_experts * k)); + sorter_.update_num_experts(k); + + int bytes_for_intermediate_and_sorting = bytes_for_fc1_result; + if (sorter_ws_size_bytes > bytes_for_fc1_result) { + int remaining_bytes = + AlignTo16(sorter_ws_size_bytes - bytes_for_fc1_result); + bytes_for_intermediate_and_sorting += remaining_bytes; + } + + total_ws_bytes += + bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub + // sorting workspace + return total_ws_bytes; +} + +// -------- initialize_expert_choice_route_kernel -------- // +template +__global__ void initialize_expert_choice_route_kernel( + int* expert_for_source_row, + int* source_row, + int* expanded_source_row_to_expanded_dest_row, + int64_t* total_rows_before_expert, + T* attr_mask, + const int cols, + const int k, + const int batch_size) { + int start = cols * blockIdx.x; + + for (int i = threadIdx.x; i < cols; i += blockDim.x) { + expert_for_source_row[start + i] = blockIdx.x; + source_row[start + i] = start + i; + expanded_source_row_to_expanded_dest_row[start + i] = -1; + attr_mask[start + i] = (T)1.0f; + } + if (threadIdx.x == 0) { + total_rows_before_expert[blockIdx.x] = batch_size * k * (blockIdx.x + 1); + } +} + +// -------- softmax_kernel -------- // +template +__global__ void softmax_kernel_v4( + T* qk_buf_, + const T* qk_buf_src, // shape [batch_size, head_num, seq_len_1, seq_len_2] + const T* attr_mask, // shape [batch_size, seq_len_1, seq_len_2] + const int batch_size, + const int head_num, + const int seq_len_1, + const int seq_len_2, + const T scalar) { + for (int seq_id = blockIdx.x; seq_id < seq_len_1; seq_id += gridDim.x) { + float data[ITEMS_PER_THREAD]; + int qk_offset; + __shared__ float s_mean, s_max; + float local_max = -1e20f; + for (int i = 0; blockDim.x * i + threadIdx.x < seq_len_2; i++) { + qk_offset = ((blockIdx.y * head_num + blockIdx.z) * seq_len_1 + seq_id) * + seq_len_2 + + blockDim.x * i + threadIdx.x; + int mask_offset = (blockIdx.y * seq_len_1 + seq_id) * seq_len_2 + + blockDim.x * i + threadIdx.x; + + float qk = static_cast(qk_buf_src[qk_offset]); + float mask_val = static_cast(__ldg(&attr_mask[mask_offset])); + + mask_val = (1.0f - mask_val) * -10000.0f; + + data[i] = qk * static_cast(scalar) + mask_val; + local_max = fmax(local_max, data[i]); + } + + float max_val = + blockDim.x <= 32 + ? phi::funcs::warpReduceMax(local_max, 0xFFFFFFFF) + : phi::funcs::blockReduceMax(local_max, 0xffffffff); + if (threadIdx.x == 0) { + s_max = max_val; + } + __syncthreads(); + + float local_sum = 0; + for (int i = 0; blockDim.x * i + threadIdx.x < seq_len_2; i++) { + data[i] = __expf(data[i] - s_max); + local_sum += data[i]; + } + float sum_val = + blockDim.x <= 32 + ? phi::funcs::warpReduceSum(local_sum, 0xffffffff) + : phi::funcs::blockReduceSum(local_sum, 0xffffffff); + if (threadIdx.x == 0) { + s_mean = sum_val + 1e-6f; + s_mean = __fdividef(1.0f, s_mean); + } + __syncthreads(); + + for (int i = 0; blockDim.x * i + threadIdx.x < seq_len_2; i++) { + qk_offset = ((blockIdx.y * head_num + blockIdx.z) * seq_len_1 + seq_id) * + seq_len_2 + + blockDim.x * i + threadIdx.x; + qk_buf_[qk_offset] = (T)(data[i] * s_mean); + } + } +} + +template +__global__ void softmax_kernel_v4_half2(T* qk_buf_, + const T* attr_mask, + const int batch_size, + const int head_num, + const int seq_len_1, + const int seq_len_2, + const T scalar) { + using T2 = half2; + T2* qk_buf_half2 = reinterpret_cast(qk_buf_); + const T2* attr_mask_half2 = (const T2*)attr_mask; + + for (int seq_id = blockIdx.x; seq_id < seq_len_1; seq_id += gridDim.x) { + T2 data[ITEMS_PER_THREAD]; + int qk_offset; + __shared__ float s_mean, s_max; + float local_max = -1e20f; + for (int i = 0; + blockDim.x * i + threadIdx.x < (seq_len_2 / 2) && i < ITEMS_PER_THREAD; + i++) { + qk_offset = ((blockIdx.y * head_num + blockIdx.z) * seq_len_1 + seq_id) * + (seq_len_2 / 2) + + blockDim.x * i + threadIdx.x; + int mask_offset = (blockIdx.y * seq_len_1 + seq_id) * (seq_len_2 / 2) + + blockDim.x * i + threadIdx.x; + + T2 qk = qk_buf_half2[qk_offset]; + T2 mask_val = __ldg(&attr_mask_half2[mask_offset]); + mask_val = __hmul2(__hsub2(__float2half2_rn(1.0f), mask_val), + __float2half2_rn(-10000.0f)); + + data[i] = __hadd2(__hmul2(qk, __half2half2(scalar)), mask_val); + + local_max = fmax( + local_max, + fmax(static_cast(data[i].x), static_cast(data[i].y))); + } + + float max_val = + blockDim.x <= 32 + ? phi::funcs::warpReduceMax(local_max, 0xFFFFFFFF) + : phi::funcs::blockReduceMax(local_max, 0xFFFFFFFF); + if (threadIdx.x == 0) { + s_max = max_val; + } + __syncthreads(); + + float local_sum = 0; + for (int i = 0; + blockDim.x * i + threadIdx.x < (seq_len_2 / 2) && i < ITEMS_PER_THREAD; + i++) { + data[i] = h2exp(__hsub2(data[i], __float2half2_rn(s_max))); + local_sum += static_cast(data[i].x + data[i].y); + } + + float sum_val = + blockDim.x <= 32 + ? phi::funcs::warpReduceSum(local_sum, 0xFFFFFFFF) + : phi::funcs::blockReduceSum(local_sum, 0xFFFFFFFF); + + if (threadIdx.x == 0) { + s_mean = sum_val + 1e-6f; + s_mean = __fdividef(1.0f, s_mean); + } + __syncthreads(); + + for (int i = 0; + blockDim.x * i + threadIdx.x < (seq_len_2 / 2) && i < ITEMS_PER_THREAD; + i++) { + qk_offset = ((blockIdx.y * head_num + blockIdx.z) * seq_len_1 + seq_id) * + (seq_len_2 / 2) + + blockDim.x * i + threadIdx.x; + qk_buf_half2[qk_offset] = __hmul2(data[i], __float2half2_rn(s_mean)); + } + } +} + +template +__global__ void softmax_kernel_v5_half2(T* qk_buf_, + const T* attr_mask, + const int batch_size, + const int head_num, + const int seq_len_1, + const int seq_len_2, + const T scalar) { + using T2 = half2; + T2* qk_buf_half2 = reinterpret_cast(qk_buf_); + const T2* attr_mask_half2 = (const T2*)attr_mask; + + for (int seq_id = blockIdx.x; seq_id < seq_len_1; seq_id += gridDim.x * NUM) { + T2 data[NUM][ITEMS_PER_THREAD]; + + int qk_offset[NUM]; + + __shared__ float s_sum[NUM], s_max[NUM]; + float local_max[NUM]; +#pragma unroll + for (int j = 0; j < NUM; j++) { + local_max[j] = -1e20f; + } + + const int MAX_NUM = + min((seq_len_1 - seq_id + gridDim.x - 1) / gridDim.x, NUM); + for (int i = 0; + blockDim.x * i + threadIdx.x < (seq_len_2 / 2) && i < ITEMS_PER_THREAD; + i++) { + int mask_offset[NUM]; +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + qk_offset[j] = ((blockIdx.y * head_num + blockIdx.z) * seq_len_1 + + seq_id + j * gridDim.x) * + (seq_len_2 / 2) + + blockDim.x * i + threadIdx.x; + mask_offset[j] = (blockIdx.y * seq_len_1 + seq_id + j * gridDim.x) * + (seq_len_2 / 2) + + blockDim.x * i + threadIdx.x; + } + + T2 mask_val[NUM]; +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + mask_val[j] = __ldg(&attr_mask_half2[mask_offset[j]]); + } + + T2 qk[NUM]; +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + qk[j] = qk_buf_half2[qk_offset[j]]; + } +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + mask_val[j] = __hmul2(__hsub2(__float2half2_rn(1.0f), mask_val[j]), + __float2half2_rn(-10000.0f)); + } +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + data[j][i] = __hadd2(__hmul2(qk[j], __half2half2(scalar)), mask_val[j]); + local_max[j] = fmax(local_max[j], + fmax(static_cast(data[j][i].x), + static_cast(data[j][i].y))); + } + } + if (blockDim.x <= 32) { + warpReduceMaxV2(local_max); + } else { + blockReduceMaxV2(local_max); + } + + if (threadIdx.x == 0) { +#pragma unroll + for (int j = 0; j < NUM; j++) { + s_max[j] = local_max[j]; + } + } + __syncthreads(); + float local_sum[NUM]; +#pragma unroll + for (int j = 0; j < NUM; j++) { + local_sum[j] = {0.f}; + } + + for (int i = 0; + blockDim.x * i + threadIdx.x < (seq_len_2 / 2) && i < ITEMS_PER_THREAD; + i++) { +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + data[j][i] = h2exp(__hsub2(data[j][i], __float2half2_rn(s_max[j]))); + } + +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + local_sum[j] += static_cast(data[j][i].x + data[j][i].y); + } + } + + if (blockDim.x <= 32) { + warpReduceSumV2(local_sum); + } else { + blockReduceSumV2(local_sum); + } + + if (threadIdx.x == 0) { +#pragma unroll + for (int j = 0; j < NUM; j++) { + s_sum[j] = __fdividef(1.0f, local_sum[j] + 1e-6f); + } + } + __syncthreads(); + + for (int i = 0; + blockDim.x * i + threadIdx.x < (seq_len_2 / 2) && i < ITEMS_PER_THREAD; + i++) { +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + qk_offset[j] = ((blockIdx.y * head_num + blockIdx.z) * seq_len_1 + + seq_id + j * gridDim.x) * + (seq_len_2 / 2) + + blockDim.x * i + threadIdx.x; + } + +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + qk_buf_half2[qk_offset[j]] = + __hmul2(data[j][i], __float2half2_rn(s_sum[j])); + } + } + } +} + +// -------- transpose_kernel -------- // +template +__global__ void transposeAxis01( + T* out, T* in, const int dim0, const int dim1, const int dim2) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < dim0 * dim1 * dim2) { + const int input_dim2_index = index % dim2; + index = (index - input_dim2_index) / dim2; + const int input_dim1_index = index % dim1; + index = (index - input_dim1_index) / dim1; + const int input_dim0_index = index % dim0; + + out[input_dim1_index * dim0 * dim2 + input_dim0_index * dim2 + + input_dim2_index] = in[input_dim0_index * dim1 * dim2 + + input_dim1_index * dim2 + input_dim2_index]; + } +} + +// -------- padding_kernel -------- // +template +__global__ void paddingKernel(T* output1, + int* output2, + const T* input1, + const int* input2, + const int* input_lengths, + const int num_tokens, + const int batch_size, + const int max_seq_len, + const int num_experts) { + const bool IS_FP16 = std::is_same::value; + const T MIN_T_VAL = (IS_FP16) ? (T)HALF_FLT_MIN : (T)FLT_MIN; + int offset1 = blockIdx.x * num_tokens; + int offset2 = blockIdx.x * batch_size * max_seq_len; + for (int i = 0; i < batch_size; i++) { + const T* in1_ptr = input1 + offset1; + const int* in2_ptr = input2 + offset1; + int input_length = input_lengths[i]; + offset1 += input_length; + + T* out1_ptr = output1 + offset2; + int* out2_ptr = output2 + offset2; + offset2 += max_seq_len; + + for (int j = threadIdx.x; j < max_seq_len; j += max_seq_len) { + if (j < input_length) { + out1_ptr[j] = in1_ptr[j]; + out2_ptr[j] = in2_ptr[j]; + } else { + out1_ptr[j] = MIN_T_VAL; + out2_ptr[j] = 0; + } + } + } +} + +// -------- general_topk_pair_sort_kernel -------- // +template +__global__ void general_topk_pair_sort(T* out_keys, + int* out_values, + T* in_keys, + int* in_values) { + typedef cub::BlockRadixSort + BlockRadixSort; + typedef cub:: + BlockLoad + BlockLoadKey; + typedef cub:: + BlockLoad + BlockLoadValue; + typedef cub:: + BlockStore + BlockStoreKey; + typedef cub::BlockStore + BlockStoreValue; + + __shared__ union { + typename BlockRadixSort::TempStorage sort; + typename BlockLoadKey::TempStorage loadkey; + typename BlockLoadValue::TempStorage loadvalue; + typename BlockStoreKey::TempStorage storekey; + typename BlockStoreValue::TempStorage storevalue; + } temp_storage; + + int block_offset = blockIdx.x * BLOCK_THREADS * ITEMS_PER_THREAD; + + T thread_keys[ITEMS_PER_THREAD]; + int thread_values[ITEMS_PER_THREAD]; + BlockLoadKey(temp_storage.loadkey).Load(in_keys + block_offset, thread_keys); + BlockLoadValue(temp_storage.loadvalue) + .Load(in_values + block_offset, thread_values); + __syncthreads(); + + BlockRadixSort(temp_storage.sort).SortDescending(thread_keys, thread_values); + __syncthreads(); + + BlockStoreKey(temp_storage.storekey) + .Store(out_keys + block_offset, thread_keys); + BlockStoreValue(temp_storage.storevalue) + .Store(out_values + block_offset, thread_values); +} + +// -------- finalize_moe_routing_kernel -------- // +template +__global__ void finalize_moe_routing_kernel( + const T* expanded_permuted_rows, + T* reduced_unpermuted_output, + const T* skip, + const T* bias, + const T* scales, + const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, + const int cols, + const int k, + bool ec_route) { + const int original_row = blockIdx.x; + const int num_rows = gridDim.x; + T* reduced_row_ptr = reduced_unpermuted_output + original_row * cols; + const T* skip_row_ptr = skip + original_row * cols; + + for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) { + T thread_output = skip_row_ptr[tid]; + for (int k_idx = 0; k_idx < k; ++k_idx) { + const int expanded_original_row = original_row + k_idx * num_rows; + const int expanded_permuted_row = + expanded_source_row_to_expanded_dest_row[expanded_original_row]; + + if (ec_route && expanded_permuted_row == -1) continue; + const int64_t k_offset = + ec_route ? expanded_original_row : original_row * k + k_idx; + const T row_scale = scales[k_offset]; + const T* expanded_permuted_rows_row_ptr = + expanded_permuted_rows + expanded_permuted_row * cols; + + const int expert_idx = ec_route ? k_idx : expert_for_source_row[k_offset]; + const T* bias_ptr = bias + expert_idx * cols; + + thread_output = + thread_output + + row_scale * (expanded_permuted_rows_row_ptr[tid] + bias_ptr[tid]); + } + reduced_row_ptr[tid] = thread_output; + } +} + +// -------- initialize_moe_routing_kernel -------- // +template +__global__ void initialize_moe_routing_kernel( + const T* unpermuted_input, + T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, + const int num_rows, + const int active_rows, + const int cols, + const int k, + const int max_seq_len, + bool ec_route) { + using LoadT = phi::AlignedVector; + LoadT src_vec; + + // Reverse permutation map. + // I do this so that later, we can use the source -> dest map to do the k-way + // reduction and unpermuting. I need the reverse map for that reduction to + // allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1 + // thread block will be responsible for all k summations. + const int expanded_dest_row = blockIdx.x; + const int expanded_source_row = + ec_route ? expanded_dest_row_to_expanded_source_row[expanded_dest_row / + k * max_seq_len + + expanded_dest_row % k] + : expanded_dest_row_to_expanded_source_row[expanded_dest_row]; + if (threadIdx.x == 0) { + expanded_source_row_to_expanded_dest_row[expanded_source_row] = + expanded_dest_row; + } + + if (blockIdx.x < active_rows) { + // Duplicate and permute rows + const int source_row = expanded_source_row % num_rows; + + const T* source_row_ptr = unpermuted_input + source_row * cols; + T* dest_row_ptr = permuted_output + expanded_dest_row * cols; + + for (int tid = threadIdx.x * VecSize; tid < cols; + tid += blockDim.x * VecSize) { + // dest_row_ptr[tid] = source_row_ptr[tid]; + phi::Load(&source_row_ptr[tid], &src_vec); + phi::Store(src_vec, &dest_row_ptr[tid]); + } + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/fusion/moe_kernel.h b/paddle/phi/kernels/fusion/moe_kernel.h new file mode 100644 index 00000000000..65a54ed1978 --- /dev/null +++ b/paddle/phi/kernels/fusion/moe_kernel.h @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void MoeKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& gate, + const DenseTensor& bmm0, + const DenseTensor& bias0, + const DenseTensor& bmm1, + const DenseTensor& bias1, + const std::string& act_type, + DenseTensor* output); + +} // namespace phi diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 25c545861d3..3602d375911 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -78,6 +78,7 @@ if(NOT WITH_GPU) list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op_api) endif() +list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op) list(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_op) list(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_grad_op) list(REMOVE_ITEM TEST_OPS test_fuse_gemm_epilogue_pass) @@ -143,6 +144,7 @@ if(WIN32) list(REMOVE_ITEM TEST_OPS test_trt_convert_preln_residual_bias) list(REMOVE_ITEM TEST_OPS test_trt_convert_preln_residual_no_bias) list(REMOVE_ITEM TEST_OPS test_fused_multi_transformer_int8_op) + list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op) endif() list(REMOVE_ITEM TEST_OPS test_checkpoint_saver) diff --git a/python/paddle/fluid/tests/unittests/test_fused_ec_moe_op.py b/python/paddle/fluid/tests/unittests/test_fused_ec_moe_op.py new file mode 100644 index 00000000000..14bdac34870 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fused_ec_moe_op.py @@ -0,0 +1,176 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from op_test import OpTest + +import paddle +import paddle.nn.functional as F +from paddle.fluid.framework import default_main_program +from paddle.incubate.nn.functional import fused_ec_moe +from paddle.nn.layer.common import Linear + +default_main_program().random_seed = 42 + + +class TestFusedEcMoEOp(OpTest): + def setUp(self): + self.config() + self.rtol = 1e-3 + self.atol = 1e-3 + + paddle.set_default_dtype(self.x_type) + self.__class__.op_type = "fused_ec_moe" + # Since it's only used in inference. + self.__class__.no_need_check_grad = True + + self.bmm_w0 = paddle.to_tensor( + np.random.randn(self.num_expert, self.d_model, self.d_feedforward) + * 0.001, + dtype=paddle.float16, + ) + self.bmm_b0 = paddle.to_tensor( + np.random.randn(self.num_expert, 1, self.d_feedforward) * 0.001, + dtype=paddle.float16, + ) + self.bmm_w1 = paddle.to_tensor( + np.random.randn(self.num_expert, self.d_feedforward, self.d_model) + * 0.001, + dtype=paddle.float16, + ) + self.bmm_b1 = paddle.to_tensor( + np.random.randn(self.num_expert, 1, self.d_model) * 0.001, + dtype=paddle.float16, + ) + self.tensor_x = paddle.to_tensor( + np.random.randn(self.batch_size, self.seq_len, self.d_model) + * 0.001, + dtype=paddle.float16, + ) + + self.bmm_w0.stop_gradient = True + self.bmm_b0.stop_gradient = True + self.bmm_w1.stop_gradient = True + self.bmm_b1.stop_gradient = True + self.tensor_x.stop_gradient = True + + self.gate = Linear(self.d_model, self.num_expert) + + paddle.set_default_dtype("float16") + self.activation = getattr(F, self.act_method) + + def config(self): + self.x_type = np.float16 + self.batch_size = 10 + self.seq_len = 128 + self.num_expert = 32 + self.d_model = 768 + self.d_feedforward = 3072 + self.act_method = 'gelu' + + def GetBaselineOut(self, tensor_x, gate_logits): + def expert_choice_gating(logits, capacity, batch_idx, expert_idx): + gates = F.softmax(logits, -1) + indices1_s = paddle.topk( + logits.transpose([0, 2, 1]), k=capacity, axis=-1 + )[1].cast("int32") + seqlen_idx = indices1_s.reshape([-1]) + gather_idx = paddle.stack([batch_idx, seqlen_idx, expert_idx], -1) + prob = paddle.gather_nd(gates, gather_idx) + return prob, expert_idx, gather_idx, capacity + + paddle.disable_static() + capacity = self.seq_len // 16 + batch_expert_idx = paddle.nonzero( + paddle.ones(shape=[self.batch_size, self.num_expert, capacity]) + ).cast('int32') + batch_idx = batch_expert_idx[:, 0] + expert_idx = batch_expert_idx[:, 1] + + ( + expert_prob_flatten, + expert_idx_flatten, + gather_idx, + cap, + ) = expert_choice_gating(gate_logits, capacity, batch_idx, expert_idx) + outputs = paddle.zeros_like(tensor_x) + batch_prob = expert_prob_flatten.reshape( + [self.batch_size, self.num_expert, -1, 1] + ) + + batch_idx = gather_idx[:, :2] + selected_token = tensor_x.gather_nd(batch_idx) + + batch_selected_token = selected_token.reshape( + [self.batch_size, self.num_expert, -1, tensor_x.shape[-1]] + ) + batch_selected_token = batch_selected_token.transpose( + [1, 0, 2, 3] + ).reshape([self.num_expert, -1, tensor_x.shape[-1]]) + + output = paddle.bmm(batch_selected_token, self.bmm_w0) + self.bmm_b0 + output = self.activation(output) + output = paddle.bmm(output, self.bmm_w1) + self.bmm_b1 + + output = output.transpose([1, 0, 2]).reshape( + [self.batch_size, -1, self.num_expert, tensor_x.shape[-1]] + ) + output = output.transpose([0, 2, 1, 3]) + output = batch_prob * output + output = output.reshape([-1, tensor_x.shape[-1]]) + + outputs = outputs.scatter_nd_add(batch_idx, output) + return outputs + tensor_x + + def GetFusedEcMoeOut(self, tensor_x, gate_logits): + paddle.disable_static() + fused_out = fused_ec_moe( + tensor_x, + gate_logits, + self.bmm_w0, + self.bmm_b0, + self.bmm_w1, + self.bmm_b1, + self.act_method, + ) + + return fused_out + + def test_fused_ec_moe_op(self): + gate_logits = self.gate(self.tensor_x) + final_out_ref = self.GetBaselineOut(self.tensor_x, gate_logits) + final_out = self.GetFusedEcMoeOut(self.tensor_x, gate_logits) + + np.testing.assert_allclose( + final_out_ref, final_out, rtol=self.rtol, atol=self.atol + ) + + +class TestFusedEcMoEOpActGeluFp16(TestFusedEcMoEOp): + def config(self): + super().config() + self.x_type = np.float16 + + +class TestFusedEcMoEOpActReluFp16(TestFusedEcMoEOp): + def config(self): + super().config() + self.x_type = np.float16 + self.act_method = "relu" + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/incubate/nn/__init__.py b/python/paddle/incubate/nn/__init__.py index 62a48b783df..fe6a2abd9bb 100644 --- a/python/paddle/incubate/nn/__init__.py +++ b/python/paddle/incubate/nn/__init__.py @@ -20,6 +20,7 @@ from .layer.fused_linear import FusedLinear # noqa: F401 from .layer.fused_transformer import ( FusedBiasDropoutResidualLayerNorm, ) # noqa: F401 +from .layer.fused_ec_moe import FusedEcMoe # noqa: F401 __all__ = [ # noqa 'FusedMultiHeadAttention', @@ -28,4 +29,5 @@ __all__ = [ # noqa 'FusedMultiTransformer', 'FusedLinear', 'FusedBiasDropoutResidualLayerNorm', + 'FusedEcMoe', ] diff --git a/python/paddle/incubate/nn/functional/__init__.py b/python/paddle/incubate/nn/functional/__init__.py index e9894990455..a8f9cb70eca 100644 --- a/python/paddle/incubate/nn/functional/__init__.py +++ b/python/paddle/incubate/nn/functional/__init__.py @@ -17,6 +17,7 @@ from .fused_transformer import fused_feedforward from .fused_transformer import fused_multi_transformer from .fused_matmul_bias import fused_matmul_bias, fused_linear from .fused_transformer import fused_bias_dropout_residual_layer_norm +from .fused_ec_moe import fused_ec_moe __all__ = [ 'fused_multi_head_attention', @@ -25,4 +26,5 @@ __all__ = [ 'fused_matmul_bias', 'fused_linear', 'fused_bias_dropout_residual_layer_norm', + 'fused_ec_moe', ] diff --git a/python/paddle/incubate/nn/functional/fused_ec_moe.py b/python/paddle/incubate/nn/functional/fused_ec_moe.py new file mode 100644 index 00000000000..ca2057fc016 --- /dev/null +++ b/python/paddle/incubate/nn/functional/fused_ec_moe.py @@ -0,0 +1,75 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.fluid.layer_helper import LayerHelper + + +def fused_ec_moe( + x, gate, bmm0_weight, bmm0_bias, bmm1_weight, bmm1_bias, act_type +): + """ + Applies fused ec_moe kernel. + This method requires SM_ARCH in sm75, sm80, sm86. + + Args: + x (Tensor): the input Tensor. Its shape is [bsz, seq_len, d_model]. + gate (Tensor): the gate Tensor to choose expert. Its shape is [bsz, seq_len, e]. + bmm0_weight (Tensor): the first batch matrix matmul weight. Its shape is [e, d_model, d_feed_forward]. + bmm0_bias (Tensor): the first batch matrix matmul bias. Its shape is [e, 1, d_feed_forward]. + bmm1_weight (Tensor): the second batch matrix matmul weight. Its shape is [e, d_model, d_feed_forward]. + bmm1_bias (Tensor): the second batch matrix matmul bias. Its shape is [e, 1, d_feed_forward]. + act_type (string): the Activation Type. Currently only support `gelu`, `relu`. + + Returns: + Tensor: the output Tensor. + + Examples: + .. code-block:: python + + # required: gpu + import paddle + from paddle.incubate.nn.functional import fused_ec_moe + + batch = 10 + seq_len = 128 + d_model = 1024 + d_feed_forward = d_model * 4 + num_expert = 8 + + x = paddle.randn([batch, seq_len, d_model]) + gate = paddle.randn([batch, seq_len, num_expert]) + bmm0_weight = paddle.randn([num_expert, d_model, d_feed_forward]) + bmm0_bias = paddle.randn([num_expert, d_model, d_feed_forward]) + bmm1_weight = paddle.randn([num_expert, d_model, d_feed_forward]) + bmm1_bias = paddle.randn([num_expert, d_model, d_feed_forward]) + out = fused_ec_moe(x, gate, bmm0_weight, bmm0_bias, bmm1_weight, bmm1_bias, act_type="gelu") + + print(out.shape) # [batch, seq_len, num_expert] + """ + helper = LayerHelper('fused_moe', **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='moe', + inputs={ + 'X': x, + 'Gate': gate, + 'Bmm0': bmm0_weight, + 'Bias0': bmm0_bias, + 'Bmm1': bmm1_weight, + 'Bias1': bmm1_bias, + }, + outputs={'Out': out}, + attrs={'act_type': act_type}, + ) + return out diff --git a/python/paddle/incubate/nn/layer/fused_ec_moe.py b/python/paddle/incubate/nn/layer/fused_ec_moe.py new file mode 100644 index 00000000000..407c8753519 --- /dev/null +++ b/python/paddle/incubate/nn/layer/fused_ec_moe.py @@ -0,0 +1,101 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.incubate.nn import functional as F +from paddle.nn import Layer + + +class FusedEcMoe(Layer): + r"""A FusedEcMoe Layer. + + Parameters: + hidden_size (int): The dim size of input units. + inter_size (int): The dim size of feed forward network. + num_expert (int): The number of experts. + act_type (string): The activation type. Currently only support `gelu`, `relu`. + weight_attr (ParamAttr, optional): The attribute for the learnable + weight of this layer. The default value is None and the weight will be + initialized to zero. For detailed information, please refer to + paddle.ParamAttr. + bias_attr (ParamAttr|bool, optional): The attribute for the learnable bias + of this layer. If it is set to False, no bias will be added to the output. + If it is set to None or one kind of ParamAttr, a bias parameter will + be created according to ParamAttr. For detailed information, please refer + to paddle.ParamAttr. The default value is None and the bias will be + initialized to zero. + + Attribute: + **weight** (Parameter): the learnable weight of this layer. + **bias** (Parameter): the learnable bias of this layer. + + Shape: + - input: Multi-dimentional tensor with shape :math:`[batch\_size, seq\_len, d\_model]` . + - output: Multi-dimentional tensor with shape :math:`[batch\_size, seq\_len, d\_model]` . + + Examples: + .. code-block:: python + + # required: gpu + import paddle + from paddle.incubate.nn.layer.fused_ec_moe import FusedEcMoe + + x = paddle.randn([10, 128, 1024]) # [bsz, seq_len, d_model] + gate = paddle.randn([10, 128, 8]) # [bsz, seq_len, num_experts] + moe = FusedEcMoe(1024, 4096, 8, act_type="gelu") + y = moe(x, gate) + print(y.shape) # [10, 128, 1024] + """ + + def __init__( + self, + hidden_size, + inter_size, + num_experts, + act_type, + weight_attr=None, + bias_attr=None, + ): + super().__init__() + weight0_shape = [num_experts, hidden_size, inter_size] + bias0_shape = [num_experts, 1, inter_size] + weight1_shape = [num_experts, inter_size, hidden_size] + bias1_shape = [num_experts, 1, hidden_size] + + dtype = self._helper.get_default_dtype() + self.bmm_weight0 = self.create_parameter( + shape=weight0_shape, attr=weight_attr, dtype=dtype, is_bias=False + ) + self.bmm_bias0 = self.create_parameter( + shape=bias0_shape, attr=bias_attr, dtype=dtype, is_bias=True + ) + self.bmm_weight1 = self.create_parameter( + shape=weight1_shape, attr=weight_attr, dtype=dtype, is_bias=False + ) + self.bmm_bias1 = self.create_parameter( + shape=bias1_shape, attr=bias_attr, dtype=dtype, is_bias=True + ) + self.act_type = act_type + if self.act_type not in ["gelu", "relu"]: + raise NotImplementedError("Currently only support `gelu`, `relu`. ") + + def forward(self, x, gate): + return F.fused_ec_moe( + x, + gate, + self.bmm_weight0, + self.bmm_bias0, + self.bmm_weight1, + self.bmm_bias1, + self.act_type, + ) -- GitLab