Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
2e3ec66b
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2e3ec66b
编写于
8月 27, 2019
作者:
J
joanna.wozna.intel
提交者:
Tao Luo
8月 27, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add conv dequant squash for int8 (#18905)
上级
482ce818
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
129 addition
and
0 deletion
+129
-0
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+18
-0
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+17
-0
paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc
paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc
+33
-0
paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h
paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h
+5
-0
paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc
...id/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc
+56
-0
未找到文件。
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
2e3ec66b
...
@@ -1267,6 +1267,24 @@ PDNode *patterns::ConvRequant::operator()() {
...
@@ -1267,6 +1267,24 @@ PDNode *patterns::ConvRequant::operator()() {
return
requant_out
;
return
requant_out
;
}
}
PDNode
*
patterns
::
ConvDequant
::
operator
()()
{
// Create Operators
auto
conv_op
=
pattern
->
NewNode
(
conv_op_repr
())
->
assert_is_op
(
"conv2d"
);
auto
dequant_op
=
pattern
->
NewNode
(
dequant_op_repr
())
->
assert_is_op
(
"dequantize"
);
auto
conv_out
=
pattern
->
NewNode
(
conv_out_repr
())
->
assert_is_op_output
(
"conv2d"
,
"Output"
);
auto
dequant_out
=
pattern
->
NewNode
(
dequant_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"dequantize"
,
"Output"
);
conv_op
->
LinksTo
({
conv_out
});
dequant_op
->
LinksFrom
({
conv_out
}).
LinksTo
({
dequant_out
});
return
dequant_out
;
}
PDNode
*
patterns
::
PriorBox
::
operator
()()
{
PDNode
*
patterns
::
PriorBox
::
operator
()()
{
auto
prior_box_op
=
auto
prior_box_op
=
pattern
->
NewNode
(
prior_box_op_repr
())
->
assert_is_op
(
"prior_box"
);
pattern
->
NewNode
(
prior_box_op_repr
())
->
assert_is_op
(
"prior_box"
);
...
...
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
2e3ec66b
...
@@ -793,6 +793,23 @@ struct ConvRequant : public PatternBase {
...
@@ -793,6 +793,23 @@ struct ConvRequant : public PatternBase {
PATTERN_DECL_NODE
(
requant_out
);
PATTERN_DECL_NODE
(
requant_out
);
};
};
// Conv + Dequant
// named nodes:
// conv_op, conv_out
// dequant_op, dequant_out
struct
ConvDequant
:
public
PatternBase
{
ConvDequant
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"conv_dequant"
)
{}
PDNode
*
operator
()();
PATTERN_DECL_NODE
(
conv_op
);
PATTERN_DECL_NODE
(
conv_out
);
PATTERN_DECL_NODE
(
dequant_op
);
PATTERN_DECL_NODE
(
dequant_out
);
};
// PriorBox operator
// PriorBox operator
// operator: prior_box_op
// operator: prior_box_op
// inputs: prior_box_input, prior_box_image
// inputs: prior_box_input, prior_box_image
...
...
paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc
浏览文件 @
2e3ec66b
...
@@ -160,6 +160,38 @@ void CPUQuantizeSquashPass::ConvRequantSquash(Graph* graph) const {
...
@@ -160,6 +160,38 @@ void CPUQuantizeSquashPass::ConvRequantSquash(Graph* graph) const {
found_requant_squash_count
);
found_requant_squash_count
);
}
}
void
CPUQuantizeSquashPass
::
ConvDequantSquash
(
Graph
*
graph
)
const
{
GraphPatternDetector
gpd
;
patterns
::
ConvDequant
conv_dequant_pattern
{
gpd
.
mutable_pattern
(),
"conv_dequant"
};
conv_dequant_pattern
();
int
found_conv_dequant_squash_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
VLOG
(
4
)
<<
"squash conv-dequant ops pair"
;
GET_IR_NODE_FROM_SUBGRAPH
(
conv_op
,
conv_op
,
conv_dequant_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_out
,
conv_out
,
conv_dequant_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
dequant_op
,
dequant_op
,
conv_dequant_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
dequant_out
,
dequant_out
,
conv_dequant_pattern
);
// if conv2d has one output
if
(
conv_out
->
outputs
.
size
()
==
1
)
{
conv_op
->
Op
()
->
SetAttr
(
"force_fp32_output"
,
true
);
conv_op
->
Op
()
->
SetOutput
(
"Output"
,
std
::
vector
<
std
::
string
>
({
dequant_out
->
Name
()}));
IR_NODE_LINK_TO
(
conv_op
,
dequant_out
);
GraphSafeRemoveNodes
(
graph
,
{
conv_out
,
dequant_op
});
found_conv_dequant_squash_count
++
;
}
};
gpd
(
graph
,
handler
);
AddStatis
(
found_conv_dequant_squash_count
);
PrettyLogDetail
(
"--- squashed %d dequant with convs"
,
found_conv_dequant_squash_count
);
}
void
CPUQuantizeSquashPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
void
CPUQuantizeSquashPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE
(
graph
);
PADDLE_ENFORCE
(
graph
);
FusePassBase
::
Init
(
"cpu_quantize_squash_pass"
,
graph
);
FusePassBase
::
Init
(
"cpu_quantize_squash_pass"
,
graph
);
...
@@ -168,6 +200,7 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
...
@@ -168,6 +200,7 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
FindNodesToKeep
(
graph
,
&
nodes_keep_counter
);
FindNodesToKeep
(
graph
,
&
nodes_keep_counter
);
DequantQuantSquash
(
graph
,
&
nodes_keep_counter
);
DequantQuantSquash
(
graph
,
&
nodes_keep_counter
);
ConvRequantSquash
(
graph
);
ConvRequantSquash
(
graph
);
ConvDequantSquash
(
graph
);
}
}
}
// namespace ir
}
// namespace ir
...
...
paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h
浏览文件 @
2e3ec66b
...
@@ -55,6 +55,11 @@ class CPUQuantizeSquashPass : public FusePassBase {
...
@@ -55,6 +55,11 @@ class CPUQuantizeSquashPass : public FusePassBase {
*/
*/
void
ConvRequantSquash
(
Graph
*
graph
)
const
;
void
ConvRequantSquash
(
Graph
*
graph
)
const
;
/*
* Squash conv2d with dequant when dequant is the only op after conv2d
*/
void
ConvDequantSquash
(
Graph
*
graph
)
const
;
const
std
::
string
name_scope_
{
"squash"
};
const
std
::
string
name_scope_
{
"squash"
};
};
};
...
...
paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc
浏览文件 @
2e3ec66b
...
@@ -161,6 +161,36 @@ ProgramDesc BuildConvMultiRequantProgramDesc(bool use_mkldnn, float scale_out,
...
@@ -161,6 +161,36 @@ ProgramDesc BuildConvMultiRequantProgramDesc(bool use_mkldnn, float scale_out,
return
prog
;
return
prog
;
}
}
// a->Conv1->b
// b->Dequant1(Scale1)->c
// c->Concat
ProgramDesc
BuildConvDequantConcatProgramDesc
(
bool
use_mkldnn
,
float
scale_out
,
float
scale
)
{
ProgramDesc
prog
;
for
(
auto
&
v
:
variable_names
)
{
prog
.
MutableBlock
(
0
)
->
Var
(
v
);
}
SetOp
(
&
prog
,
"conv2d"
,
"Conv1"
,
{
"a"
},
{
"b"
},
use_mkldnn
,
scale_out
);
SetOp
(
&
prog
,
"dequantize"
,
"Dequant1"
,
{
"b"
},
{
"c"
},
use_mkldnn
,
scale
);
SetOp
(
&
prog
,
"concat"
,
"Concat1"
,
{
"c"
},
{
"d"
},
use_mkldnn
);
return
prog
;
}
// a->Conv1->b
// b->Dequant1(Scale1)->c
// b->Conv2->d
ProgramDesc
BuildConvDequantConvProgramDesc
(
bool
use_mkldnn
,
float
scale_out
,
float
scale
)
{
ProgramDesc
prog
;
for
(
auto
&
v
:
variable_names
)
{
prog
.
MutableBlock
(
0
)
->
Var
(
v
);
}
SetOp
(
&
prog
,
"conv2d"
,
"Conv1"
,
{
"a"
},
{
"b"
},
use_mkldnn
,
scale_out
);
SetOp
(
&
prog
,
"dequantize"
,
"Dequant1"
,
{
"b"
},
{
"c"
},
use_mkldnn
,
scale
);
SetOp
(
&
prog
,
"conv2d"
,
"Conv2"
,
{
"b"
},
{
"d"
},
use_mkldnn
);
return
prog
;
}
void
InitTensorHolder
(
Scope
*
scope
,
const
paddle
::
platform
::
Place
&
place
,
void
InitTensorHolder
(
Scope
*
scope
,
const
paddle
::
platform
::
Place
&
place
,
const
char
*
var_name
)
{
const
char
*
var_name
)
{
auto
x
=
scope
->
Var
(
var_name
);
auto
x
=
scope
->
Var
(
var_name
);
...
@@ -217,6 +247,7 @@ void EqualScaleOutTest(const ProgramDesc& prog, const std::string& name,
...
@@ -217,6 +247,7 @@ void EqualScaleOutTest(const ProgramDesc& prog, const std::string& name,
void
CheckRequantScalesTest
(
const
ProgramDesc
&
prog
,
float
scale_in
,
void
CheckRequantScalesTest
(
const
ProgramDesc
&
prog
,
float
scale_in
,
float
scale_out
)
{
float
scale_out
)
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
prog
));
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
prog
));
PrepareGraph
(
&
graph
,
prog
);
PrepareGraph
(
&
graph
,
prog
);
RegisterPass
(
&
graph
);
RegisterPass
(
&
graph
);
...
@@ -238,6 +269,7 @@ TEST(CpuQuantizeSquashPass, equal_scales) {
...
@@ -238,6 +269,7 @@ TEST(CpuQuantizeSquashPass, equal_scales) {
auto
use_mkldnn
=
true
;
auto
use_mkldnn
=
true
;
// Remove 4 nodes: Dequant, Quant, e, f
// Remove 4 nodes: Dequant, Quant, e, f
auto
remove_nodes
=
4
;
auto
remove_nodes
=
4
;
CountNodeTest
(
CountNodeTest
(
BuildConvRequantProgramDesc
(
use_mkldnn
,
scale_out
,
scale
,
scale
),
BuildConvRequantProgramDesc
(
use_mkldnn
,
scale_out
,
scale
,
scale
),
remove_nodes
);
remove_nodes
);
...
@@ -253,6 +285,7 @@ TEST(CpuQuantizeSquashPass, unequal_scales) {
...
@@ -253,6 +285,7 @@ TEST(CpuQuantizeSquashPass, unequal_scales) {
auto
use_mkldnn
=
true
;
auto
use_mkldnn
=
true
;
// Remove 4 nodes: Dequant, Quant, e, d
// Remove 4 nodes: Dequant, Quant, e, d
auto
remove_nodes
=
4
;
auto
remove_nodes
=
4
;
CountNodeTest
(
CountNodeTest
(
BuildConvRequantProgramDesc
(
use_mkldnn
,
scale_out
,
scale1
,
scale2
),
BuildConvRequantProgramDesc
(
use_mkldnn
,
scale_out
,
scale1
,
scale2
),
remove_nodes
);
remove_nodes
);
...
@@ -280,6 +313,7 @@ TEST(CpuQuantizeSquashPass, branch_to_equal_unequal_and_fp32) {
...
@@ -280,6 +313,7 @@ TEST(CpuQuantizeSquashPass, branch_to_equal_unequal_and_fp32) {
// Remove 3 nodes: Quant1, c, Quant2,
// Remove 3 nodes: Quant1, c, Quant2,
// Insert 1 node: Requant
// Insert 1 node: Requant
auto
remove_nodes
=
2
;
auto
remove_nodes
=
2
;
CountNodeTest
(
BuildConvMultiOutputProgramDesc
(
use_mkldnn
,
scale_out
,
scale
,
CountNodeTest
(
BuildConvMultiOutputProgramDesc
(
use_mkldnn
,
scale_out
,
scale
,
scale
,
scale2
),
scale
,
scale2
),
remove_nodes
);
remove_nodes
);
...
@@ -322,6 +356,7 @@ TEST(CpuQuantizeSquashPass,
...
@@ -322,6 +356,7 @@ TEST(CpuQuantizeSquashPass,
// Remove 3 nodes: Dequant1, c, Quant
// Remove 3 nodes: Dequant1, c, Quant
// Insert 1 node: Requant
// Insert 1 node: Requant
auto
remove_nodes
=
2
;
auto
remove_nodes
=
2
;
CountNodeTest
(
CountNodeTest
(
BuildConcatDequantQuantProgramDesc
(
use_mkldnn
,
scale_out
,
scale
,
scale2
),
BuildConcatDequantQuantProgramDesc
(
use_mkldnn
,
scale_out
,
scale
,
scale2
),
remove_nodes
);
remove_nodes
);
...
@@ -345,6 +380,27 @@ TEST(CpuQuantizeSquashPass, more_than_one_conv_out_outputs) {
...
@@ -345,6 +380,27 @@ TEST(CpuQuantizeSquashPass, more_than_one_conv_out_outputs) {
remove_nodes
);
remove_nodes
);
}
}
// a->Conv1->c->Concat
TEST
(
CpuQuantizeSquashPass
,
conv_dequant_only_one_output
)
{
auto
scale_out
=
1.0
f
;
auto
scale
=
1.2345
f
;
auto
use_mkldnn
=
true
;
// remove 2 nodes: Dequant1, c
auto
remove_nodes
=
2
;
CountNodeTest
(
BuildConvDequantConcatProgramDesc
(
use_mkldnn
,
scale_out
,
scale
),
remove_nodes
);
}
TEST
(
CpuQuantizeSquashPass
,
conv_dequant_more_than_one_op_after_conv
)
{
auto
scale_out
=
1.0
f
;
auto
scale
=
1.2345
f
;
auto
use_mkldnn
=
true
;
// nothing change
auto
remove_nodes
=
0
;
CountNodeTest
(
BuildConvDequantConvProgramDesc
(
use_mkldnn
,
scale_out
,
scale
),
remove_nodes
);
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录