Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
51e9898d
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
51e9898d
编写于
4月 03, 2020
作者:
C
cc
提交者:
GitHub
4月 03, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Modify quant_dequant_fuse_pass to process quant_dequant_op, test=develop (#3341)
上级
add162dc
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
127 addition
and
204 deletion
+127
-204
lite/api/benchmark.cc
lite/api/benchmark.cc
+7
-13
lite/api/cxx_api.cc
lite/api/cxx_api.cc
+4
-1
lite/core/mir/fusion/quant_dequant_fuse_pass.cc
lite/core/mir/fusion/quant_dequant_fuse_pass.cc
+3
-5
lite/core/mir/fusion/quant_dequant_op_fuser.cc
lite/core/mir/fusion/quant_dequant_op_fuser.cc
+85
-175
lite/core/mir/fusion/quant_dequant_op_fuser.h
lite/core/mir/fusion/quant_dequant_op_fuser.h
+2
-10
lite/core/op_lite.h
lite/core/op_lite.h
+26
-0
未找到文件。
lite/api/benchmark.cc
浏览文件 @
51e9898d
...
...
@@ -44,7 +44,10 @@ DEFINE_string(input_shape,
"set input shapes according to the model, "
"separated by colon and comma, "
"such as 1,3,244,244"
);
DEFINE_string
(
input_img_path
,
""
,
"the path of input image"
);
DEFINE_string
(
input_img_path
,
""
,
"the path of input image, if not set "
"input_img_path, the input of model will be 1.0."
);
DEFINE_int32
(
warmup
,
0
,
"warmup times"
);
DEFINE_int32
(
repeats
,
1
,
"repeats times"
);
DEFINE_int32
(
power_mode
,
...
...
@@ -57,16 +60,11 @@ DEFINE_int32(power_mode,
DEFINE_int32
(
threads
,
1
,
"threads num"
);
DEFINE_string
(
result_filename
,
"result.txt"
,
"save benchmark "
"result to the file"
);
"save the inference time to the file."
);
DEFINE_bool
(
run_model_optimize
,
false
,
"if set true, apply model_optimize_tool to "
"model and use optimized model to test. "
);
DEFINE_bool
(
is_quantized_model
,
false
,
"if set true, "
"test the performance of the quantized model. "
);
namespace
paddle
{
namespace
lite_api
{
...
...
@@ -87,10 +85,6 @@ void OutputOptModel(const std::string& save_optimized_model_dir) {
std
::
vector
<
Place
>
vaild_places
=
{
Place
{
TARGET
(
kARM
),
PRECISION
(
kFloat
)},
};
if
(
FLAGS_is_quantized_model
)
{
vaild_places
.
insert
(
vaild_places
.
begin
(),
Place
{
TARGET
(
kARM
),
PRECISION
(
kInt8
)});
}
config
.
set_valid_places
(
vaild_places
);
auto
predictor
=
lite_api
::
CreatePaddlePredictor
(
config
);
...
...
@@ -181,8 +175,8 @@ void Run(const std::vector<int64_t>& input_shape,
int
main
(
int
argc
,
char
**
argv
)
{
gflags
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
true
);
if
(
FLAGS_model_dir
==
""
||
FLAGS_result_filename
==
""
)
{
LOG
(
INFO
)
<<
"
p
lease run ./benchmark_bin --help to obtain usage."
;
if
(
FLAGS_model_dir
==
""
)
{
LOG
(
INFO
)
<<
"
P
lease run ./benchmark_bin --help to obtain usage."
;
exit
(
0
);
}
...
...
lite/api/cxx_api.cc
浏览文件 @
51e9898d
...
...
@@ -295,6 +295,8 @@ void Predictor::Build(const cpp::ProgramDesc &desc,
inner_places
.
emplace_back
(
TARGET
(
kHost
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
));
// Analysis whether the modle is quantized.
// For quantized model, add place(arm, int8) to inner_places
const
std
::
vector
<
std
::
string
>
quant_dequant_op
=
{
"fake_quantize_abs_max"
,
"fake_quantize_range_abs_max"
,
...
...
@@ -317,7 +319,8 @@ void Predictor::Build(const cpp::ProgramDesc &desc,
}
}
if
(
is_quantized_model
)
{
inner_places
.
emplace_back
(
Place
{
TARGET
(
kARM
),
PRECISION
(
kInt8
)});
inner_places
.
insert
(
inner_places
.
begin
(),
Place
{
TARGET
(
kARM
),
PRECISION
(
kInt8
)});
}
Program
program
(
desc
,
scope_
,
inner_places
);
...
...
lite/core/mir/fusion/quant_dequant_fuse_pass.cc
浏览文件 @
51e9898d
...
...
@@ -44,11 +44,9 @@ void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fuser
(
graph
.
get
());
}
// delete quant_dequant_node
for
(
auto
op_type
:
{
"pool2d"
,
"softmax"
,
"elementwise_add"
})
{
fusion
::
DeleteQuantDequantOpFuser
fuser
(
op_type
);
fuser
(
graph
.
get
());
}
// process quant_dequant_node
fusion
::
DeleteQuantDequantOpFuser
dqd_fuser
;
dqd_fuser
(
graph
.
get
());
}
}
// namespace mir
...
...
lite/core/mir/fusion/quant_dequant_op_fuser.cc
浏览文件 @
51e9898d
...
...
@@ -50,7 +50,7 @@ void DeleteQuantOpFuser::InsertNewNode(SSAGraph* graph,
auto
*
output_scale_node
=
matched
.
at
(
"output_scale_node"
);
auto
*
output_act_node
=
matched
.
at
(
"output_act_node"
);
// obtain
values, save value
s and relink node
// obtain
scale, save attr
s and relink node
int
bit_length
=
quant_node
->
stmt
()
->
op_info
()
->
GetAttr
<
int
>
(
"bit_length"
);
int
range
=
((
1
<<
(
bit_length
-
1
))
-
1
);
auto
*
scope
=
quant_node
->
stmt
()
->
op
()
->
scope
();
...
...
@@ -58,11 +58,22 @@ void DeleteQuantOpFuser::InsertNewNode(SSAGraph* graph,
->
GetMutable
<
lite
::
Tensor
>
();
float
scale_value
=
scale_tensor
->
data
<
float
>
()[
0
]
/
range
;
auto
in_act_name
=
input_act_node
->
arg
()
->
name
;
auto
out_act_name
=
output_act_node
->
arg
()
->
name
;
auto
outlinks
=
output_act_node
->
outlinks
;
for
(
auto
*
quantized_node
:
outlinks
)
{
auto
*
op_desc
=
quantized_node
->
stmt
()
->
mutable_op_info
();
op_desc
->
SetAttr
<
int
>
(
"bit_length"
,
bit_length
);
op_desc
->
SetAttr
<
float
>
(
"input_scale"
,
scale_value
);
// save input scale in quantized op by input argname + index
auto
op_desc
=
*
quantized_node
->
stmt
()
->
mutable_op_info
();
std
::
string
argname
;
int
index
;
op_desc
.
GetInputArgname
(
out_act_name
,
&
argname
);
op_desc
.
GetInputIndex
(
out_act_name
,
&
index
);
op_desc
.
SetAttr
<
float
>
(
argname
+
std
::
to_string
(
index
)
+
"_input_scale"
,
scale_value
);
op_desc
.
SetAttr
<
float
>
(
"input_scale"
,
scale_value
);
// save it for now
op_desc
.
SetAttr
<
int
>
(
"bit_length"
,
bit_length
);
op_desc
.
UpdateAllInputs
(
out_act_name
,
in_act_name
);
quantized_node
->
stmt
()
->
ResetOp
(
op_desc
,
graph
->
valid_places
());
IR_NODE_LINK_TO
(
input_act_node
,
quantized_node
)
}
...
...
@@ -125,19 +136,18 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
auto
*
dequant_op
=
matched
.
at
(
"dequant_op"
);
auto
*
dequant_op_out
=
matched
.
at
(
"dequant_op_out"
);
// obtain
input_scale and weight_scal
e
// obtain
weight_scale from max_rang
e
auto
*
scope
=
quantized_op
->
stmt
()
->
op
()
->
scope
();
auto
&
valid_places
=
quantized_op
->
stmt
()
->
op
()
->
valid_places
();
int
bit_length
=
quantized_op
->
stmt
()
->
op_info
()
->
GetAttr
<
int
>
(
"bit_length"
);
int
range
=
((
1
<<
(
bit_length
-
1
))
-
1
);
float
input_scale
=
quantized_op
->
stmt
()
->
op_info
()
->
GetAttr
<
float
>
(
"input_scale"
);
float
max_range
=
dequant_op
->
stmt
()
->
op_info
()
->
GetAttr
<
float
>
(
"max_range"
);
float
whole_weight_scale
=
static_cast
<
float
>
(
range
*
range
)
/
max_range
/
range
;
// max_range = range * range / max(abs(weight))
// weight_scale = range * range / (range * range / max(abs(weight))) / range
// = max(abs(weight)) / range
// As: max_range = range * range / max(abs(weight))
// So: whole_weight_scale
// = range * range / (range * range / max(abs(weight))) / range
// = max(abs(weight)) / range
// set op desc
cpp
::
OpDesc
op_desc
=
*
quantized_op
->
stmt
()
->
op_info
();
...
...
@@ -153,7 +163,7 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
// Conv weight shape: Cout * Cin * kh * hw, the weight_scale_size should
// be Cout.
weight_scale_size
=
quantized_weight_t
->
dims
()[
0
];
}
else
if
(
quantized_op_type_
==
"mul"
)
{
}
else
if
(
quantized_op_type_
==
"mul"
||
quantized_op_type_
==
"matmul"
)
{
op_desc
.
SetInput
(
"X"
,
{
quantized_op_input
->
arg
()
->
name
});
op_desc
.
SetOutput
(
"Out"
,
{
dequant_op_out
->
arg
()
->
name
});
// Fc weight: Cin * Cout, the weight_scale_size should be Cout.
...
...
@@ -163,7 +173,6 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
weight_scale
.
push_back
(
whole_weight_scale
);
}
op_desc
.
SetAttr
(
"enable_int8"
,
true
);
op_desc
.
SetAttr
(
"input_scale"
,
input_scale
);
op_desc
.
SetAttr
(
"weight_scale"
,
weight_scale
);
// change the weight from the float type to int8 type.
...
...
@@ -209,6 +218,7 @@ void ChannelWiseDequantOpFuser::BuildPattern() {
->
assert_is_op_output
(
quantized_op_type_
)
->
assert_is_op_input
(
dequant_op_type
,
"X"
)
->
AsIntermediate
();
// The scale var_node of input activation is deleted in DeleteQuantOpFuser
auto
*
dequant_op_channel_scale
=
VarNode
(
"dequant_op_channel_scale"
)
->
assert_is_op_input
(
dequant_op_type
)
->
AsIntermediate
();
...
...
@@ -237,11 +247,9 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph,
auto
*
dequant_op
=
matched
.
at
(
"dequant_op"
);
auto
*
dequant_op_out
=
matched
.
at
(
"dequant_op_out"
);
// obtain input
_scale and weight_scale
// obtain input
weight_scale from fake_dequant op
auto
*
scope
=
quantized_op
->
stmt
()
->
op
()
->
scope
();
auto
&
valid_places
=
quantized_op
->
stmt
()
->
op
()
->
valid_places
();
float
input_scale
=
quantized_op
->
stmt
()
->
op_info
()
->
GetAttr
<
float
>
(
"input_scale"
);
std
::
vector
<
float
>
weight_scale
;
std
::
vector
<
int
>
quant_bits
=
...
...
@@ -258,11 +266,15 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph,
// set op desc
cpp
::
OpDesc
op_desc
=
*
quantized_op
->
stmt
()
->
op_info
();
op_desc
.
SetInput
(
"Input"
,
{
quantized_op_input
->
arg
()
->
name
});
op_desc
.
SetOutput
(
"Output"
,
{
dequant_op_out
->
arg
()
->
name
});
if
(
quantized_op_type_
==
"conv2d"
||
quantized_op_type_
==
"depthwise_conv2d"
)
{
op_desc
.
SetInput
(
"Input"
,
{
quantized_op_input
->
arg
()
->
name
});
op_desc
.
SetOutput
(
"Output"
,
{
dequant_op_out
->
arg
()
->
name
});
}
else
if
(
quantized_op_type_
==
"mul"
||
quantized_op_type_
==
"matmul"
)
{
op_desc
.
SetInput
(
"X"
,
{
quantized_op_input
->
arg
()
->
name
});
op_desc
.
SetOutput
(
"Out"
,
{
dequant_op_out
->
arg
()
->
name
});
}
op_desc
.
SetAttr
(
"enable_int8"
,
true
);
op_desc
.
SetAttr
(
"input_scale"
,
input_scale
);
op_desc
.
SetAttr
(
"weight_scale"
,
weight_scale
);
// change the weight from the float type to int8 type.
...
...
@@ -297,167 +309,65 @@ cpp::OpDesc ChannelWiseDequantOpFuser::GenOpDesc(const key2nodes_t& matched) {
void
DeleteQuantDequantOpFuser
::
BuildPattern
()
{
std
::
string
quant_dequant_op_type
=
"fake_quantize_dequantize_moving_average_abs_max"
;
if
(
quantized_op_type_
==
"pool2d"
||
quantized_op_type_
==
"softmax"
)
{
auto
*
input_scale_node
=
VarNode
(
"input_scale_node"
)
->
assert_is_op_input
(
quant_dequant_op_type
,
"InScale"
);
auto
*
input_act_node
=
VarNode
(
"input_act_node"
)
->
assert_is_op_input
(
quant_dequant_op_type
,
"X"
);
auto
*
quant_dequant_node
=
OpNode
(
"quant_dequant_node"
,
quant_dequant_op_type
)
->
assert_is_op
(
quant_dequant_op_type
);
auto
*
output_scale_node
=
VarNode
(
"output_scale_node"
)
->
assert_is_op_output
(
quant_dequant_op_type
,
"OutScale"
);
auto
*
output_act_node
=
VarNode
(
"output_act_node"
)
->
assert_is_op_output
(
quant_dequant_op_type
,
"Out"
);
auto
*
quantized_node
=
OpNode
(
"quantized_node"
,
quantized_op_type_
)
->
assert_is_op
(
quantized_op_type_
);
quant_dequant_node
->
LinksFrom
({
input_scale_node
,
input_act_node
});
output_scale_node
->
LinksFrom
({
quant_dequant_node
});
output_act_node
->
LinksFrom
({
quant_dequant_node
});
quantized_node
->
LinksFrom
({
output_act_node
});
}
else
if
(
quantized_op_type_
==
"elementwise_add"
)
{
auto
*
input_scale_left_node
=
VarNode
(
"input_scale_left_node"
)
->
assert_is_op_input
(
quant_dequant_op_type
,
"InScale"
);
auto
*
input_act_left_node
=
VarNode
(
"input_act_left_node"
)
->
assert_is_op_input
(
quant_dequant_op_type
,
"X"
);
auto
*
quant_dequant_left_node
=
OpNode
(
"quant_dequant_left_node"
,
quant_dequant_op_type
)
->
assert_is_op
(
quant_dequant_op_type
);
auto
*
output_scale_left_node
=
VarNode
(
"output_scale_left_node"
)
->
assert_is_op_output
(
quant_dequant_op_type
,
"OutScale"
);
auto
*
output_act_left_node
=
VarNode
(
"output_act_left_node"
)
->
assert_is_op_output
(
quant_dequant_op_type
,
"Out"
)
->
assert_is_op_input
(
quantized_op_type_
,
"X"
);
quant_dequant_left_node
->
LinksFrom
(
{
input_scale_left_node
,
input_act_left_node
});
output_scale_left_node
->
LinksFrom
({
quant_dequant_left_node
});
output_act_left_node
->
LinksFrom
({
quant_dequant_left_node
});
auto
*
input_scale_right_node
=
VarNode
(
"input_scale_right_node"
)
->
assert_is_op_input
(
quant_dequant_op_type
,
"InScale"
);
auto
*
input_act_right_node
=
VarNode
(
"input_act_right_node"
)
->
assert_is_op_input
(
quant_dequant_op_type
,
"X"
);
auto
*
quant_dequant_right_node
=
OpNode
(
"quant_dequant_right_node"
,
quant_dequant_op_type
)
->
assert_is_op
(
quant_dequant_op_type
);
auto
*
output_scale_right_node
=
VarNode
(
"output_scale_right_node"
)
->
assert_is_op_output
(
quant_dequant_op_type
,
"OutScale"
);
auto
*
output_act_right_node
=
VarNode
(
"output_act_right_node"
)
->
assert_is_op_output
(
quant_dequant_op_type
,
"Out"
)
->
assert_is_op_input
(
quantized_op_type_
,
"Y"
);
quant_dequant_right_node
->
LinksFrom
(
{
input_scale_right_node
,
input_act_right_node
});
output_scale_right_node
->
LinksFrom
({
quant_dequant_right_node
});
output_act_right_node
->
LinksFrom
({
quant_dequant_right_node
});
auto
*
quantized_node
=
OpNode
(
"quantized_node"
,
quantized_op_type_
)
->
assert_is_op
(
quantized_op_type_
);
quantized_node
->
LinksFrom
({
output_act_left_node
,
output_act_right_node
});
}
else
{
LOG
(
FATAL
)
<<
"No support quantized_op_type:"
<<
quantized_op_type_
;
}
VLOG
(
4
)
<<
"DeleteQuantDequantOpFuser BuildPattern op_type:"
<<
quantized_op_type_
;
auto
*
input_scale_node
=
VarNode
(
"input_scale_node"
)
->
assert_is_op_input
(
quant_dequant_op_type
,
"InScale"
);
auto
*
input_act_node
=
VarNode
(
"input_act_node"
)
->
assert_is_op_input
(
quant_dequant_op_type
,
"X"
);
auto
*
quant_dequant_node
=
OpNode
(
"quant_dequant_node"
,
quant_dequant_op_type
)
->
assert_is_op
(
quant_dequant_op_type
);
auto
*
output_scale_node
=
VarNode
(
"output_scale_node"
)
->
assert_is_op_output
(
quant_dequant_op_type
,
"OutScale"
);
auto
*
output_act_node
=
VarNode
(
"output_act_node"
)
->
assert_is_op_output
(
quant_dequant_op_type
,
"Out"
);
quant_dequant_node
->
LinksFrom
({
input_scale_node
,
input_act_node
});
output_scale_node
->
LinksFrom
({
quant_dequant_node
});
output_act_node
->
LinksFrom
({
quant_dequant_node
});
}
void
DeleteQuantDequantOpFuser
::
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
if
(
quantized_op_type_
==
"pool2d"
||
quantized_op_type_
==
"softmax"
)
{
auto
*
input_scale_node
=
matched
.
at
(
"input_scale_node"
);
auto
*
input_act_node
=
matched
.
at
(
"input_act_node"
);
auto
*
quant_dequant_node
=
matched
.
at
(
"quant_dequant_node"
);
auto
*
output_scale_node
=
matched
.
at
(
"output_scale_node"
);
auto
*
output_act_node
=
matched
.
at
(
"output_act_node"
);
auto
*
quantized_node
=
matched
.
at
(
"quantized_node"
);
// obtain values, save values and relink node
int
bit_length
=
quant_dequant_node
->
stmt
()
->
op_info
()
->
GetAttr
<
int
>
(
"bit_length"
);
int
range
=
((
1
<<
(
bit_length
-
1
))
-
1
);
auto
*
scope
=
quant_dequant_node
->
stmt
()
->
op
()
->
scope
();
auto
*
scale_tensor
=
scope
->
FindVar
(
output_scale_node
->
arg
()
->
name
)
->
GetMutable
<
lite
::
Tensor
>
();
float
scale_value
=
scale_tensor
->
data
<
float
>
()[
0
]
/
range
;
auto
*
op_desc
=
quantized_node
->
stmt
()
->
mutable_op_info
();
op_desc
->
SetAttr
<
int
>
(
"bit_length"
,
bit_length
);
op_desc
->
SetAttr
<
float
>
(
"input_scale"
,
scale_value
);
op_desc
->
SetInput
(
"X"
,
{
input_act_node
->
arg
()
->
name
});
IR_NODE_LINK_TO
(
input_act_node
,
quantized_node
)
auto
update_op_desc
=
*
quantized_node
->
stmt
()
->
mutable_op_info
();
quantized_node
->
stmt
()
->
ResetOp
(
update_op_desc
,
graph
->
valid_places
());
// delete nodes and edges
std
::
unordered_set
<
const
Node
*>
nodes2rm
=
{
input_scale_node
,
quant_dequant_node
,
output_scale_node
,
output_act_node
};
GraphSafeRemoveNodes
(
graph
,
nodes2rm
);
}
else
if
(
quantized_op_type_
==
"elementwise_add"
)
{
auto
*
input_scale_left_node
=
matched
.
at
(
"input_scale_left_node"
);
auto
*
input_act_left_node
=
matched
.
at
(
"input_act_left_node"
);
auto
*
quant_dequant_left_node
=
matched
.
at
(
"quant_dequant_left_node"
);
auto
*
output_scale_left_node
=
matched
.
at
(
"output_scale_left_node"
);
auto
*
output_act_left_node
=
matched
.
at
(
"output_act_left_node"
);
auto
*
input_scale_right_node
=
matched
.
at
(
"input_scale_right_node"
);
auto
*
input_act_right_node
=
matched
.
at
(
"input_act_right_node"
);
auto
*
quant_dequant_right_node
=
matched
.
at
(
"quant_dequant_right_node"
);
auto
*
output_scale_right_node
=
matched
.
at
(
"output_scale_right_node"
);
auto
*
output_act_right_node
=
matched
.
at
(
"output_act_right_node"
);
auto
*
quantized_node
=
matched
.
at
(
"quantized_node"
);
// obtain values, save values and relink node
int
bit_length
=
quant_dequant_left_node
->
stmt
()
->
op_info
()
->
GetAttr
<
int
>
(
"bit_length"
);
int
range
=
((
1
<<
(
bit_length
-
1
))
-
1
);
auto
*
scope
=
quant_dequant_left_node
->
stmt
()
->
op
()
->
scope
();
auto
*
left_scale_tensor
=
scope
->
FindVar
(
output_scale_left_node
->
arg
()
->
name
)
->
GetMutable
<
lite
::
Tensor
>
();
float
left_scale_value
=
left_scale_tensor
->
data
<
float
>
()[
0
]
/
range
;
auto
*
right_scale_tensor
=
scope
->
FindVar
(
output_scale_right_node
->
arg
()
->
name
)
->
GetMutable
<
lite
::
Tensor
>
();
float
right_scale_value
=
right_scale_tensor
->
data
<
float
>
()[
0
]
/
range
;
auto
*
op_desc
=
quantized_node
->
stmt
()
->
mutable_op_info
();
op_desc
->
SetAttr
<
int
>
(
"bit_length"
,
bit_length
);
op_desc
->
SetAttr
<
float
>
(
"x_input_scale"
,
left_scale_value
);
op_desc
->
SetAttr
<
float
>
(
"y_input_scale"
,
right_scale_value
);
op_desc
->
SetInput
(
"X"
,
{
input_act_left_node
->
arg
()
->
name
});
op_desc
->
SetInput
(
"Y"
,
{
input_act_right_node
->
arg
()
->
name
});
IR_NODE_LINK_TO
(
input_act_left_node
,
quantized_node
)
IR_NODE_LINK_TO
(
input_act_right_node
,
quantized_node
)
auto
update_op_desc
=
*
quantized_node
->
stmt
()
->
mutable_op_info
();
quantized_node
->
stmt
()
->
ResetOp
(
update_op_desc
,
graph
->
valid_places
());
// delete nodes and edges
std
::
unordered_set
<
const
Node
*>
nodes2rm
=
{
input_scale_left_node
,
quant_dequant_left_node
,
output_scale_left_node
,
output_act_left_node
,
input_scale_right_node
,
quant_dequant_right_node
,
output_scale_right_node
,
output_act_right_node
};
GraphSafeRemoveNodes
(
graph
,
nodes2rm
);
}
else
{
LOG
(
FATAL
)
<<
"No support quantized_op_type:"
<<
quantized_op_type_
;
auto
*
input_scale_node
=
matched
.
at
(
"input_scale_node"
);
auto
*
input_act_node
=
matched
.
at
(
"input_act_node"
);
auto
*
quant_dequant_node
=
matched
.
at
(
"quant_dequant_node"
);
auto
*
output_scale_node
=
matched
.
at
(
"output_scale_node"
);
auto
*
output_act_node
=
matched
.
at
(
"output_act_node"
);
auto
input_act_name
=
input_act_node
->
arg
()
->
name
;
auto
output_act_name
=
output_act_node
->
arg
()
->
name
;
// Get scale value from scale var node
int
bit_length
=
quant_dequant_node
->
stmt
()
->
op_info
()
->
GetAttr
<
int
>
(
"bit_length"
);
int
range
=
((
1
<<
(
bit_length
-
1
))
-
1
);
auto
*
scope
=
quant_dequant_node
->
stmt
()
->
op
()
->
scope
();
auto
*
scale_tensor
=
scope
->
FindVar
(
output_scale_node
->
arg
()
->
name
)
->
GetMutable
<
lite
::
Tensor
>
();
float
scale_value
=
scale_tensor
->
data
<
float
>
()[
0
]
/
range
;
auto
quantized_nodes
=
output_act_node
->
outlinks
;
for
(
auto
*
quantized_node
:
quantized_nodes
)
{
// Save quantization info in op_info attr
auto
op_info
=
*
quantized_node
->
stmt
()
->
op_info
();
std
::
string
argname
;
int
index
;
op_info
.
GetInputArgname
(
output_act_name
,
&
argname
);
op_info
.
GetInputIndex
(
output_act_name
,
&
index
);
op_info
.
SetAttr
<
float
>
(
argname
+
std
::
to_string
(
index
)
+
"_input_scale"
,
scale_value
);
op_info
.
SetAttr
<
float
>
(
"input_scale"
,
scale_value
);
// Save it for now
op_info
.
SetAttr
<
int
>
(
"bit_length"
,
bit_length
);
op_info
.
UpdateAllInputs
(
output_act_name
,
input_act_name
);
quantized_node
->
stmt
()
->
ResetOp
(
op_info
,
graph
->
valid_places
());
IR_NODE_LINK_TO
(
input_act_node
,
quantized_node
);
}
// delete nodes and edges
std
::
unordered_set
<
const
Node
*>
nodes2rm
=
{
input_scale_node
,
quant_dequant_node
,
output_scale_node
,
output_act_node
};
GraphSafeRemoveNodes
(
graph
,
nodes2rm
);
}
cpp
::
OpDesc
DeleteQuantDequantOpFuser
::
GenOpDesc
(
const
key2nodes_t
&
matched
)
{
...
...
lite/core/mir/fusion/quant_dequant_op_fuser.h
浏览文件 @
51e9898d
...
...
@@ -87,24 +87,16 @@ class ChannelWiseDequantOpFuser : public FuseBase {
};
/* The pattern like "fake_quantize_dequantize_moving_average_abs_max +
* pooled/elementwise_add" can be deteted by this fuser. The fuser
* extract the input_scale form fake_quant_dequant_op and save into
* the quantized_op. Besides, the fuser delete fake_quant_dequant_op in
* the graph.
* quantized_op" can be deteted by this fuser. The fuser modifies the input
* scale for the quantized_op and deletes the fake_quant_dequant_op.
*/
class
DeleteQuantDequantOpFuser
:
public
FuseBase
{
public:
explicit
DeleteQuantDequantOpFuser
(
const
std
::
string
&
quantized_op_type
)
:
quantized_op_type_
(
quantized_op_type
)
{}
void
BuildPattern
()
override
;
void
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
override
;
private:
cpp
::
OpDesc
GenOpDesc
(
const
key2nodes_t
&
matched
)
override
;
private:
std
::
string
quantized_op_type_
{};
};
}
// namespace fusion
...
...
lite/core/op_lite.h
浏览文件 @
51e9898d
...
...
@@ -225,6 +225,32 @@ class OpInfo : public cpp::OpDesc {
return
false
;
}
// For the input variable name, find the index of the corresponding
// input argname
bool
GetInputIndex
(
const
std
::
string
&
value_name
,
int
*
out
)
const
{
for
(
auto
&
item
:
inputs_
)
{
auto
it
=
std
::
find
(
item
.
second
.
begin
(),
item
.
second
.
end
(),
value_name
);
if
(
it
!=
item
.
second
.
end
())
{
*
out
=
it
-
item
.
second
.
begin
();
return
true
;
}
}
return
false
;
}
// For the output variable name, find the index of the corresponding
// output argname
bool
GetOutputIndex
(
const
std
::
string
&
value_name
,
int
*
out
)
const
{
for
(
auto
&
item
:
outputs_
)
{
auto
it
=
std
::
find
(
item
.
second
.
begin
(),
item
.
second
.
end
(),
value_name
);
if
(
it
!=
item
.
second
.
end
())
{
*
out
=
it
-
item
.
second
.
begin
();
return
true
;
}
}
return
false
;
}
void
UpdateAllInputs
(
const
std
::
string
&
from
,
const
std
::
string
&
to
)
{
for
(
auto
&
item
:
inputs_
)
{
for
(
auto
&
var
:
item
.
second
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录