Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
def2a87f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
def2a87f
编写于
12月 22, 2022
作者:
X
xiaoxiaohehe001
提交者:
GitHub
12月 22, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Paddle Inference] Add moe phi kernel (#48703)
上级
efa34534
变更
18
显示空白变更内容
内联
并排
Showing
18 changed file
with
3956 addition
and
0 deletion
+3956
-0
cmake/third_party.cmake
cmake/third_party.cmake
+1
-0
paddle/fluid/inference/tensorrt/dynamic_shape_infermeta.cc
paddle/fluid/inference/tensorrt/dynamic_shape_infermeta.cc
+10
-0
paddle/fluid/operators/moe_op.cc
paddle/fluid/operators/moe_op.cc
+64
-0
paddle/phi/infermeta/multiary.cc
paddle/phi/infermeta/multiary.cc
+14
-0
paddle/phi/infermeta/multiary.h
paddle/phi/infermeta/multiary.h
+9
-0
paddle/phi/kernels/CMakeLists.txt
paddle/phi/kernels/CMakeLists.txt
+6
-0
paddle/phi/kernels/fusion/cutlass/default_moe_fc_traits.h
paddle/phi/kernels/fusion/cutlass/default_moe_fc_traits.h
+206
-0
paddle/phi/kernels/fusion/cutlass/linear_combination_ft_gelu.h
...e/phi/kernels/fusion/cutlass/linear_combination_ft_gelu.h
+687
-0
paddle/phi/kernels/fusion/cutlass/moe_cutlass_kernel.h
paddle/phi/kernels/fusion/cutlass/moe_cutlass_kernel.h
+879
-0
paddle/phi/kernels/fusion/cutlass/moe_kernel.cu
paddle/phi/kernels/fusion/cutlass/moe_kernel.cu
+911
-0
paddle/phi/kernels/fusion/cutlass/moe_kernel_impl.h
paddle/phi/kernels/fusion/cutlass/moe_kernel_impl.h
+779
-0
paddle/phi/kernels/fusion/moe_kernel.h
paddle/phi/kernels/fusion/moe_kernel.h
+32
-0
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+2
-0
python/paddle/fluid/tests/unittests/test_fused_ec_moe_op.py
python/paddle/fluid/tests/unittests/test_fused_ec_moe_op.py
+176
-0
python/paddle/incubate/nn/__init__.py
python/paddle/incubate/nn/__init__.py
+2
-0
python/paddle/incubate/nn/functional/__init__.py
python/paddle/incubate/nn/functional/__init__.py
+2
-0
python/paddle/incubate/nn/functional/fused_ec_moe.py
python/paddle/incubate/nn/functional/fused_ec_moe.py
+75
-0
python/paddle/incubate/nn/layer/fused_ec_moe.py
python/paddle/incubate/nn/layer/fused_ec_moe.py
+101
-0
未找到文件。
cmake/third_party.cmake
浏览文件 @
def2a87f
...
...
@@ -516,6 +516,7 @@ if(WITH_GPU
if
(
${
CMAKE_CUDA_COMPILER_VERSION
}
GREATER_EQUAL 11.0
)
include
(
external/cutlass
)
# download, build, install cusparselt
list
(
APPEND third_party_deps extern_cutlass
)
set
(
WITH_CUTLASS ON
)
endif
()
endif
()
...
...
paddle/fluid/inference/tensorrt/dynamic_shape_infermeta.cc
浏览文件 @
def2a87f
...
...
@@ -235,6 +235,15 @@ nvinfer1::DimsExprs UnchangedInferMeta(
return
inputs
[
0
];
}
nvinfer1
::
DimsExprs
MoeInferMeta
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
,
// NOLINT
const
framework
::
OpDesc
&
op_desc
)
{
return
inputs
[
0
];
}
nvinfer1
::
DimsExprs
Pad3dInferMeta
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
...
...
@@ -384,6 +393,7 @@ PD_REGISTER_DYNAMIC_INFER_META_FN(instance_norm, InstanceNormInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN
(
unfold
,
UnflodInferMeta
);
PD_REGISTER_DYNAMIC_INFER_META_FN
(
scatter_nd_add
,
ScatterNdAddInferMeta
);
PD_REGISTER_DYNAMIC_INFER_META_FN
(
inverse
,
UnchangedInferMeta
);
PD_REGISTER_DYNAMIC_INFER_META_FN
(
moe
,
MoeInferMeta
);
PD_REGISTER_DYNAMIC_INFER_META_FN
(
pad3d
,
Pad3dInferMeta
);
PD_REGISTER_DYNAMIC_INFER_META_FN
(
grid_sampler
,
GridSamplerInferMeta
);
}
// namespace tensorrt
...
...
paddle/fluid/operators/moe_op.cc
0 → 100644
浏览文件 @
def2a87f
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/binary.h"
namespace
paddle
{
namespace
operators
{
class
MoeOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
data_type
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
);
return
framework
::
OpKernelType
(
data_type
,
ctx
.
device_context
());
}
};
class
MoeOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor), The source input tensor of Moe op."
);
AddInput
(
"Gate"
,
"(Tensor), The gating input tensor of Moe op."
);
AddInput
(
"Bmm0"
,
"(Tensor), The bmm0 input tensor of Moe op."
);
AddInput
(
"Bias0"
,
"(Tensor), The eltwise0 input tensor of Moe op."
);
AddInput
(
"Bmm1"
,
"(Tensor), The bmm1 input tensor of Moe op."
);
AddInput
(
"Bias1"
,
"(Tensor), The eltwise1 input tensor of Moe op."
);
AddOutput
(
"Out"
,
"(Tensor), The output tensor of Moe op."
);
AddAttr
<
std
::
string
>
(
"act_type"
,
R"DOC(activation type, currently only support `gelu`, `relu`. Default value is: `gelu`. )DOC"
)
.
SetDefault
(
"gelu"
);
AddComment
(
R"DOC(FusedEcMoe kernel. For more details you can refer to `FusedEcMoE` python documents. )DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
DECLARE_INFER_SHAPE_FUNCTOR
(
moe
,
MoeInferShapeFunctor
,
PD_INFER_META
(
phi
::
MoeInferMeta
));
REGISTER_OPERATOR
(
moe
,
ops
::
MoeOp
,
ops
::
MoeOpMaker
,
MoeInferShapeFunctor
);
paddle/phi/infermeta/multiary.cc
浏览文件 @
def2a87f
...
...
@@ -2931,6 +2931,20 @@ void YoloLossInferMeta(const MetaTensor& x,
gt_match_mask
->
set_dtype
(
x
.
dtype
());
}
void
MoeInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
gate
,
const
MetaTensor
&
bmm0
,
const
MetaTensor
&
bias0
,
const
MetaTensor
&
bmm1
,
const
MetaTensor
&
bias1
,
const
std
::
string
&
act_type
,
MetaTensor
*
out
)
{
out
->
set_dims
(
x
.
dims
());
out
->
share_lod
(
x
);
out
->
set_dtype
(
x
.
dtype
());
out
->
set_layout
(
x
.
layout
());
}
}
// namespace phi
PD_REGISTER_INFER_META_FN
(
batch_norm_infer
,
phi
::
BatchNormInferInferMeta
);
paddle/phi/infermeta/multiary.h
浏览文件 @
def2a87f
...
...
@@ -523,4 +523,13 @@ void YoloLossInferMeta(const MetaTensor& x,
MetaTensor
*
objectness_mask
,
MetaTensor
*
gt_match_mask
);
void
MoeInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
gate
,
const
MetaTensor
&
bmm0
,
const
MetaTensor
&
bias0
,
const
MetaTensor
&
bmm1
,
const
MetaTensor
&
bias1
,
const
std
::
string
&
act_type
,
MetaTensor
*
out
);
}
// namespace phi
paddle/phi/kernels/CMakeLists.txt
浏览文件 @
def2a87f
...
...
@@ -104,6 +104,12 @@ file(
"strings/gpu/*.cu"
"fusion/gpu/*.cu"
)
if
(
WITH_CUTLASS
)
file
(
GLOB cutlass_cu
"fusion/cutlass/default_moe_fc_traits.h"
"fusion/cutlass/linear_combination_ft_gelu.h"
"fusion/cutlass/moe*"
)
list
(
APPEND kernel_cu
${
cutlass_cu
}
)
endif
()
if
(
WITH_MKLDNN
)
file
(
GLOB
...
...
paddle/phi/kernels/fusion/cutlass/default_moe_fc_traits.h
0 → 100644
浏览文件 @
def2a87f
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/bfloat16.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/layout/matrix.h"
namespace
cutlass
{
namespace
gemm
{
namespace
kernel
{
template
<
typename
TypeA
,
typename
TypeB
,
typename
arch
>
struct
MoeArchTraits
{};
template
<
typename
arch
>
struct
MoeArchTraits
<
float
,
float
,
arch
>
{
static
constexpr
int
Stages
=
2
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassSimt
;
using
AccType
=
float
;
using
LayoutB
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
ElementsPerAccessA
=
1
;
static
constexpr
int
ElementsPerAccessB
=
1
;
static
constexpr
int
ElementsPerAccessC
=
1
;
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
8
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
8
>
;
using
InstructionShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
1
,
1
>
;
using
Operator
=
cutlass
::
arch
::
OpMultiplyAdd
;
};
// ========================= Volta Traits ===========================
// Volta will always dequantize after the global memory load.
template
<
typename
TypeB
>
struct
MoeArchTraits
<
cutlass
::
half_t
,
TypeB
,
cutlass
::
arch
::
Sm70
>
{
private:
static
constexpr
int
ThreadblockK
=
32
;
public:
static
constexpr
int
Stages
=
2
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
AccType
=
float
;
using
LayoutB
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
ElementsPerAccessA
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
half_t
>::
value
;
static
constexpr
int
ElementsPerAccessB
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
half_t
>::
value
;
static
constexpr
int
ElementsPerAccessC
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
half_t
>::
value
;
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
ThreadblockK
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
32
,
ThreadblockK
>
;
using
InstructionShape
=
cutlass
::
gemm
::
GemmShape
<
8
,
8
,
4
>
;
using
Operator
=
cutlass
::
arch
::
OpMultiplyAdd
;
};
template
<
typename
TypeB
>
struct
MoeArchTraits
<
cutlass
::
bfloat16_t
,
TypeB
,
cutlass
::
arch
::
Sm70
>
{
private:
static
constexpr
int
ThreadblockK
=
32
;
public:
static
constexpr
int
Stages
=
2
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
AccType
=
float
;
using
LayoutB
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
ElementsPerAccessA
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
bfloat16_t
>::
value
;
static
constexpr
int
ElementsPerAccessB
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
bfloat16_t
>::
value
;
static
constexpr
int
ElementsPerAccessC
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
bfloat16_t
>::
value
;
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
ThreadblockK
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
32
,
ThreadblockK
>
;
using
InstructionShape
=
cutlass
::
gemm
::
GemmShape
<
8
,
8
,
4
>
;
using
Operator
=
cutlass
::
arch
::
OpMultiplyAdd
;
};
// ======================= Turing Traits ==============================
// Turing will dequantize after LDSM
// fp16 x fp16 specialization
template
<
>
struct
MoeArchTraits
<
cutlass
::
half_t
,
cutlass
::
half_t
,
cutlass
::
arch
::
Sm75
>
{
static
constexpr
int
Stages
=
2
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
AccType
=
float
;
using
LayoutB
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
ElementsPerAccessA
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
half_t
>::
value
;
static
constexpr
int
ElementsPerAccessB
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
half_t
>::
value
;
static
constexpr
int
ElementsPerAccessC
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
half_t
>::
value
;
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
32
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
32
,
32
>
;
using
InstructionShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
8
>
;
using
Operator
=
cutlass
::
arch
::
OpMultiplyAdd
;
};
// bf16 x bf16 specialization
template
<
>
struct
MoeArchTraits
<
cutlass
::
bfloat16_t
,
cutlass
::
bfloat16_t
,
cutlass
::
arch
::
Sm75
>
{
static
constexpr
int
Stages
=
2
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
AccType
=
float
;
using
LayoutB
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
ElementsPerAccessA
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
bfloat16_t
>::
value
;
static
constexpr
int
ElementsPerAccessB
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
bfloat16_t
>::
value
;
static
constexpr
int
ElementsPerAccessC
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
bfloat16_t
>::
value
;
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
32
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
32
,
32
>
;
using
InstructionShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
8
>
;
using
Operator
=
cutlass
::
arch
::
OpMultiplyAdd
;
};
template
<
>
struct
MoeArchTraits
<
float
,
float
,
cutlass
::
arch
::
Sm80
>
{
static
constexpr
int
Stages
=
3
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
AccType
=
float
;
using
LayoutB
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
ElementsPerAccessA
=
4
;
static
constexpr
int
ElementsPerAccessB
=
4
;
static
constexpr
int
ElementsPerAccessC
=
4
;
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
16
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
16
>
;
using
InstructionShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
8
>
;
using
Operator
=
cutlass
::
arch
::
OpMultiplyAdd
;
};
template
<
>
struct
MoeArchTraits
<
cutlass
::
half_t
,
cutlass
::
half_t
,
cutlass
::
arch
::
Sm80
>
{
static
constexpr
int
Stages
=
3
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
AccType
=
float
;
using
LayoutB
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
ElementsPerAccessA
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
half_t
>::
value
;
static
constexpr
int
ElementsPerAccessB
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
half_t
>::
value
;
static
constexpr
int
ElementsPerAccessC
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
half_t
>::
value
;
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
32
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
32
,
32
>
;
using
InstructionShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
16
>
;
using
Operator
=
cutlass
::
arch
::
OpMultiplyAdd
;
};
template
<
>
struct
MoeArchTraits
<
cutlass
::
bfloat16_t
,
cutlass
::
bfloat16_t
,
cutlass
::
arch
::
Sm80
>
{
static
constexpr
int
Stages
=
3
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
AccType
=
float
;
using
LayoutB
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
ElementsPerAccessA
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
bfloat16_t
>::
value
;
static
constexpr
int
ElementsPerAccessB
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
bfloat16_t
>::
value
;
static
constexpr
int
ElementsPerAccessC
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
bfloat16_t
>::
value
;
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
32
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
32
,
32
>
;
using
InstructionShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
16
>
;
using
Operator
=
cutlass
::
arch
::
OpMultiplyAdd
;
};
}
// namespace kernel
}
// namespace gemm
}
// namespace cutlass
paddle/phi/kernels/fusion/cutlass/linear_combination_ft_gelu.h
0 → 100644
浏览文件 @
def2a87f
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Functor performing linear combination with a maximum operation used by
epilogues.
*/
#pragma once
#include <cuda.h>
#include <cutlass/half.h>
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/epilogue/thread/scale_type.h"
#include "cutlass/functional.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_types.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
{
namespace
epilogue
{
namespace
thread
{
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace
detail
{
/// Single source of truth for whether to unroll for `LinearCombinationClamp()`
constexpr
bool
LinearCombinationFtGeluIsHeavy
()
{
return
false
;
}
}
// namespace detail
/////////////////////////////////////////////////////////////////////////////////////////////////
__forceinline__
__device__
float
copysignf_pos
(
float
a
,
float
b
)
{
float
r
;
r
=
__int_as_float
(
__float_as_int
(
a
)
|
(
__float_as_int
(
b
)
&
0x80000000
));
return
r
;
}
__inline__
__device__
float
tanh_opt
(
float
x
)
{
#if (__CUDA_ARCH__ >= 750)
float
r
;
asm
(
"tanh.approx.f32 %0,%1;
\n\t
"
:
"=f"
(
r
)
:
"f"
(
x
));
return
r
;
#else
const
float
exp_val
=
-
1.
f
*
fabs
(
2
*
x
);
return
copysignf_pos
((
1.0
f
-
__expf
(
exp_val
))
/
(
__expf
(
exp_val
)
+
1.0
f
),
x
);
#endif
}
/////////////////////////////////////////////////////////////////////////////////////////////////
// GELU operator implemented using the Taylor series approximation
template
<
typename
T
>
struct
FtGelu
{
static
const
bool
kIsHeavy
=
true
;
CUTLASS_DEVICE
T
operator
()(
T
const
&
z
)
const
{
T
k0
=
static_cast
<
float
>
(
0.7978845608028654
);
T
k1
=
static_cast
<
float
>
(
0.044715
);
return
T
(
cutlass
::
constants
::
half
<
T
>
()
*
z
*
(
cutlass
::
constants
::
one
<
T
>
()
+
fast_tanh
(
k0
*
z
*
(
cutlass
::
constants
::
one
<
T
>
()
+
k1
*
z
*
z
))));
}
};
template
<
>
struct
FtGelu
<
float
>
{
static
const
bool
kIsHeavy
=
true
;
CUTLASS_DEVICE
float
operator
()(
float
const
&
z
)
const
{
float
k0
=
static_cast
<
float
>
(
0.7978845608028654
);
float
k1
=
static_cast
<
float
>
(
0.044715
);
return
float
(
z
*
(
cutlass
::
constants
::
one
<
float
>
()
+
tanh_opt
(
k0
*
z
*
(
cutlass
::
constants
::
one
<
float
>
()
+
k1
*
z
*
z
))));
}
};
template
<
int
N
>
struct
FtGelu
<
Array
<
half_t
,
N
>>
{
static
const
bool
kIsHeavy
=
true
;
CUTLASS_DEVICE
Array
<
half_t
,
N
>
operator
()(
Array
<
half_t
,
N
>
const
&
z
)
const
{
using
T
=
half_t
;
Array
<
half_t
,
N
>
y
;
half_t
k0
=
half_t
(
0.7978845608028654
);
half_t
k1
=
half_t
(
0.044715
);
multiply_add
<
Array
<
half_t
,
N
>>
fma
;
multiplies
<
Array
<
half_t
,
N
>>
mul
;
plus
<
Array
<
half_t
,
N
>>
add
;
fast_tanh_op
<
Array
<
half_t
,
N
>>
tanh
;
Array
<
half_t
,
N
>
u
=
mul
(
mul
(
k0
,
z
),
fma
(
mul
(
k1
,
z
),
z
,
cutlass
::
constants
::
one
<
T
>
()));
y
=
mul
(
mul
(
z
,
cutlass
::
constants
::
half
<
T
>
()),
add
(
cutlass
::
constants
::
one
<
T
>
(),
tanh
(
u
)));
return
y
;
}
};
template
<
typename
T
,
int
N
>
struct
FtGelu
<
Array
<
T
,
N
>>
{
static
const
bool
kIsHeavy
=
true
;
CUTLASS_DEVICE
Array
<
T
,
N
>
operator
()(
Array
<
T
,
N
>
const
&
rhs
)
const
{
Array
<
T
,
N
>
y
;
FtGelu
<
T
>
gelu_op
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
y
[
i
]
=
gelu_op
(
rhs
[
i
]);
}
return
y
;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Applies a linear combination operator to an array of elements.
///
/// D = alpha * accumulator + beta * source + uniform
///
template
<
typename
ElementOutput_
,
///< Data type used to load and store tensors
int
Count
,
///< Number of elements computed per operation
///< Usually it is 128/sizeof_bits<ElementOutput_>,
///< but we use 64 or 32 sometimes when there are not enough
///< data to store
typename
ElementAccumulator_
=
ElementOutput_
,
///< Accumulator data type
typename
ElementCompute_
=
ElementOutput_
,
///< Data type used to compute linear combination
ScaleType
::
Kind
Scale
=
ScaleType
::
Default
,
///< Control Alpha and Beta scaling
FloatRoundStyle
Round
=
FloatRoundStyle
::
round_to_nearest
>
class
LinearCombinationFtGelu
{
public:
using
ElementOutput
=
ElementOutput_
;
using
ElementAccumulator
=
ElementAccumulator_
;
using
ElementCompute
=
ElementCompute_
;
static
int
const
kCount
=
Count
;
static
const
ScaleType
::
Kind
kScale
=
Scale
;
using
FragmentOutput
=
Array
<
ElementOutput
,
kCount
>
;
using
FragmentAccumulator
=
Array
<
ElementAccumulator
,
kCount
>
;
using
FragmentCompute
=
Array
<
ElementCompute
,
kCount
>
;
using
FragmentScaleBias
=
Array
<
ElementCompute
,
kCount
>
;
static
FloatRoundStyle
const
kRound
=
Round
;
static
bool
const
kIsHeavy
=
detail
::
LinearCombinationFtGeluIsHeavy
();
/// Host-constructable parameters structure
struct
Params
{
ElementCompute
alpha
;
///< scales accumulators
ElementCompute
beta
;
///< scales source tensor
ElementCompute
threshold
;
///< minimum value that is output
ElementCompute
const
*
alpha_ptr
;
///< pointer to accumulator scalar - if
///< not null, loads it from memory
ElementCompute
const
*
beta_ptr
;
///< pointer to source scalar - if not
///< null, loads it from memory
//
// Methods
//
CUTLASS_HOST_DEVICE
Params
()
:
alpha
(
ElementCompute
(
1
)),
beta
(
ElementCompute
(
0
)),
threshold
(
ElementCompute
(
0
)),
alpha_ptr
(
nullptr
),
beta_ptr
(
nullptr
)
{}
CUTLASS_HOST_DEVICE
Params
(
ElementCompute
alpha
,
ElementCompute
beta
=
ElementCompute
(
0
),
ElementCompute
threshold
=
ElementCompute
(
0
))
:
alpha
(
alpha
),
beta
(
beta
),
threshold
(
threshold
),
alpha_ptr
(
nullptr
),
beta_ptr
(
nullptr
)
{}
CUTLASS_HOST_DEVICE
Params
(
ElementCompute
const
*
alpha_ptr
,
ElementCompute
const
*
beta_ptr
=
nullptr
,
ElementCompute
threshold
=
ElementCompute
(
0
))
:
alpha
(
0
),
beta
(
0
),
threshold
(
threshold
),
alpha_ptr
(
alpha_ptr
),
beta_ptr
(
beta_ptr
)
{}
};
private:
//
// Data members
//
ElementCompute
alpha_
;
ElementCompute
beta_
;
ElementCompute
threshold_
;
public:
/// Constructs the function object, possibly loading from pointers in host
/// memory
CUTLASS_HOST_DEVICE
explicit
LinearCombinationFtGelu
(
Params
const
&
params
)
{
alpha_
=
(
params
.
alpha_ptr
?
*
params
.
alpha_ptr
:
params
.
alpha
);
beta_
=
(
params
.
beta_ptr
?
*
params
.
beta_ptr
:
params
.
beta
);
threshold_
=
params
.
threshold
;
}
/// Returns true if source is needed
CUTLASS_HOST_DEVICE
bool
is_source_needed
()
const
{
if
(
Scale
==
ScaleType
::
NoBetaScaling
)
return
true
;
if
(
Scale
==
ScaleType
::
OnlyAlphaScaling
)
return
false
;
if
(
Scale
==
ScaleType
::
Nothing
)
return
false
;
return
beta_
!=
ElementCompute
(
0
);
}
/// Functionally required for serial reduction in the epilogue
CUTLASS_HOST_DEVICE
void
set_k_partition
(
int
k_partition
,
int
k_partition_count
)
{
if
(
k_partition
)
{
beta_
=
ElementCompute
(
1
);
}
if
(
k_partition
!=
k_partition_count
-
1
)
{
// set to NaN to make ReLU no-op for all except last k partitions
int64_t
allones
=
-
1
;
threshold_
=
reinterpret_cast
<
ElementCompute
const
&>
(
allones
);
}
}
/// Computes linear scaling: D = alpha * accumulator + beta * source
CUTLASS_HOST_DEVICE
FragmentOutput
operator
()(
FragmentAccumulator
const
&
accumulator
,
FragmentOutput
const
&
source
)
const
{
// Convert source to interal compute numeric type
NumericArrayConverter
<
ElementCompute
,
ElementOutput
,
kCount
,
Round
>
source_converter
;
NumericArrayConverter
<
ElementCompute
,
ElementAccumulator
,
kCount
,
Round
>
accumulator_converter
;
FragmentCompute
converted_source
=
source_converter
(
source
);
FragmentCompute
converted_accumulator
=
accumulator_converter
(
accumulator
);
// Perform binary operations
FragmentCompute
intermediate
;
multiplies
<
FragmentCompute
>
mul_add_source
;
multiply_add
<
FragmentCompute
>
mul_add_accumulator
;
GELU
<
FragmentCompute
>
ftgelu
;
if
(
Scale
==
ScaleType
::
NoBetaScaling
)
{
intermediate
=
converted_source
;
intermediate
=
mul_add_accumulator
(
alpha_
,
converted_accumulator
,
intermediate
);
// D = alpha * Accum + X
}
else
if
(
Scale
==
ScaleType
::
Nothing
)
{
intermediate
=
converted_accumulator
;
}
else
{
intermediate
=
mul_add_source
(
beta_
,
converted_source
);
// X = beta * C + uniform
intermediate
=
mul_add_accumulator
(
alpha_
,
converted_accumulator
,
intermediate
);
// D = alpha * Accum + X
}
// Compute threshold optionally
intermediate
=
ftgelu
(
intermediate
);
// Convert to destination numeric type
NumericArrayConverter
<
ElementOutput
,
ElementCompute
,
kCount
,
Round
>
destination_converter
;
return
destination_converter
(
intermediate
);
}
/// Computes linear scaling: D = alpha * accumulator
CUTLASS_HOST_DEVICE
FragmentOutput
operator
()(
FragmentAccumulator
const
&
accumulator
)
const
{
// Convert source to interal compute numeric type
NumericArrayConverter
<
ElementCompute
,
ElementAccumulator
,
kCount
,
Round
>
accumulator_converter
;
FragmentCompute
converted_accumulator
=
accumulator_converter
(
accumulator
);
// Perform binary operations
FragmentCompute
intermediate
;
multiplies
<
FragmentCompute
>
mul_accumulator
;
GELU
<
FragmentCompute
>
ftgelu
;
if
(
Scale
==
ScaleType
::
Nothing
)
{
intermediate
=
converted_accumulator
;
}
else
{
intermediate
=
mul_accumulator
(
alpha_
,
converted_accumulator
);
// D = alpha * Accum
}
// Compute threshold optionally
intermediate
=
ftgelu
(
intermediate
);
// Convert to destination numeric type
NumericArrayConverter
<
ElementOutput
,
ElementCompute
,
kCount
,
Round
>
destination_converter
;
return
destination_converter
(
intermediate
);
}
/// Computes per-channel linear scaling and bias : D = scale * accumulator +
/// bias Scale and Bias are from input Fragment
CUTLASS_HOST_DEVICE
FragmentOutput
operator
()(
FragmentAccumulator
const
&
accumulator
,
FragmentScaleBias
const
&
scale
,
FragmentScaleBias
const
&
bias
)
const
{
// Convert source to interal compute numeric type
NumericArrayConverter
<
ElementCompute
,
ElementAccumulator
,
kCount
,
Round
>
accumulator_converter
;
FragmentCompute
converted_accumulator
=
accumulator_converter
(
accumulator
);
// Perform per-channel scale and bias
FragmentCompute
intermediate
;
multiply_add
<
FragmentCompute
>
mul_add_accumulator
;
if
(
Scale
==
ScaleType
::
OnlyAlphaPerChannelScaling
)
intermediate
=
mul_add_accumulator
(
scale
,
converted_accumulator
,
bias
);
// D = scale * Accum + bias
else
intermediate
=
mul_add_accumulator
(
alpha_
,
converted_accumulator
,
bias
);
// D = alpha * Accum + bias
GELU
<
FragmentCompute
>
ftgelu
;
// Compute threshold optionally
intermediate
=
ftgelu
(
intermediate
);
// Convert to destination numeric type
NumericArrayConverter
<
ElementOutput
,
ElementCompute
,
kCount
,
Round
>
destination_converter
;
return
destination_converter
(
intermediate
);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Conditional guards to enable partial specialization for packed integers
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && \
((__CUDACC_VER_MAJOR__ > 10) || \
((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)))
/// Applies a linear combination operator to an array of elements.
///
/// D = alpha * accumulator + beta * source + uniform
///
/// Special handling for int types
template
<
typename
ElementOutput_
,
///< Data type used to load and store
///< tensors
int
Count
,
///< Number of elements computed per operation
ScaleType
::
Kind
Scale
,
///< Control Alpha and Beta scaling
FloatRoundStyle
Round
>
class
LinearCombinationFtGelu
<
ElementOutput_
,
Count
,
int
,
float
,
Scale
,
Round
>
{
public:
using
ElementOutput
=
ElementOutput_
;
using
ElementAccumulator
=
int
;
using
ElementCompute
=
float
;
static
bool
const
kIsHeavy
=
detail
::
LinearCombinationFtGeluIsHeavy
();
static
int
const
kCount
=
Count
;
static
const
ScaleType
::
Kind
kScale
=
Scale
;
using
FragmentOutput
=
Array
<
ElementOutput
,
kCount
>
;
using
FragmentAccumulator
=
Array
<
ElementAccumulator
,
kCount
>
;
using
FragmentCompute
=
Array
<
ElementCompute
,
kCount
>
;
using
FragmentScaleBias
=
Array
<
ElementCompute
,
kCount
>
;
static
FloatRoundStyle
const
kRound
=
Round
;
/// Host-constructable parameters structure
struct
Params
{
ElementCompute
alpha
;
///< scales accumulators
ElementCompute
beta
;
///< scales source tensor
ElementCompute
threshold
;
///< minimum value that is output
ElementCompute
const
*
alpha_ptr
;
///< pointer to accumulator scalar - if
///< not null, loads it from memory
ElementCompute
const
*
beta_ptr
;
///< pointer to source scalar - if not
///< null, loads it from memory
//
// Methods
//
CUTLASS_HOST_DEVICE
Params
()
:
alpha
(
ElementCompute
(
1
)),
beta
(
ElementCompute
(
0
)),
threshold
(
ElementCompute
(
0
)),
alpha_ptr
(
nullptr
),
beta_ptr
(
nullptr
)
{}
CUTLASS_HOST_DEVICE
Params
(
ElementCompute
alpha
,
ElementCompute
beta
=
ElementCompute
(
0
),
ElementCompute
threshold
=
ElementCompute
(
0
))
:
alpha
(
alpha
),
beta
(
beta
),
threshold
(
threshold
),
alpha_ptr
(
nullptr
),
beta_ptr
(
nullptr
)
{}
CUTLASS_HOST_DEVICE
Params
(
ElementCompute
const
*
alpha_ptr
,
ElementCompute
const
*
beta_ptr
=
nullptr
,
ElementCompute
threshold
=
ElementCompute
(
0
))
:
alpha
(
0
),
beta
(
0
),
threshold
(
threshold
),
alpha_ptr
(
alpha_ptr
),
beta_ptr
(
beta_ptr
)
{}
};
private:
//
// Data members
//
ElementCompute
alpha_
;
ElementCompute
beta_
;
ElementCompute
threshold_
;
public:
/// Constructs the function object, possibly loading from pointers in host
/// memory
CUTLASS_HOST_DEVICE
explicit
LinearCombinationFtGelu
(
Params
const
&
params
)
{
alpha_
=
(
params
.
alpha_ptr
?
*
params
.
alpha_ptr
:
params
.
alpha
);
beta_
=
(
params
.
beta_ptr
?
*
params
.
beta_ptr
:
params
.
beta
);
threshold_
=
params
.
threshold
;
}
/// Returns true if source is needed
CUTLASS_HOST_DEVICE
bool
is_source_needed
()
const
{
if
(
Scale
==
ScaleType
::
NoBetaScaling
)
return
true
;
if
(
Scale
==
ScaleType
::
OnlyAlphaScaling
)
return
false
;
if
(
Scale
==
ScaleType
::
Nothing
)
return
false
;
return
beta_
!=
ElementCompute
(
0
);
}
/// Functionally required for serial reduction in the epilogue
CUTLASS_HOST_DEVICE
void
set_k_partition
(
int
k_partition
,
int
k_partition_count
)
{
if
(
k_partition
)
{
beta_
=
ElementCompute
(
1
);
}
if
(
k_partition
!=
k_partition_count
-
1
)
{
// set to NaN to make ReLU no-op for all except last k partitions
int64_t
allones
=
-
1
;
threshold_
=
reinterpret_cast
<
ElementCompute
const
&>
(
allones
);
}
}
/// Computes linear scaling: D = alpha * accumulator + beta * source
CUTLASS_HOST_DEVICE
FragmentOutput
operator
()(
FragmentAccumulator
const
&
accumulator
,
FragmentOutput
const
&
source
)
const
{
// Convert source to interal compute numeric type
NumericArrayConverter
<
ElementCompute
,
ElementOutput
,
kCount
,
Round
>
source_converter
;
NumericArrayConverter
<
ElementCompute
,
ElementAccumulator
,
kCount
,
Round
>
accumulator_converter
;
FragmentCompute
converted_source
=
source_converter
(
source
);
FragmentCompute
converted_accumulator
=
accumulator_converter
(
accumulator
);
// Perform binary operations
FragmentCompute
intermediate
;
multiplies
<
FragmentCompute
>
mul_add_source
;
multiply_add
<
FragmentCompute
>
mul_add_accumulator
;
GELU
<
FragmentCompute
>
ftgelu
;
if
(
Scale
==
ScaleType
::
NoBetaScaling
)
{
intermediate
=
converted_source
;
intermediate
=
mul_add_accumulator
(
alpha_
,
converted_accumulator
,
intermediate
);
// D = alpha * Accum + X
}
else
if
(
Scale
==
ScaleType
::
Nothing
)
{
intermediate
=
converted_accumulator
;
}
else
{
intermediate
=
mul_add_source
(
beta_
,
converted_source
);
// X = beta * C + uniform
intermediate
=
mul_add_accumulator
(
alpha_
,
converted_accumulator
,
intermediate
);
// D = alpha * Accum + X
}
// Compute threshold optionally
intermediate
=
ftgelu
(
intermediate
);
if
(
platform
::
numeric_limits
<
ElementOutput
>::
is_integer
)
{
// Convert floats back to INT
FragmentAccumulator
scaled_accumulator
;
NumericArrayConverter
<
int
,
ElementCompute
,
kCount
,
Round
>
compute_converter
;
scaled_accumulator
=
compute_converter
(
intermediate
);
// Convert to destination numeric type
NumericArrayConverter
<
ElementOutput
,
int
,
kCount
,
Round
>
destination_converter
;
return
destination_converter
(
scaled_accumulator
);
}
else
{
NumericArrayConverter
<
ElementOutput
,
ElementCompute
,
kCount
,
Round
>
destination_converter
;
return
destination_converter
(
intermediate
);
}
}
/// Computes linear scaling: D = alpha * accumulator
CUTLASS_HOST_DEVICE
FragmentOutput
operator
()(
FragmentAccumulator
const
&
accumulator
)
const
{
// Convert source to interal compute numeric type
NumericArrayConverter
<
ElementCompute
,
ElementAccumulator
,
kCount
,
Round
>
accumulator_converter
;
FragmentCompute
converted_accumulator
=
accumulator_converter
(
accumulator
);
// Perform binary operations
FragmentCompute
intermediate
;
multiplies
<
FragmentCompute
>
mul_accumulator
;
GELU
<
FragmentCompute
>
ftgelu
;
if
(
Scale
==
ScaleType
::
Nothing
)
{
intermediate
=
converted_accumulator
;
}
else
{
intermediate
=
mul_accumulator
(
alpha_
,
converted_accumulator
);
// D = alpha * Accum
}
// Compute threshold optionally
intermediate
=
ftgelu
(
intermediate
);
if
(
platform
::
numeric_limits
<
ElementOutput
>::
is_integer
)
{
// Convert floats back to INT
FragmentAccumulator
scaled_accumulator
;
NumericArrayConverter
<
int
,
ElementCompute
,
kCount
,
Round
>
compute_converter
;
scaled_accumulator
=
compute_converter
(
intermediate
);
// Convert to destination numeric type
NumericArrayConverter
<
ElementOutput
,
int
,
kCount
,
Round
>
destination_converter
;
return
destination_converter
(
scaled_accumulator
);
}
else
{
NumericArrayConverter
<
ElementOutput
,
ElementCompute
,
kCount
,
Round
>
destination_converter
;
return
destination_converter
(
intermediate
);
}
}
/// Computes per-channel linear scaling and bias : D = scale * accumulator +
/// bias Scale and Bias are from input Fragment
CUTLASS_HOST_DEVICE
FragmentOutput
operator
()(
FragmentAccumulator
const
&
accumulator
,
FragmentScaleBias
const
&
scale
,
FragmentScaleBias
const
&
bias
)
const
{
// Convert source to interal compute numeric type
NumericArrayConverter
<
ElementCompute
,
ElementAccumulator
,
kCount
,
Round
>
accumulator_converter
;
FragmentCompute
converted_accumulator
=
accumulator_converter
(
accumulator
);
// Perform per-channel scale and bias
FragmentCompute
intermediate
;
multiply_add
<
FragmentCompute
>
mul_add_accumulator
;
if
(
Scale
==
ScaleType
::
OnlyAlphaPerChannelScaling
)
intermediate
=
mul_add_accumulator
(
scale
,
converted_accumulator
,
bias
);
// D = scale * Accum + bias
else
intermediate
=
mul_add_accumulator
(
alpha_
,
converted_accumulator
,
bias
);
// D = alpha * Accum + bias
GELU
<
FragmentCompute
>
ftgelu
;
// Compute threshold optionally
intermediate
=
ftgelu
(
intermediate
);
if
(
platform
::
numeric_limits
<
ElementOutput
>::
is_integer
)
{
// Convert floats back to INT
FragmentAccumulator
scaled_accumulator
;
NumericArrayConverter
<
int
,
ElementCompute
,
kCount
,
Round
>
compute_converter
;
scaled_accumulator
=
compute_converter
(
intermediate
);
// Convert to destination numeric type
NumericArrayConverter
<
ElementOutput
,
int
,
kCount
,
Round
>
destination_converter
;
return
destination_converter
(
scaled_accumulator
);
}
else
{
NumericArrayConverter
<
ElementOutput
,
ElementCompute
,
kCount
,
Round
>
destination_converter
;
return
destination_converter
(
intermediate
);
}
}
};
#endif // Conditional guards to enable partial specialization for packed
// integers
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace thread
}
// namespace epilogue
}
// namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
paddle/phi/kernels/fusion/cutlass/moe_cutlass_kernel.h
0 → 100644
浏览文件 @
def2a87f
/***************************************************************************************************
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
*modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
*notice, this list of conditions and the following disclaimer in the
*documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its
*contributors may be used to endorse or promote products derived from this
*software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
*DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
*INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
*DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
*OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
*NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
*EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
*/
#pragma once
#include "cutlass/complex.h"
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/semaphore.h"
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
#include "cutlass/gemm/kernel/grouped_problem_visitor.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/trace.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
{
namespace
gemm
{
namespace
kernel
{
/// Visitor class to abstract away the algorithm for iterating over tiles
template
<
typename
ProblemSizeHelper
,
typename
ThreadblockShape_
>
struct
BaseMoeProblemVisitor
{
using
ThreadblockShape
=
ThreadblockShape_
;
struct
ProblemInfo
{
static
int32_t
const
kNoPrefetchEntry
=
-
1
;
int32_t
problem_idx
;
int32_t
problem_start
;
CUTLASS_DEVICE
ProblemInfo
()
:
problem_idx
(
kNoPrefetchEntry
),
problem_start
(
kNoPrefetchEntry
)
{}
CUTLASS_DEVICE
ProblemInfo
(
int32_t
problem_idx_
,
int32_t
problem_start_
)
:
problem_idx
(
problem_idx_
),
problem_start
(
problem_start_
)
{}
};
struct
Params
{
int64_t
const
*
last_row_for_problem
;
int64_t
gemm_n
;
int64_t
gemm_k
;
int32_t
problem_count
;
void
const
*
workspace
;
int32_t
tile_count
;
//
// Methods
//
/// Ctor
CUTLASS_HOST_DEVICE
Params
()
:
last_row_for_problem
(
nullptr
),
gemm_n
(
0
),
gemm_k
(
0
),
problem_count
(
0
),
workspace
(
nullptr
),
tile_count
(
0
)
{}
/// Ctor
CUTLASS_HOST_DEVICE
Params
(
int64_t
const
*
last_row_for_problem
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int32_t
problem_count
,
void
const
*
workspace
=
nullptr
,
int32_t
tile_count
=
0
)
:
last_row_for_problem
(
last_row_for_problem
),
gemm_n
(
gemm_n
),
gemm_k
(
gemm_k
),
problem_count
(
problem_count
),
workspace
(
workspace
),
tile_count
(
tile_count
)
{}
};
Params
const
&
params
;
int32_t
tile_idx
;
int32_t
problem_tile_start
;
int32_t
problem_idx
;
//
// Methods
//
CUTLASS_DEVICE
BaseMoeProblemVisitor
(
Params
const
&
params_
,
int32_t
block_idx
)
:
params
(
params_
),
tile_idx
(
block_idx
),
problem_tile_start
(
0
),
problem_idx
(
0
)
{}
/// Get the grid shape
CUTLASS_HOST_DEVICE
static
cutlass
::
gemm
::
GemmCoord
grid_shape
(
const
cutlass
::
gemm
::
GemmCoord
&
problem
)
{
return
cutlass
::
gemm
::
GemmCoord
(
((
problem
.
m
()
-
1
+
ThreadblockShape
::
kM
)
/
ThreadblockShape
::
kM
),
((
problem
.
n
()
-
1
+
ThreadblockShape
::
kN
)
/
ThreadblockShape
::
kN
),
1
);
}
/// Gets the global tile index
CUTLASS_HOST_DEVICE
int32_t
tile_index
()
const
{
return
tile_idx
;
}
/// Gets the index of the problem
CUTLASS_HOST_DEVICE
int32_t
problem_index
()
const
{
return
problem_idx
;
}
CUTLASS_HOST_DEVICE
int32_t
threadblock_idx
()
const
{
return
tile_idx
-
problem_tile_start
;
}
CUTLASS_DEVICE
void
advance
(
int32_t
grid_size
)
{
tile_idx
+=
grid_size
;
}
CUTLASS_HOST_DEVICE
static
void
possibly_transpose_problem
(
cutlass
::
gemm
::
GemmCoord
&
problem
)
{
// NOLINT
ProblemSizeHelper
::
possibly_transpose_problem
(
problem
);
}
/// Returns the problem size for the current problem
CUTLASS_HOST_DEVICE
cutlass
::
gemm
::
GemmCoord
problem_size
()
const
{
return
problem_size
(
problem_idx
);
}
CUTLASS_HOST_DEVICE
cutlass
::
gemm
::
GemmCoord
problem_size
(
int
idx
)
const
{
const
int64_t
prev_problem_row
=
idx
==
0
?
0
:
params
.
last_row_for_problem
[
idx
-
1
];
const
int64_t
current_problem_row
=
params
.
last_row_for_problem
[
idx
];
const
int64_t
gemm_m
=
current_problem_row
-
prev_problem_row
;
GemmCoord
problem
(
GemmCoord
::
Index
(
gemm_m
),
GemmCoord
::
Index
(
params
.
gemm_n
),
GemmCoord
::
Index
(
params
.
gemm_k
));
ProblemSizeHelper
::
possibly_transpose_problem
(
problem
);
return
problem
;
}
CUTLASS_HOST_DEVICE
static
int32_t
tile_count
(
const
cutlass
::
gemm
::
GemmCoord
&
grid
)
{
return
ProblemSizeHelper
::
tile_count
(
grid
);
}
static
int32_t
group_tile_count
(
const
cutlass
::
gemm
::
GemmCoord
*
host_problem_sizes_ptr
,
int32_t
problem_count
)
{
int32_t
total_tiles
=
0
;
for
(
int32_t
i
=
0
;
i
<
problem_count
;
++
i
)
{
auto
problem
=
host_problem_sizes_ptr
[
i
];
possibly_transpose_problem
(
problem
);
auto
grid
=
grid_shape
(
problem
);
total_tiles
+=
tile_count
(
grid
);
}
return
total_tiles
;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
ProblemSizeHelper
,
typename
ThreadblockShape
,
GroupScheduleMode
GroupScheduleMode_
,
int
PrefetchTileCount
,
int
ThreadCount
>
struct
MoeProblemVisitor
;
/////////////////////////////////////////////////////////////////////////////////////////////////
// ProblemVisitor that performs all scheduling on device
//
template
<
typename
ProblemSizeHelper
,
typename
ThreadblockShape
,
int
PrefetchTileCount
,
int
ThreadCount
>
struct
MoeProblemVisitor
<
ProblemSizeHelper
,
ThreadblockShape
,
GroupScheduleMode
::
kDeviceOnly
,
PrefetchTileCount
,
ThreadCount
>
:
public
BaseMoeProblemVisitor
<
ProblemSizeHelper
,
ThreadblockShape
>
{
using
Base
=
BaseMoeProblemVisitor
<
ProblemSizeHelper
,
ThreadblockShape
>
;
using
Params
=
typename
Base
::
Params
;
static
int
const
kThreadCount
=
ThreadCount
;
static
bool
const
kRequiresPrecomputation
=
false
;
static
int
const
kThreadsPerWarp
=
32
;
struct
SharedStorage
{};
// Final tile of the problem loaded by this thread. Each thread will hold
// a separate value.
int32_t
problem_ending_tile
;
SharedStorage
&
shared_storage
;
//
// Methods
//
CUTLASS_DEVICE
MoeProblemVisitor
(
Params
const
&
params_
,
SharedStorage
&
shared_storage_
,
// NOLINT
int32_t
block_idx
)
:
Base
(
params_
,
block_idx
),
problem_ending_tile
(
0
),
shared_storage
(
shared_storage_
)
{
this
->
problem_idx
=
-
1
*
kThreadsPerWarp
;
this
->
problem_tile_start
=
0
;
}
CUTLASS_DEVICE
bool
next_tile
()
{
// Check whether the tile to compute is within the range of the current
// problem.
int32_t
problem_tile_end
=
__shfl_sync
(
0xffffffff
,
problem_ending_tile
,
this
->
problem_idx
%
kThreadsPerWarp
);
if
(
this
->
tile_idx
<
problem_tile_end
)
{
return
true
;
}
// Check whether the tile to compute is within the current group of problems
// fetched by the warp. The last tile for this group is the final tile of
// the problem held by the final thread in the warp.
int32_t
group_tile_end
=
__shfl_sync
(
0xffffffff
,
problem_ending_tile
,
kThreadsPerWarp
-
1
);
// Keep the starting problem for this group in `problem_idx`. This is done
// to reduce register pressure. The starting problem for this group is
// simply the first problem in the group most recently fetched by the warp.
int32_t
&
group_problem_start
=
this
->
problem_idx
;
group_problem_start
=
(
this
->
problem_idx
/
kThreadsPerWarp
)
*
kThreadsPerWarp
;
// Keep the starting tile for this group in `problem_tile_start`. This is
// done to reduce register pressure.
int32_t
&
group_tile_start
=
this
->
problem_tile_start
;
// Each thread in the warp processes a separate problem to advance until
// reaching a problem whose starting tile is less less than tile_idx.
while
(
group_tile_end
<=
this
->
tile_idx
)
{
group_problem_start
+=
kThreadsPerWarp
;
if
(
group_problem_start
>
this
->
params
.
problem_count
)
{
return
false
;
}
// Since `group_tile_start` is a reference to `this->problem_tile_start`,
// this also sets `this->problem_tile_start`. The fact that
// `this->problem_tile_start` is also set here is used later in
// `next_tile`.
group_tile_start
=
group_tile_end
;
int
lane_idx
=
threadIdx
.
x
%
kThreadsPerWarp
;
int32_t
lane_problem
=
group_problem_start
+
lane_idx
;
// Compute the number of tiles in the problem assigned to each thread.
problem_ending_tile
=
0
;
if
(
lane_problem
<
this
->
params
.
problem_count
)
{
cutlass
::
gemm
::
GemmCoord
problem
=
this
->
problem_size
(
lane_problem
);
cutlass
::
gemm
::
GemmCoord
grid
=
this
->
grid_shape
(
problem
);
problem_ending_tile
=
this
->
tile_count
(
grid
);
}
// Compute a warp-wide inclusive prefix sum to compute the ending tile
// index of each thread's problem.
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
1
;
i
<
kThreadsPerWarp
;
i
<<=
1
)
{
int32_t
val
=
__shfl_up_sync
(
0xffffffff
,
problem_ending_tile
,
i
);
if
(
lane_idx
>=
i
)
{
problem_ending_tile
+=
val
;
}
}
// The total tile count for this group is now in the final position of the
// prefix sum
int32_t
tiles_in_group
=
__shfl_sync
(
0xffffffff
,
problem_ending_tile
,
kThreadsPerWarp
-
1
);
problem_ending_tile
+=
group_tile_start
;
group_tile_end
+=
tiles_in_group
;
}
// The next problem to process is the first one that does not have ending
// tile position that is greater than or equal to tile index.
int32_t
problem_idx_in_group
=
__popc
(
__ballot_sync
(
0xffffffff
,
problem_ending_tile
<=
this
->
tile_idx
));
this
->
problem_idx
=
group_problem_start
+
problem_idx_in_group
;
// The starting tile for this problem is the ending tile of the previous
// problem. In cases where `problem_idx_in_group` is the first problem in
// the group, we do not need to reset `problem_tile_start`, because it is
// set to the previous group's ending tile in the while loop above.
if
(
problem_idx_in_group
>
0
)
{
this
->
problem_tile_start
=
__shfl_sync
(
0xffffffff
,
problem_ending_tile
,
problem_idx_in_group
-
1
);
}
return
true
;
}
static
size_t
get_workspace_size
(
const
cutlass
::
gemm
::
GemmCoord
*
host_problem_sizes_ptr
,
int32_t
problem_count
,
int32_t
block_count
)
{
return
0
;
}
static
void
host_precompute
(
const
cutlass
::
gemm
::
GemmCoord
*
host_problem_sizes_ptr
,
int32_t
problem_count
,
int32_t
block_count
,
void
*
host_workspace_ptr
)
{}
};
/// Visitor class to abstract away the algorithm for iterating over tiles
template
<
typename
ThreadblockShape
,
GroupScheduleMode
GroupScheduleMode_
,
int
PrefetchTileCount
,
int
ThreadCount
,
bool
Transposed
=
false
>
struct
GemmMoeProblemVisitor
:
public
MoeProblemVisitor
<
detail
::
GemmGroupedProblemSizeHelper
<
Transposed
>
,
ThreadblockShape
,
GroupScheduleMode_
,
PrefetchTileCount
,
ThreadCount
>
{
static
bool
const
kTransposed
=
Transposed
;
using
ProblemSizeHelper
=
detail
::
GemmGroupedProblemSizeHelper
<
Transposed
>
;
using
Base
=
MoeProblemVisitor
<
ProblemSizeHelper
,
ThreadblockShape
,
GroupScheduleMode_
,
PrefetchTileCount
,
ThreadCount
>
;
using
Params
=
typename
Base
::
Params
;
using
SharedStorage
=
typename
Base
::
SharedStorage
;
//
// Methods
//
CUTLASS_DEVICE
GemmMoeProblemVisitor
(
Params
const
&
params_
,
SharedStorage
&
shared_storage_
,
// NOLINT
int32_t
block_idx
)
:
Base
(
params_
,
shared_storage_
,
block_idx
)
{}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// This section exists to that we can use the same kernel code for regular gemm
// and dequantizing gemms. It will dispatch to the dequantizing gemm if the Mma
// type has an Iterator for scales in global.
template
<
typename
...
>
using
void_t
=
void
;
template
<
typename
Mma
,
typename
=
void
>
struct
use_dq_gemm
:
platform
::
false_type
{};
template
<
typename
Mma
>
struct
use_dq_gemm
<
Mma
,
void_t
<
typename
Mma
::
IteratorScale
>>
:
platform
::
true_type
{};
// SFINAE overload for dequantizing gemm
template
<
typename
Mma
,
typename
ElementScale
,
typename
platform
::
enable_if
<
use_dq_gemm
<
Mma
>
::
value
,
bool
>::
type
=
true
>
CUTLASS_DEVICE
static
void
run_mma
(
Mma
mma
,
int
gemm_k_iterations
,
typename
Mma
::
FragmentC
&
accum
,
// NOLINT
typename
Mma
::
IteratorA
iterator_A
,
typename
Mma
::
IteratorB
iterator_B
,
typename
Mma
::
FragmentC
const
&
src_accum
,
ElementScale
*
weight_scale_ptr
,
MatrixCoord
scale_extent
,
const
int
thread_idx
,
MatrixCoord
tb_offset_scale
)
{
typename
Mma
::
IteratorScale
iterator_scale
(
Mma
::
IteratorScale
::
Layout
(
scale_extent
.
column
()),
weight_scale_ptr
,
scale_extent
,
thread_idx
,
tb_offset_scale
);
mma
(
gemm_k_iterations
,
accum
,
iterator_A
,
iterator_B
,
iterator_scale
,
src_accum
);
}
// SFINAE overload for normal gemm. This completely ignores the scale parameters
template
<
typename
Mma
,
typename
ElementScale
,
typename
platform
::
enable_if
<!
use_dq_gemm
<
Mma
>
::
value
,
bool
>::
type
=
true
>
CUTLASS_DEVICE
static
void
run_mma
(
Mma
mma
,
int
gemm_k_iterations
,
typename
Mma
::
FragmentC
&
accum
,
// NOLINT
typename
Mma
::
IteratorA
iterator_A
,
typename
Mma
::
IteratorB
iterator_B
,
typename
Mma
::
FragmentC
const
&
src_accum
,
ElementScale
*
weight_scale_ptr
,
MatrixCoord
scale_extent
,
const
int
thread_idx
,
MatrixCoord
tb_offset_scale
)
{
mma
(
gemm_k_iterations
,
accum
,
iterator_A
,
iterator_B
,
src_accum
);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Mma_
,
///! Threadblock-scoped matrix multiply-accumulate
typename
Epilogue_
,
///! Epilogue
typename
ThreadblockSwizzle_
,
///! Threadblock swizzling function
GroupScheduleMode
GroupScheduleMode_
///! Type of scheduling to
/// perform
>
struct
MoeFCGemm
{
public:
using
Mma
=
Mma_
;
using
Epilogue
=
Epilogue_
;
using
EpilogueOutputOp
=
typename
Epilogue
::
OutputOp
;
using
ThreadblockSwizzle
=
ThreadblockSwizzle_
;
static
GroupScheduleMode
const
kGroupScheduleMode
=
GroupScheduleMode_
;
static
bool
const
kTransposed
=
false
;
// Optional transpose
using
MapArguments
=
kernel
::
detail
::
MapArguments
<
typename
Mma
::
IteratorA
::
Element
,
typename
Mma
::
IteratorA
::
Layout
,
Mma
::
kTransformA
,
Mma
::
IteratorA
::
AccessType
::
kElements
,
typename
Mma
::
IteratorB
::
Element
,
typename
Mma
::
IteratorB
::
Layout
,
Mma
::
kTransformB
,
Mma
::
IteratorB
::
AccessType
::
kElements
,
typename
Mma
::
LayoutC
,
kTransposed
>
;
// Public-facing type definitions related to operand element type, layout, and
// complex conjugate operation. Must interact with the 'kTransposed' notion.
static_assert
(
!
kTransposed
,
"Transpose problem not supported"
);
using
ElementA
=
typename
MapArguments
::
ElementA
;
using
LayoutA
=
typename
MapArguments
::
LayoutA
;
using
ElementB
=
typename
MapArguments
::
ElementB
;
using
LayoutB
=
typename
MapArguments
::
LayoutB
;
using
ElementC
=
typename
Epilogue
::
OutputTileIterator
::
Element
;
using
LayoutC
=
typename
MapArguments
::
LayoutC
;
using
ElementScale
=
ElementC
;
static
ComplexTransform
const
kTransformA
=
MapArguments
::
kTransformA
;
static
ComplexTransform
const
kTransformB
=
MapArguments
::
kTransformB
;
// Type definitions about the mainloop.
using
Operator
=
typename
Mma
::
Operator
;
using
OperatorClass
=
typename
Mma
::
Operator
::
OperatorClass
;
using
ThreadblockShape
=
typename
Mma
::
Shape
;
using
WarpShape
=
typename
Mma
::
Operator
::
Shape
;
using
InstructionShape
=
typename
Mma
::
Policy
::
Operator
::
InstructionShape
;
using
ArchTag
=
typename
Mma
::
ArchTag
;
static
int
const
kStages
=
Mma
::
kStages
;
static
int
const
kAlignmentA
=
MapArguments
::
kAlignmentA
;
static
int
const
kAlignmentB
=
MapArguments
::
kAlignmentB
;
static
int
const
kAlignmentC
=
Epilogue
::
OutputTileIterator
::
kElementsPerAccess
;
/// Warp count (concept: GemmShape)
using
WarpCount
=
typename
Mma
::
WarpCount
;
static
int
const
kThreadCount
=
32
*
WarpCount
::
kCount
;
using
ProblemVisitor
=
GemmMoeProblemVisitor
<
ThreadblockShape
,
kGroupScheduleMode
,
kThreadCount
,
kThreadCount
,
kTransposed
>
;
//
// Structures
//
/// Argument structure
struct
Arguments
{
//
// Data members
//
int
problem_count
;
int
threadblock_count
;
typename
EpilogueOutputOp
::
Params
output_op
;
ElementA
*
ptr_A
;
ElementB
*
ptr_B
;
ElementScale
*
weight_scales
;
ElementC
*
ptr_C
;
ElementC
*
ptr_D
;
int64_t
*
total_rows_before_expert
;
int64_t
gemm_n
;
int64_t
gemm_k
;
// Only used by device-level operator
GemmCoord
*
host_problem_sizes
;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments
()
:
problem_count
(
0
),
threadblock_count
(
0
),
ptr_A
(
nullptr
),
ptr_B
(
nullptr
),
weight_scales
(
nullptr
),
ptr_C
(
nullptr
),
ptr_D
(
nullptr
),
total_rows_before_expert
(
nullptr
),
gemm_n
(
0
),
gemm_k
(
0
),
host_problem_sizes
(
nullptr
)
{}
/// Ctor
CUTLASS_HOST_DEVICE
Arguments
(
int
problem_count
,
int
threadblock_count
,
typename
EpilogueOutputOp
::
Params
output_op
,
const
ElementA
*
ptr_A
,
const
ElementB
*
ptr_B
,
const
ElementScale
*
weight_scales
,
const
ElementC
*
ptr_C
,
ElementC
*
ptr_D
,
int64_t
*
total_rows_before_expert
,
int64_t
gemm_n
,
int64_t
gemm_k
,
GemmCoord
*
host_problem_sizes
=
nullptr
)
:
problem_count
(
problem_count
),
threadblock_count
(
threadblock_count
),
output_op
(
output_op
),
ptr_A
(
const_cast
<
ElementA
*>
(
ptr_A
)),
ptr_B
(
const_cast
<
ElementB
*>
(
ptr_B
)),
weight_scales
(
const_cast
<
ElementScale
*>
(
weight_scales
)),
ptr_C
(
const_cast
<
ElementC
*>
(
ptr_C
)),
ptr_D
(
ptr_D
),
total_rows_before_expert
(
total_rows_before_expert
),
gemm_n
(
gemm_n
),
gemm_k
(
gemm_k
),
host_problem_sizes
(
nullptr
)
{
if
(
platform
::
is_same
<
uint8_t
,
ElementB
>::
value
||
platform
::
is_same
<
uint4b_t
,
ElementB
>::
value
)
{
assert
(
weight_scales
);
}
}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct
Params
{
typename
ProblemVisitor
::
Params
problem_visitor
;
int
threadblock_count
;
typename
EpilogueOutputOp
::
Params
output_op
;
ElementA
*
ptr_A
;
ElementB
*
ptr_B
;
ElementScale
*
weight_scales
;
ElementC
*
ptr_C
;
ElementC
*
ptr_D
;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params
()
:
ptr_A
(
nullptr
),
ptr_B
(
nullptr
),
weight_scales
(
nullptr
),
ptr_C
(
nullptr
),
ptr_D
(
nullptr
)
{}
CUTLASS_HOST_DEVICE
Params
(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
,
int
tile_count
=
0
)
// NOLINT
:
problem_visitor
(
args
.
total_rows_before_expert
,
args
.
gemm_n
,
args
.
gemm_k
,
args
.
problem_count
,
workspace
,
tile_count
),
threadblock_count
(
args
.
threadblock_count
),
output_op
(
args
.
output_op
),
ptr_A
(
args
.
ptr_A
),
ptr_B
(
args
.
ptr_B
),
weight_scales
(
args
.
weight_scales
),
ptr_C
(
args
.
ptr_C
),
ptr_D
(
args
.
ptr_D
)
{}
CUTLASS_HOST_DEVICE
void
update
(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
,
int
tile_count
=
0
)
{
problem_visitor
=
typename
ProblemVisitor
::
Params
(
args
.
total_rows_before_expert
,
args
.
gemm_n
,
args
.
gemm_k
,
args
.
problem_count
,
workspace
,
tile_count
);
threadblock_count
=
args
.
threadblock_count
;
output_op
=
args
.
output_op
;
ptr_A
=
args
.
ptr_A
;
ptr_B
=
args
.
ptr_B
;
weight_scales
=
args
.
weight_scales
;
ptr_C
=
args
.
ptr_C
;
ptr_D
=
args
.
ptr_D
;
}
};
/// Shared memory storage structure
union
SharedStorage
{
typename
ProblemVisitor
::
SharedStorage
problem_visitor
;
typename
Mma
::
SharedStorage
main_loop
;
typename
Epilogue
::
SharedStorage
epilogue
;
};
public:
//
// Methods
//
CUTLASS_DEVICE
MoeFCGemm
()
{}
/// Determines whether kernel satisfies alignment
static
Status
can_implement
(
cutlass
::
gemm
::
GemmCoord
const
&
problem_size
)
{
return
Status
::
kSuccess
;
}
static
Status
can_implement
(
Arguments
const
&
args
)
{
if
(
platform
::
is_same
<
uint8_t
,
ElementB
>::
value
||
platform
::
is_same
<
uint4b_t
,
ElementB
>::
value
)
{
if
(
args
.
weight_scales
==
nullptr
)
{
CUTLASS_TRACE_HOST
(
"MoeFCGemm::can_implement() - weight scales are required for "
"uint8_t and uint4b_t"
);
return
Status
::
kInvalid
;
}
}
else
if
(
args
.
weight_scales
!=
nullptr
)
{
CUTLASS_TRACE_HOST
(
"MoeFCGemm::can_implement() - weight scales are ignored for all "
"types except uint8_t and uint4b_t"
);
return
Status
::
kInvalid
;
}
return
Status
::
kSuccess
;
}
static
size_t
get_extra_workspace_size
(
Arguments
const
&
args
,
cutlass
::
gemm
::
GemmCoord
const
&
grid_tiled_shape
)
{
return
0
;
}
/// Executes one GEMM
CUTLASS_DEVICE
void
operator
()(
Params
const
&
params
,
SharedStorage
&
shared_storage
)
{
// NOLINT
//
// These types shadow the type-level definitions and support the ability to
// implement a 'transposed' GEMM that computes the transposed problems.
//
using
ElementA
=
typename
Mma
::
IteratorA
::
Element
;
using
LayoutA
=
typename
Mma
::
IteratorA
::
Layout
;
using
ElementB
=
typename
Mma
::
IteratorB
::
Element
;
using
LayoutB
=
typename
Mma
::
IteratorB
::
Layout
;
using
ElementC
=
typename
Epilogue
::
OutputTileIterator
::
Element
;
using
LayoutC
=
typename
Epilogue
::
OutputTileIterator
::
Layout
;
static
constexpr
int
kInterleave
=
Mma
::
IteratorB
::
Shape
::
kRow
/
Mma
::
SmemIteratorB
::
Shape
::
kRow
;
static_assert
(
platform
::
is_same
<
LayoutB
,
layout
::
RowMajor
>::
value
&&
kInterleave
==
1
||
platform
::
is_same
<
LayoutB
,
layout
::
ColumnMajor
>::
value
&&
kInterleave
>=
1
,
"B must be row major/col major OR col major interleaved."
);
//
// Problem visitor.
//
ProblemVisitor
problem_visitor
(
params
.
problem_visitor
,
shared_storage
.
problem_visitor
,
blockIdx
.
x
);
const
int64_t
gemm_k
=
params
.
problem_visitor
.
gemm_k
;
const
int64_t
gemm_n
=
params
.
problem_visitor
.
gemm_n
;
int64_t
bytes_per_expert_matrix
=
(
gemm_k
*
gemm_n
/
8
)
*
cutlass
::
sizeof_bits
<
ElementB
>::
value
;
// Outer 'persistent' loop to iterate over tiles
while
(
problem_visitor
.
next_tile
())
{
GemmCoord
problem_size
=
problem_visitor
.
problem_size
();
int32_t
problem_idx
=
problem_visitor
.
problem_index
();
int32_t
cta_idx
=
int32_t
(
problem_visitor
.
threadblock_idx
());
GemmCoord
grid_shape
=
problem_visitor
.
grid_shape
(
problem_size
);
cutlass
::
gemm
::
GemmCoord
threadblock_offset
(
int
(
cta_idx
/
grid_shape
.
n
())
*
Mma
::
Shape
::
kM
,
// NOLINT
int
(
cta_idx
%
grid_shape
.
n
())
*
Mma
::
Shape
::
kN
,
// NOLINT
0
);
// Load element pointers. Exchange pointers and strides if working on the
// transpose
const
int64_t
rows_to_jump
=
problem_idx
==
0
?
0
:
params
.
problem_visitor
.
last_row_for_problem
[
problem_idx
-
1
];
ElementA
*
ptr_A
=
reinterpret_cast
<
ElementA
*>
(
params
.
ptr_A
)
+
rows_to_jump
*
gemm_k
;
typename
LayoutA
::
LongIndex
ldm_A
=
gemm_k
;
char
*
byte_ptr_B
=
((
char
*
)
params
.
ptr_B
)
+
// NOLINT
problem_idx
*
bytes_per_expert_matrix
;
ElementB
*
ptr_B
=
reinterpret_cast
<
ElementB
*>
(
byte_ptr_B
);
typename
LayoutB
::
LongIndex
ldm_B
=
platform
::
is_same
<
layout
::
RowMajor
,
LayoutB
>::
value
?
gemm_n
:
gemm_k
*
kInterleave
;
// Compute initial location in logical coordinates
cutlass
::
MatrixCoord
tb_offset_A
{
threadblock_offset
.
m
(),
0
,
};
cutlass
::
MatrixCoord
tb_offset_B
{
0
,
threadblock_offset
.
n
()
/
kInterleave
};
cutlass
::
MatrixCoord
tb_offset_scale
{
0
,
threadblock_offset
.
n
()};
// Compute position within threadblock
int
thread_idx
=
threadIdx
.
x
;
// Construct iterators to A and B operands
typename
Mma
::
IteratorA
iterator_A
(
LayoutA
(
ldm_A
),
ptr_A
,
{
problem_size
.
m
(),
problem_size
.
k
()},
thread_idx
,
tb_offset_A
);
typename
Mma
::
IteratorB
iterator_B
(
LayoutB
(
ldm_B
),
ptr_B
,
{
problem_size
.
k
()
*
kInterleave
,
problem_size
.
n
()
/
kInterleave
},
thread_idx
,
tb_offset_B
);
typename
Mma
::
FragmentC
accumulators
;
accumulators
.
clear
();
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int
warp_idx
=
__shfl_sync
(
0xffffffff
,
threadIdx
.
x
/
32
,
0
);
int
lane_idx
=
threadIdx
.
x
%
32
;
//
// Matrix multiply phase
//
// Construct thread-scoped matrix multiply
Mma
mma
(
shared_storage
.
main_loop
,
thread_idx
,
warp_idx
,
lane_idx
);
// Compute threadblock-scoped matrix multiply-add
int
gemm_k_iterations
=
(
problem_size
.
k
()
+
Mma
::
Shape
::
kK
-
1
)
/
Mma
::
Shape
::
kK
;
// Wait for all threads to finish their epilogue phases from the previous
// tile.
__syncthreads
();
// Compute threadblock-scoped matrix multiply-add
ElementScale
*
weight_scale_ptr
=
params
.
weight_scales
+
problem_idx
*
problem_size
.
n
();
run_mma
<
Mma
>
(
mma
,
gemm_k_iterations
,
accumulators
,
iterator_A
,
iterator_B
,
accumulators
,
weight_scale_ptr
,
{
1
,
problem_size
.
n
()},
thread_idx
,
tb_offset_scale
);
//
// Epilogue
//
EpilogueOutputOp
output_op
(
params
.
output_op
);
ElementC
*
ptr_C
=
reinterpret_cast
<
ElementC
*>
(
params
.
ptr_C
)
+
problem_idx
*
gemm_n
;
ElementC
*
ptr_D
=
reinterpret_cast
<
ElementC
*>
(
params
.
ptr_D
)
+
rows_to_jump
*
gemm_n
;
LayoutC
layout_C
(
0
);
LayoutC
layout_D
(
gemm_n
);
typename
Epilogue
::
OutputTileIterator
::
Params
params_C
(
layout_C
);
typename
Epilogue
::
OutputTileIterator
::
Params
params_D
(
layout_D
);
// Tile iterator loading from source tensor.
typename
Epilogue
::
OutputTileIterator
iterator_C
(
params_C
,
ptr_C
,
problem_size
.
mn
(),
thread_idx
,
threadblock_offset
.
mn
());
// Tile iterator writing to destination tensor.
typename
Epilogue
::
OutputTileIterator
iterator_D
(
params_D
,
ptr_D
,
problem_size
.
mn
(),
thread_idx
,
threadblock_offset
.
mn
());
Epilogue
epilogue
(
shared_storage
.
epilogue
,
thread_idx
,
warp_idx
,
lane_idx
);
// Execute the epilogue operator to update the destination tensor.
epilogue
(
output_op
,
iterator_D
,
accumulators
,
iterator_C
);
// Next tile
problem_visitor
.
advance
(
gridDim
.
x
);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace kernel
}
// namespace gemm
}
// namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
paddle/phi/kernels/fusion/cutlass/moe_kernel.cu
0 → 100644
浏览文件 @
def2a87f
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fusion/moe_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/fusion/cutlass/moe_kernel_impl.h"
// Ignore CUTLASS warnings about type punning
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wunused-function"
#include "cutlass/array.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/linear_combination_relu.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/numeric_conversion.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/kernels/fusion/cutlass/default_moe_fc_traits.h"
#include "paddle/phi/kernels/fusion/cutlass/linear_combination_ft_gelu.h"
#include "paddle/phi/kernels/fusion/cutlass/moe_cutlass_kernel.h"
#pragma GCC diagnostic pop
namespace
phi
{
namespace
{
inline
int
getSMVersion
()
{
const
int
device
=
phi
::
backends
::
gpu
::
GetCurrentDeviceId
();
const
phi
::
gpuDeviceProp
prop
=
phi
::
backends
::
gpu
::
GetDeviceProperties
(
device
);
return
prop
.
major
*
10
+
prop
.
minor
;
}
struct
EpilogueOpBiasReLU
{};
struct
EpilogueOpBiasFtGelu
{};
struct
EpilogueOpBias
{};
struct
EpilogueOpNoBias
{};
template
<
typename
ElementType
,
int
ElementsPerVectorAccess
,
typename
ElementAccumulator
,
typename
Op
>
struct
Epilogue
{};
template
<
typename
ElementType
,
int
ElementsPerVectorAccess
,
typename
ElementAccumulator
>
struct
Epilogue
<
ElementType
,
ElementsPerVectorAccess
,
ElementAccumulator
,
EpilogueOpBiasReLU
>
{
using
Op
=
cutlass
::
epilogue
::
thread
::
LinearCombinationRelu
<
ElementType
,
ElementsPerVectorAccess
,
ElementAccumulator
,
ElementAccumulator
,
cutlass
::
epilogue
::
thread
::
ScaleType
::
NoBetaScaling
>
;
};
template
<
typename
ElementType
,
int
ElementsPerVectorAccess
,
typename
ElementAccumulator
>
struct
Epilogue
<
ElementType
,
ElementsPerVectorAccess
,
ElementAccumulator
,
EpilogueOpBiasFtGelu
>
{
using
Op
=
cutlass
::
epilogue
::
thread
::
LinearCombinationFtGelu
<
ElementType
,
ElementsPerVectorAccess
,
ElementAccumulator
,
ElementAccumulator
,
cutlass
::
epilogue
::
thread
::
ScaleType
::
NoBetaScaling
>
;
};
template
<
typename
ElementType
,
int
ElementsPerVectorAccess
,
typename
ElementAccumulator
>
struct
Epilogue
<
ElementType
,
ElementsPerVectorAccess
,
ElementAccumulator
,
EpilogueOpBias
>
{
using
Op
=
cutlass
::
epilogue
::
thread
::
LinearCombination
<
ElementType
,
ElementsPerVectorAccess
,
ElementAccumulator
,
ElementAccumulator
,
cutlass
::
epilogue
::
thread
::
ScaleType
::
NoBetaScaling
>
;
};
template
<
typename
ElementType
,
int
ElementsPerVectorAccess
,
typename
ElementAccumulator
>
struct
Epilogue
<
ElementType
,
ElementsPerVectorAccess
,
ElementAccumulator
,
EpilogueOpNoBias
>
{
using
Op
=
cutlass
::
epilogue
::
thread
::
LinearCombination
<
ElementType
,
ElementsPerVectorAccess
,
ElementAccumulator
,
ElementAccumulator
,
cutlass
::
epilogue
::
thread
::
ScaleType
::
Nothing
>
;
};
}
// namespace
namespace
fusion
{
template
<
typename
T
>
void
InitExpertChoiceRouteKernelLauncher
(
int
*
expert_for_source_row
,
int
*
source_row
,
int
*
expanded_source_row_to_expanded_dest_row
,
int64_t
*
total_rows_before_expert
,
T
*
attr_mask
,
const
int
num_experts
,
const
int
num_rows
,
const
int
k
,
const
int
batch_size
,
cudaStream_t
stream
)
{
const
int
threads
=
128
;
const
int
blocks
=
num_experts
;
initialize_expert_choice_route_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
expert_for_source_row
,
source_row
,
expanded_source_row_to_expanded_dest_row
,
total_rows_before_expert
,
attr_mask
,
num_rows
,
k
,
batch_size
);
}
#define SOFTMAX_KERNEL(ITEMS_PER_THREAD) \
block.x /= ITEMS_PER_THREAD; \
assert(block.x <= 1024); \
if (is_half2) { \
if (grid.x % 4 == 0) { \
grid.x /= 4; \
softmax_kernel_v5_half2<__half, ITEMS_PER_THREAD, 4> \
<<<grid, block, 0, stream>>>(reinterpret_cast<half*>(buffer), \
(const half*)attr_mask, \
batch_size, \
head_num, \
seq_len_1, \
seq_len_2, \
(const half)scalar); \
} else { \
softmax_kernel_v4_half2<__half, ITEMS_PER_THREAD> \
<<<grid, block, 0, stream>>>(reinterpret_cast<half*>(buffer), \
(const half*)attr_mask, \
batch_size, \
head_num, \
seq_len_1, \
seq_len_2, \
(const half)scalar); \
} \
} else { \
softmax_kernel_v4<ITEMS_PER_THREAD, T> \
<<<grid, block, 0, stream>>>(buffer, \
buffer_src, \
attr_mask, \
batch_size, \
head_num, \
seq_len_1, \
seq_len_2, \
scalar); \
}
template
<
typename
T
>
void
invokeMaskedSoftMax
(
T
*
buffer
,
const
T
*
buffer_src
,
const
T
*
attr_mask
,
const
int
batch_size
,
const
int
seq_len_1
,
const
int
seq_len_2
,
const
int
head_num
,
const
T
scalar
,
cudaStream_t
stream
)
{
// NOTE: attention scores shape (batch_size, head_num, seq_len_1, seq_len_2)
dim3
grid
(
seq_len_1
,
batch_size
,
head_num
);
if
(
batch_size
*
head_num
>
360
)
{
grid
.
x
=
ceil
(
static_cast
<
float
>
(
seq_len_1
)
/
32.0
f
);
}
bool
is_half2
=
sizeof
(
T
)
==
2
&&
sizeof
(
T
)
==
2
&&
seq_len_2
%
2
==
0
;
dim3
block
((
seq_len_2
/
(
is_half2
?
2
:
1
)
+
31
)
/
32
*
32
);
if
(
block
.
x
>
2048
&&
block
.
x
<=
4096
)
{
SOFTMAX_KERNEL
(
4
)
}
else
if
(
block
.
x
>
1024
)
{
SOFTMAX_KERNEL
(
2
)
}
else
if
(
block
.
x
>
0
)
{
SOFTMAX_KERNEL
(
1
)
}
else
{
PADDLE_ENFORCE_EQ
(
true
,
false
,
phi
::
errors
::
InvalidArgument
(
"Softmax kernel only support columns in 0 - 4096. "
));
}
}
template
<
typename
T
>
void
InvokeTransposeAxis01
(
T
*
out
,
T
*
in
,
const
int
dim0
,
const
int
dim1
,
const
int
dim2
,
cudaStream_t
stream
)
{
dim3
block
(
512
);
dim3
grid
(
static_cast
<
int
>
(
ceil
(
dim0
*
dim1
*
dim2
/
512.
)));
transposeAxis01
<<<
grid
,
block
,
0
,
stream
>>>
(
out
,
in
,
dim0
,
dim1
,
dim2
);
}
template
<
typename
T
>
void
InvokePadding
(
T
*
output1
,
int
*
output2
,
const
T
*
input1
,
const
int
*
input2
,
const
int
*
input_lengths
,
const
int
num_tokens
,
const
int
batch_size
,
const
int
max_seq_len
,
const
int
num_experts
,
cudaStream_t
stream
)
{
assert
(
max_seq_len
<=
1024
);
dim3
block
(
max_seq_len
);
dim3
grid
(
num_experts
);
paddingKernel
<<<
grid
,
block
,
0
,
stream
>>>
(
output1
,
output2
,
input1
,
input2
,
input_lengths
,
num_tokens
,
batch_size
,
max_seq_len
,
num_experts
);
}
template
<
typename
T
>
void
InvokeGeneralTopKPairSort
(
T
*
out_keys
,
int
*
out_values
,
T
*
in_keys
,
int
*
in_values
,
const
int
m
,
const
int
n
,
cudaStream_t
stream
)
{
assert
(
n
<=
4096
);
const
int
blocks
=
m
;
if
(
n
==
128
)
{
general_topk_pair_sort
<
T
,
32
,
4
>
<<<
blocks
,
32
,
0
,
stream
>>>
(
out_keys
,
out_values
,
in_keys
,
in_values
);
}
if
(
n
==
256
)
{
general_topk_pair_sort
<
T
,
64
,
4
>
<<<
blocks
,
64
,
0
,
stream
>>>
(
out_keys
,
out_values
,
in_keys
,
in_values
);
}
if
(
n
==
1024
)
{
general_topk_pair_sort
<
T
,
256
,
4
>
<<<
blocks
,
256
,
0
,
stream
>>>
(
out_keys
,
out_values
,
in_keys
,
in_values
);
}
else
if
(
n
==
2048
)
{
general_topk_pair_sort
<
T
,
512
,
4
>
<<<
blocks
,
512
,
0
,
stream
>>>
(
out_keys
,
out_values
,
in_keys
,
in_values
);
}
else
if
(
n
==
4096
)
{
general_topk_pair_sort
<
T
,
1024
,
4
>
<<<
blocks
,
1024
,
0
,
stream
>>>
(
out_keys
,
out_values
,
in_keys
,
in_values
);
}
}
template
<
typename
T
>
void
InitMoeRoutingKernelLauncher
(
const
T
*
unpermuted_input
,
T
*
permuted_output
,
const
int
*
expanded_dest_row_to_expanded_source_row
,
int
*
expanded_source_row_to_expanded_dest_row
,
const
int
num_experts
,
const
int
num_rows
,
const
int
active_rows
,
const
int
cols
,
const
int
k
,
const
int
batch_size
,
const
int
max_seq_len
,
bool
ec_route
,
cudaStream_t
stream
)
{
const
int
blocks
=
ec_route
?
num_experts
*
k
*
batch_size
:
num_rows
*
k
;
if
(
ec_route
)
{
constexpr
int
max_pack_size
=
16
/
sizeof
(
T
);
const
int
threads
=
std
::
min
(
cols
/
max_pack_size
,
1024
);
if
(
cols
%
max_pack_size
==
0
)
{
initialize_moe_routing_kernel
<
T
,
max_pack_size
>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
unpermuted_input
,
permuted_output
,
expanded_dest_row_to_expanded_source_row
,
expanded_source_row_to_expanded_dest_row
,
num_rows
,
batch_size
*
k
*
num_experts
,
cols
,
k
,
max_seq_len
,
ec_route
);
}
else
{
initialize_moe_routing_kernel
<
T
,
1
><<<
blocks
,
threads
,
0
,
stream
>>>
(
unpermuted_input
,
permuted_output
,
expanded_dest_row_to_expanded_source_row
,
expanded_source_row_to_expanded_dest_row
,
num_rows
,
batch_size
*
k
*
num_experts
,
cols
,
k
,
max_seq_len
,
ec_route
);
}
}
else
{
PADDLE_THROW
(
paddle
::
platform
::
errors
::
InvalidArgument
(
"Currently only support `ec_route = True`. "
));
}
}
template
<
typename
T
,
typename
WeightType
,
typename
arch
,
typename
EpilogueType
>
void
GenericMoeGemmKernelLauncher
(
const
T
*
A
,
const
T
*
B
,
const
T
*
weight_scales
,
const
T
*
biases
,
T
*
C
,
int64_t
*
total_rows_before_expert
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int
num_experts
,
const
int
multi_processor_count
,
cudaStream_t
stream
)
{
static_assert
(
cutlass
::
platform
::
is_same
<
T
,
half
>::
value
||
cutlass
::
platform
::
is_same
<
T
,
float
>::
value
,
"Specialized for half, float"
);
static_assert
(
cutlass
::
platform
::
is_same
<
T
,
WeightType
>::
value
||
cutlass
::
platform
::
is_same
<
WeightType
,
uint8_t
>::
value
||
cutlass
::
platform
::
is_same
<
WeightType
,
cutlass
::
uint4b_t
>::
value
,
"cutlass weight type only support float, half, uint8_t, uint4b_t"
);
// The cutlass type for the input elements. This is needed to convert to
// cutlass::half_t if necessary.
using
ElementType_
=
typename
cutlass
::
platform
::
conditional
<
cutlass
::
platform
::
is_same
<
T
,
half
>::
value
,
cutlass
::
half_t
,
T
>::
type
;
using
ElementType
=
ElementType_
;
using
CutlassWeightType_
=
typename
cutlass
::
platform
::
conditional
<
cutlass
::
platform
::
is_same
<
WeightType
,
half
>::
value
,
cutlass
::
half_t
,
WeightType
>::
type
;
using
CutlassWeightType
=
CutlassWeightType_
;
// We need separate config for each architecture since we will target
// different tensorcore instructions. For float, we do not target TCs.
using
MoeArchTraits
=
cutlass
::
gemm
::
kernel
::
MoeArchTraits
<
ElementType
,
CutlassWeightType
,
arch
>
;
using
ElementAccumulator
=
typename
MoeArchTraits
::
AccType
;
using
EpilogueOp
=
typename
Epilogue
<
ElementType
,
MoeArchTraits
::
ElementsPerAccessC
,
ElementAccumulator
,
EpilogueType
>::
Op
;
// Finally, set up the kernel.
using
GemmKernel_
=
typename
cutlass
::
gemm
::
kernel
::
DefaultGemmGrouped
<
ElementType
,
cutlass
::
layout
::
RowMajor
,
cutlass
::
ComplexTransform
::
kNone
,
MoeArchTraits
::
ElementsPerAccessA
,
CutlassWeightType
,
typename
MoeArchTraits
::
LayoutB
,
cutlass
::
ComplexTransform
::
kNone
,
MoeArchTraits
::
ElementsPerAccessB
,
ElementType
,
cutlass
::
layout
::
RowMajor
,
ElementAccumulator
,
typename
MoeArchTraits
::
OperatorClass
,
arch
,
typename
MoeArchTraits
::
ThreadBlockShape
,
typename
MoeArchTraits
::
WarpShape
,
typename
MoeArchTraits
::
InstructionShape
,
EpilogueOp
,
cutlass
::
gemm
::
threadblock
::
GemmBatchedIdentityThreadblockSwizzle
,
MoeArchTraits
::
Stages
,
cutlass
::
gemm
::
kernel
::
GroupScheduleMode
::
kDeviceOnly
,
typename
MoeArchTraits
::
Operator
>::
GemmKernel
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
MoeFCGemm
<
typename
GemmKernel_
::
Mma
,
typename
GemmKernel_
::
Epilogue
,
typename
GemmKernel_
::
ThreadblockSwizzle
,
GemmKernel_
::
kGroupScheduleMode
>
;
using
GemmGrouped
=
cutlass
::
gemm
::
device
::
GemmGrouped
<
GemmKernel
>
;
int
occupancy
=
GemmGrouped
::
maximum_active_blocks
();
const
int
threadblock_count
=
multi_processor_count
*
occupancy
;
if
(
occupancy
==
0
)
{
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Fatal
(
"[MoE Runner] GPU lacks the shared memory resources to run GroupedGEMM "
"kernel"
));
}
typename
EpilogueOp
::
Params
epilogue_op
(
ElementAccumulator
(
1.
f
),
ElementAccumulator
(
1.
f
));
typename
GemmGrouped
::
Arguments
args
(
num_experts
,
threadblock_count
,
epilogue_op
,
reinterpret_cast
<
const
ElementType
*>
(
A
),
reinterpret_cast
<
const
CutlassWeightType
*>
(
B
),
reinterpret_cast
<
const
ElementType
*>
(
weight_scales
),
reinterpret_cast
<
const
ElementType
*>
(
biases
),
reinterpret_cast
<
ElementType
*>
(
C
),
total_rows_before_expert
,
gemm_n
,
gemm_k
);
GemmGrouped
gemm
;
auto
can_implement
=
gemm
.
can_implement
(
args
);
if
(
can_implement
!=
cutlass
::
Status
::
kSuccess
)
{
std
::
string
err_msg
=
"MoEFC kernel will fail for params. Error: "
+
std
::
string
(
cutlassGetStatusString
(
can_implement
));
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Fatal
(
"[MoE Runner] "
+
err_msg
));
}
auto
init_status
=
gemm
.
initialize
(
args
);
if
(
init_status
!=
cutlass
::
Status
::
kSuccess
)
{
std
::
string
err_msg
=
"Failed to initialize cutlass variable batched gemm. Error: "
+
std
::
string
(
cutlassGetStatusString
(
init_status
));
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Fatal
(
"[MoE Runner] "
+
err_msg
));
}
auto
run_status
=
gemm
.
run
(
stream
);
if
(
run_status
!=
cutlass
::
Status
::
kSuccess
)
{
std
::
string
err_msg
=
"Failed to run cutlass variable batched gemm. Error: "
+
std
::
string
(
cutlassGetStatusString
(
run_status
));
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Fatal
(
"[MoE Runner] "
+
err_msg
));
}
}
template
<
typename
T
>
void
gemm_bias_act
(
const
T
*
A
,
const
T
*
B
,
const
T
*
weight_scales
,
const
T
*
biases
,
T
*
C
,
int64_t
*
total_rows_before_expert
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int
num_experts
,
int
sm
,
int
multi_processor_count
,
const
std
::
string
&
act_type
,
cudaStream_t
stream
)
{
if
(
act_type
==
"gelu"
)
{
if
(
sm
==
75
)
{
GenericMoeGemmKernelLauncher
<
T
,
T
,
cutlass
::
arch
::
Sm75
,
EpilogueOpBiasFtGelu
>
(
A
,
B
,
weight_scales
,
biases
,
C
,
total_rows_before_expert
,
gemm_n
,
gemm_k
,
num_experts
,
multi_processor_count
,
stream
);
}
else
if
(
sm
==
80
||
sm
==
86
)
{
GenericMoeGemmKernelLauncher
<
T
,
T
,
cutlass
::
arch
::
Sm80
,
EpilogueOpBiasFtGelu
>
(
A
,
B
,
weight_scales
,
biases
,
C
,
total_rows_before_expert
,
gemm_n
,
gemm_k
,
num_experts
,
multi_processor_count
,
stream
);
}
else
{
GenericMoeGemmKernelLauncher
<
T
,
T
,
cutlass
::
arch
::
Sm70
,
EpilogueOpBiasFtGelu
>
(
A
,
B
,
weight_scales
,
biases
,
C
,
total_rows_before_expert
,
gemm_n
,
gemm_k
,
num_experts
,
multi_processor_count
,
stream
);
}
}
else
{
// act type is relu.
if
(
sm
==
75
)
{
GenericMoeGemmKernelLauncher
<
T
,
T
,
cutlass
::
arch
::
Sm75
,
EpilogueOpBiasReLU
>
(
A
,
B
,
weight_scales
,
biases
,
C
,
total_rows_before_expert
,
gemm_n
,
gemm_k
,
num_experts
,
multi_processor_count
,
stream
);
}
else
if
(
sm
==
80
||
sm
==
86
)
{
GenericMoeGemmKernelLauncher
<
T
,
T
,
cutlass
::
arch
::
Sm80
,
EpilogueOpBiasReLU
>
(
A
,
B
,
weight_scales
,
biases
,
C
,
total_rows_before_expert
,
gemm_n
,
gemm_k
,
num_experts
,
multi_processor_count
,
stream
);
}
else
{
GenericMoeGemmKernelLauncher
<
T
,
T
,
cutlass
::
arch
::
Sm70
,
EpilogueOpBiasReLU
>
(
A
,
B
,
weight_scales
,
biases
,
C
,
total_rows_before_expert
,
gemm_n
,
gemm_k
,
num_experts
,
multi_processor_count
,
stream
);
}
}
}
template
<
typename
T
>
void
gemm
(
const
T
*
A
,
const
T
*
B
,
const
T
*
weight_scales
,
T
*
C
,
int64_t
*
total_rows_before_expert
,
const
int
gemm_n
,
const
int
gemm_k
,
const
int
num_experts
,
int
sm
,
int
multi_processor_count
,
cudaStream_t
stream
)
{
if
(
sm
==
75
)
{
GenericMoeGemmKernelLauncher
<
T
,
T
,
cutlass
::
arch
::
Sm75
,
EpilogueOpNoBias
>
(
A
,
B
,
weight_scales
,
nullptr
,
C
,
total_rows_before_expert
,
gemm_n
,
gemm_k
,
num_experts
,
multi_processor_count
,
stream
);
}
else
if
(
sm
==
80
||
sm
==
86
)
{
GenericMoeGemmKernelLauncher
<
T
,
T
,
cutlass
::
arch
::
Sm80
,
EpilogueOpNoBias
>
(
A
,
B
,
weight_scales
,
nullptr
,
C
,
total_rows_before_expert
,
gemm_n
,
gemm_k
,
num_experts
,
multi_processor_count
,
stream
);
}
else
{
GenericMoeGemmKernelLauncher
<
T
,
T
,
cutlass
::
arch
::
Sm70
,
EpilogueOpNoBias
>
(
A
,
B
,
weight_scales
,
nullptr
,
C
,
total_rows_before_expert
,
gemm_n
,
gemm_k
,
num_experts
,
multi_processor_count
,
stream
);
}
}
template
<
typename
T
>
void
finalize_moe_routing_kernelLauncher
(
const
T
*
expanded_permuted_rows
,
T
*
reduced_unpermuted_output
,
const
T
*
skip
,
const
T
*
bias
,
const
T
*
scales
,
const
int
*
expanded_source_row_to_expanded_dest_row
,
const
int
*
expert_for_source_row
,
const
int
num_experts
,
const
int
num_rows
,
const
int
cols
,
const
int
k
,
bool
ec_route
,
cudaStream_t
stream
)
{
const
int
blocks
=
num_rows
;
const
int
threads
=
std
::
min
(
cols
,
1024
);
{
finalize_moe_routing_kernel
<
T
><<<
blocks
,
threads
,
0
,
stream
>>>
(
expanded_permuted_rows
,
reduced_unpermuted_output
,
skip
,
bias
,
scales
,
expanded_source_row_to_expanded_dest_row
,
expert_for_source_row
,
cols
,
num_experts
,
ec_route
);
}
}
template
<
typename
T
,
typename
Context
>
void
MoeKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
gate
,
const
DenseTensor
&
bmm0
,
const
DenseTensor
&
bias0
,
const
DenseTensor
&
bmm1
,
const
DenseTensor
&
bias1
,
const
std
::
string
&
act_type
,
DenseTensor
*
output
)
{
const
T
*
input_activations
=
x
.
data
<
T
>
();
T
*
gating_output
=
const_cast
<
T
*>
(
gate
.
data
<
T
>
());
const
T
*
fc1_expert_weights
=
bmm0
.
data
<
T
>
();
const
T
*
fc1_expert_biases
=
bias0
.
data
<
T
>
();
const
T
*
fc2_expert_weights
=
bmm1
.
data
<
T
>
();
const
T
*
fc2_expert_biases
=
bias1
.
data
<
T
>
();
// int moe_act = static_cast<int>(act);
T
*
output_
=
ctx
.
template
Alloc
<
T
>(
output
);
auto
stream
=
ctx
.
stream
();
auto
input_dims
=
x
.
dims
();
auto
bmm0_dims
=
bmm0
.
dims
();
const
bool
IS_FP16
=
std
::
is_same
<
T
,
phi
::
dtype
::
float16
>::
value
;
const
int
num_rows
=
input_dims
[
0
]
*
input_dims
[
1
];
const
int
hidden_size
=
input_dims
[
2
];
const
int
inter_size
=
bmm0_dims
[
2
];
const
int
num_experts
=
bmm0_dims
[
0
];
const
int
k
=
input_dims
[
1
]
/
16
;
const
int
batch_size
=
input_dims
[
0
];
const
int
max_seq_len
=
128
;
int64_t
bytes
=
getWorkspaceSize
<
T
>
(
num_rows
,
hidden_size
,
inter_size
,
num_experts
,
k
,
batch_size
,
max_seq_len
);
// Pointers
int
*
source_rows
;
int
*
padded_source_rows
;
int
*
permuted_rows
;
int
*
permuted_experts
;
char
*
sorter_ws_
;
T
*
permuted_data
;
T
*
padded_expert_scales
;
int64_t
*
total_rows_before_expert
;
T
*
sorted_softmax_output
;
T
*
attr_mask
;
T
*
fc1_result
;
phi
::
DenseTensor
ws_ptr_tensor
=
phi
::
Empty
<
int8_t
>
(
ctx
,
{
bytes
});
int8_t
*
ws_ptr
=
ws_ptr_tensor
.
data
<
int8_t
>
();
const
int
buf_size
=
AlignTo16
(
num_experts
*
batch_size
*
k
*
hidden_size
);
const
int
padded_experts
=
AlignTo16
(
num_experts
);
const
int
num_moe_inputs
=
AlignTo16
(
num_experts
*
num_rows
);
// padded_num_moe_inputs for topk sort
int
padded_num_moe_inputs
=
num_experts
*
batch_size
*
max_seq_len
;
source_rows
=
reinterpret_cast
<
int
*>
(
ws_ptr
);
padded_source_rows
=
source_rows
+
num_moe_inputs
;
permuted_rows
=
padded_source_rows
+
padded_num_moe_inputs
;
permuted_experts
=
permuted_rows
+
padded_num_moe_inputs
;
permuted_data
=
reinterpret_cast
<
T
*>
(
permuted_experts
+
num_experts
*
k
);
padded_expert_scales
=
reinterpret_cast
<
T
*>
(
permuted_data
+
buf_size
);
total_rows_before_expert
=
reinterpret_cast
<
int64_t
*>
(
padded_expert_scales
+
padded_num_moe_inputs
);
sorted_softmax_output
=
reinterpret_cast
<
T
*>
(
total_rows_before_expert
+
padded_experts
);
attr_mask
=
reinterpret_cast
<
T
*>
(
sorted_softmax_output
+
padded_num_moe_inputs
);
fc1_result
=
reinterpret_cast
<
T
*>
(
attr_mask
+
num_moe_inputs
);
phi
::
DenseTensor
expert_for_source_row_tensor
=
phi
::
Empty
<
int
>
(
ctx
,
{
num_experts
,
num_rows
});
int
*
expert_for_source_row
=
expert_for_source_row_tensor
.
data
<
int
>
();
phi
::
DenseTensor
expanded_source_row_to_expanded_dest_row_tensor
=
phi
::
Empty
<
int
>
(
ctx
,
{
num_experts
,
num_rows
});
int
*
expanded_source_row_to_expanded_dest_row
=
expanded_source_row_to_expanded_dest_row_tensor
.
data
<
int
>
();
phi
::
DenseTensor
expert_scales_tensor
=
phi
::
Empty
<
T
>
(
ctx
,
{
num_experts
,
num_rows
});
T
*
expert_scales
=
expert_scales_tensor
.
data
<
T
>
();
phi
::
DenseTensor
fc2_output_tensor
=
phi
::
Empty
<
T
>
(
ctx
,
{
num_experts
*
batch_size
*
k
,
hidden_size
});
T
*
fc2_result
=
fc2_output_tensor
.
data
<
T
>
();
phi
::
DenseTensor
input_lengths_tensor
=
phi
::
Empty
<
int
>
(
ctx
,
{
batch_size
});
int
*
input_lengths
=
input_lengths_tensor
.
data
<
int
>
();
funcs
::
SetConstant
<
Context
,
int
>
set_len
;
set_len
(
ctx
,
&
input_lengths_tensor
,
static_cast
<
int
>
(
max_seq_len
));
int
sm
=
getSMVersion
();
int
multi_processor_count
=
phi
::
backends
::
gpu
::
GetGPUMultiProcessors
(
phi
::
backends
::
gpu
::
GetCurrentDeviceId
());
InitExpertChoiceRouteKernelLauncher
<
T
>
(
expert_for_source_row
,
source_rows
,
expanded_source_row_to_expanded_dest_row
,
total_rows_before_expert
,
attr_mask
,
num_experts
,
num_rows
,
k
,
batch_size
,
ctx
.
stream
());
T
scalar
=
(
T
)
1.0
f
;
if
(
IS_FP16
)
{
invokeMaskedSoftMax
<
__half
>
(
reinterpret_cast
<
__half
*>
(
gating_output
),
reinterpret_cast
<
const
__half
*>
(
gating_output
),
reinterpret_cast
<
const
__half
*>
(
attr_mask
),
/*batch_size=*/
num_rows
,
/*seq_len_1=*/
1
,
/*seq_len_2=*/
num_experts
,
/*head_num=*/
1
,
*
reinterpret_cast
<
const
__half
*>
(
&
scalar
),
ctx
.
stream
());
}
else
{
invokeMaskedSoftMax
<
float
>
(
reinterpret_cast
<
float
*>
(
gating_output
),
reinterpret_cast
<
const
float
*>
(
gating_output
),
reinterpret_cast
<
const
float
*>
(
attr_mask
),
/*batch_size=*/
num_rows
,
/*seq_len_1=*/
1
,
/*seq_len_2=*/
num_experts
,
/*head_num=*/
1
,
*
reinterpret_cast
<
const
float
*>
(
&
scalar
),
ctx
.
stream
());
}
InvokeTransposeAxis01
(
expert_scales
,
gating_output
,
num_rows
,
num_experts
,
1
,
ctx
.
stream
());
int
padded_max_seq_len
=
max_seq_len
<=
128
?
128
:
256
;
InvokePadding
(
padded_expert_scales
,
padded_source_rows
,
expert_scales
,
source_rows
,
input_lengths
,
num_rows
,
batch_size
,
padded_max_seq_len
,
num_experts
,
ctx
.
stream
());
if
(
IS_FP16
)
{
InvokeGeneralTopKPairSort
<
__half
>
(
reinterpret_cast
<
__half
*>
(
sorted_softmax_output
),
permuted_rows
,
reinterpret_cast
<
__half
*>
(
padded_expert_scales
),
padded_source_rows
,
num_experts
*
batch_size
,
padded_max_seq_len
,
ctx
.
stream
());
}
else
{
InvokeGeneralTopKPairSort
<
float
>
(
reinterpret_cast
<
float
*>
(
sorted_softmax_output
),
permuted_rows
,
reinterpret_cast
<
float
*>
(
padded_expert_scales
),
padded_source_rows
,
num_experts
*
batch_size
,
padded_max_seq_len
,
ctx
.
stream
());
}
InitMoeRoutingKernelLauncher
(
input_activations
,
permuted_data
,
permuted_rows
,
expanded_source_row_to_expanded_dest_row
,
num_experts
,
num_rows
,
num_rows
,
hidden_size
,
k
,
batch_size
,
max_seq_len
,
true
,
ctx
.
stream
());
const
T
*
fc1_scales
=
nullptr
;
const
T
*
fc2_scales
=
nullptr
;
if
(
IS_FP16
)
{
gemm_bias_act
(
reinterpret_cast
<
const
__half
*>
(
permuted_data
),
reinterpret_cast
<
const
__half
*>
(
fc1_expert_weights
),
reinterpret_cast
<
const
__half
*>
(
fc1_scales
),
reinterpret_cast
<
const
__half
*>
(
fc1_expert_biases
),
reinterpret_cast
<
__half
*>
(
fc1_result
),
total_rows_before_expert
,
inter_size
,
hidden_size
,
num_experts
,
sm
,
multi_processor_count
,
act_type
,
ctx
.
stream
());
gemm
(
reinterpret_cast
<
const
__half
*>
(
fc1_result
),
reinterpret_cast
<
const
__half
*>
(
fc2_expert_weights
),
reinterpret_cast
<
const
__half
*>
(
fc2_scales
),
reinterpret_cast
<
__half
*>
(
fc2_result
),
total_rows_before_expert
,
hidden_size
,
inter_size
,
num_experts
,
sm
,
multi_processor_count
,
ctx
.
stream
());
}
else
{
gemm_bias_act
<
float
>
(
reinterpret_cast
<
const
float
*>
(
permuted_data
),
reinterpret_cast
<
const
float
*>
(
fc1_expert_weights
),
reinterpret_cast
<
const
float
*>
(
fc1_scales
),
reinterpret_cast
<
const
float
*>
(
fc1_expert_biases
),
reinterpret_cast
<
float
*>
(
fc1_result
),
total_rows_before_expert
,
inter_size
,
hidden_size
,
num_experts
,
sm
,
multi_processor_count
,
act_type
,
ctx
.
stream
());
gemm
<
float
>
(
reinterpret_cast
<
const
float
*>
(
fc1_result
),
reinterpret_cast
<
const
float
*>
(
fc2_expert_weights
),
reinterpret_cast
<
const
float
*>
(
fc2_scales
),
reinterpret_cast
<
float
*>
(
fc2_result
),
total_rows_before_expert
,
hidden_size
,
inter_size
,
num_experts
,
sm
,
multi_processor_count
,
ctx
.
stream
());
}
finalize_moe_routing_kernelLauncher
(
fc2_result
,
output_
,
input_activations
,
fc2_expert_biases
,
expert_scales
,
expanded_source_row_to_expanded_dest_row
,
expert_for_source_row
,
num_experts
,
num_rows
,
hidden_size
,
k
,
true
,
ctx
.
stream
());
}
}
// namespace fusion
}
// namespace phi
PD_REGISTER_KERNEL
(
moe
,
GPU
,
ALL_LAYOUT
,
phi
::
fusion
::
MoeKernel
,
float
,
phi
::
dtype
::
float16
)
{}
paddle/phi/kernels/fusion/cutlass/moe_kernel_impl.h
0 → 100644
浏览文件 @
def2a87f
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "cub/cub.cuh"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
namespace
phi
{
static
const
float
HALF_FLT_MAX
=
65504.
F
;
static
const
float
HALF_FLT_MIN
=
-
65504.
F
;
static
inline
size_t
AlignTo16
(
const
size_t
&
input
)
{
static
constexpr
int
ALIGNMENT
=
16
;
return
ALIGNMENT
*
((
input
+
ALIGNMENT
-
1
)
/
ALIGNMENT
);
}
/*
WarpReduce multi values.
TODO(zhengzekang): Add blocksize templates to reduce shared memory usage.
*/
template
<
typename
T
,
int
NUM
>
__inline__
__device__
T
warpReduceSumV2
(
T
*
val
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
#pragma unroll
for
(
int
mask
=
16
;
mask
>
0
;
mask
>>=
1
)
val
[
i
]
+=
__shfl_xor_sync
(
FINAL_MASK
,
val
[
i
],
mask
,
32
);
}
return
(
T
)(
0.0
f
);
}
template
<
typename
T
,
int
NUM
>
__inline__
__device__
T
blockReduceSumV2
(
T
*
val
)
{
static
__shared__
T
shared
[
NUM
][
33
];
int
lane
=
threadIdx
.
x
&
0x1f
;
int
wid
=
threadIdx
.
x
>>
5
;
warpReduceSumV2
<
T
,
NUM
>
(
val
);
if
(
lane
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
shared
[
i
][
wid
]
=
val
[
i
];
}
}
__syncthreads
();
bool
is_mask
=
threadIdx
.
x
<
(
blockDim
.
x
/
32.
f
);
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
val
[
i
]
=
is_mask
?
shared
[
i
][
lane
]
:
(
T
)(
0.0
f
);
}
warpReduceSumV2
<
T
,
NUM
>
(
val
);
return
(
T
)
0.0
f
;
}
template
<
typename
T
,
int
NUM
>
__inline__
__device__
T
warpReduceMaxV2
(
T
*
val
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
#pragma unroll
for
(
int
mask
=
16
;
mask
>
0
;
mask
>>=
1
)
val
[
i
]
=
max
(
val
[
i
],
__shfl_xor_sync
(
FINAL_MASK
,
val
[
i
],
mask
,
32
));
}
return
(
T
)(
0.0
f
);
}
template
<
typename
T
,
int
NUM
>
__inline__
__device__
T
blockReduceMaxV2
(
T
*
val
)
{
static
__shared__
T
shared
[
32
][
NUM
];
int
lane
=
threadIdx
.
x
&
0x1f
;
// in-warp idx
int
wid
=
threadIdx
.
x
>>
5
;
// warp idx
warpReduceMaxV2
<
T
,
NUM
>
(
val
);
// get maxx in each warp
if
(
lane
==
0
)
{
// record in-warp maxx by warp Idx
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
shared
[
wid
][
i
]
=
val
[
i
];
}
}
__syncthreads
();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
bool
is_mask
=
threadIdx
.
x
<
(
blockDim
.
x
/
32.
f
);
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
val
[
i
]
=
is_mask
?
shared
[
lane
][
i
]
:
(
T
)
-
1e20
f
;
}
warpReduceMaxV2
<
T
,
NUM
>
(
val
);
return
(
T
)
0.0
f
;
}
class
CubKeyValueSorter
{
public:
CubKeyValueSorter
();
explicit
CubKeyValueSorter
(
const
int
num_experts
);
void
update_num_experts
(
const
int
num_experts
);
size_t
getWorkspaceSize
(
const
size_t
num_key_value_pairs
,
bool
descending
=
false
);
template
<
typename
KeyT
>
void
run
(
void
*
workspace
,
const
size_t
workspace_size
,
const
KeyT
*
keys_in
,
KeyT
*
keys_out
,
const
int
*
values_in
,
int
*
values_out
,
const
size_t
num_key_value_pairs
,
bool
descending
,
cudaStream_t
stream
);
private:
size_t
num_key_value_pairs_
;
int
num_experts_
;
int
num_bits_
;
};
// ===== CUB Sorting things =====
CubKeyValueSorter
::
CubKeyValueSorter
()
:
num_experts_
(
0
),
num_bits_
(
sizeof
(
int
)
*
8
)
{}
CubKeyValueSorter
::
CubKeyValueSorter
(
const
int
num_experts
)
:
num_experts_
(
num_experts
),
num_bits_
(
static_cast
<
int
>
(
log2
(
num_experts
))
+
1
)
{}
void
CubKeyValueSorter
::
update_num_experts
(
const
int
num_experts
)
{
num_experts_
=
num_experts
;
num_bits_
=
static_cast
<
int
>
(
log2
(
num_experts
))
+
1
;
}
size_t
CubKeyValueSorter
::
getWorkspaceSize
(
const
size_t
num_key_value_pairs
,
bool
descending
)
{
num_key_value_pairs_
=
num_key_value_pairs
;
size_t
required_storage
=
0
;
int
*
null_int
=
nullptr
;
if
(
descending
)
{
cub
::
DeviceRadixSort
::
SortPairsDescending
(
NULL
,
required_storage
,
null_int
,
null_int
,
null_int
,
null_int
,
num_key_value_pairs
,
0
,
32
);
}
else
{
cub
::
DeviceRadixSort
::
SortPairs
(
NULL
,
required_storage
,
null_int
,
null_int
,
null_int
,
null_int
,
num_key_value_pairs
,
0
,
num_bits_
);
}
return
required_storage
;
}
template
<
typename
KeyT
>
void
CubKeyValueSorter
::
run
(
void
*
workspace
,
const
size_t
workspace_size
,
const
KeyT
*
keys_in
,
KeyT
*
keys_out
,
const
int
*
values_in
,
int
*
values_out
,
const
size_t
num_key_value_pairs
,
bool
descending
,
cudaStream_t
stream
)
{
size_t
expected_ws_size
=
getWorkspaceSize
(
num_key_value_pairs
);
size_t
actual_ws_size
=
workspace_size
;
if
(
expected_ws_size
>
workspace_size
)
{
std
::
stringstream
err_ss
;
err_ss
<<
"[Error][CubKeyValueSorter::run]
\n
"
;
err_ss
<<
"Error. The allocated workspace is too small to run this problem.
\n
"
;
err_ss
<<
"Expected workspace size of at least "
<<
expected_ws_size
<<
" but got problem size "
<<
workspace_size
<<
"
\n
"
;
throw
std
::
runtime_error
(
err_ss
.
str
());
}
if
(
descending
)
{
cub
::
DeviceRadixSort
::
SortPairsDescending
(
workspace
,
actual_ws_size
,
keys_in
,
keys_out
,
values_in
,
values_out
,
num_key_value_pairs
,
0
,
32
,
stream
);
}
else
{
cub
::
DeviceRadixSort
::
SortPairs
(
workspace
,
actual_ws_size
,
keys_in
,
keys_out
,
values_in
,
values_out
,
num_key_value_pairs
,
0
,
num_bits_
,
stream
);
}
}
template
<
>
void
CubKeyValueSorter
::
run
(
void
*
workspace
,
const
size_t
workspace_size
,
const
__nv_bfloat16
*
keys_in
,
__nv_bfloat16
*
keys_out
,
const
int
*
values_in
,
int
*
values_out
,
const
size_t
num_key_value_pairs
,
bool
descending
,
cudaStream_t
stream
)
{}
CubKeyValueSorter
sorter_
;
// -------- getWorkspaceSize -------- //
template
<
typename
T
>
size_t
getWorkspaceSize
(
const
int
num_rows
,
const
int
hidden_size
,
const
int
inter_size
,
const
int
num_experts
,
const
int
k
,
const
int
batch_size
,
const
int
max_seq_len
)
{
const
int
buf_size
=
AlignTo16
(
num_experts
*
batch_size
*
k
*
hidden_size
);
const
int
interbuf_size
=
AlignTo16
(
num_experts
*
batch_size
*
k
*
inter_size
);
const
int
padded_experts
=
AlignTo16
(
num_experts
);
const
int
num_moe_inputs
=
AlignTo16
(
num_experts
*
num_rows
);
int
padded_num_moe_inputs
=
num_experts
*
batch_size
*
max_seq_len
;
size_t
total_ws_bytes
=
sizeof
(
int
)
*
num_moe_inputs
;
// source_rows_
total_ws_bytes
+=
sizeof
(
int
)
*
padded_num_moe_inputs
;
// padded_source_rows_
total_ws_bytes
+=
sizeof
(
T
)
*
padded_num_moe_inputs
;
// padded_expert_scales_
total_ws_bytes
+=
sizeof
(
int
)
*
padded_num_moe_inputs
;
// permuted_rows_
total_ws_bytes
+=
sizeof
(
int
)
*
num_experts
*
k
;
// permuted_experts_
total_ws_bytes
+=
buf_size
*
sizeof
(
T
);
// permuted_data_
total_ws_bytes
+=
padded_experts
*
sizeof
(
int64_t
);
// Hold total_rows_before_expert_
total_ws_bytes
+=
sizeof
(
T
)
*
num_moe_inputs
;
// attr_mask: [e, n]
total_ws_bytes
+=
sizeof
(
T
)
*
padded_num_moe_inputs
;
// sorted_softmax_output
const
int
bytes_for_fc1_result
=
interbuf_size
*
sizeof
(
T
);
const
int
sorter_ws_size_bytes
=
AlignTo16
(
sorter_
.
getWorkspaceSize
(
num_experts
*
k
));
sorter_
.
update_num_experts
(
k
);
int
bytes_for_intermediate_and_sorting
=
bytes_for_fc1_result
;
if
(
sorter_ws_size_bytes
>
bytes_for_fc1_result
)
{
int
remaining_bytes
=
AlignTo16
(
sorter_ws_size_bytes
-
bytes_for_fc1_result
);
bytes_for_intermediate_and_sorting
+=
remaining_bytes
;
}
total_ws_bytes
+=
bytes_for_intermediate_and_sorting
;
// intermediate (fc1) output + cub
// sorting workspace
return
total_ws_bytes
;
}
// -------- initialize_expert_choice_route_kernel -------- //
template
<
typename
T
>
__global__
void
initialize_expert_choice_route_kernel
(
int
*
expert_for_source_row
,
int
*
source_row
,
int
*
expanded_source_row_to_expanded_dest_row
,
int64_t
*
total_rows_before_expert
,
T
*
attr_mask
,
const
int
cols
,
const
int
k
,
const
int
batch_size
)
{
int
start
=
cols
*
blockIdx
.
x
;
for
(
int
i
=
threadIdx
.
x
;
i
<
cols
;
i
+=
blockDim
.
x
)
{
expert_for_source_row
[
start
+
i
]
=
blockIdx
.
x
;
source_row
[
start
+
i
]
=
start
+
i
;
expanded_source_row_to_expanded_dest_row
[
start
+
i
]
=
-
1
;
attr_mask
[
start
+
i
]
=
(
T
)
1.0
f
;
}
if
(
threadIdx
.
x
==
0
)
{
total_rows_before_expert
[
blockIdx
.
x
]
=
batch_size
*
k
*
(
blockIdx
.
x
+
1
);
}
}
// -------- softmax_kernel -------- //
template
<
int
ITEMS_PER_THREAD
,
typename
T
>
__global__
void
softmax_kernel_v4
(
T
*
qk_buf_
,
const
T
*
qk_buf_src
,
// shape [batch_size, head_num, seq_len_1, seq_len_2]
const
T
*
attr_mask
,
// shape [batch_size, seq_len_1, seq_len_2]
const
int
batch_size
,
const
int
head_num
,
const
int
seq_len_1
,
const
int
seq_len_2
,
const
T
scalar
)
{
for
(
int
seq_id
=
blockIdx
.
x
;
seq_id
<
seq_len_1
;
seq_id
+=
gridDim
.
x
)
{
float
data
[
ITEMS_PER_THREAD
];
int
qk_offset
;
__shared__
float
s_mean
,
s_max
;
float
local_max
=
-
1e20
f
;
for
(
int
i
=
0
;
blockDim
.
x
*
i
+
threadIdx
.
x
<
seq_len_2
;
i
++
)
{
qk_offset
=
((
blockIdx
.
y
*
head_num
+
blockIdx
.
z
)
*
seq_len_1
+
seq_id
)
*
seq_len_2
+
blockDim
.
x
*
i
+
threadIdx
.
x
;
int
mask_offset
=
(
blockIdx
.
y
*
seq_len_1
+
seq_id
)
*
seq_len_2
+
blockDim
.
x
*
i
+
threadIdx
.
x
;
float
qk
=
static_cast
<
float
>
(
qk_buf_src
[
qk_offset
]);
float
mask_val
=
static_cast
<
float
>
(
__ldg
(
&
attr_mask
[
mask_offset
]));
mask_val
=
(
1.0
f
-
mask_val
)
*
-
10000.0
f
;
data
[
i
]
=
qk
*
static_cast
<
float
>
(
scalar
)
+
mask_val
;
local_max
=
fmax
(
local_max
,
data
[
i
]);
}
float
max_val
=
blockDim
.
x
<=
32
?
phi
::
funcs
::
warpReduceMax
<
float
>
(
local_max
,
0xFFFFFFFF
)
:
phi
::
funcs
::
blockReduceMax
<
float
>
(
local_max
,
0xffffffff
);
if
(
threadIdx
.
x
==
0
)
{
s_max
=
max_val
;
}
__syncthreads
();
float
local_sum
=
0
;
for
(
int
i
=
0
;
blockDim
.
x
*
i
+
threadIdx
.
x
<
seq_len_2
;
i
++
)
{
data
[
i
]
=
__expf
(
data
[
i
]
-
s_max
);
local_sum
+=
data
[
i
];
}
float
sum_val
=
blockDim
.
x
<=
32
?
phi
::
funcs
::
warpReduceSum
<
float
>
(
local_sum
,
0xffffffff
)
:
phi
::
funcs
::
blockReduceSum
<
float
>
(
local_sum
,
0xffffffff
);
if
(
threadIdx
.
x
==
0
)
{
s_mean
=
sum_val
+
1e-6
f
;
s_mean
=
__fdividef
(
1.0
f
,
s_mean
);
}
__syncthreads
();
for
(
int
i
=
0
;
blockDim
.
x
*
i
+
threadIdx
.
x
<
seq_len_2
;
i
++
)
{
qk_offset
=
((
blockIdx
.
y
*
head_num
+
blockIdx
.
z
)
*
seq_len_1
+
seq_id
)
*
seq_len_2
+
blockDim
.
x
*
i
+
threadIdx
.
x
;
qk_buf_
[
qk_offset
]
=
(
T
)(
data
[
i
]
*
s_mean
);
}
}
}
template
<
typename
T
,
int
ITEMS_PER_THREAD
>
__global__
void
softmax_kernel_v4_half2
(
T
*
qk_buf_
,
const
T
*
attr_mask
,
const
int
batch_size
,
const
int
head_num
,
const
int
seq_len_1
,
const
int
seq_len_2
,
const
T
scalar
)
{
using
T2
=
half2
;
T2
*
qk_buf_half2
=
reinterpret_cast
<
T2
*>
(
qk_buf_
);
const
T2
*
attr_mask_half2
=
(
const
T2
*
)
attr_mask
;
for
(
int
seq_id
=
blockIdx
.
x
;
seq_id
<
seq_len_1
;
seq_id
+=
gridDim
.
x
)
{
T2
data
[
ITEMS_PER_THREAD
];
int
qk_offset
;
__shared__
float
s_mean
,
s_max
;
float
local_max
=
-
1e20
f
;
for
(
int
i
=
0
;
blockDim
.
x
*
i
+
threadIdx
.
x
<
(
seq_len_2
/
2
)
&&
i
<
ITEMS_PER_THREAD
;
i
++
)
{
qk_offset
=
((
blockIdx
.
y
*
head_num
+
blockIdx
.
z
)
*
seq_len_1
+
seq_id
)
*
(
seq_len_2
/
2
)
+
blockDim
.
x
*
i
+
threadIdx
.
x
;
int
mask_offset
=
(
blockIdx
.
y
*
seq_len_1
+
seq_id
)
*
(
seq_len_2
/
2
)
+
blockDim
.
x
*
i
+
threadIdx
.
x
;
T2
qk
=
qk_buf_half2
[
qk_offset
];
T2
mask_val
=
__ldg
(
&
attr_mask_half2
[
mask_offset
]);
mask_val
=
__hmul2
(
__hsub2
(
__float2half2_rn
(
1.0
f
),
mask_val
),
__float2half2_rn
(
-
10000.0
f
));
data
[
i
]
=
__hadd2
(
__hmul2
(
qk
,
__half2half2
(
scalar
)),
mask_val
);
local_max
=
fmax
(
local_max
,
fmax
(
static_cast
<
float
>
(
data
[
i
].
x
),
static_cast
<
float
>
(
data
[
i
].
y
)));
}
float
max_val
=
blockDim
.
x
<=
32
?
phi
::
funcs
::
warpReduceMax
<
float
>
(
local_max
,
0xFFFFFFFF
)
:
phi
::
funcs
::
blockReduceMax
<
float
>
(
local_max
,
0xFFFFFFFF
);
if
(
threadIdx
.
x
==
0
)
{
s_max
=
max_val
;
}
__syncthreads
();
float
local_sum
=
0
;
for
(
int
i
=
0
;
blockDim
.
x
*
i
+
threadIdx
.
x
<
(
seq_len_2
/
2
)
&&
i
<
ITEMS_PER_THREAD
;
i
++
)
{
data
[
i
]
=
h2exp
(
__hsub2
(
data
[
i
],
__float2half2_rn
(
s_max
)));
local_sum
+=
static_cast
<
float
>
(
data
[
i
].
x
+
data
[
i
].
y
);
}
float
sum_val
=
blockDim
.
x
<=
32
?
phi
::
funcs
::
warpReduceSum
<
float
>
(
local_sum
,
0xFFFFFFFF
)
:
phi
::
funcs
::
blockReduceSum
<
float
>
(
local_sum
,
0xFFFFFFFF
);
if
(
threadIdx
.
x
==
0
)
{
s_mean
=
sum_val
+
1e-6
f
;
s_mean
=
__fdividef
(
1.0
f
,
s_mean
);
}
__syncthreads
();
for
(
int
i
=
0
;
blockDim
.
x
*
i
+
threadIdx
.
x
<
(
seq_len_2
/
2
)
&&
i
<
ITEMS_PER_THREAD
;
i
++
)
{
qk_offset
=
((
blockIdx
.
y
*
head_num
+
blockIdx
.
z
)
*
seq_len_1
+
seq_id
)
*
(
seq_len_2
/
2
)
+
blockDim
.
x
*
i
+
threadIdx
.
x
;
qk_buf_half2
[
qk_offset
]
=
__hmul2
(
data
[
i
],
__float2half2_rn
(
s_mean
));
}
}
}
template
<
typename
T
,
int
ITEMS_PER_THREAD
,
int
NUM
>
__global__
void
softmax_kernel_v5_half2
(
T
*
qk_buf_
,
const
T
*
attr_mask
,
const
int
batch_size
,
const
int
head_num
,
const
int
seq_len_1
,
const
int
seq_len_2
,
const
T
scalar
)
{
using
T2
=
half2
;
T2
*
qk_buf_half2
=
reinterpret_cast
<
T2
*>
(
qk_buf_
);
const
T2
*
attr_mask_half2
=
(
const
T2
*
)
attr_mask
;
for
(
int
seq_id
=
blockIdx
.
x
;
seq_id
<
seq_len_1
;
seq_id
+=
gridDim
.
x
*
NUM
)
{
T2
data
[
NUM
][
ITEMS_PER_THREAD
];
int
qk_offset
[
NUM
];
__shared__
float
s_sum
[
NUM
],
s_max
[
NUM
];
float
local_max
[
NUM
];
#pragma unroll
for
(
int
j
=
0
;
j
<
NUM
;
j
++
)
{
local_max
[
j
]
=
-
1e20
f
;
}
const
int
MAX_NUM
=
min
((
seq_len_1
-
seq_id
+
gridDim
.
x
-
1
)
/
gridDim
.
x
,
NUM
);
for
(
int
i
=
0
;
blockDim
.
x
*
i
+
threadIdx
.
x
<
(
seq_len_2
/
2
)
&&
i
<
ITEMS_PER_THREAD
;
i
++
)
{
int
mask_offset
[
NUM
];
#pragma unroll
for
(
int
j
=
0
;
j
<
MAX_NUM
;
j
++
)
{
qk_offset
[
j
]
=
((
blockIdx
.
y
*
head_num
+
blockIdx
.
z
)
*
seq_len_1
+
seq_id
+
j
*
gridDim
.
x
)
*
(
seq_len_2
/
2
)
+
blockDim
.
x
*
i
+
threadIdx
.
x
;
mask_offset
[
j
]
=
(
blockIdx
.
y
*
seq_len_1
+
seq_id
+
j
*
gridDim
.
x
)
*
(
seq_len_2
/
2
)
+
blockDim
.
x
*
i
+
threadIdx
.
x
;
}
T2
mask_val
[
NUM
];
#pragma unroll
for
(
int
j
=
0
;
j
<
MAX_NUM
;
j
++
)
{
mask_val
[
j
]
=
__ldg
(
&
attr_mask_half2
[
mask_offset
[
j
]]);
}
T2
qk
[
NUM
];
#pragma unroll
for
(
int
j
=
0
;
j
<
MAX_NUM
;
j
++
)
{
qk
[
j
]
=
qk_buf_half2
[
qk_offset
[
j
]];
}
#pragma unroll
for
(
int
j
=
0
;
j
<
MAX_NUM
;
j
++
)
{
mask_val
[
j
]
=
__hmul2
(
__hsub2
(
__float2half2_rn
(
1.0
f
),
mask_val
[
j
]),
__float2half2_rn
(
-
10000.0
f
));
}
#pragma unroll
for
(
int
j
=
0
;
j
<
MAX_NUM
;
j
++
)
{
data
[
j
][
i
]
=
__hadd2
(
__hmul2
(
qk
[
j
],
__half2half2
(
scalar
)),
mask_val
[
j
]);
local_max
[
j
]
=
fmax
(
local_max
[
j
],
fmax
(
static_cast
<
float
>
(
data
[
j
][
i
].
x
),
static_cast
<
float
>
(
data
[
j
][
i
].
y
)));
}
}
if
(
blockDim
.
x
<=
32
)
{
warpReduceMaxV2
<
float
,
NUM
>
(
local_max
);
}
else
{
blockReduceMaxV2
<
float
,
NUM
>
(
local_max
);
}
if
(
threadIdx
.
x
==
0
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
NUM
;
j
++
)
{
s_max
[
j
]
=
local_max
[
j
];
}
}
__syncthreads
();
float
local_sum
[
NUM
];
#pragma unroll
for
(
int
j
=
0
;
j
<
NUM
;
j
++
)
{
local_sum
[
j
]
=
{
0.
f
};
}
for
(
int
i
=
0
;
blockDim
.
x
*
i
+
threadIdx
.
x
<
(
seq_len_2
/
2
)
&&
i
<
ITEMS_PER_THREAD
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
MAX_NUM
;
j
++
)
{
data
[
j
][
i
]
=
h2exp
(
__hsub2
(
data
[
j
][
i
],
__float2half2_rn
(
s_max
[
j
])));
}
#pragma unroll
for
(
int
j
=
0
;
j
<
MAX_NUM
;
j
++
)
{
local_sum
[
j
]
+=
static_cast
<
float
>
(
data
[
j
][
i
].
x
+
data
[
j
][
i
].
y
);
}
}
if
(
blockDim
.
x
<=
32
)
{
warpReduceSumV2
<
float
,
NUM
>
(
local_sum
);
}
else
{
blockReduceSumV2
<
float
,
NUM
>
(
local_sum
);
}
if
(
threadIdx
.
x
==
0
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
NUM
;
j
++
)
{
s_sum
[
j
]
=
__fdividef
(
1.0
f
,
local_sum
[
j
]
+
1e-6
f
);
}
}
__syncthreads
();
for
(
int
i
=
0
;
blockDim
.
x
*
i
+
threadIdx
.
x
<
(
seq_len_2
/
2
)
&&
i
<
ITEMS_PER_THREAD
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
MAX_NUM
;
j
++
)
{
qk_offset
[
j
]
=
((
blockIdx
.
y
*
head_num
+
blockIdx
.
z
)
*
seq_len_1
+
seq_id
+
j
*
gridDim
.
x
)
*
(
seq_len_2
/
2
)
+
blockDim
.
x
*
i
+
threadIdx
.
x
;
}
#pragma unroll
for
(
int
j
=
0
;
j
<
MAX_NUM
;
j
++
)
{
qk_buf_half2
[
qk_offset
[
j
]]
=
__hmul2
(
data
[
j
][
i
],
__float2half2_rn
(
s_sum
[
j
]));
}
}
}
}
// -------- transpose_kernel -------- //
template
<
typename
T
>
__global__
void
transposeAxis01
(
T
*
out
,
T
*
in
,
const
int
dim0
,
const
int
dim1
,
const
int
dim2
)
{
int
index
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
index
<
dim0
*
dim1
*
dim2
)
{
const
int
input_dim2_index
=
index
%
dim2
;
index
=
(
index
-
input_dim2_index
)
/
dim2
;
const
int
input_dim1_index
=
index
%
dim1
;
index
=
(
index
-
input_dim1_index
)
/
dim1
;
const
int
input_dim0_index
=
index
%
dim0
;
out
[
input_dim1_index
*
dim0
*
dim2
+
input_dim0_index
*
dim2
+
input_dim2_index
]
=
in
[
input_dim0_index
*
dim1
*
dim2
+
input_dim1_index
*
dim2
+
input_dim2_index
];
}
}
// -------- padding_kernel -------- //
template
<
typename
T
>
__global__
void
paddingKernel
(
T
*
output1
,
int
*
output2
,
const
T
*
input1
,
const
int
*
input2
,
const
int
*
input_lengths
,
const
int
num_tokens
,
const
int
batch_size
,
const
int
max_seq_len
,
const
int
num_experts
)
{
const
bool
IS_FP16
=
std
::
is_same
<
T
,
phi
::
dtype
::
float16
>::
value
;
const
T
MIN_T_VAL
=
(
IS_FP16
)
?
(
T
)
HALF_FLT_MIN
:
(
T
)
FLT_MIN
;
int
offset1
=
blockIdx
.
x
*
num_tokens
;
int
offset2
=
blockIdx
.
x
*
batch_size
*
max_seq_len
;
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
const
T
*
in1_ptr
=
input1
+
offset1
;
const
int
*
in2_ptr
=
input2
+
offset1
;
int
input_length
=
input_lengths
[
i
];
offset1
+=
input_length
;
T
*
out1_ptr
=
output1
+
offset2
;
int
*
out2_ptr
=
output2
+
offset2
;
offset2
+=
max_seq_len
;
for
(
int
j
=
threadIdx
.
x
;
j
<
max_seq_len
;
j
+=
max_seq_len
)
{
if
(
j
<
input_length
)
{
out1_ptr
[
j
]
=
in1_ptr
[
j
];
out2_ptr
[
j
]
=
in2_ptr
[
j
];
}
else
{
out1_ptr
[
j
]
=
MIN_T_VAL
;
out2_ptr
[
j
]
=
0
;
}
}
}
}
// -------- general_topk_pair_sort_kernel -------- //
template
<
typename
T
,
int
BLOCK_THREADS
,
int
ITEMS_PER_THREAD
>
__global__
void
general_topk_pair_sort
(
T
*
out_keys
,
int
*
out_values
,
T
*
in_keys
,
int
*
in_values
)
{
typedef
cub
::
BlockRadixSort
<
T
,
BLOCK_THREADS
,
ITEMS_PER_THREAD
,
int
>
BlockRadixSort
;
typedef
cub
::
BlockLoad
<
T
,
BLOCK_THREADS
,
ITEMS_PER_THREAD
,
cub
::
BLOCK_LOAD_TRANSPOSE
>
BlockLoadKey
;
typedef
cub
::
BlockLoad
<
int
,
BLOCK_THREADS
,
ITEMS_PER_THREAD
,
cub
::
BLOCK_LOAD_TRANSPOSE
>
BlockLoadValue
;
typedef
cub
::
BlockStore
<
T
,
BLOCK_THREADS
,
ITEMS_PER_THREAD
,
cub
::
BLOCK_STORE_TRANSPOSE
>
BlockStoreKey
;
typedef
cub
::
BlockStore
<
int
,
BLOCK_THREADS
,
ITEMS_PER_THREAD
,
cub
::
BLOCK_STORE_TRANSPOSE
>
BlockStoreValue
;
__shared__
union
{
typename
BlockRadixSort
::
TempStorage
sort
;
typename
BlockLoadKey
::
TempStorage
loadkey
;
typename
BlockLoadValue
::
TempStorage
loadvalue
;
typename
BlockStoreKey
::
TempStorage
storekey
;
typename
BlockStoreValue
::
TempStorage
storevalue
;
}
temp_storage
;
int
block_offset
=
blockIdx
.
x
*
BLOCK_THREADS
*
ITEMS_PER_THREAD
;
T
thread_keys
[
ITEMS_PER_THREAD
];
int
thread_values
[
ITEMS_PER_THREAD
];
BlockLoadKey
(
temp_storage
.
loadkey
).
Load
(
in_keys
+
block_offset
,
thread_keys
);
BlockLoadValue
(
temp_storage
.
loadvalue
)
.
Load
(
in_values
+
block_offset
,
thread_values
);
__syncthreads
();
BlockRadixSort
(
temp_storage
.
sort
).
SortDescending
(
thread_keys
,
thread_values
);
__syncthreads
();
BlockStoreKey
(
temp_storage
.
storekey
)
.
Store
(
out_keys
+
block_offset
,
thread_keys
);
BlockStoreValue
(
temp_storage
.
storevalue
)
.
Store
(
out_values
+
block_offset
,
thread_values
);
}
// -------- finalize_moe_routing_kernel -------- //
template
<
typename
T
>
__global__
void
finalize_moe_routing_kernel
(
const
T
*
expanded_permuted_rows
,
T
*
reduced_unpermuted_output
,
const
T
*
skip
,
const
T
*
bias
,
const
T
*
scales
,
const
int
*
expanded_source_row_to_expanded_dest_row
,
const
int
*
expert_for_source_row
,
const
int
cols
,
const
int
k
,
bool
ec_route
)
{
const
int
original_row
=
blockIdx
.
x
;
const
int
num_rows
=
gridDim
.
x
;
T
*
reduced_row_ptr
=
reduced_unpermuted_output
+
original_row
*
cols
;
const
T
*
skip_row_ptr
=
skip
+
original_row
*
cols
;
for
(
int
tid
=
threadIdx
.
x
;
tid
<
cols
;
tid
+=
blockDim
.
x
)
{
T
thread_output
=
skip_row_ptr
[
tid
];
for
(
int
k_idx
=
0
;
k_idx
<
k
;
++
k_idx
)
{
const
int
expanded_original_row
=
original_row
+
k_idx
*
num_rows
;
const
int
expanded_permuted_row
=
expanded_source_row_to_expanded_dest_row
[
expanded_original_row
];
if
(
ec_route
&&
expanded_permuted_row
==
-
1
)
continue
;
const
int64_t
k_offset
=
ec_route
?
expanded_original_row
:
original_row
*
k
+
k_idx
;
const
T
row_scale
=
scales
[
k_offset
];
const
T
*
expanded_permuted_rows_row_ptr
=
expanded_permuted_rows
+
expanded_permuted_row
*
cols
;
const
int
expert_idx
=
ec_route
?
k_idx
:
expert_for_source_row
[
k_offset
];
const
T
*
bias_ptr
=
bias
+
expert_idx
*
cols
;
thread_output
=
thread_output
+
row_scale
*
(
expanded_permuted_rows_row_ptr
[
tid
]
+
bias_ptr
[
tid
]);
}
reduced_row_ptr
[
tid
]
=
thread_output
;
}
}
// -------- initialize_moe_routing_kernel -------- //
template
<
typename
T
,
int
VecSize
>
__global__
void
initialize_moe_routing_kernel
(
const
T
*
unpermuted_input
,
T
*
permuted_output
,
const
int
*
expanded_dest_row_to_expanded_source_row
,
int
*
expanded_source_row_to_expanded_dest_row
,
const
int
num_rows
,
const
int
active_rows
,
const
int
cols
,
const
int
k
,
const
int
max_seq_len
,
bool
ec_route
)
{
using
LoadT
=
phi
::
AlignedVector
<
T
,
VecSize
>
;
LoadT
src_vec
;
// Reverse permutation map.
// I do this so that later, we can use the source -> dest map to do the k-way
// reduction and unpermuting. I need the reverse map for that reduction to
// allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1
// thread block will be responsible for all k summations.
const
int
expanded_dest_row
=
blockIdx
.
x
;
const
int
expanded_source_row
=
ec_route
?
expanded_dest_row_to_expanded_source_row
[
expanded_dest_row
/
k
*
max_seq_len
+
expanded_dest_row
%
k
]
:
expanded_dest_row_to_expanded_source_row
[
expanded_dest_row
];
if
(
threadIdx
.
x
==
0
)
{
expanded_source_row_to_expanded_dest_row
[
expanded_source_row
]
=
expanded_dest_row
;
}
if
(
blockIdx
.
x
<
active_rows
)
{
// Duplicate and permute rows
const
int
source_row
=
expanded_source_row
%
num_rows
;
const
T
*
source_row_ptr
=
unpermuted_input
+
source_row
*
cols
;
T
*
dest_row_ptr
=
permuted_output
+
expanded_dest_row
*
cols
;
for
(
int
tid
=
threadIdx
.
x
*
VecSize
;
tid
<
cols
;
tid
+=
blockDim
.
x
*
VecSize
)
{
// dest_row_ptr[tid] = source_row_ptr[tid];
phi
::
Load
<
T
,
VecSize
>
(
&
source_row_ptr
[
tid
],
&
src_vec
);
phi
::
Store
<
T
,
VecSize
>
(
src_vec
,
&
dest_row_ptr
[
tid
]);
}
}
}
}
// namespace phi
paddle/phi/kernels/fusion/moe_kernel.h
0 → 100644
浏览文件 @
def2a87f
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
MoeKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
gate
,
const
DenseTensor
&
bmm0
,
const
DenseTensor
&
bias0
,
const
DenseTensor
&
bmm1
,
const
DenseTensor
&
bias1
,
const
std
::
string
&
act_type
,
DenseTensor
*
output
);
}
// namespace phi
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
def2a87f
...
...
@@ -78,6 +78,7 @@ if(NOT WITH_GPU)
list
(
REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op_api
)
endif
()
list
(
REMOVE_ITEM TEST_OPS test_fused_ec_moe_op
)
list
(
REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_op
)
list
(
REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_grad_op
)
list
(
REMOVE_ITEM TEST_OPS test_fuse_gemm_epilogue_pass
)
...
...
@@ -143,6 +144,7 @@ if(WIN32)
list
(
REMOVE_ITEM TEST_OPS test_trt_convert_preln_residual_bias
)
list
(
REMOVE_ITEM TEST_OPS test_trt_convert_preln_residual_no_bias
)
list
(
REMOVE_ITEM TEST_OPS test_fused_multi_transformer_int8_op
)
list
(
REMOVE_ITEM TEST_OPS test_fused_ec_moe_op
)
endif
()
list
(
REMOVE_ITEM TEST_OPS test_checkpoint_saver
)
...
...
python/paddle/fluid/tests/unittests/test_fused_ec_moe_op.py
0 → 100644
浏览文件 @
def2a87f
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
import
paddle
import
paddle.nn.functional
as
F
from
paddle.fluid.framework
import
default_main_program
from
paddle.incubate.nn.functional
import
fused_ec_moe
from
paddle.nn.layer.common
import
Linear
default_main_program
().
random_seed
=
42
class
TestFusedEcMoEOp
(
OpTest
):
def
setUp
(
self
):
self
.
config
()
self
.
rtol
=
1e-3
self
.
atol
=
1e-3
paddle
.
set_default_dtype
(
self
.
x_type
)
self
.
__class__
.
op_type
=
"fused_ec_moe"
# Since it's only used in inference.
self
.
__class__
.
no_need_check_grad
=
True
self
.
bmm_w0
=
paddle
.
to_tensor
(
np
.
random
.
randn
(
self
.
num_expert
,
self
.
d_model
,
self
.
d_feedforward
)
*
0.001
,
dtype
=
paddle
.
float16
,
)
self
.
bmm_b0
=
paddle
.
to_tensor
(
np
.
random
.
randn
(
self
.
num_expert
,
1
,
self
.
d_feedforward
)
*
0.001
,
dtype
=
paddle
.
float16
,
)
self
.
bmm_w1
=
paddle
.
to_tensor
(
np
.
random
.
randn
(
self
.
num_expert
,
self
.
d_feedforward
,
self
.
d_model
)
*
0.001
,
dtype
=
paddle
.
float16
,
)
self
.
bmm_b1
=
paddle
.
to_tensor
(
np
.
random
.
randn
(
self
.
num_expert
,
1
,
self
.
d_model
)
*
0.001
,
dtype
=
paddle
.
float16
,
)
self
.
tensor_x
=
paddle
.
to_tensor
(
np
.
random
.
randn
(
self
.
batch_size
,
self
.
seq_len
,
self
.
d_model
)
*
0.001
,
dtype
=
paddle
.
float16
,
)
self
.
bmm_w0
.
stop_gradient
=
True
self
.
bmm_b0
.
stop_gradient
=
True
self
.
bmm_w1
.
stop_gradient
=
True
self
.
bmm_b1
.
stop_gradient
=
True
self
.
tensor_x
.
stop_gradient
=
True
self
.
gate
=
Linear
(
self
.
d_model
,
self
.
num_expert
)
paddle
.
set_default_dtype
(
"float16"
)
self
.
activation
=
getattr
(
F
,
self
.
act_method
)
def
config
(
self
):
self
.
x_type
=
np
.
float16
self
.
batch_size
=
10
self
.
seq_len
=
128
self
.
num_expert
=
32
self
.
d_model
=
768
self
.
d_feedforward
=
3072
self
.
act_method
=
'gelu'
def
GetBaselineOut
(
self
,
tensor_x
,
gate_logits
):
def
expert_choice_gating
(
logits
,
capacity
,
batch_idx
,
expert_idx
):
gates
=
F
.
softmax
(
logits
,
-
1
)
indices1_s
=
paddle
.
topk
(
logits
.
transpose
([
0
,
2
,
1
]),
k
=
capacity
,
axis
=-
1
)[
1
].
cast
(
"int32"
)
seqlen_idx
=
indices1_s
.
reshape
([
-
1
])
gather_idx
=
paddle
.
stack
([
batch_idx
,
seqlen_idx
,
expert_idx
],
-
1
)
prob
=
paddle
.
gather_nd
(
gates
,
gather_idx
)
return
prob
,
expert_idx
,
gather_idx
,
capacity
paddle
.
disable_static
()
capacity
=
self
.
seq_len
//
16
batch_expert_idx
=
paddle
.
nonzero
(
paddle
.
ones
(
shape
=
[
self
.
batch_size
,
self
.
num_expert
,
capacity
])
).
cast
(
'int32'
)
batch_idx
=
batch_expert_idx
[:,
0
]
expert_idx
=
batch_expert_idx
[:,
1
]
(
expert_prob_flatten
,
expert_idx_flatten
,
gather_idx
,
cap
,
)
=
expert_choice_gating
(
gate_logits
,
capacity
,
batch_idx
,
expert_idx
)
outputs
=
paddle
.
zeros_like
(
tensor_x
)
batch_prob
=
expert_prob_flatten
.
reshape
(
[
self
.
batch_size
,
self
.
num_expert
,
-
1
,
1
]
)
batch_idx
=
gather_idx
[:,
:
2
]
selected_token
=
tensor_x
.
gather_nd
(
batch_idx
)
batch_selected_token
=
selected_token
.
reshape
(
[
self
.
batch_size
,
self
.
num_expert
,
-
1
,
tensor_x
.
shape
[
-
1
]]
)
batch_selected_token
=
batch_selected_token
.
transpose
(
[
1
,
0
,
2
,
3
]
).
reshape
([
self
.
num_expert
,
-
1
,
tensor_x
.
shape
[
-
1
]])
output
=
paddle
.
bmm
(
batch_selected_token
,
self
.
bmm_w0
)
+
self
.
bmm_b0
output
=
self
.
activation
(
output
)
output
=
paddle
.
bmm
(
output
,
self
.
bmm_w1
)
+
self
.
bmm_b1
output
=
output
.
transpose
([
1
,
0
,
2
]).
reshape
(
[
self
.
batch_size
,
-
1
,
self
.
num_expert
,
tensor_x
.
shape
[
-
1
]]
)
output
=
output
.
transpose
([
0
,
2
,
1
,
3
])
output
=
batch_prob
*
output
output
=
output
.
reshape
([
-
1
,
tensor_x
.
shape
[
-
1
]])
outputs
=
outputs
.
scatter_nd_add
(
batch_idx
,
output
)
return
outputs
+
tensor_x
def
GetFusedEcMoeOut
(
self
,
tensor_x
,
gate_logits
):
paddle
.
disable_static
()
fused_out
=
fused_ec_moe
(
tensor_x
,
gate_logits
,
self
.
bmm_w0
,
self
.
bmm_b0
,
self
.
bmm_w1
,
self
.
bmm_b1
,
self
.
act_method
,
)
return
fused_out
def
test_fused_ec_moe_op
(
self
):
gate_logits
=
self
.
gate
(
self
.
tensor_x
)
final_out_ref
=
self
.
GetBaselineOut
(
self
.
tensor_x
,
gate_logits
)
final_out
=
self
.
GetFusedEcMoeOut
(
self
.
tensor_x
,
gate_logits
)
np
.
testing
.
assert_allclose
(
final_out_ref
,
final_out
,
rtol
=
self
.
rtol
,
atol
=
self
.
atol
)
class
TestFusedEcMoEOpActGeluFp16
(
TestFusedEcMoEOp
):
def
config
(
self
):
super
().
config
()
self
.
x_type
=
np
.
float16
class
TestFusedEcMoEOpActReluFp16
(
TestFusedEcMoEOp
):
def
config
(
self
):
super
().
config
()
self
.
x_type
=
np
.
float16
self
.
act_method
=
"relu"
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/incubate/nn/__init__.py
浏览文件 @
def2a87f
...
...
@@ -20,6 +20,7 @@ from .layer.fused_linear import FusedLinear # noqa: F401
from
.layer.fused_transformer
import
(
FusedBiasDropoutResidualLayerNorm
,
)
# noqa: F401
from
.layer.fused_ec_moe
import
FusedEcMoe
# noqa: F401
__all__
=
[
# noqa
'FusedMultiHeadAttention'
,
...
...
@@ -28,4 +29,5 @@ __all__ = [ # noqa
'FusedMultiTransformer'
,
'FusedLinear'
,
'FusedBiasDropoutResidualLayerNorm'
,
'FusedEcMoe'
,
]
python/paddle/incubate/nn/functional/__init__.py
浏览文件 @
def2a87f
...
...
@@ -17,6 +17,7 @@ from .fused_transformer import fused_feedforward
from
.fused_transformer
import
fused_multi_transformer
from
.fused_matmul_bias
import
fused_matmul_bias
,
fused_linear
from
.fused_transformer
import
fused_bias_dropout_residual_layer_norm
from
.fused_ec_moe
import
fused_ec_moe
__all__
=
[
'fused_multi_head_attention'
,
...
...
@@ -25,4 +26,5 @@ __all__ = [
'fused_matmul_bias'
,
'fused_linear'
,
'fused_bias_dropout_residual_layer_norm'
,
'fused_ec_moe'
,
]
python/paddle/incubate/nn/functional/fused_ec_moe.py
0 → 100644
浏览文件 @
def2a87f
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle.fluid.layer_helper
import
LayerHelper
def
fused_ec_moe
(
x
,
gate
,
bmm0_weight
,
bmm0_bias
,
bmm1_weight
,
bmm1_bias
,
act_type
):
"""
Applies fused ec_moe kernel.
This method requires SM_ARCH in sm75, sm80, sm86.
Args:
x (Tensor): the input Tensor. Its shape is [bsz, seq_len, d_model].
gate (Tensor): the gate Tensor to choose expert. Its shape is [bsz, seq_len, e].
bmm0_weight (Tensor): the first batch matrix matmul weight. Its shape is [e, d_model, d_feed_forward].
bmm0_bias (Tensor): the first batch matrix matmul bias. Its shape is [e, 1, d_feed_forward].
bmm1_weight (Tensor): the second batch matrix matmul weight. Its shape is [e, d_model, d_feed_forward].
bmm1_bias (Tensor): the second batch matrix matmul bias. Its shape is [e, 1, d_feed_forward].
act_type (string): the Activation Type. Currently only support `gelu`, `relu`.
Returns:
Tensor: the output Tensor.
Examples:
.. code-block:: python
# required: gpu
import paddle
from paddle.incubate.nn.functional import fused_ec_moe
batch = 10
seq_len = 128
d_model = 1024
d_feed_forward = d_model * 4
num_expert = 8
x = paddle.randn([batch, seq_len, d_model])
gate = paddle.randn([batch, seq_len, num_expert])
bmm0_weight = paddle.randn([num_expert, d_model, d_feed_forward])
bmm0_bias = paddle.randn([num_expert, d_model, d_feed_forward])
bmm1_weight = paddle.randn([num_expert, d_model, d_feed_forward])
bmm1_bias = paddle.randn([num_expert, d_model, d_feed_forward])
out = fused_ec_moe(x, gate, bmm0_weight, bmm0_bias, bmm1_weight, bmm1_bias, act_type="gelu")
print(out.shape) # [batch, seq_len, num_expert]
"""
helper
=
LayerHelper
(
'fused_moe'
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
)
helper
.
append_op
(
type
=
'moe'
,
inputs
=
{
'X'
:
x
,
'Gate'
:
gate
,
'Bmm0'
:
bmm0_weight
,
'Bias0'
:
bmm0_bias
,
'Bmm1'
:
bmm1_weight
,
'Bias1'
:
bmm1_bias
,
},
outputs
=
{
'Out'
:
out
},
attrs
=
{
'act_type'
:
act_type
},
)
return
out
python/paddle/incubate/nn/layer/fused_ec_moe.py
0 → 100644
浏览文件 @
def2a87f
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle.incubate.nn
import
functional
as
F
from
paddle.nn
import
Layer
class
FusedEcMoe
(
Layer
):
r
"""A FusedEcMoe Layer.
Parameters:
hidden_size (int): The dim size of input units.
inter_size (int): The dim size of feed forward network.
num_expert (int): The number of experts.
act_type (string): The activation type. Currently only support `gelu`, `relu`.
weight_attr (ParamAttr, optional): The attribute for the learnable
weight of this layer. The default value is None and the weight will be
initialized to zero. For detailed information, please refer to
paddle.ParamAttr.
bias_attr (ParamAttr|bool, optional): The attribute for the learnable bias
of this layer. If it is set to False, no bias will be added to the output.
If it is set to None or one kind of ParamAttr, a bias parameter will
be created according to ParamAttr. For detailed information, please refer
to paddle.ParamAttr. The default value is None and the bias will be
initialized to zero.
Attribute:
**weight** (Parameter): the learnable weight of this layer.
**bias** (Parameter): the learnable bias of this layer.
Shape:
- input: Multi-dimentional tensor with shape :math:`[batch\_size, seq\_len, d\_model]` .
- output: Multi-dimentional tensor with shape :math:`[batch\_size, seq\_len, d\_model]` .
Examples:
.. code-block:: python
# required: gpu
import paddle
from paddle.incubate.nn.layer.fused_ec_moe import FusedEcMoe
x = paddle.randn([10, 128, 1024]) # [bsz, seq_len, d_model]
gate = paddle.randn([10, 128, 8]) # [bsz, seq_len, num_experts]
moe = FusedEcMoe(1024, 4096, 8, act_type="gelu")
y = moe(x, gate)
print(y.shape) # [10, 128, 1024]
"""
def
__init__
(
self
,
hidden_size
,
inter_size
,
num_experts
,
act_type
,
weight_attr
=
None
,
bias_attr
=
None
,
):
super
().
__init__
()
weight0_shape
=
[
num_experts
,
hidden_size
,
inter_size
]
bias0_shape
=
[
num_experts
,
1
,
inter_size
]
weight1_shape
=
[
num_experts
,
inter_size
,
hidden_size
]
bias1_shape
=
[
num_experts
,
1
,
hidden_size
]
dtype
=
self
.
_helper
.
get_default_dtype
()
self
.
bmm_weight0
=
self
.
create_parameter
(
shape
=
weight0_shape
,
attr
=
weight_attr
,
dtype
=
dtype
,
is_bias
=
False
)
self
.
bmm_bias0
=
self
.
create_parameter
(
shape
=
bias0_shape
,
attr
=
bias_attr
,
dtype
=
dtype
,
is_bias
=
True
)
self
.
bmm_weight1
=
self
.
create_parameter
(
shape
=
weight1_shape
,
attr
=
weight_attr
,
dtype
=
dtype
,
is_bias
=
False
)
self
.
bmm_bias1
=
self
.
create_parameter
(
shape
=
bias1_shape
,
attr
=
bias_attr
,
dtype
=
dtype
,
is_bias
=
True
)
self
.
act_type
=
act_type
if
self
.
act_type
not
in
[
"gelu"
,
"relu"
]:
raise
NotImplementedError
(
"Currently only support `gelu`, `relu`. "
)
def
forward
(
self
,
x
,
gate
):
return
F
.
fused_ec_moe
(
x
,
gate
,
self
.
bmm_weight0
,
self
.
bmm_bias0
,
self
.
bmm_weight1
,
self
.
bmm_bias1
,
self
.
act_type
,
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录