From 8fad8d4192e0657ebeb045ba5875ff1ee1d4c015 Mon Sep 17 00:00:00 2001 From: minghaoBD <79566150+minghaoBD@users.noreply.github.com> Date: Wed, 14 Apr 2021 17:42:05 +0800 Subject: [PATCH] Unstructured pruning (#710) --- demo/dygraph/unstructured_pruning/README.md | 106 ++++++++ demo/dygraph/unstructured_pruning/evaluate.py | 111 ++++++++ .../unstructured_pruning/evaluate_cifar10.sh | 5 + .../unstructured_pruning/evaluate_imagenet.sh | 5 + demo/dygraph/unstructured_pruning/train.py | 213 +++++++++++++++ .../unstructured_pruning/train_cifar10.sh | 9 + .../unstructured_pruning/train_imagenet.sh | 9 + demo/unstructured_prune/README.md | 128 +++++++++ demo/unstructured_prune/evaluate.py | 137 ++++++++++ demo/unstructured_prune/evaluate_imagenet.sh | 6 + demo/unstructured_prune/evaluate_mnist.sh | 5 + demo/unstructured_prune/train.py | 249 ++++++++++++++++++ demo/unstructured_prune/train_imagenet.sh | 10 + demo/unstructured_prune/train_mnist.sh | 10 + paddleslim/dygraph/prune/__init__.py | 3 + .../dygraph/prune/unstructured_pruner.py | 152 +++++++++++ paddleslim/prune/__init__.py | 3 + paddleslim/prune/unstructured_pruner.py | 202 ++++++++++++++ tests/dygraph/test_unstructured_prune.py | 42 +++ tests/test_unstructured_pruner.py | 86 ++++++ 20 files changed, 1491 insertions(+) create mode 100644 demo/dygraph/unstructured_pruning/README.md create mode 100644 demo/dygraph/unstructured_pruning/evaluate.py create mode 100644 demo/dygraph/unstructured_pruning/evaluate_cifar10.sh create mode 100644 demo/dygraph/unstructured_pruning/evaluate_imagenet.sh create mode 100644 demo/dygraph/unstructured_pruning/train.py create mode 100644 demo/dygraph/unstructured_pruning/train_cifar10.sh create mode 100644 demo/dygraph/unstructured_pruning/train_imagenet.sh create mode 100644 demo/unstructured_prune/README.md create mode 100644 demo/unstructured_prune/evaluate.py create mode 100644 demo/unstructured_prune/evaluate_imagenet.sh create mode 100644 demo/unstructured_prune/evaluate_mnist.sh create mode 100644 demo/unstructured_prune/train.py create mode 100644 demo/unstructured_prune/train_imagenet.sh create mode 100644 demo/unstructured_prune/train_mnist.sh create mode 100644 paddleslim/dygraph/prune/unstructured_pruner.py create mode 100644 paddleslim/prune/unstructured_pruner.py create mode 100644 tests/dygraph/test_unstructured_prune.py create mode 100644 tests/test_unstructured_pruner.py diff --git a/demo/dygraph/unstructured_pruning/README.md b/demo/dygraph/unstructured_pruning/README.md new file mode 100644 index 00000000..d343bfa5 --- /dev/null +++ b/demo/dygraph/unstructured_pruning/README.md @@ -0,0 +1,106 @@ +# 非结构化稀疏 -- 动态图剪裁(包括按照阈值和比例剪裁两种模式) + +## 简介 + +在模型压缩中,常见的稀疏方式为结构化和非结构化稀疏,前者在某个特定维度(特征通道、卷积核等等)上进行稀疏化操作;后者以每一个参数为单元进行稀疏化,所以更加依赖于硬件对稀疏后矩阵运算的加速能力。本目录即在PaddlePaddle和PaddleSlim框架下开发的非结构化稀疏算法,MobileNetV1在ImageNet上的稀疏化实验中,剪裁率55.19%,达到无损的表现。 + +## 版本要求 +```bash +python3.5+ +paddlepaddle>=2.0.0 +paddleslim>=2.1.0 +``` + +请参照github安装[paddlepaddle](https://github.com/PaddlePaddle/Paddle)和[paddleslim](https://github.com/PaddlePaddle/PaddleSlim)。 + +## 使用 + +训练前: +- 训练数据下载后,可以通过重写../imagenet_reader.py文件,并在train.py/evaluate.py文件中调用实现。 +- 开发者可以通过重写paddleslim.dygraph.prune.unstructured_pruner.py中的UnstructuredPruner.mask_parameters()和UnstructuredPruner.update_threshold()来定义自己的非结构化稀疏策略(目前为剪裁掉绝对值小的parameters)。 +- 开发可以在初始化UnstructuredPruner时,传入自定义的skip_params_func,来定义哪些参数不参与剪裁。skip_params_func示例代码如下(路径:paddleslim.dygraph.prune.unstructured_pruner._get_skip_params())。默认为所有的归一化层的参数不参与剪裁。 + +```python +def _get_skip_params(model): + """ + This function is used to check whether the given model's layers are valid to be pruned. + Usually, the convolutions are to be pruned while we skip the normalization-related parameters. + Deverlopers could replace this function by passing their own when initializing the UnstructuredPuner instance. + + Args: + - model(Paddle.nn.Layer): the current model waiting to be checked. + Return: + - skip_params(set): a set of parameters' names + """ + skip_params = set() + for _, sub_layer in model.named_sublayers(): + if type(sub_layer).__name__.split('.')[-1] in paddle.nn.norm.__all__: + skip_params.add(sub_layer.full_name()) + return skip_params +``` + +训练: +```bash +python3 train.py --data cifar10 --lr 0.1 --pruning_mode ratio --ratio=0.5 +``` + +推理: +```bash +python3 eval --pruned_model models/ --data cifar10 +``` + +剪裁训练代码示例: +```python +model = mobilenet_v1(num_classes=class_dim, pretrained=True) +#STEP1: initialize the pruner +pruner = UnstructuredPruner(model, mode='ratio', ratio=0.5) + +for epoch in range(epochs): + for batch_id, data in enumerate(train_loader): + loss = calculate_loss() + loss.backward() + opt.step() + opt.clear_grad() + #STEP2: update the pruner's threshold given the updated parameters + pruner.step() + + if epoch % args.test_period == 0: + #STEP3: before evaluation during training, eliminate the non-zeros generated by opt.step(), which, however, the cached masks setting to be zeros. + pruner.update_params() + eval(epoch) + + if epoch % args.model_period == 0: + # STEP4: same purpose as STEP3 + pruner.update_params() + paddle.save(model.state_dict(), "model-pruned.pdparams") + paddle.save(opt.state_dict(), "opt-pruned.pdopt") +``` + +剪裁后测试代码示例: +```python +model = mobilenet_v1(num_classes=class_dim, pretrained=True) +model.set_state_dict(paddle.load("model-pruned.pdparams")) +print(UnstructuredPruner.total_sparse(model)) #注意,total_sparse为静态方法(static method),可以不创建实例(instance)直接调用,方便只做测试的写法。 +test() +``` + +更多使用参数请参照shell文件或者运行如下命令查看: +```bash +python train --h +python evaluate --h +``` + +## 实验结果 (刚开始在动态图代码验证,以下为静态图代码上的结果) + +| 模型 | 数据集 | 压缩方法 | 压缩率| Top-1/Top-5 Acc | lr | threshold | epoch | +|:--:|:---:|:--:|:--:|:--:|:--:|:--:|:--:| +| MobileNetV1 | ImageNet | Baseline | - | 70.99%/89.68% | - | - | - | +| MobileNetV1 | ImageNet | ratio | -55.19% | 70.87%/89.80% (-0.12%/+0.12%) | 0.005 | - | 68 | +| YOLO v3 | VOC | - | - |76.24% | - | - | - | +| YOLO v3 | VOC |threshold | -41.35% | 75.29%(-0.95%) | 0.005 | 0.05 | 10w | +| YOLO v3 | VOC |threshold | -53.00% | 75.00%(-1.24%) | 0.005 | 0.075 | 10w | + +## TODO + +- [ ] 完成实验,验证动态图下的效果,并得到压缩模型。 +- [ ] 扩充衡量parameter重要性的方法(目前仅为绝对值)。 diff --git a/demo/dygraph/unstructured_pruning/evaluate.py b/demo/dygraph/unstructured_pruning/evaluate.py new file mode 100644 index 00000000..bd24a4b8 --- /dev/null +++ b/demo/dygraph/unstructured_pruning/evaluate.py @@ -0,0 +1,111 @@ +import paddle +import os +import sys +import argparse +import numpy as np +sys.path.append( + os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir)) +from paddleslim.dygraph.prune.unstructured_pruner import UnstructuredPruner +from utility import add_arguments, print_arguments +import paddle.vision.transforms as T +import paddle.nn.functional as F +import functools +from paddle.vision.models import mobilenet_v1 +import time +import logging +from paddleslim.common import get_logger + +_logger = get_logger(__name__, level=logging.INFO) + +parser = argparse.ArgumentParser(description=__doc__) +add_arg = functools.partial(add_arguments, argparser=parser) +# yapf: disable +add_arg('batch_size', int, 64, "Minibatch size.") +add_arg('use_gpu', bool, True, "Whether to use GPU or not.") +add_arg('model', str, "MobileNet", "The target model.") +add_arg('pruned_model', str, "dymodels/model-pruned.pdparams", "Whether to use pretrained model.") +add_arg('data', str, "cifar10", "Which data to use. 'cifar10' or 'imagenet'.") +add_arg('log_period', int, 100, "Log period in batches.") +# yapf: enable + + +def compress(args): + test_reader = None + if args.data == "imagenet": + import imagenet_reader as reader + val_dataset = reader.ImageNetDataset(data_dir='/data', mode='val') + class_dim = 1000 + elif args.data == "cifar10": + normalize = T.Normalize( + mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], data_format='CHW') + transform = T.Compose([T.Transpose(), normalize]) + val_dataset = paddle.vision.datasets.Cifar10( + mode='test', backend='cv2', transform=transform) + class_dim = 10 + else: + raise ValueError("{} is not supported.".format(args.data)) + + places = paddle.static.cuda_places( + ) if args.use_gpu else paddle.static.cpu_places() + batch_size_per_card = int(args.batch_size / len(places)) + valid_loader = paddle.io.DataLoader( + val_dataset, + places=places, + drop_last=False, + return_list=True, + batch_size=batch_size_per_card, + shuffle=False, + use_shared_memory=True) + + # model definition + model = mobilenet_v1(num_classes=class_dim, pretrained=True) + + def test(epoch): + model.eval() + acc_top1_ns = [] + acc_top5_ns = [] + for batch_id, data in enumerate(valid_loader): + start_time = time.time() + x_data = data[0] + y_data = paddle.to_tensor(data[1]) + if args.data == 'cifar10': + y_data = paddle.unsqueeze(y_data, 1) + end_time = time.time() + + logits = model(x_data) + loss = F.cross_entropy(logits, y_data) + acc_top1 = paddle.metric.accuracy(logits, y_data, k=1) + acc_top5 = paddle.metric.accuracy(logits, y_data, k=5) + + acc_top1_ns.append(acc_top1.numpy()) + acc_top5_ns.append(acc_top5.numpy()) + if batch_id % args.log_period == 0: + _logger.info( + "Eval epoch[{}] batch[{}] - acc_top1: {}; acc_top5: {}; time: {}". + format(epoch, batch_id, + np.mean(acc_top1.numpy()), + np.mean(acc_top5.numpy()), end_time - start_time)) + acc_top1_ns.append(np.mean(acc_top1.numpy())) + acc_top5_ns.append(np.mean(acc_top5.numpy())) + + _logger.info("Final eval epoch[{}] - acc_top1: {}; acc_top5: {}".format( + epoch, + np.mean(np.array( + acc_top1_ns, dtype="object")), + np.mean(np.array( + acc_top5_ns, dtype="object")))) + + model.set_state_dict(paddle.load(args.pruned_model)) + _logger.info("The current density of the pruned model is: {}%".format( + round(100 * UnstructuredPruner.total_sparse(model), 2))) + test(0) + + +def main(): + args = parser.parse_args() + print_arguments(args) + compress(args) + + +if __name__ == '__main__': + main() diff --git a/demo/dygraph/unstructured_pruning/evaluate_cifar10.sh b/demo/dygraph/unstructured_pruning/evaluate_cifar10.sh new file mode 100644 index 00000000..b07ed269 --- /dev/null +++ b/demo/dygraph/unstructured_pruning/evaluate_cifar10.sh @@ -0,0 +1,5 @@ +#!/bin/bash +export CUDA_VISIBLE_DEVICES=3 +python3.7 evaluate.py \ + --pruned_model="models/model-pruned.pdparams" \ + --data="cifar10" diff --git a/demo/dygraph/unstructured_pruning/evaluate_imagenet.sh b/demo/dygraph/unstructured_pruning/evaluate_imagenet.sh new file mode 100644 index 00000000..be03e1fd --- /dev/null +++ b/demo/dygraph/unstructured_pruning/evaluate_imagenet.sh @@ -0,0 +1,5 @@ +#!/bin/bash +export CUDA_VISIBLE_DEVICES=3 +python3.7 evaluate.py \ + --pruned_model="models/model-pruned.pdparams" \ + --data="imagenet" diff --git a/demo/dygraph/unstructured_pruning/train.py b/demo/dygraph/unstructured_pruning/train.py new file mode 100644 index 00000000..30343e0f --- /dev/null +++ b/demo/dygraph/unstructured_pruning/train.py @@ -0,0 +1,213 @@ +import paddle +import os +import sys +import argparse +import numpy as np +from paddleslim.dygraph.prune.unstructured_pruner import UnstructuredPruner +sys.path.append( + os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir)) +from utility import add_arguments, print_arguments +import paddle.vision.transforms as T +import paddle.nn.functional as F +import functools +from paddle.vision.models import mobilenet_v1 +import time +import logging +from paddleslim.common import get_logger +import paddle.distributed as dist + +_logger = get_logger(__name__, level=logging.INFO) + +parser = argparse.ArgumentParser(description=__doc__) +add_arg = functools.partial(add_arguments, argparser=parser) +# yapf: disable +add_arg('batch_size', int, 64, "Minibatch size.") +add_arg('use_gpu', bool, True, "Whether to use GPU or not.") +add_arg('lr', float, 0.1, "The learning rate used to fine-tune pruned model.") +add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.") +add_arg('l2_decay', float, 3e-5, "The l2_decay parameter.") +add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.") +add_arg('ratio', float, 0.3, "The ratio to set zeros, the smaller part bounded by the ratio will be zeros.") +add_arg('pruning_mode', str, 'ratio', "the pruning mode: whether by ratio or by threshold.") +add_arg('threshold', float, 0.001, "The threshold to set zeros.") +add_arg('num_epochs', int, 120, "The number of total epochs.") +parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step") +add_arg('data', str, "cifar10", "Which data to use. 'cifar10' or 'imagenet'.") +add_arg('log_period', int, 100, "Log period in batches.") +add_arg('test_period', int, 1, "Test period in epoches.") +add_arg('model_path', str, "./models", "The path to save model.") +add_arg('model_period', int, 10, "The period to save model in epochs.") +add_arg('resume_epoch', int, -1, "The epoch to resume training.") +add_arg('num_workers', int, 4, "number of workers when loading dataset.") +# yapf: enable + + +def piecewise_decay(args, step_per_epoch, model): + bd = [step_per_epoch * e for e in args.step_epochs] + lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)] + learning_rate = paddle.optimizer.lr.PiecewiseDecay(boundaries=bd, values=lr) + + optimizer = paddle.optimizer.Momentum( + learning_rate=learning_rate, + momentum=args.momentum_rate, + weight_decay=paddle.regularizer.L2Decay(args.l2_decay), + parameters=model.parameters()) + return optimizer, learning_rate + + +def cosine_decay(args, step_per_epoch, model): + learning_rate = paddle.optimizer.lr.CosineAnnealingDecay( + learning_rate=args.lr, T_max=args.num_epochs * step_per_epoch) + optimizer = paddle.optimizer.Momentum( + learning_rate=learning_rate, + momentum=args.momentum_rate, + weight_decay=paddle.regularizer.L2Decay(args.l2_decay), + parameters=model.parameters()) + return optimizer, learning_rate + + +def create_optimizer(args, step_per_epoch, model): + if args.lr_strategy == "piecewise_decay": + return piecewise_decay(args, step_per_epoch, model) + elif args.lr_strategy == "cosine_decay": + return cosine_decay(args, step_per_epoch, model) + + +def compress(args): + dist.init_parallel_env() + train_reader = None + test_reader = None + if args.data == "imagenet": + import imagenet_reader as reader + train_dataset = reader.ImageNetDataset(data_dir='/data', mode='train') + val_dataset = reader.ImageNetDataset(data_dir='/data', mode='val') + class_dim = 1000 + elif args.data == "cifar10": + normalize = T.Normalize( + mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], data_format='CHW') + transform = T.Compose([T.Transpose(), normalize]) + train_dataset = paddle.vision.datasets.Cifar10( + mode='train', backend='cv2', transform=transform) + val_dataset = paddle.vision.datasets.Cifar10( + mode='test', backend='cv2', transform=transform) + class_dim = 10 + else: + raise ValueError("{} is not supported.".format(args.data)) + places = paddle.static.cuda_places( + ) if args.use_gpu else paddle.static.cpu_places() + batch_size_per_card = int(args.batch_size / len(places)) + train_loader = paddle.io.DataLoader( + train_dataset, + places=places, + drop_last=True, + batch_size=args.batch_size, + shuffle=True, + return_list=True, + num_workers=args.num_workers, + use_shared_memory=True) + valid_loader = paddle.io.DataLoader( + val_dataset, + places=places, + drop_last=False, + return_list=True, + batch_size=args.batch_size, + shuffle=False, + use_shared_memory=True) + step_per_epoch = int(np.ceil(len(train_dataset) * 1. / args.batch_size)) + + # model definition + model = mobilenet_v1(num_classes=class_dim, pretrained=True) + dp_model = paddle.DataParallel(model) + + opt, learning_rate = create_optimizer(args, step_per_epoch, dp_model) + + def test(epoch): + dp_model.eval() + acc_top1_ns = [] + acc_top5_ns = [] + for batch_id, data in enumerate(valid_loader): + start_time = time.time() + x_data = data[0] + y_data = paddle.to_tensor(data[1]) + if args.data == 'cifar10': + y_data = paddle.unsqueeze(y_data, 1) + end_time = time.time() + + logits = dp_model(x_data) + loss = F.cross_entropy(logits, y_data) + acc_top1 = paddle.metric.accuracy(logits, y_data, k=1) + acc_top5 = paddle.metric.accuracy(logits, y_data, k=5) + + acc_top1_ns.append(acc_top1.numpy()) + acc_top5_ns.append(acc_top5.numpy()) + if batch_id % args.log_period == 0: + _logger.info( + "Eval epoch[{}] batch[{}] - acc_top1: {}; acc_top5: {}; time: {}". + format(epoch, batch_id, + np.mean(acc_top1.numpy()), + np.mean(acc_top5.numpy()), end_time - start_time)) + acc_top1_ns.append(np.mean(acc_top1.numpy())) + acc_top5_ns.append(np.mean(acc_top5.numpy())) + + _logger.info("Final eval epoch[{}] - acc_top1: {}; acc_top5: {}".format( + epoch, + np.mean(np.array( + acc_top1_ns, dtype="object")), + np.mean(np.array( + acc_top5_ns, dtype="object")))) + + def train(epoch): + dp_model.train() + for batch_id, data in enumerate(train_loader): + start_time = time.time() + x_data = data[0] + y_data = paddle.to_tensor(data[1]) + if args.data == 'cifar10': + y_data = paddle.unsqueeze(y_data, 1) + + logits = dp_model(x_data) + loss = F.cross_entropy(logits, y_data) + acc_top1 = paddle.metric.accuracy(logits, y_data, k=1) + acc_top5 = paddle.metric.accuracy(logits, y_data, k=5) + end_time = time.time() + if batch_id % args.log_period == 0: + _logger.info( + "epoch[{}]-batch[{}] lr: {:.6f} - loss: {}; acc_top1: {}; acc_top5: {}; time: {}". + format(epoch, batch_id, args.lr, + np.mean(loss.numpy()), + np.mean(acc_top1.numpy()), + np.mean(acc_top5.numpy()), end_time - start_time)) + loss.backward() + opt.step() + opt.clear_grad() + pruner.step() + + pruner = UnstructuredPruner( + dp_model, + mode=args.pruning_mode, + ratio=args.ratio, + threshold=args.threshold) + for i in range(args.resume_epoch + 1, args.num_epochs): + train(i) + if i % args.test_period == 0: + pruner.update_params() + _logger.info( + "The current density of the pruned model is: {}%".format( + round(100 * UnstructuredPruner.total_sparse(dp_model), 2))) + test(i) + if i > args.resume_epoch and i % args.model_period == 0: + pruner.update_params() + paddle.save(dp_model.state_dict(), + os.path.join(args.model_path, "model-pruned.pdparams")) + paddle.save(opt.state_dict(), + os.path.join(args.model_path, "opt-pruned.pdopt")) + + +def main(): + args = parser.parse_args() + print_arguments(args) + compress(args) + + +if __name__ == '__main__': + main() diff --git a/demo/dygraph/unstructured_pruning/train_cifar10.sh b/demo/dygraph/unstructured_pruning/train_cifar10.sh new file mode 100644 index 00000000..38cbd970 --- /dev/null +++ b/demo/dygraph/unstructured_pruning/train_cifar10.sh @@ -0,0 +1,9 @@ +#!/bin/bash +export CUDA_VISIBLE_DEVICES=3 +python3.7 train.py \ + --batch_size=128 \ + --lr=0.05 \ + --ratio=0.45 \ + --threshold=1e-5 \ + --pruning_mode="threshold" \ + --data="cifar10" \ diff --git a/demo/dygraph/unstructured_pruning/train_imagenet.sh b/demo/dygraph/unstructured_pruning/train_imagenet.sh new file mode 100644 index 00000000..2b7c4d68 --- /dev/null +++ b/demo/dygraph/unstructured_pruning/train_imagenet.sh @@ -0,0 +1,9 @@ +#!/bin/bash +export CUDA_VISIBLE_DEVICES=3 +python3.7 train.py \ + --batch_size=64 \ + --lr=0.05 \ + --ratio=0.45 \ + --threshold=1e-5 \ + --pruning_mode="threshold" \ + --data="imagenet" \ diff --git a/demo/unstructured_prune/README.md b/demo/unstructured_prune/README.md new file mode 100644 index 00000000..17d095af --- /dev/null +++ b/demo/unstructured_prune/README.md @@ -0,0 +1,128 @@ +# 非结构化稀疏 -- 静态图剪裁(包括按照阈值和比例剪裁两种模式) + +## 简介 + +在模型压缩中,常见的稀疏方式为结构化和非结构化稀疏,前者在某个特定维度(特征通道、卷积核等等)上进行稀疏化操作;后者以每一个参数为单元进行稀疏化,所以更加依赖于硬件对稀疏后矩阵运算的加速能力。本目录即在PaddlePaddle和PaddleSlim框架下开发的非结构化稀疏算法,MobileNetV1在ImageNet上的稀疏化实验中,剪裁率55.19%,达到无损的表现。 + +## 版本要求 +```bash +python3.5+ +paddlepaddle>=2.0.0 +paddleslim>=2.1.0 +``` + +请参照github安装[paddlepaddle](https://github.com/PaddlePaddle/Paddle)和[paddleslim](https://github.com/PaddlePaddle/PaddleSlim)。 + +## 使用 + +训练前: +- 预训练模型下载,并放到某目录下,通过train.py中的--pretrained_model设置。 +- 训练数据下载后,可以通过重写../imagenet_reader.py文件,并在train.py文件中调用实现。 +- 开发者可以通过重写paddleslim.prune.unstructured_pruner.py中的UnstructuredPruner.update_threshold()来定义自己的非结构化稀疏策略(目前为剪裁掉绝对值小的parameters)。 +- 开发可以在初始化UnstructuredPruner时,传入自定义的skip_params_func,来定义哪些参数不参与剪裁。skip_params_func示例代码如下(路径:paddleslim.prune.unstructured_pruner._get_skip_params())。默认为所有的归一化层的参数不参与剪裁。 + +```python +def _get_skip_params(program): + """ + The function is used to get a set of all the skipped parameters when performing pruning. + By default, the normalization-related ones will not be pruned. + Developers could replace it by passing their own function when initializing the UnstructuredPruner instance. + Args: + - program(paddle.static.Program): the current model. + Returns: + - skip_params(Set): a set of parameters' names. + """ + skip_params = set() + graph = paddleslim.core.GraphWrapper(program) + for op in graph.ops(): + if 'norm' in op.type() and 'grad' not in op.type(): + for input in op.all_inputs(): + skip_params.add(input.name()) + return skip_params +``` + +训练: +```bash +CUDA_VISIBLE_DEVICES=2,3 python3.7 train.py --data mnist --lr 0.1 --pruning_mode ratio --ratio=0.5 +``` + +推理: +```bash +CUDA_VISIBLE_DEVICES=0 python3.7 evaluate.py --pruned_model models/ --data imagenet +``` + +剪裁训练代码示例: +```python +# model definition +places = paddle.static.cuda_places() +place = places[0] +exe = paddle.static.Executor(place) +model = models.__dict__[args.model]() +out = model.net(input=image, class_dim=class_dim) +cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label) +avg_cost = paddle.mean(x=cost) +acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1) +acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5) + +val_program = paddle.static.default_main_program().clone(for_test=True) + +opt, learning_rate = create_optimizer(args, step_per_epoch) +opt.minimize(avg_cost) + +#STEP1: initialize the pruner +pruner = UnstructuredPruner(paddle.static.default_main_program(), mode='ratio', ratio=0.5, place=place) + +exe.run(paddle.static.default_startup_program()) +paddle.fluid.io.load_vars(exe, args.pretrained_model) + +for epoch in range(epochs): + for batch_id, data in enumerate(train_loader): + loss_n, acc_top1_n, acc_top5_n = exe.run( + train_program, + feed={ + "image": data[0].get('image'), + "label": data[0].get('label') + }, + fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name]) + learning_rate.step() + #STEP2: update the pruner's threshold given the updated parameters + pruner.step() + + if epoch % args.test_period == 0: + #STEP3: before evaluation during training, eliminate the non-zeros generated by opt.step(), which, however, the cached masks setting to be zeros. + pruner.update_params() + eval(epoch) + + if epoch % args.model_period == 0: + # STEP4: same purpose as STEP3 + pruner.update_params() + save(epoch) +``` + +剪裁后测试代码示例: +```python +# intialize the model instance in static mode +# load weights +print(UnstructuredPruner.total_sparse(paddle.static.default_main_program())) #注意,total_sparse为静态方法(static method),可以不创建实例(instance)直接调用,方便只做测试的写法。 +test() +``` + +更多使用参数请参照shell文件,或者通过运行以下命令查看: +```bash +python3.7 train.py --h +python3.7 evaluate.py --h +``` + +## 实验结果 + +| 模型 | 数据集 | 压缩方法 | 压缩率| Top-1/Top-5 Acc | lr | threshold | epoch | +|:--:|:---:|:--:|:--:|:--:|:--:|:--:|:--:| +| MobileNetV1 | ImageNet | Baseline | - | 70.99%/89.68% | - | - | - | +| MobileNetV1 | ImageNet | ratio | -55.19% | 70.87%/89.80% (-0.12%/+0.12%) | 0.005 | - | 68 | +| YOLO v3 | VOC | - | - |76.24% | - | - | - | +| YOLO v3 | VOC |threshold | -55.15% | 75.45%(-0.79%) | 0.005 | 0.05 |12.8w| + +## TODO + +- [ ] 完成实验,验证动态图下的效果,并得到压缩模型。 +- [ ] 扩充衡量parameter重要性的方法(目前仅为绝对值)。 diff --git a/demo/unstructured_prune/evaluate.py b/demo/unstructured_prune/evaluate.py new file mode 100644 index 00000000..62a08622 --- /dev/null +++ b/demo/unstructured_prune/evaluate.py @@ -0,0 +1,137 @@ +import os +import sys +import logging +import paddle +import argparse +import functools +import math +import time +import numpy as np +import paddle.fluid as fluid +sys.path.append(os.path.join(os.path.dirname("__file__"), os.path.pardir)) +from paddleslim.prune.unstructured_pruner import UnstructuredPruner +from paddleslim.common import get_logger +import models +from utility import add_arguments, print_arguments +import paddle.vision.transforms as T + +_logger = get_logger(__name__, level=logging.INFO) + +parser = argparse.ArgumentParser(description=__doc__) +add_arg = functools.partial(add_arguments, argparser=parser) +# yapf: disable +add_arg('batch_size', int, 64*12, "Minibatch size.") +add_arg('use_gpu', bool, True, "Whether to use GPU or not.") +add_arg('model', str, "MobileNet", "The target model.") +add_arg('pruned_model', str, "models", "Whether to use pretrained model.") +add_arg('data', str, "mnist", "Which data to use. 'mnist' or 'imagenet'.") +add_arg('log_period', int, 100, "Log period in batches.") +# yapf: enable + +model_list = models.__all__ + + +def compress(args): + train_reader = None + test_reader = None + if args.data == "mnist": + transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])]) + train_dataset = paddle.vision.datasets.MNIST( + mode='train', backend="cv2", transform=transform) + val_dataset = paddle.vision.datasets.MNIST( + mode='test', backend="cv2", transform=transform) + class_dim = 10 + image_shape = "1,28,28" + elif args.data == "imagenet": + import imagenet_reader as reader + train_dataset = reader.ImageNetDataset(data_dir='/data', mode='train') + val_dataset = reader.ImageNetDataset(data_dir='/data', mode='val') + class_dim = 1000 + image_shape = "3,224,224" + else: + raise ValueError("{} is not supported.".format(args.data)) + image_shape = [int(m) for m in image_shape.split(",")] + assert args.model in model_list, "{} is not in lists: {}".format(args.model, + model_list) + places = paddle.static.cuda_places( + ) if args.use_gpu else paddle.static.cpu_places() + place = places[0] + exe = paddle.static.Executor(place) + image = paddle.static.data( + name='image', shape=[None] + image_shape, dtype='float32') + label = paddle.static.data(name='label', shape=[None, 1], dtype='int64') + + batch_size_per_card = int(args.batch_size / len(places)) + valid_loader = paddle.io.DataLoader( + val_dataset, + places=place, + feed_list=[image, label], + drop_last=False, + return_list=False, + use_shared_memory=True, + batch_size=batch_size_per_card, + shuffle=False) + step_per_epoch = int(np.ceil(len(train_dataset) * 1. / args.batch_size)) + + # model definition + model = models.__dict__[args.model]() + out = model.net(input=image, class_dim=class_dim) + cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label) + avg_cost = paddle.mean(x=cost) + acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1) + acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5) + + val_program = paddle.static.default_main_program().clone(for_test=True) + + exe.run(paddle.static.default_startup_program()) + + if args.pruned_model: + + def if_exist(var): + return os.path.exists(os.path.join(args.pruned_model, var.name)) + + _logger.info("Load pruned model from {}".format(args.pruned_model)) + paddle.fluid.io.load_vars(exe, args.pruned_model, predicate=if_exist) + + def test(epoch, program): + acc_top1_ns = [] + acc_top5_ns = [] + + _logger.info("The current density of the inference model is {}%".format( + round(100 * UnstructuredPruner.total_sparse( + paddle.static.default_main_program()), 2))) + for batch_id, data in enumerate(valid_loader): + start_time = time.time() + acc_top1_n, acc_top5_n = exe.run( + program, + feed={ + "image": data[0].get('image'), + "label": data[0].get('label'), + }, + fetch_list=[acc_top1.name, acc_top5.name]) + end_time = time.time() + if batch_id % args.log_period == 0: + _logger.info( + "Eval epoch[{}] batch[{}] - acc_top1: {}; acc_top5: {}; time: {}". + format(epoch, batch_id, + np.mean(acc_top1_n), + np.mean(acc_top5_n), end_time - start_time)) + acc_top1_ns.append(np.mean(acc_top1_n)) + acc_top5_ns.append(np.mean(acc_top5_n)) + + _logger.info("Final eval epoch[{}] - acc_top1: {}; acc_top5: {}".format( + epoch, + np.mean(np.array(acc_top1_ns)), np.mean(np.array(acc_top5_ns)))) + + test(0, val_program) + + +def main(): + paddle.enable_static() + args = parser.parse_args() + print_arguments(args) + compress(args) + + +if __name__ == '__main__': + main() diff --git a/demo/unstructured_prune/evaluate_imagenet.sh b/demo/unstructured_prune/evaluate_imagenet.sh new file mode 100644 index 00000000..03e6dac5 --- /dev/null +++ b/demo/unstructured_prune/evaluate_imagenet.sh @@ -0,0 +1,6 @@ +#!/bin/bash +export CUDA_VISIBLE_DEVICES=3 +export FLAGS_fraction_of_gpu_memory_to_use=0.98 +python3.7 evaluate.py \ + --pruned_model="models" \ + --data="imagenet" diff --git a/demo/unstructured_prune/evaluate_mnist.sh b/demo/unstructured_prune/evaluate_mnist.sh new file mode 100644 index 00000000..12166bc9 --- /dev/null +++ b/demo/unstructured_prune/evaluate_mnist.sh @@ -0,0 +1,5 @@ +#!/bin/bash +export CUDA_VISIBLE_DEVICES=3 +python3.7 evaluate.py \ + --pruned_model="models" \ + --data="mnist" diff --git a/demo/unstructured_prune/train.py b/demo/unstructured_prune/train.py new file mode 100644 index 00000000..384845f3 --- /dev/null +++ b/demo/unstructured_prune/train.py @@ -0,0 +1,249 @@ +import os +import sys +import logging +import paddle +import argparse +import functools +import time +import numpy as np +import paddle.fluid as fluid +from paddleslim.prune.unstructured_pruner import UnstructuredPruner +from paddleslim.common import get_logger +sys.path.append(os.path.join(os.path.dirname("__file__"), os.path.pardir)) +import models +from utility import add_arguments, print_arguments +import paddle.vision.transforms as T + +_logger = get_logger(__name__, level=logging.INFO) + +parser = argparse.ArgumentParser(description=__doc__) +add_arg = functools.partial(add_arguments, argparser=parser) +# yapf: disable +add_arg('batch_size', int, 64, "Minibatch size.") +add_arg('use_gpu', bool, True, "Whether to use GPU or not.") +add_arg('model', str, "MobileNet", "The target model.") +add_arg('pretrained_model', str, "../pretrained_model/MobileNetV1_pretrained", "Whether to use pretrained model.") +add_arg('lr', float, 0.1, "The learning rate used to fine-tune pruned model.") +add_arg('lr_strategy', str, "cosine_decay", "The learning rate decay strategy.") +add_arg('l2_decay', float, 3e-5, "The l2_decay parameter.") +add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.") +add_arg('threshold', float, 1e-5, "The threshold to set zeros, the abs(weights) lower than which will be zeros.") +add_arg('pruning_mode', str, 'ratio', "the pruning mode: whether by ratio or by threshold.") +add_arg('ratio', float, 0.5, "The ratio to set zeros, the smaller portion will be zeros.") +add_arg('num_epochs', int, 120, "The number of total epochs.") +parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step") +add_arg('data', str, "mnist", "Which data to use. 'mnist' or 'imagenet'.") +add_arg('log_period', int, 100, "Log period in batches.") +add_arg('test_period', int, 10, "Test period in epoches.") +add_arg('model_path', str, "./models", "The path to save model.") +add_arg('model_period', int, 10, "The period to save model in epochs.") +add_arg('resume_epoch', int, -1, "The epoch to resume training.") +# yapf: enable + +model_list = models.__all__ + + +def piecewise_decay(args, step_per_epoch): + bd = [step_per_epoch * e for e in args.step_epochs] + lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)] + learning_rate = paddle.optimizer.lr.PiecewiseDecay(boundaries=bd, values=lr) + + optimizer = paddle.optimizer.Momentum( + learning_rate=learning_rate, + momentum=args.momentum_rate, + weight_decay=paddle.regularizer.L2Decay(args.l2_decay)) + return optimizer, learning_rate + + +def cosine_decay(args, step_per_epoch): + learning_rate = paddle.optimizer.lr.CosineAnnealingDecay( + learning_rate=args.lr, T_max=args.num_epochs * step_per_epoch) + optimizer = paddle.optimizer.Momentum( + learning_rate=learning_rate, + momentum=args.momentum_rate, + weight_decay=paddle.regularizer.L2Decay(args.l2_decay)) + return optimizer, learning_rate + + +def create_optimizer(args, step_per_epoch): + if args.lr_strategy == "piecewise_decay": + return piecewise_decay(args, step_per_epoch) + elif args.lr_strategy == "cosine_decay": + return cosine_decay(args, step_per_epoch) + + +def compress(args): + train_reader = None + test_reader = None + if args.data == "mnist": + transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])]) + train_dataset = paddle.vision.datasets.MNIST( + mode='train', backend="cv2", transform=transform) + val_dataset = paddle.vision.datasets.MNIST( + mode='test', backend="cv2", transform=transform) + class_dim = 10 + image_shape = "1,28,28" + args.pretrained_model = False + elif args.data == "imagenet": + import imagenet_reader as reader + train_dataset = reader.ImageNetDataset(data_dir='/data', mode='train') + val_dataset = reader.ImageNetDataset(data_dir='/data', mode='val') + class_dim = 1000 + image_shape = "3,224,224" + else: + raise ValueError("{} is not supported.".format(args.data)) + image_shape = [int(m) for m in image_shape.split(",")] + assert args.model in model_list, "{} is not in lists: {}".format(args.model, + model_list) + places = paddle.static.cuda_places( + ) if args.use_gpu else paddle.static.cpu_places() + place = places[0] + exe = paddle.static.Executor(place) + image = paddle.static.data( + name='image', shape=[None] + image_shape, dtype='float32') + label = paddle.static.data(name='label', shape=[None, 1], dtype='int64') + + batch_size_per_card = int(args.batch_size / len(places)) + train_loader = paddle.io.DataLoader( + train_dataset, + places=places, + feed_list=[image, label], + drop_last=True, + batch_size=batch_size_per_card, + shuffle=True, + return_list=False, + use_shared_memory=True, + num_workers=32) + valid_loader = paddle.io.DataLoader( + val_dataset, + places=place, + feed_list=[image, label], + drop_last=False, + return_list=False, + use_shared_memory=True, + batch_size=batch_size_per_card, + shuffle=False) + step_per_epoch = int(np.ceil(len(train_dataset) * 1. / args.batch_size)) + + # model definition + model = models.__dict__[args.model]() + out = model.net(input=image, class_dim=class_dim) + cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label) + avg_cost = paddle.mean(x=cost) + acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1) + acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5) + + val_program = paddle.static.default_main_program().clone(for_test=True) + + opt, learning_rate = create_optimizer(args, step_per_epoch) + opt.minimize(avg_cost) + + pruner = UnstructuredPruner( + paddle.static.default_main_program(), + batch_size=args.batch_size, + mode=args.pruning_mode, + ratio=args.ratio, + threshold=args.threshold, + place=place) + + exe.run(paddle.static.default_startup_program()) + + if args.pretrained_model: + + def if_exist(var): + return os.path.exists(os.path.join(args.pretrained_model, var.name)) + + _logger.info("Load pretrained model from {}".format( + args.pretrained_model)) + # NOTE: We are using fluid.io.load_vars() because the pretrained model is from an older version which requires this API. + # Please consider using paddle.static.load(program, model_path) when possible + paddle.fluid.io.load_vars( + exe, args.pretrained_model, predicate=if_exist) + + def test(epoch, program): + acc_top1_ns = [] + acc_top5_ns = [] + + _logger.info("The current density of the inference model is {}%".format( + round(100 * UnstructuredPruner.total_sparse( + paddle.static.default_main_program()), 2))) + for batch_id, data in enumerate(valid_loader): + start_time = time.time() + acc_top1_n, acc_top5_n = exe.run( + program, + feed={ + "image": data[0].get('image'), + "label": data[0].get('label') + }, + fetch_list=[acc_top1.name, acc_top5.name]) + end_time = time.time() + if batch_id % args.log_period == 0: + _logger.info( + "Eval epoch[{}] batch[{}] - acc_top1: {}; acc_top5: {}; time: {}". + format(epoch, batch_id, + np.mean(acc_top1_n), + np.mean(acc_top5_n), end_time - start_time)) + acc_top1_ns.append(np.mean(acc_top1_n)) + acc_top5_ns.append(np.mean(acc_top5_n)) + + _logger.info("Final eval epoch[{}] - acc_top1: {}; acc_top5: {}".format( + epoch, + np.mean(np.array(acc_top1_ns)), np.mean(np.array(acc_top5_ns)))) + + def train(epoch, program): + for batch_id, data in enumerate(train_loader): + start_time = time.time() + loss_n, acc_top1_n, acc_top5_n = exe.run( + train_program, + feed={ + "image": data[0].get('image'), + "label": data[0].get('label') + }, + fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name]) + end_time = time.time() + loss_n = np.mean(loss_n) + acc_top1_n = np.mean(acc_top1_n) + acc_top5_n = np.mean(acc_top5_n) + if batch_id % args.log_period == 0: + _logger.info( + "epoch[{}]-batch[{}] lr: {:.6f} - loss: {}; acc_top1: {}; acc_top5: {}; time: {}". + format(epoch, batch_id, + learning_rate.get_lr(), loss_n, acc_top1_n, + acc_top5_n, end_time - start_time)) + learning_rate.step() + pruner.step() + batch_id += 1 + + build_strategy = paddle.static.BuildStrategy() + exec_strategy = paddle.static.ExecutionStrategy() + + train_program = paddle.static.CompiledProgram( + paddle.static.default_main_program()).with_data_parallel( + loss_name=avg_cost.name, + build_strategy=build_strategy, + exec_strategy=exec_strategy) + for i in range(args.resume_epoch + 1, args.num_epochs): + train(i, train_program) + _logger.info("The current density of the pruned model is: {}%".format( + round(100 * UnstructuredPruner.total_sparse( + paddle.static.default_main_program()), 2))) + + if i % args.test_period == 0: + pruner.update_params() + test(i, val_program) + if i > args.resume_epoch and i % args.model_period == 0: + pruner.update_params() + # NOTE: We are using fluid.io.save_params() because the pretrained model is from an older version which requires this API. + # Please consider using paddle.static.save(program, model_path) as long as it becomes possible. + fluid.io.save_params(executor=exe, dirname=args.model_path) + + +def main(): + paddle.enable_static() + args = parser.parse_args() + print_arguments(args) + compress(args) + + +if __name__ == '__main__': + main() diff --git a/demo/unstructured_prune/train_imagenet.sh b/demo/unstructured_prune/train_imagenet.sh new file mode 100644 index 00000000..c8b6d671 --- /dev/null +++ b/demo/unstructured_prune/train_imagenet.sh @@ -0,0 +1,10 @@ +#!/bin/bash +export CUDA_VISIBLE_DEVICES=2,3 +export FLAGS_fraction_of_gpu_memory_to_use=0.98 +python3.7 train.py \ + --batch_size 256 \ + --data imagenet \ + --pruning_mode ratio \ + --ratio 0.45 \ + --lr 0.075 \ + --pretrained_model /PaddleSlim/demo/pretrained_model/MobileNetV1_pretrained diff --git a/demo/unstructured_prune/train_mnist.sh b/demo/unstructured_prune/train_mnist.sh new file mode 100644 index 00000000..375045f9 --- /dev/null +++ b/demo/unstructured_prune/train_mnist.sh @@ -0,0 +1,10 @@ +#!/bin/bash +export CUDA_VISIBLE_DEVICES=2,3 +export FLAGS_fraction_of_gpu_memory_to_use=0.98 +python3.7 train.py \ + --batch_size=256 \ + --data="mnist" \ + --pruning_mode="threshold" \ + --ratio=0.45 \ + --threshold=1e-5 \ + --lr=0.075 \ diff --git a/paddleslim/dygraph/prune/__init__.py b/paddleslim/dygraph/prune/__init__.py index fa2d29e9..9ab11f2f 100644 --- a/paddleslim/dygraph/prune/__init__.py +++ b/paddleslim/dygraph/prune/__init__.py @@ -10,6 +10,8 @@ from . import l2norm_pruner from .l2norm_pruner import * from . import fpgm_pruner from .fpgm_pruner import * +from . import unstructured_pruner +from .unstructured_pruner import * __all__ = [] @@ -19,3 +21,4 @@ __all__ += l2norm_pruner.__all__ __all__ += fpgm_pruner.__all__ __all__ += pruner.__all__ __all__ += filter_pruner.__all__ +__all__ += unstructured_pruner.__all__ diff --git a/paddleslim/dygraph/prune/unstructured_pruner.py b/paddleslim/dygraph/prune/unstructured_pruner.py new file mode 100644 index 00000000..44116fc8 --- /dev/null +++ b/paddleslim/dygraph/prune/unstructured_pruner.py @@ -0,0 +1,152 @@ +import numpy as np +import paddle +import logging +from paddleslim.common import get_logger + +__all__ = ["UnstructuredPruner"] + +_logger = get_logger(__name__, level=logging.INFO) + + +class UnstructuredPruner(): + """ + The unstructure pruner. + Args: + - model(Paddle.nn.Layer): The model to be pruned. + - mode(str): Pruning mode, must be selected from 'ratio' and 'threshold'. + - threshold(float): The parameters whose absolute values are smaller than the THRESHOLD will be zeros. Default: 0.01 + - ratio(float): The parameters whose absolute values are in the smaller part decided by the ratio will be zeros. Default: 0.3 + - skip_params_func(function): The function used to select the parameters which should be skipped when performing pruning. Default: normalization-related params. + """ + + def __init__(self, + model, + mode, + threshold=0.01, + ratio=0.3, + skip_params_func=None): + assert mode in ('ratio', 'threshold' + ), "mode must be selected from 'ratio' and 'threshold'" + self.model = model + self.mode = mode + self.threshold = threshold + self.ratio = ratio + if skip_params_func is None: skip_params_func = self._get_skip_params + self.skip_params = skip_params_func(model) + self._apply_masks() + + def mask_parameters(self, param, mask): + """ + Update masks and parameters. It is executed to each layer before each iteration. + User can overwrite this function in subclass to implememt different pruning stragies. + Args: + - parameters(list): The parameters to be pruned. + - masks(list): The masks used to keep zero values in parameters. + """ + bool_tmp = (paddle.abs(param) >= self.threshold) + paddle.assign(mask * bool_tmp, output=mask) + param_tmp = param * mask + param_tmp.stop_gradient = True + paddle.assign(param_tmp, output=param) + + def _apply_masks(self): + self.masks = {} + for name, sub_layer in self.model.named_sublayers(): + for param in sub_layer.parameters(include_sublayers=False): + tmp_array = np.ones(param.shape, dtype=np.float32) + mask_name = "_".join([param.name.replace(".", "_"), "mask"]) + if mask_name not in sub_layer._buffers: + sub_layer.register_buffer(mask_name, + paddle.to_tensor(tmp_array)) + self.masks[param.name] = sub_layer._buffers[mask_name] + for name, sub_layer in self.model.named_sublayers(): + sub_layer.register_forward_pre_hook(self._forward_pre_hook) + + def update_threshold(self): + ''' + Update the threshold after each optimization step. + User should overwrite this method togther with self.mask_parameters() + ''' + params_flatten = [] + for name, sub_layer in self.model.named_sublayers(): + if not self._should_prune_layer(sub_layer): + continue + for param in sub_layer.parameters(include_sublayers=False): + t_param = param.value().get_tensor() + v_param = np.array(t_param) + params_flatten.append(v_param.flatten()) + params_flatten = np.concatenate(params_flatten, axis=0) + total_length = params_flatten.size + self.threshold = np.sort(np.abs(params_flatten))[max( + 0, round(self.ratio * total_length) - 1)].item() + + def step(self): + """ + Update the threshold after each optimization step. + """ + if self.mode == 'ratio': + self.update_threshold() + elif self.mode == 'threshold': + return + + def _forward_pre_hook(self, layer, input): + if not self._should_prune_layer(layer): + return input + for param in layer.parameters(include_sublayers=False): + mask = self.masks.get(param.name) + self.mask_parameters(param, mask) + return input + + def update_params(self): + """ + Update the parameters given self.masks, usually called before saving models and evaluation step during training. + If you load a sparse model and only want to inference, no need to call the method. + """ + for name, sub_layer in self.model.named_sublayers(): + for param in sub_layer.parameters(include_sublayers=False): + mask = self.masks.get(param.name) + param_tmp = param * mask + param_tmp.stop_gradient = True + paddle.assign(param_tmp, output=param) + + @staticmethod + def total_sparse(model): + """ + This static function is used to get the whole model's density (1-sparsity). + It is static because during testing, we can calculate sparsity without initializing a pruner instance. + + Args: + - model(Paddle.Model): The sparse model. + Returns: + - ratio(float): The model's density. + """ + total = 0 + values = 0 + for name, sub_layer in model.named_sublayers(): + for param in sub_layer.parameters(include_sublayers=False): + total += np.product(param.shape) + values += len(paddle.nonzero(param)) + ratio = float(values) / total + return ratio + + def _get_skip_params(self, model): + """ + This function is used to check whether the given model's layers are valid to be pruned. + Usually, the convolutions are to be pruned while we skip the normalization-related parameters. + Deverlopers could replace this function by passing their own when initializing the UnstructuredPuner instance. + + Args: + - model(Paddle.nn.Layer): the current model waiting to be checked. + Return: + - skip_params(set): a set of parameters' names + """ + skip_params = set() + for _, sub_layer in model.named_sublayers(): + if type(sub_layer).__name__.split('.')[ + -1] in paddle.nn.norm.__all__: + skip_params.add(sub_layer.full_name()) + return skip_params + + def _should_prune_layer(self, layer): + should_prune = layer.full_name() not in self.skip_params + return should_prune diff --git a/paddleslim/prune/__init__.py b/paddleslim/prune/__init__.py index 01062078..7542031c 100644 --- a/paddleslim/prune/__init__.py +++ b/paddleslim/prune/__init__.py @@ -27,6 +27,8 @@ from .group_param import * from ..prune import group_param from .criterion import * from ..prune import criterion +from .unstructured_pruner import * +from ..prune import unstructured_pruner from .idx_selector import * from ..prune import idx_selector @@ -39,4 +41,5 @@ __all__ += prune_walker.__all__ __all__ += prune_io.__all__ __all__ += group_param.__all__ __all__ += criterion.__all__ +__all__ += unstructured_pruner.__all__ __all__ += idx_selector.__all__ diff --git a/paddleslim/prune/unstructured_pruner.py b/paddleslim/prune/unstructured_pruner.py new file mode 100644 index 00000000..bc710457 --- /dev/null +++ b/paddleslim/prune/unstructured_pruner.py @@ -0,0 +1,202 @@ +import numpy as np +from ..common import get_logger +from ..core import GraphWrapper +import paddle + +__all__ = ["UnstructuredPruner"] + + +class UnstructuredPruner(): + """ + The unstructure pruner. + + Args: + - program(paddle.static.Program): The model to be pruned. + - batch_size(int): batch size. + - mode(str): the mode to prune the model, must be selected from 'ratio' and 'threshold'. + - ratio(float): the ratio to prune the model. Only set it when mode=='ratio'. Default: 0.5. + - threshold(float): the threshold to prune the model. Only set it when mode=='threshold'. Default: 1e-5. + - scope(paddle.static.Scope): The scope storing values of all variables. None means paddle.static.global_scope. Default: None. + - place(CPUPlace | CUDAPlace): The device place used to execute model. None means CPUPlace. Default: None. + - skip_params_func(function): The function used to select the parameters which should be skipped when performing pruning. Default: normalization-related params. + """ + + def __init__(self, + program, + batch_size, + mode, + ratio=0.5, + threshold=1e-5, + scope=None, + place=None, + skip_params_func=None): + self.mode = mode + self.ratio = ratio + self.threshold = threshold + assert self.mode in [ + 'ratio', 'threshold' + ], "mode must be selected from 'ratio' and 'threshold'" + self.scope = paddle.static.global_scope() if scope == None else scope + self.place = paddle.static.CPUPlace() if place is None else place + if skip_params_func is None: skip_params_func = self._get_skip_params + self.skip_params = skip_params_func(program) + self.masks = self._apply_masks(program) + + def _apply_masks(self, program): + params = [] + masks = [] + for param in program.all_parameters(): + mask = program.global_block().create_var( + name=param.name + "_mask", + shape=param.shape, + dtype=param.dtype, + type=param.type, + persistable=param.persistable, + stop_gradient=True) + + self.scope.var(param.name + "_mask").get_tensor().set( + np.ones(mask.shape).astype("float32"), self.place) + params.append(param) + masks.append(mask) + + d_masks = {} + for _param, _mask in zip(params, masks): + d_masks[_param.name] = _mask.name + return d_masks + + def summarize_weights(self, program, ratio=0.1): + """ + The function is used to get the weights corresponding to a given ratio + when you are uncertain about the threshold in __init__() function above. + For example, when given 0.1 as ratio, the function will print the weight value, + the abs(weights) lower than which count for 10% of the total numbers. + + Args: + - program(paddle.static.Program): The model which have all the parameters. + - ratio(float): The ratio illustrated above. + Return: + - threshold(float): a threshold corresponding to the input ratio. + """ + data = [] + for param in program.all_parameters(): + data.append( + np.array(paddle.static.global_scope().find_var(param.name) + .get_tensor()).flatten()) + data = np.concatenate(data, axis=0) + threshold = np.sort(np.abs(data))[max(0, int(ratio * len(data) - 1))] + return threshold + + def sparse_by_layer(self, program): + """ + The function is used to get the density at each layer, usually called for debuggings. + + Args: + - program(paddle.static.Program): The current model. + Returns: + - layer_sparse(Dict): sparsity for each parameter. + """ + layer_sparse = {} + total = 0 + values = 0 + for param in program.all_parameters(): + value = np.count_nonzero( + np.array(paddle.static.global_scope().find_var(param.name) + .get_tensor())) + layer_sparse[param.name] = value / np.product(param.shape) + return layer_sparse + + def update_threshold(self): + ''' + Update the threshold after each optimization step in RATIO mode. + User should overwrite this method to define their own weight importance (Default is based on their absolute values). + ''' + params_flatten = [] + for param in self.masks: + if not self._should_prune_param(param): + continue + t_param = self.scope.find_var(param).get_tensor() + v_param = np.array(t_param) + params_flatten.append(v_param.flatten()) + params_flatten = np.concatenate(params_flatten, axis=0) + total_len = len(params_flatten) + self.threshold = np.sort(np.abs(params_flatten))[max( + 0, int(self.ratio * total_len) - 1)] + + def _update_params_masks(self): + for param in self.masks: + if not self._should_prune_param(param): + continue + mask_name = self.masks[param] + t_param = self.scope.find_var(param).get_tensor() + t_mask = self.scope.find_var(mask_name).get_tensor() + v_param = np.array(t_param) + v_param[np.abs(v_param) < self.threshold] = 0 + v_mask = (v_param != 0).astype(v_param.dtype) + t_mask.set(v_mask, self.place) + v_param = np.array(t_param) * np.array(t_mask) + t_param.set(v_param, self.place) + + def step(self): + """ + Update the threshold after each optimization step. + """ + if self.mode == 'threshold': + pass + elif self.mode == 'ratio': + self.update_threshold() + self._update_params_masks() + + def update_params(self): + """ + Update the parameters given self.masks, usually called before saving models. + """ + for param in self.masks: + mask = self.masks[param] + t_param = self.scope.find_var(param).get_tensor() + t_mask = self.scope.find_var(mask).get_tensor() + v_param = np.array(t_param) * np.array(t_mask) + t_param.set(v_param, self.place) + + @staticmethod + def total_sparse(program): + """ + The function is used to get the whole model's density (1-sparsity). + It is static because during testing, we can calculate sparsity without initializing a pruner instance. + + Args: + - program(paddle.static.Program): The current model. + Returns: + - density(float): the model's density. + """ + total = 0 + values = 0 + for param in program.all_parameters(): + total += np.product(param.shape) + values += np.count_nonzero( + np.array(paddle.static.global_scope().find_var(param.name) + .get_tensor())) + density = float(values) / total + return density + + def _get_skip_params(self, program): + """ + The function is used to get a set of all the skipped parameters when performing pruning. + By default, the normalization-related ones will not be pruned. + Developers could replace it by passing their own function when initializing the UnstructuredPruner instance. + + Args: + - program(paddle.static.Program): the current model. + Returns: + - skip_params(Set): a set of parameters' names. + """ + skip_params = set() + graph = GraphWrapper(program) + for op in graph.ops(): + if 'norm' in op.type() and 'grad' not in op.type(): + for input in op.all_inputs(): + skip_params.add(input.name()) + return skip_params + + def _should_prune_param(self, param): + should_prune = param not in self.skip_params + return should_prune diff --git a/tests/dygraph/test_unstructured_prune.py b/tests/dygraph/test_unstructured_prune.py new file mode 100644 index 00000000..5fe74dbe --- /dev/null +++ b/tests/dygraph/test_unstructured_prune.py @@ -0,0 +1,42 @@ +import sys +sys.path.append("../../") +import unittest +import paddle +import numpy as np +from paddleslim.dygraph.prune.unstructured_pruner import UnstructuredPruner +from paddle.vision.models import mobilenet_v1 + + +class TestUnstructuredPruner(unittest.TestCase): + def __init__(self, *args, **kwargs): + super(TestUnstructuredPruner, self).__init__(*args, **kwargs) + paddle.disable_static() + self._gen_model() + + def _gen_model(self): + self.net = mobilenet_v1(num_classes=10, pretrained=False) + self.pruner = UnstructuredPruner( + self.net, mode='ratio', ratio=0.98, threshold=0.0) + + def test_prune(self): + ori_density = UnstructuredPruner.total_sparse(self.net) + ori_threshold = self.pruner.threshold + self.pruner.step() + self.net( + paddle.to_tensor( + np.random.uniform(0, 1, [16, 3, 32, 32]), dtype='float32')) + cur_density = UnstructuredPruner.total_sparse(self.net) + cur_threshold = self.pruner.threshold + print("Original threshold: {}".format(ori_threshold)) + print("Current threshold: {}".format(cur_threshold)) + print("Original density: {}".format(ori_density)) + print("Current density: {}".format(cur_density)) + self.assertLessEqual(ori_threshold, cur_threshold) + self.assertLessEqual(cur_density, ori_density) + + self.pruner.update_params() + self.assertEqual(cur_density, UnstructuredPruner.total_sparse(self.net)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_unstructured_pruner.py b/tests/test_unstructured_pruner.py new file mode 100644 index 00000000..cdfc68e8 --- /dev/null +++ b/tests/test_unstructured_pruner.py @@ -0,0 +1,86 @@ +import sys +sys.path.append("../") +import unittest +from static_case import StaticCase +import paddle.fluid as fluid +import paddle +from paddleslim.prune import UnstructuredPruner +from layers import conv_bn_layer +import numpy as np + + +class TestUnstructuredPruner(StaticCase): + def __init__(self, *args, **kwargs): + super(TestUnstructuredPruner, self).__init__(*args, **kwargs) + paddle.enable_static() + self._gen_model() + + def _gen_model(self): + self.main_program = paddle.static.default_main_program() + self.startup_program = paddle.static.default_startup_program() + # conv1-->conv2-->sum1-->conv3-->conv4-->sum2-->conv5-->conv6 + # | ^ | ^ + # |____________| |____________________| + with paddle.static.program_guard(self.main_program, + self.startup_program): + input = paddle.static.data(name='image', shape=[None, 3, 16, 16]) + conv1 = conv_bn_layer(input, 8, 3, "conv1") + conv2 = conv_bn_layer(conv1, 8, 3, "conv2") + sum1 = conv1 + conv2 + conv3 = conv_bn_layer(sum1, 8, 3, "conv3") + conv4 = conv_bn_layer(conv3, 8, 3, "conv4") + sum2 = conv4 + sum1 + conv5 = conv_bn_layer(sum2, 8, 3, "conv5") + conv6 = conv_bn_layer(conv5, 8, 3, "conv6") + + conv7 = fluid.layers.conv2d_transpose( + input=conv6, num_filters=16, filter_size=2, stride=2) + + place = paddle.static.cpu_places()[0] + exe = paddle.static.Executor(place) + self.scope = paddle.static.global_scope() + exe.run(self.startup_program, scope=self.scope) + + self.pruner = UnstructuredPruner( + self.main_program, 16, 'ratio', scope=self.scope, place=place) + + def test_unstructured_prune(self): + for param in self.main_program.global_block().all_parameters(): + mask_name = param.name + "_mask" + mask_shape = self.scope.find_var(mask_name).get_tensor().shape() + self.assertTrue(tuple(mask_shape) == param.shape) + + def test_sparsity(self): + ori_density = UnstructuredPruner.total_sparse(self.main_program) + self.pruner.step() + cur_density = UnstructuredPruner.total_sparse(self.main_program) + cur_layer_density = self.pruner.sparse_by_layer(self.main_program) + print('original density: {}.'.format(ori_density)) + print('current density: {}.'.format(cur_density)) + total = 0 + non_zeros = 0 + for param in self.main_program.all_parameters(): + total += np.product(param.shape) + non_zeros += np.count_nonzero( + np.array(self.scope.find_var(param.name).get_tensor())) + self.assertEqual(cur_density, non_zeros / total) + self.assertLessEqual(cur_density, ori_density) + + self.pruner.update_params() + self.assertEqual(cur_density, + UnstructuredPruner.total_sparse(self.main_program)) + + def test_summarize_weights(self): + max_value = -float("inf") + threshold = self.pruner.summarize_weights(self.main_program, 1.0) + for param in self.main_program.global_block().all_parameters(): + max_value = max( + max_value, + np.max(np.array(self.scope.find_var(param.name).get_tensor()))) + print("The returned threshold is {}.".format(threshold)) + print("The max_value is {}.".format(max_value)) + self.assertEqual(max_value, threshold) + + +if __name__ == '__main__': + unittest.main() -- GitLab