Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
cf8fbaef
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
大约 2 年 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
cf8fbaef
编写于
12月 06, 2019
作者:
W
wanghaoshuang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update pruner and api doc.
上级
50ff5124
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
110 addition
and
9 deletion
+110
-9
doc/prune_api.md
doc/prune_api.md
+97
-3
paddleslim/prune/pruner.py
paddleslim/prune/pruner.py
+12
-5
tests/test_prune.py
tests/test_prune.py
+1
-1
未找到文件。
doc/prune_api.md
浏览文件 @
cf8fbaef
...
...
@@ -11,7 +11,7 @@
**参数:**
- **criterion:** 评估一个卷积层内通道重要性所参考的指标。目前仅支持`l1_norm`。默认为`l1_norm`。
-
**criterion:**
评估一个卷积层内通道重要性所参考的指标。目前仅支持
`l1_norm`
。默认为
`l1_norm`
。
**返回:**
一个Pruner类的实例
...
...
@@ -32,7 +32,7 @@ pruner = Pruner()
-
**program(paddle.fluid.Program):**
要裁剪的目标网络。更多关于Program的介绍请参考:
[
Program概念介绍
](
https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/fluid_cn/Program_cn.html#program
)
。
-
**scope(paddle.fluid.Scope):**
要裁剪的权重所在的
`scope`
,Paddle中用
`scope`
实例存放模型参数和运行时变量的值。更多介绍请参考
[
Scope概念介绍
](
)
-
**scope(paddle.fluid.Scope):**
要裁剪的权重所在的
`scope`
,Paddle中用
`scope`
实例存放模型参数和运行时变量的值。
Scope中的参数值会被
`inplace`
的裁剪。
更多介绍请参考
[
Scope概念介绍
](
)
-
**params(list<str>):**
需要被裁剪的卷积层的参数的名称列表。可以通过以下方式查看模型中所有参数的名称:
```
...
...
@@ -43,6 +43,100 @@ for block in program.blocks:
-
**ratios(list<float>):**
用于裁剪
`params`
的剪切率,类型为列表。该列表长度必须与
`params`
的长度一致。
-
**place(paddle.fluid.Place):**
-
**place(paddle.fluid.Place):**
待裁剪参数所在的设备位置,可以是
`CUDAPlace`
或
`CPUPLace`
。
[
Place概念介绍
](
)
-
**lazy(bool):**
`lazy`
为True时,通过将指定通道的参数置零达到裁剪的目的,参数的
`shape保持不变`
;
`lazy`
为False时,直接将要裁的通道的参数删除,参数的
`shape`
会发生变化。
-
**only_graph(bool):**
是否只裁剪网络结构。在Paddle中,Program定义了网络结构,Scope存储参数的数值。一个Scope实例可以被多个Program使用,比如定义了训练网络的Program和定义了测试网络的Program是使用同一个Scope实例的。
`only_graph`
为True时,只对Program中定义的卷积的通道进行剪裁;
`only_graph`
为false时,Scope中卷积参数的数值也会被剪裁。默认为False。
-
**param_backup(bool):**
是否返回对参数值的备份。默认为False。
-
**param_shape_backup(bool):**
是否返回对参数
`shape`
的备份。
**返回:**
-
**pruned_program(paddle.fluid.Program):**
被裁剪后的Program。
-
**param_backup(dict):**
对参数数值的备份,用于恢复Scope中的参数数值。
-
**param_shape_backup(dict):**
对参数形状的备份。
**示例:**
```
import paddle.fluid as fluid
from paddleslim.prune import Pruner
def conv_bn_layer(input,
num_filters,
filter_size,
name,
stride=1,
groups=1,
act=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False,
name=name + "_out")
bn_name = name + "_bn"
return fluid.layers.batch_norm(
input=conv,
act=act,
name=bn_name + '_output',
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance', )
main_program = fluid.Program()
startup_program = fluid.Program()
# X X O X O
# conv1-->conv2-->sum1-->conv3-->conv4-->sum2-->conv5-->conv6
# | ^ | ^
# |____________| |____________________|
#
# X: prune output channels
# O: prune input channels
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[None, 3, 16, 16])
conv1 = conv_bn_layer(input, 8, 3, "conv1")
conv2 = conv_bn_layer(conv1, 8, 3, "conv2")
sum1 = conv1 + conv2
conv3 = conv_bn_layer(sum1, 8, 3, "conv3")
conv4 = conv_bn_layer(conv3, 8, 3, "conv4")
sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
place = fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.Scope()
exe.run(startup_program, scope=scope)
pruner = Pruner()
main_program, _, _ = pruner.prune(
main_program,
scope,
params=["conv4_weights"],
ratios=[0.5],
place=place,
lazy=False,
only_graph=False,
param_backup=None,
param_shape_backup=None)
for param in main_program.global_block().all_parameters():
if "weights" in param.name:
print("param name: {}; param shape: {}".format(param.name, param.shape))
```
---
paddleslim/prune/pruner.py
浏览文件 @
cf8fbaef
...
...
@@ -41,8 +41,8 @@ class Pruner():
place
=
None
,
lazy
=
False
,
only_graph
=
False
,
param_backup
=
Non
e
,
param_shape_backup
=
Non
e
):
param_backup
=
Fals
e
,
param_shape_backup
=
Fals
e
):
"""
Pruning the given parameters.
Args:
...
...
@@ -55,14 +55,18 @@ class Pruner():
False means cutting down the pruned elements. Default: False.
only_graph(bool): True means only modifying the graph.
False means modifying graph and variables in scope. Default: False.
param_backup(
dict): A dict to backup the values of parameters. Default: Non
e.
param_shape_backup(
dict): A dict to backup the shapes of parameters. Default: Non
e.
param_backup(
bool): Whether to return a dict to backup the values of parameters. Default: Fals
e.
param_shape_backup(
bool): Whether to return a dict to backup the shapes of parameters. Default: Fals
e.
Returns:
Program: The pruned program.
param_backup: A dict to backup the values of parameters.
param_shape_backup: A dict to backup the shapes of parameters.
"""
self
.
pruned_list
=
[]
graph
=
GraphWrapper
(
program
.
clone
())
param_backup
=
{}
if
param_backup
else
None
param_shape_backup
=
{}
if
param_shape_backup
else
None
self
.
_prune_parameters
(
graph
,
scope
,
...
...
@@ -77,7 +81,7 @@ class Pruner():
if
op
.
type
()
==
'depthwise_conv2d'
or
op
.
type
(
)
==
'depthwise_conv2d_grad'
:
op
.
set_attr
(
'groups'
,
op
.
inputs
(
'Filter'
)[
0
].
shape
()[
0
])
return
graph
.
program
return
graph
.
program
,
param_backup
,
param_shape_backup
def
_prune_filters_by_ratio
(
self
,
scope
,
...
...
@@ -531,6 +535,9 @@ class Pruner():
self
.
pruned_list
=
[[],
[]]
for
param
,
ratio
in
zip
(
params
,
ratios
):
assert
isinstance
(
param
,
str
)
or
isinstance
(
param
,
unicode
)
if
param
in
self
.
pruned_list
[
0
]:
_logger
.
info
(
"Skip {}"
.
format
(
param
))
continue
_logger
.
info
(
"pruning param: {}"
.
format
(
param
))
param
=
graph
.
var
(
param
)
self
.
_forward_pruning_ralated_params
(
...
...
tests/test_prune.py
浏览文件 @
cf8fbaef
...
...
@@ -50,7 +50,7 @@ class TestPrune(unittest.TestCase):
scope
=
fluid
.
Scope
()
exe
.
run
(
startup_program
,
scope
=
scope
)
pruner
=
Pruner
()
main_program
=
pruner
.
prune
(
main_program
,
_
,
_
=
pruner
.
prune
(
main_program
,
scope
,
params
=
[
"conv4_weights"
],
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录