Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
bb0f8fbb
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
大约 2 年 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
bb0f8fbb
编写于
2月 20, 2020
作者:
W
wanghaoshuang
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/PaddleSlim
into base
上级
ae093d57
2cd6b44d
变更
18
展开全部
隐藏空白更改
内联
并排
Showing
18 changed file
with
402 addition
and
178 deletion
+402
-178
demo/nas/block_sa_nas_mobilenetv2.py
demo/nas/block_sa_nas_mobilenetv2.py
+10
-16
demo/nas/sa_nas_mobilenetv2.py
demo/nas/sa_nas_mobilenetv2.py
+11
-17
demo/pantheon/run_student.py
demo/pantheon/run_student.py
+2
-2
demo/prune/README.md
demo/prune/README.md
+19
-6
demo/prune/eval.py
demo/prune/eval.py
+1
-1
demo/prune/train.py
demo/prune/train.py
+7
-1
docs/en/model_zoo_en.md
docs/en/model_zoo_en.md
+185
-101
docs/zh_cn/api_cn/nas_api.rst
docs/zh_cn/api_cn/nas_api.rst
+3
-1
docs/zh_cn/api_cn/pantheon_api.md
docs/zh_cn/api_cn/pantheon_api.md
+3
-2
docs/zh_cn/api_cn/prune_api.rst
docs/zh_cn/api_cn/prune_api.rst
+1
-1
docs/zh_cn/model_zoo.md
docs/zh_cn/model_zoo.md
+24
-4
paddleslim/__init__.py
paddleslim/__init__.py
+2
-1
paddleslim/nas/sa_nas.py
paddleslim/nas/sa_nas.py
+4
-1
paddleslim/pantheon/README.md
paddleslim/pantheon/README.md
+1
-1
paddleslim/pantheon/student.py
paddleslim/pantheon/student.py
+5
-5
paddleslim/pantheon/teacher.py
paddleslim/pantheon/teacher.py
+19
-11
paddleslim/prune/pruner.py
paddleslim/prune/pruner.py
+23
-7
tests/test_slim_prune.py
tests/test_slim_prune.py
+82
-0
未找到文件。
demo/nas/block_sa_nas_mobilenetv2.py
浏览文件 @
bb0f8fbb
...
@@ -16,13 +16,6 @@ import imagenet_reader
...
@@ -16,13 +16,6 @@ import imagenet_reader
_logger
=
get_logger
(
__name__
,
level
=
logging
.
INFO
)
_logger
=
get_logger
(
__name__
,
level
=
logging
.
INFO
)
reduce_rate
=
0.85
init_temperature
=
10.24
max_flops
=
321208544
server_address
=
""
port
=
8979
retain_epoch
=
5
def
create_data_loader
(
image_shape
):
def
create_data_loader
(
image_shape
):
data_shape
=
[
None
]
+
image_shape
data_shape
=
[
None
]
+
image_shape
...
@@ -71,17 +64,13 @@ def search_mobilenetv2_block(config, args, image_size):
...
@@ -71,17 +64,13 @@ def search_mobilenetv2_block(config, args, image_size):
if
args
.
is_server
:
if
args
.
is_server
:
sa_nas
=
SANAS
(
sa_nas
=
SANAS
(
config
,
config
,
server_addr
=
(
""
,
port
),
server_addr
=
(
args
.
server_address
,
args
.
port
),
init_temperature
=
init_temperature
,
reduce_rate
=
reduce_rate
,
search_steps
=
args
.
search_steps
,
search_steps
=
args
.
search_steps
,
is_server
=
True
)
is_server
=
True
)
else
:
else
:
sa_nas
=
SANAS
(
sa_nas
=
SANAS
(
config
,
config
,
server_addr
=
(
server_address
,
port
),
server_addr
=
(
args
.
server_address
,
args
.
port
),
init_temperature
=
init_temperature
,
reduce_rate
=
reduce_rate
,
search_steps
=
args
.
search_steps
,
search_steps
=
args
.
search_steps
,
is_server
=
False
)
is_server
=
False
)
...
@@ -140,7 +129,7 @@ def search_mobilenetv2_block(config, args, image_size):
...
@@ -140,7 +129,7 @@ def search_mobilenetv2_block(config, args, image_size):
current_flops
=
flops
(
train_program
)
current_flops
=
flops
(
train_program
)
print
(
'step: {}, current_flops: {}'
.
format
(
step
,
current_flops
))
print
(
'step: {}, current_flops: {}'
.
format
(
step
,
current_flops
))
if
current_flops
>
max_flops
:
if
current_flops
>
int
(
321208544
)
:
continue
continue
place
=
fluid
.
CUDAPlace
(
0
)
if
args
.
use_gpu
else
fluid
.
CPUPlace
()
place
=
fluid
.
CUDAPlace
(
0
)
if
args
.
use_gpu
else
fluid
.
CPUPlace
()
...
@@ -178,7 +167,7 @@ def search_mobilenetv2_block(config, args, image_size):
...
@@ -178,7 +167,7 @@ def search_mobilenetv2_block(config, args, image_size):
train_compiled_program
=
fluid
.
CompiledProgram
(
train_compiled_program
=
fluid
.
CompiledProgram
(
train_program
).
with_data_parallel
(
train_program
).
with_data_parallel
(
loss_name
=
avg_cost
.
name
,
build_strategy
=
build_strategy
)
loss_name
=
avg_cost
.
name
,
build_strategy
=
build_strategy
)
for
epoch_id
in
range
(
retain_epoch
):
for
epoch_id
in
range
(
args
.
retain_epoch
):
for
batch_id
,
data
in
enumerate
(
train_loader
()):
for
batch_id
,
data
in
enumerate
(
train_loader
()):
fetches
=
[
avg_cost
.
name
]
fetches
=
[
avg_cost
.
name
]
s_time
=
time
.
time
()
s_time
=
time
.
time
()
...
@@ -243,6 +232,11 @@ if __name__ == '__main__':
...
@@ -243,6 +232,11 @@ if __name__ == '__main__':
type
=
int
,
type
=
int
,
default
=
100
,
default
=
100
,
help
=
'controller server number.'
)
help
=
'controller server number.'
)
parser
.
add_argument
(
'--server_address'
,
type
=
str
,
default
=
""
,
help
=
'server ip.'
)
parser
.
add_argument
(
'--port'
,
type
=
int
,
default
=
8881
,
help
=
'server port'
)
parser
.
add_argument
(
'--retain_epoch'
,
type
=
int
,
default
=
5
,
help
=
'epoch for each token.'
)
parser
.
add_argument
(
'--lr'
,
type
=
float
,
default
=
0.1
,
help
=
'learning rate.'
)
parser
.
add_argument
(
'--lr'
,
type
=
float
,
default
=
0.1
,
help
=
'learning rate.'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print
(
args
)
print
(
args
)
...
@@ -257,7 +251,7 @@ if __name__ == '__main__':
...
@@ -257,7 +251,7 @@ if __name__ == '__main__':
args
.
data
))
args
.
data
))
# block mask means block number, 1 mean downsample, 0 means the size of feature map don't change after this block
# block mask means block number, 1 mean downsample, 0 means the size of feature map don't change after this block
config_info
=
{
'block_mask'
:
[
0
,
1
,
1
,
1
,
1
,
0
,
1
,
0
]}
config_info
=
{
'block_mask'
:
[
0
,
1
,
1
,
1
,
0
]}
config
=
[(
'MobileNetV2BlockSpace'
,
config_info
)]
config
=
[(
'MobileNetV2BlockSpace'
,
config_info
)]
search_mobilenetv2_block
(
config
,
args
,
image_size
)
search_mobilenetv2_block
(
config
,
args
,
image_size
)
demo/nas/sa_nas_mobilenetv2.py
浏览文件 @
bb0f8fbb
...
@@ -18,13 +18,6 @@ import imagenet_reader
...
@@ -18,13 +18,6 @@ import imagenet_reader
_logger
=
get_logger
(
__name__
,
level
=
logging
.
INFO
)
_logger
=
get_logger
(
__name__
,
level
=
logging
.
INFO
)
reduce_rate
=
0.85
init_temperature
=
10.24
max_flops
=
321208544
server_address
=
""
port
=
8989
retain_epoch
=
5
def
create_data_loader
(
image_shape
):
def
create_data_loader
(
image_shape
):
data_shape
=
[
None
]
+
image_shape
data_shape
=
[
None
]
+
image_shape
...
@@ -66,18 +59,14 @@ def search_mobilenetv2(config, args, image_size, is_server=True):
...
@@ -66,18 +59,14 @@ def search_mobilenetv2(config, args, image_size, is_server=True):
### start a server and a client
### start a server and a client
sa_nas
=
SANAS
(
sa_nas
=
SANAS
(
config
,
config
,
server_addr
=
(
""
,
port
),
server_addr
=
(
args
.
server_address
,
args
.
port
),
init_temperature
=
init_temperature
,
reduce_rate
=
reduce_rate
,
search_steps
=
args
.
search_steps
,
search_steps
=
args
.
search_steps
,
is_server
=
True
)
is_server
=
True
)
else
:
else
:
### start a client
### start a client
sa_nas
=
SANAS
(
sa_nas
=
SANAS
(
config
,
config
,
server_addr
=
(
server_address
,
port
),
server_addr
=
(
args
.
server_address
,
args
.
port
),
init_temperature
=
init_temperature
,
reduce_rate
=
reduce_rate
,
search_steps
=
args
.
search_steps
,
search_steps
=
args
.
search_steps
,
is_server
=
False
)
is_server
=
False
)
...
@@ -93,7 +82,7 @@ def search_mobilenetv2(config, args, image_size, is_server=True):
...
@@ -93,7 +82,7 @@ def search_mobilenetv2(config, args, image_size, is_server=True):
current_flops
=
flops
(
train_program
)
current_flops
=
flops
(
train_program
)
print
(
'step: {}, current_flops: {}'
.
format
(
step
,
current_flops
))
print
(
'step: {}, current_flops: {}'
.
format
(
step
,
current_flops
))
if
current_flops
>
max_flops
:
if
current_flops
>
int
(
321208544
)
:
continue
continue
test_loader
,
test_avg_cost
,
test_acc_top1
,
test_acc_top5
=
build_program
(
test_loader
,
test_avg_cost
,
test_acc_top1
,
test_acc_top5
=
build_program
(
...
@@ -139,7 +128,7 @@ def search_mobilenetv2(config, args, image_size, is_server=True):
...
@@ -139,7 +128,7 @@ def search_mobilenetv2(config, args, image_size, is_server=True):
train_compiled_program
=
fluid
.
CompiledProgram
(
train_compiled_program
=
fluid
.
CompiledProgram
(
train_program
).
with_data_parallel
(
train_program
).
with_data_parallel
(
loss_name
=
avg_cost
.
name
,
build_strategy
=
build_strategy
)
loss_name
=
avg_cost
.
name
,
build_strategy
=
build_strategy
)
for
epoch_id
in
range
(
retain_epoch
):
for
epoch_id
in
range
(
args
.
retain_epoch
):
for
batch_id
,
data
in
enumerate
(
train_loader
()):
for
batch_id
,
data
in
enumerate
(
train_loader
()):
fetches
=
[
avg_cost
.
name
]
fetches
=
[
avg_cost
.
name
]
s_time
=
time
.
time
()
s_time
=
time
.
time
()
...
@@ -179,7 +168,7 @@ def search_mobilenetv2(config, args, image_size, is_server=True):
...
@@ -179,7 +168,7 @@ def search_mobilenetv2(config, args, image_size, is_server=True):
def
test_search_result
(
tokens
,
image_size
,
args
,
config
):
def
test_search_result
(
tokens
,
image_size
,
args
,
config
):
sa_nas
=
SANAS
(
sa_nas
=
SANAS
(
config
,
config
,
server_addr
=
(
""
,
8887
),
server_addr
=
(
args
.
server_address
,
args
.
port
),
init_temperature
=
args
.
init_temperature
,
init_temperature
=
args
.
init_temperature
,
reduce_rate
=
args
.
reduce_rate
,
reduce_rate
=
args
.
reduce_rate
,
search_steps
=
args
.
search_steps
,
search_steps
=
args
.
search_steps
,
...
@@ -234,7 +223,7 @@ def test_search_result(tokens, image_size, args, config):
...
@@ -234,7 +223,7 @@ def test_search_result(tokens, image_size, args, config):
train_compiled_program
=
fluid
.
CompiledProgram
(
train_compiled_program
=
fluid
.
CompiledProgram
(
train_program
).
with_data_parallel
(
train_program
).
with_data_parallel
(
loss_name
=
avg_cost
.
name
,
build_strategy
=
build_strategy
)
loss_name
=
avg_cost
.
name
,
build_strategy
=
build_strategy
)
for
epoch_id
in
range
(
retain_epoch
):
for
epoch_id
in
range
(
args
.
retain_epoch
):
for
batch_id
,
data
in
enumerate
(
train_loader
()):
for
batch_id
,
data
in
enumerate
(
train_loader
()):
fetches
=
[
avg_cost
.
name
]
fetches
=
[
avg_cost
.
name
]
s_time
=
time
.
time
()
s_time
=
time
.
time
()
...
@@ -298,6 +287,11 @@ if __name__ == '__main__':
...
@@ -298,6 +287,11 @@ if __name__ == '__main__':
type
=
int
,
type
=
int
,
default
=
100
,
default
=
100
,
help
=
'controller server number.'
)
help
=
'controller server number.'
)
parser
.
add_argument
(
'--server_address'
,
type
=
str
,
default
=
""
,
help
=
'server ip.'
)
parser
.
add_argument
(
'--port'
,
type
=
int
,
default
=
8881
,
help
=
'server port'
)
parser
.
add_argument
(
'--retain_epoch'
,
type
=
int
,
default
=
5
,
help
=
'epoch for each token.'
)
parser
.
add_argument
(
'--lr'
,
type
=
float
,
default
=
0.1
,
help
=
'learning rate.'
)
parser
.
add_argument
(
'--lr'
,
type
=
float
,
default
=
0.1
,
help
=
'learning rate.'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print
(
args
)
print
(
args
)
...
...
demo/pantheon/run_student.py
浏览文件 @
bb0f8fbb
...
@@ -80,8 +80,8 @@ def run(args):
...
@@ -80,8 +80,8 @@ def run(args):
student
.
start
()
student
.
start
()
if
args
.
test_send_recv
:
if
args
.
test_send_recv
:
for
t
in
x
range
(
2
):
for
t
in
range
(
2
):
for
i
in
x
range
(
3
):
for
i
in
range
(
3
):
print
(
student
.
recv
(
t
))
print
(
student
.
recv
(
t
))
student
.
send
(
"message from student!"
)
student
.
send
(
"message from student!"
)
...
...
demo/prune/README.md
浏览文件 @
bb0f8fbb
...
@@ -17,7 +17,20 @@
...
@@ -17,7 +17,20 @@
1). 根据分类模型中
[
ImageNet数据准备文档
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification#%E6%95%B0%E6%8D%AE%E5%87%86%E5%A4%87
)
下载数据到
`PaddleSlim/demo/data/ILSVRC2012`
路径下。
1). 根据分类模型中
[
ImageNet数据准备文档
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification#%E6%95%B0%E6%8D%AE%E5%87%86%E5%A4%87
)
下载数据到
`PaddleSlim/demo/data/ILSVRC2012`
路径下。
2). 使用
`train.py`
脚本时,指定
`--data`
选项为
`imagenet`
.
2). 使用
`train.py`
脚本时,指定
`--data`
选项为
`imagenet`
.
## 2. 启动剪裁任务
## 2. 下载预训练模型
如果使用
`ImageNet`
数据,建议在预训练模型的基础上进行剪裁,请从
[
分类库
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification#%E5%B7%B2%E5%8F%91%E5%B8%83%E6%A8%A1%E5%9E%8B%E5%8F%8A%E5%85%B6%E6%80%A7%E8%83%BD
)
中下载合适的预训练模型。
这里以
`MobileNetV1`
为例,下载并解压预训练模型到当前路径:
```
wget http://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV1_pretrained.tar
tar -xf MobileNetV1_pretrained.tar
```
使用
`train.py`
脚本时,指定
`--pretrained_model`
加载预训练模型。
## 3. 启动剪裁任务
通过以下命令启动裁剪任务:
通过以下命令启动裁剪任务:
...
@@ -25,8 +38,8 @@
...
@@ -25,8 +38,8 @@
export CUDA_VISIBLE_DEVICES=0
export CUDA_VISIBLE_DEVICES=0
python train.py \
python train.py \
--model "MobileNet" \
--model "MobileNet" \
--pruned_ratio 0.3
3
\
--pruned_ratio 0.3
1
\
--data "
imagene
t"
--data "
mnis
t"
```
```
其中,
`model`
用于指定待裁剪的模型。
`pruned_ratio`
用于指定各个卷积层通道数被裁剪的比例。
`data`
选项用于指定使用的数据集。
其中,
`model`
用于指定待裁剪的模型。
`pruned_ratio`
用于指定各个卷积层通道数被裁剪的比例。
`data`
选项用于指定使用的数据集。
...
@@ -35,7 +48,7 @@ python train.py \
...
@@ -35,7 +48,7 @@ python train.py \
在本示例中,会在日志中输出剪裁前后的
`FLOPs`
,并且每训练一轮就会保存一个模型到文件系统。
在本示例中,会在日志中输出剪裁前后的
`FLOPs`
,并且每训练一轮就会保存一个模型到文件系统。
##
3
. 加载和评估模型
##
4
. 加载和评估模型
本节介绍如何加载训练过程中保存的模型。
本节介绍如何加载训练过程中保存的模型。
...
@@ -43,14 +56,14 @@ python train.py \
...
@@ -43,14 +56,14 @@ python train.py \
```
```
python eval.py \
python eval.py \
--model "
mobilen
et" \
--model "
MobileN
et" \
--data "mnist" \
--data "mnist" \
--model_path "./models/0"
--model_path "./models/0"
```
```
在脚本
`eval.py`
中,使用
`paddleslim.prune.load_model`
接口加载剪裁得到的模型。
在脚本
`eval.py`
中,使用
`paddleslim.prune.load_model`
接口加载剪裁得到的模型。
##
4
. 接口介绍
##
5
. 接口介绍
该示例使用了
`paddleslim.Pruner`
工具类,用户接口使用介绍请参考:
[
API文档
](
https://paddlepaddle.github.io/PaddleSlim/api/prune_api/
)
该示例使用了
`paddleslim.Pruner`
工具类,用户接口使用介绍请参考:
[
API文档
](
https://paddlepaddle.github.io/PaddleSlim/api/prune_api/
)
...
...
demo/prune/eval.py
浏览文件 @
bb0f8fbb
...
@@ -68,7 +68,7 @@ def eval(args):
...
@@ -68,7 +68,7 @@ def eval(args):
val_feeder
=
feeder
=
fluid
.
DataFeeder
(
val_feeder
=
feeder
=
fluid
.
DataFeeder
(
[
image
,
label
],
place
,
program
=
val_program
)
[
image
,
label
],
place
,
program
=
val_program
)
load_model
(
val_program
,
"./model/mobilenetv1_prune_50"
)
load_model
(
exe
,
val_program
,
args
.
model_path
)
batch_id
=
0
batch_id
=
0
acc_top1_ns
=
[]
acc_top1_ns
=
[]
...
...
demo/prune/train.py
浏览文件 @
bb0f8fbb
...
@@ -8,6 +8,7 @@ import math
...
@@ -8,6 +8,7 @@ import math
import
time
import
time
import
numpy
as
np
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
sys
.
path
.
append
(
"../../"
)
from
paddleslim.prune
import
Pruner
,
save_model
from
paddleslim.prune
import
Pruner
,
save_model
from
paddleslim.common
import
get_logger
from
paddleslim.common
import
get_logger
from
paddleslim.analysis
import
flops
from
paddleslim.analysis
import
flops
...
@@ -37,6 +38,7 @@ add_arg('log_period', int, 10, "Log period in batches.")
...
@@ -37,6 +38,7 @@ add_arg('log_period', int, 10, "Log period in batches.")
add_arg
(
'test_period'
,
int
,
10
,
"Test period in epoches."
)
add_arg
(
'test_period'
,
int
,
10
,
"Test period in epoches."
)
add_arg
(
'model_path'
,
str
,
"./models"
,
"The path to save model."
)
add_arg
(
'model_path'
,
str
,
"./models"
,
"The path to save model."
)
add_arg
(
'pruned_ratio'
,
float
,
None
,
"The ratios to be pruned."
)
add_arg
(
'pruned_ratio'
,
float
,
None
,
"The ratios to be pruned."
)
add_arg
(
'criterion'
,
str
,
"l1_norm"
,
"The prune criterion to be used, support l1_norm and batch_norm_scale."
)
# yapf: enable
# yapf: enable
model_list
=
models
.
__all__
model_list
=
models
.
__all__
...
@@ -136,6 +138,8 @@ def compress(args):
...
@@ -136,6 +138,8 @@ def compress(args):
return
os
.
path
.
exists
(
return
os
.
path
.
exists
(
os
.
path
.
join
(
args
.
pretrained_model
,
var
.
name
))
os
.
path
.
join
(
args
.
pretrained_model
,
var
.
name
))
_logger
.
info
(
"Load pretrained model from {}"
.
format
(
args
.
pretrained_model
))
fluid
.
io
.
load_vars
(
exe
,
args
.
pretrained_model
,
predicate
=
if_exist
)
fluid
.
io
.
load_vars
(
exe
,
args
.
pretrained_model
,
predicate
=
if_exist
)
val_reader
=
paddle
.
batch
(
val_reader
,
batch_size
=
args
.
batch_size
)
val_reader
=
paddle
.
batch
(
val_reader
,
batch_size
=
args
.
batch_size
)
...
@@ -200,10 +204,12 @@ def compress(args):
...
@@ -200,10 +204,12 @@ def compress(args):
end_time
-
start_time
))
end_time
-
start_time
))
batch_id
+=
1
batch_id
+=
1
test
(
0
,
val_program
)
params
=
get_pruned_params
(
args
,
fluid
.
default_main_program
())
params
=
get_pruned_params
(
args
,
fluid
.
default_main_program
())
_logger
.
info
(
"FLOPs before pruning: {}"
.
format
(
_logger
.
info
(
"FLOPs before pruning: {}"
.
format
(
flops
(
fluid
.
default_main_program
())))
flops
(
fluid
.
default_main_program
())))
pruner
=
Pruner
()
pruner
=
Pruner
(
args
.
criterion
)
pruned_val_program
,
_
,
_
=
pruner
.
prune
(
pruned_val_program
,
_
,
_
=
pruner
.
prune
(
val_program
,
val_program
,
fluid
.
global_scope
(),
fluid
.
global_scope
(),
...
...
docs/en/model_zoo_en.md
浏览文件 @
bb0f8fbb
此差异已折叠。
点击以展开。
docs/zh_cn/api_cn/nas_api.rst
浏览文件 @
bb0f8fbb
...
@@ -128,7 +128,7 @@ SANAS(Simulated Annealing Neural Architecture Search)是基于模拟退火
...
@@ -128,7 +128,7 @@ SANAS(Simulated Annealing Neural Architecture Search)是基于模拟退火
- **tokens(list):** - 一组tokens。tokens的长度和范围取决于搜索空间。
- **tokens(list):** - 一组tokens。tokens的长度和范围取决于搜索空间。
**返回:**
**返回:**
根据传入的token得到一个模型结构实例。
根据传入的token得到一个模型结构实例
列表
。
**示例代码:**
**示例代码:**
...
@@ -153,8 +153,10 @@ SANAS(Simulated Annealing Neural Architecture Search)是基于模拟退火
...
@@ -153,8 +153,10 @@ SANAS(Simulated Annealing Neural Architecture Search)是基于模拟退火
**示例代码:**
**示例代码:**
.. code-block:: python
.. code-block:: python
import paddle.fluid as fluid
import paddle.fluid as fluid
from paddleslim.nas import SANAS
from paddleslim.nas import SANAS
config = [('MobileNetV2Space')]
config = [('MobileNetV2Space')]
sanas = SANAS(configs=config)
sanas = SANAS(configs=config)
print(sanas.current_info())
print(sanas.current_info())
docs/zh_cn/api_cn/pantheon_api.md
浏览文件 @
bb0f8fbb
#
多进程蒸馏
#
大规模可扩展知识蒸馏框架 Pantheon
## Teacher
## Teacher
...
@@ -100,7 +100,8 @@ pantheon.Teacher.start\_knowledge\_service(feed\_list, schema, program, reader\_
...
@@ -100,7 +100,8 @@ pantheon.Teacher.start\_knowledge\_service(feed\_list, schema, program, reader\_
-
**times (int):**
The maximum repeated serving times, default 1. Whenever
-
**times (int):**
The maximum repeated serving times, default 1. Whenever
the public method
**get\_knowledge\_generator()**
in
**Student**
the public method
**get\_knowledge\_generator()**
in
**Student**
object called once, the serving times will be added one,
object called once, the serving times will be added one,
until reaching the maximum and ending the service.
until reaching the maximum and ending the service. Only
valid in online mode, and will be ignored in offline mode.
**Return:**
None
**Return:**
None
...
...
docs/zh_cn/api_cn/prune_api.rst
浏览文件 @
bb0f8fbb
...
@@ -378,7 +378,7 @@ load_sensitivities
...
@@ -378,7 +378,7 @@ load_sensitivities
}
}
}
}
sensitivities_file = "sensitive_api_demo.data"
sensitivities_file = "sensitive_api_demo.data"
with open(sensitivities_file, 'w') as f:
with open(sensitivities_file, 'w
b
') as f:
pickle.dump(sen, f)
pickle.dump(sen, f)
sensitivities = load_sensitivities(sensitivities_file)
sensitivities = load_sensitivities(sensitivities_file)
print(sensitivities)
print(sensitivities)
...
...
docs/zh_cn/model_zoo.md
浏览文件 @
bb0f8fbb
# 模型库
# 模型库
## 1. 图
象
分类
## 1. 图
像
分类
数据集:ImageNet1000类
数据集:ImageNet1000类
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
| MobileNetV2 | quant_aware |72.05%/90.63% (-0.1%/-0.02%)| 4.0 | - |
[
下载链接
](
https://paddlemodels.bj.bcebos.com/PaddleSlim/MobileNetV2_quant_aware.tar
)
|
| MobileNetV2 | quant_aware |72.05%/90.63% (-0.1%/-0.02%)| 4.0 | - |
[
下载链接
](
https://paddlemodels.bj.bcebos.com/PaddleSlim/MobileNetV2_quant_aware.tar
)
|
|ResNet50|-|76.50%/93.00%| 99 | 2.71 |
[
下载链接
](
http://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar
)
|
|ResNet50|-|76.50%/93.00%| 99 | 2.71 |
[
下载链接
](
http://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar
)
|
|ResNet50|quant_post|76.33%/93.02% (-0.17%/+0.02%)| 25.1| 1.19 |
[
下载链接
](
https://paddlemodels.bj.bcebos.com/PaddleSlim/ResNet50_quant_post.tar
)
|
|ResNet50|quant_post|76.33%/93.02% (-0.17%/+0.02%)| 25.1| 1.19 |
[
下载链接
](
https://paddlemodels.bj.bcebos.com/PaddleSlim/ResNet50_quant_post.tar
)
|
|ResNet50|quant_aware|
76.48%/93.11% (-0.02%/+0.11%)| 25.1 | 1.17 |
[
下载链接
](
https://paddlemodels.bj.bcebos.com/PaddleSlim/ResNet50_quant_awre.tar
)
|
|ResNet50|quant_aware|
76.48%/93.11% (-0.02%/+0.11%)| 25.1 | 1.17 |
[
下载链接
](
https://paddlemodels.bj.bcebos.com/PaddleSlim/ResNet50_quant_awre.tar
)
|
分类模型Lite时延(ms)
分类模型Lite时延(ms)
...
@@ -89,6 +89,12 @@
...
@@ -89,6 +89,12 @@
<a name="trans1">[1]</a>:带_vd后缀代表该预训练模型使用了Mixup,Mixup相关介绍参考[mixup: Beyond Empirical Risk Minimization](https://arxiv.org/abs/1710.09412)
<a name="trans1">[1]</a>:带_vd后缀代表该预训练模型使用了Mixup,Mixup相关介绍参考[mixup: Beyond Empirical Risk Minimization](https://arxiv.org/abs/1710.09412)
### 1.4 搜索
| 模型 | 压缩方法 | Top-1/Top-5 Acc | 模型体积(MB) | GFLOPs | 下载 |
|:--:|:---:|:--:|:--:|:--:|:--:|
| MobileNetV2 | - | 72.15%/90.65% | 15 | 0.59 |
[
下载链接
](
https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_pretrained.tar
)
|
| MobileNetV2 | SANAS | 71.518%/90.208% (-0.632%/-0.442%) | 14 | 0.295 |
[
下载链接
](
https://paddlemodels.cdn.bcebos.com/PaddleSlim/MobileNetV2_sanas.tar
)
|
## 2. 目标检测
## 2. 目标检测
...
@@ -99,8 +105,8 @@
...
@@ -99,8 +105,8 @@
| 模型 | 压缩方法 | 数据集 | Image/GPU | 输入608 Box AP | 输入416 Box AP | 输入320 Box AP | 模型体积(MB) | TensorRT时延(V100, ms) | 下载 |
| 模型 | 压缩方法 | 数据集 | Image/GPU | 输入608 Box AP | 输入416 Box AP | 输入320 Box AP | 模型体积(MB) | TensorRT时延(V100, ms) | 下载 |
| :----------------------------: | :---------: | :----: | :-------: | :------------: | :------------: | :------------: | :------------: | :----------: |:----------: |
| :----------------------------: | :---------: | :----: | :-------: | :------------: | :------------: | :------------: | :------------: | :----------: |:----------: |
| MobileNet-V1-YOLOv3 | - | COCO | 8 | 29.3 | 29.3 | 27.1 | 95 | - |
[
下载链接
](
https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1.tar
)
|
| MobileNet-V1-YOLOv3 | - | COCO | 8 | 29.3 | 29.3 | 27.1 | 95 | - |
[
下载链接
](
https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1.tar
)
|
| MobileNet-V1-YOLOv3 | quant_post | COCO | 8 | 27.9 (-1.4)|
28.0 (-1.3)
| 26.0 (-1.0) | 25 | - |
[
下载链接
](
https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_mobilenetv1_coco_quant_post.tar
)
|
| MobileNet-V1-YOLOv3 | quant_post | COCO | 8 | 27.9 (-1.4)|
28.0 (-1.3)
| 26.0 (-1.0) | 25 | - |
[
下载链接
](
https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_mobilenetv1_coco_quant_post.tar
)
|
| MobileNet-V1-YOLOv3 | quant_aware | COCO | 8 | 28.1 (-1.2)| 28.2 (-1.1)
| 25.8 (-1.2) | 26.3 | - |
[
下载链接
](
https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_mobilenet_coco_quant_aware.tar
)
|
| MobileNet-V1-YOLOv3 | quant_aware | COCO | 8 | 28.1 (-1.2)| 28.2 (-1.1)
| 25.8 (-1.2) | 26.3 | - |
[
下载链接
](
https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_mobilenet_coco_quant_aware.tar
)
|
| R34-YOLOv3 | - | COCO | 8 | 36.2 | 34.3 | 31.4 | 162 | - |
[
下载链接
](
https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34.tar
)
|
| R34-YOLOv3 | - | COCO | 8 | 36.2 | 34.3 | 31.4 | 162 | - |
[
下载链接
](
https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34.tar
)
|
| R34-YOLOv3 | quant_post | COCO | 8 | 35.7 (-0.5) | - | - | 42.7 | - |
[
下载链接
](
https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_r34_coco_quant_post.tar
)
|
| R34-YOLOv3 | quant_post | COCO | 8 | 35.7 (-0.5) | - | - | 42.7 | - |
[
下载链接
](
https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_r34_coco_quant_post.tar
)
|
| R34-YOLOv3 | quant_aware | COCO | 8 | 35.2 (-1.0) | 33.3 (-1.0) | 30.3 (-1.1)| 44 | - |
[
下载链接
](
https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_r34_coco_quant_aware.tar
)
|
| R34-YOLOv3 | quant_aware | COCO | 8 | 35.2 (-1.0) | 33.3 (-1.0) | 30.3 (-1.1)| 44 | - |
[
下载链接
](
https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_r34_coco_quant_aware.tar
)
|
...
@@ -157,6 +163,20 @@
...
@@ -157,6 +163,20 @@
| MobileNet-V1-YOLOv3 | ResNet34-YOLOv3 distill | COCO | 8 | 31.4 (+2.1) | 30.0 (+0.7) | 27.1 (+0.1) | 95 |
[
下载链接
](
https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_mobilenetv1_coco_distilled.tar
)
|
| MobileNet-V1-YOLOv3 | ResNet34-YOLOv3 distill | COCO | 8 | 31.4 (+2.1) | 30.0 (+0.7) | 27.1 (+0.1) | 95 |
[
下载链接
](
https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_mobilenetv1_coco_distilled.tar
)
|
### 2.4 搜索
数据集:WIDER-FACE
| 模型 | 压缩方法 | Image/GPU | 输入尺寸 | Easy/Medium/Hard | 模型体积(KB) | 硬件延时(ms)| 下载 |
| :------------: | :---------: | :-------: | :------: | :-----------------------------: | :------------: | :------------: | :----------------------------------------------------------: |
| BlazeFace | - | 8 | 640 | 91.5/89.2/79.7 | 815 | 71.862 |
[
下载链接
](
https://paddlemodels.bj.bcebos.com/object_detection/blazeface_original.tar
)
|
| BlazeFace-NAS | - | 8 | 640 | 83.7/80.7/65.8 | 244 | 21.117 |
[
下载链接
](
https://paddlemodels.bj.bcebos.com/object_detection/blazeface_nas.tar
)
|
| BlazeFace-NAS1 | SANAS | 8 | 640 | 87.0/83.7/68.5 | 389 | 22.558 |
[
下载链接
](
https://paddlemodels.bj.bcebos.com/object_detection/blazeface_nas2.tar
)
|
!!! note "Note"
<a name="trans1">[1]</a>: 硬件延时时间是利用提供的硬件延时表得到的,硬件延时表是在855芯片上基于PaddleLite测试的结果。
## 3. 图像分割
## 3. 图像分割
数据集:Cityscapes
数据集:Cityscapes
...
...
paddleslim/__init__.py
浏览文件 @
bb0f8fbb
...
@@ -19,4 +19,5 @@ from paddleslim import nas
...
@@ -19,4 +19,5 @@ from paddleslim import nas
from
paddleslim
import
analysis
from
paddleslim
import
analysis
from
paddleslim
import
dist
from
paddleslim
import
dist
from
paddleslim
import
quant
from
paddleslim
import
quant
__all__
=
[
'models'
,
'prune'
,
'nas'
,
'analysis'
,
'dist'
,
'quant'
]
from
paddleslim
import
pantheon
__all__
=
[
'models'
,
'prune'
,
'nas'
,
'analysis'
,
'dist'
,
'quant'
,
'pantheon'
]
paddleslim/nas/sa_nas.py
浏览文件 @
bb0f8fbb
...
@@ -190,7 +190,10 @@ class SANAS(object):
...
@@ -190,7 +190,10 @@ class SANAS(object):
self
.
_iter
=
0
self
.
_iter
=
0
def
_get_host_ip
(
self
):
def
_get_host_ip
(
self
):
return
socket
.
gethostbyname
(
socket
.
gethostname
())
if
os
.
name
==
'posix'
:
return
socket
.
gethostbyname
(
'localhost'
)
else
:
return
socket
.
gethostbyname
(
socket
.
gethostname
())
def
tokens2arch
(
self
,
tokens
):
def
tokens2arch
(
self
,
tokens
):
"""
"""
...
...
paddleslim/pantheon/README.md
浏览文件 @
bb0f8fbb
...
@@ -13,7 +13,7 @@ The illustration below shows an application of Pantheon, where the sudent model
...
@@ -13,7 +13,7 @@ The illustration below shows an application of Pantheon, where the sudent model
## Prerequisites
## Prerequisites
-
Python 2.7.x or 3.x
-
Python 2.7.x or 3.x
-
PaddlePaddle >= 1.
6
.0
-
PaddlePaddle >= 1.
7
.0
## APIs
## APIs
...
...
paddleslim/pantheon/student.py
浏览文件 @
bb0f8fbb
...
@@ -158,7 +158,7 @@ class Student(object):
...
@@ -158,7 +158,7 @@ class Student(object):
if
end_recved
:
if
end_recved
:
break
break
with
open
(
in_path
,
'r'
)
as
fin
:
with
open
(
in_path
,
'r
b
'
)
as
fin
:
# get knowledge desc
# get knowledge desc
desc
=
pickle
.
load
(
fin
)
desc
=
pickle
.
load
(
fin
)
out_queue
.
put
(
desc
)
out_queue
.
put
(
desc
)
...
@@ -222,7 +222,7 @@ class Student(object):
...
@@ -222,7 +222,7 @@ class Student(object):
self
.
_started
=
True
self
.
_started
=
True
def
_merge_knowledge
(
self
,
knowledge
):
def
_merge_knowledge
(
self
,
knowledge
):
for
k
,
tensors
in
knowledge
.
items
(
):
for
k
,
tensors
in
list
(
knowledge
.
items
()
):
if
len
(
tensors
)
==
0
:
if
len
(
tensors
)
==
0
:
del
knowledge
[
k
]
del
knowledge
[
k
]
elif
len
(
tensors
)
==
1
:
elif
len
(
tensors
)
==
1
:
...
@@ -308,7 +308,7 @@ class Student(object):
...
@@ -308,7 +308,7 @@ class Student(object):
print
(
"Knowledge merging strategy: {}"
.
format
(
print
(
"Knowledge merging strategy: {}"
.
format
(
self
.
_merge_strategy
))
self
.
_merge_strategy
))
print
(
"Knowledge description after merging:"
)
print
(
"Knowledge description after merging:"
)
for
schema
,
desc
in
knowledge_desc
.
items
(
):
for
schema
,
desc
in
list
(
knowledge_desc
.
items
()
):
print
(
"{}: {}"
.
format
(
schema
,
desc
))
print
(
"{}: {}"
.
format
(
schema
,
desc
))
self
.
_knowledge_desc
=
knowledge_desc
self
.
_knowledge_desc
=
knowledge_desc
...
@@ -426,13 +426,13 @@ class Student(object):
...
@@ -426,13 +426,13 @@ class Student(object):
end_received
=
[
0
]
*
len
(
queues
)
end_received
=
[
0
]
*
len
(
queues
)
while
True
:
while
True
:
knowledge
=
OrderedDict
(
knowledge
=
OrderedDict
(
[(
k
,
[])
for
k
,
v
in
self
.
_knowledge_desc
.
items
(
)])
[(
k
,
[])
for
k
,
v
in
list
(
self
.
_knowledge_desc
.
items
()
)])
for
idx
,
receiver
in
enumerate
(
data_receivers
):
for
idx
,
receiver
in
enumerate
(
data_receivers
):
if
not
end_received
[
idx
]:
if
not
end_received
[
idx
]:
batch_samples
=
receiver
.
next
(
batch_samples
=
receiver
.
next
(
)
if
six
.
PY2
else
receiver
.
__next__
()
)
if
six
.
PY2
else
receiver
.
__next__
()
if
not
isinstance
(
batch_samples
,
EndSignal
):
if
not
isinstance
(
batch_samples
,
EndSignal
):
for
k
,
v
in
batch_samples
.
items
(
):
for
k
,
v
in
list
(
batch_samples
.
items
()
):
knowledge
[
k
].
append
(
v
)
knowledge
[
k
].
append
(
v
)
else
:
else
:
end_received
[
idx
]
=
1
end_received
[
idx
]
=
1
...
...
paddleslim/pantheon/teacher.py
浏览文件 @
bb0f8fbb
...
@@ -151,7 +151,7 @@ class Teacher(object):
...
@@ -151,7 +151,7 @@ class Teacher(object):
self
.
_t2s_queue
=
None
self
.
_t2s_queue
=
None
self
.
_cmd_queue
=
None
self
.
_cmd_queue
=
None
self
.
_out_file
=
open
(
self
.
_out_path
,
"w"
)
if
self
.
_out_path
else
None
self
.
_out_file
=
open
(
self
.
_out_path
,
"w
b
"
)
if
self
.
_out_path
else
None
if
self
.
_out_file
:
if
self
.
_out_file
:
return
return
...
@@ -231,7 +231,7 @@ class Teacher(object):
...
@@ -231,7 +231,7 @@ class Teacher(object):
"The knowledge data should be a dict or OrderedDict!"
)
"The knowledge data should be a dict or OrderedDict!"
)
knowledge_desc
=
{}
knowledge_desc
=
{}
for
name
,
value
in
knowledge
.
items
(
):
for
name
,
value
in
list
(
knowledge
.
items
()
):
knowledge_desc
[
name
]
=
{
knowledge_desc
[
name
]
=
{
"shape"
:
[
-
1
]
+
list
(
value
.
shape
[
1
:]),
"shape"
:
[
-
1
]
+
list
(
value
.
shape
[
1
:]),
"dtype"
:
str
(
value
.
dtype
),
"dtype"
:
str
(
value
.
dtype
),
...
@@ -294,7 +294,8 @@ class Teacher(object):
...
@@ -294,7 +294,8 @@ class Teacher(object):
times (int): The maximum repeated serving times. Default 1. Whenever
times (int): The maximum repeated serving times. Default 1. Whenever
the public method 'get_knowledge_generator()' in Student
the public method 'get_knowledge_generator()' in Student
object called once, the serving times will be added one,
object called once, the serving times will be added one,
until reaching the maximum and ending the service.
until reaching the maximum and ending the service. Only
valid in online mode, and will be ignored in offline mode.
"""
"""
if
not
self
.
_started
:
if
not
self
.
_started
:
raise
ValueError
(
"The method start() should be called first!"
)
raise
ValueError
(
"The method start() should be called first!"
)
...
@@ -339,9 +340,12 @@ class Teacher(object):
...
@@ -339,9 +340,12 @@ class Teacher(object):
if
not
times
>
0
:
if
not
times
>
0
:
raise
ValueError
(
"Repeated serving times should be positive!"
)
raise
ValueError
(
"Repeated serving times should be positive!"
)
self
.
_times
=
times
self
.
_times
=
times
if
self
.
_times
>
1
and
self
.
_out_file
:
self
.
_times
=
1
print
(
"WARNING: args 'times' will be ignored in offline mode"
)
desc
=
{}
desc
=
{}
for
name
,
var
in
schema
.
items
(
):
for
name
,
var
in
list
(
schema
.
items
()
):
if
not
isinstance
(
var
,
fluid
.
framework
.
Variable
):
if
not
isinstance
(
var
,
fluid
.
framework
.
Variable
):
raise
ValueError
(
raise
ValueError
(
"The member of schema must be fluid Variable."
)
"The member of schema must be fluid Variable."
)
...
@@ -412,10 +416,14 @@ class Teacher(object):
...
@@ -412,10 +416,14 @@ class Teacher(object):
else
:
else
:
if
self
.
_knowledge_queue
:
if
self
.
_knowledge_queue
:
self
.
_knowledge_queue
.
put
(
EndSignal
())
self
.
_knowledge_queue
.
put
(
EndSignal
())
# should close file in child thread to wait for all
# writing finished
if
self
.
_out_file
:
self
.
_out_file
.
close
()
# Asynchronous output
# Asynchronous output
out_buf_queue
=
Queue
.
Queue
(
self
.
_buf_size
)
out_buf_queue
=
Queue
.
Queue
(
self
.
_buf_size
)
schema_keys
,
schema_vars
=
zip
(
*
self
.
_schema
.
items
(
))
schema_keys
,
schema_vars
=
zip
(
*
list
(
self
.
_schema
.
items
()
))
out_thread
=
Thread
(
target
=
writer
,
args
=
(
out_buf_queue
,
schema_keys
))
out_thread
=
Thread
(
target
=
writer
,
args
=
(
out_buf_queue
,
schema_keys
))
out_thread
.
daemon
=
True
out_thread
.
daemon
=
True
out_thread
.
start
()
out_thread
.
start
()
...
@@ -424,8 +432,9 @@ class Teacher(object):
...
@@ -424,8 +432,9 @@ class Teacher(object):
self
.
_program
).
with_data_parallel
()
self
.
_program
).
with_data_parallel
()
print
(
"Knowledge description {}"
.
format
(
self
.
_knowledge_desc
))
print
(
"Knowledge description {}"
.
format
(
self
.
_knowledge_desc
))
print
(
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
time
.
time
()))
+
print
(
" Teacher begins to serve ..."
)
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
time
.
time
()))
+
" Teacher begins to serve ..."
)
# For offline dump, write the knowledge description to the head of file
# For offline dump, write the knowledge description to the head of file
if
self
.
_out_file
:
if
self
.
_out_file
:
self
.
_out_file
.
write
(
pickle
.
dumps
(
self
.
_knowledge_desc
))
self
.
_out_file
.
write
(
pickle
.
dumps
(
self
.
_knowledge_desc
))
...
@@ -491,11 +500,10 @@ class Teacher(object):
...
@@ -491,11 +500,10 @@ class Teacher(object):
if
self
.
_knowledge_queue
:
if
self
.
_knowledge_queue
:
self
.
_knowledge_queue
.
join
()
self
.
_knowledge_queue
.
join
()
print
(
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
time
.
time
()))
+
print
(
" Teacher ends serving."
)
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
time
.
time
()))
+
" Teacher ends serving."
)
def
__del__
(
self
):
def
__del__
(
self
):
if
self
.
_manager
:
if
self
.
_manager
:
self
.
_manager
.
shutdown
()
self
.
_manager
.
shutdown
()
if
self
.
_out_file
:
self
.
_out_file
.
close
()
paddleslim/prune/pruner.py
浏览文件 @
bb0f8fbb
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
import
logging
import
logging
import
sys
import
numpy
as
np
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
copy
import
copy
...
@@ -79,8 +80,8 @@ class Pruner():
...
@@ -79,8 +80,8 @@ class Pruner():
pruned_num
=
int
(
round
(
param_v
.
shape
()[
0
]
*
ratio
))
pruned_num
=
int
(
round
(
param_v
.
shape
()[
0
]
*
ratio
))
pruned_idx
=
[
0
]
*
pruned_num
pruned_idx
=
[
0
]
*
pruned_num
else
:
else
:
p
aram_t
=
np
.
array
(
scope
.
find_var
(
param
).
get_tensor
())
p
runed_idx
=
self
.
_cal_pruned_idx
(
pruned_idx
=
self
.
_cal_pruned_idx
(
param_t
,
ratio
,
axis
=
0
)
graph
,
scope
,
param
,
ratio
,
axis
=
0
)
param
=
graph
.
var
(
param
)
param
=
graph
.
var
(
param
)
conv_op
=
param
.
outputs
()[
0
]
conv_op
=
param
.
outputs
()[
0
]
walker
=
conv2d_walker
(
walker
=
conv2d_walker
(
...
@@ -130,7 +131,7 @@ class Pruner():
...
@@ -130,7 +131,7 @@ class Pruner():
graph
.
infer_shape
()
graph
.
infer_shape
()
return
graph
.
program
,
param_backup
,
param_shape_backup
return
graph
.
program
,
param_backup
,
param_shape_backup
def
_cal_pruned_idx
(
self
,
param
,
ratio
,
axis
):
def
_cal_pruned_idx
(
self
,
graph
,
scope
,
param
,
ratio
,
axis
):
"""
"""
Calculate the index to be pruned on axis by given pruning ratio.
Calculate the index to be pruned on axis by given pruning ratio.
...
@@ -145,11 +146,26 @@ class Pruner():
...
@@ -145,11 +146,26 @@ class Pruner():
Returns:
Returns:
list<int>: The indexes to be pruned on axis.
list<int>: The indexes to be pruned on axis.
"""
"""
prune_num
=
int
(
round
(
param
.
shape
[
axis
]
*
ratio
))
reduce_dims
=
[
i
for
i
in
range
(
len
(
param
.
shape
))
if
i
!=
axis
]
if
self
.
criterion
==
'l1_norm'
:
if
self
.
criterion
==
'l1_norm'
:
criterions
=
np
.
sum
(
np
.
abs
(
param
),
axis
=
tuple
(
reduce_dims
))
param_t
=
np
.
array
(
scope
.
find_var
(
param
).
get_tensor
())
pruned_idx
=
criterions
.
argsort
()[:
prune_num
]
prune_num
=
int
(
round
(
param_t
.
shape
[
axis
]
*
ratio
))
reduce_dims
=
[
i
for
i
in
range
(
len
(
param_t
.
shape
))
if
i
!=
axis
]
criterions
=
np
.
sum
(
np
.
abs
(
param_t
),
axis
=
tuple
(
reduce_dims
))
pruned_idx
=
criterions
.
argsort
()[:
prune_num
]
elif
self
.
criterion
==
"batch_norm_scale"
:
param_var
=
graph
.
var
(
param
)
conv_op
=
param_var
.
outputs
()[
0
]
conv_output
=
conv_op
.
outputs
(
"Output"
)[
0
]
bn_op
=
conv_output
.
outputs
()[
0
]
if
bn_op
is
not
None
:
bn_scale_param
=
bn_op
.
inputs
(
"Scale"
)[
0
].
name
()
bn_scale_np
=
np
.
array
(
scope
.
find_var
(
bn_scale_param
).
get_tensor
())
prune_num
=
int
(
round
(
bn_scale_np
.
shape
[
axis
]
*
ratio
))
pruned_idx
=
np
.
abs
(
bn_scale_np
).
argsort
()[:
prune_num
]
else
:
raise
SystemExit
(
"Can't find BatchNorm op after Conv op in Network."
)
return
pruned_idx
return
pruned_idx
def
_prune_tensor
(
self
,
tensor
,
pruned_idx
,
pruned_axis
,
lazy
=
False
):
def
_prune_tensor
(
self
,
tensor
,
pruned_idx
,
pruned_axis
,
lazy
=
False
):
...
...
tests/test_slim_prune.py
0 → 100644
浏览文件 @
bb0f8fbb
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
sys
.
path
.
append
(
"../"
)
import
unittest
import
paddle.fluid
as
fluid
from
paddleslim.prune
import
Pruner
from
layers
import
conv_bn_layer
class
TestPrune
(
unittest
.
TestCase
):
def
test_prune
(
self
):
main_program
=
fluid
.
Program
()
startup_program
=
fluid
.
Program
()
# X X O X O
# conv1-->conv2-->sum1-->conv3-->conv4-->sum2-->conv5-->conv6
# | ^ | ^
# |____________| |____________________|
#
# X: prune output channels
# O: prune input channels
with
fluid
.
program_guard
(
main_program
,
startup_program
):
input
=
fluid
.
data
(
name
=
"image"
,
shape
=
[
None
,
3
,
16
,
16
])
conv1
=
conv_bn_layer
(
input
,
8
,
3
,
"conv1"
)
conv2
=
conv_bn_layer
(
conv1
,
8
,
3
,
"conv2"
)
sum1
=
conv1
+
conv2
conv3
=
conv_bn_layer
(
sum1
,
8
,
3
,
"conv3"
)
conv4
=
conv_bn_layer
(
conv3
,
8
,
3
,
"conv4"
)
sum2
=
conv4
+
sum1
conv5
=
conv_bn_layer
(
sum2
,
8
,
3
,
"conv5"
)
conv6
=
conv_bn_layer
(
conv5
,
8
,
3
,
"conv6"
)
shapes
=
{}
for
param
in
main_program
.
global_block
().
all_parameters
():
shapes
[
param
.
name
]
=
param
.
shape
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
scope
=
fluid
.
Scope
()
exe
.
run
(
startup_program
,
scope
=
scope
)
criterion
=
'batch_norm_scale'
pruner
=
Pruner
(
criterion
)
main_program
,
_
,
_
=
pruner
.
prune
(
main_program
,
scope
,
params
=
[
"conv4_weights"
],
ratios
=
[
0.5
],
place
=
place
,
lazy
=
False
,
only_graph
=
False
,
param_backup
=
None
,
param_shape_backup
=
None
)
shapes
=
{
"conv1_weights"
:
(
4L
,
3L
,
3L
,
3L
),
"conv2_weights"
:
(
4L
,
4L
,
3L
,
3L
),
"conv3_weights"
:
(
8L
,
4L
,
3L
,
3L
),
"conv4_weights"
:
(
4L
,
8L
,
3L
,
3L
),
"conv5_weights"
:
(
8L
,
4L
,
3L
,
3L
),
"conv6_weights"
:
(
8L
,
8L
,
3L
,
3L
)
}
for
param
in
main_program
.
global_block
().
all_parameters
():
if
"weights"
in
param
.
name
:
print
(
"param: {}; param shape: {}"
.
format
(
param
.
name
,
param
.
shape
))
self
.
assertTrue
(
param
.
shape
==
shapes
[
param
.
name
])
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录