Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
6512e087
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6512e087
编写于
10月 10, 2022
作者:
W
Wangzheee
提交者:
GitHub
10月 10, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Paddle Inference]fix embedding fused (#46789)
* fix embedding fused
上级
ae6b4713
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
1203 addition
and
275 deletion
+1203
-275
paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
...fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
+3
-3
paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc
...inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc
+1
-1
paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelHFace.cu
...nsorrt/plugin/many_emb_Layernorm_varseqlen_kernelHFace.cu
+403
-61
paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelMTron.cu
...nsorrt/plugin/many_emb_Layernorm_varseqlen_kernelMTron.cu
+419
-71
paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu
...ce/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu
+263
-115
paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h
...nce/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h
+114
-24
未找到文件。
paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
浏览文件 @
6512e087
...
@@ -210,14 +210,14 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
...
@@ -210,14 +210,14 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
"max_seqlen_tensor"
));
// max_seqlen, eval_placeholder_3
"max_seqlen_tensor"
));
// max_seqlen, eval_placeholder_3
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
"ManyEmbLayerNormPluginDynamic"
,
"
2
"
);
"ManyEmbLayerNormPluginDynamic"
,
"
1
"
);
auto
plugin_obj
=
auto
plugin_obj
=
creator
->
createPlugin
(
"ManyEmbLayerNormPluginDynamic"
,
plugin_ptr
);
creator
->
createPlugin
(
"ManyEmbLayerNormPluginDynamic"
,
plugin_ptr
);
auto
plugin_layer
=
engine_
->
network
()
->
addPluginV2
(
auto
plugin_layer
=
engine_
->
network
()
->
addPluginV2
(
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
*
plugin_obj
);
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
*
plugin_obj
);
plugin_layer
->
setName
((
"ManyEmbLayerNormPluginDynamic_V
2
(Output: "
+
plugin_layer
->
setName
((
"ManyEmbLayerNormPluginDynamic_V
1
(Output: "
+
op_desc
.
Output
(
"Out"
)[
0
]
+
")"
)
op_desc
.
Output
(
"Out"
)[
0
]
+
")"
)
.
c_str
());
.
c_str
());
free
(
plugin_ptr
);
free
(
plugin_ptr
);
...
@@ -248,7 +248,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
...
@@ -248,7 +248,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
layer
=
plugin_layer
;
layer
=
plugin_layer
;
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
RreplenishLayerAndOutput
(
layer
,
RreplenishLayerAndOutput
(
layer
,
"ManyEmbLayerNormPluginDynamic_V
2
"
,
"ManyEmbLayerNormPluginDynamic_V
1
"
,
{
output_name
,
std
::
string
(
"qkv_plugin_mask"
)},
{
output_name
,
std
::
string
(
"qkv_plugin_mask"
)},
test_mode
);
test_mode
);
}
}
...
...
paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc
浏览文件 @
6512e087
...
@@ -194,7 +194,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
...
@@ -194,7 +194,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
"max_seqlen_tensor"
));
// max_seqlen, eval_placeholder_3
"max_seqlen_tensor"
));
// max_seqlen, eval_placeholder_3
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
"ManyEmbLayerNormPluginDynamic"
,
"
3
"
);
"ManyEmbLayerNormPluginDynamic"
,
"
2
"
);
auto
plugin_obj
=
auto
plugin_obj
=
creator
->
createPlugin
(
"ManyEmbLayerNormPluginDynamic"
,
plugin_ptr
);
creator
->
createPlugin
(
"ManyEmbLayerNormPluginDynamic"
,
plugin_ptr
);
...
...
paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelHFace.cu
浏览文件 @
6512e087
...
@@ -30,20 +30,22 @@ namespace tensorrt {
...
@@ -30,20 +30,22 @@ namespace tensorrt {
namespace
plugin
{
namespace
plugin
{
template
<
typename
T
,
unsigned
TPB
>
template
<
typename
T
,
unsigned
TPB
>
__global__
void
embLayerNormKernelHFace
(
int32_t
ld
,
__global__
void
embLayerNormKernelHFace_2
(
int32_t
ld
,
int32_t
**
inputIds
,
int32_t
const
*
inputIds0
,
int32_t
const
nbLookupTables
,
int32_t
const
*
inputIds1
,
int32_t
nbLookupTables
,
float
const
*
beta
,
float
const
*
beta
,
float
const
*
gamma
,
float
const
*
gamma
,
T
**
mIdsEmbDev
,
T
const
*
mIdsEmbDev0
,
int32_t
*
IdsSize
,
T
const
*
mIdsEmbDev1
,
int32_t
IdsSize0
,
int32_t
IdsSize1
,
T
*
output
)
{
T
*
output
)
{
cub
::
Sum
pairSum
;
cub
::
Sum
pairSum
;
int32_t
const
s
=
blockIdx
.
x
;
int32_t
const
s
=
blockIdx
.
x
;
int32_t
const
b
=
blockIdx
.
y
;
int32_t
const
b
=
blockIdx
.
y
;
int32_t
*
cuSeqlens
=
inputIds
[
0
];
int32_t
const
sumS
=
inputIds0
[
b
];
int32_t
const
sumS
=
cuSeqlens
[
b
];
int32_t
const
s_b
=
inputIds0
[
b
+
1
]
-
sumS
;
int32_t
const
s_b
=
cuSeqlens
[
b
+
1
]
-
sumS
;
if
(
s
>=
s_b
)
{
if
(
s
>=
s_b
)
{
return
;
// This CTA has nothing to do
return
;
// This CTA has nothing to do
}
}
...
@@ -52,17 +54,87 @@ __global__ void embLayerNormKernelHFace(int32_t ld,
...
@@ -52,17 +54,87 @@ __global__ void embLayerNormKernelHFace(int32_t ld,
extern
__shared__
int32_t
word_id
[];
extern
__shared__
int32_t
word_id
[];
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
for
(
int
i
=
1
;
i
<
nbLookupTables
;
++
i
)
{
if
(
static_cast
<
int32_t
const
*>
(
inputIds1
)[
seqPos
]
<
0
||
if
(
static_cast
<
int32_t
const
*>
(
inputIds
[
i
])[
seqPos
]
<
0
||
static_cast
<
int32_t
const
*>
(
inputIds1
)[
seqPos
]
>=
IdsSize1
)
{
static_cast
<
int32_t
const
*>
(
inputIds
[
i
])[
seqPos
]
>=
IdsSize
[
i
])
{
printf
(
printf
(
"Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup "
"Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup "
"table: ID < 0 or ID > max "
);
"table: ID < 0 or ID > max "
);
return
;
return
;
}
else
{
}
else
{
word_id
[
i
-
1
]
=
static_cast
<
int32_t
const
*>
(
inputIds
[
i
]
)[
seqPos
];
word_id
[
0
]
=
static_cast
<
int32_t
const
*>
(
inputIds1
)[
seqPos
];
}
}
}
}
__syncthreads
();
// 2. load pos/tok/word embeddings and add them toghether
// offset into embeddings is given by wordId * hidden_size
int32_t
const
poffset
=
blockIdx
.
x
*
ld
;
int32_t
const
outOffset
=
seqPos
*
ld
;
// the output offset is given by b * (S*hidden_size) + s * hidden_size
kvp
<
T
>
threadData
(
0
,
0
);
for
(
int32_t
it
=
threadIdx
.
x
;
it
<
ld
;
it
+=
TPB
)
{
T
p
(
mIdsEmbDev0
[
poffset
+
it
]);
// pos id
T
val
=
p
;
int32_t
const
offset
=
word_id
[
0
]
*
ld
;
val
+=
mIdsEmbDev1
[
offset
+
it
];
output
[
outOffset
+
it
]
=
val
;
T
const
rldval
=
rld
*
val
;
threadData
=
pairSum
(
threadData
,
kvp
<
T
>
(
rldval
,
rldval
*
val
));
}
// 3. layer norm on the sum
layerNorm
<
T
,
T
,
float
,
TPB
>
(
threadData
,
ld
,
outOffset
,
beta
,
gamma
,
output
);
}
template
<
typename
T
,
unsigned
TPB
>
__global__
void
embLayerNormKernelHFace_3
(
int32_t
ld
,
int32_t
const
*
inputIds0
,
int32_t
const
*
inputIds1
,
int32_t
const
*
inputIds2
,
int32_t
nbLookupTables
,
float
const
*
beta
,
float
const
*
gamma
,
T
const
*
mIdsEmbDev0
,
T
const
*
mIdsEmbDev1
,
T
const
*
mIdsEmbDev2
,
int32_t
IdsSize0
,
int32_t
IdsSize1
,
int32_t
IdsSize2
,
T
*
output
)
{
cub
::
Sum
pairSum
;
int32_t
const
s
=
blockIdx
.
x
;
int32_t
const
b
=
blockIdx
.
y
;
int32_t
const
sumS
=
inputIds0
[
b
];
int32_t
const
s_b
=
inputIds0
[
b
+
1
]
-
sumS
;
if
(
s
>=
s_b
)
{
return
;
// This CTA has nothing to do
}
T
const
rld
=
T
(
1.
f
)
/
T
(
ld
);
int32_t
const
seqPos
=
sumS
+
s
;
extern
__shared__
int32_t
word_id
[];
if
(
threadIdx
.
x
==
0
)
{
if
(
static_cast
<
int32_t
const
*>
(
inputIds1
)[
seqPos
]
<
0
||
static_cast
<
int32_t
const
*>
(
inputIds1
)[
seqPos
]
>=
IdsSize1
)
{
printf
(
"Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup "
"table: ID < 0 or ID > max "
);
return
;
}
else
{
word_id
[
0
]
=
static_cast
<
int32_t
const
*>
(
inputIds1
)[
seqPos
];
}
if
(
static_cast
<
int32_t
const
*>
(
inputIds2
)[
seqPos
]
<
0
||
static_cast
<
int32_t
const
*>
(
inputIds2
)[
seqPos
]
>=
IdsSize2
)
{
printf
(
"Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup "
"table: ID < 0 or ID > max "
);
return
;
}
else
{
word_id
[
1
]
=
static_cast
<
int32_t
const
*>
(
inputIds2
)[
seqPos
];
}
}
}
__syncthreads
();
__syncthreads
();
...
@@ -74,12 +146,101 @@ __global__ void embLayerNormKernelHFace(int32_t ld,
...
@@ -74,12 +146,101 @@ __global__ void embLayerNormKernelHFace(int32_t ld,
kvp
<
T
>
threadData
(
0
,
0
);
kvp
<
T
>
threadData
(
0
,
0
);
for
(
int32_t
it
=
threadIdx
.
x
;
it
<
ld
;
it
+=
TPB
)
{
for
(
int32_t
it
=
threadIdx
.
x
;
it
<
ld
;
it
+=
TPB
)
{
T
p
(
mIdsEmbDev
[
0
]
[
poffset
+
it
]);
// pos id
T
p
(
mIdsEmbDev
0
[
poffset
+
it
]);
// pos id
T
val
=
p
;
T
val
=
p
;
for
(
int
i
=
1
;
i
<
nbLookupTables
;
++
i
)
{
int32_t
const
offset0
=
word_id
[
0
]
*
ld
;
int32_t
const
offset
=
word_id
[
i
-
1
]
*
ld
;
val
+=
mIdsEmbDev1
[
offset0
+
it
];
val
+=
mIdsEmbDev
[
i
][
offset
+
it
];
int32_t
const
offset1
=
word_id
[
1
]
*
ld
;
val
+=
mIdsEmbDev2
[
offset1
+
it
];
output
[
outOffset
+
it
]
=
val
;
T
const
rldval
=
rld
*
val
;
threadData
=
pairSum
(
threadData
,
kvp
<
T
>
(
rldval
,
rldval
*
val
));
}
// 3. layer norm on the sum
layerNorm
<
T
,
T
,
float
,
TPB
>
(
threadData
,
ld
,
outOffset
,
beta
,
gamma
,
output
);
}
template
<
typename
T
,
unsigned
TPB
>
__global__
void
embLayerNormKernelHFace_4
(
int32_t
ld
,
int32_t
const
*
inputIds0
,
int32_t
const
*
inputIds1
,
int32_t
const
*
inputIds2
,
int32_t
const
*
inputIds3
,
int32_t
nbLookupTables
,
float
const
*
beta
,
float
const
*
gamma
,
T
const
*
mIdsEmbDev0
,
T
const
*
mIdsEmbDev1
,
T
const
*
mIdsEmbDev2
,
T
const
*
mIdsEmbDev3
,
int32_t
IdsSize0
,
int32_t
IdsSize1
,
int32_t
IdsSize2
,
int32_t
IdsSize3
,
T
*
output
)
{
cub
::
Sum
pairSum
;
int32_t
const
s
=
blockIdx
.
x
;
int32_t
const
b
=
blockIdx
.
y
;
int32_t
const
sumS
=
inputIds0
[
b
];
int32_t
const
s_b
=
inputIds0
[
b
+
1
]
-
sumS
;
if
(
s
>=
s_b
)
{
return
;
// This CTA has nothing to do
}
T
const
rld
=
T
(
1.
f
)
/
T
(
ld
);
int32_t
const
seqPos
=
sumS
+
s
;
extern
__shared__
int32_t
word_id
[];
if
(
threadIdx
.
x
==
0
)
{
if
(
static_cast
<
int32_t
const
*>
(
inputIds1
)[
seqPos
]
<
0
||
static_cast
<
int32_t
const
*>
(
inputIds1
)[
seqPos
]
>=
IdsSize1
)
{
printf
(
"Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup "
"table: ID < 0 or ID > max "
);
return
;
}
else
{
word_id
[
0
]
=
static_cast
<
int32_t
const
*>
(
inputIds1
)[
seqPos
];
}
if
(
static_cast
<
int32_t
const
*>
(
inputIds2
)[
seqPos
]
<
0
||
static_cast
<
int32_t
const
*>
(
inputIds2
)[
seqPos
]
>=
IdsSize2
)
{
printf
(
"Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup "
"table: ID < 0 or ID > max "
);
return
;
}
else
{
word_id
[
1
]
=
static_cast
<
int32_t
const
*>
(
inputIds2
)[
seqPos
];
}
if
(
static_cast
<
int32_t
const
*>
(
inputIds3
)[
seqPos
]
<
0
||
static_cast
<
int32_t
const
*>
(
inputIds3
)[
seqPos
]
>=
IdsSize3
)
{
printf
(
"Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup "
"table: ID < 0 or ID > max "
);
return
;
}
else
{
word_id
[
2
]
=
static_cast
<
int32_t
const
*>
(
inputIds3
)[
seqPos
];
}
}
}
__syncthreads
();
// 2. load pos/tok/word embeddings and add them toghether
// offset into embeddings is given by wordId * hidden_size
int32_t
const
poffset
=
blockIdx
.
x
*
ld
;
int32_t
const
outOffset
=
seqPos
*
ld
;
// the output offset is given by b * (S*hidden_size) + s * hidden_size
kvp
<
T
>
threadData
(
0
,
0
);
for
(
int32_t
it
=
threadIdx
.
x
;
it
<
ld
;
it
+=
TPB
)
{
T
p
(
mIdsEmbDev0
[
poffset
+
it
]);
// pos id
T
val
=
p
;
int32_t
const
offset0
=
word_id
[
0
]
*
ld
;
val
+=
mIdsEmbDev1
[
offset0
+
it
];
int32_t
const
offset1
=
word_id
[
1
]
*
ld
;
val
+=
mIdsEmbDev2
[
offset1
+
it
];
int32_t
const
offset2
=
word_id
[
2
]
*
ld
;
val
+=
mIdsEmbDev3
[
offset2
+
it
];
output
[
outOffset
+
it
]
=
val
;
output
[
outOffset
+
it
]
=
val
;
T
const
rldval
=
rld
*
val
;
T
const
rldval
=
rld
*
val
;
...
@@ -89,52 +250,233 @@ __global__ void embLayerNormKernelHFace(int32_t ld,
...
@@ -89,52 +250,233 @@ __global__ void embLayerNormKernelHFace(int32_t ld,
// 3. layer norm on the sum
// 3. layer norm on the sum
layerNorm
<
T
,
T
,
float
,
TPB
>
(
threadData
,
ld
,
outOffset
,
beta
,
gamma
,
output
);
layerNorm
<
T
,
T
,
float
,
TPB
>
(
threadData
,
ld
,
outOffset
,
beta
,
gamma
,
output
);
}
}
template
<
typename
T
>
int32_t
embSkipLayerNormHFace_2
(
cudaStream_t
stream
,
int32_t
ld
,
int32_t
B
,
int32_t
S
,
int
const
*
inputIds0
,
int
const
*
inputIds1
,
int32_t
nbLookupTables
,
float
const
*
beta
,
float
const
*
gamma
,
T
const
*
mIdsEmbDev0
,
T
const
*
mIdsEmbDev1
,
int32_t
IdsSize0
,
int32_t
IdsSize1
,
T
*
output
)
{
constexpr
int32_t
tpb
=
256
;
dim3
const
grid
(
S
,
B
,
1
);
dim3
const
block
(
tpb
,
1
,
1
);
size_t
cache_size
=
sizeof
(
int32_t
)
*
(
nbLookupTables
-
1
);
embLayerNormKernelHFace_2
<
T
,
tpb
>
<<<
grid
,
block
,
cache_size
,
stream
>>>
(
ld
,
inputIds0
,
inputIds1
,
nbLookupTables
,
beta
,
gamma
,
mIdsEmbDev0
,
mIdsEmbDev1
,
IdsSize0
,
IdsSize1
,
output
);
return
cudaPeekAtLastError
();
}
template
<
typename
T
>
int32_t
embSkipLayerNormHFace_3
(
cudaStream_t
stream
,
int32_t
ld
,
int32_t
B
,
int32_t
S
,
int
const
*
inputIds0
,
int
const
*
inputIds1
,
int
const
*
inputIds2
,
int32_t
nbLookupTables
,
float
const
*
beta
,
float
const
*
gamma
,
T
const
*
mIdsEmbDev0
,
T
const
*
mIdsEmbDev1
,
T
const
*
mIdsEmbDev2
,
int32_t
IdsSize0
,
int32_t
IdsSize1
,
int32_t
IdsSize2
,
T
*
output
)
{
constexpr
int32_t
tpb
=
256
;
dim3
const
grid
(
S
,
B
,
1
);
dim3
const
block
(
tpb
,
1
,
1
);
size_t
cache_size
=
sizeof
(
int32_t
)
*
(
nbLookupTables
-
1
);
embLayerNormKernelHFace_3
<
T
,
tpb
>
<<<
grid
,
block
,
cache_size
,
stream
>>>
(
ld
,
inputIds0
,
inputIds1
,
inputIds2
,
nbLookupTables
,
beta
,
gamma
,
mIdsEmbDev0
,
mIdsEmbDev1
,
mIdsEmbDev2
,
IdsSize0
,
IdsSize1
,
IdsSize2
,
output
);
return
cudaPeekAtLastError
();
}
template
<
typename
T
>
template
<
typename
T
>
int32_t
embSkipLayerNormHFace
(
cudaStream_t
stream
,
int32_t
embSkipLayerNormHFace
_4
(
cudaStream_t
stream
,
int32_t
ld
,
int32_t
ld
,
int32_t
B
,
int32_t
B
,
int32_t
S
,
int32_t
S
,
int32_t
**
inputIds
,
int
const
*
inputIds0
,
int32_t
const
nbLookupTables
,
int
const
*
inputIds1
,
int
const
*
inputIds2
,
int
const
*
inputIds3
,
int32_t
nbLookupTables
,
float
const
*
beta
,
float
const
*
beta
,
float
const
*
gamma
,
float
const
*
gamma
,
T
**
mIdsEmbDev
,
T
const
*
mIdsEmbDev0
,
int32_t
*
IdsSize
,
T
const
*
mIdsEmbDev1
,
T
const
*
mIdsEmbDev2
,
T
const
*
mIdsEmbDev3
,
int32_t
IdsSize0
,
int32_t
IdsSize1
,
int32_t
IdsSize2
,
int32_t
IdsSize3
,
T
*
output
)
{
T
*
output
)
{
constexpr
int32_t
tpb
=
256
;
constexpr
int32_t
tpb
=
256
;
dim3
const
grid
(
S
,
B
,
1
);
dim3
const
grid
(
S
,
B
,
1
);
dim3
const
block
(
tpb
,
1
,
1
);
dim3
const
block
(
tpb
,
1
,
1
);
size_t
cache_size
=
sizeof
(
int32_t
)
*
(
nbLookupTables
-
1
);
size_t
cache_size
=
sizeof
(
int32_t
)
*
(
nbLookupTables
-
1
);
embLayerNormKernelHFace
<
T
,
tpb
><<<
grid
,
block
,
cache_size
,
stream
>>>
(
embLayerNormKernelHFace_4
<
T
,
tpb
>
ld
,
inputIds
,
nbLookupTables
,
beta
,
gamma
,
mIdsEmbDev
,
IdsSize
,
output
);
<<<
grid
,
block
,
cache_size
,
stream
>>>
(
ld
,
inputIds0
,
inputIds1
,
inputIds2
,
inputIds3
,
nbLookupTables
,
beta
,
gamma
,
mIdsEmbDev0
,
mIdsEmbDev1
,
mIdsEmbDev2
,
mIdsEmbDev3
,
IdsSize0
,
IdsSize1
,
IdsSize2
,
IdsSize3
,
output
);
return
cudaPeekAtLastError
();
return
cudaPeekAtLastError
();
}
}
template
int32_t
embSkipLayerNormHFace
<
float
>(
cudaStream_t
,
template
int32_t
embSkipLayerNormHFace_2
<
float
>(
cudaStream_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
,
float
const
*
,
float
const
*
,
float
const
*
,
float
const
*
,
int32_t
,
int32_t
,
float
*
);
template
int32_t
embSkipLayerNormHFace_3
<
float
>(
cudaStream_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
**
,
int32_t
const
*
,
int32_t
const
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
,
float
const
*
,
float
const
*
,
float
const
*
,
float
const
*
,
float
const
*
,
int32_t
,
int32_t
,
int32_t
,
float
*
);
template
int32_t
embSkipLayerNormHFace_4
<
float
>(
cudaStream_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
,
float
const
*
,
float
const
*
,
float
const
*
,
float
const
*
,
float
const
*
,
float
const
*
,
float
const
*
,
float
**
,
float
const
*
,
int32_t
*
,
int32_t
,
int32_t
,
int32_t
,
int32_t
,
float
*
);
float
*
);
template
int32_t
embSkipLayerNormHFace
<
half
>(
cudaStream_t
,
template
int32_t
embSkipLayerNormHFace_2
<
half
>(
cudaStream_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
,
float
const
*
,
float
const
*
,
half
const
*
,
half
const
*
,
int32_t
,
int32_t
,
half
*
);
template
int32_t
embSkipLayerNormHFace_3
<
half
>(
cudaStream_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
**
,
int32_t
const
*
,
int32_t
const
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
,
float
const
*
,
float
const
*
,
float
const
*
,
float
const
*
,
half
**
,
half
const
*
,
int32_t
*
,
half
const
*
,
half
const
*
,
int32_t
,
int32_t
,
int32_t
,
half
*
);
half
*
);
template
int32_t
embSkipLayerNormHFace_4
<
half
>(
cudaStream_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
,
float
const
*
,
float
const
*
,
half
const
*
,
half
const
*
,
half
const
*
,
half
const
*
,
int32_t
,
int32_t
,
int32_t
,
int32_t
,
half
*
);
}
// namespace plugin
}
// namespace plugin
}
// namespace tensorrt
}
// namespace tensorrt
}
// namespace inference
}
// namespace inference
...
...
paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelMTron.cu
浏览文件 @
6512e087
...
@@ -30,121 +30,469 @@ namespace tensorrt {
...
@@ -30,121 +30,469 @@ namespace tensorrt {
namespace
plugin
{
namespace
plugin
{
template
<
typename
T
,
unsigned
TPB
>
template
<
typename
T
,
unsigned
TPB
>
__global__
void
embLayerNormKernelMTron
(
int32_t
ld
,
__global__
void
embLayerNormKernelMTron_2
(
int32_t
ld
,
int32_t
**
inputIds
,
int32_t
const
*
inputIds0
,
int32_t
const
nbLookupTables
,
int32_t
const
*
inputIds1
,
int32_t
nbLookupTables
,
float
const
*
beta
,
float
const
*
beta
,
float
const
*
gamma
,
float
const
*
gamma
,
T
**
mIdsEmbDev
,
T
const
*
mIdsEmbDev0
,
int32_t
*
IdsSize
,
T
const
*
mIdsEmbDev1
,
int32_t
IdsSize0
,
int32_t
IdsSize1
,
T
*
output
,
T
*
output
,
T
*
skip
)
{
T
*
skip
)
{
cub
::
Sum
pairSum
;
cub
::
Sum
pairSum
;
int32_t
const
s
=
blockIdx
.
x
;
int32_t
const
s
=
blockIdx
.
x
;
int32_t
const
b
=
blockIdx
.
y
;
int32_t
const
b
=
blockIdx
.
y
;
int32_t
*
cuSeqlens
=
inputIds
[
0
];
int32_t
const
sumS
=
inputIds0
[
b
];
int32_t
const
sumS
=
cuSeqlens
[
b
];
int32_t
const
s_b
=
inputIds0
[
b
+
1
]
-
sumS
;
int32_t
const
s_b
=
cuSeqlens
[
b
+
1
]
-
sumS
;
if
(
s
>=
s_b
)
{
if
(
s
>=
s_b
)
{
return
;
// This CTA has nothing to do
return
;
// This CTA has nothing to do
}
}
T
const
rld
=
T
(
1.
f
)
/
T
(
ld
);
T
const
rld
=
T
(
1.
f
)
/
T
(
ld
);
int32_t
cons
t
seqPos
=
sumS
+
s
;
const
int32_
t
seqPos
=
sumS
+
s
;
extern
__shared__
int32_t
word_id
[];
extern
__shared__
int32_t
word_id
[];
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
for
(
int
i
=
1
;
i
<
nbLookupTables
;
++
i
)
{
if
(
static_cast
<
int32_t
const
*>
(
inputIds1
)[
seqPos
]
<
0
||
if
(
static_cast
<
int32_t
const
*>
(
inputIds
[
i
])[
seqPos
]
<
0
||
static_cast
<
int32_t
const
*>
(
inputIds1
)[
seqPos
]
>=
IdsSize1
)
{
static_cast
<
int32_t
const
*>
(
inputIds
[
i
])[
seqPos
]
>=
IdsSize
[
i
])
{
printf
(
printf
(
"Error !!!!!!!!!!!!!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot
"
"Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup
"
"be lookup
table: ID < 0 or ID > max "
);
"
table: ID < 0 or ID > max "
);
return
;
return
;
}
else
{
}
else
{
word_id
[
i
-
1
]
=
static_cast
<
int32_t
const
*>
(
inputIds
[
i
]
)[
seqPos
];
word_id
[
0
]
=
static_cast
<
int32_t
const
*>
(
inputIds1
)[
seqPos
];
}
}
}
}
__syncthreads
();
// 2. load pos/tok/word embeddings and add them toghether
// offset into embeddings is given by wordId * hidden_size
const
int32_t
poffset
=
blockIdx
.
x
*
ld
;
const
int32_t
outOffset
=
seqPos
*
ld
;
// the output offset is given by b * (S*hidden_size) + s * hidden_size
kvp
<
T
>
threadData
(
0
,
0
);
for
(
int32_t
it
=
threadIdx
.
x
;
it
<
ld
;
it
+=
TPB
)
{
T
p
(
mIdsEmbDev0
[
poffset
+
it
]);
// pos id
T
val
=
p
;
const
int32_t
offset
=
word_id
[
0
]
*
ld
;
val
+=
mIdsEmbDev1
[
offset
+
it
];
output
[
outOffset
+
it
]
=
val
;
skip
[
outOffset
+
it
]
=
val
;
const
T
rldval
=
rld
*
val
;
threadData
=
pairSum
(
threadData
,
kvp
<
T
>
(
rldval
,
rldval
*
val
));
}
// 3. layer norm on the sum
layerNorm
<
T
,
T
,
float
,
TPB
>
(
threadData
,
ld
,
outOffset
,
beta
,
gamma
,
output
);
}
template
<
typename
T
,
unsigned
TPB
>
__global__
void
embLayerNormKernelMTron_3
(
int32_t
ld
,
int32_t
const
*
inputIds0
,
int32_t
const
*
inputIds1
,
int32_t
const
*
inputIds2
,
int32_t
nbLookupTables
,
float
const
*
beta
,
float
const
*
gamma
,
T
const
*
mIdsEmbDev0
,
T
const
*
mIdsEmbDev1
,
T
const
*
mIdsEmbDev2
,
int32_t
IdsSize0
,
int32_t
IdsSize1
,
int32_t
IdsSize2
,
T
*
output
,
T
*
skip
)
{
cub
::
Sum
pairSum
;
const
int32_t
s
=
blockIdx
.
x
;
const
int32_t
b
=
blockIdx
.
y
;
const
int32_t
sumS
=
inputIds0
[
b
];
const
int32_t
s_b
=
inputIds0
[
b
+
1
]
-
sumS
;
if
(
s
>=
s_b
)
{
return
;
// This CTA has nothing to do
}
const
T
rld
=
T
(
1.
f
)
/
T
(
ld
);
const
int32_t
seqPos
=
sumS
+
s
;
extern
__shared__
int32_t
word_id
[];
if
(
threadIdx
.
x
==
0
)
{
if
(
static_cast
<
int32_t
const
*>
(
inputIds1
)[
seqPos
]
<
0
||
static_cast
<
int32_t
const
*>
(
inputIds1
)[
seqPos
]
>=
IdsSize1
)
{
printf
(
"Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup "
"table: ID < 0 or ID > max "
);
return
;
}
else
{
word_id
[
0
]
=
static_cast
<
int32_t
const
*>
(
inputIds1
)[
seqPos
];
}
if
(
static_cast
<
int32_t
const
*>
(
inputIds2
)[
seqPos
]
<
0
||
static_cast
<
int32_t
const
*>
(
inputIds2
)[
seqPos
]
>=
IdsSize2
)
{
printf
(
"Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup "
"table: ID < 0 or ID > max "
);
return
;
}
else
{
word_id
[
1
]
=
static_cast
<
int32_t
const
*>
(
inputIds2
)[
seqPos
];
}
}
}
__syncthreads
();
__syncthreads
();
// 2. load pos/tok/word embeddings and add them toghether
// 2. load pos/tok/word embeddings and add them toghether
// offset into embeddings is given by wordId * hidden_size
// offset into embeddings is given by wordId * hidden_size
int32_t
cons
t
poffset
=
blockIdx
.
x
*
ld
;
const
int32_
t
poffset
=
blockIdx
.
x
*
ld
;
int32_t
cons
t
outOffset
=
seqPos
*
ld
;
const
int32_
t
outOffset
=
seqPos
*
ld
;
// the output offset is given by b * (S*hidden_size) + s * hidden_size
// the output offset is given by b * (S*hidden_size) + s * hidden_size
kvp
<
T
>
threadData
(
0
,
0
);
kvp
<
T
>
threadData
(
0
,
0
);
for
(
int32_t
it
=
threadIdx
.
x
;
it
<
ld
;
it
+=
TPB
)
{
for
(
int32_t
it
=
threadIdx
.
x
;
it
<
ld
;
it
+=
TPB
)
{
T
p
(
mIdsEmbDev
[
0
]
[
poffset
+
it
]);
// pos id
T
p
(
mIdsEmbDev
0
[
poffset
+
it
]);
// pos id
T
val
=
p
;
T
val
=
p
;
for
(
int
i
=
1
;
i
<
nbLookupTables
;
++
i
)
{
const
int32_t
offset0
=
word_id
[
0
]
*
ld
;
int32_t
const
offset
=
word_id
[
i
-
1
]
*
ld
;
val
+=
mIdsEmbDev1
[
offset0
+
it
];
val
+=
mIdsEmbDev
[
i
][
offset
+
it
];
const
int32_t
offset1
=
word_id
[
1
]
*
ld
;
val
+=
mIdsEmbDev2
[
offset1
+
it
];
output
[
outOffset
+
it
]
=
val
;
skip
[
outOffset
+
it
]
=
val
;
const
T
rldval
=
rld
*
val
;
threadData
=
pairSum
(
threadData
,
kvp
<
T
>
(
rldval
,
rldval
*
val
));
}
// 3. layer norm on the sum
layerNorm
<
T
,
T
,
float
,
TPB
>
(
threadData
,
ld
,
outOffset
,
beta
,
gamma
,
output
);
}
template
<
typename
T
,
unsigned
TPB
>
__global__
void
embLayerNormKernelMTron_4
(
int32_t
ld
,
int32_t
const
*
inputIds0
,
int32_t
const
*
inputIds1
,
int32_t
const
*
inputIds2
,
int32_t
const
*
inputIds3
,
int32_t
nbLookupTables
,
float
const
*
beta
,
float
const
*
gamma
,
T
const
*
mIdsEmbDev0
,
T
const
*
mIdsEmbDev1
,
T
const
*
mIdsEmbDev2
,
T
const
*
mIdsEmbDev3
,
int32_t
IdsSize0
,
int32_t
IdsSize1
,
int32_t
IdsSize2
,
int32_t
IdsSize3
,
T
*
output
,
T
*
skip
)
{
cub
::
Sum
pairSum
;
const
int32_t
s
=
blockIdx
.
x
;
const
int32_t
b
=
blockIdx
.
y
;
const
int32_t
sumS
=
inputIds0
[
b
];
const
int32_t
s_b
=
inputIds0
[
b
+
1
]
-
sumS
;
if
(
s
>=
s_b
)
{
return
;
// This CTA has nothing to do
}
const
T
rld
=
T
(
1.
f
)
/
T
(
ld
);
const
int32_t
seqPos
=
sumS
+
s
;
extern
__shared__
int32_t
word_id
[];
if
(
threadIdx
.
x
==
0
)
{
if
(
static_cast
<
int32_t
const
*>
(
inputIds1
)[
seqPos
]
<
0
||
static_cast
<
int32_t
const
*>
(
inputIds1
)[
seqPos
]
>=
IdsSize1
)
{
printf
(
"Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup "
"table: ID < 0 or ID > max "
);
return
;
}
else
{
word_id
[
0
]
=
static_cast
<
int32_t
const
*>
(
inputIds1
)[
seqPos
];
}
if
(
static_cast
<
int32_t
const
*>
(
inputIds2
)[
seqPos
]
<
0
||
static_cast
<
int32_t
const
*>
(
inputIds2
)[
seqPos
]
>=
IdsSize2
)
{
printf
(
"Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup "
"table: ID < 0 or ID > max "
);
return
;
}
else
{
word_id
[
1
]
=
static_cast
<
int32_t
const
*>
(
inputIds2
)[
seqPos
];
}
if
(
static_cast
<
int32_t
const
*>
(
inputIds3
)[
seqPos
]
<
0
||
static_cast
<
int32_t
const
*>
(
inputIds3
)[
seqPos
]
>=
IdsSize3
)
{
printf
(
"Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup "
"table: ID < 0 or ID > max "
);
return
;
}
else
{
word_id
[
2
]
=
static_cast
<
int32_t
const
*>
(
inputIds3
)[
seqPos
];
}
}
}
__syncthreads
();
// 2. load pos/tok/word embeddings and add them toghether
// offset into embeddings is given by wordId * hidden_size
const
int32_t
poffset
=
blockIdx
.
x
*
ld
;
const
int32_t
outOffset
=
seqPos
*
ld
;
// the output offset is given by b * (S*hidden_size) + s * hidden_size
kvp
<
T
>
threadData
(
0
,
0
);
for
(
int32_t
it
=
threadIdx
.
x
;
it
<
ld
;
it
+=
TPB
)
{
T
p
(
mIdsEmbDev0
[
poffset
+
it
]);
// pos id
T
val
=
p
;
const
int32_t
offset0
=
word_id
[
0
]
*
ld
;
val
+=
mIdsEmbDev1
[
offset0
+
it
];
const
int32_t
offset1
=
word_id
[
1
]
*
ld
;
val
+=
mIdsEmbDev2
[
offset1
+
it
];
const
int32_t
offset2
=
word_id
[
2
]
*
ld
;
val
+=
mIdsEmbDev3
[
offset2
+
it
];
output
[
outOffset
+
it
]
=
val
;
output
[
outOffset
+
it
]
=
val
;
skip
[
outOffset
+
it
]
=
val
;
skip
[
outOffset
+
it
]
=
val
;
T
const
rldval
=
rld
*
val
;
const
T
rldval
=
rld
*
val
;
threadData
=
pairSum
(
threadData
,
kvp
<
T
>
(
rldval
,
rldval
*
val
));
threadData
=
pairSum
(
threadData
,
kvp
<
T
>
(
rldval
,
rldval
*
val
));
}
}
// 3. layer norm on the sum
// 3. layer norm on the sum
layerNorm
<
T
,
T
,
float
,
TPB
>
(
threadData
,
ld
,
outOffset
,
beta
,
gamma
,
output
);
layerNorm
<
T
,
T
,
float
,
TPB
>
(
threadData
,
ld
,
outOffset
,
beta
,
gamma
,
output
);
}
}
template
<
typename
T
>
int32_t
embSkipLayerNormMTron_2
(
cudaStream_t
stream
,
int32_t
ld
,
int32_t
B
,
int32_t
S
,
int32_t
const
*
inputIds0
,
int32_t
const
*
inputIds1
,
int32_t
nbLookupTables
,
float
const
*
beta
,
float
const
*
gamma
,
T
const
*
mIdsEmbDev0
,
T
const
*
mIdsEmbDev1
,
int32_t
IdsSize0
,
int32_t
IdsSize1
,
T
*
output
,
T
*
skip
)
{
constexpr
int32_t
tpb
=
256
;
dim3
const
grid
(
S
,
B
,
1
);
dim3
const
block
(
tpb
,
1
,
1
);
size_t
cache_size
=
sizeof
(
int32_t
)
*
(
nbLookupTables
-
1
);
embLayerNormKernelMTron_2
<
T
,
tpb
>
<<<
grid
,
block
,
cache_size
,
stream
>>>
(
ld
,
inputIds0
,
inputIds1
,
nbLookupTables
,
beta
,
gamma
,
mIdsEmbDev0
,
mIdsEmbDev1
,
IdsSize0
,
IdsSize1
,
output
,
skip
);
return
cudaPeekAtLastError
();
}
template
<
typename
T
>
int32_t
embSkipLayerNormMTron_3
(
cudaStream_t
stream
,
int32_t
ld
,
int32_t
B
,
int32_t
S
,
int32_t
const
*
inputIds0
,
int32_t
const
*
inputIds1
,
int32_t
const
*
inputIds2
,
int32_t
nbLookupTables
,
float
const
*
beta
,
float
const
*
gamma
,
T
const
*
mIdsEmbDev0
,
T
const
*
mIdsEmbDev1
,
T
const
*
mIdsEmbDev2
,
int32_t
IdsSize0
,
int32_t
IdsSize1
,
int32_t
IdsSize2
,
T
*
output
,
T
*
skip
)
{
constexpr
int32_t
tpb
=
256
;
dim3
const
grid
(
S
,
B
,
1
);
dim3
const
block
(
tpb
,
1
,
1
);
size_t
cache_size
=
sizeof
(
int32_t
)
*
(
nbLookupTables
-
1
);
embLayerNormKernelMTron_3
<
T
,
tpb
>
<<<
grid
,
block
,
cache_size
,
stream
>>>
(
ld
,
inputIds0
,
inputIds1
,
inputIds2
,
nbLookupTables
,
beta
,
gamma
,
mIdsEmbDev0
,
mIdsEmbDev1
,
mIdsEmbDev2
,
IdsSize0
,
IdsSize1
,
IdsSize2
,
output
,
skip
);
return
cudaPeekAtLastError
();
}
template
<
typename
T
>
template
<
typename
T
>
int32_t
embSkipLayerNormMTron
(
cudaStream_t
stream
,
int32_t
embSkipLayerNormMTron
_4
(
cudaStream_t
stream
,
int32_t
ld
,
int32_t
ld
,
int32_t
B
,
int32_t
B
,
int32_t
S
,
int32_t
S
,
int32_t
**
inputIds
,
int32_t
const
*
inputIds0
,
int32_t
const
nbLookupTables
,
int32_t
const
*
inputIds1
,
int32_t
const
*
inputIds2
,
int32_t
const
*
inputIds3
,
int32_t
nbLookupTables
,
float
const
*
beta
,
float
const
*
beta
,
float
const
*
gamma
,
float
const
*
gamma
,
T
**
mIdsEmbDev
,
T
const
*
mIdsEmbDev0
,
int32_t
*
IdsSize
,
T
const
*
mIdsEmbDev1
,
T
const
*
mIdsEmbDev2
,
T
const
*
mIdsEmbDev3
,
int32_t
IdsSize0
,
int32_t
IdsSize1
,
int32_t
IdsSize2
,
int32_t
IdsSize3
,
T
*
output
,
T
*
output
,
T
*
skip
)
{
T
*
skip
)
{
constexpr
int32_t
tpb
=
256
;
constexpr
int32_t
tpb
=
256
;
dim3
const
grid
(
S
,
B
,
1
);
dim3
const
grid
(
S
,
B
,
1
);
dim3
const
block
(
tpb
,
1
,
1
);
dim3
const
block
(
tpb
,
1
,
1
);
size_t
cache_size
=
sizeof
(
int32_t
)
*
(
nbLookupTables
-
1
);
size_t
cache_size
=
sizeof
(
int32_t
)
*
(
nbLookupTables
-
1
);
embLayerNormKernelMTron
<
T
,
tpb
>
embLayerNormKernelMTron
_4
<
T
,
tpb
>
<<<
grid
,
block
,
cache_size
,
stream
>>>
(
ld
,
<<<
grid
,
block
,
cache_size
,
stream
>>>
(
ld
,
inputIds
,
inputIds0
,
inputIds1
,
inputIds2
,
inputIds3
,
nbLookupTables
,
nbLookupTables
,
beta
,
beta
,
gamma
,
gamma
,
mIdsEmbDev
,
mIdsEmbDev0
,
IdsSize
,
mIdsEmbDev1
,
mIdsEmbDev2
,
mIdsEmbDev3
,
IdsSize0
,
IdsSize1
,
IdsSize2
,
IdsSize3
,
output
,
output
,
skip
);
skip
);
return
cudaPeekAtLastError
();
return
cudaPeekAtLastError
();
}
}
template
int32_t
embSkipLayerNormMTron
<
float
>(
cudaStream_t
,
template
int32_t
embSkipLayerNormMTron
_2
<
float
>(
cudaStream_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
**
,
int32_t
const
*
,
int32_t
const
,
int32_t
const
*
,
int32_t
,
float
const
*
,
float
const
*
,
float
const
*
,
float
const
*
,
float
const
*
,
float
const
*
,
float
**
,
int32_t
,
int32_t
*
,
int32_t
,
float
*
,
float
*
,
float
*
);
float
*
);
template
int32_t
embSkipLayerNormMTron
<
half
>(
cudaStream_t
,
template
int32_t
embSkipLayerNormMTron_3
<
float
>(
cudaStream_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
,
int32_t
,
int32_t
**
,
int32_t
const
,
float
const
*
,
float
const
*
,
float
const
*
,
float
const
*
,
half
**
,
float
const
*
,
int32_t
*
,
float
const
*
,
float
const
*
,
int32_t
,
int32_t
,
int32_t
,
float
*
,
float
*
);
template
int32_t
embSkipLayerNormMTron_4
<
float
>(
cudaStream_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
,
float
const
*
,
float
const
*
,
float
const
*
,
float
const
*
,
float
const
*
,
float
const
*
,
int32_t
,
int32_t
,
int32_t
,
int32_t
,
float
*
,
float
*
);
template
int32_t
embSkipLayerNormMTron_2
<
half
>(
cudaStream_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
,
float
const
*
,
float
const
*
,
half
const
*
,
half
const
*
,
int32_t
,
int32_t
,
half
*
,
half
*
);
template
int32_t
embSkipLayerNormMTron_3
<
half
>(
cudaStream_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
,
float
const
*
,
float
const
*
,
half
const
*
,
half
const
*
,
half
const
*
,
int32_t
,
int32_t
,
int32_t
,
half
*
,
half
*
);
template
int32_t
embSkipLayerNormMTron_4
<
half
>(
cudaStream_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
,
float
const
*
,
float
const
*
,
half
const
*
,
half
const
*
,
half
const
*
,
half
const
*
,
int32_t
,
int32_t
,
int32_t
,
int32_t
,
half
*
,
half
*
,
half
*
);
half
*
);
...
...
paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu
浏览文件 @
6512e087
...
@@ -37,8 +37,8 @@ constexpr size_t xmmasM384 = 24;
...
@@ -37,8 +37,8 @@ constexpr size_t xmmasM384 = 24;
constexpr
size_t
packedMaskSize128
=
xmmasM128
*
threadsPerCta128
;
constexpr
size_t
packedMaskSize128
=
xmmasM128
*
threadsPerCta128
;
constexpr
size_t
packedMaskSize256
=
xmmasM256
*
threadsPerCta256
;
constexpr
size_t
packedMaskSize256
=
xmmasM256
*
threadsPerCta256
;
constexpr
size_t
packedMaskSize384
=
xmmasM384
*
threadsPerCta384
;
constexpr
size_t
packedMaskSize384
=
xmmasM384
*
threadsPerCta384
;
char
const
*
EMB_LAYER_NORM_VAR_SEQLEN_VERSION_HFACE
{
"
2
"
};
char
const
*
EMB_LAYER_NORM_VAR_SEQLEN_VERSION_HFACE
{
"
1
"
};
char
const
*
EMB_LAYER_NORM_VAR_SEQLEN_VERSION_MTRON
{
"
3
"
};
char
const
*
EMB_LAYER_NORM_VAR_SEQLEN_VERSION_MTRON
{
"
2
"
};
char
const
*
EMB_LAYER_NORM_VAR_SEQLEN_NAME
{
"ManyEmbLayerNormPluginDynamic"
};
char
const
*
EMB_LAYER_NORM_VAR_SEQLEN_NAME
{
"ManyEmbLayerNormPluginDynamic"
};
// Static class fields initialization
// Static class fields initialization
nvinfer1
::
PluginFieldCollection
EmbLayerNormVarSeqlenPluginBaseCreator
::
mFC
{};
nvinfer1
::
PluginFieldCollection
EmbLayerNormVarSeqlenPluginBaseCreator
::
mFC
{};
...
@@ -74,7 +74,7 @@ EmbLayerNormVarSeqlenPluginBase::EmbLayerNormVarSeqlenPluginBase(
...
@@ -74,7 +74,7 @@ EmbLayerNormVarSeqlenPluginBase::EmbLayerNormVarSeqlenPluginBase(
tem_weight
.
values
,
tem_weight
.
values
,
getWeightsSize
(
tem_weight
,
mType
),
getWeightsSize
(
tem_weight
,
mType
),
cudaMemcpyHostToDevice
));
cudaMemcpyHostToDevice
));
mIdsEmb
Dev
.
push_back
(
cudaMem
);
mIdsEmb
Ptrs
.
push_back
(
cudaMem
);
}
}
}
}
...
@@ -83,7 +83,7 @@ EmbLayerNormVarSeqlenPluginBase::EmbLayerNormVarSeqlenPluginBase(
...
@@ -83,7 +83,7 @@ EmbLayerNormVarSeqlenPluginBase::EmbLayerNormVarSeqlenPluginBase(
:
mLayerName
(
name
),
:
mLayerName
(
name
),
mGammaDev
(
nullptr
),
mGammaDev
(
nullptr
),
mBetaDev
(
nullptr
),
mBetaDev
(
nullptr
),
mIdsEmb
Dev
{},
mIdsEmb
Ptrs
{},
mIdsEmb_
{}
{
mIdsEmb_
{}
{
// Deserialize in the same order as serialization
// Deserialize in the same order as serialization
deserialize_value
(
&
data
,
&
length
,
&
mType
);
deserialize_value
(
&
data
,
&
length
,
&
mType
);
...
@@ -141,8 +141,8 @@ EmbLayerNormVarSeqlenPluginMTron::EmbLayerNormVarSeqlenPluginMTron(
...
@@ -141,8 +141,8 @@ EmbLayerNormVarSeqlenPluginMTron::EmbLayerNormVarSeqlenPluginMTron(
// IPluginV2DynamicExt Methods
// IPluginV2DynamicExt Methods
nvinfer1
::
IPluginV2DynamicExt
*
EmbLayerNormVarSeqlenPluginHFace
::
clone
()
nvinfer1
::
IPluginV2DynamicExt
*
EmbLayerNormVarSeqlenPluginHFace
::
clone
()
const
noexcept
{
const
noexcept
{
TRANSFORMER_DEBUG_MSG
(
"EmbLayerNormVarSeqlenPlugin
MTron
clone"
);
TRANSFORMER_DEBUG_MSG
(
"EmbLayerNormVarSeqlenPlugin
HFace
clone"
);
auto
p
=
new
EmbLayerNormVarSeqlenPlugin
MTron
(
auto
p
=
new
EmbLayerNormVarSeqlenPlugin
HFace
(
mLayerName
,
mType
,
mBeta
,
mGamma
,
mIdsEmb_
);
mLayerName
,
mType
,
mBeta
,
mGamma
,
mIdsEmb_
);
p
->
setPluginNamespace
(
mNamespace
.
c_str
());
p
->
setPluginNamespace
(
mNamespace
.
c_str
());
return
p
;
return
p
;
...
@@ -333,7 +333,7 @@ int32_t EmbLayerNormVarSeqlenPluginHFace::enqueue(
...
@@ -333,7 +333,7 @@ int32_t EmbLayerNormVarSeqlenPluginHFace::enqueue(
void
*
const
*
outputs
,
void
*
const
*
outputs
,
void
*
workspace
,
void
*
workspace
,
cudaStream_t
stream
)
noexcept
{
cudaStream_t
stream
)
noexcept
{
int32_t
const
batchSize
=
inputDesc
[
0
].
dims
.
d
[
0
]
-
1
;
int32_t
batchSize
=
inputDesc
[
0
].
dims
.
d
[
0
]
-
1
;
// read out the maximum sequence length from the dummy input
// read out the maximum sequence length from the dummy input
int32_t
const
maxSeqlen
=
inputDesc
[
nbLookupTables_
].
dims
.
d
[
0
];
int32_t
const
maxSeqlen
=
inputDesc
[
nbLookupTables_
].
dims
.
d
[
0
];
int32_t
S
=
384
;
int32_t
S
=
384
;
...
@@ -346,60 +346,132 @@ int32_t EmbLayerNormVarSeqlenPluginHFace::enqueue(
...
@@ -346,60 +346,132 @@ int32_t EmbLayerNormVarSeqlenPluginHFace::enqueue(
}
}
const
float
*
beta
=
mBetaDev
.
get
();
const
float
*
beta
=
mBetaDev
.
get
();
const
float
*
gamma
=
mGammaDev
.
get
();
const
float
*
gamma
=
mGammaDev
.
get
();
int32_t
**
tem_inputs_ptr_dev
;
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
tem_inputs_ptr_dev
),
sizeof
(
void
*
)
*
nbLookupTables_
);
cudaMemcpy
(
tem_inputs_ptr_dev
,
inputs
,
sizeof
(
void
*
)
*
nbLookupTables_
,
cudaMemcpyHostToDevice
);
int32_t
*
mIdsVocabSize_dev
;
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
mIdsVocabSize_dev
),
sizeof
(
int32_t
)
*
mIdsVocabSize
.
size
());
cudaMemcpy
(
mIdsVocabSize_dev
,
&
(
mIdsVocabSize
[
0
]),
sizeof
(
int32_t
)
*
mIdsVocabSize
.
size
(),
cudaMemcpyHostToDevice
);
if
(
mType
==
nvinfer1
::
DataType
::
kFLOAT
)
{
if
(
mType
==
nvinfer1
::
DataType
::
kFLOAT
)
{
auto
output
=
static_cast
<
float
*>
(
outputs
[
0
]);
auto
output
=
static_cast
<
float
*>
(
outputs
[
0
]);
float
**
mIdsEmbDev_float
;
if
(
nbLookupTables_
==
2
)
{
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
mIdsEmbDev_float
),
return
embSkipLayerNormHFace_2
<
float
>
(
sizeof
(
void
*
)
*
nbLookupTables_
);
stream
,
cudaMemcpy
(
mIdsEmbDev_float
,
static_cast
<
int32_t
>
(
mLd
),
&
(
mIdsEmbDev
[
0
]),
batchSize
,
sizeof
(
void
*
)
*
nbLookupTables_
,
S
,
cudaMemcpyHostToDevice
);
static_cast
<
int32_t
const
*>
(
inputs
[
0
]),
return
embSkipLayerNormHFace
<
float
>
(
stream
,
static_cast
<
int32_t
const
*>
(
inputs
[
1
]),
nbLookupTables_
,
beta
,
gamma
,
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
0
]),
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
1
]),
mIdsVocabSize
[
0
],
mIdsVocabSize
[
1
],
output
);
}
else
if
(
nbLookupTables_
==
3
)
{
return
embSkipLayerNormHFace_3
<
float
>
(
stream
,
static_cast
<
int32_t
>
(
mLd
),
static_cast
<
int32_t
>
(
mLd
),
batchSize
,
batchSize
,
S
,
S
,
tem_inputs_ptr_dev
,
static_cast
<
int32_t
const
*>
(
inputs
[
0
]),
static_cast
<
int32_t
const
*>
(
inputs
[
1
]),
static_cast
<
int32_t
const
*>
(
inputs
[
2
]),
nbLookupTables_
,
nbLookupTables_
,
beta
,
beta
,
gamma
,
gamma
,
mIdsEmbDev_float
,
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
0
]),
mIdsVocabSize_dev
,
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
1
]),
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
2
]),
mIdsVocabSize
[
0
],
mIdsVocabSize
[
1
],
mIdsVocabSize
[
2
],
output
);
output
);
}
else
if
(
nbLookupTables_
==
4
)
{
return
embSkipLayerNormHFace_4
<
float
>
(
stream
,
static_cast
<
int32_t
>
(
mLd
),
batchSize
,
S
,
static_cast
<
int32_t
const
*>
(
inputs
[
0
]),
static_cast
<
int32_t
const
*>
(
inputs
[
1
]),
static_cast
<
int32_t
const
*>
(
inputs
[
2
]),
static_cast
<
int32_t
const
*>
(
inputs
[
3
]),
nbLookupTables_
,
beta
,
gamma
,
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
0
]),
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
1
]),
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
2
]),
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
3
]),
mIdsVocabSize
[
0
],
mIdsVocabSize
[
1
],
mIdsVocabSize
[
2
],
mIdsVocabSize
[
3
],
output
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Only support 2,3,4 lookup_tables fused "
));
}
}
else
if
(
mType
==
nvinfer1
::
DataType
::
kHALF
)
{
}
else
if
(
mType
==
nvinfer1
::
DataType
::
kHALF
)
{
auto
output
=
static_cast
<
half
*>
(
outputs
[
0
]);
auto
output
=
static_cast
<
half
*>
(
outputs
[
0
]);
half
**
mIdsEmbDev_half
;
if
(
nbLookupTables_
==
2
)
{
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
mIdsEmbDev_half
),
return
embSkipLayerNormHFace_2
<
half
>
(
sizeof
(
void
*
)
*
nbLookupTables_
);
stream
,
cudaMemcpy
(
mIdsEmbDev_half
,
&
(
mIdsEmbDev
[
0
]),
sizeof
(
void
*
)
*
nbLookupTables_
,
cudaMemcpyHostToDevice
);
return
embSkipLayerNormHFace
<
half
>
(
stream
,
static_cast
<
int32_t
>
(
mLd
),
static_cast
<
int32_t
>
(
mLd
),
batchSize
,
batchSize
,
S
,
S
,
tem_inputs_ptr_dev
,
static_cast
<
int32_t
const
*>
(
inputs
[
0
]),
static_cast
<
int32_t
const
*>
(
inputs
[
1
]),
nbLookupTables_
,
nbLookupTables_
,
beta
,
beta
,
gamma
,
gamma
,
mIdsEmbDev_half
,
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
0
]),
mIdsVocabSize_dev
,
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
1
]),
mIdsVocabSize
[
0
],
mIdsVocabSize
[
1
],
output
);
}
else
if
(
nbLookupTables_
==
3
)
{
return
embSkipLayerNormHFace_3
<
half
>
(
stream
,
static_cast
<
int32_t
>
(
mLd
),
batchSize
,
S
,
static_cast
<
int32_t
const
*>
(
inputs
[
0
]),
static_cast
<
int32_t
const
*>
(
inputs
[
1
]),
static_cast
<
int32_t
const
*>
(
inputs
[
2
]),
nbLookupTables_
,
beta
,
gamma
,
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
0
]),
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
1
]),
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
2
]),
mIdsVocabSize
[
0
],
mIdsVocabSize
[
1
],
mIdsVocabSize
[
2
],
output
);
}
else
if
(
nbLookupTables_
==
4
)
{
return
embSkipLayerNormHFace_4
<
half
>
(
stream
,
static_cast
<
int32_t
>
(
mLd
),
batchSize
,
S
,
static_cast
<
int32_t
const
*>
(
inputs
[
0
]),
static_cast
<
int32_t
const
*>
(
inputs
[
1
]),
static_cast
<
int32_t
const
*>
(
inputs
[
2
]),
static_cast
<
int32_t
const
*>
(
inputs
[
3
]),
nbLookupTables_
,
beta
,
gamma
,
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
0
]),
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
1
]),
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
2
]),
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
3
]),
mIdsVocabSize
[
0
],
mIdsVocabSize
[
1
],
mIdsVocabSize
[
2
],
mIdsVocabSize
[
3
],
output
);
output
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Only support 2,3,4 lookup_tables fused "
));
}
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Unsupported type error, expected [kHALF,kFLOAT]"
));
"Unsupported type error, expected [kHALF,kFLOAT]"
));
...
@@ -414,7 +486,7 @@ int32_t EmbLayerNormVarSeqlenPluginMTron::enqueue(
...
@@ -414,7 +486,7 @@ int32_t EmbLayerNormVarSeqlenPluginMTron::enqueue(
void
*
const
*
outputs
,
void
*
const
*
outputs
,
void
*
workspace
,
void
*
workspace
,
cudaStream_t
stream
)
noexcept
{
cudaStream_t
stream
)
noexcept
{
int32_t
const
batchSize
=
inputDesc
[
0
].
dims
.
d
[
0
]
-
1
;
int32_t
batchSize
=
inputDesc
[
0
].
dims
.
d
[
0
]
-
1
;
// read out the maximum sequence length from the dummy input
// read out the maximum sequence length from the dummy input
int32_t
const
maxSeqlen
=
inputDesc
[
nbLookupTables_
].
dims
.
d
[
0
];
int32_t
const
maxSeqlen
=
inputDesc
[
nbLookupTables_
].
dims
.
d
[
0
];
int32_t
S
=
384
;
int32_t
S
=
384
;
...
@@ -427,64 +499,141 @@ int32_t EmbLayerNormVarSeqlenPluginMTron::enqueue(
...
@@ -427,64 +499,141 @@ int32_t EmbLayerNormVarSeqlenPluginMTron::enqueue(
}
}
const
float
*
beta
=
mBetaDev
.
get
();
const
float
*
beta
=
mBetaDev
.
get
();
const
float
*
gamma
=
mGammaDev
.
get
();
const
float
*
gamma
=
mGammaDev
.
get
();
int32_t
**
tem_inputs_ptr_dev
;
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
tem_inputs_ptr_dev
),
sizeof
(
void
*
)
*
nbLookupTables_
);
cudaMemcpy
(
tem_inputs_ptr_dev
,
inputs
,
sizeof
(
void
*
)
*
nbLookupTables_
,
cudaMemcpyHostToDevice
);
int32_t
*
mIdsVocabSize_dev
;
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
mIdsVocabSize_dev
),
sizeof
(
int32_t
)
*
mIdsVocabSize
.
size
());
cudaMemcpy
(
mIdsVocabSize_dev
,
&
(
mIdsVocabSize
[
0
]),
sizeof
(
int32_t
)
*
mIdsVocabSize
.
size
(),
cudaMemcpyHostToDevice
);
if
(
mType
==
nvinfer1
::
DataType
::
kFLOAT
)
{
if
(
mType
==
nvinfer1
::
DataType
::
kFLOAT
)
{
auto
output
=
static_cast
<
float
*>
(
outputs
[
0
]);
auto
output
=
static_cast
<
float
*>
(
outputs
[
0
]);
auto
skip
=
static_cast
<
float
*>
(
outputs
[
1
]);
auto
skip
=
static_cast
<
float
*>
(
outputs
[
1
]);
float
**
mIdsEmbDev_float
;
if
(
nbLookupTables_
==
2
)
{
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
mIdsEmbDev_float
),
return
embSkipLayerNormMTron_2
<
float
>
(
sizeof
(
void
*
)
*
nbLookupTables_
);
stream
,
cudaMemcpy
(
mIdsEmbDev_float
,
static_cast
<
int32_t
>
(
mLd
),
&
(
mIdsEmbDev
[
0
]),
batchSize
,
sizeof
(
void
*
)
*
nbLookupTables_
,
S
,
cudaMemcpyHostToDevice
);
static_cast
<
int32_t
const
*>
(
inputs
[
0
]),
return
embSkipLayerNormMTron
<
float
>
(
stream
,
static_cast
<
int32_t
const
*>
(
inputs
[
1
]),
nbLookupTables_
,
beta
,
gamma
,
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
0
]),
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
1
]),
mIdsVocabSize
[
0
],
mIdsVocabSize
[
1
],
output
,
skip
);
}
else
if
(
nbLookupTables_
==
3
)
{
return
embSkipLayerNormMTron_3
<
float
>
(
stream
,
static_cast
<
int32_t
>
(
mLd
),
static_cast
<
int32_t
>
(
mLd
),
batchSize
,
batchSize
,
S
,
S
,
tem_inputs_ptr_dev
,
static_cast
<
int32_t
const
*>
(
inputs
[
0
]),
static_cast
<
int32_t
const
*>
(
inputs
[
1
]),
static_cast
<
int32_t
const
*>
(
inputs
[
2
]),
nbLookupTables_
,
nbLookupTables_
,
beta
,
beta
,
gamma
,
gamma
,
mIdsEmbDev_float
,
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
0
]),
mIdsVocabSize_dev
,
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
1
]),
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
2
]),
mIdsVocabSize
[
0
],
mIdsVocabSize
[
1
],
mIdsVocabSize
[
2
],
output
,
output
,
skip
);
skip
);
}
else
if
(
nbLookupTables_
==
4
)
{
return
embSkipLayerNormMTron_4
<
float
>
(
stream
,
static_cast
<
int32_t
>
(
mLd
),
batchSize
,
S
,
static_cast
<
int32_t
const
*>
(
inputs
[
0
]),
static_cast
<
int32_t
const
*>
(
inputs
[
1
]),
static_cast
<
int32_t
const
*>
(
inputs
[
2
]),
static_cast
<
int32_t
const
*>
(
inputs
[
3
]),
nbLookupTables_
,
beta
,
gamma
,
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
0
]),
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
1
]),
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
2
]),
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
3
]),
mIdsVocabSize
[
0
],
mIdsVocabSize
[
1
],
mIdsVocabSize
[
2
],
mIdsVocabSize
[
3
],
output
,
skip
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Only support 2,3,4 lookup_tables fused "
));
}
}
else
if
(
mType
==
nvinfer1
::
DataType
::
kHALF
)
{
}
else
if
(
mType
==
nvinfer1
::
DataType
::
kHALF
)
{
auto
output
=
static_cast
<
half
*>
(
outputs
[
0
]);
auto
output
=
static_cast
<
half
*>
(
outputs
[
0
]);
auto
skip
=
static_cast
<
half
*>
(
outputs
[
1
]);
auto
skip
=
static_cast
<
half
*>
(
outputs
[
1
]);
half
**
mIdsEmbDev_half
;
if
(
nbLookupTables_
==
2
)
{
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
mIdsEmbDev_half
),
return
embSkipLayerNormMTron_2
<
half
>
(
sizeof
(
void
*
)
*
nbLookupTables_
);
stream
,
cudaMemcpy
(
mIdsEmbDev_half
,
static_cast
<
int32_t
>
(
mLd
),
&
(
mIdsEmbDev
[
0
]),
batchSize
,
sizeof
(
void
*
)
*
nbLookupTables_
,
S
,
cudaMemcpyHostToDevice
);
static_cast
<
int32_t
const
*>
(
inputs
[
0
]),
return
embSkipLayerNormMTron
<
half
>
(
stream
,
static_cast
<
int32_t
const
*>
(
inputs
[
1
]),
nbLookupTables_
,
beta
,
gamma
,
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
0
]),
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
1
]),
mIdsVocabSize
[
0
],
mIdsVocabSize
[
1
],
output
,
skip
);
}
else
if
(
nbLookupTables_
==
3
)
{
return
embSkipLayerNormMTron_3
<
half
>
(
stream
,
static_cast
<
int32_t
>
(
mLd
),
static_cast
<
int32_t
>
(
mLd
),
batchSize
,
batchSize
,
S
,
S
,
tem_inputs_ptr_dev
,
static_cast
<
int32_t
const
*>
(
inputs
[
0
]),
static_cast
<
int32_t
const
*>
(
inputs
[
1
]),
static_cast
<
int32_t
const
*>
(
inputs
[
2
]),
nbLookupTables_
,
nbLookupTables_
,
beta
,
beta
,
gamma
,
gamma
,
mIdsEmbDev_half
,
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
0
]),
mIdsVocabSize_dev
,
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
1
]),
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
2
]),
mIdsVocabSize
[
0
],
mIdsVocabSize
[
1
],
mIdsVocabSize
[
2
],
output
,
skip
);
}
else
if
(
nbLookupTables_
==
4
)
{
return
embSkipLayerNormMTron_4
<
half
>
(
stream
,
static_cast
<
int32_t
>
(
mLd
),
batchSize
,
S
,
static_cast
<
int32_t
const
*>
(
inputs
[
0
]),
static_cast
<
int32_t
const
*>
(
inputs
[
1
]),
static_cast
<
int32_t
const
*>
(
inputs
[
2
]),
static_cast
<
int32_t
const
*>
(
inputs
[
3
]),
nbLookupTables_
,
beta
,
gamma
,
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
0
]),
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
1
]),
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
2
]),
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
3
]),
mIdsVocabSize
[
0
],
mIdsVocabSize
[
1
],
mIdsVocabSize
[
2
],
mIdsVocabSize
[
3
],
output
,
output
,
skip
);
skip
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Only support 2,3,4 lookup_tables fused "
));
}
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Unsupported type error, expected [kHALF,kFLOAT]"
));
"Unsupported type error, expected [kHALF,kFLOAT]"
));
...
@@ -566,9 +715,9 @@ void EmbLayerNormVarSeqlenPluginBase::serialize(void* buffer) const noexcept {
...
@@ -566,9 +715,9 @@ void EmbLayerNormVarSeqlenPluginBase::serialize(void* buffer) const noexcept {
size_t
const
wordSize
=
getElementSize
(
mType
);
size_t
const
wordSize
=
getElementSize
(
mType
);
serFromDev
(
&
d
,
mBetaDev
.
get
(),
mLd
);
serFromDev
(
&
d
,
mBetaDev
.
get
(),
mLd
);
serFromDev
(
&
d
,
mGammaDev
.
get
(),
mLd
);
serFromDev
(
&
d
,
mGammaDev
.
get
(),
mLd
);
for
(
size_t
i
=
0
;
i
<
mIdsEmb
Dev
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
mIdsEmb
Ptrs
.
size
();
++
i
)
{
serFromDev
(
&
d
,
serFromDev
(
&
d
,
static_cast
<
char
*>
(
mIdsEmb
Dev
[
i
]),
static_cast
<
char
*>
(
mIdsEmb
Ptrs
[
i
]),
mLd
*
mIdsVocabSize
[
i
]
*
wordSize
);
mLd
*
mIdsVocabSize
[
i
]
*
wordSize
);
}
}
}
}
...
@@ -577,8 +726,8 @@ void EmbLayerNormVarSeqlenPluginBase::destroy() noexcept {
...
@@ -577,8 +726,8 @@ void EmbLayerNormVarSeqlenPluginBase::destroy() noexcept {
// This gets called when the network containing plugin is destroyed
// This gets called when the network containing plugin is destroyed
mBetaDev
.
reset
(
nullptr
);
mBetaDev
.
reset
(
nullptr
);
mGammaDev
.
reset
(
nullptr
);
mGammaDev
.
reset
(
nullptr
);
for
(
size_t
i
=
0
;
i
<
mIdsEmb
Dev
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
mIdsEmb
Ptrs
.
size
();
++
i
)
{
cudaFree
(
mIdsEmb
Dev
[
i
]);
cudaFree
(
mIdsEmb
Ptrs
[
i
]);
}
}
delete
this
;
delete
this
;
}
}
...
@@ -680,7 +829,6 @@ nvinfer1::IPluginV2* EmbLayerNormVarSeqlenPluginHFaceCreator::createPlugin(
...
@@ -680,7 +829,6 @@ nvinfer1::IPluginV2* EmbLayerNormVarSeqlenPluginHFaceCreator::createPlugin(
beta
,
beta
,
gamma
,
gamma
,
IdsEmb
);
IdsEmb
);
return
p
;
return
p
;
}
}
...
...
paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h
浏览文件 @
6512e087
...
@@ -31,32 +31,121 @@ namespace tensorrt {
...
@@ -31,32 +31,121 @@ namespace tensorrt {
namespace
plugin
{
namespace
plugin
{
template
<
typename
T
>
template
<
typename
T
>
int32_t
embSkipLayerNormHFace
(
cudaStream_t
stream
,
int32_t
embSkipLayerNormHFace_2
(
cudaStream_t
,
int32_t
ld
,
int32_t
,
int32_t
B
,
int32_t
,
int32_t
S
,
int32_t
,
int32_t
**
inputIds
,
int32_t
const
*
,
int32_t
const
nbLookupTables
,
int32_t
const
*
,
float
const
*
beta
,
int32_t
,
float
const
*
gamma
,
float
const
*
,
T
**
idsEmb
,
float
const
*
,
int32_t
*
,
T
const
*
,
T
*
output
);
T
const
*
,
int32_t
,
int32_t
,
T
*
);
template
<
typename
T
>
template
<
typename
T
>
int32_t
embSkipLayerNormMTron
(
cudaStream_t
stream
,
int32_t
embSkipLayerNormHFace_3
(
cudaStream_t
,
int32_t
ld
,
int32_t
,
int32_t
B
,
int32_t
,
int32_t
S
,
int32_t
,
int32_t
**
inputIds
,
int32_t
const
*
,
int32_t
const
nbLookupTables
,
int32_t
const
*
,
float
const
*
beta
,
int32_t
const
*
,
float
const
*
gamma
,
int32_t
,
T
**
idsEmb
,
float
const
*
,
int32_t
*
,
float
const
*
,
T
*
output
,
T
const
*
,
T
*
skip
);
T
const
*
,
T
const
*
,
int32_t
,
int32_t
,
int32_t
,
T
*
);
template
<
typename
T
>
int32_t
embSkipLayerNormHFace_4
(
cudaStream_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
,
float
const
*
,
float
const
*
,
T
const
*
,
T
const
*
,
T
const
*
,
T
const
*
,
int32_t
,
int32_t
,
int32_t
,
int32_t
,
T
*
);
template
<
typename
T
>
int32_t
embSkipLayerNormMTron_2
(
cudaStream_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
,
float
const
*
,
float
const
*
,
T
const
*
,
T
const
*
,
int32_t
,
int32_t
,
T
*
,
T
*
);
template
<
typename
T
>
int32_t
embSkipLayerNormMTron_3
(
cudaStream_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
,
float
const
*
,
float
const
*
,
T
const
*
,
T
const
*
,
T
const
*
,
int32_t
,
int32_t
,
int32_t
,
T
*
,
T
*
);
template
<
typename
T
>
int32_t
embSkipLayerNormMTron_4
(
cudaStream_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
,
float
const
*
,
float
const
*
,
T
const
*
,
T
const
*
,
T
const
*
,
T
const
*
,
int32_t
,
int32_t
,
int32_t
,
int32_t
,
T
*
,
T
*
);
class
EmbLayerNormVarSeqlenPluginBase
:
public
nvinfer1
::
IPluginV2DynamicExt
{
class
EmbLayerNormVarSeqlenPluginBase
:
public
nvinfer1
::
IPluginV2DynamicExt
{
public:
public:
EmbLayerNormVarSeqlenPluginBase
(
EmbLayerNormVarSeqlenPluginBase
(
...
@@ -104,7 +193,8 @@ class EmbLayerNormVarSeqlenPluginBase : public nvinfer1::IPluginV2DynamicExt {
...
@@ -104,7 +193,8 @@ class EmbLayerNormVarSeqlenPluginBase : public nvinfer1::IPluginV2DynamicExt {
std
::
string
mNamespace
;
std
::
string
mNamespace
;
cuda_unique_ptr
<
float
>
mGammaDev
;
cuda_unique_ptr
<
float
>
mGammaDev
;
cuda_unique_ptr
<
float
>
mBetaDev
;
cuda_unique_ptr
<
float
>
mBetaDev
;
std
::
vector
<
void
*>
mIdsEmbDev
;
std
::
vector
<
void
*>
mIdsEmbPtrs
;
// std::vector<void*> mIdsEmbDev;
size_t
mLd
;
// leading dim = hidden size
size_t
mLd
;
// leading dim = hidden size
std
::
vector
<
int32_t
>
mIdsVocabSize
;
std
::
vector
<
int32_t
>
mIdsVocabSize
;
WeightsWithOwnership
mBeta
;
WeightsWithOwnership
mBeta
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录