Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
fa06d9c3
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
fa06d9c3
编写于
8月 26, 2022
作者:
W
Wangzheee
提交者:
GitHub
8月 26, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix_multihead (#45429)
上级
a5e9ccda
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
319 addition
and
310 deletion
+319
-310
paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
...e/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
+319
-310
未找到文件。
paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
浏览文件 @
fa06d9c3
...
@@ -291,7 +291,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
...
@@ -291,7 +291,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
*
plugin
);
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
*
plugin
);
layer
=
plugin_layer
;
layer
=
plugin_layer
;
}
}
}
}
else
{
if
(
input_dims
.
d
[
1
]
<=
384
&&
!
bias_qk_attr
&&
if
(
input_dims
.
d
[
1
]
<=
384
&&
!
bias_qk_attr
&&
engine_
->
precision
()
!=
AnalysisConfig
::
Precision
::
kFloat32
)
{
engine_
->
precision
()
!=
AnalysisConfig
::
Precision
::
kFloat32
)
{
/*
/*
...
@@ -392,12 +392,14 @@ class MultiheadMatMulOpConverter : public OpConverter {
...
@@ -392,12 +392,14 @@ class MultiheadMatMulOpConverter : public OpConverter {
reshape_before_fc_layer
->
setInput
(
reshape_before_fc_layer
->
setInput
(
1
,
*
Concat
(
reshape_before_fc_shape_tensor
));
1
,
*
Concat
(
reshape_before_fc_shape_tensor
));
reshape_before_fc_layer
->
setName
(
reshape_before_fc_layer
->
setName
(
(
"shuffle_before_fc_multihead_matmul(Output: "
+
output_name
+
")"
)
(
"shuffle_before_fc_multihead_matmul(Output: "
+
output_name
+
")"
)
.
c_str
());
.
c_str
());
// add fc layer
// add fc layer
nvinfer1
::
ILayer
*
fc_layer
=
nullptr
;
nvinfer1
::
ILayer
*
fc_layer
=
nullptr
;
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
FullyConnected
,
FullyConnected
,
*
reshape_before_fc_layer
->
getOutput
(
0
),
*
reshape_before_fc_layer
->
getOutput
(
0
),
n
,
n
,
...
@@ -427,14 +429,20 @@ class MultiheadMatMulOpConverter : public OpConverter {
...
@@ -427,14 +429,20 @@ class MultiheadMatMulOpConverter : public OpConverter {
int
var_seqlen
=
1
;
int
var_seqlen
=
1
;
bool
has_mask
=
true
;
bool
has_mask
=
true
;
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
{
"hidden_size"
,
&
hidden_out
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"hidden_size"
,
&
hidden_out
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"num_heads"
,
&
head_number
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"num_heads"
,
&
head_number
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"type_id"
,
&
type
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"type_id"
,
&
type
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"has_mask"
,
&
has_mask
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"has_mask"
,
&
has_mask
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"var_seqlen"
,
&
var_seqlen
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
}};
{
"var_seqlen"
,
&
var_seqlen
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
}};
nvinfer1
::
PluginFieldCollection
*
plugin_collection
=
nvinfer1
::
PluginFieldCollection
*
plugin_collection
=
static_cast
<
nvinfer1
::
PluginFieldCollection
*>
(
static_cast
<
nvinfer1
::
PluginFieldCollection
*>
(
malloc
(
malloc
(
sizeof
(
*
plugin_collection
)
+
sizeof
(
*
plugin_collection
)
+
fields
.
size
()
*
fields
.
size
()
*
sizeof
(
nvinfer1
::
PluginField
)));
// remember to free
sizeof
(
nvinfer1
::
PluginField
)));
// remember to free
plugin_collection
->
nbFields
=
static_cast
<
int
>
(
fields
.
size
());
plugin_collection
->
nbFields
=
static_cast
<
int
>
(
fields
.
size
());
...
@@ -506,8 +514,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
...
@@ -506,8 +514,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
*
plugin
);
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
*
plugin
);
// add shuffle
// add shuffle
auto
*
reshape_after_mha_layer
=
auto
*
reshape_after_mha_layer
=
TRT_ENGINE_ADD_LAYER
(
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
plugin_layer
->
getOutput
(
0
));
engine_
,
Shuffle
,
*
plugin_layer
->
getOutput
(
0
));
std
::
vector
<
nvinfer1
::
ITensor
*>
reshape_tensor
;
std
::
vector
<
nvinfer1
::
ITensor
*>
reshape_tensor
;
reshape_tensor
.
push_back
(
batch_tensor
);
reshape_tensor
.
push_back
(
batch_tensor
);
reshape_tensor
.
push_back
(
length_tensor
);
reshape_tensor
.
push_back
(
length_tensor
);
...
@@ -554,8 +562,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
...
@@ -554,8 +562,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
auto
*
reshape_before_fc_layer
=
auto
*
reshape_before_fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
input
);
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
input
);
if
(
op_desc
.
HasAttr
(
"Input_scale"
))
{
if
(
op_desc
.
HasAttr
(
"Input_scale"
))
{
engine_
->
SetTensorDynamicRange
(
reshape_before_fc_layer
->
getOutput
(
0
),
engine_
->
SetTensorDynamicRange
(
in_scale
);
reshape_before_fc_layer
->
getOutput
(
0
),
in_scale
);
}
}
reshape_before_fc_layer
->
setInput
(
reshape_before_fc_layer
->
setInput
(
1
,
*
Concat
(
reshape_before_fc_shape_tensor
));
1
,
*
Concat
(
reshape_before_fc_shape_tensor
));
...
@@ -586,11 +594,11 @@ class MultiheadMatMulOpConverter : public OpConverter {
...
@@ -586,11 +594,11 @@ class MultiheadMatMulOpConverter : public OpConverter {
}
}
if
(
op_desc
.
HasAttr
(
"fc_out_threshold"
))
{
if
(
op_desc
.
HasAttr
(
"fc_out_threshold"
))
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
op_desc
.
HasAttr
(
"fc_out_threshold"
),
op_desc
.
HasAttr
(
"fc_out_threshold"
),
true
,
true
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"must have out threshold in multihead layers in int8 mode"
));
"must have out threshold in multihead layers "
"in int8 mode"
));
float
out_scale
=
float
out_scale
=
PADDLE_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"fc_out_threshold"
));
PADDLE_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"fc_out_threshold"
));
engine_
->
SetTensorDynamicRange
(
fc_layer
->
getOutput
(
0
),
out_scale
);
engine_
->
SetTensorDynamicRange
(
fc_layer
->
getOutput
(
0
),
out_scale
);
...
@@ -619,6 +627,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
...
@@ -619,6 +627,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
hidden_in
,
head_number
,
head_size
,
scale
,
with_fp16
);
hidden_in
,
head_number
,
head_size
,
scale
,
with_fp16
);
layer
=
engine_
->
AddDynamicPlugin
(
plugin_inputs
.
data
(),
2
,
plugin
);
layer
=
engine_
->
AddDynamicPlugin
(
plugin_inputs
.
data
(),
2
,
plugin
);
}
}
}
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"You are running the Ernie(Bert) model in static shape mode, which "
"You are running the Ernie(Bert) model in static shape mode, which "
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录