未验证 提交 8fad8d41 编写于 作者: M minghaoBD 提交者: GitHub

Unstructured pruning (#710)

上级 065f6444
# 非结构化稀疏 -- 动态图剪裁(包括按照阈值和比例剪裁两种模式)
## 简介
在模型压缩中,常见的稀疏方式为结构化和非结构化稀疏,前者在某个特定维度(特征通道、卷积核等等)上进行稀疏化操作;后者以每一个参数为单元进行稀疏化,所以更加依赖于硬件对稀疏后矩阵运算的加速能力。本目录即在PaddlePaddle和PaddleSlim框架下开发的非结构化稀疏算法,MobileNetV1在ImageNet上的稀疏化实验中,剪裁率55.19%,达到无损的表现。
## 版本要求
```bash
python3.5+
paddlepaddle>=2.0.0
paddleslim>=2.1.0
```
请参照github安装[paddlepaddle](https://github.com/PaddlePaddle/Paddle)[paddleslim](https://github.com/PaddlePaddle/PaddleSlim)
## 使用
训练前:
- 训练数据下载后,可以通过重写../imagenet_reader.py文件,并在train.py/evaluate.py文件中调用实现。
- 开发者可以通过重写paddleslim.dygraph.prune.unstructured_pruner.py中的UnstructuredPruner.mask_parameters()和UnstructuredPruner.update_threshold()来定义自己的非结构化稀疏策略(目前为剪裁掉绝对值小的parameters)。
- 开发可以在初始化UnstructuredPruner时,传入自定义的skip_params_func,来定义哪些参数不参与剪裁。skip_params_func示例代码如下(路径:paddleslim.dygraph.prune.unstructured_pruner._get_skip_params())。默认为所有的归一化层的参数不参与剪裁。
```python
def _get_skip_params(model):
"""
This function is used to check whether the given model's layers are valid to be pruned.
Usually, the convolutions are to be pruned while we skip the normalization-related parameters.
Deverlopers could replace this function by passing their own when initializing the UnstructuredPuner instance.
Args:
- model(Paddle.nn.Layer): the current model waiting to be checked.
Return:
- skip_params(set<String>): a set of parameters' names
"""
skip_params = set()
for _, sub_layer in model.named_sublayers():
if type(sub_layer).__name__.split('.')[-1] in paddle.nn.norm.__all__:
skip_params.add(sub_layer.full_name())
return skip_params
```
训练:
```bash
python3 train.py --data cifar10 --lr 0.1 --pruning_mode ratio --ratio=0.5
```
推理:
```bash
python3 eval --pruned_model models/ --data cifar10
```
剪裁训练代码示例:
```python
model = mobilenet_v1(num_classes=class_dim, pretrained=True)
#STEP1: initialize the pruner
pruner = UnstructuredPruner(model, mode='ratio', ratio=0.5)
for epoch in range(epochs):
for batch_id, data in enumerate(train_loader):
loss = calculate_loss()
loss.backward()
opt.step()
opt.clear_grad()
#STEP2: update the pruner's threshold given the updated parameters
pruner.step()
if epoch % args.test_period == 0:
#STEP3: before evaluation during training, eliminate the non-zeros generated by opt.step(), which, however, the cached masks setting to be zeros.
pruner.update_params()
eval(epoch)
if epoch % args.model_period == 0:
# STEP4: same purpose as STEP3
pruner.update_params()
paddle.save(model.state_dict(), "model-pruned.pdparams")
paddle.save(opt.state_dict(), "opt-pruned.pdopt")
```
剪裁后测试代码示例:
```python
model = mobilenet_v1(num_classes=class_dim, pretrained=True)
model.set_state_dict(paddle.load("model-pruned.pdparams"))
print(UnstructuredPruner.total_sparse(model)) #注意,total_sparse为静态方法(static method),可以不创建实例(instance)直接调用,方便只做测试的写法。
test()
```
更多使用参数请参照shell文件或者运行如下命令查看:
```bash
python train --h
python evaluate --h
```
## 实验结果 (刚开始在动态图代码验证,以下为静态图代码上的结果)
| 模型 | 数据集 | 压缩方法 | 压缩率| Top-1/Top-5 Acc | lr | threshold | epoch |
|:--:|:---:|:--:|:--:|:--:|:--:|:--:|:--:|
| MobileNetV1 | ImageNet | Baseline | - | 70.99%/89.68% | - | - | - |
| MobileNetV1 | ImageNet | ratio | -55.19% | 70.87%/89.80% (-0.12%/+0.12%) | 0.005 | - | 68 |
| YOLO v3 | VOC | - | - |76.24% | - | - | - |
| YOLO v3 | VOC |threshold | -41.35% | 75.29%(-0.95%) | 0.005 | 0.05 | 10w |
| YOLO v3 | VOC |threshold | -53.00% | 75.00%(-1.24%) | 0.005 | 0.075 | 10w |
## TODO
- [ ] 完成实验,验证动态图下的效果,并得到压缩模型。
- [ ] 扩充衡量parameter重要性的方法(目前仅为绝对值)。
import paddle
import os
import sys
import argparse
import numpy as np
sys.path.append(
os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir))
from paddleslim.dygraph.prune.unstructured_pruner import UnstructuredPruner
from utility import add_arguments, print_arguments
import paddle.vision.transforms as T
import paddle.nn.functional as F
import functools
from paddle.vision.models import mobilenet_v1
import time
import logging
from paddleslim.common import get_logger
_logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 64, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('model', str, "MobileNet", "The target model.")
add_arg('pruned_model', str, "dymodels/model-pruned.pdparams", "Whether to use pretrained model.")
add_arg('data', str, "cifar10", "Which data to use. 'cifar10' or 'imagenet'.")
add_arg('log_period', int, 100, "Log period in batches.")
# yapf: enable
def compress(args):
test_reader = None
if args.data == "imagenet":
import imagenet_reader as reader
val_dataset = reader.ImageNetDataset(data_dir='/data', mode='val')
class_dim = 1000
elif args.data == "cifar10":
normalize = T.Normalize(
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], data_format='CHW')
transform = T.Compose([T.Transpose(), normalize])
val_dataset = paddle.vision.datasets.Cifar10(
mode='test', backend='cv2', transform=transform)
class_dim = 10
else:
raise ValueError("{} is not supported.".format(args.data))
places = paddle.static.cuda_places(
) if args.use_gpu else paddle.static.cpu_places()
batch_size_per_card = int(args.batch_size / len(places))
valid_loader = paddle.io.DataLoader(
val_dataset,
places=places,
drop_last=False,
return_list=True,
batch_size=batch_size_per_card,
shuffle=False,
use_shared_memory=True)
# model definition
model = mobilenet_v1(num_classes=class_dim, pretrained=True)
def test(epoch):
model.eval()
acc_top1_ns = []
acc_top5_ns = []
for batch_id, data in enumerate(valid_loader):
start_time = time.time()
x_data = data[0]
y_data = paddle.to_tensor(data[1])
if args.data == 'cifar10':
y_data = paddle.unsqueeze(y_data, 1)
end_time = time.time()
logits = model(x_data)
loss = F.cross_entropy(logits, y_data)
acc_top1 = paddle.metric.accuracy(logits, y_data, k=1)
acc_top5 = paddle.metric.accuracy(logits, y_data, k=5)
acc_top1_ns.append(acc_top1.numpy())
acc_top5_ns.append(acc_top5.numpy())
if batch_id % args.log_period == 0:
_logger.info(
"Eval epoch[{}] batch[{}] - acc_top1: {}; acc_top5: {}; time: {}".
format(epoch, batch_id,
np.mean(acc_top1.numpy()),
np.mean(acc_top5.numpy()), end_time - start_time))
acc_top1_ns.append(np.mean(acc_top1.numpy()))
acc_top5_ns.append(np.mean(acc_top5.numpy()))
_logger.info("Final eval epoch[{}] - acc_top1: {}; acc_top5: {}".format(
epoch,
np.mean(np.array(
acc_top1_ns, dtype="object")),
np.mean(np.array(
acc_top5_ns, dtype="object"))))
model.set_state_dict(paddle.load(args.pruned_model))
_logger.info("The current density of the pruned model is: {}%".format(
round(100 * UnstructuredPruner.total_sparse(model), 2)))
test(0)
def main():
args = parser.parse_args()
print_arguments(args)
compress(args)
if __name__ == '__main__':
main()
#!/bin/bash
export CUDA_VISIBLE_DEVICES=3
python3.7 evaluate.py \
--pruned_model="models/model-pruned.pdparams" \
--data="cifar10"
#!/bin/bash
export CUDA_VISIBLE_DEVICES=3
python3.7 evaluate.py \
--pruned_model="models/model-pruned.pdparams" \
--data="imagenet"
import paddle
import os
import sys
import argparse
import numpy as np
from paddleslim.dygraph.prune.unstructured_pruner import UnstructuredPruner
sys.path.append(
os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir))
from utility import add_arguments, print_arguments
import paddle.vision.transforms as T
import paddle.nn.functional as F
import functools
from paddle.vision.models import mobilenet_v1
import time
import logging
from paddleslim.common import get_logger
import paddle.distributed as dist
_logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 64, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('lr', float, 0.1, "The learning rate used to fine-tune pruned model.")
add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.")
add_arg('l2_decay', float, 3e-5, "The l2_decay parameter.")
add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.")
add_arg('ratio', float, 0.3, "The ratio to set zeros, the smaller part bounded by the ratio will be zeros.")
add_arg('pruning_mode', str, 'ratio', "the pruning mode: whether by ratio or by threshold.")
add_arg('threshold', float, 0.001, "The threshold to set zeros.")
add_arg('num_epochs', int, 120, "The number of total epochs.")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
add_arg('data', str, "cifar10", "Which data to use. 'cifar10' or 'imagenet'.")
add_arg('log_period', int, 100, "Log period in batches.")
add_arg('test_period', int, 1, "Test period in epoches.")
add_arg('model_path', str, "./models", "The path to save model.")
add_arg('model_period', int, 10, "The period to save model in epochs.")
add_arg('resume_epoch', int, -1, "The epoch to resume training.")
add_arg('num_workers', int, 4, "number of workers when loading dataset.")
# yapf: enable
def piecewise_decay(args, step_per_epoch, model):
bd = [step_per_epoch * e for e in args.step_epochs]
lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
learning_rate = paddle.optimizer.lr.PiecewiseDecay(boundaries=bd, values=lr)
optimizer = paddle.optimizer.Momentum(
learning_rate=learning_rate,
momentum=args.momentum_rate,
weight_decay=paddle.regularizer.L2Decay(args.l2_decay),
parameters=model.parameters())
return optimizer, learning_rate
def cosine_decay(args, step_per_epoch, model):
learning_rate = paddle.optimizer.lr.CosineAnnealingDecay(
learning_rate=args.lr, T_max=args.num_epochs * step_per_epoch)
optimizer = paddle.optimizer.Momentum(
learning_rate=learning_rate,
momentum=args.momentum_rate,
weight_decay=paddle.regularizer.L2Decay(args.l2_decay),
parameters=model.parameters())
return optimizer, learning_rate
def create_optimizer(args, step_per_epoch, model):
if args.lr_strategy == "piecewise_decay":
return piecewise_decay(args, step_per_epoch, model)
elif args.lr_strategy == "cosine_decay":
return cosine_decay(args, step_per_epoch, model)
def compress(args):
dist.init_parallel_env()
train_reader = None
test_reader = None
if args.data == "imagenet":
import imagenet_reader as reader
train_dataset = reader.ImageNetDataset(data_dir='/data', mode='train')
val_dataset = reader.ImageNetDataset(data_dir='/data', mode='val')
class_dim = 1000
elif args.data == "cifar10":
normalize = T.Normalize(
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], data_format='CHW')
transform = T.Compose([T.Transpose(), normalize])
train_dataset = paddle.vision.datasets.Cifar10(
mode='train', backend='cv2', transform=transform)
val_dataset = paddle.vision.datasets.Cifar10(
mode='test', backend='cv2', transform=transform)
class_dim = 10
else:
raise ValueError("{} is not supported.".format(args.data))
places = paddle.static.cuda_places(
) if args.use_gpu else paddle.static.cpu_places()
batch_size_per_card = int(args.batch_size / len(places))
train_loader = paddle.io.DataLoader(
train_dataset,
places=places,
drop_last=True,
batch_size=args.batch_size,
shuffle=True,
return_list=True,
num_workers=args.num_workers,
use_shared_memory=True)
valid_loader = paddle.io.DataLoader(
val_dataset,
places=places,
drop_last=False,
return_list=True,
batch_size=args.batch_size,
shuffle=False,
use_shared_memory=True)
step_per_epoch = int(np.ceil(len(train_dataset) * 1. / args.batch_size))
# model definition
model = mobilenet_v1(num_classes=class_dim, pretrained=True)
dp_model = paddle.DataParallel(model)
opt, learning_rate = create_optimizer(args, step_per_epoch, dp_model)
def test(epoch):
dp_model.eval()
acc_top1_ns = []
acc_top5_ns = []
for batch_id, data in enumerate(valid_loader):
start_time = time.time()
x_data = data[0]
y_data = paddle.to_tensor(data[1])
if args.data == 'cifar10':
y_data = paddle.unsqueeze(y_data, 1)
end_time = time.time()
logits = dp_model(x_data)
loss = F.cross_entropy(logits, y_data)
acc_top1 = paddle.metric.accuracy(logits, y_data, k=1)
acc_top5 = paddle.metric.accuracy(logits, y_data, k=5)
acc_top1_ns.append(acc_top1.numpy())
acc_top5_ns.append(acc_top5.numpy())
if batch_id % args.log_period == 0:
_logger.info(
"Eval epoch[{}] batch[{}] - acc_top1: {}; acc_top5: {}; time: {}".
format(epoch, batch_id,
np.mean(acc_top1.numpy()),
np.mean(acc_top5.numpy()), end_time - start_time))
acc_top1_ns.append(np.mean(acc_top1.numpy()))
acc_top5_ns.append(np.mean(acc_top5.numpy()))
_logger.info("Final eval epoch[{}] - acc_top1: {}; acc_top5: {}".format(
epoch,
np.mean(np.array(
acc_top1_ns, dtype="object")),
np.mean(np.array(
acc_top5_ns, dtype="object"))))
def train(epoch):
dp_model.train()
for batch_id, data in enumerate(train_loader):
start_time = time.time()
x_data = data[0]
y_data = paddle.to_tensor(data[1])
if args.data == 'cifar10':
y_data = paddle.unsqueeze(y_data, 1)
logits = dp_model(x_data)
loss = F.cross_entropy(logits, y_data)
acc_top1 = paddle.metric.accuracy(logits, y_data, k=1)
acc_top5 = paddle.metric.accuracy(logits, y_data, k=5)
end_time = time.time()
if batch_id % args.log_period == 0:
_logger.info(
"epoch[{}]-batch[{}] lr: {:.6f} - loss: {}; acc_top1: {}; acc_top5: {}; time: {}".
format(epoch, batch_id, args.lr,
np.mean(loss.numpy()),
np.mean(acc_top1.numpy()),
np.mean(acc_top5.numpy()), end_time - start_time))
loss.backward()
opt.step()
opt.clear_grad()
pruner.step()
pruner = UnstructuredPruner(
dp_model,
mode=args.pruning_mode,
ratio=args.ratio,
threshold=args.threshold)
for i in range(args.resume_epoch + 1, args.num_epochs):
train(i)
if i % args.test_period == 0:
pruner.update_params()
_logger.info(
"The current density of the pruned model is: {}%".format(
round(100 * UnstructuredPruner.total_sparse(dp_model), 2)))
test(i)
if i > args.resume_epoch and i % args.model_period == 0:
pruner.update_params()
paddle.save(dp_model.state_dict(),
os.path.join(args.model_path, "model-pruned.pdparams"))
paddle.save(opt.state_dict(),
os.path.join(args.model_path, "opt-pruned.pdopt"))
def main():
args = parser.parse_args()
print_arguments(args)
compress(args)
if __name__ == '__main__':
main()
#!/bin/bash
export CUDA_VISIBLE_DEVICES=3
python3.7 train.py \
--batch_size=128 \
--lr=0.05 \
--ratio=0.45 \
--threshold=1e-5 \
--pruning_mode="threshold" \
--data="cifar10" \
#!/bin/bash
export CUDA_VISIBLE_DEVICES=3
python3.7 train.py \
--batch_size=64 \
--lr=0.05 \
--ratio=0.45 \
--threshold=1e-5 \
--pruning_mode="threshold" \
--data="imagenet" \
# 非结构化稀疏 -- 静态图剪裁(包括按照阈值和比例剪裁两种模式)
## 简介
在模型压缩中,常见的稀疏方式为结构化和非结构化稀疏,前者在某个特定维度(特征通道、卷积核等等)上进行稀疏化操作;后者以每一个参数为单元进行稀疏化,所以更加依赖于硬件对稀疏后矩阵运算的加速能力。本目录即在PaddlePaddle和PaddleSlim框架下开发的非结构化稀疏算法,MobileNetV1在ImageNet上的稀疏化实验中,剪裁率55.19%,达到无损的表现。
## 版本要求
```bash
python3.5+
paddlepaddle>=2.0.0
paddleslim>=2.1.0
```
请参照github安装[paddlepaddle](https://github.com/PaddlePaddle/Paddle)[paddleslim](https://github.com/PaddlePaddle/PaddleSlim)
## 使用
训练前:
- 预训练模型下载,并放到某目录下,通过train.py中的--pretrained_model设置。
- 训练数据下载后,可以通过重写../imagenet_reader.py文件,并在train.py文件中调用实现。
- 开发者可以通过重写paddleslim.prune.unstructured_pruner.py中的UnstructuredPruner.update_threshold()来定义自己的非结构化稀疏策略(目前为剪裁掉绝对值小的parameters)。
- 开发可以在初始化UnstructuredPruner时,传入自定义的skip_params_func,来定义哪些参数不参与剪裁。skip_params_func示例代码如下(路径:paddleslim.prune.unstructured_pruner._get_skip_params())。默认为所有的归一化层的参数不参与剪裁。
```python
def _get_skip_params(program):
"""
The function is used to get a set of all the skipped parameters when performing pruning.
By default, the normalization-related ones will not be pruned.
Developers could replace it by passing their own function when initializing the UnstructuredPruner instance.
Args:
- program(paddle.static.Program): the current model.
Returns:
- skip_params(Set<String>): a set of parameters' names.
"""
skip_params = set()
graph = paddleslim.core.GraphWrapper(program)
for op in graph.ops():
if 'norm' in op.type() and 'grad' not in op.type():
for input in op.all_inputs():
skip_params.add(input.name())
return skip_params
```
训练:
```bash
CUDA_VISIBLE_DEVICES=2,3 python3.7 train.py --data mnist --lr 0.1 --pruning_mode ratio --ratio=0.5
```
推理:
```bash
CUDA_VISIBLE_DEVICES=0 python3.7 evaluate.py --pruned_model models/ --data imagenet
```
剪裁训练代码示例:
```python
# model definition
places = paddle.static.cuda_places()
place = places[0]
exe = paddle.static.Executor(place)
model = models.__dict__[args.model]()
out = model.net(input=image, class_dim=class_dim)
cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label)
avg_cost = paddle.mean(x=cost)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
val_program = paddle.static.default_main_program().clone(for_test=True)
opt, learning_rate = create_optimizer(args, step_per_epoch)
opt.minimize(avg_cost)
#STEP1: initialize the pruner
pruner = UnstructuredPruner(paddle.static.default_main_program(), mode='ratio', ratio=0.5, place=place)
exe.run(paddle.static.default_startup_program())
paddle.fluid.io.load_vars(exe, args.pretrained_model)
for epoch in range(epochs):
for batch_id, data in enumerate(train_loader):
loss_n, acc_top1_n, acc_top5_n = exe.run(
train_program,
feed={
"image": data[0].get('image'),
"label": data[0].get('label')
},
fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name])
learning_rate.step()
#STEP2: update the pruner's threshold given the updated parameters
pruner.step()
if epoch % args.test_period == 0:
#STEP3: before evaluation during training, eliminate the non-zeros generated by opt.step(), which, however, the cached masks setting to be zeros.
pruner.update_params()
eval(epoch)
if epoch % args.model_period == 0:
# STEP4: same purpose as STEP3
pruner.update_params()
save(epoch)
```
剪裁后测试代码示例:
```python
# intialize the model instance in static mode
# load weights
print(UnstructuredPruner.total_sparse(paddle.static.default_main_program())) #注意,total_sparse为静态方法(static method),可以不创建实例(instance)直接调用,方便只做测试的写法。
test()
```
更多使用参数请参照shell文件,或者通过运行以下命令查看:
```bash
python3.7 train.py --h
python3.7 evaluate.py --h
```
## 实验结果
| 模型 | 数据集 | 压缩方法 | 压缩率| Top-1/Top-5 Acc | lr | threshold | epoch |
|:--:|:---:|:--:|:--:|:--:|:--:|:--:|:--:|
| MobileNetV1 | ImageNet | Baseline | - | 70.99%/89.68% | - | - | - |
| MobileNetV1 | ImageNet | ratio | -55.19% | 70.87%/89.80% (-0.12%/+0.12%) | 0.005 | - | 68 |
| YOLO v3 | VOC | - | - |76.24% | - | - | - |
| YOLO v3 | VOC |threshold | -55.15% | 75.45%(-0.79%) | 0.005 | 0.05 |12.8w|
## TODO
- [ ] 完成实验,验证动态图下的效果,并得到压缩模型。
- [ ] 扩充衡量parameter重要性的方法(目前仅为绝对值)。
import os
import sys
import logging
import paddle
import argparse
import functools
import math
import time
import numpy as np
import paddle.fluid as fluid
sys.path.append(os.path.join(os.path.dirname("__file__"), os.path.pardir))
from paddleslim.prune.unstructured_pruner import UnstructuredPruner
from paddleslim.common import get_logger
import models
from utility import add_arguments, print_arguments
import paddle.vision.transforms as T
_logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 64*12, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('model', str, "MobileNet", "The target model.")
add_arg('pruned_model', str, "models", "Whether to use pretrained model.")
add_arg('data', str, "mnist", "Which data to use. 'mnist' or 'imagenet'.")
add_arg('log_period', int, 100, "Log period in batches.")
# yapf: enable
model_list = models.__all__
def compress(args):
train_reader = None
test_reader = None
if args.data == "mnist":
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
train_dataset = paddle.vision.datasets.MNIST(
mode='train', backend="cv2", transform=transform)
val_dataset = paddle.vision.datasets.MNIST(
mode='test', backend="cv2", transform=transform)
class_dim = 10
image_shape = "1,28,28"
elif args.data == "imagenet":
import imagenet_reader as reader
train_dataset = reader.ImageNetDataset(data_dir='/data', mode='train')
val_dataset = reader.ImageNetDataset(data_dir='/data', mode='val')
class_dim = 1000
image_shape = "3,224,224"
else:
raise ValueError("{} is not supported.".format(args.data))
image_shape = [int(m) for m in image_shape.split(",")]
assert args.model in model_list, "{} is not in lists: {}".format(args.model,
model_list)
places = paddle.static.cuda_places(
) if args.use_gpu else paddle.static.cpu_places()
place = places[0]
exe = paddle.static.Executor(place)
image = paddle.static.data(
name='image', shape=[None] + image_shape, dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
batch_size_per_card = int(args.batch_size / len(places))
valid_loader = paddle.io.DataLoader(
val_dataset,
places=place,
feed_list=[image, label],
drop_last=False,
return_list=False,
use_shared_memory=True,
batch_size=batch_size_per_card,
shuffle=False)
step_per_epoch = int(np.ceil(len(train_dataset) * 1. / args.batch_size))
# model definition
model = models.__dict__[args.model]()
out = model.net(input=image, class_dim=class_dim)
cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label)
avg_cost = paddle.mean(x=cost)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
val_program = paddle.static.default_main_program().clone(for_test=True)
exe.run(paddle.static.default_startup_program())
if args.pruned_model:
def if_exist(var):
return os.path.exists(os.path.join(args.pruned_model, var.name))
_logger.info("Load pruned model from {}".format(args.pruned_model))
paddle.fluid.io.load_vars(exe, args.pruned_model, predicate=if_exist)
def test(epoch, program):
acc_top1_ns = []
acc_top5_ns = []
_logger.info("The current density of the inference model is {}%".format(
round(100 * UnstructuredPruner.total_sparse(
paddle.static.default_main_program()), 2)))
for batch_id, data in enumerate(valid_loader):
start_time = time.time()
acc_top1_n, acc_top5_n = exe.run(
program,
feed={
"image": data[0].get('image'),
"label": data[0].get('label'),
},
fetch_list=[acc_top1.name, acc_top5.name])
end_time = time.time()
if batch_id % args.log_period == 0:
_logger.info(
"Eval epoch[{}] batch[{}] - acc_top1: {}; acc_top5: {}; time: {}".
format(epoch, batch_id,
np.mean(acc_top1_n),
np.mean(acc_top5_n), end_time - start_time))
acc_top1_ns.append(np.mean(acc_top1_n))
acc_top5_ns.append(np.mean(acc_top5_n))
_logger.info("Final eval epoch[{}] - acc_top1: {}; acc_top5: {}".format(
epoch,
np.mean(np.array(acc_top1_ns)), np.mean(np.array(acc_top5_ns))))
test(0, val_program)
def main():
paddle.enable_static()
args = parser.parse_args()
print_arguments(args)
compress(args)
if __name__ == '__main__':
main()
#!/bin/bash
export CUDA_VISIBLE_DEVICES=3
export FLAGS_fraction_of_gpu_memory_to_use=0.98
python3.7 evaluate.py \
--pruned_model="models" \
--data="imagenet"
#!/bin/bash
export CUDA_VISIBLE_DEVICES=3
python3.7 evaluate.py \
--pruned_model="models" \
--data="mnist"
import os
import sys
import logging
import paddle
import argparse
import functools
import time
import numpy as np
import paddle.fluid as fluid
from paddleslim.prune.unstructured_pruner import UnstructuredPruner
from paddleslim.common import get_logger
sys.path.append(os.path.join(os.path.dirname("__file__"), os.path.pardir))
import models
from utility import add_arguments, print_arguments
import paddle.vision.transforms as T
_logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 64, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('model', str, "MobileNet", "The target model.")
add_arg('pretrained_model', str, "../pretrained_model/MobileNetV1_pretrained", "Whether to use pretrained model.")
add_arg('lr', float, 0.1, "The learning rate used to fine-tune pruned model.")
add_arg('lr_strategy', str, "cosine_decay", "The learning rate decay strategy.")
add_arg('l2_decay', float, 3e-5, "The l2_decay parameter.")
add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.")
add_arg('threshold', float, 1e-5, "The threshold to set zeros, the abs(weights) lower than which will be zeros.")
add_arg('pruning_mode', str, 'ratio', "the pruning mode: whether by ratio or by threshold.")
add_arg('ratio', float, 0.5, "The ratio to set zeros, the smaller portion will be zeros.")
add_arg('num_epochs', int, 120, "The number of total epochs.")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
add_arg('data', str, "mnist", "Which data to use. 'mnist' or 'imagenet'.")
add_arg('log_period', int, 100, "Log period in batches.")
add_arg('test_period', int, 10, "Test period in epoches.")
add_arg('model_path', str, "./models", "The path to save model.")
add_arg('model_period', int, 10, "The period to save model in epochs.")
add_arg('resume_epoch', int, -1, "The epoch to resume training.")
# yapf: enable
model_list = models.__all__
def piecewise_decay(args, step_per_epoch):
bd = [step_per_epoch * e for e in args.step_epochs]
lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
learning_rate = paddle.optimizer.lr.PiecewiseDecay(boundaries=bd, values=lr)
optimizer = paddle.optimizer.Momentum(
learning_rate=learning_rate,
momentum=args.momentum_rate,
weight_decay=paddle.regularizer.L2Decay(args.l2_decay))
return optimizer, learning_rate
def cosine_decay(args, step_per_epoch):
learning_rate = paddle.optimizer.lr.CosineAnnealingDecay(
learning_rate=args.lr, T_max=args.num_epochs * step_per_epoch)
optimizer = paddle.optimizer.Momentum(
learning_rate=learning_rate,
momentum=args.momentum_rate,
weight_decay=paddle.regularizer.L2Decay(args.l2_decay))
return optimizer, learning_rate
def create_optimizer(args, step_per_epoch):
if args.lr_strategy == "piecewise_decay":
return piecewise_decay(args, step_per_epoch)
elif args.lr_strategy == "cosine_decay":
return cosine_decay(args, step_per_epoch)
def compress(args):
train_reader = None
test_reader = None
if args.data == "mnist":
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
train_dataset = paddle.vision.datasets.MNIST(
mode='train', backend="cv2", transform=transform)
val_dataset = paddle.vision.datasets.MNIST(
mode='test', backend="cv2", transform=transform)
class_dim = 10
image_shape = "1,28,28"
args.pretrained_model = False
elif args.data == "imagenet":
import imagenet_reader as reader
train_dataset = reader.ImageNetDataset(data_dir='/data', mode='train')
val_dataset = reader.ImageNetDataset(data_dir='/data', mode='val')
class_dim = 1000
image_shape = "3,224,224"
else:
raise ValueError("{} is not supported.".format(args.data))
image_shape = [int(m) for m in image_shape.split(",")]
assert args.model in model_list, "{} is not in lists: {}".format(args.model,
model_list)
places = paddle.static.cuda_places(
) if args.use_gpu else paddle.static.cpu_places()
place = places[0]
exe = paddle.static.Executor(place)
image = paddle.static.data(
name='image', shape=[None] + image_shape, dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
batch_size_per_card = int(args.batch_size / len(places))
train_loader = paddle.io.DataLoader(
train_dataset,
places=places,
feed_list=[image, label],
drop_last=True,
batch_size=batch_size_per_card,
shuffle=True,
return_list=False,
use_shared_memory=True,
num_workers=32)
valid_loader = paddle.io.DataLoader(
val_dataset,
places=place,
feed_list=[image, label],
drop_last=False,
return_list=False,
use_shared_memory=True,
batch_size=batch_size_per_card,
shuffle=False)
step_per_epoch = int(np.ceil(len(train_dataset) * 1. / args.batch_size))
# model definition
model = models.__dict__[args.model]()
out = model.net(input=image, class_dim=class_dim)
cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label)
avg_cost = paddle.mean(x=cost)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
val_program = paddle.static.default_main_program().clone(for_test=True)
opt, learning_rate = create_optimizer(args, step_per_epoch)
opt.minimize(avg_cost)
pruner = UnstructuredPruner(
paddle.static.default_main_program(),
batch_size=args.batch_size,
mode=args.pruning_mode,
ratio=args.ratio,
threshold=args.threshold,
place=place)
exe.run(paddle.static.default_startup_program())
if args.pretrained_model:
def if_exist(var):
return os.path.exists(os.path.join(args.pretrained_model, var.name))
_logger.info("Load pretrained model from {}".format(
args.pretrained_model))
# NOTE: We are using fluid.io.load_vars() because the pretrained model is from an older version which requires this API.
# Please consider using paddle.static.load(program, model_path) when possible
paddle.fluid.io.load_vars(
exe, args.pretrained_model, predicate=if_exist)
def test(epoch, program):
acc_top1_ns = []
acc_top5_ns = []
_logger.info("The current density of the inference model is {}%".format(
round(100 * UnstructuredPruner.total_sparse(
paddle.static.default_main_program()), 2)))
for batch_id, data in enumerate(valid_loader):
start_time = time.time()
acc_top1_n, acc_top5_n = exe.run(
program,
feed={
"image": data[0].get('image'),
"label": data[0].get('label')
},
fetch_list=[acc_top1.name, acc_top5.name])
end_time = time.time()
if batch_id % args.log_period == 0:
_logger.info(
"Eval epoch[{}] batch[{}] - acc_top1: {}; acc_top5: {}; time: {}".
format(epoch, batch_id,
np.mean(acc_top1_n),
np.mean(acc_top5_n), end_time - start_time))
acc_top1_ns.append(np.mean(acc_top1_n))
acc_top5_ns.append(np.mean(acc_top5_n))
_logger.info("Final eval epoch[{}] - acc_top1: {}; acc_top5: {}".format(
epoch,
np.mean(np.array(acc_top1_ns)), np.mean(np.array(acc_top5_ns))))
def train(epoch, program):
for batch_id, data in enumerate(train_loader):
start_time = time.time()
loss_n, acc_top1_n, acc_top5_n = exe.run(
train_program,
feed={
"image": data[0].get('image'),
"label": data[0].get('label')
},
fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name])
end_time = time.time()
loss_n = np.mean(loss_n)
acc_top1_n = np.mean(acc_top1_n)
acc_top5_n = np.mean(acc_top5_n)
if batch_id % args.log_period == 0:
_logger.info(
"epoch[{}]-batch[{}] lr: {:.6f} - loss: {}; acc_top1: {}; acc_top5: {}; time: {}".
format(epoch, batch_id,
learning_rate.get_lr(), loss_n, acc_top1_n,
acc_top5_n, end_time - start_time))
learning_rate.step()
pruner.step()
batch_id += 1
build_strategy = paddle.static.BuildStrategy()
exec_strategy = paddle.static.ExecutionStrategy()
train_program = paddle.static.CompiledProgram(
paddle.static.default_main_program()).with_data_parallel(
loss_name=avg_cost.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
for i in range(args.resume_epoch + 1, args.num_epochs):
train(i, train_program)
_logger.info("The current density of the pruned model is: {}%".format(
round(100 * UnstructuredPruner.total_sparse(
paddle.static.default_main_program()), 2)))
if i % args.test_period == 0:
pruner.update_params()
test(i, val_program)
if i > args.resume_epoch and i % args.model_period == 0:
pruner.update_params()
# NOTE: We are using fluid.io.save_params() because the pretrained model is from an older version which requires this API.
# Please consider using paddle.static.save(program, model_path) as long as it becomes possible.
fluid.io.save_params(executor=exe, dirname=args.model_path)
def main():
paddle.enable_static()
args = parser.parse_args()
print_arguments(args)
compress(args)
if __name__ == '__main__':
main()
#!/bin/bash
export CUDA_VISIBLE_DEVICES=2,3
export FLAGS_fraction_of_gpu_memory_to_use=0.98
python3.7 train.py \
--batch_size 256 \
--data imagenet \
--pruning_mode ratio \
--ratio 0.45 \
--lr 0.075 \
--pretrained_model /PaddleSlim/demo/pretrained_model/MobileNetV1_pretrained
#!/bin/bash
export CUDA_VISIBLE_DEVICES=2,3
export FLAGS_fraction_of_gpu_memory_to_use=0.98
python3.7 train.py \
--batch_size=256 \
--data="mnist" \
--pruning_mode="threshold" \
--ratio=0.45 \
--threshold=1e-5 \
--lr=0.075 \
......@@ -10,6 +10,8 @@ from . import l2norm_pruner
from .l2norm_pruner import *
from . import fpgm_pruner
from .fpgm_pruner import *
from . import unstructured_pruner
from .unstructured_pruner import *
__all__ = []
......@@ -19,3 +21,4 @@ __all__ += l2norm_pruner.__all__
__all__ += fpgm_pruner.__all__
__all__ += pruner.__all__
__all__ += filter_pruner.__all__
__all__ += unstructured_pruner.__all__
import numpy as np
import paddle
import logging
from paddleslim.common import get_logger
__all__ = ["UnstructuredPruner"]
_logger = get_logger(__name__, level=logging.INFO)
class UnstructuredPruner():
"""
The unstructure pruner.
Args:
- model(Paddle.nn.Layer): The model to be pruned.
- mode(str): Pruning mode, must be selected from 'ratio' and 'threshold'.
- threshold(float): The parameters whose absolute values are smaller than the THRESHOLD will be zeros. Default: 0.01
- ratio(float): The parameters whose absolute values are in the smaller part decided by the ratio will be zeros. Default: 0.3
- skip_params_func(function): The function used to select the parameters which should be skipped when performing pruning. Default: normalization-related params.
"""
def __init__(self,
model,
mode,
threshold=0.01,
ratio=0.3,
skip_params_func=None):
assert mode in ('ratio', 'threshold'
), "mode must be selected from 'ratio' and 'threshold'"
self.model = model
self.mode = mode
self.threshold = threshold
self.ratio = ratio
if skip_params_func is None: skip_params_func = self._get_skip_params
self.skip_params = skip_params_func(model)
self._apply_masks()
def mask_parameters(self, param, mask):
"""
Update masks and parameters. It is executed to each layer before each iteration.
User can overwrite this function in subclass to implememt different pruning stragies.
Args:
- parameters(list<Tensor>): The parameters to be pruned.
- masks(list<Tensor>): The masks used to keep zero values in parameters.
"""
bool_tmp = (paddle.abs(param) >= self.threshold)
paddle.assign(mask * bool_tmp, output=mask)
param_tmp = param * mask
param_tmp.stop_gradient = True
paddle.assign(param_tmp, output=param)
def _apply_masks(self):
self.masks = {}
for name, sub_layer in self.model.named_sublayers():
for param in sub_layer.parameters(include_sublayers=False):
tmp_array = np.ones(param.shape, dtype=np.float32)
mask_name = "_".join([param.name.replace(".", "_"), "mask"])
if mask_name not in sub_layer._buffers:
sub_layer.register_buffer(mask_name,
paddle.to_tensor(tmp_array))
self.masks[param.name] = sub_layer._buffers[mask_name]
for name, sub_layer in self.model.named_sublayers():
sub_layer.register_forward_pre_hook(self._forward_pre_hook)
def update_threshold(self):
'''
Update the threshold after each optimization step.
User should overwrite this method togther with self.mask_parameters()
'''
params_flatten = []
for name, sub_layer in self.model.named_sublayers():
if not self._should_prune_layer(sub_layer):
continue
for param in sub_layer.parameters(include_sublayers=False):
t_param = param.value().get_tensor()
v_param = np.array(t_param)
params_flatten.append(v_param.flatten())
params_flatten = np.concatenate(params_flatten, axis=0)
total_length = params_flatten.size
self.threshold = np.sort(np.abs(params_flatten))[max(
0, round(self.ratio * total_length) - 1)].item()
def step(self):
"""
Update the threshold after each optimization step.
"""
if self.mode == 'ratio':
self.update_threshold()
elif self.mode == 'threshold':
return
def _forward_pre_hook(self, layer, input):
if not self._should_prune_layer(layer):
return input
for param in layer.parameters(include_sublayers=False):
mask = self.masks.get(param.name)
self.mask_parameters(param, mask)
return input
def update_params(self):
"""
Update the parameters given self.masks, usually called before saving models and evaluation step during training.
If you load a sparse model and only want to inference, no need to call the method.
"""
for name, sub_layer in self.model.named_sublayers():
for param in sub_layer.parameters(include_sublayers=False):
mask = self.masks.get(param.name)
param_tmp = param * mask
param_tmp.stop_gradient = True
paddle.assign(param_tmp, output=param)
@staticmethod
def total_sparse(model):
"""
This static function is used to get the whole model's density (1-sparsity).
It is static because during testing, we can calculate sparsity without initializing a pruner instance.
Args:
- model(Paddle.Model): The sparse model.
Returns:
- ratio(float): The model's density.
"""
total = 0
values = 0
for name, sub_layer in model.named_sublayers():
for param in sub_layer.parameters(include_sublayers=False):
total += np.product(param.shape)
values += len(paddle.nonzero(param))
ratio = float(values) / total
return ratio
def _get_skip_params(self, model):
"""
This function is used to check whether the given model's layers are valid to be pruned.
Usually, the convolutions are to be pruned while we skip the normalization-related parameters.
Deverlopers could replace this function by passing their own when initializing the UnstructuredPuner instance.
Args:
- model(Paddle.nn.Layer): the current model waiting to be checked.
Return:
- skip_params(set<String>): a set of parameters' names
"""
skip_params = set()
for _, sub_layer in model.named_sublayers():
if type(sub_layer).__name__.split('.')[
-1] in paddle.nn.norm.__all__:
skip_params.add(sub_layer.full_name())
return skip_params
def _should_prune_layer(self, layer):
should_prune = layer.full_name() not in self.skip_params
return should_prune
......@@ -27,6 +27,8 @@ from .group_param import *
from ..prune import group_param
from .criterion import *
from ..prune import criterion
from .unstructured_pruner import *
from ..prune import unstructured_pruner
from .idx_selector import *
from ..prune import idx_selector
......@@ -39,4 +41,5 @@ __all__ += prune_walker.__all__
__all__ += prune_io.__all__
__all__ += group_param.__all__
__all__ += criterion.__all__
__all__ += unstructured_pruner.__all__
__all__ += idx_selector.__all__
import numpy as np
from ..common import get_logger
from ..core import GraphWrapper
import paddle
__all__ = ["UnstructuredPruner"]
class UnstructuredPruner():
"""
The unstructure pruner.
Args:
- program(paddle.static.Program): The model to be pruned.
- batch_size(int): batch size.
- mode(str): the mode to prune the model, must be selected from 'ratio' and 'threshold'.
- ratio(float): the ratio to prune the model. Only set it when mode=='ratio'. Default: 0.5.
- threshold(float): the threshold to prune the model. Only set it when mode=='threshold'. Default: 1e-5.
- scope(paddle.static.Scope): The scope storing values of all variables. None means paddle.static.global_scope. Default: None.
- place(CPUPlace | CUDAPlace): The device place used to execute model. None means CPUPlace. Default: None.
- skip_params_func(function): The function used to select the parameters which should be skipped when performing pruning. Default: normalization-related params.
"""
def __init__(self,
program,
batch_size,
mode,
ratio=0.5,
threshold=1e-5,
scope=None,
place=None,
skip_params_func=None):
self.mode = mode
self.ratio = ratio
self.threshold = threshold
assert self.mode in [
'ratio', 'threshold'
], "mode must be selected from 'ratio' and 'threshold'"
self.scope = paddle.static.global_scope() if scope == None else scope
self.place = paddle.static.CPUPlace() if place is None else place
if skip_params_func is None: skip_params_func = self._get_skip_params
self.skip_params = skip_params_func(program)
self.masks = self._apply_masks(program)
def _apply_masks(self, program):
params = []
masks = []
for param in program.all_parameters():
mask = program.global_block().create_var(
name=param.name + "_mask",
shape=param.shape,
dtype=param.dtype,
type=param.type,
persistable=param.persistable,
stop_gradient=True)
self.scope.var(param.name + "_mask").get_tensor().set(
np.ones(mask.shape).astype("float32"), self.place)
params.append(param)
masks.append(mask)
d_masks = {}
for _param, _mask in zip(params, masks):
d_masks[_param.name] = _mask.name
return d_masks
def summarize_weights(self, program, ratio=0.1):
"""
The function is used to get the weights corresponding to a given ratio
when you are uncertain about the threshold in __init__() function above.
For example, when given 0.1 as ratio, the function will print the weight value,
the abs(weights) lower than which count for 10% of the total numbers.
Args:
- program(paddle.static.Program): The model which have all the parameters.
- ratio(float): The ratio illustrated above.
Return:
- threshold(float): a threshold corresponding to the input ratio.
"""
data = []
for param in program.all_parameters():
data.append(
np.array(paddle.static.global_scope().find_var(param.name)
.get_tensor()).flatten())
data = np.concatenate(data, axis=0)
threshold = np.sort(np.abs(data))[max(0, int(ratio * len(data) - 1))]
return threshold
def sparse_by_layer(self, program):
"""
The function is used to get the density at each layer, usually called for debuggings.
Args:
- program(paddle.static.Program): The current model.
Returns:
- layer_sparse(Dict<string, float>): sparsity for each parameter.
"""
layer_sparse = {}
total = 0
values = 0
for param in program.all_parameters():
value = np.count_nonzero(
np.array(paddle.static.global_scope().find_var(param.name)
.get_tensor()))
layer_sparse[param.name] = value / np.product(param.shape)
return layer_sparse
def update_threshold(self):
'''
Update the threshold after each optimization step in RATIO mode.
User should overwrite this method to define their own weight importance (Default is based on their absolute values).
'''
params_flatten = []
for param in self.masks:
if not self._should_prune_param(param):
continue
t_param = self.scope.find_var(param).get_tensor()
v_param = np.array(t_param)
params_flatten.append(v_param.flatten())
params_flatten = np.concatenate(params_flatten, axis=0)
total_len = len(params_flatten)
self.threshold = np.sort(np.abs(params_flatten))[max(
0, int(self.ratio * total_len) - 1)]
def _update_params_masks(self):
for param in self.masks:
if not self._should_prune_param(param):
continue
mask_name = self.masks[param]
t_param = self.scope.find_var(param).get_tensor()
t_mask = self.scope.find_var(mask_name).get_tensor()
v_param = np.array(t_param)
v_param[np.abs(v_param) < self.threshold] = 0
v_mask = (v_param != 0).astype(v_param.dtype)
t_mask.set(v_mask, self.place)
v_param = np.array(t_param) * np.array(t_mask)
t_param.set(v_param, self.place)
def step(self):
"""
Update the threshold after each optimization step.
"""
if self.mode == 'threshold':
pass
elif self.mode == 'ratio':
self.update_threshold()
self._update_params_masks()
def update_params(self):
"""
Update the parameters given self.masks, usually called before saving models.
"""
for param in self.masks:
mask = self.masks[param]
t_param = self.scope.find_var(param).get_tensor()
t_mask = self.scope.find_var(mask).get_tensor()
v_param = np.array(t_param) * np.array(t_mask)
t_param.set(v_param, self.place)
@staticmethod
def total_sparse(program):
"""
The function is used to get the whole model's density (1-sparsity).
It is static because during testing, we can calculate sparsity without initializing a pruner instance.
Args:
- program(paddle.static.Program): The current model.
Returns:
- density(float): the model's density.
"""
total = 0
values = 0
for param in program.all_parameters():
total += np.product(param.shape)
values += np.count_nonzero(
np.array(paddle.static.global_scope().find_var(param.name)
.get_tensor()))
density = float(values) / total
return density
def _get_skip_params(self, program):
"""
The function is used to get a set of all the skipped parameters when performing pruning.
By default, the normalization-related ones will not be pruned.
Developers could replace it by passing their own function when initializing the UnstructuredPruner instance.
Args:
- program(paddle.static.Program): the current model.
Returns:
- skip_params(Set<String>): a set of parameters' names.
"""
skip_params = set()
graph = GraphWrapper(program)
for op in graph.ops():
if 'norm' in op.type() and 'grad' not in op.type():
for input in op.all_inputs():
skip_params.add(input.name())
return skip_params
def _should_prune_param(self, param):
should_prune = param not in self.skip_params
return should_prune
import sys
sys.path.append("../../")
import unittest
import paddle
import numpy as np
from paddleslim.dygraph.prune.unstructured_pruner import UnstructuredPruner
from paddle.vision.models import mobilenet_v1
class TestUnstructuredPruner(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestUnstructuredPruner, self).__init__(*args, **kwargs)
paddle.disable_static()
self._gen_model()
def _gen_model(self):
self.net = mobilenet_v1(num_classes=10, pretrained=False)
self.pruner = UnstructuredPruner(
self.net, mode='ratio', ratio=0.98, threshold=0.0)
def test_prune(self):
ori_density = UnstructuredPruner.total_sparse(self.net)
ori_threshold = self.pruner.threshold
self.pruner.step()
self.net(
paddle.to_tensor(
np.random.uniform(0, 1, [16, 3, 32, 32]), dtype='float32'))
cur_density = UnstructuredPruner.total_sparse(self.net)
cur_threshold = self.pruner.threshold
print("Original threshold: {}".format(ori_threshold))
print("Current threshold: {}".format(cur_threshold))
print("Original density: {}".format(ori_density))
print("Current density: {}".format(cur_density))
self.assertLessEqual(ori_threshold, cur_threshold)
self.assertLessEqual(cur_density, ori_density)
self.pruner.update_params()
self.assertEqual(cur_density, UnstructuredPruner.total_sparse(self.net))
if __name__ == "__main__":
unittest.main()
import sys
sys.path.append("../")
import unittest
from static_case import StaticCase
import paddle.fluid as fluid
import paddle
from paddleslim.prune import UnstructuredPruner
from layers import conv_bn_layer
import numpy as np
class TestUnstructuredPruner(StaticCase):
def __init__(self, *args, **kwargs):
super(TestUnstructuredPruner, self).__init__(*args, **kwargs)
paddle.enable_static()
self._gen_model()
def _gen_model(self):
self.main_program = paddle.static.default_main_program()
self.startup_program = paddle.static.default_startup_program()
# conv1-->conv2-->sum1-->conv3-->conv4-->sum2-->conv5-->conv6
# | ^ | ^
# |____________| |____________________|
with paddle.static.program_guard(self.main_program,
self.startup_program):
input = paddle.static.data(name='image', shape=[None, 3, 16, 16])
conv1 = conv_bn_layer(input, 8, 3, "conv1")
conv2 = conv_bn_layer(conv1, 8, 3, "conv2")
sum1 = conv1 + conv2
conv3 = conv_bn_layer(sum1, 8, 3, "conv3")
conv4 = conv_bn_layer(conv3, 8, 3, "conv4")
sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
conv7 = fluid.layers.conv2d_transpose(
input=conv6, num_filters=16, filter_size=2, stride=2)
place = paddle.static.cpu_places()[0]
exe = paddle.static.Executor(place)
self.scope = paddle.static.global_scope()
exe.run(self.startup_program, scope=self.scope)
self.pruner = UnstructuredPruner(
self.main_program, 16, 'ratio', scope=self.scope, place=place)
def test_unstructured_prune(self):
for param in self.main_program.global_block().all_parameters():
mask_name = param.name + "_mask"
mask_shape = self.scope.find_var(mask_name).get_tensor().shape()
self.assertTrue(tuple(mask_shape) == param.shape)
def test_sparsity(self):
ori_density = UnstructuredPruner.total_sparse(self.main_program)
self.pruner.step()
cur_density = UnstructuredPruner.total_sparse(self.main_program)
cur_layer_density = self.pruner.sparse_by_layer(self.main_program)
print('original density: {}.'.format(ori_density))
print('current density: {}.'.format(cur_density))
total = 0
non_zeros = 0
for param in self.main_program.all_parameters():
total += np.product(param.shape)
non_zeros += np.count_nonzero(
np.array(self.scope.find_var(param.name).get_tensor()))
self.assertEqual(cur_density, non_zeros / total)
self.assertLessEqual(cur_density, ori_density)
self.pruner.update_params()
self.assertEqual(cur_density,
UnstructuredPruner.total_sparse(self.main_program))
def test_summarize_weights(self):
max_value = -float("inf")
threshold = self.pruner.summarize_weights(self.main_program, 1.0)
for param in self.main_program.global_block().all_parameters():
max_value = max(
max_value,
np.max(np.array(self.scope.find_var(param.name).get_tensor())))
print("The returned threshold is {}.".format(threshold))
print("The max_value is {}.".format(max_value))
self.assertEqual(max_value, threshold)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册