未验证 提交 def2a87f 编写于 作者: X xiaoxiaohehe001 提交者: GitHub

[Paddle Inference] Add moe phi kernel (#48703)

上级 efa34534
......@@ -516,6 +516,7 @@ if(WITH_GPU
if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.0)
include(external/cutlass) # download, build, install cusparselt
list(APPEND third_party_deps extern_cutlass)
set(WITH_CUTLASS ON)
endif()
endif()
......
......@@ -235,6 +235,15 @@ nvinfer1::DimsExprs UnchangedInferMeta(
return inputs[0];
}
nvinfer1::DimsExprs MoeInferMeta(
int output_index,
const nvinfer1::DimsExprs* inputs,
int nb_inputs,
nvinfer1::IExprBuilder& expr_builder, // NOLINT
const framework::OpDesc& op_desc) {
return inputs[0];
}
nvinfer1::DimsExprs Pad3dInferMeta(
int output_index,
const nvinfer1::DimsExprs* inputs,
......@@ -384,6 +393,7 @@ PD_REGISTER_DYNAMIC_INFER_META_FN(instance_norm, InstanceNormInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(unfold, UnflodInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(scatter_nd_add, ScatterNdAddInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(inverse, UnchangedInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(moe, MoeInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(pad3d, Pad3dInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(grid_sampler, GridSamplerInferMeta);
} // namespace tensorrt
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
class MoeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class MoeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The source input tensor of Moe op.");
AddInput("Gate", "(Tensor), The gating input tensor of Moe op.");
AddInput("Bmm0", "(Tensor), The bmm0 input tensor of Moe op.");
AddInput("Bias0", "(Tensor), The eltwise0 input tensor of Moe op.");
AddInput("Bmm1", "(Tensor), The bmm1 input tensor of Moe op.");
AddInput("Bias1", "(Tensor), The eltwise1 input tensor of Moe op.");
AddOutput("Out", "(Tensor), The output tensor of Moe op.");
AddAttr<std::string>(
"act_type",
R"DOC(activation type, currently only support `gelu`, `relu`. Default value is: `gelu`. )DOC")
.SetDefault("gelu");
AddComment(
R"DOC(FusedEcMoe kernel. For more details you can refer to `FusedEcMoE` python documents. )DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(moe,
MoeInferShapeFunctor,
PD_INFER_META(phi::MoeInferMeta));
REGISTER_OPERATOR(moe, ops::MoeOp, ops::MoeOpMaker, MoeInferShapeFunctor);
......@@ -2931,6 +2931,20 @@ void YoloLossInferMeta(const MetaTensor& x,
gt_match_mask->set_dtype(x.dtype());
}
void MoeInferMeta(const MetaTensor& x,
const MetaTensor& gate,
const MetaTensor& bmm0,
const MetaTensor& bias0,
const MetaTensor& bmm1,
const MetaTensor& bias1,
const std::string& act_type,
MetaTensor* out) {
out->set_dims(x.dims());
out->share_lod(x);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
}
} // namespace phi
PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta);
......@@ -523,4 +523,13 @@ void YoloLossInferMeta(const MetaTensor& x,
MetaTensor* objectness_mask,
MetaTensor* gt_match_mask);
void MoeInferMeta(const MetaTensor& x,
const MetaTensor& gate,
const MetaTensor& bmm0,
const MetaTensor& bias0,
const MetaTensor& bmm1,
const MetaTensor& bias1,
const std::string& act_type,
MetaTensor* out);
} // namespace phi
......@@ -104,6 +104,12 @@ file(
"strings/gpu/*.cu"
"fusion/gpu/*.cu")
if(WITH_CUTLASS)
file(GLOB cutlass_cu "fusion/cutlass/default_moe_fc_traits.h"
"fusion/cutlass/linear_combination_ft_gelu.h" "fusion/cutlass/moe*")
list(APPEND kernel_cu ${cutlass_cu})
endif()
if(WITH_MKLDNN)
file(
GLOB
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/bfloat16.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/layout/matrix.h"
namespace cutlass {
namespace gemm {
namespace kernel {
template <typename TypeA, typename TypeB, typename arch>
struct MoeArchTraits {};
template <typename arch>
struct MoeArchTraits<float, float, arch> {
static constexpr int Stages = 2;
using OperatorClass = cutlass::arch::OpClassSimt;
using AccType = float;
using LayoutB = cutlass::layout::RowMajor;
static constexpr int ElementsPerAccessA = 1;
static constexpr int ElementsPerAccessB = 1;
static constexpr int ElementsPerAccessC = 1;
using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 8>;
using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
using Operator = cutlass::arch::OpMultiplyAdd;
};
// ========================= Volta Traits ===========================
// Volta will always dequantize after the global memory load.
template <typename TypeB>
struct MoeArchTraits<cutlass::half_t, TypeB, cutlass::arch::Sm70> {
private:
static constexpr int ThreadblockK = 32;
public:
static constexpr int Stages = 2;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using AccType = float;
using LayoutB = cutlass::layout::RowMajor;
static constexpr int ElementsPerAccessA =
128 / cutlass::sizeof_bits<cutlass::half_t>::value;
static constexpr int ElementsPerAccessB =
128 / cutlass::sizeof_bits<cutlass::half_t>::value;
static constexpr int ElementsPerAccessC =
128 / cutlass::sizeof_bits<cutlass::half_t>::value;
using ThreadBlockShape = cutlass::gemm::GemmShape<32, 128, ThreadblockK>;
using WarpShape = cutlass::gemm::GemmShape<32, 32, ThreadblockK>;
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
using Operator = cutlass::arch::OpMultiplyAdd;
};
template <typename TypeB>
struct MoeArchTraits<cutlass::bfloat16_t, TypeB, cutlass::arch::Sm70> {
private:
static constexpr int ThreadblockK = 32;
public:
static constexpr int Stages = 2;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using AccType = float;
using LayoutB = cutlass::layout::RowMajor;
static constexpr int ElementsPerAccessA =
128 / cutlass::sizeof_bits<cutlass::bfloat16_t>::value;
static constexpr int ElementsPerAccessB =
128 / cutlass::sizeof_bits<cutlass::bfloat16_t>::value;
static constexpr int ElementsPerAccessC =
128 / cutlass::sizeof_bits<cutlass::bfloat16_t>::value;
using ThreadBlockShape = cutlass::gemm::GemmShape<32, 128, ThreadblockK>;
using WarpShape = cutlass::gemm::GemmShape<32, 32, ThreadblockK>;
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
using Operator = cutlass::arch::OpMultiplyAdd;
};
// ======================= Turing Traits ==============================
// Turing will dequantize after LDSM
// fp16 x fp16 specialization
template <>
struct MoeArchTraits<cutlass::half_t, cutlass::half_t, cutlass::arch::Sm75> {
static constexpr int Stages = 2;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using AccType = float;
using LayoutB = cutlass::layout::RowMajor;
static constexpr int ElementsPerAccessA =
128 / cutlass::sizeof_bits<cutlass::half_t>::value;
static constexpr int ElementsPerAccessB =
128 / cutlass::sizeof_bits<cutlass::half_t>::value;
static constexpr int ElementsPerAccessC =
128 / cutlass::sizeof_bits<cutlass::half_t>::value;
using ThreadBlockShape = cutlass::gemm::GemmShape<32, 128, 32>;
using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using Operator = cutlass::arch::OpMultiplyAdd;
};
// bf16 x bf16 specialization
template <>
struct MoeArchTraits<cutlass::bfloat16_t,
cutlass::bfloat16_t,
cutlass::arch::Sm75> {
static constexpr int Stages = 2;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using AccType = float;
using LayoutB = cutlass::layout::RowMajor;
static constexpr int ElementsPerAccessA =
128 / cutlass::sizeof_bits<cutlass::bfloat16_t>::value;
static constexpr int ElementsPerAccessB =
128 / cutlass::sizeof_bits<cutlass::bfloat16_t>::value;
static constexpr int ElementsPerAccessC =
128 / cutlass::sizeof_bits<cutlass::bfloat16_t>::value;
using ThreadBlockShape = cutlass::gemm::GemmShape<32, 128, 32>;
using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using Operator = cutlass::arch::OpMultiplyAdd;
};
template <>
struct MoeArchTraits<float, float, cutlass::arch::Sm80> {
static constexpr int Stages = 3;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using AccType = float;
using LayoutB = cutlass::layout::RowMajor;
static constexpr int ElementsPerAccessA = 4;
static constexpr int ElementsPerAccessB = 4;
static constexpr int ElementsPerAccessC = 4;
using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 16>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using Operator = cutlass::arch::OpMultiplyAdd;
};
template <>
struct MoeArchTraits<cutlass::half_t, cutlass::half_t, cutlass::arch::Sm80> {
static constexpr int Stages = 3;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using AccType = float;
using LayoutB = cutlass::layout::RowMajor;
static constexpr int ElementsPerAccessA =
128 / cutlass::sizeof_bits<cutlass::half_t>::value;
static constexpr int ElementsPerAccessB =
128 / cutlass::sizeof_bits<cutlass::half_t>::value;
static constexpr int ElementsPerAccessC =
128 / cutlass::sizeof_bits<cutlass::half_t>::value;
using ThreadBlockShape = cutlass::gemm::GemmShape<32, 128, 32>;
using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using Operator = cutlass::arch::OpMultiplyAdd;
};
template <>
struct MoeArchTraits<cutlass::bfloat16_t,
cutlass::bfloat16_t,
cutlass::arch::Sm80> {
static constexpr int Stages = 3;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using AccType = float;
using LayoutB = cutlass::layout::RowMajor;
static constexpr int ElementsPerAccessA =
128 / cutlass::sizeof_bits<cutlass::bfloat16_t>::value;
static constexpr int ElementsPerAccessB =
128 / cutlass::sizeof_bits<cutlass::bfloat16_t>::value;
static constexpr int ElementsPerAccessC =
128 / cutlass::sizeof_bits<cutlass::bfloat16_t>::value;
using ThreadBlockShape = cutlass::gemm::GemmShape<32, 128, 32>;
using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using Operator = cutlass::arch::OpMultiplyAdd;
};
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Functor performing linear combination with a maximum operation used by
epilogues.
*/
#pragma once
#include <cuda.h>
#include <cutlass/half.h>
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/epilogue/thread/scale_type.h"
#include "cutlass/functional.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_types.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace thread {
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
/// Single source of truth for whether to unroll for `LinearCombinationClamp()`
constexpr bool LinearCombinationFtGeluIsHeavy() { return false; }
} // namespace detail
/////////////////////////////////////////////////////////////////////////////////////////////////
__forceinline__ __device__ float copysignf_pos(float a, float b) {
float r;
r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000));
return r;
}
__inline__ __device__ float tanh_opt(float x) {
#if (__CUDA_ARCH__ >= 750)
float r;
asm("tanh.approx.f32 %0,%1; \n\t" : "=f"(r) : "f"(x));
return r;
#else
const float exp_val = -1.f * fabs(2 * x);
return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x);
#endif
}
/////////////////////////////////////////////////////////////////////////////////////////////////
// GELU operator implemented using the Taylor series approximation
template <typename T>
struct FtGelu {
static const bool kIsHeavy = true;
CUTLASS_DEVICE
T operator()(T const &z) const {
T k0 = static_cast<float>(0.7978845608028654);
T k1 = static_cast<float>(0.044715);
return T(cutlass::constants::half<T>() * z *
(cutlass::constants::one<T>() +
fast_tanh(k0 * z * (cutlass::constants::one<T>() + k1 * z * z))));
}
};
template <>
struct FtGelu<float> {
static const bool kIsHeavy = true;
CUTLASS_DEVICE
float operator()(float const &z) const {
float k0 = static_cast<float>(0.7978845608028654);
float k1 = static_cast<float>(0.044715);
return float(
z *
(cutlass::constants::one<float>() +
tanh_opt(k0 * z * (cutlass::constants::one<float>() + k1 * z * z))));
}
};
template <int N>
struct FtGelu<Array<half_t, N>> {
static const bool kIsHeavy = true;
CUTLASS_DEVICE
Array<half_t, N> operator()(Array<half_t, N> const &z) const {
using T = half_t;
Array<half_t, N> y;
half_t k0 = half_t(0.7978845608028654);
half_t k1 = half_t(0.044715);
multiply_add<Array<half_t, N>> fma;
multiplies<Array<half_t, N>> mul;
plus<Array<half_t, N>> add;
fast_tanh_op<Array<half_t, N>> tanh;
Array<half_t, N> u =
mul(mul(k0, z), fma(mul(k1, z), z, cutlass::constants::one<T>()));
y = mul(mul(z, cutlass::constants::half<T>()),
add(cutlass::constants::one<T>(), tanh(u)));
return y;
}
};
template <typename T, int N>
struct FtGelu<Array<T, N>> {
static const bool kIsHeavy = true;
CUTLASS_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs) const {
Array<T, N> y;
FtGelu<T> gelu_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
y[i] = gelu_op(rhs[i]);
}
return y;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Applies a linear combination operator to an array of elements.
///
/// D = alpha * accumulator + beta * source + uniform
///
template <
typename ElementOutput_, ///< Data type used to load and store tensors
int Count, ///< Number of elements computed per operation
///< Usually it is 128/sizeof_bits<ElementOutput_>,
///< but we use 64 or 32 sometimes when there are not enough
///< data to store
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
typename ElementCompute_ =
ElementOutput_, ///< Data type used to compute linear combination
ScaleType::Kind Scale =
ScaleType::Default, ///< Control Alpha and Beta scaling
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest>
class LinearCombinationFtGelu {
public:
using ElementOutput = ElementOutput_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
static int const kCount = Count;
static const ScaleType::Kind kScale = Scale;
using FragmentOutput = Array<ElementOutput, kCount>;
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
using FragmentCompute = Array<ElementCompute, kCount>;
using FragmentScaleBias = Array<ElementCompute, kCount>;
static FloatRoundStyle const kRound = Round;
static bool const kIsHeavy = detail::LinearCombinationFtGeluIsHeavy();
/// Host-constructable parameters structure
struct Params {
ElementCompute alpha; ///< scales accumulators
ElementCompute beta; ///< scales source tensor
ElementCompute threshold; ///< minimum value that is output
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if
///< not null, loads it from memory
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not
///< null, loads it from memory
//
// Methods
//
CUTLASS_HOST_DEVICE
Params()
: alpha(ElementCompute(1)),
beta(ElementCompute(0)),
threshold(ElementCompute(0)),
alpha_ptr(nullptr),
beta_ptr(nullptr) {}
CUTLASS_HOST_DEVICE
Params(ElementCompute alpha,
ElementCompute beta = ElementCompute(0),
ElementCompute threshold = ElementCompute(0))
: alpha(alpha),
beta(beta),
threshold(threshold),
alpha_ptr(nullptr),
beta_ptr(nullptr) {}
CUTLASS_HOST_DEVICE
Params(ElementCompute const *alpha_ptr,
ElementCompute const *beta_ptr = nullptr,
ElementCompute threshold = ElementCompute(0))
: alpha(0),
beta(0),
threshold(threshold),
alpha_ptr(alpha_ptr),
beta_ptr(beta_ptr) {}
};
private:
//
// Data members
//
ElementCompute alpha_;
ElementCompute beta_;
ElementCompute threshold_;
public:
/// Constructs the function object, possibly loading from pointers in host
/// memory
CUTLASS_HOST_DEVICE
explicit LinearCombinationFtGelu(Params const &params) {
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
threshold_ = params.threshold;
}
/// Returns true if source is needed
CUTLASS_HOST_DEVICE
bool is_source_needed() const {
if (Scale == ScaleType::NoBetaScaling) return true;
if (Scale == ScaleType::OnlyAlphaScaling) return false;
if (Scale == ScaleType::Nothing) return false;
return beta_ != ElementCompute(0);
}
/// Functionally required for serial reduction in the epilogue
CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition, int k_partition_count) {
if (k_partition) {
beta_ = ElementCompute(1);
}
if (k_partition != k_partition_count - 1) {
// set to NaN to make ReLU no-op for all except last k partitions
int64_t allones = -1;
threshold_ = reinterpret_cast<ElementCompute const &>(allones);
}
}
/// Computes linear scaling: D = alpha * accumulator + beta * source
CUTLASS_HOST_DEVICE
FragmentOutput operator()(FragmentAccumulator const &accumulator,
FragmentOutput const &source) const {
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round>
source_converter;
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
accumulator_converter;
FragmentCompute converted_source = source_converter(source);
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
// Perform binary operations
FragmentCompute intermediate;
multiplies<FragmentCompute> mul_add_source;
multiply_add<FragmentCompute> mul_add_accumulator;
GELU<FragmentCompute> ftgelu;
if (Scale == ScaleType::NoBetaScaling) {
intermediate = converted_source;
intermediate =
mul_add_accumulator(alpha_,
converted_accumulator,
intermediate); // D = alpha * Accum + X
} else if (Scale == ScaleType::Nothing) {
intermediate = converted_accumulator;
} else {
intermediate =
mul_add_source(beta_, converted_source); // X = beta * C + uniform
intermediate =
mul_add_accumulator(alpha_,
converted_accumulator,
intermediate); // D = alpha * Accum + X
}
// Compute threshold optionally
intermediate = ftgelu(intermediate);
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
destination_converter;
return destination_converter(intermediate);
}
/// Computes linear scaling: D = alpha * accumulator
CUTLASS_HOST_DEVICE
FragmentOutput operator()(FragmentAccumulator const &accumulator) const {
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
accumulator_converter;
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
// Perform binary operations
FragmentCompute intermediate;
multiplies<FragmentCompute> mul_accumulator;
GELU<FragmentCompute> ftgelu;
if (Scale == ScaleType::Nothing) {
intermediate = converted_accumulator;
} else {
intermediate =
mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
}
// Compute threshold optionally
intermediate = ftgelu(intermediate);
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
destination_converter;
return destination_converter(intermediate);
}
/// Computes per-channel linear scaling and bias : D = scale * accumulator +
/// bias Scale and Bias are from input Fragment
CUTLASS_HOST_DEVICE
FragmentOutput operator()(FragmentAccumulator const &accumulator,
FragmentScaleBias const &scale,
FragmentScaleBias const &bias) const {
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
accumulator_converter;
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
// Perform per-channel scale and bias
FragmentCompute intermediate;
multiply_add<FragmentCompute> mul_add_accumulator;
if (Scale == ScaleType::OnlyAlphaPerChannelScaling)
intermediate = mul_add_accumulator(
scale, converted_accumulator, bias); // D = scale * Accum + bias
else
intermediate = mul_add_accumulator(
alpha_, converted_accumulator, bias); // D = alpha * Accum + bias
GELU<FragmentCompute> ftgelu;
// Compute threshold optionally
intermediate = ftgelu(intermediate);
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
destination_converter;
return destination_converter(intermediate);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Conditional guards to enable partial specialization for packed integers
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && \
((__CUDACC_VER_MAJOR__ > 10) || \
((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)))
/// Applies a linear combination operator to an array of elements.
///
/// D = alpha * accumulator + beta * source + uniform
///
/// Special handling for int types
template <typename ElementOutput_, ///< Data type used to load and store
///< tensors
int Count, ///< Number of elements computed per operation
ScaleType::Kind Scale, ///< Control Alpha and Beta scaling
FloatRoundStyle Round>
class LinearCombinationFtGelu<ElementOutput_, Count, int, float, Scale, Round> {
public:
using ElementOutput = ElementOutput_;
using ElementAccumulator = int;
using ElementCompute = float;
static bool const kIsHeavy = detail::LinearCombinationFtGeluIsHeavy();
static int const kCount = Count;
static const ScaleType::Kind kScale = Scale;
using FragmentOutput = Array<ElementOutput, kCount>;
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
using FragmentCompute = Array<ElementCompute, kCount>;
using FragmentScaleBias = Array<ElementCompute, kCount>;
static FloatRoundStyle const kRound = Round;
/// Host-constructable parameters structure
struct Params {
ElementCompute alpha; ///< scales accumulators
ElementCompute beta; ///< scales source tensor
ElementCompute threshold; ///< minimum value that is output
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if
///< not null, loads it from memory
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not
///< null, loads it from memory
//
// Methods
//
CUTLASS_HOST_DEVICE
Params()
: alpha(ElementCompute(1)),
beta(ElementCompute(0)),
threshold(ElementCompute(0)),
alpha_ptr(nullptr),
beta_ptr(nullptr) {}
CUTLASS_HOST_DEVICE
Params(ElementCompute alpha,
ElementCompute beta = ElementCompute(0),
ElementCompute threshold = ElementCompute(0))
: alpha(alpha),
beta(beta),
threshold(threshold),
alpha_ptr(nullptr),
beta_ptr(nullptr) {}
CUTLASS_HOST_DEVICE
Params(ElementCompute const *alpha_ptr,
ElementCompute const *beta_ptr = nullptr,
ElementCompute threshold = ElementCompute(0))
: alpha(0),
beta(0),
threshold(threshold),
alpha_ptr(alpha_ptr),
beta_ptr(beta_ptr) {}
};
private:
//
// Data members
//
ElementCompute alpha_;
ElementCompute beta_;
ElementCompute threshold_;
public:
/// Constructs the function object, possibly loading from pointers in host
/// memory
CUTLASS_HOST_DEVICE
explicit LinearCombinationFtGelu(Params const &params) {
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
threshold_ = params.threshold;
}
/// Returns true if source is needed
CUTLASS_HOST_DEVICE
bool is_source_needed() const {
if (Scale == ScaleType::NoBetaScaling) return true;
if (Scale == ScaleType::OnlyAlphaScaling) return false;
if (Scale == ScaleType::Nothing) return false;
return beta_ != ElementCompute(0);
}
/// Functionally required for serial reduction in the epilogue
CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition, int k_partition_count) {
if (k_partition) {
beta_ = ElementCompute(1);
}
if (k_partition != k_partition_count - 1) {
// set to NaN to make ReLU no-op for all except last k partitions
int64_t allones = -1;
threshold_ = reinterpret_cast<ElementCompute const &>(allones);
}
}
/// Computes linear scaling: D = alpha * accumulator + beta * source
CUTLASS_HOST_DEVICE
FragmentOutput operator()(FragmentAccumulator const &accumulator,
FragmentOutput const &source) const {
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round>
source_converter;
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
accumulator_converter;
FragmentCompute converted_source = source_converter(source);
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
// Perform binary operations
FragmentCompute intermediate;
multiplies<FragmentCompute> mul_add_source;
multiply_add<FragmentCompute> mul_add_accumulator;
GELU<FragmentCompute> ftgelu;
if (Scale == ScaleType::NoBetaScaling) {
intermediate = converted_source;
intermediate =
mul_add_accumulator(alpha_,
converted_accumulator,
intermediate); // D = alpha * Accum + X
} else if (Scale == ScaleType::Nothing) {
intermediate = converted_accumulator;
} else {
intermediate =
mul_add_source(beta_, converted_source); // X = beta * C + uniform
intermediate =
mul_add_accumulator(alpha_,
converted_accumulator,
intermediate); // D = alpha * Accum + X
}
// Compute threshold optionally
intermediate = ftgelu(intermediate);
if (platform::numeric_limits<ElementOutput>::is_integer) {
// Convert floats back to INT
FragmentAccumulator scaled_accumulator;
NumericArrayConverter<int, ElementCompute, kCount, Round>
compute_converter;
scaled_accumulator = compute_converter(intermediate);
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, int, kCount, Round>
destination_converter;
return destination_converter(scaled_accumulator);
} else {
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
destination_converter;
return destination_converter(intermediate);
}
}
/// Computes linear scaling: D = alpha * accumulator
CUTLASS_HOST_DEVICE
FragmentOutput operator()(FragmentAccumulator const &accumulator) const {
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
accumulator_converter;
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
// Perform binary operations
FragmentCompute intermediate;
multiplies<FragmentCompute> mul_accumulator;
GELU<FragmentCompute> ftgelu;
if (Scale == ScaleType::Nothing) {
intermediate = converted_accumulator;
} else {
intermediate =
mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
}
// Compute threshold optionally
intermediate = ftgelu(intermediate);
if (platform::numeric_limits<ElementOutput>::is_integer) {
// Convert floats back to INT
FragmentAccumulator scaled_accumulator;
NumericArrayConverter<int, ElementCompute, kCount, Round>
compute_converter;
scaled_accumulator = compute_converter(intermediate);
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, int, kCount, Round>
destination_converter;
return destination_converter(scaled_accumulator);
} else {
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
destination_converter;
return destination_converter(intermediate);
}
}
/// Computes per-channel linear scaling and bias : D = scale * accumulator +
/// bias Scale and Bias are from input Fragment
CUTLASS_HOST_DEVICE
FragmentOutput operator()(FragmentAccumulator const &accumulator,
FragmentScaleBias const &scale,
FragmentScaleBias const &bias) const {
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
accumulator_converter;
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
// Perform per-channel scale and bias
FragmentCompute intermediate;
multiply_add<FragmentCompute> mul_add_accumulator;
if (Scale == ScaleType::OnlyAlphaPerChannelScaling)
intermediate = mul_add_accumulator(
scale, converted_accumulator, bias); // D = scale * Accum + bias
else
intermediate = mul_add_accumulator(
alpha_, converted_accumulator, bias); // D = alpha * Accum + bias
GELU<FragmentCompute> ftgelu;
// Compute threshold optionally
intermediate = ftgelu(intermediate);
if (platform::numeric_limits<ElementOutput>::is_integer) {
// Convert floats back to INT
FragmentAccumulator scaled_accumulator;
NumericArrayConverter<int, ElementCompute, kCount, Round>
compute_converter;
scaled_accumulator = compute_converter(intermediate);
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, int, kCount, Round>
destination_converter;
return destination_converter(scaled_accumulator);
} else {
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
destination_converter;
return destination_converter(intermediate);
}
}
};
#endif // Conditional guards to enable partial specialization for packed
// integers
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace thread
} // namespace epilogue
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
*modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
*notice, this list of conditions and the following disclaimer in the
*documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its
*contributors may be used to endorse or promote products derived from this
*software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
*DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
*INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
*DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
*OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
*NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
*EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
*/
#pragma once
#include "cutlass/complex.h"
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/semaphore.h"
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
#include "cutlass/gemm/kernel/grouped_problem_visitor.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/trace.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/// Visitor class to abstract away the algorithm for iterating over tiles
template <typename ProblemSizeHelper, typename ThreadblockShape_>
struct BaseMoeProblemVisitor {
using ThreadblockShape = ThreadblockShape_;
struct ProblemInfo {
static int32_t const kNoPrefetchEntry = -1;
int32_t problem_idx;
int32_t problem_start;
CUTLASS_DEVICE
ProblemInfo()
: problem_idx(kNoPrefetchEntry), problem_start(kNoPrefetchEntry) {}
CUTLASS_DEVICE
ProblemInfo(int32_t problem_idx_, int32_t problem_start_)
: problem_idx(problem_idx_), problem_start(problem_start_) {}
};
struct Params {
int64_t const *last_row_for_problem;
int64_t gemm_n;
int64_t gemm_k;
int32_t problem_count;
void const *workspace;
int32_t tile_count;
//
// Methods
//
/// Ctor
CUTLASS_HOST_DEVICE
Params()
: last_row_for_problem(nullptr),
gemm_n(0),
gemm_k(0),
problem_count(0),
workspace(nullptr),
tile_count(0) {}
/// Ctor
CUTLASS_HOST_DEVICE
Params(int64_t const *last_row_for_problem,
int64_t gemm_n,
int64_t gemm_k,
int32_t problem_count,
void const *workspace = nullptr,
int32_t tile_count = 0)
: last_row_for_problem(last_row_for_problem),
gemm_n(gemm_n),
gemm_k(gemm_k),
problem_count(problem_count),
workspace(workspace),
tile_count(tile_count) {}
};
Params const &params;
int32_t tile_idx;
int32_t problem_tile_start;
int32_t problem_idx;
//
// Methods
//
CUTLASS_DEVICE
BaseMoeProblemVisitor(Params const &params_, int32_t block_idx)
: params(params_),
tile_idx(block_idx),
problem_tile_start(0),
problem_idx(0) {}
/// Get the grid shape
CUTLASS_HOST_DEVICE
static cutlass::gemm::GemmCoord grid_shape(
const cutlass::gemm::GemmCoord &problem) {
return cutlass::gemm::GemmCoord(
((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM),
((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN),
1);
}
/// Gets the global tile index
CUTLASS_HOST_DEVICE
int32_t tile_index() const { return tile_idx; }
/// Gets the index of the problem
CUTLASS_HOST_DEVICE
int32_t problem_index() const { return problem_idx; }
CUTLASS_HOST_DEVICE
int32_t threadblock_idx() const { return tile_idx - problem_tile_start; }
CUTLASS_DEVICE
void advance(int32_t grid_size) { tile_idx += grid_size; }
CUTLASS_HOST_DEVICE
static void possibly_transpose_problem(
cutlass::gemm::GemmCoord &problem) { // NOLINT
ProblemSizeHelper::possibly_transpose_problem(problem);
}
/// Returns the problem size for the current problem
CUTLASS_HOST_DEVICE
cutlass::gemm::GemmCoord problem_size() const {
return problem_size(problem_idx);
}
CUTLASS_HOST_DEVICE
cutlass::gemm::GemmCoord problem_size(int idx) const {
const int64_t prev_problem_row =
idx == 0 ? 0 : params.last_row_for_problem[idx - 1];
const int64_t current_problem_row = params.last_row_for_problem[idx];
const int64_t gemm_m = current_problem_row - prev_problem_row;
GemmCoord problem(GemmCoord::Index(gemm_m),
GemmCoord::Index(params.gemm_n),
GemmCoord::Index(params.gemm_k));
ProblemSizeHelper::possibly_transpose_problem(problem);
return problem;
}
CUTLASS_HOST_DEVICE
static int32_t tile_count(const cutlass::gemm::GemmCoord &grid) {
return ProblemSizeHelper::tile_count(grid);
}
static int32_t group_tile_count(
const cutlass::gemm::GemmCoord *host_problem_sizes_ptr,
int32_t problem_count) {
int32_t total_tiles = 0;
for (int32_t i = 0; i < problem_count; ++i) {
auto problem = host_problem_sizes_ptr[i];
possibly_transpose_problem(problem);
auto grid = grid_shape(problem);
total_tiles += tile_count(grid);
}
return total_tiles;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename ProblemSizeHelper,
typename ThreadblockShape,
GroupScheduleMode GroupScheduleMode_,
int PrefetchTileCount,
int ThreadCount>
struct MoeProblemVisitor;
/////////////////////////////////////////////////////////////////////////////////////////////////
// ProblemVisitor that performs all scheduling on device
//
template <typename ProblemSizeHelper,
typename ThreadblockShape,
int PrefetchTileCount,
int ThreadCount>
struct MoeProblemVisitor<ProblemSizeHelper,
ThreadblockShape,
GroupScheduleMode::kDeviceOnly,
PrefetchTileCount,
ThreadCount>
: public BaseMoeProblemVisitor<ProblemSizeHelper, ThreadblockShape> {
using Base = BaseMoeProblemVisitor<ProblemSizeHelper, ThreadblockShape>;
using Params = typename Base::Params;
static int const kThreadCount = ThreadCount;
static bool const kRequiresPrecomputation = false;
static int const kThreadsPerWarp = 32;
struct SharedStorage {};
// Final tile of the problem loaded by this thread. Each thread will hold
// a separate value.
int32_t problem_ending_tile;
SharedStorage &shared_storage;
//
// Methods
//
CUTLASS_DEVICE
MoeProblemVisitor(Params const &params_,
SharedStorage &shared_storage_, // NOLINT
int32_t block_idx)
: Base(params_, block_idx),
problem_ending_tile(0),
shared_storage(shared_storage_) {
this->problem_idx = -1 * kThreadsPerWarp;
this->problem_tile_start = 0;
}
CUTLASS_DEVICE
bool next_tile() {
// Check whether the tile to compute is within the range of the current
// problem.
int32_t problem_tile_end = __shfl_sync(
0xffffffff, problem_ending_tile, this->problem_idx % kThreadsPerWarp);
if (this->tile_idx < problem_tile_end) {
return true;
}
// Check whether the tile to compute is within the current group of problems
// fetched by the warp. The last tile for this group is the final tile of
// the problem held by the final thread in the warp.
int32_t group_tile_end =
__shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1);
// Keep the starting problem for this group in `problem_idx`. This is done
// to reduce register pressure. The starting problem for this group is
// simply the first problem in the group most recently fetched by the warp.
int32_t &group_problem_start = this->problem_idx;
group_problem_start =
(this->problem_idx / kThreadsPerWarp) * kThreadsPerWarp;
// Keep the starting tile for this group in `problem_tile_start`. This is
// done to reduce register pressure.
int32_t &group_tile_start = this->problem_tile_start;
// Each thread in the warp processes a separate problem to advance until
// reaching a problem whose starting tile is less less than tile_idx.
while (group_tile_end <= this->tile_idx) {
group_problem_start += kThreadsPerWarp;
if (group_problem_start > this->params.problem_count) {
return false;
}
// Since `group_tile_start` is a reference to `this->problem_tile_start`,
// this also sets `this->problem_tile_start`. The fact that
// `this->problem_tile_start` is also set here is used later in
// `next_tile`.
group_tile_start = group_tile_end;
int lane_idx = threadIdx.x % kThreadsPerWarp;
int32_t lane_problem = group_problem_start + lane_idx;
// Compute the number of tiles in the problem assigned to each thread.
problem_ending_tile = 0;
if (lane_problem < this->params.problem_count) {
cutlass::gemm::GemmCoord problem = this->problem_size(lane_problem);
cutlass::gemm::GemmCoord grid = this->grid_shape(problem);
problem_ending_tile = this->tile_count(grid);
}
// Compute a warp-wide inclusive prefix sum to compute the ending tile
// index of each thread's problem.
CUTLASS_PRAGMA_UNROLL
for (int i = 1; i < kThreadsPerWarp; i <<= 1) {
int32_t val = __shfl_up_sync(0xffffffff, problem_ending_tile, i);
if (lane_idx >= i) {
problem_ending_tile += val;
}
}
// The total tile count for this group is now in the final position of the
// prefix sum
int32_t tiles_in_group =
__shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1);
problem_ending_tile += group_tile_start;
group_tile_end += tiles_in_group;
}
// The next problem to process is the first one that does not have ending
// tile position that is greater than or equal to tile index.
int32_t problem_idx_in_group = __popc(
__ballot_sync(0xffffffff, problem_ending_tile <= this->tile_idx));
this->problem_idx = group_problem_start + problem_idx_in_group;
// The starting tile for this problem is the ending tile of the previous
// problem. In cases where `problem_idx_in_group` is the first problem in
// the group, we do not need to reset `problem_tile_start`, because it is
// set to the previous group's ending tile in the while loop above.
if (problem_idx_in_group > 0) {
this->problem_tile_start = __shfl_sync(
0xffffffff, problem_ending_tile, problem_idx_in_group - 1);
}
return true;
}
static size_t get_workspace_size(
const cutlass::gemm::GemmCoord *host_problem_sizes_ptr,
int32_t problem_count,
int32_t block_count) {
return 0;
}
static void host_precompute(
const cutlass::gemm::GemmCoord *host_problem_sizes_ptr,
int32_t problem_count,
int32_t block_count,
void *host_workspace_ptr) {}
};
/// Visitor class to abstract away the algorithm for iterating over tiles
template <typename ThreadblockShape,
GroupScheduleMode GroupScheduleMode_,
int PrefetchTileCount,
int ThreadCount,
bool Transposed = false>
struct GemmMoeProblemVisitor
: public MoeProblemVisitor<detail::GemmGroupedProblemSizeHelper<Transposed>,
ThreadblockShape,
GroupScheduleMode_,
PrefetchTileCount,
ThreadCount> {
static bool const kTransposed = Transposed;
using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper<Transposed>;
using Base = MoeProblemVisitor<ProblemSizeHelper,
ThreadblockShape,
GroupScheduleMode_,
PrefetchTileCount,
ThreadCount>;
using Params = typename Base::Params;
using SharedStorage = typename Base::SharedStorage;
//
// Methods
//
CUTLASS_DEVICE
GemmMoeProblemVisitor(Params const &params_,
SharedStorage &shared_storage_, // NOLINT
int32_t block_idx)
: Base(params_, shared_storage_, block_idx) {}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// This section exists to that we can use the same kernel code for regular gemm
// and dequantizing gemms. It will dispatch to the dequantizing gemm if the Mma
// type has an Iterator for scales in global.
template <typename...>
using void_t = void;
template <typename Mma, typename = void>
struct use_dq_gemm : platform::false_type {};
template <typename Mma>
struct use_dq_gemm<Mma, void_t<typename Mma::IteratorScale>>
: platform::true_type {};
// SFINAE overload for dequantizing gemm
template <
typename Mma,
typename ElementScale,
typename platform::enable_if<use_dq_gemm<Mma>::value, bool>::type = true>
CUTLASS_DEVICE static void run_mma(Mma mma,
int gemm_k_iterations,
typename Mma::FragmentC &accum, // NOLINT
typename Mma::IteratorA iterator_A,
typename Mma::IteratorB iterator_B,
typename Mma::FragmentC const &src_accum,
ElementScale *weight_scale_ptr,
MatrixCoord scale_extent,
const int thread_idx,
MatrixCoord tb_offset_scale) {
typename Mma::IteratorScale iterator_scale(
Mma::IteratorScale::Layout(scale_extent.column()),
weight_scale_ptr,
scale_extent,
thread_idx,
tb_offset_scale);
mma(gemm_k_iterations,
accum,
iterator_A,
iterator_B,
iterator_scale,
src_accum);
}
// SFINAE overload for normal gemm. This completely ignores the scale parameters
template <
typename Mma,
typename ElementScale,
typename platform::enable_if<!use_dq_gemm<Mma>::value, bool>::type = true>
CUTLASS_DEVICE static void run_mma(Mma mma,
int gemm_k_iterations,
typename Mma::FragmentC &accum, // NOLINT
typename Mma::IteratorA iterator_A,
typename Mma::IteratorB iterator_B,
typename Mma::FragmentC const &src_accum,
ElementScale *weight_scale_ptr,
MatrixCoord scale_extent,
const int thread_idx,
MatrixCoord tb_offset_scale) {
mma(gemm_k_iterations, accum, iterator_A, iterator_B, src_accum);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
GroupScheduleMode GroupScheduleMode_ ///! Type of scheduling to
/// perform
>
struct MoeFCGemm {
public:
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueOutputOp = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_;
static bool const kTransposed = false;
// Optional transpose
using MapArguments =
kernel::detail::MapArguments<typename Mma::IteratorA::Element,
typename Mma::IteratorA::Layout,
Mma::kTransformA,
Mma::IteratorA::AccessType::kElements,
typename Mma::IteratorB::Element,
typename Mma::IteratorB::Layout,
Mma::kTransformB,
Mma::IteratorB::AccessType::kElements,
typename Mma::LayoutC,
kTransposed>;
// Public-facing type definitions related to operand element type, layout, and
// complex conjugate operation. Must interact with the 'kTransposed' notion.
static_assert(!kTransposed, "Transpose problem not supported");
using ElementA = typename MapArguments::ElementA;
using LayoutA = typename MapArguments::LayoutA;
using ElementB = typename MapArguments::ElementB;
using LayoutB = typename MapArguments::LayoutB;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename MapArguments::LayoutC;
using ElementScale = ElementC;
static ComplexTransform const kTransformA = MapArguments::kTransformA;
static ComplexTransform const kTransformB = MapArguments::kTransformB;
// Type definitions about the mainloop.
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
using ArchTag = typename Mma::ArchTag;
static int const kStages = Mma::kStages;
static int const kAlignmentA = MapArguments::kAlignmentA;
static int const kAlignmentB = MapArguments::kAlignmentB;
static int const kAlignmentC =
Epilogue::OutputTileIterator::kElementsPerAccess;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
using ProblemVisitor = GemmMoeProblemVisitor<ThreadblockShape,
kGroupScheduleMode,
kThreadCount,
kThreadCount,
kTransposed>;
//
// Structures
//
/// Argument structure
struct Arguments {
//
// Data members
//
int problem_count;
int threadblock_count;
typename EpilogueOutputOp::Params output_op;
ElementA *ptr_A;
ElementB *ptr_B;
ElementScale *weight_scales;
ElementC *ptr_C;
ElementC *ptr_D;
int64_t *total_rows_before_expert;
int64_t gemm_n;
int64_t gemm_k;
// Only used by device-level operator
GemmCoord *host_problem_sizes;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments()
: problem_count(0),
threadblock_count(0),
ptr_A(nullptr),
ptr_B(nullptr),
weight_scales(nullptr),
ptr_C(nullptr),
ptr_D(nullptr),
total_rows_before_expert(nullptr),
gemm_n(0),
gemm_k(0),
host_problem_sizes(nullptr) {}
/// Ctor
CUTLASS_HOST_DEVICE
Arguments(int problem_count,
int threadblock_count,
typename EpilogueOutputOp::Params output_op,
const ElementA *ptr_A,
const ElementB *ptr_B,
const ElementScale *weight_scales,
const ElementC *ptr_C,
ElementC *ptr_D,
int64_t *total_rows_before_expert,
int64_t gemm_n,
int64_t gemm_k,
GemmCoord *host_problem_sizes = nullptr)
: problem_count(problem_count),
threadblock_count(threadblock_count),
output_op(output_op),
ptr_A(const_cast<ElementA *>(ptr_A)),
ptr_B(const_cast<ElementB *>(ptr_B)),
weight_scales(const_cast<ElementScale *>(weight_scales)),
ptr_C(const_cast<ElementC *>(ptr_C)),
ptr_D(ptr_D),
total_rows_before_expert(total_rows_before_expert),
gemm_n(gemm_n),
gemm_k(gemm_k),
host_problem_sizes(nullptr) {
if (platform::is_same<uint8_t, ElementB>::value ||
platform::is_same<uint4b_t, ElementB>::value) {
assert(weight_scales);
}
}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params {
typename ProblemVisitor::Params problem_visitor;
int threadblock_count;
typename EpilogueOutputOp::Params output_op;
ElementA *ptr_A;
ElementB *ptr_B;
ElementScale *weight_scales;
ElementC *ptr_C;
ElementC *ptr_D;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params()
: ptr_A(nullptr),
ptr_B(nullptr),
weight_scales(nullptr),
ptr_C(nullptr),
ptr_D(nullptr) {}
CUTLASS_HOST_DEVICE
Params(Arguments const &args,
void *workspace = nullptr,
int tile_count = 0) // NOLINT
: problem_visitor(args.total_rows_before_expert,
args.gemm_n,
args.gemm_k,
args.problem_count,
workspace,
tile_count),
threadblock_count(args.threadblock_count),
output_op(args.output_op),
ptr_A(args.ptr_A),
ptr_B(args.ptr_B),
weight_scales(args.weight_scales),
ptr_C(args.ptr_C),
ptr_D(args.ptr_D) {}
CUTLASS_HOST_DEVICE
void update(Arguments const &args,
void *workspace = nullptr,
int tile_count = 0) {
problem_visitor =
typename ProblemVisitor::Params(args.total_rows_before_expert,
args.gemm_n,
args.gemm_k,
args.problem_count,
workspace,
tile_count);
threadblock_count = args.threadblock_count;
output_op = args.output_op;
ptr_A = args.ptr_A;
ptr_B = args.ptr_B;
weight_scales = args.weight_scales;
ptr_C = args.ptr_C;
ptr_D = args.ptr_D;
}
};
/// Shared memory storage structure
union SharedStorage {
typename ProblemVisitor::SharedStorage problem_visitor;
typename Mma::SharedStorage main_loop;
typename Epilogue::SharedStorage epilogue;
};
public:
//
// Methods
//
CUTLASS_DEVICE
MoeFCGemm() {}
/// Determines whether kernel satisfies alignment
static Status can_implement(cutlass::gemm::GemmCoord const &problem_size) {
return Status::kSuccess;
}
static Status can_implement(Arguments const &args) {
if (platform::is_same<uint8_t, ElementB>::value ||
platform::is_same<uint4b_t, ElementB>::value) {
if (args.weight_scales == nullptr) {
CUTLASS_TRACE_HOST(
"MoeFCGemm::can_implement() - weight scales are required for "
"uint8_t and uint4b_t");
return Status::kInvalid;
}
} else if (args.weight_scales != nullptr) {
CUTLASS_TRACE_HOST(
"MoeFCGemm::can_implement() - weight scales are ignored for all "
"types except uint8_t and uint4b_t");
return Status::kInvalid;
}
return Status::kSuccess;
}
static size_t get_extra_workspace_size(
Arguments const &args, cutlass::gemm::GemmCoord const &grid_tiled_shape) {
return 0;
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const &params,
SharedStorage &shared_storage) { // NOLINT
//
// These types shadow the type-level definitions and support the ability to
// implement a 'transposed' GEMM that computes the transposed problems.
//
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
static constexpr int kInterleave =
Mma::IteratorB::Shape::kRow / Mma::SmemIteratorB::Shape::kRow;
static_assert(platform::is_same<LayoutB, layout::RowMajor>::value &&
kInterleave == 1 ||
platform::is_same<LayoutB, layout::ColumnMajor>::value &&
kInterleave >= 1,
"B must be row major/col major OR col major interleaved.");
//
// Problem visitor.
//
ProblemVisitor problem_visitor(
params.problem_visitor, shared_storage.problem_visitor, blockIdx.x);
const int64_t gemm_k = params.problem_visitor.gemm_k;
const int64_t gemm_n = params.problem_visitor.gemm_n;
int64_t bytes_per_expert_matrix =
(gemm_k * gemm_n / 8) * cutlass::sizeof_bits<ElementB>::value;
// Outer 'persistent' loop to iterate over tiles
while (problem_visitor.next_tile()) {
GemmCoord problem_size = problem_visitor.problem_size();
int32_t problem_idx = problem_visitor.problem_index();
int32_t cta_idx = int32_t(problem_visitor.threadblock_idx());
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
cutlass::gemm::GemmCoord threadblock_offset(
int(cta_idx / grid_shape.n()) * Mma::Shape::kM, // NOLINT
int(cta_idx % grid_shape.n()) * Mma::Shape::kN, // NOLINT
0);
// Load element pointers. Exchange pointers and strides if working on the
// transpose
const int64_t rows_to_jump =
problem_idx == 0
? 0
: params.problem_visitor.last_row_for_problem[problem_idx - 1];
ElementA *ptr_A =
reinterpret_cast<ElementA *>(params.ptr_A) + rows_to_jump * gemm_k;
typename LayoutA::LongIndex ldm_A = gemm_k;
char *byte_ptr_B = ((char *)params.ptr_B) + // NOLINT
problem_idx * bytes_per_expert_matrix;
ElementB *ptr_B = reinterpret_cast<ElementB *>(byte_ptr_B);
typename LayoutB::LongIndex ldm_B =
platform::is_same<layout::RowMajor, LayoutB>::value
? gemm_n
: gemm_k * kInterleave;
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_offset.m(),
0,
};
cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave};
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(LayoutA(ldm_A),
ptr_A,
{problem_size.m(), problem_size.k()},
thread_idx,
tb_offset_A);
typename Mma::IteratorB iterator_B(
LayoutB(ldm_B),
ptr_B,
{problem_size.k() * kInterleave, problem_size.n() / kInterleave},
thread_idx,
tb_offset_B);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Matrix multiply phase
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations =
(problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Wait for all threads to finish their epilogue phases from the previous
// tile.
__syncthreads();
// Compute threadblock-scoped matrix multiply-add
ElementScale *weight_scale_ptr =
params.weight_scales + problem_idx * problem_size.n();
run_mma<Mma>(mma,
gemm_k_iterations,
accumulators,
iterator_A,
iterator_B,
accumulators,
weight_scale_ptr,
{1, problem_size.n()},
thread_idx,
tb_offset_scale);
//
// Epilogue
//
EpilogueOutputOp output_op(params.output_op);
ElementC *ptr_C =
reinterpret_cast<ElementC *>(params.ptr_C) + problem_idx * gemm_n;
ElementC *ptr_D =
reinterpret_cast<ElementC *>(params.ptr_D) + rows_to_jump * gemm_n;
LayoutC layout_C(0);
LayoutC layout_D(gemm_n);
typename Epilogue::OutputTileIterator::Params params_C(layout_C);
typename Epilogue::OutputTileIterator::Params params_D(layout_D);
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(params_C,
ptr_C,
problem_size.mn(),
thread_idx,
threadblock_offset.mn());
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(params_D,
ptr_D,
problem_size.mn(),
thread_idx,
threadblock_offset.mn());
Epilogue epilogue(
shared_storage.epilogue, thread_idx, warp_idx, lane_idx);
// Execute the epilogue operator to update the destination tensor.
epilogue(output_op, iterator_D, accumulators, iterator_C);
// Next tile
problem_visitor.advance(gridDim.x);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fusion/moe_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/fusion/cutlass/moe_kernel_impl.h"
// Ignore CUTLASS warnings about type punning
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wunused-function"
#include "cutlass/array.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/linear_combination_relu.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/numeric_conversion.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/kernels/fusion/cutlass/default_moe_fc_traits.h"
#include "paddle/phi/kernels/fusion/cutlass/linear_combination_ft_gelu.h"
#include "paddle/phi/kernels/fusion/cutlass/moe_cutlass_kernel.h"
#pragma GCC diagnostic pop
namespace phi {
namespace {
inline int getSMVersion() {
const int device = phi::backends::gpu::GetCurrentDeviceId();
const phi::gpuDeviceProp prop =
phi::backends::gpu::GetDeviceProperties(device);
return prop.major * 10 + prop.minor;
}
struct EpilogueOpBiasReLU {};
struct EpilogueOpBiasFtGelu {};
struct EpilogueOpBias {};
struct EpilogueOpNoBias {};
template <typename ElementType,
int ElementsPerVectorAccess,
typename ElementAccumulator,
typename Op>
struct Epilogue {};
template <typename ElementType,
int ElementsPerVectorAccess,
typename ElementAccumulator>
struct Epilogue<ElementType,
ElementsPerVectorAccess,
ElementAccumulator,
EpilogueOpBiasReLU> {
using Op = cutlass::epilogue::thread::LinearCombinationRelu<
ElementType,
ElementsPerVectorAccess,
ElementAccumulator,
ElementAccumulator,
cutlass::epilogue::thread::ScaleType::NoBetaScaling>;
};
template <typename ElementType,
int ElementsPerVectorAccess,
typename ElementAccumulator>
struct Epilogue<ElementType,
ElementsPerVectorAccess,
ElementAccumulator,
EpilogueOpBiasFtGelu> {
using Op = cutlass::epilogue::thread::LinearCombinationFtGelu<
ElementType,
ElementsPerVectorAccess,
ElementAccumulator,
ElementAccumulator,
cutlass::epilogue::thread::ScaleType::NoBetaScaling>;
};
template <typename ElementType,
int ElementsPerVectorAccess,
typename ElementAccumulator>
struct Epilogue<ElementType,
ElementsPerVectorAccess,
ElementAccumulator,
EpilogueOpBias> {
using Op = cutlass::epilogue::thread::LinearCombination<
ElementType,
ElementsPerVectorAccess,
ElementAccumulator,
ElementAccumulator,
cutlass::epilogue::thread::ScaleType::NoBetaScaling>;
};
template <typename ElementType,
int ElementsPerVectorAccess,
typename ElementAccumulator>
struct Epilogue<ElementType,
ElementsPerVectorAccess,
ElementAccumulator,
EpilogueOpNoBias> {
using Op = cutlass::epilogue::thread::LinearCombination<
ElementType,
ElementsPerVectorAccess,
ElementAccumulator,
ElementAccumulator,
cutlass::epilogue::thread::ScaleType::Nothing>;
};
} // namespace
namespace fusion {
template <typename T>
void InitExpertChoiceRouteKernelLauncher(
int* expert_for_source_row,
int* source_row,
int* expanded_source_row_to_expanded_dest_row,
int64_t* total_rows_before_expert,
T* attr_mask,
const int num_experts,
const int num_rows,
const int k,
const int batch_size,
cudaStream_t stream) {
const int threads = 128;
const int blocks = num_experts;
initialize_expert_choice_route_kernel<<<blocks, threads, 0, stream>>>(
expert_for_source_row,
source_row,
expanded_source_row_to_expanded_dest_row,
total_rows_before_expert,
attr_mask,
num_rows,
k,
batch_size);
}
#define SOFTMAX_KERNEL(ITEMS_PER_THREAD) \
block.x /= ITEMS_PER_THREAD; \
assert(block.x <= 1024); \
if (is_half2) { \
if (grid.x % 4 == 0) { \
grid.x /= 4; \
softmax_kernel_v5_half2<__half, ITEMS_PER_THREAD, 4> \
<<<grid, block, 0, stream>>>(reinterpret_cast<half*>(buffer), \
(const half*)attr_mask, \
batch_size, \
head_num, \
seq_len_1, \
seq_len_2, \
(const half)scalar); \
} else { \
softmax_kernel_v4_half2<__half, ITEMS_PER_THREAD> \
<<<grid, block, 0, stream>>>(reinterpret_cast<half*>(buffer), \
(const half*)attr_mask, \
batch_size, \
head_num, \
seq_len_1, \
seq_len_2, \
(const half)scalar); \
} \
} else { \
softmax_kernel_v4<ITEMS_PER_THREAD, T> \
<<<grid, block, 0, stream>>>(buffer, \
buffer_src, \
attr_mask, \
batch_size, \
head_num, \
seq_len_1, \
seq_len_2, \
scalar); \
}
template <typename T>
void invokeMaskedSoftMax(T* buffer,
const T* buffer_src,
const T* attr_mask,
const int batch_size,
const int seq_len_1,
const int seq_len_2,
const int head_num,
const T scalar,
cudaStream_t stream) {
// NOTE: attention scores shape (batch_size, head_num, seq_len_1, seq_len_2)
dim3 grid(seq_len_1, batch_size, head_num);
if (batch_size * head_num > 360) {
grid.x = ceil(static_cast<float>(seq_len_1) / 32.0f);
}
bool is_half2 = sizeof(T) == 2 && sizeof(T) == 2 && seq_len_2 % 2 == 0;
dim3 block((seq_len_2 / (is_half2 ? 2 : 1) + 31) / 32 * 32);
if (block.x > 2048 && block.x <= 4096) {
SOFTMAX_KERNEL(4)
} else if (block.x > 1024) {
SOFTMAX_KERNEL(2)
} else if (block.x > 0) {
SOFTMAX_KERNEL(1)
} else {
PADDLE_ENFORCE_EQ(true,
false,
phi::errors::InvalidArgument(
"Softmax kernel only support columns in 0 - 4096. "));
}
}
template <typename T>
void InvokeTransposeAxis01(T* out,
T* in,
const int dim0,
const int dim1,
const int dim2,
cudaStream_t stream) {
dim3 block(512);
dim3 grid(static_cast<int>(ceil(dim0 * dim1 * dim2 / 512.)));
transposeAxis01<<<grid, block, 0, stream>>>(out, in, dim0, dim1, dim2);
}
template <typename T>
void InvokePadding(T* output1,
int* output2,
const T* input1,
const int* input2,
const int* input_lengths,
const int num_tokens,
const int batch_size,
const int max_seq_len,
const int num_experts,
cudaStream_t stream) {
assert(max_seq_len <= 1024);
dim3 block(max_seq_len);
dim3 grid(num_experts);
paddingKernel<<<grid, block, 0, stream>>>(output1,
output2,
input1,
input2,
input_lengths,
num_tokens,
batch_size,
max_seq_len,
num_experts);
}
template <typename T>
void InvokeGeneralTopKPairSort(T* out_keys,
int* out_values,
T* in_keys,
int* in_values,
const int m,
const int n,
cudaStream_t stream) {
assert(n <= 4096);
const int blocks = m;
if (n == 128) {
general_topk_pair_sort<T, 32, 4>
<<<blocks, 32, 0, stream>>>(out_keys, out_values, in_keys, in_values);
}
if (n == 256) {
general_topk_pair_sort<T, 64, 4>
<<<blocks, 64, 0, stream>>>(out_keys, out_values, in_keys, in_values);
}
if (n == 1024) {
general_topk_pair_sort<T, 256, 4>
<<<blocks, 256, 0, stream>>>(out_keys, out_values, in_keys, in_values);
} else if (n == 2048) {
general_topk_pair_sort<T, 512, 4>
<<<blocks, 512, 0, stream>>>(out_keys, out_values, in_keys, in_values);
} else if (n == 4096) {
general_topk_pair_sort<T, 1024, 4>
<<<blocks, 1024, 0, stream>>>(out_keys, out_values, in_keys, in_values);
}
}
template <typename T>
void InitMoeRoutingKernelLauncher(
const T* unpermuted_input,
T* permuted_output,
const int* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row,
const int num_experts,
const int num_rows,
const int active_rows,
const int cols,
const int k,
const int batch_size,
const int max_seq_len,
bool ec_route,
cudaStream_t stream) {
const int blocks = ec_route ? num_experts * k * batch_size : num_rows * k;
if (ec_route) {
constexpr int max_pack_size = 16 / sizeof(T);
const int threads = std::min(cols / max_pack_size, 1024);
if (cols % max_pack_size == 0) {
initialize_moe_routing_kernel<T, max_pack_size>
<<<blocks, threads, 0, stream>>>(
unpermuted_input,
permuted_output,
expanded_dest_row_to_expanded_source_row,
expanded_source_row_to_expanded_dest_row,
num_rows,
batch_size * k * num_experts,
cols,
k,
max_seq_len,
ec_route);
} else {
initialize_moe_routing_kernel<T, 1><<<blocks, threads, 0, stream>>>(
unpermuted_input,
permuted_output,
expanded_dest_row_to_expanded_source_row,
expanded_source_row_to_expanded_dest_row,
num_rows,
batch_size * k * num_experts,
cols,
k,
max_seq_len,
ec_route);
}
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Currently only support `ec_route = True`. "));
}
}
template <typename T, typename WeightType, typename arch, typename EpilogueType>
void GenericMoeGemmKernelLauncher(const T* A,
const T* B,
const T* weight_scales,
const T* biases,
T* C,
int64_t* total_rows_before_expert,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const int multi_processor_count,
cudaStream_t stream) {
static_assert(cutlass::platform::is_same<T, half>::value ||
cutlass::platform::is_same<T, float>::value,
"Specialized for half, float");
static_assert(
cutlass::platform::is_same<T, WeightType>::value ||
cutlass::platform::is_same<WeightType, uint8_t>::value ||
cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value,
"cutlass weight type only support float, half, uint8_t, uint4b_t");
// The cutlass type for the input elements. This is needed to convert to
// cutlass::half_t if necessary.
using ElementType_ = typename cutlass::platform::conditional<
cutlass::platform::is_same<T, half>::value,
cutlass::half_t,
T>::type;
using ElementType = ElementType_;
using CutlassWeightType_ = typename cutlass::platform::conditional<
cutlass::platform::is_same<WeightType, half>::value,
cutlass::half_t,
WeightType>::type;
using CutlassWeightType = CutlassWeightType_;
// We need separate config for each architecture since we will target
// different tensorcore instructions. For float, we do not target TCs.
using MoeArchTraits = cutlass::gemm::kernel::
MoeArchTraits<ElementType, CutlassWeightType, arch>;
using ElementAccumulator = typename MoeArchTraits::AccType;
using EpilogueOp = typename Epilogue<ElementType,
MoeArchTraits::ElementsPerAccessC,
ElementAccumulator,
EpilogueType>::Op;
// Finally, set up the kernel.
using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemmGrouped<
ElementType,
cutlass::layout::RowMajor,
cutlass::ComplexTransform::kNone,
MoeArchTraits::ElementsPerAccessA,
CutlassWeightType,
typename MoeArchTraits::LayoutB,
cutlass::ComplexTransform::kNone,
MoeArchTraits::ElementsPerAccessB,
ElementType,
cutlass::layout::RowMajor,
ElementAccumulator,
typename MoeArchTraits::OperatorClass,
arch,
typename MoeArchTraits::ThreadBlockShape,
typename MoeArchTraits::WarpShape,
typename MoeArchTraits::InstructionShape,
EpilogueOp,
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
MoeArchTraits::Stages,
cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,
typename MoeArchTraits::Operator>::GemmKernel;
using GemmKernel =
cutlass::gemm::kernel::MoeFCGemm<typename GemmKernel_::Mma,
typename GemmKernel_::Epilogue,
typename GemmKernel_::ThreadblockSwizzle,
GemmKernel_::kGroupScheduleMode>;
using GemmGrouped = cutlass::gemm::device::GemmGrouped<GemmKernel>;
int occupancy = GemmGrouped::maximum_active_blocks();
const int threadblock_count = multi_processor_count * occupancy;
if (occupancy == 0) {
PADDLE_THROW(paddle::platform::errors::Fatal(
"[MoE Runner] GPU lacks the shared memory resources to run GroupedGEMM "
"kernel"));
}
typename EpilogueOp::Params epilogue_op(ElementAccumulator(1.f),
ElementAccumulator(1.f));
typename GemmGrouped::Arguments args(
num_experts,
threadblock_count,
epilogue_op,
reinterpret_cast<const ElementType*>(A),
reinterpret_cast<const CutlassWeightType*>(B),
reinterpret_cast<const ElementType*>(weight_scales),
reinterpret_cast<const ElementType*>(biases),
reinterpret_cast<ElementType*>(C),
total_rows_before_expert,
gemm_n,
gemm_k);
GemmGrouped gemm;
auto can_implement = gemm.can_implement(args);
if (can_implement != cutlass::Status::kSuccess) {
std::string err_msg = "MoEFC kernel will fail for params. Error: " +
std::string(cutlassGetStatusString(can_implement));
PADDLE_THROW(paddle::platform::errors::Fatal("[MoE Runner] " + err_msg));
}
auto init_status = gemm.initialize(args);
if (init_status != cutlass::Status::kSuccess) {
std::string err_msg =
"Failed to initialize cutlass variable batched gemm. Error: " +
std::string(cutlassGetStatusString(init_status));
PADDLE_THROW(paddle::platform::errors::Fatal("[MoE Runner] " + err_msg));
}
auto run_status = gemm.run(stream);
if (run_status != cutlass::Status::kSuccess) {
std::string err_msg =
"Failed to run cutlass variable batched gemm. Error: " +
std::string(cutlassGetStatusString(run_status));
PADDLE_THROW(paddle::platform::errors::Fatal("[MoE Runner] " + err_msg));
}
}
template <typename T>
void gemm_bias_act(const T* A,
const T* B,
const T* weight_scales,
const T* biases,
T* C,
int64_t* total_rows_before_expert,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
int sm,
int multi_processor_count,
const std::string& act_type,
cudaStream_t stream) {
if (act_type == "gelu") {
if (sm == 75) {
GenericMoeGemmKernelLauncher<T,
T,
cutlass::arch::Sm75,
EpilogueOpBiasFtGelu>(
A,
B,
weight_scales,
biases,
C,
total_rows_before_expert,
gemm_n,
gemm_k,
num_experts,
multi_processor_count,
stream);
} else if (sm == 80 || sm == 86) {
GenericMoeGemmKernelLauncher<T,
T,
cutlass::arch::Sm80,
EpilogueOpBiasFtGelu>(
A,
B,
weight_scales,
biases,
C,
total_rows_before_expert,
gemm_n,
gemm_k,
num_experts,
multi_processor_count,
stream);
} else {
GenericMoeGemmKernelLauncher<T,
T,
cutlass::arch::Sm70,
EpilogueOpBiasFtGelu>(
A,
B,
weight_scales,
biases,
C,
total_rows_before_expert,
gemm_n,
gemm_k,
num_experts,
multi_processor_count,
stream);
}
} else {
// act type is relu.
if (sm == 75) {
GenericMoeGemmKernelLauncher<T,
T,
cutlass::arch::Sm75,
EpilogueOpBiasReLU>(A,
B,
weight_scales,
biases,
C,
total_rows_before_expert,
gemm_n,
gemm_k,
num_experts,
multi_processor_count,
stream);
} else if (sm == 80 || sm == 86) {
GenericMoeGemmKernelLauncher<T,
T,
cutlass::arch::Sm80,
EpilogueOpBiasReLU>(A,
B,
weight_scales,
biases,
C,
total_rows_before_expert,
gemm_n,
gemm_k,
num_experts,
multi_processor_count,
stream);
} else {
GenericMoeGemmKernelLauncher<T,
T,
cutlass::arch::Sm70,
EpilogueOpBiasReLU>(A,
B,
weight_scales,
biases,
C,
total_rows_before_expert,
gemm_n,
gemm_k,
num_experts,
multi_processor_count,
stream);
}
}
}
template <typename T>
void gemm(const T* A,
const T* B,
const T* weight_scales,
T* C,
int64_t* total_rows_before_expert,
const int gemm_n,
const int gemm_k,
const int num_experts,
int sm,
int multi_processor_count,
cudaStream_t stream) {
if (sm == 75) {
GenericMoeGemmKernelLauncher<T, T, cutlass::arch::Sm75, EpilogueOpNoBias>(
A,
B,
weight_scales,
nullptr,
C,
total_rows_before_expert,
gemm_n,
gemm_k,
num_experts,
multi_processor_count,
stream);
} else if (sm == 80 || sm == 86) {
GenericMoeGemmKernelLauncher<T, T, cutlass::arch::Sm80, EpilogueOpNoBias>(
A,
B,
weight_scales,
nullptr,
C,
total_rows_before_expert,
gemm_n,
gemm_k,
num_experts,
multi_processor_count,
stream);
} else {
GenericMoeGemmKernelLauncher<T, T, cutlass::arch::Sm70, EpilogueOpNoBias>(
A,
B,
weight_scales,
nullptr,
C,
total_rows_before_expert,
gemm_n,
gemm_k,
num_experts,
multi_processor_count,
stream);
}
}
template <typename T>
void finalize_moe_routing_kernelLauncher(
const T* expanded_permuted_rows,
T* reduced_unpermuted_output,
const T* skip,
const T* bias,
const T* scales,
const int* expanded_source_row_to_expanded_dest_row,
const int* expert_for_source_row,
const int num_experts,
const int num_rows,
const int cols,
const int k,
bool ec_route,
cudaStream_t stream) {
const int blocks = num_rows;
const int threads = std::min(cols, 1024);
{
finalize_moe_routing_kernel<T><<<blocks, threads, 0, stream>>>(
expanded_permuted_rows,
reduced_unpermuted_output,
skip,
bias,
scales,
expanded_source_row_to_expanded_dest_row,
expert_for_source_row,
cols,
num_experts,
ec_route);
}
}
template <typename T, typename Context>
void MoeKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& gate,
const DenseTensor& bmm0,
const DenseTensor& bias0,
const DenseTensor& bmm1,
const DenseTensor& bias1,
const std::string& act_type,
DenseTensor* output) {
const T* input_activations = x.data<T>();
T* gating_output = const_cast<T*>(gate.data<T>());
const T* fc1_expert_weights = bmm0.data<T>();
const T* fc1_expert_biases = bias0.data<T>();
const T* fc2_expert_weights = bmm1.data<T>();
const T* fc2_expert_biases = bias1.data<T>();
// int moe_act = static_cast<int>(act);
T* output_ = ctx.template Alloc<T>(output);
auto stream = ctx.stream();
auto input_dims = x.dims();
auto bmm0_dims = bmm0.dims();
const bool IS_FP16 = std::is_same<T, phi::dtype::float16>::value;
const int num_rows = input_dims[0] * input_dims[1];
const int hidden_size = input_dims[2];
const int inter_size = bmm0_dims[2];
const int num_experts = bmm0_dims[0];
const int k = input_dims[1] / 16;
const int batch_size = input_dims[0];
const int max_seq_len = 128;
int64_t bytes = getWorkspaceSize<T>(num_rows,
hidden_size,
inter_size,
num_experts,
k,
batch_size,
max_seq_len);
// Pointers
int* source_rows;
int* padded_source_rows;
int* permuted_rows;
int* permuted_experts;
char* sorter_ws_;
T* permuted_data;
T* padded_expert_scales;
int64_t* total_rows_before_expert;
T* sorted_softmax_output;
T* attr_mask;
T* fc1_result;
phi::DenseTensor ws_ptr_tensor = phi::Empty<int8_t>(ctx, {bytes});
int8_t* ws_ptr = ws_ptr_tensor.data<int8_t>();
const int buf_size = AlignTo16(num_experts * batch_size * k * hidden_size);
const int padded_experts = AlignTo16(num_experts);
const int num_moe_inputs = AlignTo16(num_experts * num_rows);
// padded_num_moe_inputs for topk sort
int padded_num_moe_inputs = num_experts * batch_size * max_seq_len;
source_rows = reinterpret_cast<int*>(ws_ptr);
padded_source_rows = source_rows + num_moe_inputs;
permuted_rows = padded_source_rows + padded_num_moe_inputs;
permuted_experts = permuted_rows + padded_num_moe_inputs;
permuted_data = reinterpret_cast<T*>(permuted_experts + num_experts * k);
padded_expert_scales = reinterpret_cast<T*>(permuted_data + buf_size);
total_rows_before_expert =
reinterpret_cast<int64_t*>(padded_expert_scales + padded_num_moe_inputs);
sorted_softmax_output =
reinterpret_cast<T*>(total_rows_before_expert + padded_experts);
attr_mask =
reinterpret_cast<T*>(sorted_softmax_output + padded_num_moe_inputs);
fc1_result = reinterpret_cast<T*>(attr_mask + num_moe_inputs);
phi::DenseTensor expert_for_source_row_tensor =
phi::Empty<int>(ctx, {num_experts, num_rows});
int* expert_for_source_row = expert_for_source_row_tensor.data<int>();
phi::DenseTensor expanded_source_row_to_expanded_dest_row_tensor =
phi::Empty<int>(ctx, {num_experts, num_rows});
int* expanded_source_row_to_expanded_dest_row =
expanded_source_row_to_expanded_dest_row_tensor.data<int>();
phi::DenseTensor expert_scales_tensor =
phi::Empty<T>(ctx, {num_experts, num_rows});
T* expert_scales = expert_scales_tensor.data<T>();
phi::DenseTensor fc2_output_tensor =
phi::Empty<T>(ctx, {num_experts * batch_size * k, hidden_size});
T* fc2_result = fc2_output_tensor.data<T>();
phi::DenseTensor input_lengths_tensor = phi::Empty<int>(ctx, {batch_size});
int* input_lengths = input_lengths_tensor.data<int>();
funcs::SetConstant<Context, int> set_len;
set_len(ctx, &input_lengths_tensor, static_cast<int>(max_seq_len));
int sm = getSMVersion();
int multi_processor_count = phi::backends::gpu::GetGPUMultiProcessors(
phi::backends::gpu::GetCurrentDeviceId());
InitExpertChoiceRouteKernelLauncher<T>(
expert_for_source_row,
source_rows,
expanded_source_row_to_expanded_dest_row,
total_rows_before_expert,
attr_mask,
num_experts,
num_rows,
k,
batch_size,
ctx.stream());
T scalar = (T)1.0f;
if (IS_FP16) {
invokeMaskedSoftMax<__half>(reinterpret_cast<__half*>(gating_output),
reinterpret_cast<const __half*>(gating_output),
reinterpret_cast<const __half*>(attr_mask),
/*batch_size=*/num_rows,
/*seq_len_1=*/1,
/*seq_len_2=*/num_experts,
/*head_num=*/1,
*reinterpret_cast<const __half*>(&scalar),
ctx.stream());
} else {
invokeMaskedSoftMax<float>(reinterpret_cast<float*>(gating_output),
reinterpret_cast<const float*>(gating_output),
reinterpret_cast<const float*>(attr_mask),
/*batch_size=*/num_rows,
/*seq_len_1=*/1,
/*seq_len_2=*/num_experts,
/*head_num=*/1,
*reinterpret_cast<const float*>(&scalar),
ctx.stream());
}
InvokeTransposeAxis01(
expert_scales, gating_output, num_rows, num_experts, 1, ctx.stream());
int padded_max_seq_len = max_seq_len <= 128 ? 128 : 256;
InvokePadding(padded_expert_scales,
padded_source_rows,
expert_scales,
source_rows,
input_lengths,
num_rows,
batch_size,
padded_max_seq_len,
num_experts,
ctx.stream());
if (IS_FP16) {
InvokeGeneralTopKPairSort<__half>(
reinterpret_cast<__half*>(sorted_softmax_output),
permuted_rows,
reinterpret_cast<__half*>(padded_expert_scales),
padded_source_rows,
num_experts * batch_size,
padded_max_seq_len,
ctx.stream());
} else {
InvokeGeneralTopKPairSort<float>(
reinterpret_cast<float*>(sorted_softmax_output),
permuted_rows,
reinterpret_cast<float*>(padded_expert_scales),
padded_source_rows,
num_experts * batch_size,
padded_max_seq_len,
ctx.stream());
}
InitMoeRoutingKernelLauncher(input_activations,
permuted_data,
permuted_rows,
expanded_source_row_to_expanded_dest_row,
num_experts,
num_rows,
num_rows,
hidden_size,
k,
batch_size,
max_seq_len,
true,
ctx.stream());
const T* fc1_scales = nullptr;
const T* fc2_scales = nullptr;
if (IS_FP16) {
gemm_bias_act(reinterpret_cast<const __half*>(permuted_data),
reinterpret_cast<const __half*>(fc1_expert_weights),
reinterpret_cast<const __half*>(fc1_scales),
reinterpret_cast<const __half*>(fc1_expert_biases),
reinterpret_cast<__half*>(fc1_result),
total_rows_before_expert,
inter_size,
hidden_size,
num_experts,
sm,
multi_processor_count,
act_type,
ctx.stream());
gemm(reinterpret_cast<const __half*>(fc1_result),
reinterpret_cast<const __half*>(fc2_expert_weights),
reinterpret_cast<const __half*>(fc2_scales),
reinterpret_cast<__half*>(fc2_result),
total_rows_before_expert,
hidden_size,
inter_size,
num_experts,
sm,
multi_processor_count,
ctx.stream());
} else {
gemm_bias_act<float>(reinterpret_cast<const float*>(permuted_data),
reinterpret_cast<const float*>(fc1_expert_weights),
reinterpret_cast<const float*>(fc1_scales),
reinterpret_cast<const float*>(fc1_expert_biases),
reinterpret_cast<float*>(fc1_result),
total_rows_before_expert,
inter_size,
hidden_size,
num_experts,
sm,
multi_processor_count,
act_type,
ctx.stream());
gemm<float>(reinterpret_cast<const float*>(fc1_result),
reinterpret_cast<const float*>(fc2_expert_weights),
reinterpret_cast<const float*>(fc2_scales),
reinterpret_cast<float*>(fc2_result),
total_rows_before_expert,
hidden_size,
inter_size,
num_experts,
sm,
multi_processor_count,
ctx.stream());
}
finalize_moe_routing_kernelLauncher(fc2_result,
output_,
input_activations,
fc2_expert_biases,
expert_scales,
expanded_source_row_to_expanded_dest_row,
expert_for_source_row,
num_experts,
num_rows,
hidden_size,
k,
true,
ctx.stream());
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(
moe, GPU, ALL_LAYOUT, phi::fusion::MoeKernel, float, phi::dtype::float16) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "cub/cub.cuh"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
namespace phi {
static const float HALF_FLT_MAX = 65504.F;
static const float HALF_FLT_MIN = -65504.F;
static inline size_t AlignTo16(const size_t& input) {
static constexpr int ALIGNMENT = 16;
return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT);
}
/*
WarpReduce multi values.
TODO(zhengzekang): Add blocksize templates to reduce shared memory usage.
*/
template <typename T, int NUM>
__inline__ __device__ T warpReduceSumV2(T* val) {
#pragma unroll
for (int i = 0; i < NUM; i++) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32);
}
return (T)(0.0f);
}
template <typename T, int NUM>
__inline__ __device__ T blockReduceSumV2(T* val) {
static __shared__ T shared[NUM][33];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
warpReduceSumV2<T, NUM>(val);
if (lane == 0) {
#pragma unroll
for (int i = 0; i < NUM; i++) {
shared[i][wid] = val[i];
}
}
__syncthreads();
bool is_mask = threadIdx.x < (blockDim.x / 32.f);
#pragma unroll
for (int i = 0; i < NUM; i++) {
val[i] = is_mask ? shared[i][lane] : (T)(0.0f);
}
warpReduceSumV2<T, NUM>(val);
return (T)0.0f;
}
template <typename T, int NUM>
__inline__ __device__ T warpReduceMaxV2(T* val) {
#pragma unroll
for (int i = 0; i < NUM; i++) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32));
}
return (T)(0.0f);
}
template <typename T, int NUM>
__inline__ __device__ T blockReduceMaxV2(T* val) {
static __shared__ T shared[32][NUM];
int lane = threadIdx.x & 0x1f; // in-warp idx
int wid = threadIdx.x >> 5; // warp idx
warpReduceMaxV2<T, NUM>(val); // get maxx in each warp
if (lane == 0) { // record in-warp maxx by warp Idx
#pragma unroll
for (int i = 0; i < NUM; i++) {
shared[wid][i] = val[i];
}
}
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
bool is_mask = threadIdx.x < (blockDim.x / 32.f);
#pragma unroll
for (int i = 0; i < NUM; i++) {
val[i] = is_mask ? shared[lane][i] : (T)-1e20f;
}
warpReduceMaxV2<T, NUM>(val);
return (T)0.0f;
}
class CubKeyValueSorter {
public:
CubKeyValueSorter();
explicit CubKeyValueSorter(const int num_experts);
void update_num_experts(const int num_experts);
size_t getWorkspaceSize(const size_t num_key_value_pairs,
bool descending = false);
template <typename KeyT>
void run(void* workspace,
const size_t workspace_size,
const KeyT* keys_in,
KeyT* keys_out,
const int* values_in,
int* values_out,
const size_t num_key_value_pairs,
bool descending,
cudaStream_t stream);
private:
size_t num_key_value_pairs_;
int num_experts_;
int num_bits_;
};
// ===== CUB Sorting things =====
CubKeyValueSorter::CubKeyValueSorter()
: num_experts_(0), num_bits_(sizeof(int) * 8) {}
CubKeyValueSorter::CubKeyValueSorter(const int num_experts)
: num_experts_(num_experts),
num_bits_(static_cast<int>(log2(num_experts)) + 1) {}
void CubKeyValueSorter::update_num_experts(const int num_experts) {
num_experts_ = num_experts;
num_bits_ = static_cast<int>(log2(num_experts)) + 1;
}
size_t CubKeyValueSorter::getWorkspaceSize(const size_t num_key_value_pairs,
bool descending) {
num_key_value_pairs_ = num_key_value_pairs;
size_t required_storage = 0;
int* null_int = nullptr;
if (descending) {
cub::DeviceRadixSort::SortPairsDescending(NULL,
required_storage,
null_int,
null_int,
null_int,
null_int,
num_key_value_pairs,
0,
32);
} else {
cub::DeviceRadixSort::SortPairs(NULL,
required_storage,
null_int,
null_int,
null_int,
null_int,
num_key_value_pairs,
0,
num_bits_);
}
return required_storage;
}
template <typename KeyT>
void CubKeyValueSorter::run(void* workspace,
const size_t workspace_size,
const KeyT* keys_in,
KeyT* keys_out,
const int* values_in,
int* values_out,
const size_t num_key_value_pairs,
bool descending,
cudaStream_t stream) {
size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs);
size_t actual_ws_size = workspace_size;
if (expected_ws_size > workspace_size) {
std::stringstream err_ss;
err_ss << "[Error][CubKeyValueSorter::run]\n";
err_ss
<< "Error. The allocated workspace is too small to run this problem.\n";
err_ss << "Expected workspace size of at least " << expected_ws_size
<< " but got problem size " << workspace_size << "\n";
throw std::runtime_error(err_ss.str());
}
if (descending) {
cub::DeviceRadixSort::SortPairsDescending(workspace,
actual_ws_size,
keys_in,
keys_out,
values_in,
values_out,
num_key_value_pairs,
0,
32,
stream);
} else {
cub::DeviceRadixSort::SortPairs(workspace,
actual_ws_size,
keys_in,
keys_out,
values_in,
values_out,
num_key_value_pairs,
0,
num_bits_,
stream);
}
}
template <>
void CubKeyValueSorter::run(void* workspace,
const size_t workspace_size,
const __nv_bfloat16* keys_in,
__nv_bfloat16* keys_out,
const int* values_in,
int* values_out,
const size_t num_key_value_pairs,
bool descending,
cudaStream_t stream) {}
CubKeyValueSorter sorter_;
// -------- getWorkspaceSize -------- //
template <typename T>
size_t getWorkspaceSize(const int num_rows,
const int hidden_size,
const int inter_size,
const int num_experts,
const int k,
const int batch_size,
const int max_seq_len) {
const int buf_size = AlignTo16(num_experts * batch_size * k * hidden_size);
const int interbuf_size =
AlignTo16(num_experts * batch_size * k * inter_size);
const int padded_experts = AlignTo16(num_experts);
const int num_moe_inputs = AlignTo16(num_experts * num_rows);
int padded_num_moe_inputs = num_experts * batch_size * max_seq_len;
size_t total_ws_bytes = sizeof(int) * num_moe_inputs; // source_rows_
total_ws_bytes += sizeof(int) * padded_num_moe_inputs; // padded_source_rows_
total_ws_bytes += sizeof(T) * padded_num_moe_inputs; // padded_expert_scales_
total_ws_bytes += sizeof(int) * padded_num_moe_inputs; // permuted_rows_
total_ws_bytes += sizeof(int) * num_experts * k; // permuted_experts_
total_ws_bytes += buf_size * sizeof(T); // permuted_data_
total_ws_bytes +=
padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_
total_ws_bytes += sizeof(T) * num_moe_inputs; // attr_mask: [e, n]
total_ws_bytes += sizeof(T) * padded_num_moe_inputs; // sorted_softmax_output
const int bytes_for_fc1_result = interbuf_size * sizeof(T);
const int sorter_ws_size_bytes =
AlignTo16(sorter_.getWorkspaceSize(num_experts * k));
sorter_.update_num_experts(k);
int bytes_for_intermediate_and_sorting = bytes_for_fc1_result;
if (sorter_ws_size_bytes > bytes_for_fc1_result) {
int remaining_bytes =
AlignTo16(sorter_ws_size_bytes - bytes_for_fc1_result);
bytes_for_intermediate_and_sorting += remaining_bytes;
}
total_ws_bytes +=
bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub
// sorting workspace
return total_ws_bytes;
}
// -------- initialize_expert_choice_route_kernel -------- //
template <typename T>
__global__ void initialize_expert_choice_route_kernel(
int* expert_for_source_row,
int* source_row,
int* expanded_source_row_to_expanded_dest_row,
int64_t* total_rows_before_expert,
T* attr_mask,
const int cols,
const int k,
const int batch_size) {
int start = cols * blockIdx.x;
for (int i = threadIdx.x; i < cols; i += blockDim.x) {
expert_for_source_row[start + i] = blockIdx.x;
source_row[start + i] = start + i;
expanded_source_row_to_expanded_dest_row[start + i] = -1;
attr_mask[start + i] = (T)1.0f;
}
if (threadIdx.x == 0) {
total_rows_before_expert[blockIdx.x] = batch_size * k * (blockIdx.x + 1);
}
}
// -------- softmax_kernel -------- //
template <int ITEMS_PER_THREAD, typename T>
__global__ void softmax_kernel_v4(
T* qk_buf_,
const T* qk_buf_src, // shape [batch_size, head_num, seq_len_1, seq_len_2]
const T* attr_mask, // shape [batch_size, seq_len_1, seq_len_2]
const int batch_size,
const int head_num,
const int seq_len_1,
const int seq_len_2,
const T scalar) {
for (int seq_id = blockIdx.x; seq_id < seq_len_1; seq_id += gridDim.x) {
float data[ITEMS_PER_THREAD];
int qk_offset;
__shared__ float s_mean, s_max;
float local_max = -1e20f;
for (int i = 0; blockDim.x * i + threadIdx.x < seq_len_2; i++) {
qk_offset = ((blockIdx.y * head_num + blockIdx.z) * seq_len_1 + seq_id) *
seq_len_2 +
blockDim.x * i + threadIdx.x;
int mask_offset = (blockIdx.y * seq_len_1 + seq_id) * seq_len_2 +
blockDim.x * i + threadIdx.x;
float qk = static_cast<float>(qk_buf_src[qk_offset]);
float mask_val = static_cast<float>(__ldg(&attr_mask[mask_offset]));
mask_val = (1.0f - mask_val) * -10000.0f;
data[i] = qk * static_cast<float>(scalar) + mask_val;
local_max = fmax(local_max, data[i]);
}
float max_val =
blockDim.x <= 32
? phi::funcs::warpReduceMax<float>(local_max, 0xFFFFFFFF)
: phi::funcs::blockReduceMax<float>(local_max, 0xffffffff);
if (threadIdx.x == 0) {
s_max = max_val;
}
__syncthreads();
float local_sum = 0;
for (int i = 0; blockDim.x * i + threadIdx.x < seq_len_2; i++) {
data[i] = __expf(data[i] - s_max);
local_sum += data[i];
}
float sum_val =
blockDim.x <= 32
? phi::funcs::warpReduceSum<float>(local_sum, 0xffffffff)
: phi::funcs::blockReduceSum<float>(local_sum, 0xffffffff);
if (threadIdx.x == 0) {
s_mean = sum_val + 1e-6f;
s_mean = __fdividef(1.0f, s_mean);
}
__syncthreads();
for (int i = 0; blockDim.x * i + threadIdx.x < seq_len_2; i++) {
qk_offset = ((blockIdx.y * head_num + blockIdx.z) * seq_len_1 + seq_id) *
seq_len_2 +
blockDim.x * i + threadIdx.x;
qk_buf_[qk_offset] = (T)(data[i] * s_mean);
}
}
}
template <typename T, int ITEMS_PER_THREAD>
__global__ void softmax_kernel_v4_half2(T* qk_buf_,
const T* attr_mask,
const int batch_size,
const int head_num,
const int seq_len_1,
const int seq_len_2,
const T scalar) {
using T2 = half2;
T2* qk_buf_half2 = reinterpret_cast<T2*>(qk_buf_);
const T2* attr_mask_half2 = (const T2*)attr_mask;
for (int seq_id = blockIdx.x; seq_id < seq_len_1; seq_id += gridDim.x) {
T2 data[ITEMS_PER_THREAD];
int qk_offset;
__shared__ float s_mean, s_max;
float local_max = -1e20f;
for (int i = 0;
blockDim.x * i + threadIdx.x < (seq_len_2 / 2) && i < ITEMS_PER_THREAD;
i++) {
qk_offset = ((blockIdx.y * head_num + blockIdx.z) * seq_len_1 + seq_id) *
(seq_len_2 / 2) +
blockDim.x * i + threadIdx.x;
int mask_offset = (blockIdx.y * seq_len_1 + seq_id) * (seq_len_2 / 2) +
blockDim.x * i + threadIdx.x;
T2 qk = qk_buf_half2[qk_offset];
T2 mask_val = __ldg(&attr_mask_half2[mask_offset]);
mask_val = __hmul2(__hsub2(__float2half2_rn(1.0f), mask_val),
__float2half2_rn(-10000.0f));
data[i] = __hadd2(__hmul2(qk, __half2half2(scalar)), mask_val);
local_max = fmax(
local_max,
fmax(static_cast<float>(data[i].x), static_cast<float>(data[i].y)));
}
float max_val =
blockDim.x <= 32
? phi::funcs::warpReduceMax<float>(local_max, 0xFFFFFFFF)
: phi::funcs::blockReduceMax<float>(local_max, 0xFFFFFFFF);
if (threadIdx.x == 0) {
s_max = max_val;
}
__syncthreads();
float local_sum = 0;
for (int i = 0;
blockDim.x * i + threadIdx.x < (seq_len_2 / 2) && i < ITEMS_PER_THREAD;
i++) {
data[i] = h2exp(__hsub2(data[i], __float2half2_rn(s_max)));
local_sum += static_cast<float>(data[i].x + data[i].y);
}
float sum_val =
blockDim.x <= 32
? phi::funcs::warpReduceSum<float>(local_sum, 0xFFFFFFFF)
: phi::funcs::blockReduceSum<float>(local_sum, 0xFFFFFFFF);
if (threadIdx.x == 0) {
s_mean = sum_val + 1e-6f;
s_mean = __fdividef(1.0f, s_mean);
}
__syncthreads();
for (int i = 0;
blockDim.x * i + threadIdx.x < (seq_len_2 / 2) && i < ITEMS_PER_THREAD;
i++) {
qk_offset = ((blockIdx.y * head_num + blockIdx.z) * seq_len_1 + seq_id) *
(seq_len_2 / 2) +
blockDim.x * i + threadIdx.x;
qk_buf_half2[qk_offset] = __hmul2(data[i], __float2half2_rn(s_mean));
}
}
}
template <typename T, int ITEMS_PER_THREAD, int NUM>
__global__ void softmax_kernel_v5_half2(T* qk_buf_,
const T* attr_mask,
const int batch_size,
const int head_num,
const int seq_len_1,
const int seq_len_2,
const T scalar) {
using T2 = half2;
T2* qk_buf_half2 = reinterpret_cast<T2*>(qk_buf_);
const T2* attr_mask_half2 = (const T2*)attr_mask;
for (int seq_id = blockIdx.x; seq_id < seq_len_1; seq_id += gridDim.x * NUM) {
T2 data[NUM][ITEMS_PER_THREAD];
int qk_offset[NUM];
__shared__ float s_sum[NUM], s_max[NUM];
float local_max[NUM];
#pragma unroll
for (int j = 0; j < NUM; j++) {
local_max[j] = -1e20f;
}
const int MAX_NUM =
min((seq_len_1 - seq_id + gridDim.x - 1) / gridDim.x, NUM);
for (int i = 0;
blockDim.x * i + threadIdx.x < (seq_len_2 / 2) && i < ITEMS_PER_THREAD;
i++) {
int mask_offset[NUM];
#pragma unroll
for (int j = 0; j < MAX_NUM; j++) {
qk_offset[j] = ((blockIdx.y * head_num + blockIdx.z) * seq_len_1 +
seq_id + j * gridDim.x) *
(seq_len_2 / 2) +
blockDim.x * i + threadIdx.x;
mask_offset[j] = (blockIdx.y * seq_len_1 + seq_id + j * gridDim.x) *
(seq_len_2 / 2) +
blockDim.x * i + threadIdx.x;
}
T2 mask_val[NUM];
#pragma unroll
for (int j = 0; j < MAX_NUM; j++) {
mask_val[j] = __ldg(&attr_mask_half2[mask_offset[j]]);
}
T2 qk[NUM];
#pragma unroll
for (int j = 0; j < MAX_NUM; j++) {
qk[j] = qk_buf_half2[qk_offset[j]];
}
#pragma unroll
for (int j = 0; j < MAX_NUM; j++) {
mask_val[j] = __hmul2(__hsub2(__float2half2_rn(1.0f), mask_val[j]),
__float2half2_rn(-10000.0f));
}
#pragma unroll
for (int j = 0; j < MAX_NUM; j++) {
data[j][i] = __hadd2(__hmul2(qk[j], __half2half2(scalar)), mask_val[j]);
local_max[j] = fmax(local_max[j],
fmax(static_cast<float>(data[j][i].x),
static_cast<float>(data[j][i].y)));
}
}
if (blockDim.x <= 32) {
warpReduceMaxV2<float, NUM>(local_max);
} else {
blockReduceMaxV2<float, NUM>(local_max);
}
if (threadIdx.x == 0) {
#pragma unroll
for (int j = 0; j < NUM; j++) {
s_max[j] = local_max[j];
}
}
__syncthreads();
float local_sum[NUM];
#pragma unroll
for (int j = 0; j < NUM; j++) {
local_sum[j] = {0.f};
}
for (int i = 0;
blockDim.x * i + threadIdx.x < (seq_len_2 / 2) && i < ITEMS_PER_THREAD;
i++) {
#pragma unroll
for (int j = 0; j < MAX_NUM; j++) {
data[j][i] = h2exp(__hsub2(data[j][i], __float2half2_rn(s_max[j])));
}
#pragma unroll
for (int j = 0; j < MAX_NUM; j++) {
local_sum[j] += static_cast<float>(data[j][i].x + data[j][i].y);
}
}
if (blockDim.x <= 32) {
warpReduceSumV2<float, NUM>(local_sum);
} else {
blockReduceSumV2<float, NUM>(local_sum);
}
if (threadIdx.x == 0) {
#pragma unroll
for (int j = 0; j < NUM; j++) {
s_sum[j] = __fdividef(1.0f, local_sum[j] + 1e-6f);
}
}
__syncthreads();
for (int i = 0;
blockDim.x * i + threadIdx.x < (seq_len_2 / 2) && i < ITEMS_PER_THREAD;
i++) {
#pragma unroll
for (int j = 0; j < MAX_NUM; j++) {
qk_offset[j] = ((blockIdx.y * head_num + blockIdx.z) * seq_len_1 +
seq_id + j * gridDim.x) *
(seq_len_2 / 2) +
blockDim.x * i + threadIdx.x;
}
#pragma unroll
for (int j = 0; j < MAX_NUM; j++) {
qk_buf_half2[qk_offset[j]] =
__hmul2(data[j][i], __float2half2_rn(s_sum[j]));
}
}
}
}
// -------- transpose_kernel -------- //
template <typename T>
__global__ void transposeAxis01(
T* out, T* in, const int dim0, const int dim1, const int dim2) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < dim0 * dim1 * dim2) {
const int input_dim2_index = index % dim2;
index = (index - input_dim2_index) / dim2;
const int input_dim1_index = index % dim1;
index = (index - input_dim1_index) / dim1;
const int input_dim0_index = index % dim0;
out[input_dim1_index * dim0 * dim2 + input_dim0_index * dim2 +
input_dim2_index] = in[input_dim0_index * dim1 * dim2 +
input_dim1_index * dim2 + input_dim2_index];
}
}
// -------- padding_kernel -------- //
template <typename T>
__global__ void paddingKernel(T* output1,
int* output2,
const T* input1,
const int* input2,
const int* input_lengths,
const int num_tokens,
const int batch_size,
const int max_seq_len,
const int num_experts) {
const bool IS_FP16 = std::is_same<T, phi::dtype::float16>::value;
const T MIN_T_VAL = (IS_FP16) ? (T)HALF_FLT_MIN : (T)FLT_MIN;
int offset1 = blockIdx.x * num_tokens;
int offset2 = blockIdx.x * batch_size * max_seq_len;
for (int i = 0; i < batch_size; i++) {
const T* in1_ptr = input1 + offset1;
const int* in2_ptr = input2 + offset1;
int input_length = input_lengths[i];
offset1 += input_length;
T* out1_ptr = output1 + offset2;
int* out2_ptr = output2 + offset2;
offset2 += max_seq_len;
for (int j = threadIdx.x; j < max_seq_len; j += max_seq_len) {
if (j < input_length) {
out1_ptr[j] = in1_ptr[j];
out2_ptr[j] = in2_ptr[j];
} else {
out1_ptr[j] = MIN_T_VAL;
out2_ptr[j] = 0;
}
}
}
}
// -------- general_topk_pair_sort_kernel -------- //
template <typename T, int BLOCK_THREADS, int ITEMS_PER_THREAD>
__global__ void general_topk_pair_sort(T* out_keys,
int* out_values,
T* in_keys,
int* in_values) {
typedef cub::BlockRadixSort<T, BLOCK_THREADS, ITEMS_PER_THREAD, int>
BlockRadixSort;
typedef cub::
BlockLoad<T, BLOCK_THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_TRANSPOSE>
BlockLoadKey;
typedef cub::
BlockLoad<int, BLOCK_THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_TRANSPOSE>
BlockLoadValue;
typedef cub::
BlockStore<T, BLOCK_THREADS, ITEMS_PER_THREAD, cub::BLOCK_STORE_TRANSPOSE>
BlockStoreKey;
typedef cub::BlockStore<int,
BLOCK_THREADS,
ITEMS_PER_THREAD,
cub::BLOCK_STORE_TRANSPOSE>
BlockStoreValue;
__shared__ union {
typename BlockRadixSort::TempStorage sort;
typename BlockLoadKey::TempStorage loadkey;
typename BlockLoadValue::TempStorage loadvalue;
typename BlockStoreKey::TempStorage storekey;
typename BlockStoreValue::TempStorage storevalue;
} temp_storage;
int block_offset = blockIdx.x * BLOCK_THREADS * ITEMS_PER_THREAD;
T thread_keys[ITEMS_PER_THREAD];
int thread_values[ITEMS_PER_THREAD];
BlockLoadKey(temp_storage.loadkey).Load(in_keys + block_offset, thread_keys);
BlockLoadValue(temp_storage.loadvalue)
.Load(in_values + block_offset, thread_values);
__syncthreads();
BlockRadixSort(temp_storage.sort).SortDescending(thread_keys, thread_values);
__syncthreads();
BlockStoreKey(temp_storage.storekey)
.Store(out_keys + block_offset, thread_keys);
BlockStoreValue(temp_storage.storevalue)
.Store(out_values + block_offset, thread_values);
}
// -------- finalize_moe_routing_kernel -------- //
template <typename T>
__global__ void finalize_moe_routing_kernel(
const T* expanded_permuted_rows,
T* reduced_unpermuted_output,
const T* skip,
const T* bias,
const T* scales,
const int* expanded_source_row_to_expanded_dest_row,
const int* expert_for_source_row,
const int cols,
const int k,
bool ec_route) {
const int original_row = blockIdx.x;
const int num_rows = gridDim.x;
T* reduced_row_ptr = reduced_unpermuted_output + original_row * cols;
const T* skip_row_ptr = skip + original_row * cols;
for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) {
T thread_output = skip_row_ptr[tid];
for (int k_idx = 0; k_idx < k; ++k_idx) {
const int expanded_original_row = original_row + k_idx * num_rows;
const int expanded_permuted_row =
expanded_source_row_to_expanded_dest_row[expanded_original_row];
if (ec_route && expanded_permuted_row == -1) continue;
const int64_t k_offset =
ec_route ? expanded_original_row : original_row * k + k_idx;
const T row_scale = scales[k_offset];
const T* expanded_permuted_rows_row_ptr =
expanded_permuted_rows + expanded_permuted_row * cols;
const int expert_idx = ec_route ? k_idx : expert_for_source_row[k_offset];
const T* bias_ptr = bias + expert_idx * cols;
thread_output =
thread_output +
row_scale * (expanded_permuted_rows_row_ptr[tid] + bias_ptr[tid]);
}
reduced_row_ptr[tid] = thread_output;
}
}
// -------- initialize_moe_routing_kernel -------- //
template <typename T, int VecSize>
__global__ void initialize_moe_routing_kernel(
const T* unpermuted_input,
T* permuted_output,
const int* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row,
const int num_rows,
const int active_rows,
const int cols,
const int k,
const int max_seq_len,
bool ec_route) {
using LoadT = phi::AlignedVector<T, VecSize>;
LoadT src_vec;
// Reverse permutation map.
// I do this so that later, we can use the source -> dest map to do the k-way
// reduction and unpermuting. I need the reverse map for that reduction to
// allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1
// thread block will be responsible for all k summations.
const int expanded_dest_row = blockIdx.x;
const int expanded_source_row =
ec_route ? expanded_dest_row_to_expanded_source_row[expanded_dest_row /
k * max_seq_len +
expanded_dest_row % k]
: expanded_dest_row_to_expanded_source_row[expanded_dest_row];
if (threadIdx.x == 0) {
expanded_source_row_to_expanded_dest_row[expanded_source_row] =
expanded_dest_row;
}
if (blockIdx.x < active_rows) {
// Duplicate and permute rows
const int source_row = expanded_source_row % num_rows;
const T* source_row_ptr = unpermuted_input + source_row * cols;
T* dest_row_ptr = permuted_output + expanded_dest_row * cols;
for (int tid = threadIdx.x * VecSize; tid < cols;
tid += blockDim.x * VecSize) {
// dest_row_ptr[tid] = source_row_ptr[tid];
phi::Load<T, VecSize>(&source_row_ptr[tid], &src_vec);
phi::Store<T, VecSize>(src_vec, &dest_row_ptr[tid]);
}
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void MoeKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& gate,
const DenseTensor& bmm0,
const DenseTensor& bias0,
const DenseTensor& bmm1,
const DenseTensor& bias1,
const std::string& act_type,
DenseTensor* output);
} // namespace phi
......@@ -78,6 +78,7 @@ if(NOT WITH_GPU)
list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op_api)
endif()
list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op)
list(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_op)
list(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_grad_op)
list(REMOVE_ITEM TEST_OPS test_fuse_gemm_epilogue_pass)
......@@ -143,6 +144,7 @@ if(WIN32)
list(REMOVE_ITEM TEST_OPS test_trt_convert_preln_residual_bias)
list(REMOVE_ITEM TEST_OPS test_trt_convert_preln_residual_no_bias)
list(REMOVE_ITEM TEST_OPS test_fused_multi_transformer_int8_op)
list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op)
endif()
list(REMOVE_ITEM TEST_OPS test_checkpoint_saver)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.nn.functional as F
from paddle.fluid.framework import default_main_program
from paddle.incubate.nn.functional import fused_ec_moe
from paddle.nn.layer.common import Linear
default_main_program().random_seed = 42
class TestFusedEcMoEOp(OpTest):
def setUp(self):
self.config()
self.rtol = 1e-3
self.atol = 1e-3
paddle.set_default_dtype(self.x_type)
self.__class__.op_type = "fused_ec_moe"
# Since it's only used in inference.
self.__class__.no_need_check_grad = True
self.bmm_w0 = paddle.to_tensor(
np.random.randn(self.num_expert, self.d_model, self.d_feedforward)
* 0.001,
dtype=paddle.float16,
)
self.bmm_b0 = paddle.to_tensor(
np.random.randn(self.num_expert, 1, self.d_feedforward) * 0.001,
dtype=paddle.float16,
)
self.bmm_w1 = paddle.to_tensor(
np.random.randn(self.num_expert, self.d_feedforward, self.d_model)
* 0.001,
dtype=paddle.float16,
)
self.bmm_b1 = paddle.to_tensor(
np.random.randn(self.num_expert, 1, self.d_model) * 0.001,
dtype=paddle.float16,
)
self.tensor_x = paddle.to_tensor(
np.random.randn(self.batch_size, self.seq_len, self.d_model)
* 0.001,
dtype=paddle.float16,
)
self.bmm_w0.stop_gradient = True
self.bmm_b0.stop_gradient = True
self.bmm_w1.stop_gradient = True
self.bmm_b1.stop_gradient = True
self.tensor_x.stop_gradient = True
self.gate = Linear(self.d_model, self.num_expert)
paddle.set_default_dtype("float16")
self.activation = getattr(F, self.act_method)
def config(self):
self.x_type = np.float16
self.batch_size = 10
self.seq_len = 128
self.num_expert = 32
self.d_model = 768
self.d_feedforward = 3072
self.act_method = 'gelu'
def GetBaselineOut(self, tensor_x, gate_logits):
def expert_choice_gating(logits, capacity, batch_idx, expert_idx):
gates = F.softmax(logits, -1)
indices1_s = paddle.topk(
logits.transpose([0, 2, 1]), k=capacity, axis=-1
)[1].cast("int32")
seqlen_idx = indices1_s.reshape([-1])
gather_idx = paddle.stack([batch_idx, seqlen_idx, expert_idx], -1)
prob = paddle.gather_nd(gates, gather_idx)
return prob, expert_idx, gather_idx, capacity
paddle.disable_static()
capacity = self.seq_len // 16
batch_expert_idx = paddle.nonzero(
paddle.ones(shape=[self.batch_size, self.num_expert, capacity])
).cast('int32')
batch_idx = batch_expert_idx[:, 0]
expert_idx = batch_expert_idx[:, 1]
(
expert_prob_flatten,
expert_idx_flatten,
gather_idx,
cap,
) = expert_choice_gating(gate_logits, capacity, batch_idx, expert_idx)
outputs = paddle.zeros_like(tensor_x)
batch_prob = expert_prob_flatten.reshape(
[self.batch_size, self.num_expert, -1, 1]
)
batch_idx = gather_idx[:, :2]
selected_token = tensor_x.gather_nd(batch_idx)
batch_selected_token = selected_token.reshape(
[self.batch_size, self.num_expert, -1, tensor_x.shape[-1]]
)
batch_selected_token = batch_selected_token.transpose(
[1, 0, 2, 3]
).reshape([self.num_expert, -1, tensor_x.shape[-1]])
output = paddle.bmm(batch_selected_token, self.bmm_w0) + self.bmm_b0
output = self.activation(output)
output = paddle.bmm(output, self.bmm_w1) + self.bmm_b1
output = output.transpose([1, 0, 2]).reshape(
[self.batch_size, -1, self.num_expert, tensor_x.shape[-1]]
)
output = output.transpose([0, 2, 1, 3])
output = batch_prob * output
output = output.reshape([-1, tensor_x.shape[-1]])
outputs = outputs.scatter_nd_add(batch_idx, output)
return outputs + tensor_x
def GetFusedEcMoeOut(self, tensor_x, gate_logits):
paddle.disable_static()
fused_out = fused_ec_moe(
tensor_x,
gate_logits,
self.bmm_w0,
self.bmm_b0,
self.bmm_w1,
self.bmm_b1,
self.act_method,
)
return fused_out
def test_fused_ec_moe_op(self):
gate_logits = self.gate(self.tensor_x)
final_out_ref = self.GetBaselineOut(self.tensor_x, gate_logits)
final_out = self.GetFusedEcMoeOut(self.tensor_x, gate_logits)
np.testing.assert_allclose(
final_out_ref, final_out, rtol=self.rtol, atol=self.atol
)
class TestFusedEcMoEOpActGeluFp16(TestFusedEcMoEOp):
def config(self):
super().config()
self.x_type = np.float16
class TestFusedEcMoEOpActReluFp16(TestFusedEcMoEOp):
def config(self):
super().config()
self.x_type = np.float16
self.act_method = "relu"
if __name__ == "__main__":
unittest.main()
......@@ -20,6 +20,7 @@ from .layer.fused_linear import FusedLinear # noqa: F401
from .layer.fused_transformer import (
FusedBiasDropoutResidualLayerNorm,
) # noqa: F401
from .layer.fused_ec_moe import FusedEcMoe # noqa: F401
__all__ = [ # noqa
'FusedMultiHeadAttention',
......@@ -28,4 +29,5 @@ __all__ = [ # noqa
'FusedMultiTransformer',
'FusedLinear',
'FusedBiasDropoutResidualLayerNorm',
'FusedEcMoe',
]
......@@ -17,6 +17,7 @@ from .fused_transformer import fused_feedforward
from .fused_transformer import fused_multi_transformer
from .fused_matmul_bias import fused_matmul_bias, fused_linear
from .fused_transformer import fused_bias_dropout_residual_layer_norm
from .fused_ec_moe import fused_ec_moe
__all__ = [
'fused_multi_head_attention',
......@@ -25,4 +26,5 @@ __all__ = [
'fused_matmul_bias',
'fused_linear',
'fused_bias_dropout_residual_layer_norm',
'fused_ec_moe',
]
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid.layer_helper import LayerHelper
def fused_ec_moe(
x, gate, bmm0_weight, bmm0_bias, bmm1_weight, bmm1_bias, act_type
):
"""
Applies fused ec_moe kernel.
This method requires SM_ARCH in sm75, sm80, sm86.
Args:
x (Tensor): the input Tensor. Its shape is [bsz, seq_len, d_model].
gate (Tensor): the gate Tensor to choose expert. Its shape is [bsz, seq_len, e].
bmm0_weight (Tensor): the first batch matrix matmul weight. Its shape is [e, d_model, d_feed_forward].
bmm0_bias (Tensor): the first batch matrix matmul bias. Its shape is [e, 1, d_feed_forward].
bmm1_weight (Tensor): the second batch matrix matmul weight. Its shape is [e, d_model, d_feed_forward].
bmm1_bias (Tensor): the second batch matrix matmul bias. Its shape is [e, 1, d_feed_forward].
act_type (string): the Activation Type. Currently only support `gelu`, `relu`.
Returns:
Tensor: the output Tensor.
Examples:
.. code-block:: python
# required: gpu
import paddle
from paddle.incubate.nn.functional import fused_ec_moe
batch = 10
seq_len = 128
d_model = 1024
d_feed_forward = d_model * 4
num_expert = 8
x = paddle.randn([batch, seq_len, d_model])
gate = paddle.randn([batch, seq_len, num_expert])
bmm0_weight = paddle.randn([num_expert, d_model, d_feed_forward])
bmm0_bias = paddle.randn([num_expert, d_model, d_feed_forward])
bmm1_weight = paddle.randn([num_expert, d_model, d_feed_forward])
bmm1_bias = paddle.randn([num_expert, d_model, d_feed_forward])
out = fused_ec_moe(x, gate, bmm0_weight, bmm0_bias, bmm1_weight, bmm1_bias, act_type="gelu")
print(out.shape) # [batch, seq_len, num_expert]
"""
helper = LayerHelper('fused_moe', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='moe',
inputs={
'X': x,
'Gate': gate,
'Bmm0': bmm0_weight,
'Bias0': bmm0_bias,
'Bmm1': bmm1_weight,
'Bias1': bmm1_bias,
},
outputs={'Out': out},
attrs={'act_type': act_type},
)
return out
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.incubate.nn import functional as F
from paddle.nn import Layer
class FusedEcMoe(Layer):
r"""A FusedEcMoe Layer.
Parameters:
hidden_size (int): The dim size of input units.
inter_size (int): The dim size of feed forward network.
num_expert (int): The number of experts.
act_type (string): The activation type. Currently only support `gelu`, `relu`.
weight_attr (ParamAttr, optional): The attribute for the learnable
weight of this layer. The default value is None and the weight will be
initialized to zero. For detailed information, please refer to
paddle.ParamAttr.
bias_attr (ParamAttr|bool, optional): The attribute for the learnable bias
of this layer. If it is set to False, no bias will be added to the output.
If it is set to None or one kind of ParamAttr, a bias parameter will
be created according to ParamAttr. For detailed information, please refer
to paddle.ParamAttr. The default value is None and the bias will be
initialized to zero.
Attribute:
**weight** (Parameter): the learnable weight of this layer.
**bias** (Parameter): the learnable bias of this layer.
Shape:
- input: Multi-dimentional tensor with shape :math:`[batch\_size, seq\_len, d\_model]` .
- output: Multi-dimentional tensor with shape :math:`[batch\_size, seq\_len, d\_model]` .
Examples:
.. code-block:: python
# required: gpu
import paddle
from paddle.incubate.nn.layer.fused_ec_moe import FusedEcMoe
x = paddle.randn([10, 128, 1024]) # [bsz, seq_len, d_model]
gate = paddle.randn([10, 128, 8]) # [bsz, seq_len, num_experts]
moe = FusedEcMoe(1024, 4096, 8, act_type="gelu")
y = moe(x, gate)
print(y.shape) # [10, 128, 1024]
"""
def __init__(
self,
hidden_size,
inter_size,
num_experts,
act_type,
weight_attr=None,
bias_attr=None,
):
super().__init__()
weight0_shape = [num_experts, hidden_size, inter_size]
bias0_shape = [num_experts, 1, inter_size]
weight1_shape = [num_experts, inter_size, hidden_size]
bias1_shape = [num_experts, 1, hidden_size]
dtype = self._helper.get_default_dtype()
self.bmm_weight0 = self.create_parameter(
shape=weight0_shape, attr=weight_attr, dtype=dtype, is_bias=False
)
self.bmm_bias0 = self.create_parameter(
shape=bias0_shape, attr=bias_attr, dtype=dtype, is_bias=True
)
self.bmm_weight1 = self.create_parameter(
shape=weight1_shape, attr=weight_attr, dtype=dtype, is_bias=False
)
self.bmm_bias1 = self.create_parameter(
shape=bias1_shape, attr=bias_attr, dtype=dtype, is_bias=True
)
self.act_type = act_type
if self.act_type not in ["gelu", "relu"]:
raise NotImplementedError("Currently only support `gelu`, `relu`. ")
def forward(self, x, gate):
return F.fused_ec_moe(
x,
gate,
self.bmm_weight0,
self.bmm_bias0,
self.bmm_weight1,
self.bmm_bias1,
self.act_type,
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册