Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
6512e087
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6512e087
编写于
10月 10, 2022
作者:
W
Wangzheee
提交者:
GitHub
10月 10, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Paddle Inference]fix embedding fused (#46789)
* fix embedding fused
上级
ae6b4713
变更
6
展开全部
隐藏空白更改
内联
并排
Showing
6 changed file
with
1203 addition
and
275 deletion
+1203
-275
paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
...fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
+3
-3
paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc
...inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc
+1
-1
paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelHFace.cu
...nsorrt/plugin/many_emb_Layernorm_varseqlen_kernelHFace.cu
+403
-61
paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelMTron.cu
...nsorrt/plugin/many_emb_Layernorm_varseqlen_kernelMTron.cu
+419
-71
paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu
...ce/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu
+263
-115
paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h
...nce/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h
+114
-24
未找到文件。
paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
浏览文件 @
6512e087
...
...
@@ -210,14 +210,14 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
"max_seqlen_tensor"
));
// max_seqlen, eval_placeholder_3
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
"ManyEmbLayerNormPluginDynamic"
,
"
2
"
);
"ManyEmbLayerNormPluginDynamic"
,
"
1
"
);
auto
plugin_obj
=
creator
->
createPlugin
(
"ManyEmbLayerNormPluginDynamic"
,
plugin_ptr
);
auto
plugin_layer
=
engine_
->
network
()
->
addPluginV2
(
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
*
plugin_obj
);
plugin_layer
->
setName
((
"ManyEmbLayerNormPluginDynamic_V
2
(Output: "
+
plugin_layer
->
setName
((
"ManyEmbLayerNormPluginDynamic_V
1
(Output: "
+
op_desc
.
Output
(
"Out"
)[
0
]
+
")"
)
.
c_str
());
free
(
plugin_ptr
);
...
...
@@ -248,7 +248,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
layer
=
plugin_layer
;
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
RreplenishLayerAndOutput
(
layer
,
"ManyEmbLayerNormPluginDynamic_V
2
"
,
"ManyEmbLayerNormPluginDynamic_V
1
"
,
{
output_name
,
std
::
string
(
"qkv_plugin_mask"
)},
test_mode
);
}
...
...
paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc
浏览文件 @
6512e087
...
...
@@ -194,7 +194,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
"max_seqlen_tensor"
));
// max_seqlen, eval_placeholder_3
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
"ManyEmbLayerNormPluginDynamic"
,
"
3
"
);
"ManyEmbLayerNormPluginDynamic"
,
"
2
"
);
auto
plugin_obj
=
creator
->
createPlugin
(
"ManyEmbLayerNormPluginDynamic"
,
plugin_ptr
);
...
...
paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelHFace.cu
浏览文件 @
6512e087
此差异已折叠。
点击以展开。
paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelMTron.cu
浏览文件 @
6512e087
此差异已折叠。
点击以展开。
paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu
浏览文件 @
6512e087
...
...
@@ -37,8 +37,8 @@ constexpr size_t xmmasM384 = 24;
constexpr
size_t
packedMaskSize128
=
xmmasM128
*
threadsPerCta128
;
constexpr
size_t
packedMaskSize256
=
xmmasM256
*
threadsPerCta256
;
constexpr
size_t
packedMaskSize384
=
xmmasM384
*
threadsPerCta384
;
char
const
*
EMB_LAYER_NORM_VAR_SEQLEN_VERSION_HFACE
{
"
2
"
};
char
const
*
EMB_LAYER_NORM_VAR_SEQLEN_VERSION_MTRON
{
"
3
"
};
char
const
*
EMB_LAYER_NORM_VAR_SEQLEN_VERSION_HFACE
{
"
1
"
};
char
const
*
EMB_LAYER_NORM_VAR_SEQLEN_VERSION_MTRON
{
"
2
"
};
char
const
*
EMB_LAYER_NORM_VAR_SEQLEN_NAME
{
"ManyEmbLayerNormPluginDynamic"
};
// Static class fields initialization
nvinfer1
::
PluginFieldCollection
EmbLayerNormVarSeqlenPluginBaseCreator
::
mFC
{};
...
...
@@ -74,7 +74,7 @@ EmbLayerNormVarSeqlenPluginBase::EmbLayerNormVarSeqlenPluginBase(
tem_weight
.
values
,
getWeightsSize
(
tem_weight
,
mType
),
cudaMemcpyHostToDevice
));
mIdsEmb
Dev
.
push_back
(
cudaMem
);
mIdsEmb
Ptrs
.
push_back
(
cudaMem
);
}
}
...
...
@@ -83,7 +83,7 @@ EmbLayerNormVarSeqlenPluginBase::EmbLayerNormVarSeqlenPluginBase(
:
mLayerName
(
name
),
mGammaDev
(
nullptr
),
mBetaDev
(
nullptr
),
mIdsEmb
Dev
{},
mIdsEmb
Ptrs
{},
mIdsEmb_
{}
{
// Deserialize in the same order as serialization
deserialize_value
(
&
data
,
&
length
,
&
mType
);
...
...
@@ -141,8 +141,8 @@ EmbLayerNormVarSeqlenPluginMTron::EmbLayerNormVarSeqlenPluginMTron(
// IPluginV2DynamicExt Methods
nvinfer1
::
IPluginV2DynamicExt
*
EmbLayerNormVarSeqlenPluginHFace
::
clone
()
const
noexcept
{
TRANSFORMER_DEBUG_MSG
(
"EmbLayerNormVarSeqlenPlugin
MTron
clone"
);
auto
p
=
new
EmbLayerNormVarSeqlenPlugin
MTron
(
TRANSFORMER_DEBUG_MSG
(
"EmbLayerNormVarSeqlenPlugin
HFace
clone"
);
auto
p
=
new
EmbLayerNormVarSeqlenPlugin
HFace
(
mLayerName
,
mType
,
mBeta
,
mGamma
,
mIdsEmb_
);
p
->
setPluginNamespace
(
mNamespace
.
c_str
());
return
p
;
...
...
@@ -333,7 +333,7 @@ int32_t EmbLayerNormVarSeqlenPluginHFace::enqueue(
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
noexcept
{
int32_t
const
batchSize
=
inputDesc
[
0
].
dims
.
d
[
0
]
-
1
;
int32_t
batchSize
=
inputDesc
[
0
].
dims
.
d
[
0
]
-
1
;
// read out the maximum sequence length from the dummy input
int32_t
const
maxSeqlen
=
inputDesc
[
nbLookupTables_
].
dims
.
d
[
0
];
int32_t
S
=
384
;
...
...
@@ -346,60 +346,132 @@ int32_t EmbLayerNormVarSeqlenPluginHFace::enqueue(
}
const
float
*
beta
=
mBetaDev
.
get
();
const
float
*
gamma
=
mGammaDev
.
get
();
int32_t
**
tem_inputs_ptr_dev
;
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
tem_inputs_ptr_dev
),
sizeof
(
void
*
)
*
nbLookupTables_
);
cudaMemcpy
(
tem_inputs_ptr_dev
,
inputs
,
sizeof
(
void
*
)
*
nbLookupTables_
,
cudaMemcpyHostToDevice
);
int32_t
*
mIdsVocabSize_dev
;
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
mIdsVocabSize_dev
),
sizeof
(
int32_t
)
*
mIdsVocabSize
.
size
());
cudaMemcpy
(
mIdsVocabSize_dev
,
&
(
mIdsVocabSize
[
0
]),
sizeof
(
int32_t
)
*
mIdsVocabSize
.
size
(),
cudaMemcpyHostToDevice
);
if
(
mType
==
nvinfer1
::
DataType
::
kFLOAT
)
{
auto
output
=
static_cast
<
float
*>
(
outputs
[
0
]);
float
**
mIdsEmbDev_float
;
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
mIdsEmbDev_float
),
sizeof
(
void
*
)
*
nbLookupTables_
);
cudaMemcpy
(
mIdsEmbDev_float
,
&
(
mIdsEmbDev
[
0
]),
sizeof
(
void
*
)
*
nbLookupTables_
,
cudaMemcpyHostToDevice
);
return
embSkipLayerNormHFace
<
float
>
(
stream
,
static_cast
<
int32_t
>
(
mLd
),
batchSize
,
S
,
tem_inputs_ptr_dev
,
nbLookupTables_
,
beta
,
gamma
,
mIdsEmbDev_float
,
mIdsVocabSize_dev
,
output
);
if
(
nbLookupTables_
==
2
)
{
return
embSkipLayerNormHFace_2
<
float
>
(
stream
,
static_cast
<
int32_t
>
(
mLd
),
batchSize
,
S
,
static_cast
<
int32_t
const
*>
(
inputs
[
0
]),
static_cast
<
int32_t
const
*>
(
inputs
[
1
]),
nbLookupTables_
,
beta
,
gamma
,
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
0
]),
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
1
]),
mIdsVocabSize
[
0
],
mIdsVocabSize
[
1
],
output
);
}
else
if
(
nbLookupTables_
==
3
)
{
return
embSkipLayerNormHFace_3
<
float
>
(
stream
,
static_cast
<
int32_t
>
(
mLd
),
batchSize
,
S
,
static_cast
<
int32_t
const
*>
(
inputs
[
0
]),
static_cast
<
int32_t
const
*>
(
inputs
[
1
]),
static_cast
<
int32_t
const
*>
(
inputs
[
2
]),
nbLookupTables_
,
beta
,
gamma
,
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
0
]),
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
1
]),
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
2
]),
mIdsVocabSize
[
0
],
mIdsVocabSize
[
1
],
mIdsVocabSize
[
2
],
output
);
}
else
if
(
nbLookupTables_
==
4
)
{
return
embSkipLayerNormHFace_4
<
float
>
(
stream
,
static_cast
<
int32_t
>
(
mLd
),
batchSize
,
S
,
static_cast
<
int32_t
const
*>
(
inputs
[
0
]),
static_cast
<
int32_t
const
*>
(
inputs
[
1
]),
static_cast
<
int32_t
const
*>
(
inputs
[
2
]),
static_cast
<
int32_t
const
*>
(
inputs
[
3
]),
nbLookupTables_
,
beta
,
gamma
,
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
0
]),
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
1
]),
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
2
]),
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
3
]),
mIdsVocabSize
[
0
],
mIdsVocabSize
[
1
],
mIdsVocabSize
[
2
],
mIdsVocabSize
[
3
],
output
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Only support 2,3,4 lookup_tables fused "
));
}
}
else
if
(
mType
==
nvinfer1
::
DataType
::
kHALF
)
{
auto
output
=
static_cast
<
half
*>
(
outputs
[
0
]);
half
**
mIdsEmbDev_half
;
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
mIdsEmbDev_half
),
sizeof
(
void
*
)
*
nbLookupTables_
);
cudaMemcpy
(
mIdsEmbDev_half
,
&
(
mIdsEmbDev
[
0
]),
sizeof
(
void
*
)
*
nbLookupTables_
,
cudaMemcpyHostToDevice
);
return
embSkipLayerNormHFace
<
half
>
(
stream
,
static_cast
<
int32_t
>
(
mLd
),
batchSize
,
S
,
tem_inputs_ptr_dev
,
nbLookupTables_
,
beta
,
gamma
,
mIdsEmbDev_half
,
mIdsVocabSize_dev
,
output
);
if
(
nbLookupTables_
==
2
)
{
return
embSkipLayerNormHFace_2
<
half
>
(
stream
,
static_cast
<
int32_t
>
(
mLd
),
batchSize
,
S
,
static_cast
<
int32_t
const
*>
(
inputs
[
0
]),
static_cast
<
int32_t
const
*>
(
inputs
[
1
]),
nbLookupTables_
,
beta
,
gamma
,
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
0
]),
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
1
]),
mIdsVocabSize
[
0
],
mIdsVocabSize
[
1
],
output
);
}
else
if
(
nbLookupTables_
==
3
)
{
return
embSkipLayerNormHFace_3
<
half
>
(
stream
,
static_cast
<
int32_t
>
(
mLd
),
batchSize
,
S
,
static_cast
<
int32_t
const
*>
(
inputs
[
0
]),
static_cast
<
int32_t
const
*>
(
inputs
[
1
]),
static_cast
<
int32_t
const
*>
(
inputs
[
2
]),
nbLookupTables_
,
beta
,
gamma
,
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
0
]),
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
1
]),
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
2
]),
mIdsVocabSize
[
0
],
mIdsVocabSize
[
1
],
mIdsVocabSize
[
2
],
output
);
}
else
if
(
nbLookupTables_
==
4
)
{
return
embSkipLayerNormHFace_4
<
half
>
(
stream
,
static_cast
<
int32_t
>
(
mLd
),
batchSize
,
S
,
static_cast
<
int32_t
const
*>
(
inputs
[
0
]),
static_cast
<
int32_t
const
*>
(
inputs
[
1
]),
static_cast
<
int32_t
const
*>
(
inputs
[
2
]),
static_cast
<
int32_t
const
*>
(
inputs
[
3
]),
nbLookupTables_
,
beta
,
gamma
,
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
0
]),
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
1
]),
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
2
]),
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
3
]),
mIdsVocabSize
[
0
],
mIdsVocabSize
[
1
],
mIdsVocabSize
[
2
],
mIdsVocabSize
[
3
],
output
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Only support 2,3,4 lookup_tables fused "
));
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Unsupported type error, expected [kHALF,kFLOAT]"
));
...
...
@@ -414,7 +486,7 @@ int32_t EmbLayerNormVarSeqlenPluginMTron::enqueue(
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
noexcept
{
int32_t
const
batchSize
=
inputDesc
[
0
].
dims
.
d
[
0
]
-
1
;
int32_t
batchSize
=
inputDesc
[
0
].
dims
.
d
[
0
]
-
1
;
// read out the maximum sequence length from the dummy input
int32_t
const
maxSeqlen
=
inputDesc
[
nbLookupTables_
].
dims
.
d
[
0
];
int32_t
S
=
384
;
...
...
@@ -427,64 +499,141 @@ int32_t EmbLayerNormVarSeqlenPluginMTron::enqueue(
}
const
float
*
beta
=
mBetaDev
.
get
();
const
float
*
gamma
=
mGammaDev
.
get
();
int32_t
**
tem_inputs_ptr_dev
;
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
tem_inputs_ptr_dev
),
sizeof
(
void
*
)
*
nbLookupTables_
);
cudaMemcpy
(
tem_inputs_ptr_dev
,
inputs
,
sizeof
(
void
*
)
*
nbLookupTables_
,
cudaMemcpyHostToDevice
);
int32_t
*
mIdsVocabSize_dev
;
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
mIdsVocabSize_dev
),
sizeof
(
int32_t
)
*
mIdsVocabSize
.
size
());
cudaMemcpy
(
mIdsVocabSize_dev
,
&
(
mIdsVocabSize
[
0
]),
sizeof
(
int32_t
)
*
mIdsVocabSize
.
size
(),
cudaMemcpyHostToDevice
);
if
(
mType
==
nvinfer1
::
DataType
::
kFLOAT
)
{
auto
output
=
static_cast
<
float
*>
(
outputs
[
0
]);
auto
skip
=
static_cast
<
float
*>
(
outputs
[
1
]);
float
**
mIdsEmbDev_float
;
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
mIdsEmbDev_float
),
sizeof
(
void
*
)
*
nbLookupTables_
);
cudaMemcpy
(
mIdsEmbDev_float
,
&
(
mIdsEmbDev
[
0
]),
sizeof
(
void
*
)
*
nbLookupTables_
,
cudaMemcpyHostToDevice
);
return
embSkipLayerNormMTron
<
float
>
(
stream
,
static_cast
<
int32_t
>
(
mLd
),
batchSize
,
S
,
tem_inputs_ptr_dev
,
nbLookupTables_
,
beta
,
gamma
,
mIdsEmbDev_float
,
mIdsVocabSize_dev
,
output
,
skip
);
if
(
nbLookupTables_
==
2
)
{
return
embSkipLayerNormMTron_2
<
float
>
(
stream
,
static_cast
<
int32_t
>
(
mLd
),
batchSize
,
S
,
static_cast
<
int32_t
const
*>
(
inputs
[
0
]),
static_cast
<
int32_t
const
*>
(
inputs
[
1
]),
nbLookupTables_
,
beta
,
gamma
,
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
0
]),
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
1
]),
mIdsVocabSize
[
0
],
mIdsVocabSize
[
1
],
output
,
skip
);
}
else
if
(
nbLookupTables_
==
3
)
{
return
embSkipLayerNormMTron_3
<
float
>
(
stream
,
static_cast
<
int32_t
>
(
mLd
),
batchSize
,
S
,
static_cast
<
int32_t
const
*>
(
inputs
[
0
]),
static_cast
<
int32_t
const
*>
(
inputs
[
1
]),
static_cast
<
int32_t
const
*>
(
inputs
[
2
]),
nbLookupTables_
,
beta
,
gamma
,
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
0
]),
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
1
]),
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
2
]),
mIdsVocabSize
[
0
],
mIdsVocabSize
[
1
],
mIdsVocabSize
[
2
],
output
,
skip
);
}
else
if
(
nbLookupTables_
==
4
)
{
return
embSkipLayerNormMTron_4
<
float
>
(
stream
,
static_cast
<
int32_t
>
(
mLd
),
batchSize
,
S
,
static_cast
<
int32_t
const
*>
(
inputs
[
0
]),
static_cast
<
int32_t
const
*>
(
inputs
[
1
]),
static_cast
<
int32_t
const
*>
(
inputs
[
2
]),
static_cast
<
int32_t
const
*>
(
inputs
[
3
]),
nbLookupTables_
,
beta
,
gamma
,
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
0
]),
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
1
]),
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
2
]),
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
3
]),
mIdsVocabSize
[
0
],
mIdsVocabSize
[
1
],
mIdsVocabSize
[
2
],
mIdsVocabSize
[
3
],
output
,
skip
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Only support 2,3,4 lookup_tables fused "
));
}
}
else
if
(
mType
==
nvinfer1
::
DataType
::
kHALF
)
{
auto
output
=
static_cast
<
half
*>
(
outputs
[
0
]);
auto
skip
=
static_cast
<
half
*>
(
outputs
[
1
]);
half
**
mIdsEmbDev_half
;
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
mIdsEmbDev_half
),
sizeof
(
void
*
)
*
nbLookupTables_
);
cudaMemcpy
(
mIdsEmbDev_half
,
&
(
mIdsEmbDev
[
0
]),
sizeof
(
void
*
)
*
nbLookupTables_
,
cudaMemcpyHostToDevice
);
return
embSkipLayerNormMTron
<
half
>
(
stream
,
static_cast
<
int32_t
>
(
mLd
),
batchSize
,
S
,
tem_inputs_ptr_dev
,
nbLookupTables_
,
beta
,
gamma
,
mIdsEmbDev_half
,
mIdsVocabSize_dev
,
output
,
skip
);
if
(
nbLookupTables_
==
2
)
{
return
embSkipLayerNormMTron_2
<
half
>
(
stream
,
static_cast
<
int32_t
>
(
mLd
),
batchSize
,
S
,
static_cast
<
int32_t
const
*>
(
inputs
[
0
]),
static_cast
<
int32_t
const
*>
(
inputs
[
1
]),
nbLookupTables_
,
beta
,
gamma
,
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
0
]),
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
1
]),
mIdsVocabSize
[
0
],
mIdsVocabSize
[
1
],
output
,
skip
);
}
else
if
(
nbLookupTables_
==
3
)
{
return
embSkipLayerNormMTron_3
<
half
>
(
stream
,
static_cast
<
int32_t
>
(
mLd
),
batchSize
,
S
,
static_cast
<
int32_t
const
*>
(
inputs
[
0
]),
static_cast
<
int32_t
const
*>
(
inputs
[
1
]),
static_cast
<
int32_t
const
*>
(
inputs
[
2
]),
nbLookupTables_
,
beta
,
gamma
,
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
0
]),
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
1
]),
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
2
]),
mIdsVocabSize
[
0
],
mIdsVocabSize
[
1
],
mIdsVocabSize
[
2
],
output
,
skip
);
}
else
if
(
nbLookupTables_
==
4
)
{
return
embSkipLayerNormMTron_4
<
half
>
(
stream
,
static_cast
<
int32_t
>
(
mLd
),
batchSize
,
S
,
static_cast
<
int32_t
const
*>
(
inputs
[
0
]),
static_cast
<
int32_t
const
*>
(
inputs
[
1
]),
static_cast
<
int32_t
const
*>
(
inputs
[
2
]),
static_cast
<
int32_t
const
*>
(
inputs
[
3
]),
nbLookupTables_
,
beta
,
gamma
,
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
0
]),
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
1
]),
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
2
]),
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
3
]),
mIdsVocabSize
[
0
],
mIdsVocabSize
[
1
],
mIdsVocabSize
[
2
],
mIdsVocabSize
[
3
],
output
,
skip
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Only support 2,3,4 lookup_tables fused "
));
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Unsupported type error, expected [kHALF,kFLOAT]"
));
...
...
@@ -566,9 +715,9 @@ void EmbLayerNormVarSeqlenPluginBase::serialize(void* buffer) const noexcept {
size_t
const
wordSize
=
getElementSize
(
mType
);
serFromDev
(
&
d
,
mBetaDev
.
get
(),
mLd
);
serFromDev
(
&
d
,
mGammaDev
.
get
(),
mLd
);
for
(
size_t
i
=
0
;
i
<
mIdsEmb
Dev
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
mIdsEmb
Ptrs
.
size
();
++
i
)
{
serFromDev
(
&
d
,
static_cast
<
char
*>
(
mIdsEmb
Dev
[
i
]),
static_cast
<
char
*>
(
mIdsEmb
Ptrs
[
i
]),
mLd
*
mIdsVocabSize
[
i
]
*
wordSize
);
}
}
...
...
@@ -577,8 +726,8 @@ void EmbLayerNormVarSeqlenPluginBase::destroy() noexcept {
// This gets called when the network containing plugin is destroyed
mBetaDev
.
reset
(
nullptr
);
mGammaDev
.
reset
(
nullptr
);
for
(
size_t
i
=
0
;
i
<
mIdsEmb
Dev
.
size
();
++
i
)
{
cudaFree
(
mIdsEmb
Dev
[
i
]);
for
(
size_t
i
=
0
;
i
<
mIdsEmb
Ptrs
.
size
();
++
i
)
{
cudaFree
(
mIdsEmb
Ptrs
[
i
]);
}
delete
this
;
}
...
...
@@ -680,7 +829,6 @@ nvinfer1::IPluginV2* EmbLayerNormVarSeqlenPluginHFaceCreator::createPlugin(
beta
,
gamma
,
IdsEmb
);
return
p
;
}
...
...
paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h
浏览文件 @
6512e087
...
...
@@ -31,32 +31,121 @@ namespace tensorrt {
namespace
plugin
{
template
<
typename
T
>
int32_t
embSkipLayerNormHFace
(
cudaStream_t
stream
,
int32_t
ld
,
int32_t
B
,
int32_t
S
,
int32_t
**
inputIds
,
int32_t
const
nbLookupTables
,
float
const
*
beta
,
float
const
*
gamma
,
T
**
idsEmb
,
int32_t
*
,
T
*
output
);
int32_t
embSkipLayerNormHFace_2
(
cudaStream_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
,
float
const
*
,
float
const
*
,
T
const
*
,
T
const
*
,
int32_t
,
int32_t
,
T
*
);
template
<
typename
T
>
int32_t
embSkipLayerNormMTron
(
cudaStream_t
stream
,
int32_t
ld
,
int32_t
B
,
int32_t
S
,
int32_t
**
inputIds
,
int32_t
const
nbLookupTables
,
float
const
*
beta
,
float
const
*
gamma
,
T
**
idsEmb
,
int32_t
*
,
T
*
output
,
T
*
skip
);
int32_t
embSkipLayerNormHFace_3
(
cudaStream_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
,
float
const
*
,
float
const
*
,
T
const
*
,
T
const
*
,
T
const
*
,
int32_t
,
int32_t
,
int32_t
,
T
*
);
template
<
typename
T
>
int32_t
embSkipLayerNormHFace_4
(
cudaStream_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
,
float
const
*
,
float
const
*
,
T
const
*
,
T
const
*
,
T
const
*
,
T
const
*
,
int32_t
,
int32_t
,
int32_t
,
int32_t
,
T
*
);
template
<
typename
T
>
int32_t
embSkipLayerNormMTron_2
(
cudaStream_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
,
float
const
*
,
float
const
*
,
T
const
*
,
T
const
*
,
int32_t
,
int32_t
,
T
*
,
T
*
);
template
<
typename
T
>
int32_t
embSkipLayerNormMTron_3
(
cudaStream_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
,
float
const
*
,
float
const
*
,
T
const
*
,
T
const
*
,
T
const
*
,
int32_t
,
int32_t
,
int32_t
,
T
*
,
T
*
);
template
<
typename
T
>
int32_t
embSkipLayerNormMTron_4
(
cudaStream_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
,
float
const
*
,
float
const
*
,
T
const
*
,
T
const
*
,
T
const
*
,
T
const
*
,
int32_t
,
int32_t
,
int32_t
,
int32_t
,
T
*
,
T
*
);
class
EmbLayerNormVarSeqlenPluginBase
:
public
nvinfer1
::
IPluginV2DynamicExt
{
public:
EmbLayerNormVarSeqlenPluginBase
(
...
...
@@ -104,7 +193,8 @@ class EmbLayerNormVarSeqlenPluginBase : public nvinfer1::IPluginV2DynamicExt {
std
::
string
mNamespace
;
cuda_unique_ptr
<
float
>
mGammaDev
;
cuda_unique_ptr
<
float
>
mBetaDev
;
std
::
vector
<
void
*>
mIdsEmbDev
;
std
::
vector
<
void
*>
mIdsEmbPtrs
;
// std::vector<void*> mIdsEmbDev;
size_t
mLd
;
// leading dim = hidden size
std
::
vector
<
int32_t
>
mIdsVocabSize
;
WeightsWithOwnership
mBeta
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录