Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
51e9898d
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
332
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
51e9898d
编写于
4月 03, 2020
作者:
C
cc
提交者:
GitHub
4月 03, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Modify quant_dequant_fuse_pass to process quant_dequant_op, test=develop (#3341)
上级
add162dc
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
127 addition
and
204 deletion
+127
-204
lite/api/benchmark.cc
lite/api/benchmark.cc
+7
-13
lite/api/cxx_api.cc
lite/api/cxx_api.cc
+4
-1
lite/core/mir/fusion/quant_dequant_fuse_pass.cc
lite/core/mir/fusion/quant_dequant_fuse_pass.cc
+3
-5
lite/core/mir/fusion/quant_dequant_op_fuser.cc
lite/core/mir/fusion/quant_dequant_op_fuser.cc
+85
-175
lite/core/mir/fusion/quant_dequant_op_fuser.h
lite/core/mir/fusion/quant_dequant_op_fuser.h
+2
-10
lite/core/op_lite.h
lite/core/op_lite.h
+26
-0
未找到文件。
lite/api/benchmark.cc
浏览文件 @
51e9898d
...
@@ -44,7 +44,10 @@ DEFINE_string(input_shape,
...
@@ -44,7 +44,10 @@ DEFINE_string(input_shape,
"set input shapes according to the model, "
"set input shapes according to the model, "
"separated by colon and comma, "
"separated by colon and comma, "
"such as 1,3,244,244"
);
"such as 1,3,244,244"
);
DEFINE_string
(
input_img_path
,
""
,
"the path of input image"
);
DEFINE_string
(
input_img_path
,
""
,
"the path of input image, if not set "
"input_img_path, the input of model will be 1.0."
);
DEFINE_int32
(
warmup
,
0
,
"warmup times"
);
DEFINE_int32
(
warmup
,
0
,
"warmup times"
);
DEFINE_int32
(
repeats
,
1
,
"repeats times"
);
DEFINE_int32
(
repeats
,
1
,
"repeats times"
);
DEFINE_int32
(
power_mode
,
DEFINE_int32
(
power_mode
,
...
@@ -57,16 +60,11 @@ DEFINE_int32(power_mode,
...
@@ -57,16 +60,11 @@ DEFINE_int32(power_mode,
DEFINE_int32
(
threads
,
1
,
"threads num"
);
DEFINE_int32
(
threads
,
1
,
"threads num"
);
DEFINE_string
(
result_filename
,
DEFINE_string
(
result_filename
,
"result.txt"
,
"result.txt"
,
"save benchmark "
"save the inference time to the file."
);
"result to the file"
);
DEFINE_bool
(
run_model_optimize
,
DEFINE_bool
(
run_model_optimize
,
false
,
false
,
"if set true, apply model_optimize_tool to "
"if set true, apply model_optimize_tool to "
"model and use optimized model to test. "
);
"model and use optimized model to test. "
);
DEFINE_bool
(
is_quantized_model
,
false
,
"if set true, "
"test the performance of the quantized model. "
);
namespace
paddle
{
namespace
paddle
{
namespace
lite_api
{
namespace
lite_api
{
...
@@ -87,10 +85,6 @@ void OutputOptModel(const std::string& save_optimized_model_dir) {
...
@@ -87,10 +85,6 @@ void OutputOptModel(const std::string& save_optimized_model_dir) {
std
::
vector
<
Place
>
vaild_places
=
{
std
::
vector
<
Place
>
vaild_places
=
{
Place
{
TARGET
(
kARM
),
PRECISION
(
kFloat
)},
Place
{
TARGET
(
kARM
),
PRECISION
(
kFloat
)},
};
};
if
(
FLAGS_is_quantized_model
)
{
vaild_places
.
insert
(
vaild_places
.
begin
(),
Place
{
TARGET
(
kARM
),
PRECISION
(
kInt8
)});
}
config
.
set_valid_places
(
vaild_places
);
config
.
set_valid_places
(
vaild_places
);
auto
predictor
=
lite_api
::
CreatePaddlePredictor
(
config
);
auto
predictor
=
lite_api
::
CreatePaddlePredictor
(
config
);
...
@@ -181,8 +175,8 @@ void Run(const std::vector<int64_t>& input_shape,
...
@@ -181,8 +175,8 @@ void Run(const std::vector<int64_t>& input_shape,
int
main
(
int
argc
,
char
**
argv
)
{
int
main
(
int
argc
,
char
**
argv
)
{
gflags
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
true
);
gflags
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
true
);
if
(
FLAGS_model_dir
==
""
||
FLAGS_result_filename
==
""
)
{
if
(
FLAGS_model_dir
==
""
)
{
LOG
(
INFO
)
<<
"
p
lease run ./benchmark_bin --help to obtain usage."
;
LOG
(
INFO
)
<<
"
P
lease run ./benchmark_bin --help to obtain usage."
;
exit
(
0
);
exit
(
0
);
}
}
...
...
lite/api/cxx_api.cc
浏览文件 @
51e9898d
...
@@ -295,6 +295,8 @@ void Predictor::Build(const cpp::ProgramDesc &desc,
...
@@ -295,6 +295,8 @@ void Predictor::Build(const cpp::ProgramDesc &desc,
inner_places
.
emplace_back
(
inner_places
.
emplace_back
(
TARGET
(
kHost
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
));
TARGET
(
kHost
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
));
// Analysis whether the modle is quantized.
// For quantized model, add place(arm, int8) to inner_places
const
std
::
vector
<
std
::
string
>
quant_dequant_op
=
{
const
std
::
vector
<
std
::
string
>
quant_dequant_op
=
{
"fake_quantize_abs_max"
,
"fake_quantize_abs_max"
,
"fake_quantize_range_abs_max"
,
"fake_quantize_range_abs_max"
,
...
@@ -317,7 +319,8 @@ void Predictor::Build(const cpp::ProgramDesc &desc,
...
@@ -317,7 +319,8 @@ void Predictor::Build(const cpp::ProgramDesc &desc,
}
}
}
}
if
(
is_quantized_model
)
{
if
(
is_quantized_model
)
{
inner_places
.
emplace_back
(
Place
{
TARGET
(
kARM
),
PRECISION
(
kInt8
)});
inner_places
.
insert
(
inner_places
.
begin
(),
Place
{
TARGET
(
kARM
),
PRECISION
(
kInt8
)});
}
}
Program
program
(
desc
,
scope_
,
inner_places
);
Program
program
(
desc
,
scope_
,
inner_places
);
...
...
lite/core/mir/fusion/quant_dequant_fuse_pass.cc
浏览文件 @
51e9898d
...
@@ -44,11 +44,9 @@ void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
...
@@ -44,11 +44,9 @@ void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fuser
(
graph
.
get
());
fuser
(
graph
.
get
());
}
}
// delete quant_dequant_node
// process quant_dequant_node
for
(
auto
op_type
:
{
"pool2d"
,
"softmax"
,
"elementwise_add"
})
{
fusion
::
DeleteQuantDequantOpFuser
dqd_fuser
;
fusion
::
DeleteQuantDequantOpFuser
fuser
(
op_type
);
dqd_fuser
(
graph
.
get
());
fuser
(
graph
.
get
());
}
}
}
}
// namespace mir
}
// namespace mir
...
...
lite/core/mir/fusion/quant_dequant_op_fuser.cc
浏览文件 @
51e9898d
...
@@ -50,7 +50,7 @@ void DeleteQuantOpFuser::InsertNewNode(SSAGraph* graph,
...
@@ -50,7 +50,7 @@ void DeleteQuantOpFuser::InsertNewNode(SSAGraph* graph,
auto
*
output_scale_node
=
matched
.
at
(
"output_scale_node"
);
auto
*
output_scale_node
=
matched
.
at
(
"output_scale_node"
);
auto
*
output_act_node
=
matched
.
at
(
"output_act_node"
);
auto
*
output_act_node
=
matched
.
at
(
"output_act_node"
);
// obtain
values, save value
s and relink node
// obtain
scale, save attr
s and relink node
int
bit_length
=
quant_node
->
stmt
()
->
op_info
()
->
GetAttr
<
int
>
(
"bit_length"
);
int
bit_length
=
quant_node
->
stmt
()
->
op_info
()
->
GetAttr
<
int
>
(
"bit_length"
);
int
range
=
((
1
<<
(
bit_length
-
1
))
-
1
);
int
range
=
((
1
<<
(
bit_length
-
1
))
-
1
);
auto
*
scope
=
quant_node
->
stmt
()
->
op
()
->
scope
();
auto
*
scope
=
quant_node
->
stmt
()
->
op
()
->
scope
();
...
@@ -58,11 +58,22 @@ void DeleteQuantOpFuser::InsertNewNode(SSAGraph* graph,
...
@@ -58,11 +58,22 @@ void DeleteQuantOpFuser::InsertNewNode(SSAGraph* graph,
->
GetMutable
<
lite
::
Tensor
>
();
->
GetMutable
<
lite
::
Tensor
>
();
float
scale_value
=
scale_tensor
->
data
<
float
>
()[
0
]
/
range
;
float
scale_value
=
scale_tensor
->
data
<
float
>
()[
0
]
/
range
;
auto
in_act_name
=
input_act_node
->
arg
()
->
name
;
auto
out_act_name
=
output_act_node
->
arg
()
->
name
;
auto
outlinks
=
output_act_node
->
outlinks
;
auto
outlinks
=
output_act_node
->
outlinks
;
for
(
auto
*
quantized_node
:
outlinks
)
{
for
(
auto
*
quantized_node
:
outlinks
)
{
auto
*
op_desc
=
quantized_node
->
stmt
()
->
mutable_op_info
();
// save input scale in quantized op by input argname + index
op_desc
->
SetAttr
<
int
>
(
"bit_length"
,
bit_length
);
auto
op_desc
=
*
quantized_node
->
stmt
()
->
mutable_op_info
();
op_desc
->
SetAttr
<
float
>
(
"input_scale"
,
scale_value
);
std
::
string
argname
;
int
index
;
op_desc
.
GetInputArgname
(
out_act_name
,
&
argname
);
op_desc
.
GetInputIndex
(
out_act_name
,
&
index
);
op_desc
.
SetAttr
<
float
>
(
argname
+
std
::
to_string
(
index
)
+
"_input_scale"
,
scale_value
);
op_desc
.
SetAttr
<
float
>
(
"input_scale"
,
scale_value
);
// save it for now
op_desc
.
SetAttr
<
int
>
(
"bit_length"
,
bit_length
);
op_desc
.
UpdateAllInputs
(
out_act_name
,
in_act_name
);
quantized_node
->
stmt
()
->
ResetOp
(
op_desc
,
graph
->
valid_places
());
IR_NODE_LINK_TO
(
input_act_node
,
quantized_node
)
IR_NODE_LINK_TO
(
input_act_node
,
quantized_node
)
}
}
...
@@ -125,19 +136,18 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
...
@@ -125,19 +136,18 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
auto
*
dequant_op
=
matched
.
at
(
"dequant_op"
);
auto
*
dequant_op
=
matched
.
at
(
"dequant_op"
);
auto
*
dequant_op_out
=
matched
.
at
(
"dequant_op_out"
);
auto
*
dequant_op_out
=
matched
.
at
(
"dequant_op_out"
);
// obtain
input_scale and weight_scal
e
// obtain
weight_scale from max_rang
e
auto
*
scope
=
quantized_op
->
stmt
()
->
op
()
->
scope
();
auto
*
scope
=
quantized_op
->
stmt
()
->
op
()
->
scope
();
auto
&
valid_places
=
quantized_op
->
stmt
()
->
op
()
->
valid_places
();
auto
&
valid_places
=
quantized_op
->
stmt
()
->
op
()
->
valid_places
();
int
bit_length
=
quantized_op
->
stmt
()
->
op_info
()
->
GetAttr
<
int
>
(
"bit_length"
);
int
bit_length
=
quantized_op
->
stmt
()
->
op_info
()
->
GetAttr
<
int
>
(
"bit_length"
);
int
range
=
((
1
<<
(
bit_length
-
1
))
-
1
);
int
range
=
((
1
<<
(
bit_length
-
1
))
-
1
);
float
input_scale
=
quantized_op
->
stmt
()
->
op_info
()
->
GetAttr
<
float
>
(
"input_scale"
);
float
max_range
=
dequant_op
->
stmt
()
->
op_info
()
->
GetAttr
<
float
>
(
"max_range"
);
float
max_range
=
dequant_op
->
stmt
()
->
op_info
()
->
GetAttr
<
float
>
(
"max_range"
);
float
whole_weight_scale
=
float
whole_weight_scale
=
static_cast
<
float
>
(
range
*
range
)
/
max_range
/
range
;
static_cast
<
float
>
(
range
*
range
)
/
max_range
/
range
;
// max_range = range * range / max(abs(weight))
// As: max_range = range * range / max(abs(weight))
// weight_scale = range * range / (range * range / max(abs(weight))) / range
// So: whole_weight_scale
// = max(abs(weight)) / range
// = range * range / (range * range / max(abs(weight))) / range
// = max(abs(weight)) / range
// set op desc
// set op desc
cpp
::
OpDesc
op_desc
=
*
quantized_op
->
stmt
()
->
op_info
();
cpp
::
OpDesc
op_desc
=
*
quantized_op
->
stmt
()
->
op_info
();
...
@@ -153,7 +163,7 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
...
@@ -153,7 +163,7 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
// Conv weight shape: Cout * Cin * kh * hw, the weight_scale_size should
// Conv weight shape: Cout * Cin * kh * hw, the weight_scale_size should
// be Cout.
// be Cout.
weight_scale_size
=
quantized_weight_t
->
dims
()[
0
];
weight_scale_size
=
quantized_weight_t
->
dims
()[
0
];
}
else
if
(
quantized_op_type_
==
"mul"
)
{
}
else
if
(
quantized_op_type_
==
"mul"
||
quantized_op_type_
==
"matmul"
)
{
op_desc
.
SetInput
(
"X"
,
{
quantized_op_input
->
arg
()
->
name
});
op_desc
.
SetInput
(
"X"
,
{
quantized_op_input
->
arg
()
->
name
});
op_desc
.
SetOutput
(
"Out"
,
{
dequant_op_out
->
arg
()
->
name
});
op_desc
.
SetOutput
(
"Out"
,
{
dequant_op_out
->
arg
()
->
name
});
// Fc weight: Cin * Cout, the weight_scale_size should be Cout.
// Fc weight: Cin * Cout, the weight_scale_size should be Cout.
...
@@ -163,7 +173,6 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
...
@@ -163,7 +173,6 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
weight_scale
.
push_back
(
whole_weight_scale
);
weight_scale
.
push_back
(
whole_weight_scale
);
}
}
op_desc
.
SetAttr
(
"enable_int8"
,
true
);
op_desc
.
SetAttr
(
"enable_int8"
,
true
);
op_desc
.
SetAttr
(
"input_scale"
,
input_scale
);
op_desc
.
SetAttr
(
"weight_scale"
,
weight_scale
);
op_desc
.
SetAttr
(
"weight_scale"
,
weight_scale
);
// change the weight from the float type to int8 type.
// change the weight from the float type to int8 type.
...
@@ -209,6 +218,7 @@ void ChannelWiseDequantOpFuser::BuildPattern() {
...
@@ -209,6 +218,7 @@ void ChannelWiseDequantOpFuser::BuildPattern() {
->
assert_is_op_output
(
quantized_op_type_
)
->
assert_is_op_output
(
quantized_op_type_
)
->
assert_is_op_input
(
dequant_op_type
,
"X"
)
->
assert_is_op_input
(
dequant_op_type
,
"X"
)
->
AsIntermediate
();
->
AsIntermediate
();
// The scale var_node of input activation is deleted in DeleteQuantOpFuser
auto
*
dequant_op_channel_scale
=
VarNode
(
"dequant_op_channel_scale"
)
auto
*
dequant_op_channel_scale
=
VarNode
(
"dequant_op_channel_scale"
)
->
assert_is_op_input
(
dequant_op_type
)
->
assert_is_op_input
(
dequant_op_type
)
->
AsIntermediate
();
->
AsIntermediate
();
...
@@ -237,11 +247,9 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph,
...
@@ -237,11 +247,9 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph,
auto
*
dequant_op
=
matched
.
at
(
"dequant_op"
);
auto
*
dequant_op
=
matched
.
at
(
"dequant_op"
);
auto
*
dequant_op_out
=
matched
.
at
(
"dequant_op_out"
);
auto
*
dequant_op_out
=
matched
.
at
(
"dequant_op_out"
);
// obtain input
_scale and weight_scale
// obtain input
weight_scale from fake_dequant op
auto
*
scope
=
quantized_op
->
stmt
()
->
op
()
->
scope
();
auto
*
scope
=
quantized_op
->
stmt
()
->
op
()
->
scope
();
auto
&
valid_places
=
quantized_op
->
stmt
()
->
op
()
->
valid_places
();
auto
&
valid_places
=
quantized_op
->
stmt
()
->
op
()
->
valid_places
();
float
input_scale
=
quantized_op
->
stmt
()
->
op_info
()
->
GetAttr
<
float
>
(
"input_scale"
);
std
::
vector
<
float
>
weight_scale
;
std
::
vector
<
float
>
weight_scale
;
std
::
vector
<
int
>
quant_bits
=
std
::
vector
<
int
>
quant_bits
=
...
@@ -258,11 +266,15 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph,
...
@@ -258,11 +266,15 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph,
// set op desc
// set op desc
cpp
::
OpDesc
op_desc
=
*
quantized_op
->
stmt
()
->
op_info
();
cpp
::
OpDesc
op_desc
=
*
quantized_op
->
stmt
()
->
op_info
();
op_desc
.
SetInput
(
"Input"
,
{
quantized_op_input
->
arg
()
->
name
});
if
(
quantized_op_type_
==
"conv2d"
||
op_desc
.
SetOutput
(
"Output"
,
{
dequant_op_out
->
arg
()
->
name
});
quantized_op_type_
==
"depthwise_conv2d"
)
{
op_desc
.
SetInput
(
"Input"
,
{
quantized_op_input
->
arg
()
->
name
});
op_desc
.
SetOutput
(
"Output"
,
{
dequant_op_out
->
arg
()
->
name
});
}
else
if
(
quantized_op_type_
==
"mul"
||
quantized_op_type_
==
"matmul"
)
{
op_desc
.
SetInput
(
"X"
,
{
quantized_op_input
->
arg
()
->
name
});
op_desc
.
SetOutput
(
"Out"
,
{
dequant_op_out
->
arg
()
->
name
});
}
op_desc
.
SetAttr
(
"enable_int8"
,
true
);
op_desc
.
SetAttr
(
"enable_int8"
,
true
);
op_desc
.
SetAttr
(
"input_scale"
,
input_scale
);
op_desc
.
SetAttr
(
"weight_scale"
,
weight_scale
);
op_desc
.
SetAttr
(
"weight_scale"
,
weight_scale
);
// change the weight from the float type to int8 type.
// change the weight from the float type to int8 type.
...
@@ -297,167 +309,65 @@ cpp::OpDesc ChannelWiseDequantOpFuser::GenOpDesc(const key2nodes_t& matched) {
...
@@ -297,167 +309,65 @@ cpp::OpDesc ChannelWiseDequantOpFuser::GenOpDesc(const key2nodes_t& matched) {
void
DeleteQuantDequantOpFuser
::
BuildPattern
()
{
void
DeleteQuantDequantOpFuser
::
BuildPattern
()
{
std
::
string
quant_dequant_op_type
=
std
::
string
quant_dequant_op_type
=
"fake_quantize_dequantize_moving_average_abs_max"
;
"fake_quantize_dequantize_moving_average_abs_max"
;
if
(
quantized_op_type_
==
"pool2d"
||
quantized_op_type_
==
"softmax"
)
{
auto
*
input_scale_node
=
auto
*
input_scale_node
=
VarNode
(
"input_scale_node"
)
VarNode
(
"input_scale_node"
)
->
assert_is_op_input
(
quant_dequant_op_type
,
"InScale"
);
->
assert_is_op_input
(
quant_dequant_op_type
,
"InScale"
);
auto
*
input_act_node
=
auto
*
input_act_node
=
VarNode
(
"input_act_node"
)
VarNode
(
"input_act_node"
)
->
assert_is_op_input
(
quant_dequant_op_type
,
"X"
);
->
assert_is_op_input
(
quant_dequant_op_type
,
"X"
);
auto
*
quant_dequant_node
=
OpNode
(
"quant_dequant_node"
,
quant_dequant_op_type
)
auto
*
quant_dequant_node
=
->
assert_is_op
(
quant_dequant_op_type
);
OpNode
(
"quant_dequant_node"
,
quant_dequant_op_type
)
auto
*
output_scale_node
=
->
assert_is_op
(
quant_dequant_op_type
);
VarNode
(
"output_scale_node"
)
auto
*
output_scale_node
=
->
assert_is_op_output
(
quant_dequant_op_type
,
"OutScale"
);
VarNode
(
"output_scale_node"
)
auto
*
output_act_node
=
->
assert_is_op_output
(
quant_dequant_op_type
,
"OutScale"
);
VarNode
(
"output_act_node"
)
auto
*
output_act_node
=
->
assert_is_op_output
(
quant_dequant_op_type
,
"Out"
);
VarNode
(
"output_act_node"
)
->
assert_is_op_output
(
quant_dequant_op_type
,
"Out"
);
quant_dequant_node
->
LinksFrom
({
input_scale_node
,
input_act_node
});
auto
*
quantized_node
=
OpNode
(
"quantized_node"
,
quantized_op_type_
)
output_scale_node
->
LinksFrom
({
quant_dequant_node
});
->
assert_is_op
(
quantized_op_type_
);
output_act_node
->
LinksFrom
({
quant_dequant_node
});
quant_dequant_node
->
LinksFrom
({
input_scale_node
,
input_act_node
});
output_scale_node
->
LinksFrom
({
quant_dequant_node
});
output_act_node
->
LinksFrom
({
quant_dequant_node
});
quantized_node
->
LinksFrom
({
output_act_node
});
}
else
if
(
quantized_op_type_
==
"elementwise_add"
)
{
auto
*
input_scale_left_node
=
VarNode
(
"input_scale_left_node"
)
->
assert_is_op_input
(
quant_dequant_op_type
,
"InScale"
);
auto
*
input_act_left_node
=
VarNode
(
"input_act_left_node"
)
->
assert_is_op_input
(
quant_dequant_op_type
,
"X"
);
auto
*
quant_dequant_left_node
=
OpNode
(
"quant_dequant_left_node"
,
quant_dequant_op_type
)
->
assert_is_op
(
quant_dequant_op_type
);
auto
*
output_scale_left_node
=
VarNode
(
"output_scale_left_node"
)
->
assert_is_op_output
(
quant_dequant_op_type
,
"OutScale"
);
auto
*
output_act_left_node
=
VarNode
(
"output_act_left_node"
)
->
assert_is_op_output
(
quant_dequant_op_type
,
"Out"
)
->
assert_is_op_input
(
quantized_op_type_
,
"X"
);
quant_dequant_left_node
->
LinksFrom
(
{
input_scale_left_node
,
input_act_left_node
});
output_scale_left_node
->
LinksFrom
({
quant_dequant_left_node
});
output_act_left_node
->
LinksFrom
({
quant_dequant_left_node
});
auto
*
input_scale_right_node
=
VarNode
(
"input_scale_right_node"
)
->
assert_is_op_input
(
quant_dequant_op_type
,
"InScale"
);
auto
*
input_act_right_node
=
VarNode
(
"input_act_right_node"
)
->
assert_is_op_input
(
quant_dequant_op_type
,
"X"
);
auto
*
quant_dequant_right_node
=
OpNode
(
"quant_dequant_right_node"
,
quant_dequant_op_type
)
->
assert_is_op
(
quant_dequant_op_type
);
auto
*
output_scale_right_node
=
VarNode
(
"output_scale_right_node"
)
->
assert_is_op_output
(
quant_dequant_op_type
,
"OutScale"
);
auto
*
output_act_right_node
=
VarNode
(
"output_act_right_node"
)
->
assert_is_op_output
(
quant_dequant_op_type
,
"Out"
)
->
assert_is_op_input
(
quantized_op_type_
,
"Y"
);
quant_dequant_right_node
->
LinksFrom
(
{
input_scale_right_node
,
input_act_right_node
});
output_scale_right_node
->
LinksFrom
({
quant_dequant_right_node
});
output_act_right_node
->
LinksFrom
({
quant_dequant_right_node
});
auto
*
quantized_node
=
OpNode
(
"quantized_node"
,
quantized_op_type_
)
->
assert_is_op
(
quantized_op_type_
);
quantized_node
->
LinksFrom
({
output_act_left_node
,
output_act_right_node
});
}
else
{
LOG
(
FATAL
)
<<
"No support quantized_op_type:"
<<
quantized_op_type_
;
}
VLOG
(
4
)
<<
"DeleteQuantDequantOpFuser BuildPattern op_type:"
<<
quantized_op_type_
;
}
}
void
DeleteQuantDequantOpFuser
::
InsertNewNode
(
SSAGraph
*
graph
,
void
DeleteQuantDequantOpFuser
::
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
const
key2nodes_t
&
matched
)
{
if
(
quantized_op_type_
==
"pool2d"
||
quantized_op_type_
==
"softmax"
)
{
auto
*
input_scale_node
=
matched
.
at
(
"input_scale_node"
);
auto
*
input_scale_node
=
matched
.
at
(
"input_scale_node"
);
auto
*
input_act_node
=
matched
.
at
(
"input_act_node"
);
auto
*
input_act_node
=
matched
.
at
(
"input_act_node"
);
auto
*
quant_dequant_node
=
matched
.
at
(
"quant_dequant_node"
);
auto
*
quant_dequant_node
=
matched
.
at
(
"quant_dequant_node"
);
auto
*
output_scale_node
=
matched
.
at
(
"output_scale_node"
);
auto
*
output_scale_node
=
matched
.
at
(
"output_scale_node"
);
auto
*
output_act_node
=
matched
.
at
(
"output_act_node"
);
auto
*
output_act_node
=
matched
.
at
(
"output_act_node"
);
auto
input_act_name
=
input_act_node
->
arg
()
->
name
;
auto
*
quantized_node
=
matched
.
at
(
"quantized_node"
);
auto
output_act_name
=
output_act_node
->
arg
()
->
name
;
// obtain values, save values and relink node
// Get scale value from scale var node
int
bit_length
=
int
bit_length
=
quant_dequant_node
->
stmt
()
->
op_info
()
->
GetAttr
<
int
>
(
"bit_length"
);
quant_dequant_node
->
stmt
()
->
op_info
()
->
GetAttr
<
int
>
(
"bit_length"
);
int
range
=
((
1
<<
(
bit_length
-
1
))
-
1
);
int
range
=
((
1
<<
(
bit_length
-
1
))
-
1
);
auto
*
scope
=
quant_dequant_node
->
stmt
()
->
op
()
->
scope
();
auto
*
scope
=
quant_dequant_node
->
stmt
()
->
op
()
->
scope
();
auto
*
scale_tensor
=
scope
->
FindVar
(
output_scale_node
->
arg
()
->
name
)
auto
*
scale_tensor
=
scope
->
FindVar
(
output_scale_node
->
arg
()
->
name
)
->
GetMutable
<
lite
::
Tensor
>
();
->
GetMutable
<
lite
::
Tensor
>
();
float
scale_value
=
scale_tensor
->
data
<
float
>
()[
0
]
/
range
;
float
scale_value
=
scale_tensor
->
data
<
float
>
()[
0
]
/
range
;
auto
*
op_desc
=
quantized_node
->
stmt
()
->
mutable_op_info
();
auto
quantized_nodes
=
output_act_node
->
outlinks
;
op_desc
->
SetAttr
<
int
>
(
"bit_length"
,
bit_length
);
for
(
auto
*
quantized_node
:
quantized_nodes
)
{
op_desc
->
SetAttr
<
float
>
(
"input_scale"
,
scale_value
);
// Save quantization info in op_info attr
op_desc
->
SetInput
(
"X"
,
{
input_act_node
->
arg
()
->
name
});
auto
op_info
=
*
quantized_node
->
stmt
()
->
op_info
();
IR_NODE_LINK_TO
(
input_act_node
,
quantized_node
)
std
::
string
argname
;
auto
update_op_desc
=
*
quantized_node
->
stmt
()
->
mutable_op_info
();
int
index
;
quantized_node
->
stmt
()
->
ResetOp
(
update_op_desc
,
graph
->
valid_places
());
op_info
.
GetInputArgname
(
output_act_name
,
&
argname
);
op_info
.
GetInputIndex
(
output_act_name
,
&
index
);
// delete nodes and edges
op_info
.
SetAttr
<
float
>
(
argname
+
std
::
to_string
(
index
)
+
"_input_scale"
,
std
::
unordered_set
<
const
Node
*>
nodes2rm
=
{
input_scale_node
,
scale_value
);
quant_dequant_node
,
op_info
.
SetAttr
<
float
>
(
"input_scale"
,
scale_value
);
// Save it for now
output_scale_node
,
op_info
.
SetAttr
<
int
>
(
"bit_length"
,
bit_length
);
output_act_node
};
GraphSafeRemoveNodes
(
graph
,
nodes2rm
);
op_info
.
UpdateAllInputs
(
output_act_name
,
input_act_name
);
}
else
if
(
quantized_op_type_
==
"elementwise_add"
)
{
quantized_node
->
stmt
()
->
ResetOp
(
op_info
,
graph
->
valid_places
());
auto
*
input_scale_left_node
=
matched
.
at
(
"input_scale_left_node"
);
IR_NODE_LINK_TO
(
input_act_node
,
quantized_node
);
auto
*
input_act_left_node
=
matched
.
at
(
"input_act_left_node"
);
auto
*
quant_dequant_left_node
=
matched
.
at
(
"quant_dequant_left_node"
);
auto
*
output_scale_left_node
=
matched
.
at
(
"output_scale_left_node"
);
auto
*
output_act_left_node
=
matched
.
at
(
"output_act_left_node"
);
auto
*
input_scale_right_node
=
matched
.
at
(
"input_scale_right_node"
);
auto
*
input_act_right_node
=
matched
.
at
(
"input_act_right_node"
);
auto
*
quant_dequant_right_node
=
matched
.
at
(
"quant_dequant_right_node"
);
auto
*
output_scale_right_node
=
matched
.
at
(
"output_scale_right_node"
);
auto
*
output_act_right_node
=
matched
.
at
(
"output_act_right_node"
);
auto
*
quantized_node
=
matched
.
at
(
"quantized_node"
);
// obtain values, save values and relink node
int
bit_length
=
quant_dequant_left_node
->
stmt
()
->
op_info
()
->
GetAttr
<
int
>
(
"bit_length"
);
int
range
=
((
1
<<
(
bit_length
-
1
))
-
1
);
auto
*
scope
=
quant_dequant_left_node
->
stmt
()
->
op
()
->
scope
();
auto
*
left_scale_tensor
=
scope
->
FindVar
(
output_scale_left_node
->
arg
()
->
name
)
->
GetMutable
<
lite
::
Tensor
>
();
float
left_scale_value
=
left_scale_tensor
->
data
<
float
>
()[
0
]
/
range
;
auto
*
right_scale_tensor
=
scope
->
FindVar
(
output_scale_right_node
->
arg
()
->
name
)
->
GetMutable
<
lite
::
Tensor
>
();
float
right_scale_value
=
right_scale_tensor
->
data
<
float
>
()[
0
]
/
range
;
auto
*
op_desc
=
quantized_node
->
stmt
()
->
mutable_op_info
();
op_desc
->
SetAttr
<
int
>
(
"bit_length"
,
bit_length
);
op_desc
->
SetAttr
<
float
>
(
"x_input_scale"
,
left_scale_value
);
op_desc
->
SetAttr
<
float
>
(
"y_input_scale"
,
right_scale_value
);
op_desc
->
SetInput
(
"X"
,
{
input_act_left_node
->
arg
()
->
name
});
op_desc
->
SetInput
(
"Y"
,
{
input_act_right_node
->
arg
()
->
name
});
IR_NODE_LINK_TO
(
input_act_left_node
,
quantized_node
)
IR_NODE_LINK_TO
(
input_act_right_node
,
quantized_node
)
auto
update_op_desc
=
*
quantized_node
->
stmt
()
->
mutable_op_info
();
quantized_node
->
stmt
()
->
ResetOp
(
update_op_desc
,
graph
->
valid_places
());
// delete nodes and edges
std
::
unordered_set
<
const
Node
*>
nodes2rm
=
{
input_scale_left_node
,
quant_dequant_left_node
,
output_scale_left_node
,
output_act_left_node
,
input_scale_right_node
,
quant_dequant_right_node
,
output_scale_right_node
,
output_act_right_node
};
GraphSafeRemoveNodes
(
graph
,
nodes2rm
);
}
else
{
LOG
(
FATAL
)
<<
"No support quantized_op_type:"
<<
quantized_op_type_
;
}
}
// delete nodes and edges
std
::
unordered_set
<
const
Node
*>
nodes2rm
=
{
input_scale_node
,
quant_dequant_node
,
output_scale_node
,
output_act_node
};
GraphSafeRemoveNodes
(
graph
,
nodes2rm
);
}
}
cpp
::
OpDesc
DeleteQuantDequantOpFuser
::
GenOpDesc
(
const
key2nodes_t
&
matched
)
{
cpp
::
OpDesc
DeleteQuantDequantOpFuser
::
GenOpDesc
(
const
key2nodes_t
&
matched
)
{
...
...
lite/core/mir/fusion/quant_dequant_op_fuser.h
浏览文件 @
51e9898d
...
@@ -87,24 +87,16 @@ class ChannelWiseDequantOpFuser : public FuseBase {
...
@@ -87,24 +87,16 @@ class ChannelWiseDequantOpFuser : public FuseBase {
};
};
/* The pattern like "fake_quantize_dequantize_moving_average_abs_max +
/* The pattern like "fake_quantize_dequantize_moving_average_abs_max +
* pooled/elementwise_add" can be deteted by this fuser. The fuser
* quantized_op" can be deteted by this fuser. The fuser modifies the input
* extract the input_scale form fake_quant_dequant_op and save into
* scale for the quantized_op and deletes the fake_quant_dequant_op.
* the quantized_op. Besides, the fuser delete fake_quant_dequant_op in
* the graph.
*/
*/
class
DeleteQuantDequantOpFuser
:
public
FuseBase
{
class
DeleteQuantDequantOpFuser
:
public
FuseBase
{
public:
public:
explicit
DeleteQuantDequantOpFuser
(
const
std
::
string
&
quantized_op_type
)
:
quantized_op_type_
(
quantized_op_type
)
{}
void
BuildPattern
()
override
;
void
BuildPattern
()
override
;
void
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
override
;
void
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
override
;
private:
private:
cpp
::
OpDesc
GenOpDesc
(
const
key2nodes_t
&
matched
)
override
;
cpp
::
OpDesc
GenOpDesc
(
const
key2nodes_t
&
matched
)
override
;
private:
std
::
string
quantized_op_type_
{};
};
};
}
// namespace fusion
}
// namespace fusion
...
...
lite/core/op_lite.h
浏览文件 @
51e9898d
...
@@ -225,6 +225,32 @@ class OpInfo : public cpp::OpDesc {
...
@@ -225,6 +225,32 @@ class OpInfo : public cpp::OpDesc {
return
false
;
return
false
;
}
}
// For the input variable name, find the index of the corresponding
// input argname
bool
GetInputIndex
(
const
std
::
string
&
value_name
,
int
*
out
)
const
{
for
(
auto
&
item
:
inputs_
)
{
auto
it
=
std
::
find
(
item
.
second
.
begin
(),
item
.
second
.
end
(),
value_name
);
if
(
it
!=
item
.
second
.
end
())
{
*
out
=
it
-
item
.
second
.
begin
();
return
true
;
}
}
return
false
;
}
// For the output variable name, find the index of the corresponding
// output argname
bool
GetOutputIndex
(
const
std
::
string
&
value_name
,
int
*
out
)
const
{
for
(
auto
&
item
:
outputs_
)
{
auto
it
=
std
::
find
(
item
.
second
.
begin
(),
item
.
second
.
end
(),
value_name
);
if
(
it
!=
item
.
second
.
end
())
{
*
out
=
it
-
item
.
second
.
begin
();
return
true
;
}
}
return
false
;
}
void
UpdateAllInputs
(
const
std
::
string
&
from
,
const
std
::
string
&
to
)
{
void
UpdateAllInputs
(
const
std
::
string
&
from
,
const
std
::
string
&
to
)
{
for
(
auto
&
item
:
inputs_
)
{
for
(
auto
&
item
:
inputs_
)
{
for
(
auto
&
var
:
item
.
second
)
{
for
(
auto
&
var
:
item
.
second
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录