Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
9a17d02c
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9a17d02c
编写于
10月 12, 2019
作者:
W
whs
提交者:
GitHub
10月 12, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[PaddleSlim] Add pruning demo for yolov3 (#3414)
上级
fe350028
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
480 addition
and
0 deletion
+480
-0
slim/prune/README.md
slim/prune/README.md
+199
-0
slim/prune/compress.py
slim/prune/compress.py
+258
-0
slim/prune/images/MobileNetV1-YoloV3.pdf
slim/prune/images/MobileNetV1-YoloV3.pdf
+0
-0
slim/prune/yolov3_mobilenet_v1_slim.yaml
slim/prune/yolov3_mobilenet_v1_slim.yaml
+23
-0
未找到文件。
slim/prune/README.md
0 → 100644
浏览文件 @
9a17d02c
>运行该示例前请安装Paddle1.6或更高版本
# 检测模型卷积通道剪裁示例
## 概述
该示例使用PaddleSlim提供的
[
卷积通道剪裁压缩策略
](
https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/tutorial.md#2-%E5%8D%B7%E7%A7%AF%E6%A0%B8%E5%89%AA%E8%A3%81%E5%8E%9F%E7%90%86
)
对检测库中的模型进行压缩。
在阅读该示例前,建议您先了解以下内容:
-
<a
href=
"../..README_cn.md"
>
检测库的常规训练方法
</a>
-
[
PaddleSlim使用文档
](
https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md
)
## 配置文件说明
关于配置文件如何编写您可以参考:
-
[
PaddleSlim配置文件编写说明
](
https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md#122-%E9%85%8D%E7%BD%AE%E6%96%87%E4%BB%B6%E7%9A%84%E4%BD%BF%E7%94%A8
)
-
[
裁剪策略配置文件编写说明
](
https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md#22-%E6%A8%A1%E5%9E%8B%E9%80%9A%E9%81%93%E5%89%AA%E8%A3%81
)
其中,配置文件中的
`pruned_params`
需要根据当前模型的网络结构特点设置,它用来指定要裁剪的parameters.
这里以MobileNetV1-YoloV3模型为例,其卷积可以三种:主干网络中的普通卷积,主干网络中的
`depthwise convolution`
和
`yolo block`
里的普通卷积。PaddleSlim暂时无法对
`depthwise convolution`
直接进行剪裁, 因为
`depthwise convolution`
的
`channel`
的变化会同时影响到前后的卷积层。我们这里只对主干网络中的普通卷积和
`yolo block`
里的普通卷积做裁剪。
通过以下方式可视化模型结构:
```
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
graph = IrGraph(core.Graph(train_prog.desc), for_test=True)
marked_nodes = set()
for op in graph.all_op_nodes():
print(op.name())
if op.name().find('conv') > -1:
marked_nodes.add(op)
graph.draw('.', 'forward', marked_nodes)
```
该示例中MobileNetV1-YoloV3模型结构的可视化结果:
<a
href=
"./images/MobileNetV1-YoloV3.pdf"
>
MobileNetV1-YoloV3.pdf
</a>
同时通过以下命令观察目标卷积层的参数(parameters)的名称和shape:
```
for param in fluid.default_main_program().global_block().all_parameters():
if 'weights' in param.name:
print param.name, param.shape
```
从可视化结果,我们可以排除后续会做concat的卷积层,最终得到如下要裁剪的参数名称:
```
conv2_1_sep_weights
conv2_2_sep_weights
conv3_1_sep_weights
conv4_1_sep_weights
conv5_1_sep_weights
conv5_2_sep_weights
conv5_3_sep_weights
conv5_4_sep_weights
conv5_5_sep_weights
conv5_6_sep_weights
yolo_block.0.0.0.conv.weights
yolo_block.0.0.1.conv.weights
yolo_block.0.1.0.conv.weights
yolo_block.0.1.1.conv.weights
yolo_block.1.0.0.conv.weights
yolo_block.1.0.1.conv.weights
yolo_block.1.1.0.conv.weights
yolo_block.1.1.1.conv.weights
yolo_block.1.2.conv.weights
yolo_block.2.0.0.conv.weights
yolo_block.2.0.1.conv.weights
yolo_block.2.1.1.conv.weights
yolo_block.2.2.conv.weights
yolo_block.2.tip.conv.weights
```
```
(conv2_1_sep_weights)|(conv2_2_sep_weights)|(conv3_1_sep_weights)|(conv4_1_sep_weights)|(conv5_1_sep_weights)|(conv5_2_sep_weights)|(conv5_3_sep_weights)|(conv5_4_sep_weights)|(conv5_5_sep_weights)|(conv5_6_sep_weights)|(yolo_block.0.0.0.conv.weights)|(yolo_block.0.0.1.conv.weights)|(yolo_block.0.1.0.conv.weights)|(yolo_block.0.1.1.conv.weights)|(yolo_block.1.0.0.conv.weights)|(yolo_block.1.0.1.conv.weights)|(yolo_block.1.1.0.conv.weights)|(yolo_block.1.1.1.conv.weights)|(yolo_block.1.2.conv.weights)|(yolo_block.2.0.0.conv.weights)|(yolo_block.2.0.1.conv.weights)|(yolo_block.2.1.1.conv.weights)|(yolo_block.2.2.conv.weights)|(yolo_block.2.tip.conv.weights)
```
综上,我们将MobileNetV2配置文件中的
`pruned_params`
设置为以下正则表达式:
```
(conv2_1_sep_weights)|(conv2_2_sep_weights)|(conv3_1_sep_weights)|(conv4_1_sep_weights)|(conv5_1_sep_weights)|(conv5_2_sep_weights)|(conv5_3_sep_weights)|(conv5_4_sep_weights)|(conv5_5_sep_weights)|(conv5_6_sep_weights)|(yolo_block.0.0.0.conv.weights)|(yolo_block.0.0.1.conv.weights)|(yolo_block.0.1.0.conv.weights)|(yolo_block.0.1.1.conv.weights)|(yolo_block.1.0.0.conv.weights)|(yolo_block.1.0.1.conv.weights)|(yolo_block.1.1.0.conv.weights)|(yolo_block.1.1.1.conv.weights)|(yolo_block.1.2.conv.weights)|(yolo_block.2.0.0.conv.weights)|(yolo_block.2.0.1.conv.weights)|(yolo_block.2.1.1.conv.weights)|(yolo_block.2.2.conv.weights)|(yolo_block.2.tip.conv.weights)
```
我们可以用上述操作观察其它检测模型的参数名称规律,然后设置合适的正则表达式来剪裁合适的参数。
## 训练
根据
<a
href=
"../../tools/train.py"
>
PaddleDetection/tools/train.py
</a>
编写压缩脚本compress.py。
在该脚本中定义了Compressor对象,用于执行压缩任务。
### 执行示例
step1: 设置gpu卡
```
export CUDA_VISIBLE_DEVICES=0
```
step2: 开始训练
使用PaddleDetection提供的配置文件在用8卡进行训练:
```
python compress.py \
-s yolov3_mobilenet_v1_slim.yaml \
-c ../../configs/yolov3_mobilenet_v1_voc.yml \
-o max_iters=258 \
-d "../../dataset/voc"
```
>通过命令行覆盖设置max_iters选项,因为PaddleDetection中训练是以`batch`为单位迭代的,并没有涉及`epoch`的概念,但是PaddleSlim需要知道当前训练进行到第几个`epoch`, 所以需要将`max_iters`设置为一个`epoch`内的`batch`的数量。
如果要调整训练卡数,需要调整配置文件
`yolov3_mobilenet_v1_voc.yml`
中的以下参数:
-
**max_iters:**
一个
`epoch`
中batch的数量,需要设置为
`total_num / batch_size`
, 其中
`total_num`
为训练样本总数量,
`batch_size`
为多卡上总的batch size.
-
**YoloTrainFeed.batch_size:**
单张卡上的batch size, 受限于显存大小。
-
**LeaningRate.base_lr:**
根据多卡的总
`batch_size`
调整
`base_lr`
,两者大小正相关,可以简单的按比例进行调整。
-
**LearningRate.schedulers.PiecewiseDecay.milestones:**
请根据batch size的变化对其调整。
-
**LearningRate.schedulers.PiecewiseDecay.LinearWarmup.steps:**
请根据batch size的变化对其进行调整。
以下为4卡训练示例,通过命令行覆盖
`yolov3_mobilenet_v1_voc.yml`
中的参数:
```
python compress.py \
-s yolov3_mobilenet_v1_slim.yaml \
-c ../../configs/yolov3_mobilenet_v1_voc.yml \
-o max_iters=258 \
-o YoloTrainFeed.batch_size = 16 \
-d "../../dataset/voc"
```
以下为2卡训练示例,受显存所制,单卡
`batch_size`
不变,总
`batch_size`
减小,
`base_lr`
减小,一个epoch内batch数量增加,同时需要调整学习率相关参数,如下:
```
python compress.py \
-s yolov3_mobilenet_v1_slim.yaml \
-c ../../configs/yolov3_mobilenet_v1_voc.yml \
-o max_iters=516 \
-o LeaningRate.base_lr=0.005 \ # 0.001 /2
-o YoloTrainFeed.batch_size = 16 \
-o LearningRate.schedulers='[!PiecewiseDecay {gamma: 0.1, milestones: [110000, 124000]}, !LinearWarmup {start_factor: 0., steps: 2000}]' \
-d "../../dataset/voc"
```
通过
`python compress.py --help`
查看可配置参数。
通过
`python ../../tools/configure.py ${option_name} help`
查看如何通过命令行覆盖配置文件
`yolov3_mobilenet_v1_voc.yml`
中的参数。
### 保存断点(checkpoint)
如果在配置文件中设置了
`checkpoint_path`
, 则在压缩任务执行过程中会自动保存断点,当任务异常中断时,
重启任务会自动从
`checkpoint_path`
路径下按数字顺序加载最新的checkpoint文件。如果不想让重启的任务从断点恢复,
需要修改配置文件中的
`checkpoint_path`
,或者将
`checkpoint_path`
路径下文件清空。
>注意:配置文件中的信息不会保存在断点中,重启前对配置文件的修改将会生效。
## 评估
如果在配置文件中设置了
`checkpoint_path`
,则每个epoch会保存一个压缩后的用于评估的模型,
该模型会保存在
`${checkpoint_path}/${epoch_id}/eval_model/`
路径下,包含
`__model__`
和
`__params__`
两个文件。
其中,
`__model__`
用于保存模型结构信息,
`__params__`
用于保存参数(parameters)信息。
如果不需要保存评估模型,可以在定义Compressor对象时,将
`save_eval_model`
选项设置为False(默认为True)。
## 预测
如果在配置文件中设置了
`checkpoint_path`
,并且在定义Compressor对象时指定了
`prune_infer_model`
选项,则每个epoch都会
保存一个
`inference model`
。该模型是通过删除eval_program中多余的operators而得到的。
该模型会保存在
`${checkpoint_path}/${epoch_id}/eval_model/`
路径下,包含
`__model__.infer`
和
`__params__`
两个文件。
其中,
`__model__.infer`
用于保存模型结构信息,
`__params__`
用于保存参数(parameters)信息。
更多关于
`prune_infer_model`
选项的介绍,请参考:
[
Compressor介绍
](
https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md#121-%E5%A6%82%E4%BD%95%E6%94%B9%E5%86%99%E6%99%AE%E9%80%9A%E8%AE%AD%E7%BB%83%E8%84%9A%E6%9C%AC
)
### python预测
在脚本
<a
href=
"../infer.py"
>
PaddleDetection/tools/infer.py
</a>
中展示了如何使用fluid python API加载使用预测模型进行预测。
### PaddleLite
该示例中产出的预测(inference)模型可以直接用PaddleLite进行加载使用。
关于PaddleLite如何使用,请参考:
[
PaddleLite使用文档
](
https://github.com/PaddlePaddle/Paddle-Lite/wiki#%E4%BD%BF%E7%94%A8
)
## 示例结果
### MobileNetV1-YOLO-V3
| FLOPS |top1_acc/top5_acc| model_size |Paddle Fluid inference time(ms)| Paddle Lite inference time(ms)|
|---|---|---|---|---|
|baseline|- |- |- |-|
|-10%|- |- |- |-|
|-30%|- |- |- |-|
|-50%|- |- |- |-|
## FAQ
slim/prune/compress.py
0 → 100644
浏览文件 @
9a17d02c
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
time
import
multiprocessing
import
numpy
as
np
import
sys
sys
.
path
.
append
(
"../../"
)
from
paddle.fluid.contrib.slim
import
Compressor
def
set_paddle_flags
(
**
kwargs
):
for
key
,
value
in
kwargs
.
items
():
if
os
.
environ
.
get
(
key
,
None
)
is
None
:
os
.
environ
[
key
]
=
str
(
value
)
# NOTE(paddle-dev): All of these flags should be set before
# `import paddle`. Otherwise, it would not take any effect.
set_paddle_flags
(
FLAGS_eager_delete_tensor_gb
=
0
,
# enable GC to save memory
)
from
paddle
import
fluid
from
ppdet.core.workspace
import
load_config
,
merge_config
,
create
from
ppdet.data.data_feed
import
create_reader
from
ppdet.utils.eval_utils
import
parse_fetches
,
eval_results
from
ppdet.utils.cli
import
ArgsParser
from
ppdet.utils.check
import
check_gpu
import
ppdet.utils.checkpoint
as
checkpoint
from
ppdet.modeling.model_input
import
create_feed
import
logging
FORMAT
=
'%(asctime)s-%(levelname)s: %(message)s'
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
FORMAT
)
logger
=
logging
.
getLogger
(
__name__
)
def
eval_run
(
exe
,
compile_program
,
reader
,
keys
,
values
,
cls
,
test_feed
):
"""
Run evaluation program, return program outputs.
"""
iter_id
=
0
results
=
[]
if
len
(
cls
)
!=
0
:
values
=
[]
for
i
in
range
(
len
(
cls
)):
_
,
accum_map
=
cls
[
i
].
get_map_var
()
cls
[
i
].
reset
(
exe
)
values
.
append
(
accum_map
)
images_num
=
0
start_time
=
time
.
time
()
has_bbox
=
'bbox'
in
keys
for
data
in
reader
():
data
=
test_feed
.
feed
(
data
)
feed_data
=
{
'image'
:
data
[
'image'
],
'im_size'
:
data
[
'im_size'
]}
outs
=
exe
.
run
(
compile_program
,
feed
=
feed_data
,
fetch_list
=
[
values
[
0
]],
return_numpy
=
False
)
outs
.
append
(
data
[
'gt_box'
])
outs
.
append
(
data
[
'gt_label'
])
outs
.
append
(
data
[
'is_difficult'
])
res
=
{
k
:
(
np
.
array
(
v
),
v
.
recursive_sequence_lengths
())
for
k
,
v
in
zip
(
keys
,
outs
)
}
results
.
append
(
res
)
if
iter_id
%
100
==
0
:
logger
.
info
(
'Test iter {}'
.
format
(
iter_id
))
iter_id
+=
1
images_num
+=
len
(
res
[
'bbox'
][
1
][
0
])
if
has_bbox
else
1
logger
.
info
(
'Test finish iter {}'
.
format
(
iter_id
))
end_time
=
time
.
time
()
fps
=
images_num
/
(
end_time
-
start_time
)
if
has_bbox
:
logger
.
info
(
'Total number of images: {}, inference time: {} fps.'
.
format
(
images_num
,
fps
))
else
:
logger
.
info
(
'Total iteration: {}, inference time: {} batch/s.'
.
format
(
images_num
,
fps
))
return
results
def
main
():
cfg
=
load_config
(
FLAGS
.
config
)
if
'architecture'
in
cfg
:
main_arch
=
cfg
.
architecture
else
:
raise
ValueError
(
"'architecture' not specified in config file."
)
merge_config
(
FLAGS
.
opt
)
if
'log_iter'
not
in
cfg
:
cfg
.
log_iter
=
20
# check if set use_gpu=True in paddlepaddle cpu version
check_gpu
(
cfg
.
use_gpu
)
if
cfg
.
use_gpu
:
devices_num
=
fluid
.
core
.
get_cuda_device_count
()
else
:
devices_num
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
multiprocessing
.
cpu_count
()))
if
'train_feed'
not
in
cfg
:
train_feed
=
create
(
main_arch
+
'TrainFeed'
)
else
:
train_feed
=
create
(
cfg
.
train_feed
)
if
'eval_feed'
not
in
cfg
:
eval_feed
=
create
(
main_arch
+
'EvalFeed'
)
else
:
eval_feed
=
create
(
cfg
.
eval_feed
)
place
=
fluid
.
CUDAPlace
(
0
)
if
cfg
.
use_gpu
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
lr_builder
=
create
(
'LearningRate'
)
optim_builder
=
create
(
'OptimizerBuilder'
)
# build program
startup_prog
=
fluid
.
Program
()
train_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
train_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
model
=
create
(
main_arch
)
train_loader
,
feed_vars
=
create_feed
(
train_feed
,
iterable
=
True
)
train_fetches
=
model
.
train
(
feed_vars
)
loss
=
train_fetches
[
'loss'
]
lr
=
lr_builder
()
optimizer
=
optim_builder
(
lr
)
optimizer
.
minimize
(
loss
)
train_reader
=
create_reader
(
train_feed
,
cfg
.
max_iters
*
devices_num
,
FLAGS
.
dataset_dir
)
train_loader
.
set_sample_list_generator
(
train_reader
,
place
)
# parse train fetches
train_keys
,
train_values
,
_
=
parse_fetches
(
train_fetches
)
train_keys
.
append
(
"lr"
)
train_values
.
append
(
lr
.
name
)
train_fetch_list
=
[]
for
k
,
v
in
zip
(
train_keys
,
train_values
):
train_fetch_list
.
append
((
k
,
v
))
eval_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
eval_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
model
=
create
(
main_arch
)
_
,
test_feed_vars
=
create_feed
(
eval_feed
,
iterable
=
True
)
fetches
=
model
.
eval
(
test_feed_vars
)
eval_prog
=
eval_prog
.
clone
(
True
)
eval_reader
=
create_reader
(
eval_feed
,
args_path
=
FLAGS
.
dataset_dir
)
#eval_pyreader.decorate_sample_list_generator(eval_reader, place)
test_data_feed
=
fluid
.
DataFeeder
(
test_feed_vars
.
values
(),
place
)
# parse eval fetches
extra_keys
=
[]
if
cfg
.
metric
==
'COCO'
:
extra_keys
=
[
'im_info'
,
'im_id'
,
'im_shape'
]
if
cfg
.
metric
==
'VOC'
:
extra_keys
=
[
'gt_box'
,
'gt_label'
,
'is_difficult'
]
eval_keys
,
eval_values
,
eval_cls
=
parse_fetches
(
fetches
,
eval_prog
,
extra_keys
)
eval_fetch_list
=
[]
for
k
,
v
in
zip
(
eval_keys
,
eval_values
):
eval_fetch_list
.
append
((
k
,
v
))
exe
.
run
(
startup_prog
)
checkpoint
.
load_params
(
exe
,
train_prog
,
cfg
.
pretrain_weights
)
best_box_ap_list
=
[]
def
eval_func
(
program
,
scope
):
#place = fluid.CPUPlace()
#exe = fluid.Executor(place)
results
=
eval_run
(
exe
,
program
,
eval_reader
,
eval_keys
,
eval_values
,
eval_cls
,
test_data_feed
)
resolution
=
None
if
'mask'
in
results
[
0
]:
resolution
=
model
.
mask_head
.
resolution
box_ap_stats
=
eval_results
(
results
,
eval_feed
,
cfg
.
metric
,
cfg
.
num_classes
,
resolution
,
False
,
FLAGS
.
output_eval
)
if
len
(
best_box_ap_list
)
==
0
:
best_box_ap_list
.
append
(
box_ap_stats
[
0
])
elif
box_ap_stats
[
0
]
>
best_box_ap_list
[
0
]:
best_box_ap_list
[
0
]
=
box_ap_stats
[
0
]
checkpoint
.
save
(
exe
,
train_prog
,
os
.
path
.
join
(
save_dir
,
"best_model"
))
logger
.
info
(
"Best test box ap: {}"
.
format
(
best_box_ap_list
[
0
]))
return
best_box_ap_list
[
0
]
test_feed
=
[(
'image'
,
test_feed_vars
[
'image'
].
name
),
(
'im_size'
,
test_feed_vars
[
'im_size'
].
name
)]
com
=
Compressor
(
place
,
fluid
.
global_scope
(),
train_prog
,
train_reader
=
train_reader
,
train_feed_list
=
[(
key
,
value
.
name
)
for
key
,
value
in
feed_vars
.
items
()],
train_fetch_list
=
train_fetch_list
,
eval_program
=
eval_prog
,
eval_reader
=
eval_reader
,
eval_feed_list
=
test_feed
,
eval_func
=
{
'map'
:
eval_func
},
eval_fetch_list
=
[
eval_fetch_list
[
0
]],
save_eval_model
=
True
,
prune_infer_model
=
[[
"image"
,
"im_size"
],[
"multiclass_nms_0.tmp_0"
]],
train_optimizer
=
None
)
com
.
config
(
FLAGS
.
slim_file
)
com
.
run
()
if
__name__
==
'__main__'
:
parser
=
ArgsParser
()
parser
.
add_argument
(
"-s"
,
"--slim_file"
,
default
=
None
,
type
=
str
,
help
=
"Config file of PaddleSlim."
)
parser
.
add_argument
(
"--output_eval"
,
default
=
None
,
type
=
str
,
help
=
"Evaluation directory, default is current directory."
)
parser
.
add_argument
(
"-d"
,
"--dataset_dir"
,
default
=
None
,
type
=
str
,
help
=
"Dataset path, same as DataFeed.dataset.dataset_dir"
)
FLAGS
=
parser
.
parse_args
()
main
()
slim/prune/images/MobileNetV1-YoloV3.pdf
0 → 100644
浏览文件 @
9a17d02c
文件已添加
slim/prune/yolov3_mobilenet_v1_slim.yaml
0 → 100644
浏览文件 @
9a17d02c
version
:
1.0
pruners
:
pruner_1
:
class
:
'
StructurePruner'
pruning_axis
:
'
*'
:
0
criterions
:
'
*'
:
'
l1_norm'
strategies
:
uniform_pruning_strategy
:
class
:
'
UniformPruneStrategy'
pruner
:
'
pruner_1'
start_epoch
:
0
target_ratio
:
0.5
pruned_params
:
'
(conv2_1_sep_weights)|(conv2_2_sep_weights)|(conv3_1_sep_weights)|(conv4_1_sep_weights)|(conv5_1_sep_weights)|(conv5_2_sep_weights)|(conv5_3_sep_weights)|(conv5_4_sep_weights)|(conv5_5_sep_weights)|(conv5_6_sep_weights)|(yolo_block.0.0.0.conv.weights)|(yolo_block.0.0.1.conv.weights)|(yolo_block.0.1.0.conv.weights)|(yolo_block.0.1.1.conv.weights)|(yolo_block.1.0.0.conv.weights)|(yolo_block.1.0.1.conv.weights)|(yolo_block.1.1.0.conv.weights)|(yolo_block.1.1.1.conv.weights)|(yolo_block.1.2.conv.weights)|(yolo_block.2.0.0.conv.weights)|(yolo_block.2.0.1.conv.weights)|(yolo_block.2.1.1.conv.weights)|(yolo_block.2.2.conv.weights)|(yolo_block.2.tip.conv.weights)'
metric_name
:
'
acc_top1'
compressor
:
epoch
:
271
eval_epoch
:
10
#init_model: './checkpoints/0' # Please enable this option for loading checkpoint.
checkpoint_path
:
'
./checkpoints/'
strategies
:
-
uniform_pruning_strategy
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录