Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
123fcc57
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
大约 1 年 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
123fcc57
编写于
5月 18, 2021
作者:
M
minghaoBD
提交者:
GitHub
5月 18, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix api docs easeof use (#740) (#743)
上级
aef70340
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
168 addition
and
59 deletion
+168
-59
demo/dygraph/unstructured_pruning/evaluate.py
demo/dygraph/unstructured_pruning/evaluate.py
+3
-7
demo/dygraph/unstructured_pruning/train.py
demo/dygraph/unstructured_pruning/train.py
+3
-5
demo/unstructured_prune/evaluate.py
demo/unstructured_prune/evaluate.py
+4
-10
demo/unstructured_prune/train.py
demo/unstructured_prune/train.py
+2
-1
docs/zh_cn/api_cn/dygraph/pruners/unstructured_pruner.rst
docs/zh_cn/api_cn/dygraph/pruners/unstructured_pruner.rst
+44
-14
docs/zh_cn/api_cn/static/prune/unstructured_prune_api.rst
docs/zh_cn/api_cn/static/prune/unstructured_prune_api.rst
+112
-22
未找到文件。
demo/dygraph/unstructured_pruning/evaluate.py
浏览文件 @
123fcc57
...
...
@@ -33,7 +33,7 @@ def compress(args):
test_reader
=
None
if
args
.
data
==
"imagenet"
:
import
imagenet_reader
as
reader
val_dataset
=
reader
.
ImageNetDataset
(
data_dir
=
'/data'
,
mode
=
'val'
)
val_dataset
=
reader
.
ImageNetDataset
(
mode
=
'val'
)
class_dim
=
1000
elif
args
.
data
==
"cifar10"
:
normalize
=
T
.
Normalize
(
...
...
@@ -47,13 +47,12 @@ def compress(args):
places
=
paddle
.
static
.
cuda_places
(
)
if
args
.
use_gpu
else
paddle
.
static
.
cpu_places
()
batch_size_per_card
=
int
(
args
.
batch_size
/
len
(
places
))
valid_loader
=
paddle
.
io
.
DataLoader
(
val_dataset
,
places
=
places
,
drop_last
=
False
,
return_list
=
True
,
batch_size
=
batch_size_per_card
,
batch_size
=
args
.
batch_size
,
shuffle
=
False
,
use_shared_memory
=
True
)
...
...
@@ -70,15 +69,12 @@ def compress(args):
y_data
=
paddle
.
to_tensor
(
data
[
1
])
if
args
.
data
==
'cifar10'
:
y_data
=
paddle
.
unsqueeze
(
y_data
,
1
)
end_time
=
time
.
time
()
logits
=
model
(
x_data
)
loss
=
F
.
cross_entropy
(
logits
,
y_data
)
acc_top1
=
paddle
.
metric
.
accuracy
(
logits
,
y_data
,
k
=
1
)
acc_top5
=
paddle
.
metric
.
accuracy
(
logits
,
y_data
,
k
=
5
)
acc_top1_ns
.
append
(
acc_top1
.
numpy
())
acc_top5_ns
.
append
(
acc_top5
.
numpy
())
end_time
=
time
.
time
()
if
batch_id
%
args
.
log_period
==
0
:
_logger
.
info
(
"Eval epoch[{}] batch[{}] - acc_top1: {}; acc_top5: {}; time: {}"
.
...
...
demo/dygraph/unstructured_pruning/train.py
浏览文件 @
123fcc57
...
...
@@ -23,6 +23,7 @@ parser = argparse.ArgumentParser(description=__doc__)
add_arg
=
functools
.
partial
(
add_arguments
,
argparser
=
parser
)
# yapf: disable
add_arg
(
'batch_size'
,
int
,
64
*
4
,
"Minibatch size."
)
add_arg
(
'batch_size_for_validation'
,
int
,
64
,
"Minibatch size for validation."
)
add_arg
(
'use_gpu'
,
bool
,
True
,
"Whether to use GPU or not."
)
add_arg
(
'lr'
,
float
,
0.1
,
"The learning rate used to fine-tune pruned model."
)
add_arg
(
'lr_strategy'
,
str
,
"piecewise_decay"
,
"The learning rate decay strategy."
)
...
...
@@ -121,7 +122,7 @@ def compress(args):
places
=
place
,
drop_last
=
False
,
return_list
=
True
,
batch_size
=
64
,
batch_size
=
args
.
batch_size_for_validation
,
shuffle
=
False
,
use_shared_memory
=
True
)
step_per_epoch
=
int
(
...
...
@@ -146,15 +147,12 @@ def compress(args):
y_data
=
paddle
.
to_tensor
(
data
[
1
])
if
args
.
data
==
'cifar10'
:
y_data
=
paddle
.
unsqueeze
(
y_data
,
1
)
end_time
=
time
.
time
()
logits
=
model
(
x_data
)
loss
=
F
.
cross_entropy
(
logits
,
y_data
)
acc_top1
=
paddle
.
metric
.
accuracy
(
logits
,
y_data
,
k
=
1
)
acc_top5
=
paddle
.
metric
.
accuracy
(
logits
,
y_data
,
k
=
5
)
acc_top1_ns
.
append
(
acc_top1
.
numpy
())
acc_top5_ns
.
append
(
acc_top5
.
numpy
())
end_time
=
time
.
time
()
if
batch_id
%
args
.
log_period
==
0
:
_logger
.
info
(
"Eval epoch[{}] batch[{}] - acc_top1: {}; acc_top5: {}; time: {}"
.
...
...
demo/unstructured_prune/evaluate.py
浏览文件 @
123fcc57
...
...
@@ -20,7 +20,7 @@ _logger = get_logger(__name__, level=logging.INFO)
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
add_arg
=
functools
.
partial
(
add_arguments
,
argparser
=
parser
)
# yapf: disable
add_arg
(
'batch_size'
,
int
,
64
*
12
,
"Minibatch size."
)
add_arg
(
'batch_size'
,
int
,
64
,
"Minibatch size."
)
add_arg
(
'use_gpu'
,
bool
,
True
,
"Whether to use GPU or not."
)
add_arg
(
'model'
,
str
,
"MobileNet"
,
"The target model."
)
add_arg
(
'pruned_model'
,
str
,
"models"
,
"Whether to use pretrained model."
)
...
...
@@ -44,8 +44,8 @@ def compress(args):
image_shape
=
"1,28,28"
elif
args
.
data
==
"imagenet"
:
import
imagenet_reader
as
reader
train_dataset
=
reader
.
ImageNetDataset
(
data_dir
=
'/data'
,
mode
=
'train'
)
val_dataset
=
reader
.
ImageNetDataset
(
data_dir
=
'/data'
,
mode
=
'val'
)
train_dataset
=
reader
.
ImageNetDataset
(
mode
=
'train'
)
val_dataset
=
reader
.
ImageNetDataset
(
mode
=
'val'
)
class_dim
=
1000
image_shape
=
"3,224,224"
else
:
...
...
@@ -71,7 +71,6 @@ def compress(args):
use_shared_memory
=
True
,
batch_size
=
batch_size_per_card
,
shuffle
=
False
)
step_per_epoch
=
int
(
np
.
ceil
(
len
(
train_dataset
)
*
1.
/
args
.
batch_size
))
# model definition
model
=
models
.
__dict__
[
args
.
model
]()
...
...
@@ -103,12 +102,7 @@ def compress(args):
for
batch_id
,
data
in
enumerate
(
valid_loader
):
start_time
=
time
.
time
()
acc_top1_n
,
acc_top5_n
=
exe
.
run
(
program
,
feed
=
{
"image"
:
data
[
0
].
get
(
'image'
),
"label"
:
data
[
0
].
get
(
'label'
),
},
fetch_list
=
[
acc_top1
.
name
,
acc_top5
.
name
])
program
,
feed
=
data
,
fetch_list
=
[
acc_top1
.
name
,
acc_top5
.
name
])
end_time
=
time
.
time
()
if
batch_id
%
args
.
log_period
==
0
:
_logger
.
info
(
...
...
demo/unstructured_prune/train.py
浏览文件 @
123fcc57
...
...
@@ -20,6 +20,7 @@ parser = argparse.ArgumentParser(description=__doc__)
add_arg
=
functools
.
partial
(
add_arguments
,
argparser
=
parser
)
# yapf: disable
add_arg
(
'batch_size'
,
int
,
64
*
4
,
"Minibatch size."
)
add_arg
(
'batch_size_for_validation'
,
int
,
64
,
"Minibatch size for validation."
)
add_arg
(
'use_gpu'
,
bool
,
True
,
"Whether to use GPU or not."
)
add_arg
(
'model'
,
str
,
"MobileNet"
,
"The target model."
)
add_arg
(
'pretrained_model'
,
str
,
"../pretrained_model/MobileNetV1_pretrained"
,
"Whether to use pretrained model."
)
...
...
@@ -123,7 +124,7 @@ def compress(args):
drop_last
=
False
,
return_list
=
False
,
use_shared_memory
=
True
,
batch_size
=
batch_size_per_card
,
batch_size
=
args
.
batch_size_for_validation
,
shuffle
=
False
)
step_per_epoch
=
int
(
np
.
ceil
(
len
(
train_dataset
)
*
1.
/
args
.
batch_size
))
...
...
docs/zh_cn/api_cn/dygraph/pruners/unstructured_pruner.rst
浏览文件 @
123fcc57
...
...
@@ -7,7 +7,7 @@ UnstructuredPruner
..
py
:
class
::
paddleslim
.
UnstructuredPruner
(
model
,
mode
,
threshold
=
0.01
,
ratio
=
0.3
,
skip_params_func
=
None
)
`源代码 <https://github.com/
minghaoBD/PaddleSlim/blob/update_unstructured_pruning_docs
/paddleslim/dygraph/prune/unstructured_pruner.py>`_
`
源代码
<
https
://
github
.
com
/
PaddlePaddle
/
PaddleSlim
/
blob
/
develop
/
paddleslim
/
dygraph
/
prune
/
unstructured_pruner
.
py
>`
_
对于神经网络中的参数进行非结构化稀疏。非结构化稀疏是指,根据某些衡量指标,将不重要的参数置
0
。其不按照固定结构剪裁(例如一个通道等),这是和结构化剪枝的主要区别。
...
...
@@ -23,11 +23,16 @@ UnstructuredPruner
**
示例代码:
**
此示例不能直接运行,因为需要定义和加载模型,详细用法请参考 `这里 <https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/dygraph/unstructured_pruning>`_
..
code
-
block
::
python
import
paddle
from
paddleslim
import
UnstructuredPruner
from
paddle
.
vision
.
models
import
LeNet
as
net
import
numpy
as
np
place
=
paddle
.
set_device
(
'cpu'
)
model
=
net
(
num_classes
=
10
)
pruner
=
UnstructuredPruner
(
model
,
mode
=
'ratio'
,
ratio
=
0.5
)
..
...
...
@@ -38,13 +43,19 @@ UnstructuredPruner
**
示例代码:
**
此示例不能直接运行,因为需要定义和加载模型,详细用法请参考 `这里 <https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/dygraph/unstructured_pruning>`_
..
code
-
block
::
python
from
paddleslim
import
UnstructuredPruner
from
paddle
.
vision
.
models
import
LeNet
as
net
import
numpy
as
np
place
=
paddle
.
set_device
(
'cpu'
)
model
=
net
(
num_classes
=
10
)
pruner
=
UnstructuredPruner
(
model
,
mode
=
'ratio'
,
ratio
=
0.5
)
print
(
pruner
.
threshold
)
pruner
.
step
()
print
(
pruner
.
threshold
)
#
可以看出,这里的
threshold
和上面打印的不同,这是因为
step
函数根据设定的
ratio
更新了
threshold
数值,便于剪裁操作。
..
...
...
@@ -54,13 +65,23 @@ UnstructuredPruner
**
示例代码:
**
此示例不能直接运行,因为需要定义和加载模型,详细用法请参考 `这里 <https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/dygraph/unstructured_pruning>`_
..
code
-
block
::
python
from
paddleslim
import
UnstructuredPruner
pruner = UnstructuredPruner(model, mode='ratio', ratio=0.5)
from
paddle
.
vision
.
models
import
LeNet
as
net
import
numpy
as
np
place
=
paddle
.
set_device
(
'cpu'
)
model
=
net
(
num_classes
=
10
)
pruner
=
UnstructuredPruner
(
model
,
mode
=
'threshold'
,
threshold
=
0.5
)
density
=
UnstructuredPruner
.
total_sparse
(
model
)
print
(
density
)
model
(
paddle
.
to_tensor
(
np
.
random
.
uniform
(
0
,
1
,
[
16
,
1
,
28
,
28
]),
dtype
=
'float32'
))
pruner
.
update_params
()
density
=
UnstructuredPruner
.
total_sparse
(
model
)
print
(
density
)
#
可以看出,这里打印的模型稠密度与上述不同,这是因为
update_params
()
函数置零了所有绝对值小于
0.5
的权重。
..
...
...
@@ -78,13 +99,17 @@ UnstructuredPruner
**
示例代码:
**
此示例不能直接运行,因为需要定义和加载模型,详细用法请参考 `这里 <https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/dygraph/unstructured_pruning>`_
..
code
-
block
::
python
from
paddleslim
import
UnstructuredPruner
density = UnstructuredPruner.total_sparse(model)
from
paddle
.
vision
.
models
import
LeNet
as
net
import
numpy
as
np
place
=
paddle
.
set_device
(
'cpu'
)
model
=
net
(
num_classes
=
10
)
density
=
UnstructuredPruner
.
total_sparse
(
model
)
print
(
density
)
..
..
py
:
method
::
paddleslim
.
UnstructuredPruner
.
summarize_weights
(
model
,
ratio
=
0.1
)
...
...
@@ -102,12 +127,17 @@ UnstructuredPruner
**
示例代码:
**
此示例不能直接运行,因为需要定义和加载模型,详细用法请参考 `这里 <https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/dygraph/unstructured_pruning>`_
..
code
-
block
::
python
from
paddleslim
import
UnstructuredPruner
from
paddle
.
vision
.
models
import
LeNet
as
net
import
numpy
as
np
place
=
paddle
.
set_device
(
'cpu'
)
model
=
net
(
num_classes
=
10
)
pruner
=
UnstructuredPruner
(
model
,
mode
=
'ratio'
,
ratio
=
0.5
)
threshold = pruner.summarize_weights(model, ratio=0.1)
threshold
=
pruner
.
summarize_weights
(
model
,
0.5
)
print
(
threshold
)
..
docs/zh_cn/api_cn/static/prune/unstructured_prune_api.rst
浏览文件 @
123fcc57
...
...
@@ -24,13 +24,30 @@ UnstrucuturedPruner
**示例代码:**
此示例不能直接运行,因为需要加载数据和模型,详细demo请参照 `这里 <https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/unstructured_prune>`_
.. code-block:: python
from paddleslim.prune import UnstructuredPruner
pruner = UnstructuredPruner()
import paddle as paddle
import paddle.fluid as fluid
from paddleslim.prune import UnstructuredPruner
paddle.enable_static()
train_program = paddle.static.default_main_program()
startup_program = paddle.static.default_startup_program()
with fluid.program_guard(train_program, startup_program):
image = fluid.data(name='x', shape=[None, 1, 28, 28])
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
conv = fluid.layers.conv2d(image, 32, 1)
feature = fluid.layers.fc(conv, 10, act='softmax')
cost = fluid.layers.cross_entropy(input=feature, label=label)
avg_cost = fluid.layers.mean(x=cost)
place = paddle.static.cpu_places()[0]
exe = paddle.static.Executor(place)
exe.run(startup_program)
pruner = UnstructuredPruner(paddle.static.default_main_program(), 'ratio', ratio=0.5, place=place)
..
.. py:method:: paddleslim.prune.unstructured_pruner.UnstructuredPruner.step()
...
...
@@ -39,33 +56,71 @@ UnstrucuturedPruner
**示例代码:**
此示例不能直接运行,因为需要加载数据和模型,详细demo请参照 `这里 <https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/unstructured_prune>`_
.. code-block:: python
import paddle as paddle
import paddle.fluid as fluid
from paddleslim.prune import UnstructuredPruner
pruner = UnstructuredPruner(
paddle.static.default_main_program(), 'ratio', scope=paddle.static.global_scope(), place=paddle.static.cpu_places()[0])
pruner.step()
paddle.enable_static()
train_program = paddle.static.default_main_program()
startup_program = paddle.static.default_startup_program()
with fluid.program_guard(train_program, startup_program):
image = fluid.data(name='x', shape=[None, 1, 28, 28])
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
conv = fluid.layers.conv2d(image, 32, 1)
feature = fluid.layers.fc(conv, 10, act='softmax')
cost = fluid.layers.cross_entropy(input=feature, label=label)
avg_cost = fluid.layers.mean(x=cost)
place = paddle.static.cpu_places()[0]
exe = paddle.static.Executor(place)
exe.run(startup_program)
pruner = UnstructuredPruner(paddle.static.default_main_program(), 'ratio', ratio=0.5, place=place)
print(pruner.threshold)
pruner.step()
print(pruner.threshold) # 可以看出,这里的threshold和上面打印的不同,这是因为step函数根据设定的ratio更新了threshold数值,便于剪裁操作。
..
.. py:method:: paddleslim.prune.unstructured_pruner.UnstructuredPruner.update_params()
每一步优化后,重制模型中本来是0的权重。这一步通常用于模型evaluation和save之前,确保模型的稀疏率。
每一步优化后,重制模型中本来是0的权重。这一步通常用于模型evaluation和save之前,确保模型的稀疏率。
但是,在训练过程中,由于step()函数会调用该方法,故不需要开发者在训练过程中额外调用了。
**示例代码:**
此示例不能直接运行,因为需要加载数据和模型,详细demo请参照 `这里 <https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/unstructured_prune>`_
.. code-block:: python
import paddle as paddle
import paddle.fluid as fluid
from paddleslim.prune import UnstructuredPruner
pruner = UnstructuredPruner(
paddle.static.default_main_program(), 'ratio', scope=paddle.static.global_scope(), place=paddle.static.cpu_places()[0])
paddle.enable_static()
train_program = paddle.static.default_main_program()
startup_program = paddle.static.default_startup_program()
with fluid.program_guard(train_program, startup_program):
image = fluid.data(name='x', shape=[None, 1, 28, 28])
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
conv = fluid.layers.conv2d(image, 32, 1)
feature = fluid.layers.fc(conv, 10, act='softmax')
cost = fluid.layers.cross_entropy(input=feature, label=label)
avg_cost = fluid.layers.mean(x=cost)
place = paddle.static.cpu_places()[0]
exe = paddle.static.Executor(place)
exe.run(startup_program)
pruner = UnstructuredPruner(paddle.static.default_main_program(), 'threshold', threshold=0.5, place=place)
density = UnstructuredPruner.total_sparse(paddle.static.default_main_program())
print(density)
pruner.step()
pruner.update_params()
density = UnstructuredPruner.total_sparse(paddle.static.default_main_program())
print(density) # 可以看出,这里打印的模型稠密度与上述不同,这是因为update_params()函数置零了所有绝对值小于0.5的权重。
..
...
...
@@ -83,13 +138,31 @@ UnstrucuturedPruner
**示例代码:**
此示例不能直接运行,因为需要加载数据和模型,详细demo请参照 `这里 <https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/unstructured_prune>`_
.. code-block:: python
import paddle as paddle
import paddle.fluid as fluid
from paddleslim.prune import UnstructuredPruner
paddle.enable_static()
train_program = paddle.static.default_main_program()
startup_program = paddle.static.default_startup_program()
with fluid.program_guard(train_program, startup_program):
image = fluid.data(name='x', shape=[None, 1, 28, 28])
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
conv = fluid.layers.conv2d(image, 32, 1)
feature = fluid.layers.fc(conv, 10, act='softmax')
cost = fluid.layers.cross_entropy(input=feature, label=label)
avg_cost = fluid.layers.mean(x=cost)
place = paddle.static.cpu_places()[0]
exe = paddle.static.Executor(place)
exe.run(startup_program)
density = UnstructuredPruner.total_sparse(paddle.static.default_main_program())
print(density)
..
...
...
@@ -108,14 +181,31 @@ UnstrucuturedPruner
**示例代码:**
此示例不能直接运行,因为需要加载数据和模型,详细demo请参照 `这里 <https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/unstructured_prune>`_
.. code-block:: python
import paddle as paddle
import paddle.fluid as fluid
from paddleslim.prune import UnstructuredPruner
pruner = UnstructuredPruner(
paddle.static.default_main_program(), 'ratio', scope=paddle.static.global_scope(), place=paddle.static.cpu_places()[0])
threshold = pruner.summarize_weights(paddle.static.default_main_program(), 1.0)
paddle.enable_static()
train_program = paddle.static.default_main_program()
startup_program = paddle.static.default_startup_program()
with fluid.program_guard(train_program, startup_program):
image = fluid.data(name='x', shape=[None, 1, 28, 28])
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
conv = fluid.layers.conv2d(image, 32, 1)
feature = fluid.layers.fc(conv, 10, act='softmax')
cost = fluid.layers.cross_entropy(input=feature, label=label)
avg_cost = fluid.layers.mean(x=cost)
place = paddle.static.cpu_places()[0]
exe = paddle.static.Executor(place)
exe.run(startup_program)
threshold = pruner.summarize_weights(paddle.static.default_main_program(), ratio=0.5)
print(threshold)
..
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录