未验证 提交 d72acebb 编写于 作者: M minghaoBD 提交者: GitHub

Improve unstructured pruner easeofuse (#738)

上级 8eacc16b
......@@ -67,7 +67,7 @@ python3.7 train.py --data imagenet --lr 0.05 --pruning_mode threshold --threshol
按照比例剪裁(训练速度较慢,推荐按照阈值剪裁):
```bash
python3.7 train.py --data imagenet --lr 0.05 --pruning_mode ratio --ratio 0.5
python3.7 train.py --data imagenet --lr 0.05 --pruning_mode ratio --ratio 0.55
```
GPU多卡训练:
......@@ -76,9 +76,11 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
python3.7 -m paddle.distributed.launch \
--gpus="0,1,2,3" \
--log_dir="train_mbv1_imagenet_threshold_001_log" \
train.py --data imagenet --lr 0.05 --pruning_mode threshold --threshold 0.01
train.py --data imagenet --lr 0.05 --pruning_mode threshold --threshold 0.01 --batch_size 256
```
**注意**,这里的batch_size为单卡上的。
恢复训练(请替代命令中的`dir/to/the/saved/pruned/model``INTERRUPTED_EPOCH`):
```bash
python3.7 train.py --data imagenet --lr 0.05 --pruning_mode threshold --threshold 0.01 \
......@@ -87,7 +89,7 @@ python3.7 train.py --data imagenet --lr 0.05 --pruning_mode threshold --threshol
## 推理:
```bash
python3.7 eval --pruned_model models/ --data imagenet
python3.7 evalualte.py --pruned_model models/ --data imagenet
```
剪裁训练代码示例:
......@@ -101,6 +103,7 @@ for epoch in range(epochs):
loss = calculate_loss()
loss.backward()
opt.step()
learning_rate.step()
opt.clear_grad()
#STEP2: update the pruner's threshold given the updated parameters
pruner.step()
......@@ -128,8 +131,8 @@ test()
更多使用参数请参照shell文件或者运行如下命令查看:
```bash
python3.7 train --h
python3.7 evaluate --h
python3.7 train.py --h
python3.7 evaluate.py --h
```
## 实验结果
......@@ -138,5 +141,6 @@ python3.7 evaluate --h
|:--:|:---:|:--:|:--:|:--:|:--:|:--:|:--:|
| MobileNetV1 | ImageNet | Baseline | - | 70.99%/89.68% | - | - | - |
| MobileNetV1 | ImageNet | ratio | -55.19% | 70.87%/89.80% (-0.12%/+0.12%) | 0.005 | - | 68 |
| MobileNetV1 | ImageNet | threshold | -49.49% | 71.22%/89.78% (+0.23%/+0.10%) | 0.05 | 0.01 | 93 |
| YOLO v3 | VOC | - | - |76.24% | - | - | - |
| YOLO v3 | VOC |threshold | -56.50% | 77.02%(+0.78%) | 0.001 | 0.01 | 102k iterations |
| YOLO v3 | VOC |threshold | -56.50% | 77.21% (+0.97%) | 0.001 | 0.01 | 150k iterations |
......@@ -15,13 +15,14 @@ import time
import logging
from paddleslim.common import get_logger
import paddle.distributed as dist
from paddle.distributed import ParallelEnv
_logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 64, "Minibatch size.")
add_arg('batch_size', int, 64 * 4, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('lr', float, 0.1, "The learning rate used to fine-tune pruned model.")
add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.")
......@@ -39,7 +40,7 @@ add_arg('pretrained_model', str, None, "The pretrained model the lo
add_arg('model_path', str, "./models", "The path to save model.")
add_arg('model_period', int, 10, "The period to save model in epochs.")
add_arg('resume_epoch', int, -1, "The epoch to resume training.")
add_arg('num_workers', int, 4, "number of workers when loading dataset.")
add_arg('num_workers', int, 16, "number of workers when loading dataset.")
# yapf: enable
......@@ -75,13 +76,22 @@ def create_optimizer(args, step_per_epoch, model):
def compress(args):
dist.init_parallel_env()
if args.use_gpu:
place = paddle.set_device('gpu')
else:
place = paddle.set_device('cpu')
trainer_num = paddle.distributed.get_world_size()
use_data_parallel = trainer_num != 1
if use_data_parallel:
dist.init_parallel_env()
train_reader = None
test_reader = None
if args.data == "imagenet":
import imagenet_reader as reader
train_dataset = reader.ImageNetDataset(data_dir='/data', mode='train')
val_dataset = reader.ImageNetDataset(data_dir='/data', mode='val')
train_dataset = reader.ImageNetDataset(mode='train')
val_dataset = reader.ImageNetDataset(mode='val')
class_dim = 1000
elif args.data == "cifar10":
normalize = T.Normalize(
......@@ -94,30 +104,33 @@ def compress(args):
class_dim = 10
else:
raise ValueError("{} is not supported.".format(args.data))
places = paddle.static.cuda_places(
) if args.use_gpu else paddle.static.cpu_places()
batch_size_per_card = int(args.batch_size / len(places))
batch_sampler = paddle.io.DistributedBatchSampler(
train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
train_loader = paddle.io.DataLoader(
train_dataset,
places=places,
drop_last=True,
batch_size=args.batch_size,
shuffle=True,
places=place,
batch_sampler=batch_sampler,
return_list=True,
num_workers=args.num_workers,
use_shared_memory=True)
valid_loader = paddle.io.DataLoader(
val_dataset,
places=places,
places=place,
drop_last=False,
return_list=True,
batch_size=args.batch_size,
batch_size=64,
shuffle=False,
use_shared_memory=True)
step_per_epoch = int(np.ceil(len(train_dataset) * 1. / args.batch_size))
step_per_epoch = int(
np.ceil(len(train_dataset) / args.batch_size / ParallelEnv().nranks))
# model definition
model = mobilenet_v1(num_classes=class_dim, pretrained=True)
if ParallelEnv().nranks > 1:
model = paddle.DataParallel(model)
if args.pretrained_model is not None:
model.set_state_dict(paddle.load(args.pretrained_model))
......@@ -160,44 +173,65 @@ def compress(args):
def train(epoch):
model.train()
train_reader_cost = 0.0
train_run_cost = 0.0
total_samples = 0
reader_start = time.time()
for batch_id, data in enumerate(train_loader):
start_time = time.time()
train_reader_cost += time.time() - reader_start
x_data = data[0]
y_data = paddle.to_tensor(data[1])
if args.data == 'cifar10':
y_data = paddle.unsqueeze(y_data, 1)
train_start = time.time()
logits = model(x_data)
loss = F.cross_entropy(logits, y_data)
acc_top1 = paddle.metric.accuracy(logits, y_data, k=1)
acc_top5 = paddle.metric.accuracy(logits, y_data, k=5)
end_time = time.time()
if batch_id % args.log_period == 0:
_logger.info(
"epoch[{}]-batch[{}] lr: {:.6f} - loss: {}; acc_top1: {}; acc_top5: {}; time: {}".
format(epoch, batch_id, args.lr,
np.mean(loss.numpy()),
np.mean(acc_top1.numpy()),
np.mean(acc_top5.numpy()), end_time - start_time))
loss.backward()
opt.step()
learning_rate.step()
opt.clear_grad()
pruner.step()
train_run_cost += time.time() - train_start
total_samples += args.batch_size * ParallelEnv().nranks
if batch_id % args.log_period == 0:
_logger.info(
"epoch[{}]-batch[{}] lr: {:.6f} - loss: {}; acc_top1: {}; acc_top5: {}; avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec".
format(epoch, batch_id,
opt.get_lr(),
np.mean(loss.numpy()),
np.mean(acc_top1.numpy()),
np.mean(acc_top5.numpy()), train_reader_cost /
args.log_period, (train_reader_cost + train_run_cost
) / args.log_period, total_samples
/ args.log_period, total_samples / (
train_reader_cost + train_run_cost)))
train_reader_cost = 0.0
train_run_cost = 0.0
total_samples = 0
reader_start = time.time()
pruner = UnstructuredPruner(
model,
mode=args.pruning_mode,
ratio=args.ratio,
threshold=args.threshold)
for i in range(args.resume_epoch + 1, args.num_epochs):
train(i)
if i % args.test_period == 0:
if (i + 1) % args.test_period == 0:
pruner.update_params()
_logger.info(
"The current density of the pruned model is: {}%".format(
round(100 * UnstructuredPruner.total_sparse(model), 2)))
test(i)
if i > args.resume_epoch and i % args.model_period == 0:
if (i + 1) % args.model_period == 0:
pruner.update_params()
paddle.save(model.state_dict(),
os.path.join(args.model_path, "model-pruned.pdparams"))
......
#!/bin/bash
export CUDA_VISIBLE_DEVICES=3
python3.7 train.py \
--batch_size=128 \
--batch_size=256 \
--lr=0.05 \
--ratio=0.45 \
--threshold=1e-5 \
--threshold=0.01 \
--pruning_mode="threshold" \
--data="cifar10" \
#!/bin/bash
export CUDA_VISIBLE_DEVICES=3
python3.7 train.py \
--batch_size=64 \
--batch_size=256 \
--lr=0.05 \
--ratio=0.45 \
--threshold=1e-5 \
--threshold=0.01 \
--pruning_mode="threshold" \
--data="imagenet" \
......@@ -15,7 +15,7 @@ DATA_DIM = 224
THREAD = 16
BUF_SIZE = 10240
DATA_DIR = './data/ILSVRC2012/'
DATA_DIR = 'data/ILSVRC2012/'
DATA_DIR = os.path.join(os.path.split(os.path.realpath(__file__))[0], DATA_DIR)
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
......
......@@ -68,20 +68,22 @@ def _get_skip_params(program):
按照阈值剪裁:
```bash
CUDA_VISIBLE_DEVICES=2,3 python3.7 train.py --data imagenet --lr 0.05 --pruning_mode threshold --threshold 0.01
CUDA_VISIBLE_DEVICES=2,3 python3.7 train.py --batch_size 512 --data imagenet --lr 0.05 --pruning_mode threshold --threshold 0.01
```
按照比例剪裁(训练速度较慢,推荐按照阈值剪裁):
```bash
CUDA_VISIBLE_DEVICES=2,3 python3.7 train.py --data imagenet --lr 0.05 --pruning_mode ratio --ratio 0.5
CUDA_VISIBLE_DEVICES=2,3 python3.7 train.py --batch_size 512 --data imagenet --lr 0.05 --pruning_mode ratio --ratio 0.55
```
恢复训练(请替代命令中的`dir/to/the/saved/pruned/model``INTERRUPTED_EPOCH`):
```
CUDA_VISIBLE_DEVICES=2,3 python3.7 train.py --data imagenet --lr 0.05 --pruning_mode threshold --threshold 0.01 \
CUDA_VISIBLE_DEVICES=2,3 python3.7 train.py --batch_size 512 --data imagenet --lr 0.05 --pruning_mode threshold --threshold 0.01 \
--pretrained_model dir/to/the/saved/pruned/model --resume_epoch INTERRUPTED_EPOCH
```
**注意**,上述命令中的`batch_size`为多张卡上总的`batch_size`,即一张卡的`batch_size`为256。
## 推理
```bash
CUDA_VISIBLE_DEVICES=0 python3.7 evaluate.py --pruned_model models/ --data imagenet
......@@ -107,7 +109,7 @@ opt.minimize(avg_cost)
#STEP1: initialize the pruner
pruner = UnstructuredPruner(paddle.static.default_main_program(), mode='threshold', threshold=0.01, place=place) # 按照阈值剪裁
# pruner = UnstructuredPruner(paddle.static.default_main_program(), mode='ratio', ratio=0.5, place=place) # 按照比例剪裁
# pruner = UnstructuredPruner(paddle.static.default_main_program(), mode='ratio', ratio=0.55, place=place) # 按照比例剪裁
exe.run(paddle.static.default_startup_program())
paddle.fluid.io.load_vars(exe, args.pretrained_model)
......@@ -116,10 +118,7 @@ for epoch in range(epochs):
for batch_id, data in enumerate(train_loader):
loss_n, acc_top1_n, acc_top5_n = exe.run(
train_program,
feed={
"image": data[0].get('image'),
"label": data[0].get('label')
},
feed=data,
fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name])
learning_rate.step()
#STEP2: update the pruner's threshold given the updated parameters
......@@ -157,5 +156,6 @@ python3.7 evaluate.py --h
|:--:|:---:|:--:|:--:|:--:|:--:|:--:|:--:|
| MobileNetV1 | ImageNet | Baseline | - | 70.99%/89.68% | - | - | - |
| MobileNetV1 | ImageNet | ratio | -55.19% | 70.87%/89.80% (-0.12%/+0.12%) | 0.05 | - | 68 |
| MobileNetV1 | ImageNet | threshold | -49.49% | 71.22%/89.78% (+0.23%/+0.10%) | 0.05 | 0.01 | 93 |
| YOLO v3 | VOC | - | - |76.24% | - | - | - |
| YOLO v3 | VOC |threshold | -56.50% | 77.02%(+0.78%) | 0.001 | 0.01 |102k iterations|
| YOLO v3 | VOC |threshold | -56.50% | 77.21%(+0.97%) | 0.001 | 0.01 |150k iterations|
......@@ -19,12 +19,12 @@ _logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 64, "Minibatch size.")
add_arg('batch_size', int, 64 * 4, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('model', str, "MobileNet", "The target model.")
add_arg('pretrained_model', str, "../pretrained_model/MobileNetV1_pretrained", "Whether to use pretrained model.")
add_arg('lr', float, 0.1, "The learning rate used to fine-tune pruned model.")
add_arg('lr_strategy', str, "cosine_decay", "The learning rate decay strategy.")
add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.")
add_arg('l2_decay', float, 3e-5, "The l2_decay parameter.")
add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.")
add_arg('threshold', float, 1e-5, "The threshold to set zeros, the abs(weights) lower than which will be zeros.")
......@@ -86,8 +86,8 @@ def compress(args):
args.pretrained_model = False
elif args.data == "imagenet":
import imagenet_reader as reader
train_dataset = reader.ImageNetDataset(data_dir='/data', mode='train')
val_dataset = reader.ImageNetDataset(data_dir='/data', mode='val')
train_dataset = reader.ImageNetDataset(mode='train')
val_dataset = reader.ImageNetDataset(mode='val')
class_dim = 1000
image_shape = "3,224,224"
else:
......@@ -95,14 +95,16 @@ def compress(args):
image_shape = [int(m) for m in image_shape.split(",")]
assert args.model in model_list, "{} is not in lists: {}".format(args.model,
model_list)
places = paddle.static.cuda_places(
) if args.use_gpu else paddle.static.cpu_places()
if args.use_gpu:
places = paddle.static.cuda_places()
else:
places = paddle.static.cpu_places()
place = places[0]
exe = paddle.static.Executor(place)
image = paddle.static.data(
name='image', shape=[None] + image_shape, dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
batch_size_per_card = int(args.batch_size / len(places))
train_loader = paddle.io.DataLoader(
train_dataset,
......@@ -148,6 +150,10 @@ def compress(args):
exe.run(paddle.static.default_startup_program())
if args.pretrained_model:
assert os.path.exists(
args.
pretrained_model), "Pretrained model path {} doesn't exist".format(
args.pretrained_model)
def if_exist(var):
return os.path.exists(os.path.join(args.pretrained_model, var.name))
......@@ -169,12 +175,7 @@ def compress(args):
for batch_id, data in enumerate(valid_loader):
start_time = time.time()
acc_top1_n, acc_top5_n = exe.run(
program,
feed={
"image": data[0].get('image'),
"label": data[0].get('label')
},
fetch_list=[acc_top1.name, acc_top5.name])
program, feed=data, fetch_list=[acc_top1.name, acc_top5.name])
end_time = time.time()
if batch_id % args.log_period == 0:
_logger.info(
......@@ -190,28 +191,38 @@ def compress(args):
np.mean(np.array(acc_top1_ns)), np.mean(np.array(acc_top5_ns))))
def train(epoch, program):
train_reader_cost = 0.0
train_run_cost = 0.0
total_samples = 0
reader_start = time.time()
for batch_id, data in enumerate(train_loader):
start_time = time.time()
train_reader_cost += time.time() - reader_start
train_start = time.time()
loss_n, acc_top1_n, acc_top5_n = exe.run(
train_program,
feed={
"image": data[0].get('image'),
"label": data[0].get('label')
},
feed=data,
fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name])
end_time = time.time()
pruner.step()
train_run_cost += time.time() - train_start
total_samples += args.batch_size
loss_n = np.mean(loss_n)
acc_top1_n = np.mean(acc_top1_n)
acc_top5_n = np.mean(acc_top5_n)
if batch_id % args.log_period == 0:
_logger.info(
"epoch[{}]-batch[{}] lr: {:.6f} - loss: {}; acc_top1: {}; acc_top5: {}; time: {}".
"epoch[{}]-batch[{}] lr: {:.6f} - loss: {}; acc_top1: {}; acc_top5: {}; avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec".
format(epoch, batch_id,
learning_rate.get_lr(), loss_n, acc_top1_n,
acc_top5_n, end_time - start_time))
acc_top5_n, train_reader_cost / args.log_period, (
train_reader_cost + train_run_cost
) / args.log_period, total_samples / args.log_period,
total_samples / (train_reader_cost + train_run_cost
)))
train_reader_cost = 0.0
train_run_cost = 0.0
total_samples = 0
learning_rate.step()
pruner.step()
batch_id += 1
reader_start = time.time()
build_strategy = paddle.static.BuildStrategy()
exec_strategy = paddle.static.ExecutionStrategy()
......@@ -227,10 +238,10 @@ def compress(args):
round(100 * UnstructuredPruner.total_sparse(
paddle.static.default_main_program()), 2)))
if i % args.test_period == 0:
if (i + 1) % args.test_period == 0:
pruner.update_params()
test(i, val_program)
if i > args.resume_epoch and i % args.model_period == 0:
if (i + 1) % args.model_period == 0:
pruner.update_params()
# NOTE: We are using fluid.io.save_params() because the pretrained model is from an older version which requires this API.
# Please consider using paddle.static.save(program, model_path) as long as it becomes possible.
......
......@@ -2,9 +2,9 @@
export CUDA_VISIBLE_DEVICES=2,3
export FLAGS_fraction_of_gpu_memory_to_use=0.98
python3.7 train.py \
--batch_size 256 \
--batch_size 512 \
--data imagenet \
--pruning_mode ratio \
--ratio 0.45 \
--lr 0.075 \
--pretrained_model /PaddleSlim/demo/pretrained_model/MobileNetV1_pretrained
--ratio 0.55 \
--lr 0.05 \
--pretrained_model ./MobileNetV1_pretrained
......@@ -2,9 +2,8 @@
export CUDA_VISIBLE_DEVICES=2,3
export FLAGS_fraction_of_gpu_memory_to_use=0.98
python3.7 train.py \
--batch_size=256 \
--batch_size=512 \
--data="mnist" \
--pruning_mode="threshold" \
--ratio=0.45 \
--threshold=1e-5 \
--lr=0.075 \
--threshold=0.01 \
--lr=0.05 \
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册