Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
84103819
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
84103819
编写于
9月 02, 2020
作者:
Z
zlsh80826
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add 64/96/384 support
上级
d80ae5bc
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
56 addition
and
6 deletion
+56
-6
paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.cu
...le/fluid/inference/tensorrt/plugin/convert_mask_plugin.cu
+56
-6
未找到文件。
paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.cu
浏览文件 @
84103819
...
@@ -31,26 +31,63 @@ namespace plugin {
...
@@ -31,26 +31,63 @@ namespace plugin {
to the mask with the bertQKV fused_multihead_attention format */
to the mask with the bertQKV fused_multihead_attention format */
constexpr
size_t
threadsPerCta128
=
2
*
2
*
32
;
constexpr
size_t
threadsPerCta128
=
2
*
2
*
32
;
constexpr
size_t
threadsPerCta384
=
1
*
8
*
32
;
constexpr
size_t
xmmasM128
=
4
;
constexpr
size_t
xmmasM128
=
4
;
constexpr
size_t
xmmasM384
=
24
;
constexpr
size_t
packedMaskSize64
=
xmmasM128
*
threadsPerCta128
;
constexpr
size_t
packedMaskSize96
=
xmmasM128
*
threadsPerCta128
;
constexpr
size_t
packedMaskSize128
=
xmmasM128
*
threadsPerCta128
;
constexpr
size_t
packedMaskSize128
=
xmmasM128
*
threadsPerCta128
;
constexpr
size_t
packedMaskSize384
=
xmmasM384
*
threadsPerCta384
;
nvinfer1
::
DimsExprs
ConvertMaskPluginDynamic
::
getOutputDimensions
(
nvinfer1
::
DimsExprs
ConvertMaskPluginDynamic
::
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
{
nvinfer1
::
IExprBuilder
&
expr_builder
)
{
assert
(
output_index
==
0
);
assert
(
output_index
==
0
);
constexpr
int
BDIM
=
0
;
constexpr
int
SDIM
=
1
;
if
(
type_
==
nvinfer1
::
DataType
::
kHALF
)
{
if
(
type_
==
nvinfer1
::
DataType
::
kHALF
)
{
auto
cms64
=
expr_builder
.
constant
(
packedMaskSize64
);
auto
cms96
=
expr_builder
.
constant
(
packedMaskSize96
);
auto
cms128
=
expr_builder
.
constant
(
packedMaskSize128
);
auto
cms128
=
expr_builder
.
constant
(
packedMaskSize128
);
auto
cms384
=
expr_builder
.
constant
(
packedMaskSize384
);
auto
c64
=
expr_builder
.
constant
(
64
);
auto
c96
=
expr_builder
.
constant
(
96
);
auto
c128
=
expr_builder
.
constant
(
128
);
auto
c384
=
expr_builder
.
constant
(
384
);
auto
is64
=
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kEQUAL
,
*
inputs
[
0
].
d
[
SDIM
],
*
c64
);
auto
is96
=
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kEQUAL
,
*
inputs
[
0
].
d
[
SDIM
],
*
c96
);
auto
is128
=
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kEQUAL
,
*
inputs
[
0
].
d
[
SDIM
],
*
c128
);
auto
is384
=
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kEQUAL
,
*
inputs
[
0
].
d
[
SDIM
],
*
c384
);
auto
sel64
=
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kPROD
,
*
is64
,
*
cms64
);
auto
sel96
=
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kPROD
,
*
is96
,
*
cms96
);
auto
sel128
=
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kPROD
,
*
is128
,
*
cms128
);
auto
sel384
=
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kPROD
,
*
is384
,
*
cms384
);
auto
maskSize1
=
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kSUM
,
*
sel64
,
*
sel96
);
auto
maskSize2
=
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kSUM
,
*
sel384
,
*
sel128
);
auto
maskSize
=
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kSUM
,
*
maskSize1
,
*
maskSize2
);
auto
fp16maskSize
=
auto
fp16maskSize
=
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kPROD
,
*
cms128
,
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kPROD
,
*
maskSize
,
*
expr_builder
.
constant
(
2
));
*
expr_builder
.
constant
(
2
));
nvinfer1
::
DimsExprs
ret
;
nvinfer1
::
DimsExprs
ret
;
ret
.
nbDims
=
2
;
ret
.
nbDims
=
2
;
ret
.
d
[
0
]
=
inputs
[
0
].
d
[
0
];
ret
.
d
[
0
]
=
inputs
[
0
].
d
[
BDIM
];
ret
.
d
[
1
]
=
fp16maskSize
;
ret
.
d
[
1
]
=
fp16maskSize
;
return
ret
;
return
ret
;
}
}
nvinfer1
::
DimsExprs
ret
;
nvinfer1
::
DimsExprs
ret
;
...
@@ -187,7 +224,7 @@ int ConvertMaskPluginDynamic::enqueue(
...
@@ -187,7 +224,7 @@ int ConvertMaskPluginDynamic::enqueue(
int
batch
=
input_dims
.
d
[
0
];
int
batch
=
input_dims
.
d
[
0
];
int
seq_len
=
input_dims
.
d
[
1
];
int
seq_len
=
input_dims
.
d
[
1
];
assert
(
seq_len
==
128
);
// assert(seq_len == 64 || seq_len == 96 || seq_len == 128 || seq_len == 384
);
if
(
type_
==
nvinfer1
::
DataType
::
kFLOAT
)
{
if
(
type_
==
nvinfer1
::
DataType
::
kFLOAT
)
{
IMaskPreprocess
<<<
batch
,
seq_len
,
0
,
stream
>>>
(
IMaskPreprocess
<<<
batch
,
seq_len
,
0
,
stream
>>>
(
...
@@ -204,11 +241,24 @@ int ConvertMaskPluginDynamic::enqueue(
...
@@ -204,11 +241,24 @@ int ConvertMaskPluginDynamic::enqueue(
static_cast
<
const
half
*>
(
inputs
[
0
]),
inputMaskSB
,
seq_len
,
batch
);
static_cast
<
const
half
*>
(
inputs
[
0
]),
inputMaskSB
,
seq_len
,
batch
);
}
}
size_t
warps_m
=
0
,
warps_n
=
0
,
warps_k
=
1
;
size_t
warps_m
=
0
,
warps_n
=
0
,
warps_k
=
1
;
if
(
seq_len
==
128
)
{
if
(
seq_len
==
64
||
seq_len
==
96
||
seq_len
==
128
)
{
warps_m
=
2
;
warps_m
=
2
;
warps_n
=
2
;
warps_n
=
2
;
}
else
if
(
seq_len
==
384
)
{
warps_m
=
1
;
warps_n
=
8
;
}
else
{
assert
(
false
);
}
}
/*
int* buf_h = (int*)malloc(batch * seq_len * sizeof(int));
cudaMemcpy(buf_h, inputMaskSB, batch * seq_len * sizeof(int),
cudaMemcpyDeviceToHost);
for (int i = 0; i < batch*seq_len; ++ i) {
std::cerr << buf_h[i] << " ";
}
std::cerr << std::endl;
*/
convertMask
(
seq_len
,
batch
,
warps_m
,
warps_n
,
warps_k
,
inputMaskSB
,
convertMask
(
seq_len
,
batch
,
warps_m
,
warps_n
,
warps_k
,
inputMaskSB
,
static_cast
<
uint32_t
*>
(
outputs
[
0
]),
stream
);
static_cast
<
uint32_t
*>
(
outputs
[
0
]),
stream
);
cudaFree
(
inputMaskSB
);
cudaFree
(
inputMaskSB
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录