未验证 提交 50d69ece 编写于 作者: W whs 提交者: GitHub

Add save and load API for pruned model (#38)

上级 6cde8f01
# 卷积通道剪裁示例
本示例将演示如何按指定的剪裁率对每个卷积层的通道数进行剪裁。该示例默认会自动下载并使用mnist数据。
当前示例支持以下分类模型:
- MobileNetV1
- MobileNetV2
- ResNet50
- PVANet
## 接口介绍
该示例使用了`paddleslim.Pruner`工具类,用户接口使用介绍请参考:[API文档](https://paddlepaddle.github.io/PaddleSlim/api/prune_api/)
## 确定待裁参数
不同模型的参数命名不同,在剪裁前需要确定待裁卷积层的参数名称。可通过以下方法列出所有参数名:
```
for param in program.global_block().all_parameters():
print("param name: {}; shape: {}".format(param.name, param.shape))
```
`train.py`脚本中,提供了`get_pruned_params`方法,根据用户设置的选项`--model`确定要裁剪的参数。
## 启动裁剪任务
通过以下命令启动裁剪任务:
```
export CUDA_VISIBLE_DEVICES=0
python train.py
```
在本示例中,每训练一轮就会保存一个模型到文件系统。
执行`python train.py --help`查看更多选项。
## 注意
1. 在接口`paddle.Pruner.prune`的参数中,`params``ratios`的长度需要一样。
## 加载和评估模型
本节介绍如何加载训练过程中保存的模型。
执行以下代码加载模型并评估模型在测试集上的指标。
```
python eval.py \
--model "mobilenet" \
--data "mnist" \
--model_path "./models/0"
```
在脚本`eval.py`中,使用`paddleslim.prune.load_model`接口加载剪裁得到的模型。
import os
import sys
import logging
import paddle
import argparse
import functools
import math
import time
import numpy as np
import paddle.fluid as fluid
from paddleslim.prune import load_model
from paddleslim.common import get_logger
from paddleslim.analysis import flops
sys.path.append(sys.path[0] + "/../")
import models
from utility import add_arguments, print_arguments
_logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 64 * 4, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('model', str, "MobileNet", "The target model.")
add_arg('model_path', str, "./models/0", "The path of model used to evalate..")
add_arg('data', str, "mnist", "Which data to use. 'mnist' or 'imagenet'")
add_arg('log_period', int, 10, "Log period in batches.")
# yapf: enable
model_list = models.__all__
def eval(args):
train_reader = None
test_reader = None
if args.data == "mnist":
import paddle.dataset.mnist as reader
train_reader = reader.train()
val_reader = reader.test()
class_dim = 10
image_shape = "1,28,28"
elif args.data == "imagenet":
import imagenet_reader as reader
train_reader = reader.train()
val_reader = reader.val()
class_dim = 1000
image_shape = "3,224,224"
else:
raise ValueError("{} is not supported.".format(args.data))
image_shape = [int(m) for m in image_shape.split(",")]
assert args.model in model_list, "{} is not in lists: {}".format(
args.model, model_list)
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# model definition
model = models.__dict__[args.model]()
out = model.net(input=image, class_dim=class_dim)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
val_program = fluid.default_main_program().clone(for_test=True)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
val_reader = paddle.batch(val_reader, batch_size=args.batch_size)
val_feeder = feeder = fluid.DataFeeder(
[image, label], place, program=val_program)
load_model(val_program, "./model/mobilenetv1_prune_50")
batch_id = 0
acc_top1_ns = []
acc_top5_ns = []
for data in val_reader():
start_time = time.time()
acc_top1_n, acc_top5_n = exe.run(
val_program,
feed=val_feeder.feed(data),
fetch_list=[acc_top1.name, acc_top5.name])
end_time = time.time()
if batch_id % args.log_period == 0:
_logger.info(
"Eval batch[{}] - acc_top1: {}; acc_top5: {}; time: {}".format(
batch_id,
np.mean(acc_top1_n),
np.mean(acc_top5_n), end_time - start_time))
acc_top1_ns.append(np.mean(acc_top1_n))
acc_top5_ns.append(np.mean(acc_top5_n))
batch_id += 1
_logger.info("Final eval - acc_top1: {}; acc_top5: {}".format(
np.mean(np.array(acc_top1_ns)), np.mean(np.array(acc_top5_ns))))
def main():
args = parser.parse_args()
print_arguments(args)
eval(args)
if __name__ == '__main__':
main()
...@@ -35,9 +35,10 @@ add_arg('config_file', str, None, "The config file for comp ...@@ -35,9 +35,10 @@ add_arg('config_file', str, None, "The config file for comp
add_arg('data', str, "mnist", "Which data to use. 'mnist' or 'imagenet'") add_arg('data', str, "mnist", "Which data to use. 'mnist' or 'imagenet'")
add_arg('log_period', int, 10, "Log period in batches.") add_arg('log_period', int, 10, "Log period in batches.")
add_arg('test_period', int, 10, "Test period in epoches.") add_arg('test_period', int, 10, "Test period in epoches.")
add_arg('model_path', str, "./models", "The path to save model.")
# yapf: enable # yapf: enable
model_list = [m for m in dir(models) if "__" not in m] model_list = models.__all__
def get_pruned_params(args, program): def get_pruned_params(args, program):
...@@ -221,6 +222,8 @@ def compress(args): ...@@ -221,6 +222,8 @@ def compress(args):
train(i, pruned_program) train(i, pruned_program)
if i % args.test_period == 0: if i % args.test_period == 0:
test(i, pruned_val_program) test(i, pruned_val_program)
save_model(pruned_val_program,
os.path.join(args.model_path, str(i)))
def main(): def main():
......
...@@ -25,6 +25,8 @@ from .sensitive import * ...@@ -25,6 +25,8 @@ from .sensitive import *
import sensitive import sensitive
from prune_walker import * from prune_walker import *
import prune_walker import prune_walker
from io import *
import io
__all__ = [] __all__ = []
...@@ -35,3 +37,4 @@ __all__ += controller_client.__all__ ...@@ -35,3 +37,4 @@ __all__ += controller_client.__all__
__all__ += sensitive_pruner.__all__ __all__ += sensitive_pruner.__all__
__all__ += sensitive.__all__ __all__ += sensitive.__all__
__all__ += prune_walker.__all__ __all__ += prune_walker.__all__
__all__ += io.__all__
import os
import paddle.fluid as fluid
from paddle.fluid import Program
from ..core import GraphWrapper
from ..common import get_logger
import json
import logging
__all__ = ["save_model", "load_model"]
_logger = get_logger(__name__, level=logging.INFO)
PARAMS_FILE = "__params__"
SHAPES_FILE = "__shapes__"
def save_model(graph, dirname):
"""
Save weights of model and information of shapes into filesystem.
Args:
- graph(Program|Graph): The graph to be saved.
- dirname(str): The directory that the model saved into.
"""
assert graph is not None and dirname is not None
graph = GraphWrapper(graph) if isinstance(graph, Program) else graph
exe = fluid.Executor(fluid.CPUPlace())
fluid.io.save_params(
executor=exe,
dirname=dirname,
main_program=graph.program,
filename=PARAMS_FILE)
weights_file = os.path.join(dirname, PARAMS_FILE)
_logger.info("Save model weights into {}".format(weights_file))
shapes = {}
for var in graph.all_parameters():
shapes[var.name()] = var.shape()
SHAPES_FILE = os.path.join(dirname, SHAPES_FILE)
with open(SHAPES_FILE, "w") as f:
json.dump(shapes, f)
_logger.info("Save shapes of weights into {}".format(SHAPES_FILE))
def load_model(graph, dirname):
"""
Load weights of model and information of shapes from filesystem.
Args:
- graph(Program|Graph): The graph to be saved.
- dirname(str): The directory that the model saved into.
"""
assert graph is not None and dirname is not None
graph = GraphWrapper(graph) if isinstance(graph, Program) else graph
exe = fluid.Executor(fluid.CPUPlace())
SHAPES_FILE = os.path.join(dirname, SHAPES_FILE)
_logger.info("Load shapes of weights from {}".format(SHAPES_FILE))
with open(SHAPES_FILE, "r") as f:
shapes = json.load(f)
for param, shape in shapes.items():
graph.var(param).set_shape(shape)
_logger.info("Load shapes of weights from {}".format(SHAPES_FILE))
exe = fluid.Executor(fluid.CPUPlace())
fluid.io.load_params(
executor=exe,
dirname=dirname,
main_program=graph.program,
filename=PARAMS_FILE)
graph.update_groups_of_conv()
graph.infer_shape()
_logger.info("Load weights from {}".format(
os.path.join(dirname, PARAMS_FILE)))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册