Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
c797e64d
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c797e64d
编写于
7月 12, 2022
作者:
J
joanna.wozna.intel
提交者:
GitHub
7月 12, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add pool avg to quantization and concat scales correction (#44186)
上级
015532b4
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
18 addition
and
30 deletion
+18
-30
paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc
...amework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc
+7
-0
paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc
...le/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc
+0
-18
python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py
...luid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py
+11
-12
未找到文件。
paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc
浏览文件 @
c797e64d
...
...
@@ -390,6 +390,13 @@ std::unordered_set<std::string> ComputePropagateScalesMkldnnPass::UpdateScales(
}
else
if
(
out_iter
!=
var_quant_scales
->
end
())
{
(
*
var_quant_scales
)[
input_name
]
=
out_iter
->
second
;
}
}
else
if
(
op_name
==
"concat"
)
{
auto
out_iter
=
var_quant_scales
->
find
(
op_node
->
Op
()
->
Output
(
"Out"
)[
0
]);
if
(
out_iter
!=
var_quant_scales
->
end
())
{
std
::
vector
<
std
::
string
>
input_names
=
op_node
->
Op
()
->
Input
(
"X"
);
for
(
auto
input_name
:
input_names
)
(
*
var_quant_scales
)[
input_name
]
=
out_iter
->
second
;
}
}
else
if
(
op_name
==
"scale"
)
{
const
std
::
string
output_name
=
op_node
->
Op
()
->
Output
(
"Out"
)[
0
];
auto
out_iter
=
var_quant_scales
->
find
(
output_name
);
...
...
paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc
浏览文件 @
c797e64d
...
...
@@ -55,23 +55,6 @@ void QuantDequantMkldnnPass::MarkSkipQuantizedOps(
}
}
void
QuantDequantMkldnnPass
::
MarkSkipQuantizedPool2d
(
ir
::
Graph
*
graph
)
const
{
VLOG
(
3
)
<<
"mark avg pool2d as skip quantized op"
;
for
(
auto
*
op_node
:
ir
::
TopologyVarientSort
(
*
graph
,
static_cast
<
ir
::
SortKind
>
(
0
)))
{
if
(
!
op_node
->
IsOp
())
continue
;
if
(
op_node
->
Name
()
==
"pool2d"
)
{
auto
*
op_desc
=
op_node
->
Op
();
auto
pool_type
=
BOOST_GET_CONST
(
std
::
string
,
op_desc
->
GetAttr
(
"pooling_type"
));
if
(
pool_type
==
"avg"
)
{
op_node
->
Op
()
->
SetAttr
(
"skip_quant"
,
1
);
}
}
}
}
void
QuantDequantMkldnnPass
::
CollectInfoFromFake
(
ir
::
Graph
*
graph
,
Scope
*
scope
,
...
...
@@ -548,7 +531,6 @@ void QuantDequantMkldnnPass::ApplyImpl(ir::Graph* graph) const {
auto
*
scope
=
param_scope
();
MarkSkipQuantizedOps
(
graph
,
skip_ops
);
MarkSkipQuantizedPool2d
(
graph
);
CollectInfoFromFake
(
graph
,
scope
,
fake_dequantize_types
,
&
weight_thresholds
);
CollectInputScalesFromFake
(
graph
,
scope
,
fake_quantize_types
,
&
var_quant_scales
);
...
...
python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py
浏览文件 @
c797e64d
...
...
@@ -264,6 +264,14 @@ class Quant2Int8MkldnnPass(object):
elif
output_name
in
self
.
_var_quant_scales
:
self
.
_var_quant_scales
[
input_name
]
=
self
.
_var_quant_scales
[
output_name
]
elif
op
.
name
()
==
'concat'
:
output_name
=
op
.
output
(
"Out"
)[
0
]
if
output_name
in
self
.
_var_quant_scales
:
input_names
=
op
.
input
(
"X"
)
for
input_name
in
input_names
:
self
.
_var_quant_scales
[
input_name
]
=
self
.
_var_quant_scales
[
output_name
]
elif
op
.
name
()
in
self
.
_scale_ops
:
input_name
=
op
.
input
(
"X"
)[
0
]
output_name
=
op
.
output
(
"Out"
)[
0
]
...
...
@@ -595,13 +603,6 @@ class Quant2Int8MkldnnPass(object):
_compute_lstm_weight_scales
(
"WeightX"
,
"WeightH"
)
return
graph
def
_find_avg_pooling_ids
(
self
,
graph
):
for
op
in
graph
.
all_op_nodes
():
if
op
.
name
()
in
self
.
_pool_ops
:
if
op
.
op
().
attr
(
"pooling_type"
)
==
"avg"
:
self
.
_op_ids_to_skip
.
add
(
op
.
id
())
return
self
.
_op_ids_to_skip
def
_update_relu_output_scales
(
self
,
graph
):
def
_set_unsigned_scale
(
graph
,
ops
,
op_out_name
,
predicate
):
...
...
@@ -651,11 +652,9 @@ class Quant2Int8MkldnnPass(object):
'reshape_transpose_matmul_mkldnn_fuse_pass'
)
graph
=
self
.
_apply_pass
(
graph
,
'reshape_transpose_matmul_v2_mkldnn_fuse_pass'
)
graph
=
self
.
_apply_pass
(
graph
,
'cpu_quantize_placement_pass'
,
[
'quantize_enabled_op_types'
,
'quantize_excluded_op_ids'
],
[
self
.
_ops_to_quantize
,
self
.
_find_avg_pooling_ids
(
graph
)])
graph
=
self
.
_apply_pass
(
graph
,
'cpu_quantize_placement_pass'
,
[
'quantize_enabled_op_types'
],
[
self
.
_ops_to_quantize
])
graph
=
self
.
_apply_pass
(
graph
,
'cpu_quantize_pass'
,
[
'quant_var_scales'
,
'data_layout'
],
[
self
.
_var_quant_scales
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录