Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
0043fa8c
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0043fa8c
编写于
8月 30, 2021
作者:
C
ceci3
提交者:
GitHub
8月 30, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[paddle-TRT]support matmul set to int8 in multihead (#34917)
* update ernie int8
上级
c0bdef5d
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
47 addition
and
15 deletion
+47
-15
paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc
paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc
+26
-7
paddle/fluid/inference/tensorrt/convert/fc_op.cc
paddle/fluid/inference/tensorrt/convert/fc_op.cc
+1
-2
paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
...e/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
+20
-6
未找到文件。
paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc
浏览文件 @
0043fa8c
...
...
@@ -758,7 +758,9 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
Node
*
input0
,
Node
*
mul0
,
Node
*
mul1
,
Node
*
mul2
,
Node
*
mul0_out
,
Node
*
mul1_out
,
Node
*
mul2_out
,
Node
*
mul0_w
,
Node
*
mul1_w
,
Node
*
mul2_w
,
Node
*
eltadd0_b
,
Node
*
eltadd1_b
,
Node
*
eltadd2_b
,
Node
*
eltadd_qk_b
,
Node
*
reshape2
,
Node
*
reshape2_qkv_out
,
Node
*
scale
,
Node
*
scale_out
)
{
Node
*
reshape2
,
Node
*
reshape2_qkv_out
,
Node
*
scale
,
Node
*
scale_out
,
Node
*
softmax_qk
,
Node
*
eltadd0
,
Node
*
eltadd1
,
Node
*
eltadd2
,
Node
*
matmul_qk
)
{
auto
scale_attr
=
BOOST_GET_CONST
(
float
,
scale
->
Op
()
->
GetAttr
(
"scale"
));
// mul (B * S * Hidden) x (Hidden * 3 * N * H) = (B * S * 3 * N * H)
...
...
@@ -876,19 +878,35 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
weight_max
=
std
::
max
(
weight_max
,
weight_scale2
);
multihead_op_desc
.
SetAttr
(
"weight_scale"
,
weight_max
);
if
(
mul0_op_desc
->
HasAttr
(
"out_threshold"
))
{
auto
*
add0_op_desc
=
eltadd0
->
Op
();
auto
*
add1_op_desc
=
eltadd1
->
Op
();
auto
*
add2_op_desc
=
eltadd2
->
Op
();
if
(
add0_op_desc
->
HasAttr
(
"out_threshold"
))
{
auto
out_scale0
=
BOOST_GET_CONST
(
float
,
mul
0_op_desc
->
GetAttr
(
"out_threshold"
));
BOOST_GET_CONST
(
float
,
add
0_op_desc
->
GetAttr
(
"out_threshold"
));
auto
out_scale1
=
BOOST_GET_CONST
(
float
,
mul
1_op_desc
->
GetAttr
(
"out_threshold"
));
BOOST_GET_CONST
(
float
,
add
1_op_desc
->
GetAttr
(
"out_threshold"
));
auto
out_scale2
=
BOOST_GET_CONST
(
float
,
mul
2_op_desc
->
GetAttr
(
"out_threshold"
));
BOOST_GET_CONST
(
float
,
add
2_op_desc
->
GetAttr
(
"out_threshold"
));
auto
out_scale_max
=
std
::
max
(
out_scale0
,
out_scale1
);
out_scale_max
=
std
::
max
(
out_scale_max
,
out_scale2
);
multihead_op_desc
.
SetAttr
(
"out_threshold"
,
out_scale_max
);
multihead_op_desc
.
SetAttr
(
"
fc_
out_threshold"
,
out_scale_max
);
}
}
auto
*
softmax_qk_op_desc
=
softmax_qk
->
Op
();
auto
*
matmul_qk_op_desc
=
matmul_qk
->
Op
();
if
(
matmul_qk_op_desc
->
HasAttr
(
"X_scale"
))
{
multihead_op_desc
.
SetAttr
(
"qkv2context_plugin_int8"
,
true
);
if
(
softmax_qk_op_desc
->
HasAttr
(
"out_threshold"
))
{
auto
qkv_plugin_scale
=
BOOST_GET_CONST
(
float
,
softmax_qk_op_desc
->
GetAttr
(
"out_threshold"
));
multihead_op_desc
.
SetAttr
(
"dp_probs"
,
qkv_plugin_scale
);
}
}
else
{
multihead_op_desc
.
SetAttr
(
"qkv2context_plugin_int8"
,
false
);
}
auto
*
multihead
=
graph
->
CreateOpNode
(
&
multihead_op_desc
);
IR_NODE_LINK_TO
(
input0
,
multihead
);
...
...
@@ -990,7 +1008,8 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
}
fuse_creater
(
input0
,
mul0
,
mul1
,
mul2
,
mul0_out
,
mul1_out
,
mul2_out
,
mul0_w
,
mul1_w
,
mul2_w
,
eltadd0_b
,
eltadd1_b
,
eltadd2_b
,
eltadd_qk_b
,
reshape2_0
,
reshape2_qkv_out
,
scale
,
scale_out
);
reshape2_0
,
reshape2_qkv_out
,
scale
,
scale_out
,
softmax_qk
,
eltadd0
,
eltadd1
,
eltadd2
,
matmul_qk
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
eltadd0
,
eltadd1
,
...
...
paddle/fluid/inference/tensorrt/convert/fc_op.cc
浏览文件 @
0043fa8c
...
...
@@ -164,10 +164,9 @@ class FcOpConverter : public OpConverter {
auto
*
fc_layer_int8
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Convolution
,
*
inputs
,
n_output
,
nv_ksize
,
weight
.
get
(),
bias
.
get
());
engine_
->
SetTensorDynamicRange
(
fc_layer_int8
->
getOutput
(
0
),
out_scale
);
auto
*
fc_after_reshape_int8
=
reshape_after_fc
(
fc_layer_int8
->
getOutput
(
0
),
x_dim
,
x_num_col_dims
);
engine_
->
SetTensorDynamicRange
(
fc_after_reshape_int8
->
getOutput
(
0
),
out_scale
);
if
(
activation_type
==
"relu"
)
{
nvinfer1
::
IActivationLayer
*
relu_layer_int8
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Activation
,
*
(
fc_after_reshape_int8
->
getOutput
(
0
)),
...
...
paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
浏览文件 @
0043fa8c
...
...
@@ -42,6 +42,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
float
*
weight_data
=
nullptr
;
bool
enable_int8
=
op_desc
.
HasAttr
(
"enable_int8"
);
bool
qkv2context_plugin_int8
=
BOOST_GET_CONST
(
bool
,
op_desc
.
GetAttr
(
"qkv2context_plugin_int8"
));
float
in_scale
=
0.
;
if
(
enable_int8
)
{
...
...
@@ -147,13 +149,16 @@ class MultiheadMatMulOpConverter : public OpConverter {
if
(
enable_int8
)
{
PADDLE_ENFORCE_EQ
(
op_desc
.
HasAttr
(
"out_threshold"
),
true
,
op_desc
.
HasAttr
(
"
fc_
out_threshold"
),
true
,
platform
::
errors
::
InvalidArgument
(
"must have out threshold in multihead layers in int8 mode"
));
float
out_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"out_threshold"
));
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"
fc_
out_threshold"
));
engine_
->
SetTensorDynamicRange
(
fc_layer
->
getOutput
(
0
),
out_scale
);
dp_probs
=
out_scale
/
127.0
;
if
(
qkv2context_plugin_int8
)
{
dp_probs
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"dp_probs"
))
/
127.0
;
}
}
auto
mask_tensor
=
engine_
->
GetITensor
(
"qkv_plugin_mask"
);
...
...
@@ -166,16 +171,25 @@ class MultiheadMatMulOpConverter : public OpConverter {
:
nvinfer1
::
DataType
::
kFLOAT
);
if
(
enable_int8
)
{
type
=
static_cast
<
int
>
(
nvinfer1
::
DataType
::
kHALF
);
if
(
qkv2context_plugin_int8
)
{
type
=
static_cast
<
int
>
(
nvinfer1
::
DataType
::
kINT8
);
}
}
bool
has_mask
=
true
;
int
var_seqlen
=
1
;
const
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
{
"type_id"
,
&
type
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"hidden_size"
,
&
hidden_out
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"num_heads"
,
&
head_number
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"has_mask"
,
&
has_mask
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"var_seqlen"
,
&
var_seqlen
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"dq_probs"
,
&
dp_probs
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
1
}};
{
"var_seqlen"
,
&
var_seqlen
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
}};
if
(
qkv2context_plugin_int8
)
{
fields
.
push_back
(
{
"dq_probs"
,
&
dp_probs
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
1
});
}
nvinfer1
::
PluginFieldCollection
*
plugin_collection
=
static_cast
<
nvinfer1
::
PluginFieldCollection
*>
(
malloc
(
sizeof
(
*
plugin_collection
)
+
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录