Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
4c1ec41d
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4c1ec41d
编写于
3月 29, 2019
作者:
Z
Zhen Wang
提交者:
GitHub
3月 29, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #16531 from wanghaoshuang/quan_ck
[slim] Fix checkpoint of quantization strategy.
上级
e18ab78f
d41b623a
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
76 addition
and
32 deletion
+76
-32
python/paddle/fluid/contrib/slim/graph/graph_wrapper.py
python/paddle/fluid/contrib/slim/graph/graph_wrapper.py
+11
-2
python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py
.../fluid/contrib/slim/quantization/quantization_strategy.py
+65
-30
未找到文件。
python/paddle/fluid/contrib/slim/graph/graph_wrapper.py
浏览文件 @
4c1ec41d
...
...
@@ -204,6 +204,10 @@ class GraphWrapper(object):
"""
super
(
GraphWrapper
,
self
).
__init__
()
self
.
program
=
Program
()
if
program
is
None
else
program
self
.
persistables
=
{}
for
var
in
self
.
program
.
list_vars
():
if
var
.
persistable
:
self
.
persistables
[
var
.
name
]
=
var
self
.
compiled_graph
=
None
self
.
in_nodes
=
OrderedDict
(
in_nodes
)
self
.
out_nodes
=
OrderedDict
(
out_nodes
)
...
...
@@ -467,7 +471,12 @@ class GraphWrapper(object):
path(str): The path to save the persistables.
exe(framework.Executor): The executor used to save the persistables.
"""
io
.
save_persistables
(
exe
.
exe
,
path
,
main_program
=
self
.
program
)
# update persistables from program
for
var
in
self
.
program
.
list_vars
():
if
var
.
persistable
and
var
.
name
not
in
self
.
persistables
:
self
.
persistables
[
var
.
name
]
=
var
io
.
save_vars
(
exe
.
exe
,
path
,
vars
=
self
.
persistables
.
values
())
def
load_persistables
(
self
,
path
,
exe
):
"""
...
...
@@ -481,7 +490,7 @@ class GraphWrapper(object):
return
os
.
path
.
exists
(
os
.
path
.
join
(
path
,
var
.
name
))
io
.
load_vars
(
exe
.
exe
,
path
,
main_program
=
self
.
program
,
predicate
=
if_exist
)
exe
.
exe
,
path
,
vars
=
self
.
persistables
.
values
()
,
predicate
=
if_exist
)
def
update_param_shape
(
self
,
scope
):
"""
...
...
python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py
浏览文件 @
4c1ec41d
...
...
@@ -20,7 +20,7 @@ from .... import io
from
....
import
core
from
....compiler
import
CompiledProgram
from
....compiler
import
BuildStrategy
from
....framework
import
IrGraph
from
....framework
import
IrGraph
,
Variable
,
Program
from
..core.strategy
import
Strategy
from
.quantization_pass
import
*
...
...
@@ -88,41 +88,76 @@ class QuantizationStrategy(Strategy):
self
.
save_out_nodes
=
save_out_nodes
self
.
save_in_nodes
=
save_in_nodes
def
on_compression_begin
(
self
,
context
):
"""
Restore graph when the compressoin task is inited from checkpoint.
"""
# It is inited from checkpoint and has missed start epoch.
if
context
.
epoch_id
!=
0
and
context
.
epoch_id
>
self
.
start_epoch
:
_logger
.
info
(
"Restore quantization task from checkpoint"
)
self
.
_modify_graph_for_quantization
(
context
)
_logger
.
info
(
"Finish restoring quantization task from checkpoint"
)
def
_modify_graph_for_quantization
(
self
,
context
):
"""
Insert fake_quantize_op and fake_dequantize_op before trainging and testing.
"""
train_ir_graph
=
IrGraph
(
core
.
Graph
(
context
.
optimize_graph
.
program
.
clone
().
desc
),
for_test
=
False
)
test_ir_graph
=
IrGraph
(
core
.
Graph
(
context
.
eval_graph
.
program
.
clone
().
desc
),
for_test
=
True
)
transform_pass
=
QuantizationTransformPass
(
scope
=
context
.
scope
,
place
=
context
.
place
,
weight_bits
=
self
.
weight_bits
,
activation_bits
=
self
.
activation_bits
,
activation_quantize_type
=
self
.
activation_quantize_type
,
weight_quantize_type
=
self
.
weight_quantize_type
)
transform_pass
.
apply
(
train_ir_graph
)
transform_pass
.
apply
(
test_ir_graph
)
# Put persistables created by transform_pass into context.optimize_graph.persistables
# for saving checkpoint.
program_persistables
=
set
()
for
var
in
context
.
optimize_graph
.
program
.
list_vars
():
if
var
.
persistable
:
program_persistables
.
add
(
var
.
name
)
program
=
Program
()
for
var_node
in
train_ir_graph
.
all_persistable_nodes
():
if
var_node
.
name
()
not
in
program_persistables
:
var_desc
=
var_node
.
var
()
var
=
program
.
global_block
().
create_var
(
name
=
var_node
.
name
(),
shape
=
var_desc
.
shape
(),
dtype
=
var_desc
.
dtype
(),
type
=
var_desc
.
type
(),
lod_level
=
var_desc
.
lod_level
())
context
.
optimize_graph
.
persistables
[
var
.
name
]
=
var
build_strategy
=
BuildStrategy
()
build_strategy
.
enable_inplace
=
False
build_strategy
.
memory_optimize
=
False
# for quantization training
context
.
optimize_graph
.
compiled_graph
=
CompiledProgram
(
train_ir_graph
.
graph
).
with_data_parallel
(
loss_name
=
context
.
optimize_graph
.
out_nodes
[
'loss'
],
build_strategy
=
build_strategy
)
# for evaluation. And program compiled from ir graph must be with data parallel.
context
.
eval_graph
.
compiled_graph
=
CompiledProgram
(
test_ir_graph
.
graph
).
with_data_parallel
(
build_strategy
=
build_strategy
)
# for saving inference model after training
context
.
put
(
'quantization_test_ir_graph_backup'
,
test_ir_graph
)
def
on_epoch_begin
(
self
,
context
):
"""
Insert fake_quantize_op and fake_dequantize_op before trainging and testing.
"""
super
(
QuantizationStrategy
,
self
).
on_
compression
_begin
(
context
)
super
(
QuantizationStrategy
,
self
).
on_
epoch
_begin
(
context
)
if
self
.
start_epoch
==
context
.
epoch_id
:
_logger
.
info
(
'QuantizationStrategy::on_epoch_begin'
)
train_ir_graph
=
IrGraph
(
core
.
Graph
(
context
.
optimize_graph
.
program
.
desc
),
for_test
=
False
)
test_ir_graph
=
IrGraph
(
core
.
Graph
(
context
.
eval_graph
.
program
.
desc
),
for_test
=
True
)
transform_pass
=
QuantizationTransformPass
(
scope
=
context
.
scope
,
place
=
context
.
place
,
weight_bits
=
self
.
weight_bits
,
activation_bits
=
self
.
activation_bits
,
activation_quantize_type
=
self
.
activation_quantize_type
,
weight_quantize_type
=
self
.
weight_quantize_type
)
transform_pass
.
apply
(
train_ir_graph
)
transform_pass
.
apply
(
test_ir_graph
)
build_strategy
=
BuildStrategy
()
build_strategy
.
enable_inplace
=
False
build_strategy
.
memory_optimize
=
False
# for quantization training
context
.
optimize_graph
.
compiled_graph
=
CompiledProgram
(
train_ir_graph
.
graph
).
with_data_parallel
(
loss_name
=
context
.
optimize_graph
.
out_nodes
[
'loss'
],
build_strategy
=
build_strategy
)
# for evaluation. And program compiled from ir graph must be with data parallel.
context
.
eval_graph
.
compiled_graph
=
CompiledProgram
(
test_ir_graph
.
graph
).
with_data_parallel
(
build_strategy
=
build_strategy
)
# for saving inference model after training
context
.
put
(
'quantization_test_ir_graph_backup'
,
test_ir_graph
)
self
.
_modify_graph_for_quantization
(
context
)
_logger
.
info
(
'Finish QuantizationStrategy::on_epoch_begin'
)
def
on_epoch_end
(
self
,
context
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录