Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
def2a87f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
def2a87f
编写于
12月 22, 2022
作者:
X
xiaoxiaohehe001
提交者:
GitHub
12月 22, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Paddle Inference] Add moe phi kernel (#48703)
上级
efa34534
变更
18
展开全部
隐藏空白更改
内联
并排
Showing
18 changed file
with
3956 addition
and
0 deletion
+3956
-0
cmake/third_party.cmake
cmake/third_party.cmake
+1
-0
paddle/fluid/inference/tensorrt/dynamic_shape_infermeta.cc
paddle/fluid/inference/tensorrt/dynamic_shape_infermeta.cc
+10
-0
paddle/fluid/operators/moe_op.cc
paddle/fluid/operators/moe_op.cc
+64
-0
paddle/phi/infermeta/multiary.cc
paddle/phi/infermeta/multiary.cc
+14
-0
paddle/phi/infermeta/multiary.h
paddle/phi/infermeta/multiary.h
+9
-0
paddle/phi/kernels/CMakeLists.txt
paddle/phi/kernels/CMakeLists.txt
+6
-0
paddle/phi/kernels/fusion/cutlass/default_moe_fc_traits.h
paddle/phi/kernels/fusion/cutlass/default_moe_fc_traits.h
+206
-0
paddle/phi/kernels/fusion/cutlass/linear_combination_ft_gelu.h
...e/phi/kernels/fusion/cutlass/linear_combination_ft_gelu.h
+687
-0
paddle/phi/kernels/fusion/cutlass/moe_cutlass_kernel.h
paddle/phi/kernels/fusion/cutlass/moe_cutlass_kernel.h
+879
-0
paddle/phi/kernels/fusion/cutlass/moe_kernel.cu
paddle/phi/kernels/fusion/cutlass/moe_kernel.cu
+911
-0
paddle/phi/kernels/fusion/cutlass/moe_kernel_impl.h
paddle/phi/kernels/fusion/cutlass/moe_kernel_impl.h
+779
-0
paddle/phi/kernels/fusion/moe_kernel.h
paddle/phi/kernels/fusion/moe_kernel.h
+32
-0
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+2
-0
python/paddle/fluid/tests/unittests/test_fused_ec_moe_op.py
python/paddle/fluid/tests/unittests/test_fused_ec_moe_op.py
+176
-0
python/paddle/incubate/nn/__init__.py
python/paddle/incubate/nn/__init__.py
+2
-0
python/paddle/incubate/nn/functional/__init__.py
python/paddle/incubate/nn/functional/__init__.py
+2
-0
python/paddle/incubate/nn/functional/fused_ec_moe.py
python/paddle/incubate/nn/functional/fused_ec_moe.py
+75
-0
python/paddle/incubate/nn/layer/fused_ec_moe.py
python/paddle/incubate/nn/layer/fused_ec_moe.py
+101
-0
未找到文件。
cmake/third_party.cmake
浏览文件 @
def2a87f
...
...
@@ -516,6 +516,7 @@ if(WITH_GPU
if
(
${
CMAKE_CUDA_COMPILER_VERSION
}
GREATER_EQUAL 11.0
)
include
(
external/cutlass
)
# download, build, install cusparselt
list
(
APPEND third_party_deps extern_cutlass
)
set
(
WITH_CUTLASS ON
)
endif
()
endif
()
...
...
paddle/fluid/inference/tensorrt/dynamic_shape_infermeta.cc
浏览文件 @
def2a87f
...
...
@@ -235,6 +235,15 @@ nvinfer1::DimsExprs UnchangedInferMeta(
return
inputs
[
0
];
}
nvinfer1
::
DimsExprs
MoeInferMeta
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
,
// NOLINT
const
framework
::
OpDesc
&
op_desc
)
{
return
inputs
[
0
];
}
nvinfer1
::
DimsExprs
Pad3dInferMeta
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
...
...
@@ -384,6 +393,7 @@ PD_REGISTER_DYNAMIC_INFER_META_FN(instance_norm, InstanceNormInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN
(
unfold
,
UnflodInferMeta
);
PD_REGISTER_DYNAMIC_INFER_META_FN
(
scatter_nd_add
,
ScatterNdAddInferMeta
);
PD_REGISTER_DYNAMIC_INFER_META_FN
(
inverse
,
UnchangedInferMeta
);
PD_REGISTER_DYNAMIC_INFER_META_FN
(
moe
,
MoeInferMeta
);
PD_REGISTER_DYNAMIC_INFER_META_FN
(
pad3d
,
Pad3dInferMeta
);
PD_REGISTER_DYNAMIC_INFER_META_FN
(
grid_sampler
,
GridSamplerInferMeta
);
}
// namespace tensorrt
...
...
paddle/fluid/operators/moe_op.cc
0 → 100644
浏览文件 @
def2a87f
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/binary.h"
namespace
paddle
{
namespace
operators
{
class
MoeOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
data_type
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
);
return
framework
::
OpKernelType
(
data_type
,
ctx
.
device_context
());
}
};
class
MoeOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor), The source input tensor of Moe op."
);
AddInput
(
"Gate"
,
"(Tensor), The gating input tensor of Moe op."
);
AddInput
(
"Bmm0"
,
"(Tensor), The bmm0 input tensor of Moe op."
);
AddInput
(
"Bias0"
,
"(Tensor), The eltwise0 input tensor of Moe op."
);
AddInput
(
"Bmm1"
,
"(Tensor), The bmm1 input tensor of Moe op."
);
AddInput
(
"Bias1"
,
"(Tensor), The eltwise1 input tensor of Moe op."
);
AddOutput
(
"Out"
,
"(Tensor), The output tensor of Moe op."
);
AddAttr
<
std
::
string
>
(
"act_type"
,
R"DOC(activation type, currently only support `gelu`, `relu`. Default value is: `gelu`. )DOC"
)
.
SetDefault
(
"gelu"
);
AddComment
(
R"DOC(FusedEcMoe kernel. For more details you can refer to `FusedEcMoE` python documents. )DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
DECLARE_INFER_SHAPE_FUNCTOR
(
moe
,
MoeInferShapeFunctor
,
PD_INFER_META
(
phi
::
MoeInferMeta
));
REGISTER_OPERATOR
(
moe
,
ops
::
MoeOp
,
ops
::
MoeOpMaker
,
MoeInferShapeFunctor
);
paddle/phi/infermeta/multiary.cc
浏览文件 @
def2a87f
...
...
@@ -2931,6 +2931,20 @@ void YoloLossInferMeta(const MetaTensor& x,
gt_match_mask
->
set_dtype
(
x
.
dtype
());
}
void
MoeInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
gate
,
const
MetaTensor
&
bmm0
,
const
MetaTensor
&
bias0
,
const
MetaTensor
&
bmm1
,
const
MetaTensor
&
bias1
,
const
std
::
string
&
act_type
,
MetaTensor
*
out
)
{
out
->
set_dims
(
x
.
dims
());
out
->
share_lod
(
x
);
out
->
set_dtype
(
x
.
dtype
());
out
->
set_layout
(
x
.
layout
());
}
}
// namespace phi
PD_REGISTER_INFER_META_FN
(
batch_norm_infer
,
phi
::
BatchNormInferInferMeta
);
paddle/phi/infermeta/multiary.h
浏览文件 @
def2a87f
...
...
@@ -523,4 +523,13 @@ void YoloLossInferMeta(const MetaTensor& x,
MetaTensor
*
objectness_mask
,
MetaTensor
*
gt_match_mask
);
void
MoeInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
gate
,
const
MetaTensor
&
bmm0
,
const
MetaTensor
&
bias0
,
const
MetaTensor
&
bmm1
,
const
MetaTensor
&
bias1
,
const
std
::
string
&
act_type
,
MetaTensor
*
out
);
}
// namespace phi
paddle/phi/kernels/CMakeLists.txt
浏览文件 @
def2a87f
...
...
@@ -104,6 +104,12 @@ file(
"strings/gpu/*.cu"
"fusion/gpu/*.cu"
)
if
(
WITH_CUTLASS
)
file
(
GLOB cutlass_cu
"fusion/cutlass/default_moe_fc_traits.h"
"fusion/cutlass/linear_combination_ft_gelu.h"
"fusion/cutlass/moe*"
)
list
(
APPEND kernel_cu
${
cutlass_cu
}
)
endif
()
if
(
WITH_MKLDNN
)
file
(
GLOB
...
...
paddle/phi/kernels/fusion/cutlass/default_moe_fc_traits.h
0 → 100644
浏览文件 @
def2a87f
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/bfloat16.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/layout/matrix.h"
namespace
cutlass
{
namespace
gemm
{
namespace
kernel
{
template
<
typename
TypeA
,
typename
TypeB
,
typename
arch
>
struct
MoeArchTraits
{};
template
<
typename
arch
>
struct
MoeArchTraits
<
float
,
float
,
arch
>
{
static
constexpr
int
Stages
=
2
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassSimt
;
using
AccType
=
float
;
using
LayoutB
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
ElementsPerAccessA
=
1
;
static
constexpr
int
ElementsPerAccessB
=
1
;
static
constexpr
int
ElementsPerAccessC
=
1
;
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
8
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
8
>
;
using
InstructionShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
1
,
1
>
;
using
Operator
=
cutlass
::
arch
::
OpMultiplyAdd
;
};
// ========================= Volta Traits ===========================
// Volta will always dequantize after the global memory load.
template
<
typename
TypeB
>
struct
MoeArchTraits
<
cutlass
::
half_t
,
TypeB
,
cutlass
::
arch
::
Sm70
>
{
private:
static
constexpr
int
ThreadblockK
=
32
;
public:
static
constexpr
int
Stages
=
2
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
AccType
=
float
;
using
LayoutB
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
ElementsPerAccessA
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
half_t
>::
value
;
static
constexpr
int
ElementsPerAccessB
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
half_t
>::
value
;
static
constexpr
int
ElementsPerAccessC
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
half_t
>::
value
;
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
ThreadblockK
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
32
,
ThreadblockK
>
;
using
InstructionShape
=
cutlass
::
gemm
::
GemmShape
<
8
,
8
,
4
>
;
using
Operator
=
cutlass
::
arch
::
OpMultiplyAdd
;
};
template
<
typename
TypeB
>
struct
MoeArchTraits
<
cutlass
::
bfloat16_t
,
TypeB
,
cutlass
::
arch
::
Sm70
>
{
private:
static
constexpr
int
ThreadblockK
=
32
;
public:
static
constexpr
int
Stages
=
2
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
AccType
=
float
;
using
LayoutB
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
ElementsPerAccessA
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
bfloat16_t
>::
value
;
static
constexpr
int
ElementsPerAccessB
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
bfloat16_t
>::
value
;
static
constexpr
int
ElementsPerAccessC
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
bfloat16_t
>::
value
;
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
ThreadblockK
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
32
,
ThreadblockK
>
;
using
InstructionShape
=
cutlass
::
gemm
::
GemmShape
<
8
,
8
,
4
>
;
using
Operator
=
cutlass
::
arch
::
OpMultiplyAdd
;
};
// ======================= Turing Traits ==============================
// Turing will dequantize after LDSM
// fp16 x fp16 specialization
template
<
>
struct
MoeArchTraits
<
cutlass
::
half_t
,
cutlass
::
half_t
,
cutlass
::
arch
::
Sm75
>
{
static
constexpr
int
Stages
=
2
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
AccType
=
float
;
using
LayoutB
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
ElementsPerAccessA
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
half_t
>::
value
;
static
constexpr
int
ElementsPerAccessB
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
half_t
>::
value
;
static
constexpr
int
ElementsPerAccessC
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
half_t
>::
value
;
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
32
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
32
,
32
>
;
using
InstructionShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
8
>
;
using
Operator
=
cutlass
::
arch
::
OpMultiplyAdd
;
};
// bf16 x bf16 specialization
template
<
>
struct
MoeArchTraits
<
cutlass
::
bfloat16_t
,
cutlass
::
bfloat16_t
,
cutlass
::
arch
::
Sm75
>
{
static
constexpr
int
Stages
=
2
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
AccType
=
float
;
using
LayoutB
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
ElementsPerAccessA
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
bfloat16_t
>::
value
;
static
constexpr
int
ElementsPerAccessB
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
bfloat16_t
>::
value
;
static
constexpr
int
ElementsPerAccessC
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
bfloat16_t
>::
value
;
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
32
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
32
,
32
>
;
using
InstructionShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
8
>
;
using
Operator
=
cutlass
::
arch
::
OpMultiplyAdd
;
};
template
<
>
struct
MoeArchTraits
<
float
,
float
,
cutlass
::
arch
::
Sm80
>
{
static
constexpr
int
Stages
=
3
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
AccType
=
float
;
using
LayoutB
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
ElementsPerAccessA
=
4
;
static
constexpr
int
ElementsPerAccessB
=
4
;
static
constexpr
int
ElementsPerAccessC
=
4
;
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
16
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
16
>
;
using
InstructionShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
8
>
;
using
Operator
=
cutlass
::
arch
::
OpMultiplyAdd
;
};
template
<
>
struct
MoeArchTraits
<
cutlass
::
half_t
,
cutlass
::
half_t
,
cutlass
::
arch
::
Sm80
>
{
static
constexpr
int
Stages
=
3
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
AccType
=
float
;
using
LayoutB
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
ElementsPerAccessA
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
half_t
>::
value
;
static
constexpr
int
ElementsPerAccessB
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
half_t
>::
value
;
static
constexpr
int
ElementsPerAccessC
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
half_t
>::
value
;
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
32
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
32
,
32
>
;
using
InstructionShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
16
>
;
using
Operator
=
cutlass
::
arch
::
OpMultiplyAdd
;
};
template
<
>
struct
MoeArchTraits
<
cutlass
::
bfloat16_t
,
cutlass
::
bfloat16_t
,
cutlass
::
arch
::
Sm80
>
{
static
constexpr
int
Stages
=
3
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
AccType
=
float
;
using
LayoutB
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
ElementsPerAccessA
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
bfloat16_t
>::
value
;
static
constexpr
int
ElementsPerAccessB
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
bfloat16_t
>::
value
;
static
constexpr
int
ElementsPerAccessC
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
bfloat16_t
>::
value
;
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
32
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
32
,
32
>
;
using
InstructionShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
16
>
;
using
Operator
=
cutlass
::
arch
::
OpMultiplyAdd
;
};
}
// namespace kernel
}
// namespace gemm
}
// namespace cutlass
paddle/phi/kernels/fusion/cutlass/linear_combination_ft_gelu.h
0 → 100644
浏览文件 @
def2a87f
此差异已折叠。
点击以展开。
paddle/phi/kernels/fusion/cutlass/moe_cutlass_kernel.h
0 → 100644
浏览文件 @
def2a87f
此差异已折叠。
点击以展开。
paddle/phi/kernels/fusion/cutlass/moe_kernel.cu
0 → 100644
浏览文件 @
def2a87f
此差异已折叠。
点击以展开。
paddle/phi/kernels/fusion/cutlass/moe_kernel_impl.h
0 → 100644
浏览文件 @
def2a87f
此差异已折叠。
点击以展开。
paddle/phi/kernels/fusion/moe_kernel.h
0 → 100644
浏览文件 @
def2a87f
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
MoeKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
gate
,
const
DenseTensor
&
bmm0
,
const
DenseTensor
&
bias0
,
const
DenseTensor
&
bmm1
,
const
DenseTensor
&
bias1
,
const
std
::
string
&
act_type
,
DenseTensor
*
output
);
}
// namespace phi
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
def2a87f
...
...
@@ -78,6 +78,7 @@ if(NOT WITH_GPU)
list
(
REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op_api
)
endif
()
list
(
REMOVE_ITEM TEST_OPS test_fused_ec_moe_op
)
list
(
REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_op
)
list
(
REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_grad_op
)
list
(
REMOVE_ITEM TEST_OPS test_fuse_gemm_epilogue_pass
)
...
...
@@ -143,6 +144,7 @@ if(WIN32)
list
(
REMOVE_ITEM TEST_OPS test_trt_convert_preln_residual_bias
)
list
(
REMOVE_ITEM TEST_OPS test_trt_convert_preln_residual_no_bias
)
list
(
REMOVE_ITEM TEST_OPS test_fused_multi_transformer_int8_op
)
list
(
REMOVE_ITEM TEST_OPS test_fused_ec_moe_op
)
endif
()
list
(
REMOVE_ITEM TEST_OPS test_checkpoint_saver
)
...
...
python/paddle/fluid/tests/unittests/test_fused_ec_moe_op.py
0 → 100644
浏览文件 @
def2a87f
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
import
paddle
import
paddle.nn.functional
as
F
from
paddle.fluid.framework
import
default_main_program
from
paddle.incubate.nn.functional
import
fused_ec_moe
from
paddle.nn.layer.common
import
Linear
default_main_program
().
random_seed
=
42
class
TestFusedEcMoEOp
(
OpTest
):
def
setUp
(
self
):
self
.
config
()
self
.
rtol
=
1e-3
self
.
atol
=
1e-3
paddle
.
set_default_dtype
(
self
.
x_type
)
self
.
__class__
.
op_type
=
"fused_ec_moe"
# Since it's only used in inference.
self
.
__class__
.
no_need_check_grad
=
True
self
.
bmm_w0
=
paddle
.
to_tensor
(
np
.
random
.
randn
(
self
.
num_expert
,
self
.
d_model
,
self
.
d_feedforward
)
*
0.001
,
dtype
=
paddle
.
float16
,
)
self
.
bmm_b0
=
paddle
.
to_tensor
(
np
.
random
.
randn
(
self
.
num_expert
,
1
,
self
.
d_feedforward
)
*
0.001
,
dtype
=
paddle
.
float16
,
)
self
.
bmm_w1
=
paddle
.
to_tensor
(
np
.
random
.
randn
(
self
.
num_expert
,
self
.
d_feedforward
,
self
.
d_model
)
*
0.001
,
dtype
=
paddle
.
float16
,
)
self
.
bmm_b1
=
paddle
.
to_tensor
(
np
.
random
.
randn
(
self
.
num_expert
,
1
,
self
.
d_model
)
*
0.001
,
dtype
=
paddle
.
float16
,
)
self
.
tensor_x
=
paddle
.
to_tensor
(
np
.
random
.
randn
(
self
.
batch_size
,
self
.
seq_len
,
self
.
d_model
)
*
0.001
,
dtype
=
paddle
.
float16
,
)
self
.
bmm_w0
.
stop_gradient
=
True
self
.
bmm_b0
.
stop_gradient
=
True
self
.
bmm_w1
.
stop_gradient
=
True
self
.
bmm_b1
.
stop_gradient
=
True
self
.
tensor_x
.
stop_gradient
=
True
self
.
gate
=
Linear
(
self
.
d_model
,
self
.
num_expert
)
paddle
.
set_default_dtype
(
"float16"
)
self
.
activation
=
getattr
(
F
,
self
.
act_method
)
def
config
(
self
):
self
.
x_type
=
np
.
float16
self
.
batch_size
=
10
self
.
seq_len
=
128
self
.
num_expert
=
32
self
.
d_model
=
768
self
.
d_feedforward
=
3072
self
.
act_method
=
'gelu'
def
GetBaselineOut
(
self
,
tensor_x
,
gate_logits
):
def
expert_choice_gating
(
logits
,
capacity
,
batch_idx
,
expert_idx
):
gates
=
F
.
softmax
(
logits
,
-
1
)
indices1_s
=
paddle
.
topk
(
logits
.
transpose
([
0
,
2
,
1
]),
k
=
capacity
,
axis
=-
1
)[
1
].
cast
(
"int32"
)
seqlen_idx
=
indices1_s
.
reshape
([
-
1
])
gather_idx
=
paddle
.
stack
([
batch_idx
,
seqlen_idx
,
expert_idx
],
-
1
)
prob
=
paddle
.
gather_nd
(
gates
,
gather_idx
)
return
prob
,
expert_idx
,
gather_idx
,
capacity
paddle
.
disable_static
()
capacity
=
self
.
seq_len
//
16
batch_expert_idx
=
paddle
.
nonzero
(
paddle
.
ones
(
shape
=
[
self
.
batch_size
,
self
.
num_expert
,
capacity
])
).
cast
(
'int32'
)
batch_idx
=
batch_expert_idx
[:,
0
]
expert_idx
=
batch_expert_idx
[:,
1
]
(
expert_prob_flatten
,
expert_idx_flatten
,
gather_idx
,
cap
,
)
=
expert_choice_gating
(
gate_logits
,
capacity
,
batch_idx
,
expert_idx
)
outputs
=
paddle
.
zeros_like
(
tensor_x
)
batch_prob
=
expert_prob_flatten
.
reshape
(
[
self
.
batch_size
,
self
.
num_expert
,
-
1
,
1
]
)
batch_idx
=
gather_idx
[:,
:
2
]
selected_token
=
tensor_x
.
gather_nd
(
batch_idx
)
batch_selected_token
=
selected_token
.
reshape
(
[
self
.
batch_size
,
self
.
num_expert
,
-
1
,
tensor_x
.
shape
[
-
1
]]
)
batch_selected_token
=
batch_selected_token
.
transpose
(
[
1
,
0
,
2
,
3
]
).
reshape
([
self
.
num_expert
,
-
1
,
tensor_x
.
shape
[
-
1
]])
output
=
paddle
.
bmm
(
batch_selected_token
,
self
.
bmm_w0
)
+
self
.
bmm_b0
output
=
self
.
activation
(
output
)
output
=
paddle
.
bmm
(
output
,
self
.
bmm_w1
)
+
self
.
bmm_b1
output
=
output
.
transpose
([
1
,
0
,
2
]).
reshape
(
[
self
.
batch_size
,
-
1
,
self
.
num_expert
,
tensor_x
.
shape
[
-
1
]]
)
output
=
output
.
transpose
([
0
,
2
,
1
,
3
])
output
=
batch_prob
*
output
output
=
output
.
reshape
([
-
1
,
tensor_x
.
shape
[
-
1
]])
outputs
=
outputs
.
scatter_nd_add
(
batch_idx
,
output
)
return
outputs
+
tensor_x
def
GetFusedEcMoeOut
(
self
,
tensor_x
,
gate_logits
):
paddle
.
disable_static
()
fused_out
=
fused_ec_moe
(
tensor_x
,
gate_logits
,
self
.
bmm_w0
,
self
.
bmm_b0
,
self
.
bmm_w1
,
self
.
bmm_b1
,
self
.
act_method
,
)
return
fused_out
def
test_fused_ec_moe_op
(
self
):
gate_logits
=
self
.
gate
(
self
.
tensor_x
)
final_out_ref
=
self
.
GetBaselineOut
(
self
.
tensor_x
,
gate_logits
)
final_out
=
self
.
GetFusedEcMoeOut
(
self
.
tensor_x
,
gate_logits
)
np
.
testing
.
assert_allclose
(
final_out_ref
,
final_out
,
rtol
=
self
.
rtol
,
atol
=
self
.
atol
)
class
TestFusedEcMoEOpActGeluFp16
(
TestFusedEcMoEOp
):
def
config
(
self
):
super
().
config
()
self
.
x_type
=
np
.
float16
class
TestFusedEcMoEOpActReluFp16
(
TestFusedEcMoEOp
):
def
config
(
self
):
super
().
config
()
self
.
x_type
=
np
.
float16
self
.
act_method
=
"relu"
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/incubate/nn/__init__.py
浏览文件 @
def2a87f
...
...
@@ -20,6 +20,7 @@ from .layer.fused_linear import FusedLinear # noqa: F401
from
.layer.fused_transformer
import
(
FusedBiasDropoutResidualLayerNorm
,
)
# noqa: F401
from
.layer.fused_ec_moe
import
FusedEcMoe
# noqa: F401
__all__
=
[
# noqa
'FusedMultiHeadAttention'
,
...
...
@@ -28,4 +29,5 @@ __all__ = [ # noqa
'FusedMultiTransformer'
,
'FusedLinear'
,
'FusedBiasDropoutResidualLayerNorm'
,
'FusedEcMoe'
,
]
python/paddle/incubate/nn/functional/__init__.py
浏览文件 @
def2a87f
...
...
@@ -17,6 +17,7 @@ from .fused_transformer import fused_feedforward
from
.fused_transformer
import
fused_multi_transformer
from
.fused_matmul_bias
import
fused_matmul_bias
,
fused_linear
from
.fused_transformer
import
fused_bias_dropout_residual_layer_norm
from
.fused_ec_moe
import
fused_ec_moe
__all__
=
[
'fused_multi_head_attention'
,
...
...
@@ -25,4 +26,5 @@ __all__ = [
'fused_matmul_bias'
,
'fused_linear'
,
'fused_bias_dropout_residual_layer_norm'
,
'fused_ec_moe'
,
]
python/paddle/incubate/nn/functional/fused_ec_moe.py
0 → 100644
浏览文件 @
def2a87f
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle.fluid.layer_helper
import
LayerHelper
def
fused_ec_moe
(
x
,
gate
,
bmm0_weight
,
bmm0_bias
,
bmm1_weight
,
bmm1_bias
,
act_type
):
"""
Applies fused ec_moe kernel.
This method requires SM_ARCH in sm75, sm80, sm86.
Args:
x (Tensor): the input Tensor. Its shape is [bsz, seq_len, d_model].
gate (Tensor): the gate Tensor to choose expert. Its shape is [bsz, seq_len, e].
bmm0_weight (Tensor): the first batch matrix matmul weight. Its shape is [e, d_model, d_feed_forward].
bmm0_bias (Tensor): the first batch matrix matmul bias. Its shape is [e, 1, d_feed_forward].
bmm1_weight (Tensor): the second batch matrix matmul weight. Its shape is [e, d_model, d_feed_forward].
bmm1_bias (Tensor): the second batch matrix matmul bias. Its shape is [e, 1, d_feed_forward].
act_type (string): the Activation Type. Currently only support `gelu`, `relu`.
Returns:
Tensor: the output Tensor.
Examples:
.. code-block:: python
# required: gpu
import paddle
from paddle.incubate.nn.functional import fused_ec_moe
batch = 10
seq_len = 128
d_model = 1024
d_feed_forward = d_model * 4
num_expert = 8
x = paddle.randn([batch, seq_len, d_model])
gate = paddle.randn([batch, seq_len, num_expert])
bmm0_weight = paddle.randn([num_expert, d_model, d_feed_forward])
bmm0_bias = paddle.randn([num_expert, d_model, d_feed_forward])
bmm1_weight = paddle.randn([num_expert, d_model, d_feed_forward])
bmm1_bias = paddle.randn([num_expert, d_model, d_feed_forward])
out = fused_ec_moe(x, gate, bmm0_weight, bmm0_bias, bmm1_weight, bmm1_bias, act_type="gelu")
print(out.shape) # [batch, seq_len, num_expert]
"""
helper
=
LayerHelper
(
'fused_moe'
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
)
helper
.
append_op
(
type
=
'moe'
,
inputs
=
{
'X'
:
x
,
'Gate'
:
gate
,
'Bmm0'
:
bmm0_weight
,
'Bias0'
:
bmm0_bias
,
'Bmm1'
:
bmm1_weight
,
'Bias1'
:
bmm1_bias
,
},
outputs
=
{
'Out'
:
out
},
attrs
=
{
'act_type'
:
act_type
},
)
return
out
python/paddle/incubate/nn/layer/fused_ec_moe.py
0 → 100644
浏览文件 @
def2a87f
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle.incubate.nn
import
functional
as
F
from
paddle.nn
import
Layer
class
FusedEcMoe
(
Layer
):
r
"""A FusedEcMoe Layer.
Parameters:
hidden_size (int): The dim size of input units.
inter_size (int): The dim size of feed forward network.
num_expert (int): The number of experts.
act_type (string): The activation type. Currently only support `gelu`, `relu`.
weight_attr (ParamAttr, optional): The attribute for the learnable
weight of this layer. The default value is None and the weight will be
initialized to zero. For detailed information, please refer to
paddle.ParamAttr.
bias_attr (ParamAttr|bool, optional): The attribute for the learnable bias
of this layer. If it is set to False, no bias will be added to the output.
If it is set to None or one kind of ParamAttr, a bias parameter will
be created according to ParamAttr. For detailed information, please refer
to paddle.ParamAttr. The default value is None and the bias will be
initialized to zero.
Attribute:
**weight** (Parameter): the learnable weight of this layer.
**bias** (Parameter): the learnable bias of this layer.
Shape:
- input: Multi-dimentional tensor with shape :math:`[batch\_size, seq\_len, d\_model]` .
- output: Multi-dimentional tensor with shape :math:`[batch\_size, seq\_len, d\_model]` .
Examples:
.. code-block:: python
# required: gpu
import paddle
from paddle.incubate.nn.layer.fused_ec_moe import FusedEcMoe
x = paddle.randn([10, 128, 1024]) # [bsz, seq_len, d_model]
gate = paddle.randn([10, 128, 8]) # [bsz, seq_len, num_experts]
moe = FusedEcMoe(1024, 4096, 8, act_type="gelu")
y = moe(x, gate)
print(y.shape) # [10, 128, 1024]
"""
def
__init__
(
self
,
hidden_size
,
inter_size
,
num_experts
,
act_type
,
weight_attr
=
None
,
bias_attr
=
None
,
):
super
().
__init__
()
weight0_shape
=
[
num_experts
,
hidden_size
,
inter_size
]
bias0_shape
=
[
num_experts
,
1
,
inter_size
]
weight1_shape
=
[
num_experts
,
inter_size
,
hidden_size
]
bias1_shape
=
[
num_experts
,
1
,
hidden_size
]
dtype
=
self
.
_helper
.
get_default_dtype
()
self
.
bmm_weight0
=
self
.
create_parameter
(
shape
=
weight0_shape
,
attr
=
weight_attr
,
dtype
=
dtype
,
is_bias
=
False
)
self
.
bmm_bias0
=
self
.
create_parameter
(
shape
=
bias0_shape
,
attr
=
bias_attr
,
dtype
=
dtype
,
is_bias
=
True
)
self
.
bmm_weight1
=
self
.
create_parameter
(
shape
=
weight1_shape
,
attr
=
weight_attr
,
dtype
=
dtype
,
is_bias
=
False
)
self
.
bmm_bias1
=
self
.
create_parameter
(
shape
=
bias1_shape
,
attr
=
bias_attr
,
dtype
=
dtype
,
is_bias
=
True
)
self
.
act_type
=
act_type
if
self
.
act_type
not
in
[
"gelu"
,
"relu"
]:
raise
NotImplementedError
(
"Currently only support `gelu`, `relu`. "
)
def
forward
(
self
,
x
,
gate
):
return
F
.
fused_ec_moe
(
x
,
gate
,
self
.
bmm_weight0
,
self
.
bmm_bias0
,
self
.
bmm_weight1
,
self
.
bmm_bias1
,
self
.
act_type
,
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录