Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
356f5ee2
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
356f5ee2
编写于
5月 06, 2020
作者:
J
joanna.wozna.intel
提交者:
GitHub
5月 06, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Refactoring] Unify op-dequant squashes (#24277)
上级
ac9a7eee
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
58 addition
and
190 deletion
+58
-190
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+12
-44
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+7
-36
paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc
paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc
+32
-93
paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h
paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h
+7
-17
未找到文件。
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
356f5ee2
...
...
@@ -1529,39 +1529,24 @@ PDNode *patterns::RequantOp::operator()() {
return
any_op
;
}
PDNode
*
patterns
::
ConvDequant
::
operator
()()
{
// Create Operators
auto
conv_op
=
pattern
->
NewNode
(
conv_op_repr
())
->
assert_is_op
(
"conv2d"
);
auto
dequant_op
=
pattern
->
NewNode
(
dequant_op_repr
())
->
assert_is_op
(
"dequantize"
);
auto
conv_out
=
pattern
->
NewNode
(
conv_out_repr
())
->
assert_is_op_output
(
"conv2d"
,
"Output"
);
auto
dequant_out
=
pattern
->
NewNode
(
dequant_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"dequantize"
,
"Output"
);
conv_op
->
LinksTo
({
conv_out
});
dequant_op
->
LinksFrom
({
conv_out
}).
LinksTo
({
dequant_out
});
return
dequant_out
;
}
PDNode
*
patterns
::
FcDequant
::
operator
()()
{
// Create Operators
auto
fc_op
=
pattern
->
NewNode
(
fc_op_repr
())
->
assert_is_op
(
"fc"
);
PDNode
*
patterns
::
OpDequant
::
operator
()()
{
auto
any_op
=
pattern
->
NewNode
(
any_op_repr
())
->
assert_is_op
()
->
assert_more
([
&
](
Node
*
node
)
{
return
(
node
->
Op
()
->
Type
()
==
"matmul"
||
node
->
Op
()
->
Type
()
==
"conv2d"
||
node
->
Op
()
->
Type
()
==
"fc"
);
});
auto
dequant_in
=
pattern
->
NewNode
(
dequant_in_repr
())
->
assert_is_op_input
(
"dequantize"
,
"Input"
);
auto
dequant_op
=
pattern
->
NewNode
(
dequant_op_repr
())
->
assert_is_op
(
"dequantize"
);
auto
fc_out
=
pattern
->
NewNode
(
fc_out_repr
())
->
assert_is_op_output
(
"fc"
,
"Out"
);
auto
dequant_out
=
pattern
->
NewNode
(
dequant_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"dequantize"
,
"Output"
);
fc_op
->
LinksTo
({
fc_out
});
dequant_op
->
LinksFrom
({
fc_out
}).
LinksTo
({
dequant_out
});
any_op
->
LinksTo
({
dequant_in
});
dequant_op
->
LinksFrom
({
dequant_in
}).
LinksTo
({
dequant_out
});
return
dequant_out
;
}
...
...
@@ -1584,23 +1569,6 @@ PDNode *patterns::DequantScale::operator()() {
return
scale_out
;
}
PDNode
*
patterns
::
MatmulDequant
::
operator
()()
{
auto
matmul_op
=
pattern
->
NewNode
(
matmul_op_repr
())
->
assert_is_op
(
"matmul"
);
auto
dequant_op
=
pattern
->
NewNode
(
dequant_op_repr
())
->
assert_is_op
(
"dequantize"
);
auto
matmul_out
=
pattern
->
NewNode
(
matmul_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"matmul"
,
"Out"
);
auto
dequant_out
=
pattern
->
NewNode
(
dequant_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"dequantize"
,
"Output"
);
matmul_op
->
LinksTo
({
matmul_out
});
dequant_op
->
LinksFrom
({
matmul_out
}).
LinksTo
({
dequant_out
});
return
dequant_out
;
}
PDNode
*
patterns
::
ScaleMatmul
::
operator
()()
{
auto
scale_in
=
pattern
->
NewNode
(
scale_in_repr
())
->
AsInput
()
...
...
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
356f5ee2
...
...
@@ -929,33 +929,18 @@ struct RequantOp : public PatternBase {
PATTERN_DECL_NODE
(
requant_out
);
};
//
Conv
+ Dequant
//
Op
+ Dequant
// named nodes:
//
conv_op, conv_out
//
any_op, dequant_in
// dequant_op, dequant_out
struct
ConvDequant
:
public
PatternBase
{
ConvDequant
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"conv_dequant"
)
{}
PDNode
*
operator
()();
PATTERN_DECL_NODE
(
conv_op
);
PATTERN_DECL_NODE
(
conv_out
);
PATTERN_DECL_NODE
(
dequant_op
);
PATTERN_DECL_NODE
(
dequant_out
);
};
// Fc + Dequant
struct
FcDequant
:
public
PatternBase
{
FcDequant
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"fc_dequant"
)
{}
struct
OpDequant
:
public
PatternBase
{
OpDequant
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"op_dequant"
)
{}
PDNode
*
operator
()();
PATTERN_DECL_NODE
(
fc_op
);
PATTERN_DECL_NODE
(
fc_out
);
PATTERN_DECL_NODE
(
any_op
);
PATTERN_DECL_NODE
(
dequant_in
);
PATTERN_DECL_NODE
(
dequant_op
);
PATTERN_DECL_NODE
(
dequant_out
);
};
...
...
@@ -974,20 +959,6 @@ struct DequantScale : public PatternBase {
PATTERN_DECL_NODE
(
scale_out
);
};
// Matmul + Dequantize
struct
MatmulDequant
:
public
PatternBase
{
MatmulDequant
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"matmul_dequant"
)
{}
PDNode
*
operator
()();
PATTERN_DECL_NODE
(
matmul_op
);
PATTERN_DECL_NODE
(
matmul_out
);
PATTERN_DECL_NODE
(
dequant_op
);
PATTERN_DECL_NODE
(
dequant_out
);
};
// Scale + Matmul
struct
ScaleMatmul
:
public
PatternBase
{
ScaleMatmul
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
...
...
paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc
浏览文件 @
356f5ee2
...
...
@@ -223,71 +223,44 @@ void CPUQuantizeSquashPass::RequantOpSquash(Graph* graph) const {
found_requant_squash_count
);
}
void
CPUQuantizeSquashPass
::
ConvDequantSquash
(
Graph
*
graph
)
const
{
GraphPatternDetector
gpd
;
patterns
::
ConvDequant
conv_dequant_pattern
{
gpd
.
mutable_pattern
(),
"conv_dequant"
};
conv_dequant_pattern
();
int
found_conv_dequant_squash_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
VLOG
(
4
)
<<
"squash conv-dequant ops pair"
;
GET_IR_NODE_FROM_SUBGRAPH
(
conv_op
,
conv_op
,
conv_dequant_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_out
,
conv_out
,
conv_dequant_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
dequant_op
,
dequant_op
,
conv_dequant_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
dequant_out
,
dequant_out
,
conv_dequant_pattern
);
// if conv2d has one output
// and there is no fuse residual connection
// because residual fusion does not support force output with fp32
if
(
conv_out
->
outputs
.
size
()
==
1
&&
!
(
conv_op
->
Op
()
->
GetAttrIfExists
<
bool
>
(
"fuse_residual_connection"
)))
{
conv_op
->
Op
()
->
SetAttr
(
"force_fp32_output"
,
true
);
conv_op
->
Op
()
->
SetOutput
(
"Output"
,
std
::
vector
<
std
::
string
>
({
dequant_out
->
Name
()}));
IR_NODE_LINK_TO
(
conv_op
,
dequant_out
);
GraphSafeRemoveNodes
(
graph
,
{
conv_out
,
dequant_op
});
found_conv_dequant_squash_count
++
;
}
};
gpd
(
graph
,
handler
);
AddStatis
(
found_conv_dequant_squash_count
);
PrettyLogDetail
(
"--- squashed %d dequant with convs"
,
found_conv_dequant_squash_count
);
}
// squash fc with dequant
void
CPUQuantizeSquashPass
::
FcDequantSquash
(
Graph
*
graph
)
const
{
// squash dequant with previous op if that op has force_fp32_output attr
// conv2d, fc, matmul
void
CPUQuantizeSquashPass
::
OpDequantSquash
(
Graph
*
graph
)
const
{
GraphPatternDetector
gpd
;
patterns
::
FcDequant
fc_dequant_pattern
{
gpd
.
mutable_pattern
(),
"fc
_dequant"
};
fc
_dequant_pattern
();
patterns
::
OpDequant
op_dequant_pattern
{
gpd
.
mutable_pattern
(),
"op
_dequant"
};
op
_dequant_pattern
();
int
found_
fc
_dequant_squash_count
=
0
;
int
found_
op
_dequant_squash_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
VLOG
(
4
)
<<
"squash fc-dequant ops pair"
;
GET_IR_NODE_FROM_SUBGRAPH
(
fc_op
,
fc_op
,
fc_dequant_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fc_out
,
fc_out
,
fc_dequant_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
dequant_op
,
dequant_op
,
fc_dequant_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
dequant_out
,
dequant_out
,
fc_dequant_pattern
);
// if fc has force_fp32_output attribute
if
(
fc_out
->
outputs
.
size
()
==
1
)
{
fc_op
->
Op
()
->
SetAttr
(
"force_fp32_output"
,
true
);
fc_op
->
Op
()
->
SetOutput
(
"Out"
,
std
::
vector
<
std
::
string
>
({
dequant_out
->
Name
()}));
IR_NODE_LINK_TO
(
fc_op
,
dequant_out
);
GraphSafeRemoveNodes
(
graph
,
{
fc_out
,
dequant_op
});
found_fc_dequant_squash_count
++
;
VLOG
(
4
)
<<
"squash op-dequant ops pair"
;
GET_IR_NODE_FROM_SUBGRAPH
(
any_op
,
any_op
,
op_dequant_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
dequant_in
,
dequant_in
,
op_dequant_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
dequant_op
,
dequant_op
,
op_dequant_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
dequant_out
,
dequant_out
,
op_dequant_pattern
);
if
(
dequant_in
->
outputs
.
size
()
==
1
)
{
auto
output_name
=
"Out"
;
if
(
any_op
->
Op
()
->
Type
()
==
"conv2d"
)
{
// do not squash if fuse residual connection is true
// because residual fusion does not support force output with fp32
if
(
any_op
->
Op
()
->
GetAttrIfExists
<
bool
>
(
"fuse_residual_connection"
))
return
;
output_name
=
"Output"
;
}
any_op
->
Op
()
->
SetAttr
(
"force_fp32_output"
,
true
);
any_op
->
Op
()
->
SetOutput
(
output_name
,
std
::
vector
<
std
::
string
>
({
dequant_out
->
Name
()}));
IR_NODE_LINK_TO
(
any_op
,
dequant_out
);
GraphSafeRemoveNodes
(
graph
,
{
dequant_in
,
dequant_op
});
found_op_dequant_squash_count
++
;
}
};
gpd
(
graph
,
handler
);
AddStatis
(
found_
fc
_dequant_squash_count
);
PrettyLogDetail
(
"--- squashed %d dequant with
fc
s"
,
found_
fc
_dequant_squash_count
);
AddStatis
(
found_
op
_dequant_squash_count
);
PrettyLogDetail
(
"--- squashed %d dequant with
op
s"
,
found_
op
_dequant_squash_count
);
}
void
CPUQuantizeSquashPass
::
MultipleQuantizeSquash
(
Graph
*
graph
)
const
{
...
...
@@ -389,38 +362,6 @@ void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const {
found_dequant_scale_squash_count
);
}
// squash dequant with dequant
void
CPUQuantizeSquashPass
::
MatmulDequantSquash
(
Graph
*
graph
)
const
{
GraphPatternDetector
gpd
;
patterns
::
MatmulDequant
matmul_dequant_pattern
{
gpd
.
mutable_pattern
(),
"matmul_dequant"
};
matmul_dequant_pattern
();
int
found_matmul_dequant_squash_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
VLOG
(
4
)
<<
"squash matmul-dequant ops pair"
;
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_op
,
matmul_op
,
matmul_dequant_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_out
,
matmul_out
,
matmul_dequant_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
dequant_op
,
dequant_op
,
matmul_dequant_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
dequant_out
,
dequant_out
,
matmul_dequant_pattern
);
if
(
matmul_out
->
outputs
.
size
()
==
1
)
{
matmul_op
->
Op
()
->
SetAttr
(
"force_fp32_output"
,
true
);
matmul_op
->
Op
()
->
SetOutput
(
"Out"
,
std
::
vector
<
std
::
string
>
({
dequant_out
->
Name
()}));
IR_NODE_LINK_TO
(
matmul_op
,
dequant_out
);
GraphSafeRemoveNodes
(
graph
,
{
matmul_out
,
dequant_op
});
found_matmul_dequant_squash_count
++
;
}
};
gpd
(
graph
,
handler
);
AddStatis
(
found_matmul_dequant_squash_count
);
PrettyLogDetail
(
"--- squashed %d dequant with matmul"
,
found_matmul_dequant_squash_count
);
}
void
CPUQuantizeSquashPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
...
...
@@ -433,11 +374,9 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
DequantQuantSquash
(
graph
,
&
nodes_keep_counter
);
OpRequantSquash
(
graph
);
RequantOpSquash
(
graph
);
ConvDequantSquash
(
graph
);
FcDequantSquash
(
graph
);
OpDequantSquash
(
graph
);
MultipleQuantizeSquash
(
graph
);
DequantScaleSquash
(
graph
);
MatmulDequantSquash
(
graph
);
}
}
// namespace ir
...
...
paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h
浏览文件 @
356f5ee2
...
...
@@ -61,29 +61,19 @@ class CPUQuantizeSquashPass : public FusePassBase {
void
RequantOpSquash
(
Graph
*
graph
)
const
;
/*
* Squash conv2d with dequant when dequant is the only op after conv2d
*/
void
ConvDequantSquash
(
Graph
*
graph
)
const
;
/*
* Squash fc with dequant when dequant is the next op after fc
*/
void
FcDequantSquash
(
Graph
*
graph
)
const
;
* Squash dequant if the previous operator has force_fp32_output attribute
*/
void
OpDequantSquash
(
Graph
*
graph
)
const
;
/*
*
Squash quantize if several quatize ops have the same scale
*/
*
Squash quantize if several quatize ops have the same scale
*/
void
MultipleQuantizeSquash
(
Graph
*
graph
)
const
;
/*
* Squash scale if dequantize is before scale
*/
void
DequantScaleSquash
(
Graph
*
graph
)
const
;
/*
* Squash dequantize if it is after matmul
* Squash scale if dequantize is before scale
*/
void
MatmulDequant
Squash
(
Graph
*
graph
)
const
;
void
DequantScale
Squash
(
Graph
*
graph
)
const
;
const
std
::
string
name_scope_
{
"squash"
};
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录