Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
50db9490
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
大约 2 年 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
50db9490
编写于
6月 24, 2022
作者:
C
ceci3
提交者:
GitHub
6月 24, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimization train config (#1187)
上级
8c6e3ab9
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
128 addition
and
67 deletion
+128
-67
paddleslim/auto_compression/auto_strategy.py
paddleslim/auto_compression/auto_strategy.py
+9
-6
paddleslim/auto_compression/compressor.py
paddleslim/auto_compression/compressor.py
+119
-61
未找到文件。
paddleslim/auto_compression/auto_strategy.py
浏览文件 @
50db9490
...
@@ -46,12 +46,13 @@ default_hpo_config = {
...
@@ -46,12 +46,13 @@ default_hpo_config = {
# default quant config, can be used by ptq&hpo and qat&distillation
# default quant config, can be used by ptq&hpo and qat&distillation
default_quant_config
=
{
default_quant_config
=
{
'quantize_op_types'
:
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
,
'matmul'
],
'quantize_op_types'
:
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
,
'matmul'
,
'matmul_v2'
],
'weight_bits'
:
8
,
'weight_bits'
:
8
,
'activation_bits'
:
8
,
'activation_bits'
:
8
,
"is_full_quantize"
:
False
,
"is_full_quantize"
:
False
,
"activation_quantize_type"
:
'
ran
ge_abs_max'
,
"activation_quantize_type"
:
'
moving_avera
ge_abs_max'
,
"weight_quantize_type"
:
'abs_max'
,
"weight_quantize_type"
:
'
channel_wise_
abs_max'
,
"not_quant_pattern"
:
[
"skip_quant"
],
"not_quant_pattern"
:
[
"skip_quant"
],
}
}
...
@@ -60,10 +61,12 @@ DefaultTrainConfig = {
...
@@ -60,10 +61,12 @@ DefaultTrainConfig = {
"epochs"
:
1
,
"epochs"
:
1
,
"eval_iter"
:
500
,
"eval_iter"
:
500
,
"learning_rate"
:
0.0001
,
"learning_rate"
:
0.0001
,
"optimizer
"
:
"Momentum"
,
"optimizer
_builder"
:
{
"optim_args
"
:
{
"optimizer
"
:
{
"weight_decay"
:
4.0e-05
"type"
:
"Momentum"
,
},
},
"weight_decay"
:
4.0e-05
}
}
}
EXPERIENCE_STRATEGY_WITHOUT_LOSS
=
[
EXPERIENCE_STRATEGY_WITHOUT_LOSS
=
[
...
...
paddleslim/auto_compression/compressor.py
浏览文件 @
50db9490
...
@@ -16,6 +16,7 @@ import logging
...
@@ -16,6 +16,7 @@ import logging
import
os
import
os
import
sys
import
sys
import
numpy
as
np
import
numpy
as
np
import
copy
import
inspect
import
inspect
import
shutil
import
shutil
from
time
import
gmtime
,
strftime
from
time
import
gmtime
,
strftime
...
@@ -28,7 +29,7 @@ from ..common import get_logger
...
@@ -28,7 +29,7 @@ from ..common import get_logger
from
..common.patterns
import
get_patterns
from
..common.patterns
import
get_patterns
from
..analysis
import
TableLatencyPredictor
from
..analysis
import
TableLatencyPredictor
from
.create_compressed_program
import
build_distill_program
,
build_quant_program
,
build_prune_program
,
remove_unused_var_nodes
from
.create_compressed_program
import
build_distill_program
,
build_quant_program
,
build_prune_program
,
remove_unused_var_nodes
from
.strategy_config
import
ProgramInfo
,
merge_config
from
.strategy_config
import
TrainConfig
,
ProgramInfo
,
merge_config
from
.auto_strategy
import
prepare_strategy
,
get_final_quant_config
,
create_strategy_config
,
create_train_config
from
.auto_strategy
import
prepare_strategy
,
get_final_quant_config
,
create_strategy_config
,
create_train_config
from
.utils.predict
import
with_variable_shape
from
.utils.predict
import
with_variable_shape
...
@@ -127,7 +128,6 @@ class AutoCompression:
...
@@ -127,7 +128,6 @@ class AutoCompression:
if
not
os
.
path
.
exists
(
self
.
final_dir
):
if
not
os
.
path
.
exists
(
self
.
final_dir
):
os
.
makedirs
(
self
.
final_dir
)
os
.
makedirs
(
self
.
final_dir
)
self
.
strategy_config
=
strategy_config
self
.
strategy_config
=
strategy_config
self
.
train_config
=
train_config
self
.
train_dataloader
=
train_dataloader
self
.
train_dataloader
=
train_dataloader
self
.
target_speedup
=
target_speedup
self
.
target_speedup
=
target_speedup
self
.
eval_function
=
eval_callback
self
.
eval_function
=
eval_callback
...
@@ -142,7 +142,7 @@ class AutoCompression:
...
@@ -142,7 +142,7 @@ class AutoCompression:
self
.
model_type
=
self
.
_get_model_type
(
self
.
_exe
,
model_dir
,
self
.
model_type
=
self
.
_get_model_type
(
self
.
_exe
,
model_dir
,
model_filename
,
params_filename
)
model_filename
,
params_filename
)
if
self
.
train_config
is
not
None
and
self
.
train_config
.
use_fleet
:
if
train_config
is
not
None
and
train_config
.
use_fleet
:
fleet
.
init
(
is_collective
=
True
)
fleet
.
init
(
is_collective
=
True
)
if
with_variable_shape
(
if
with_variable_shape
(
...
@@ -173,10 +173,48 @@ class AutoCompression:
...
@@ -173,10 +173,48 @@ class AutoCompression:
self
.
_strategy
,
self
.
_config
=
self
.
_prepare_strategy
(
self
.
_strategy
,
self
.
_config
=
self
.
_prepare_strategy
(
self
.
strategy_config
)
self
.
strategy_config
)
self
.
train_config
=
self
.
_get_final_train_config
(
train_config
,
self
.
_strategy
,
self
.
model_type
)
def
_get_final_train_config
(
self
,
train_config
,
strategy_config
,
model_type
):
# If train_config is None, set default train_config
# If train_config is None, set default train_config
if
self
.
train_config
is
None
:
if
train_config
is
None
:
self
.
train_config
=
create_train_config
(
self
.
strategy_config
,
train_config
=
create_train_config
(
strategy_config
,
model_type
)
self
.
model_type
)
train_configs
=
[
train_config
]
for
idx
in
range
(
1
,
len
(
self
.
_strategy
)):
if
'qat'
in
self
.
_strategy
[
idx
]:
### if compress strategy more than one, the train config in the yaml set for prune
### the train config for quantization is extrapolate from the yaml
tmp_train_config
=
copy
.
deepcopy
(
train_config
.
__dict__
)
### the epoch, train_iter, learning rate of quant is 10% of the prune compress
tmp_train_config
[
'epochs'
]
=
max
(
int
(
train_config
.
epochs
*
0.1
),
1
)
if
train_config
.
train_iter
is
not
None
:
tmp_train_config
[
'train_iter'
]
=
int
(
train_config
.
train_iter
*
0.1
)
if
isinstance
(
train_config
.
learning_rate
,
float
):
tmp_train_config
[
'learning_rate'
]
=
train_config
.
learning_rate
*
0.1
else
:
if
'learning_rate'
in
train_config
.
learning_rate
:
tmp_train_config
[
'learning_rate'
][
'learning_rate'
]
=
train_config
.
learning_rate
[
'learning_rate'
]
*
0.1
else
:
### learning rate decay is PiecewiseDecay
tmp_train_config
[
'learning_rate'
][
'values'
]
=
list
(
map
(
lambda
x
:
x
*
0.1
,
train_config
.
learning_rate
[
'values'
]))
train_cfg
=
TrainConfig
(
**
tmp_train_config
)
elif
'ptq'
in
self
.
_strategy
[
idx
]:
train_cfg
=
None
else
:
tmp_train_config
=
copy
.
deepcopy
(
train_config
.
__dict__
)
train_cfg
=
TrainConfig
(
**
tmp_train_config
)
train_configs
.
append
(
train_cfg
)
return
train_configs
def
_infer_shape
(
self
,
model_dir
,
model_filename
,
params_filename
,
def
_infer_shape
(
self
,
model_dir
,
model_filename
,
params_filename
,
input_shapes
,
save_path
):
input_shapes
,
save_path
):
...
@@ -285,42 +323,50 @@ class AutoCompression:
...
@@ -285,42 +323,50 @@ class AutoCompression:
single_teacher_distill_config
is
not
None
else
\
single_teacher_distill_config
is
not
None
else
\
multi_teacher_distill_config
multi_teacher_distill_config
### case1: quant_config & hpo_config ==> PTQ & HPO
only_distillation
=
True
if
quant_config
is
not
None
and
hpo_config
is
not
None
:
strategy
.
append
(
'ptq_hpo'
)
config
.
append
(
merge_config
(
quant_config
,
hpo_config
))
### case2: quant_config & distill config ==> QAT & Distill
elif
quant_config
is
not
None
and
self
.
_distill_config
is
not
None
:
strategy
.
append
(
'qat_dis'
)
config
.
append
(
merge_config
(
quant_config
,
self
.
_distill_config
))
### case3: prune_config & distill config
### case1: prune_config & distill config
elif
prune_config
is
not
None
and
self
.
_distill_config
is
not
None
:
if
prune_config
is
not
None
and
self
.
_distill_config
is
not
None
:
only_distillation
=
False
strategy
.
append
(
'channel_prune_dis'
)
strategy
.
append
(
'channel_prune_dis'
)
config
.
append
(
merge_config
(
prune_config
,
self
.
_distill_config
))
config
.
append
(
merge_config
(
prune_config
,
self
.
_distill_config
))
### case4: asp_config & distill config
### case2: asp_config & distill config
elif
asp_config
is
not
None
and
self
.
_distill_config
is
not
None
:
if
asp_config
is
not
None
and
self
.
_distill_config
is
not
None
:
only_distillation
=
False
strategy
.
append
(
'asp_prune_dis'
)
strategy
.
append
(
'asp_prune_dis'
)
config
.
append
(
merge_config
(
asp_config
,
self
.
_distill_config
))
config
.
append
(
merge_config
(
asp_config
,
self
.
_distill_config
))
### case5: transformer_prune_config & distill config
### case3: transformer_prune_config & distill config
elif
transformer_prune_config
is
not
None
and
self
.
_distill_config
is
not
None
:
if
transformer_prune_config
is
not
None
and
self
.
_distill_config
is
not
None
:
only_distillation
=
False
strategy
.
append
(
'transformer_prune_dis'
)
strategy
.
append
(
'transformer_prune_dis'
)
config
.
append
(
config
.
append
(
merge_config
(
transformer_prune_config
,
merge_config
(
transformer_prune_config
,
self
.
_distill_config
))
self
.
_distill_config
))
### case6: unstructure_config & distill config
### case4: unstructure_config & distill config
elif
unstructure_prune_config
is
not
None
and
self
.
_distill_config
is
not
None
:
if
unstructure_prune_config
is
not
None
and
self
.
_distill_config
is
not
None
:
only_distillation
=
False
strategy
.
append
(
'unstructure_prune_dis'
)
strategy
.
append
(
'unstructure_prune_dis'
)
config
.
append
(
config
.
append
(
merge_config
(
unstructure_prune_config
,
merge_config
(
unstructure_prune_config
,
self
.
_distill_config
))
self
.
_distill_config
))
### case5: quant_config & hpo_config ==> PTQ & HPO
if
quant_config
is
not
None
and
hpo_config
is
not
None
:
only_distillation
=
False
strategy
.
append
(
'ptq_hpo'
)
config
.
append
(
merge_config
(
quant_config
,
hpo_config
))
### case6: quant_config & distill config ==> QAT & Distill
if
quant_config
is
not
None
and
self
.
_distill_config
is
not
None
:
only_distillation
=
False
strategy
.
append
(
'qat_dis'
)
config
.
append
(
merge_config
(
quant_config
,
self
.
_distill_config
))
### case7: distill_config
### case7: distill_config
elif
self
.
_distill_config
is
not
None
:
if
only_distillation
==
True
and
self
.
_distill_config
is
not
None
:
if
single_teacher_distill_config
is
not
None
:
if
single_teacher_distill_config
is
not
None
:
strategy
.
append
(
'single_teacher_dis'
)
strategy
.
append
(
'single_teacher_dis'
)
config
.
append
(
single_teacher_distill_config
)
config
.
append
(
single_teacher_distill_config
)
...
@@ -328,11 +374,18 @@ class AutoCompression:
...
@@ -328,11 +374,18 @@ class AutoCompression:
strategy
.
append
(
'multi_teacher_dis'
)
strategy
.
append
(
'multi_teacher_dis'
)
config
.
append
(
multi_teacher_distill_config
)
config
.
append
(
multi_teacher_distill_config
)
### case N: todo
### NOTE: keep quantation in the last step
else
:
idx
=
-
1
raise
NotImplementedError
(
if
'qat_dis'
in
strategy
and
strategy
.
index
(
'qat_dis'
)
!=
(
"Not Implemented {} be set at the same time now"
.
format
(
len
(
strategy
)
-
1
):
strategy_c
.
keys
()))
idx
=
strategy
.
index
(
'qat_dis'
)
elif
'ptq_hpo'
in
strategy
and
strategy
.
index
(
'ptq_hpo'
)
!=
(
len
(
strategy
)
-
1
):
idx
=
strategy
.
index
(
'ptq_hpo'
)
if
idx
!=
-
1
:
strategy
=
strategy
[:
idx
]
+
strategy
[
idx
+
1
:]
+
[
strategy
[
idx
]]
config
=
config
[:
idx
]
+
config
[
idx
+
1
:]
+
[
config
[
idx
]]
return
strategy
,
config
return
strategy
,
config
...
@@ -356,7 +409,8 @@ class AutoCompression:
...
@@ -356,7 +409,8 @@ class AutoCompression:
return
strategy
return
strategy
def
_prepare_program
(
self
,
program
,
feed_target_names
,
fetch_targets
,
def
_prepare_program
(
self
,
program
,
feed_target_names
,
fetch_targets
,
patterns
,
default_distill_node_pair
,
strategy
,
config
):
patterns
,
default_distill_node_pair
,
strategy
,
config
,
train_config
):
train_program
=
recover_inference_program
(
program
)
train_program
=
recover_inference_program
(
program
)
startup_program
=
paddle
.
static
.
Program
()
startup_program
=
paddle
.
static
.
Program
()
train_program_info
=
ProgramInfo
(
startup_program
,
train_program
,
train_program_info
=
ProgramInfo
(
startup_program
,
train_program
,
...
@@ -369,11 +423,11 @@ class AutoCompression:
...
@@ -369,11 +423,11 @@ class AutoCompression:
_logger
.
info
(
_logger
.
info
(
"Calculating the iterations per epoch……(It will take some time)"
)
"Calculating the iterations per epoch……(It will take some time)"
)
# NOTE:XXX: This way of calculating the iters needs to be improved.
# NOTE:XXX: This way of calculating the iters needs to be improved.
if
self
.
train_config
.
epochs
:
if
train_config
.
epochs
:
iters_per_epoch
=
len
(
list
(
self
.
train_dataloader
()))
iters_per_epoch
=
len
(
list
(
self
.
train_dataloader
()))
total_iters
=
self
.
train_config
.
epochs
*
iters_per_epoch
total_iters
=
train_config
.
epochs
*
iters_per_epoch
elif
self
.
train_config
.
train_iter
:
elif
train_config
.
train_iter
:
total_iters
=
self
.
train_config
.
train_iter
total_iters
=
train_config
.
train_iter
else
:
else
:
raise
RuntimeError
(
raise
RuntimeError
(
'train_config must has `epochs` or `train_iter` field.'
)
'train_config must has `epochs` or `train_iter` field.'
)
...
@@ -392,8 +446,8 @@ class AutoCompression:
...
@@ -392,8 +446,8 @@ class AutoCompression:
self
.
_exe
,
self
.
_places
,
config_dict
,
train_program_info
,
self
.
_exe
,
self
.
_places
,
config_dict
,
train_program_info
,
strategy
,
patterns
,
self
.
eval_dataloader
)
strategy
,
patterns
,
self
.
eval_dataloader
)
if
self
.
train_config
.
use_fleet
:
if
train_config
.
use_fleet
:
dist_strategy
=
_prepare_fleet_strategy
(
self
.
train_config
)
dist_strategy
=
_prepare_fleet_strategy
(
train_config
)
else
:
else
:
dist_strategy
=
None
dist_strategy
=
None
...
@@ -403,7 +457,7 @@ class AutoCompression:
...
@@ -403,7 +457,7 @@ class AutoCompression:
self
.
_exe
,
self
.
_exe
,
self
.
_places
,
self
.
_places
,
config_dict
,
config_dict
,
self
.
train_config
.
__dict__
,
train_config
.
__dict__
,
train_program_info
,
train_program_info
,
pruner
=
self
.
_pruner
,
pruner
=
self
.
_pruner
,
dist_strategy
=
dist_strategy
,
dist_strategy
=
dist_strategy
,
...
@@ -415,7 +469,7 @@ class AutoCompression:
...
@@ -415,7 +469,7 @@ class AutoCompression:
train_program_info
,
test_program_info
,
self
.
_quant_config
=
build_quant_program
(
train_program_info
,
test_program_info
,
self
.
_quant_config
=
build_quant_program
(
self
.
_exe
,
self
.
_places
,
config_dict
,
train_program_info
,
self
.
_exe
,
self
.
_places
,
config_dict
,
train_program_info
,
test_program_info
)
test_program_info
)
if
self
.
train_config
.
sparse_model
:
if
train_config
.
sparse_model
:
from
..prune.unstructured_pruner
import
UnstructuredPruner
from
..prune.unstructured_pruner
import
UnstructuredPruner
# NOTE: The initialization parameter of this pruner doesn't work, it is only used to call the 'set_static_masks' function
# NOTE: The initialization parameter of this pruner doesn't work, it is only used to call the 'set_static_masks' function
self
.
_pruner
=
UnstructuredPruner
(
self
.
_pruner
=
UnstructuredPruner
(
...
@@ -428,10 +482,10 @@ class AutoCompression:
...
@@ -428,10 +482,10 @@ class AutoCompression:
self
.
_exe
.
run
(
train_program_info
.
startup_program
)
self
.
_exe
.
run
(
train_program_info
.
startup_program
)
if
(
not
self
.
train_config
.
use_fleet
if
(
not
train_config
.
use_fleet
)
and
train_config
.
amp_config
is
not
None
:
)
and
self
.
train_config
.
amp_config
is
not
None
:
if
hasattr
(
if
hasattr
(
self
.
train_config
.
amp_config
,
'use_pure_fp16'
train_config
.
amp_config
,
)
and
self
.
train_config
.
amp_config
.
use_pure_fp16
:
'use_pure_fp16'
)
and
train_config
.
amp_config
.
use_pure_fp16
:
train_program_info
.
optimizer
.
amp_init
(
train_program_info
.
optimizer
.
amp_init
(
self
.
_places
,
scope
=
paddle
.
static
.
global_scope
())
self
.
_places
,
scope
=
paddle
.
static
.
global_scope
())
...
@@ -439,7 +493,7 @@ class AutoCompression:
...
@@ -439,7 +493,7 @@ class AutoCompression:
### prune weight in scope
### prune weight in scope
self
.
_pruner
.
prune_model
(
train_program_info
.
program
)
self
.
_pruner
.
prune_model
(
train_program_info
.
program
)
if
not
self
.
train_config
.
use_fleet
:
if
not
train_config
.
use_fleet
:
train_program_info
=
self
.
_compiled_program
(
train_program_info
,
train_program_info
=
self
.
_compiled_program
(
train_program_info
,
strategy
)
strategy
)
test_program_info
=
self
.
_compiled_program
(
test_program_info
,
test_program_info
=
self
.
_compiled_program
(
test_program_info
,
...
@@ -475,9 +529,10 @@ class AutoCompression:
...
@@ -475,9 +529,10 @@ class AutoCompression:
def
compress
(
self
):
def
compress
(
self
):
self
.
tmp_dir
=
self
.
create_tmp_dir
(
self
.
final_dir
)
self
.
tmp_dir
=
self
.
create_tmp_dir
(
self
.
final_dir
)
for
strategy_idx
,
(
for
strategy_idx
,
(
strategy
,
strategy
,
config
,
train_config
config
)
in
enumerate
(
zip
(
self
.
_strategy
,
self
.
_config
)):
)
in
enumerate
(
zip
(
self
.
_strategy
,
self
.
_config
,
self
.
train_config
)):
self
.
single_strategy_compress
(
strategy
,
config
,
strategy_idx
)
self
.
single_strategy_compress
(
strategy
,
config
,
strategy_idx
,
train_config
)
if
strategy
==
'ptq_hpo'
and
config
.
max_quant_count
==
1
and
platform
.
system
(
if
strategy
==
'ptq_hpo'
and
config
.
max_quant_count
==
1
and
platform
.
system
(
).
lower
()
==
'linux'
:
).
lower
()
==
'linux'
:
...
@@ -488,7 +543,8 @@ class AutoCompression:
...
@@ -488,7 +543,8 @@ class AutoCompression:
quant_strategy
,
quant_config
=
self
.
_prepare_strategy
(
quant_strategy
,
quant_config
=
self
.
_prepare_strategy
(
final_quant_config
)
final_quant_config
)
self
.
single_strategy_compress
(
quant_strategy
[
0
],
self
.
single_strategy_compress
(
quant_strategy
[
0
],
quant_config
[
0
],
strategy_idx
)
quant_config
[
0
],
strategy_idx
,
train_config
)
tmp_model_path
=
os
.
path
.
join
(
tmp_model_path
=
os
.
path
.
join
(
self
.
tmp_dir
,
'strategy_{}'
.
format
(
str
(
strategy_idx
+
1
)))
self
.
tmp_dir
,
'strategy_{}'
.
format
(
str
(
strategy_idx
+
1
)))
final_model_path
=
os
.
path
.
join
(
self
.
final_dir
)
final_model_path
=
os
.
path
.
join
(
self
.
final_dir
)
...
@@ -507,7 +563,8 @@ class AutoCompression:
...
@@ -507,7 +563,8 @@ class AutoCompression:
format
(
final_model_path
))
format
(
final_model_path
))
os
.
_exit
(
0
)
os
.
_exit
(
0
)
def
single_strategy_compress
(
self
,
strategy
,
config
,
strategy_idx
):
def
single_strategy_compress
(
self
,
strategy
,
config
,
strategy_idx
,
train_config
):
# start compress, including train/eval model
# start compress, including train/eval model
# TODO: add the emd loss of evaluation model.
# TODO: add the emd loss of evaluation model.
if
strategy
==
'quant_post'
:
if
strategy
==
'quant_post'
:
...
@@ -581,19 +638,19 @@ class AutoCompression:
...
@@ -581,19 +638,19 @@ class AutoCompression:
### used to check whether the dataloader is right
### used to check whether the dataloader is right
self
.
metric_before_compressed
=
None
self
.
metric_before_compressed
=
None
if
self
.
eval_function
is
not
None
and
self
.
train_config
.
origin_metric
is
not
None
:
if
self
.
eval_function
is
not
None
and
train_config
.
origin_metric
is
not
None
:
_logger
.
info
(
"start to test metric before compress"
)
_logger
.
info
(
"start to test metric before compress"
)
metric
=
self
.
eval_function
(
self
.
_exe
,
inference_program
,
metric
=
self
.
eval_function
(
self
.
_exe
,
inference_program
,
feed_target_names
,
fetch_targets
)
feed_target_names
,
fetch_targets
)
_logger
.
info
(
"metric of compressed model is: {}"
.
format
(
metric
))
_logger
.
info
(
"metric of compressed model is: {}"
.
format
(
metric
))
buf
=
0.05
buf
=
0.05
if
metric
<
(
float
(
self
.
train_config
.
origin_metric
)
-
buf
)
or
\
if
metric
<
(
float
(
train_config
.
origin_metric
)
-
buf
)
or
\
metric
>
(
float
(
self
.
train_config
.
origin_metric
)
+
buf
):
metric
>
(
float
(
train_config
.
origin_metric
)
+
buf
):
raise
RuntimeError
(
"target metric of pretrained model is {},
\
raise
RuntimeError
(
"target metric of pretrained model is {},
\
but now is {}, Please check the format of evaluation dataset
\
but now is {}, Please check the format of evaluation dataset
\
or check the origin_metric in train_config"
or check the origin_metric in train_config"
.
format
(
\
.
format
(
\
self
.
train_config
.
origin_metric
,
metric
))
train_config
.
origin_metric
,
metric
))
self
.
metric_before_compressed
=
metric
self
.
metric_before_compressed
=
metric
patterns
,
default_distill_node_pair
,
_
=
get_patterns
(
patterns
,
default_distill_node_pair
,
_
=
get_patterns
(
...
@@ -601,15 +658,16 @@ class AutoCompression:
...
@@ -601,15 +658,16 @@ class AutoCompression:
train_program_info
,
test_program_info
=
self
.
_prepare_program
(
train_program_info
,
test_program_info
=
self
.
_prepare_program
(
inference_program
,
feed_target_names
,
fetch_targets
,
patterns
,
inference_program
,
feed_target_names
,
fetch_targets
,
patterns
,
default_distill_node_pair
,
strategy
,
config
)
default_distill_node_pair
,
strategy
,
config
,
train_config
)
if
'unstructure'
in
self
.
_strategy
:
if
'unstructure'
in
self
.
_strategy
:
test_program_info
.
program
.
_program
=
remove_unused_var_nodes
(
test_program_info
.
program
.
_program
=
remove_unused_var_nodes
(
test_program_info
.
program
.
_program
)
test_program_info
.
program
.
_program
)
test_program_info
=
self
.
_start_train
(
train_program_info
,
test_program_info
=
self
.
_start_train
(
test_program_info
,
strategy
)
train_program_info
,
test_program_info
,
strategy
,
train_config
)
self
.
_save_model
(
test_program_info
,
strategy
,
strategy_idx
)
self
.
_save_model
(
test_program_info
,
strategy
,
strategy_idx
)
def
_start_train
(
self
,
train_program_info
,
test_program_info
,
strategy
):
def
_start_train
(
self
,
train_program_info
,
test_program_info
,
strategy
,
train_config
):
best_metric
=
-
1.0
best_metric
=
-
1.0
total_epochs
=
self
.
train_config
.
epochs
if
self
.
train_config
.
epochs
else
100
total_epochs
=
self
.
train_config
.
epochs
if
self
.
train_config
.
epochs
else
100
total_train_iter
=
0
total_train_iter
=
0
...
@@ -623,10 +681,10 @@ class AutoCompression:
...
@@ -623,10 +681,10 @@ class AutoCompression:
if
'unstructure'
in
strategy
:
if
'unstructure'
in
strategy
:
self
.
_pruner
.
step
()
self
.
_pruner
.
step
()
if
self
.
train_config
.
logging_iter
is
None
:
if
train_config
.
logging_iter
is
None
:
logging_iter
=
10
logging_iter
=
10
else
:
else
:
logging_iter
=
self
.
train_config
.
logging_iter
logging_iter
=
train_config
.
logging_iter
if
batch_id
%
int
(
logging_iter
)
==
0
:
if
batch_id
%
int
(
logging_iter
)
==
0
:
_logger
.
info
(
_logger
.
info
(
"Total iter: {}, epoch: {}, batch: {}, loss: {}"
.
format
(
"Total iter: {}, epoch: {}, batch: {}, loss: {}"
.
format
(
...
@@ -661,8 +719,8 @@ class AutoCompression:
...
@@ -661,8 +719,8 @@ class AutoCompression:
self
.
metric_before_compressed
)
self
.
metric_before_compressed
)
)
/
self
.
metric_before_compressed
<=
0.005
:
)
/
self
.
metric_before_compressed
<=
0.005
:
break
break
if
self
.
train_config
.
target_metric
is
not
None
:
if
train_config
.
target_metric
is
not
None
:
if
metric
>
float
(
self
.
train_config
.
target_metric
):
if
metric
>
float
(
train_config
.
target_metric
):
break
break
else
:
else
:
...
@@ -672,7 +730,7 @@ class AutoCompression:
...
@@ -672,7 +730,7 @@ class AutoCompression:
if
self
.
train_config
.
train_iter
and
total_train_iter
>=
self
.
train_config
.
train_iter
:
if
self
.
train_config
.
train_iter
and
total_train_iter
>=
self
.
train_config
.
train_iter
:
break
break
if
'unstructure'
in
self
.
_strategy
or
self
.
train_config
.
sparse_model
:
if
'unstructure'
in
self
.
_strategy
or
train_config
.
sparse_model
:
self
.
_pruner
.
update_params
()
self
.
_pruner
.
update_params
()
return
test_program_info
return
test_program_info
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录