未验证 提交 186960cc 编写于 作者: W whs 提交者: GitHub

Add demo for pruning dygraph (#538)

上级 d216cdc9
# 剪裁示例
该示例介绍了如何使用PaddleSlim对PaddlePaddle动态图进行剪裁。
使用的数据集有ImageNet1K与Cifar10,支持的模型有:Mobilenet系列分类模型、Resnet系列模型。
## 1. 数据准备
### 1.1 ImageNet1K
数据下载链接:http://www.image-net.org/challenges/LSVRC/2012/
下载数据后,按以下结构组织数据:
```
PaddleClas/dataset/ILSVRC2012/
|_ train/
| |_ n01440764
| | |_ n01440764_10026.JPEG
| | |_ ...
| |_ ...
| |
| |_ n15075141
| |_ ...
| |_ n15075141_9993.JPEG
|_ val/
| |_ ILSVRC2012_val_00000001.JPEG
| |_ ...
| |_ ILSVRC2012_val_00050000.JPEG
|_ train_list.txt
|_ val_list.txt
```
上述结构中的 `train_list.txt``val_list.txt` 内容如下:
```
# delimiter: "space"
# content of train_list.txt
train/n01440764/n01440764_10026.JPEG 0
...
# content of val_list.txt
val/ILSVRC2012_val_00000001.JPEG 65
...
```
### 1.2 Cifar10
对于`Cifar10`数据,该示例直接使用的是`paddle.vision.dataset.Cifar10`提供的数据读取接口,该接口会自动下载数据并将其缓存到本地文件系统,用户不需要关系该数据集的存储与格式。
## 2. 剪裁与训练
实践表明,对在目标任务上预训练过的模型进行剪裁,比剪裁没训过的模型,最终的效果要好。该示例中直接使用`paddle.vision.models`模块提供的针对`ImageNet1K`分类任务的预训练模型。
对预训练好的模型剪裁后,需要在目标数据集上进行重新训练,以便恢复因剪裁损失的精度。
`train.py`脚本中实现了上述剪裁和重训练两个步骤,其中的可配置参数可以通过执行`python train.py --help`查看。
### 2.1 CPU训练或GPU单卡训练
执行如下命令在GPU单卡进行剪裁和训练,该参数列表表示:对在`ImageNet1K`数据集上预训练好的`resnet34`模型进行剪裁,每层卷积剪掉25%的`filters`,卷积内评估`filters`重要性的方式为`FPGM`。最后对训练好的模型重训练120个epoch,并将每个epoch产出的模型保存至`./fpgm_resnet34_025_120_models`路径下。
```
python train.py \
--use_gpu=True \
--model="resnet34" \
--data="imagenet" \
--pruned_ratio=0.25 \
--num_epochs=120 \
--batch_size=256 \
--lr_strategy="cosine_decay" \
--criterion="fpgm" \
--model_path="./fpgm_resnet34_025_120_models"
```
如果需要仅在CPU上训练,需要修改上述命令中的`--use_gpu``False`.
### 2.2 GPU多卡训练
以下命令为启动GPU多卡剪裁和重训练任务,任务内容与2.1节内容一致。其中需要注意的是:`batch_size`为多张卡上总的`batch_size`
```
export CUDA_VISIBLE_DEVICES=0,1,2,3
python -m paddle.distributed.launch \
--gpus="0,1,2,3" \
--log_dir="fpgm_resnet34_f-42_train_log" \
train.py \
--use_gpu=True \
--model="resnet34" \
--data="imagenet" \
--pruned_ratio=0.25 \
--batch_size=256 \
--num_epochs=120 \
--lr_strategy="cosine_decay" \
--criterion="fpgm" \
--model_path="./fpgm_resnet34_025_120_models"
```
### 2.3 恢复训练
通过设置`checkpoint`选项进行恢复训练:
```
python train.py \
--use_gpu=True \
--model="resnet34" \
--data="imagenet" \
--pruned_ratio=0.25 \
--num_epochs=120 \
--batch_size=256 \
--lr_strategy="cosine_decay" \
--criterion="fpgm" \
--model_path="./fpgm_resnet34_025_120_models" \
--checkpoint="./fpgm_resnet34_025_120_models/0"
```
## 3. 评估
通过调用`eval.py`脚本,对剪裁和重训练后的模型在测试数据上进行精度:
```
python eval.py \
--checkpoint=./fpgm_resnet34_025_120_models/1 \
--model="resnet34" \
--pruned_ratio=0.25
```
## 4. 导出模型
执行以下命令导出用于预测的模型:
```
python export_model.py \
--checkpoint=./fpgm_resnet34_025_120_models/1 \
--model="resnet34" \
--pruned_ratio=0.25 \
--output_path=./infer/resnet
```
如上述命令所示,如果指定了`--output_path=./infer/resnet`,则会在路径`./infer`下生成三个文件:`resnet.pdiparams`, `resnet.pdmodel`, `resnet.pdiparams.info`. 这三个文件可以被PaddleLite或PaddleInference加载使用。
## 5. 部分实验结果
| 模型 | 原模型精度(Top1/Top5) | FLOPs剪裁百分比 | 剪裁后模型准确率(Top1/Top5) | 使用脚本 |
| ----------- | --------------------------- | ---------------- | --------------------------- | ------------------------------ |
| MobileNetV1 | 70.99/89.68 | -50% | 69.23/88.71 | fpgm_mobilenetv1_f-50_train.sh |
| MobileNetV2 | 72.15/90.65 | -50% | 67.00/87.56 | fpgm_mobilenetv2_f-50_train.sh |
| ResNet34 | 74.57/92.14 | -42% | 73.20/91.21 | fpgm_resnet34_f-42_train.sh |
#!/bin/bash
CUDA_VISIBLE_DEVICES=0,1 \
python -m paddle.distributed.launch \
--gpus="0,1" \
--log_dir="fpgm_mobilenetv1_train_log" \
train.py \
--model="mobilenet_v1" \
--data="imagenet" \
--pruned_ratio=0.3125 \
--lr=0.1 \
--num_epochs=120 \
--test_period=10 \
--step_epochs 30 60 90\
--l2_decay=3e-5 \
--lr_strategy="piecewise_decay" \
--criterion="fpgm" \
--model_path="./fpgm_mobilenetv1_models"
#!/bin/bash
CUDA_VISIBLE_DEVICES=0,1 \
python -m paddle.distributed.launch \
--gpus="0,1" \
--log_dir="fpgm_mobilenetv2_train_log" \
train.py \
--model="mobilenet_v2" \
--data="imagenet" \
--pruned_ratio=0.325 \
--lr=0.001 \
--num_epochs=90 \
--test_period=5 \
--step_epochs 30 60 80\
--l2_decay=1e-4 \
--lr_strategy="piecewise_decay" \
--criterion="fpgm" \
--model_path="./fpgm_mobilenetv2_models"
#!/bin/bash
CUDA_VISIBLE_DEVICES=0,1,2,3 \
python -m paddle.distributed.launch \
--gpus="0,1,2,3" \
--log_dir="fpgm_resnet34_f-42_train_log" \
train.py \
--model="resnet34" \
--data="imagenet" \
--pruned_ratio=0.25 \
--num_epochs=120 \
--lr_strategy="cosine_decay" \
--criterion="fpgm" \
--model_path="./fpgm_resnet34_025_120_models"
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import math
import random
import numpy as np
from PIL import Image
from paddle.vision.datasets import DatasetFolder
from paddle.vision.transforms import transforms
class ImageNetDataset(DatasetFolder):
def __init__(self,
path,
mode='train',
image_size=224,
resize_short_size=256):
super(ImageNetDataset, self).__init__(path)
self.mode = mode
self.samples = []
list_file = "train_list.txt" if self.mode == "train" else "val_list.txt"
with open(os.path.join([path, list_file]), 'r') as f:
for line in f:
_image, _label = line.strip().split(" ")
self.samples.append((_image, int(_label)))
normalize = transforms.Normalize(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.120, 57.375])
if self.mode == 'train':
self.transform = transforms.Compose([
transforms.RandomResizedCrop(image_size),
transforms.RandomHorizontalFlip(), transforms.Transpose(),
normalize
])
else:
self.transform = transforms.Compose([
transforms.Resize(resize_short_size),
transforms.CenterCrop(image_size), transforms.Transpose(),
normalize
])
def __getitem__(self, idx):
img_path, label = self.samples[idx]
img = Image.open(img_path).convert('RGB')
label = np.array([label]).astype(np.int64)
return self.transform(img), label
def __len__(self):
return len(self.samples)
from __future__ import division
from __future__ import print_function
import os
import sys
import logging
import paddle
import argparse
import functools
import math
import time
import numpy as np
sys.path[0] = os.path.join(
os.path.dirname("__file__"), os.path.pardir, os.path.pardir)
import paddleslim
from paddleslim.common import get_logger
from paddleslim.analysis import dygraph_flops as flops
import paddle.vision.models as models
from utility import add_arguments, print_arguments
import paddle.vision.transforms as T
from paddle.static import InputSpec as Input
from imagenet import ImageNetDataset
from paddle.io import BatchSampler, DataLoader, DistributedBatchSampler
from paddle.distributed import ParallelEnv
_logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 64 * 4, "Minibatch size.")
add_arg('model', str, "MobileNet", "The target model.")
add_arg('lr', float, 0.1, "The learning rate used to fine-tune pruned model.")
add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.")
add_arg('l2_decay', float, 3e-5, "The l2_decay parameter.")
add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.")
add_arg('num_epochs', int, 120, "The number of total epochs.")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
add_arg('data', str, "mnist", "Which data to use. 'mnist' or 'imagenet'")
add_arg('log_period', int, 10, "Log period in batches.")
add_arg('test_period', int, 10, "Test period in epoches.")
add_arg('model_path', str, "./models", "The path to save model.")
add_arg('pruned_ratio', float, None, "The ratios to be pruned.")
add_arg('criterion', str, "l1_norm", "The prune criterion to be used, support l1_norm and batch_norm_scale.")
add_arg('use_gpu', bool, True, "Whether to GPUs.")
add_arg('checkpoint', str, None, "The path of checkpoint which is used for resume training.")
# yapf: enable
model_list = models.__all__
def get_pruned_params(args, model):
params = []
if args.model == "mobilenet_v1":
skip_vars = ['linear_0.b_0',
'conv2d_0.w_0'] # skip the first conv2d and last linear
for sublayer in model.sublayers():
for param in sublayer.parameters(include_sublayers=False):
if isinstance(
sublayer, paddle.nn.Conv2D
) and sublayer._groups == 1 and param.name not in skip_vars:
params.append(param.name)
elif args.model == "mobilenet_v2":
for sublayer in model.sublayers():
for param in sublayer.parameters(include_sublayers=False):
if isinstance(sublayer, paddle.nn.Conv2D):
params.append(param.name)
return params
elif args.model == "resnet34":
for sublayer in model.sublayers():
for param in sublayer.parameters(include_sublayers=False):
if isinstance(sublayer, paddle.nn.Conv2D):
params.append(param.name)
return params
else:
raise NotImplementedError(
"Current demo only support for mobilenet_v1, mobilenet_v2, resnet34")
return params
def piecewise_decay(args, parameters, steps_per_epoch):
bd = [steps_per_epoch * e for e in args.step_epochs]
lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
learning_rate = paddle.optimizer.lr.PiecewiseDecay(boundaries=bd, values=lr)
optimizer = paddle.optimizer.Momentum(
learning_rate=learning_rate,
momentum=args.momentum_rate,
weight_decay=paddle.regularizer.L2Decay(args.l2_decay),
parameters=parameters)
return optimizer
def cosine_decay(args, parameters, steps_per_epoch):
learning_rate = paddle.optimizer.lr.CosineAnnealingDecay(
learning_rate=args.lr, T_max=args.num_epochs * steps_per_epoch)
optimizer = paddle.optimizer.Momentum(
learning_rate=learning_rate,
momentum=args.momentum_rate,
weight_decay=paddle.regularizer.L2Decay(args.l2_decay),
parameters=parameters)
return optimizer
def create_optimizer(args, parameters, steps_per_epoch):
if args.lr_strategy == "piecewise_decay":
return piecewise_decay(args, parameters, steps_per_epoch)
elif args.lr_strategy == "cosine_decay":
return cosine_decay(args, parameters, steps_per_epoch)
def compress(args):
paddle.set_device('gpu' if args.use_gpu else 'cpu')
train_reader = None
test_reader = None
if args.data == "cifar10":
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
train_dataset = paddle.vision.datasets.Cifar10(
mode="train", backend="cv2", transform=transform)
val_dataset = paddle.vision.datasets.Cifar10(
mode="test", backend="cv2", transform=transform)
class_dim = 10
image_shape = [3, 32, 32]
pretrain = False
elif args.data == "imagenet":
train_dataset = ImageNetDataset(
"data/ILSVRC2012",
mode='train',
image_size=224,
resize_short_size=256)
val_dataset = ImageNetDataset(
"data/ILSVRC2012",
mode='val',
image_size=224,
resize_short_size=256)
class_dim = 1000
image_shape = [3, 224, 224]
pretrain = True
else:
raise ValueError("{} is not supported.".format(args.data))
assert args.model in model_list, "{} is not in lists: {}".format(args.model,
model_list)
inputs = [Input([None] + image_shape, 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')]
# model definition
net = models.__dict__[args.model](pretrained=pretrain,
num_classes=class_dim)
_logger.info("FLOPs before pruning: {}GFLOPs".format(
flops(net, [1] + image_shape) / 1000))
net.eval()
if args.criterion == 'fpgm':
pruner = paddleslim.dygraph.FPGMFilterPruner(net, [1] + image_shape)
elif args.criterion == 'l1_norm':
pruner = paddleslim.dygraph.L1NormFilterPruner(net, [1] + image_shape)
params = get_pruned_params(args, net)
ratios = {}
for param in params:
ratios[param] = args.pruned_ratio
plan = pruner.prune_vars(ratios, [0])
_logger.info("FLOPs after pruning: {}GFLOPs; pruned ratio: {}".format(
flops(net, [1] + image_shape) / 1000, plan.pruned_flops))
for param in net.parameters():
if "conv2d" in param.name:
print(f"{param.name}\t{param.shape}")
net.train()
model = paddle.Model(net, inputs, labels)
steps_per_epoch = int(np.ceil(len(train_dataset) * 1. / args.batch_size))
opt = create_optimizer(args, net.parameters(), steps_per_epoch)
model.prepare(
opt, paddle.nn.CrossEntropyLoss(), paddle.metric.Accuracy(topk=(1, 5)))
if args.checkpoint is not None:
model.load(args.checkpoint)
model.fit(train_data=train_dataset,
eval_data=val_dataset,
epochs=args.num_epochs,
batch_size=args.batch_size // ParallelEnv().nranks,
verbose=1,
save_dir=args.model_path,
num_workers=8)
def main():
args = parser.parse_args()
print_arguments(args)
compress(args)
if __name__ == '__main__':
main()
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 图像分类模型通道剪裁-快速开始\n",
"\n",
"该教程以图像分类模型MobileNetV1为例,说明如何快速使用[PaddleSlim的卷积通道剪裁接口]()。\n",
"该示例包含以下步骤:\n",
"\n",
"1. 导入依赖\n",
"2. 构建模型\n",
"3. 剪裁\n",
"4. 训练剪裁后的模型\n",
"\n",
"以下章节依次次介绍每个步骤的内容。\n",
"\n",
"## 1. 导入依赖\n",
"\n",
"PaddleSlim依赖Paddle1.7版本,请确认已正确安装Paddle,然后按以下方式导入Paddle和PaddleSlim:"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"import paddle\n",
"import paddle.fluid as fluid\n",
"import paddleslim as slim"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. 构建网络\n",
"\n",
"该章节构造一个用于对MNIST数据进行分类的分类模型,选用`MobileNetV1`,并将输入大小设置为`[1, 28, 28]`,输出类别数为10。\n",
"为了方便展示示例,我们在`paddleslim.models`下预定义了用于构建分类模型的方法,执行以下代码构建分类模型:"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"exe, train_program, val_program, inputs, outputs = slim.models.image_classification(\"MobileNet\", [1, 28, 28], 10, use_gpu=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
">注意:paddleslim.models下的API并非PaddleSlim常规API,是为了简化示例而封装预定义的一系列方法,比如:模型结构的定义、Program的构建等。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. 剪裁卷积层通道\n",
"\n",
"### 3.1 计算剪裁之前的FLOPs"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"FLOPs: 10907072.0\n"
]
}
],
"source": [
"FLOPs = slim.analysis.flops(train_program)\n",
"print(\"FLOPs: {}\".format(FLOPs))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3.2 剪裁\n",
"\n",
"我们这里对参数名为`conv2_1_sep_weights`和`conv2_2_sep_weights`的卷积层进行剪裁,分别剪掉20%和30%的通道数。\n",
"代码如下所示:"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"pruner = slim.prune.Pruner()\n",
"pruned_program, _, _ = pruner.prune(\n",
" train_program,\n",
" fluid.global_scope(),\n",
" params=[\"conv2_1_sep_weights\", \"conv2_2_sep_weights\"],\n",
" ratios=[0.33] * 2,\n",
" place=fluid.CPUPlace())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"以上操作会修改`train_program`中对应卷积层参数的定义,同时对`fluid.global_scope()`中存储的参数数组进行裁剪。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3.3 计算剪裁之后的FLOPs"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"FLOPs: 10907072.0\n"
]
}
],
"source": [
"FLOPs = paddleslim.analysis.flops(train_program)\n",
"print(\"FLOPs: {}\".format(FLOPs))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. 训练剪裁后的模型\n",
"\n",
"### 4.1 定义输入数据\n",
"\n",
"为了快速执行该示例,我们选取简单的MNIST数据,Paddle框架的`paddle.dataset.mnist`包定义了MNIST数据的下载和读取。\n",
"代码如下:"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"import paddle.dataset.mnist as reader\n",
"train_reader = paddle.fluid.io.batch(\n",
" reader.train(), batch_size=128, drop_last=True)\n",
"train_feeder = fluid.DataFeeder(inputs, fluid.CPUPlace())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4.2 执行训练\n",
"以下代码执行了一个`epoch`的训练:"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.1484375] [0.4921875] [2.6727316]\n",
"[0.125] [0.546875] [2.6547904]\n",
"[0.125] [0.5546875] [2.795205]\n",
"[0.1171875] [0.578125] [2.8561475]\n",
"[0.1875] [0.59375] [2.470603]\n",
"[0.1796875] [0.578125] [2.8031898]\n",
"[0.1484375] [0.6015625] [2.7530417]\n",
"[0.1953125] [0.640625] [2.711596]\n",
"[0.125] [0.59375] [2.8637898]\n",
"[0.1796875] [0.53125] [2.9473038]\n",
"[0.25] [0.671875] [2.3943179]\n",
"[0.25] [0.6953125] [2.632146]\n",
"[0.2578125] [0.7265625] [2.723265]\n",
"[0.359375] [0.765625] [2.4263484]\n",
"[0.3828125] [0.8203125] [2.226284]\n",
"[0.421875] [0.8203125] [1.8042578]\n",
"[0.4765625] [0.890625] [1.6841211]\n",
"[0.53125] [0.8671875] [2.1971617]\n",
"[0.5546875] [0.8984375] [1.5361531]\n",
"[0.53125] [0.890625] [1.7211896]\n",
"[0.5078125] [0.8984375] [1.6586945]\n",
"[0.53125] [0.9140625] [1.8980236]\n",
"[0.546875] [0.9453125] [1.5279069]\n",
"[0.5234375] [0.8828125] [1.7356458]\n",
"[0.6015625] [0.9765625] [1.0375824]\n",
"[0.5546875] [0.921875] [1.639497]\n",
"[0.6015625] [0.9375] [1.5469061]\n",
"[0.578125] [0.96875] [1.3573356]\n",
"[0.65625] [0.9453125] [1.3787829]\n",
"[0.640625] [0.9765625] [0.9946856]\n",
"[0.65625] [0.96875] [1.1651027]\n",
"[0.625] [0.984375] [1.0487883]\n",
"[0.7265625] [0.9609375] [1.2526855]\n",
"[0.7265625] [0.9765625] [1.2954011]\n",
"[0.65625] [0.96875] [1.1181556]\n",
"[0.71875] [0.9765625] [0.97891223]\n",
"[0.640625] [0.9609375] [1.2135172]\n",
"[0.7265625] [0.9921875] [0.8950747]\n",
"[0.7578125] [0.96875] [1.0864108]\n",
"[0.734375] [0.9921875] [0.8392239]\n",
"[0.796875] [0.9609375] [0.7012155]\n",
"[0.7734375] [0.9765625] [0.7409136]\n",
"[0.8046875] [0.984375] [0.6108341]\n",
"[0.796875] [0.9765625] [0.63867176]\n",
"[0.7734375] [0.984375] [0.64099216]\n",
"[0.7578125] [0.9453125] [0.83827704]\n",
"[0.8046875] [0.9921875] [0.5311729]\n",
"[0.8984375] [0.9921875] [0.36445504]\n",
"[0.859375] [0.9921875] [0.40577835]\n",
"[0.8125] [0.9765625] [0.64629185]\n",
"[0.84375] [1.] [0.38400555]\n",
"[0.890625] [0.9765625] [0.45866236]\n",
"[0.8828125] [0.9921875] [0.3711415]\n",
"[0.7578125] [0.9921875] [0.6650479]\n",
"[0.7578125] [0.984375] [0.9030752]\n",
"[0.8671875] [0.9921875] [0.3678714]\n",
"[0.7421875] [0.9765625] [0.7424855]\n",
"[0.7890625] [1.] [0.6212543]\n",
"[0.8359375] [1.] [0.58529043]\n",
"[0.8203125] [0.96875] [0.5860813]\n",
"[0.8671875] [0.9921875] [0.415236]\n",
"[0.8125] [1.] [0.60501564]\n",
"[0.796875] [0.9765625] [0.60677457]\n",
"[0.8515625] [1.] [0.5338207]\n",
"[0.8046875] [0.9921875] [0.54180473]\n",
"[0.875] [0.9921875] [0.7293667]\n",
"[0.84375] [0.9765625] [0.5581689]\n",
"[0.8359375] [1.] [0.50712734]\n",
"[0.8671875] [0.9921875] [0.55217856]\n",
"[0.765625] [0.96875] [0.8076792]\n",
"[0.953125] [1.] [0.17031987]\n",
"[0.890625] [0.9921875] [0.42383268]\n",
"[0.828125] [0.9765625] [0.49300486]\n",
"[0.8671875] [0.96875] [0.57985115]\n",
"[0.8515625] [1.] [0.4901033]\n",
"[0.921875] [1.] [0.34583277]\n",
"[0.8984375] [0.984375] [0.41139168]\n",
"[0.9296875] [1.] [0.20420414]\n",
"[0.921875] [0.984375] [0.24322833]\n",
"[0.921875] [0.9921875] [0.30570173]\n",
"[0.875] [0.9921875] [0.3866225]\n",
"[0.9140625] [0.9921875] [0.20813875]\n",
"[0.9140625] [1.] [0.17933217]\n",
"[0.8984375] [0.9921875] [0.32508463]\n",
"[0.9375] [1.] [0.24799153]\n",
"[0.9140625] [1.] [0.26146784]\n",
"[0.90625] [1.] [0.24672262]\n",
"[0.8828125] [1.] [0.34094217]\n",
"[0.90625] [1.] [0.2964819]\n",
"[0.9296875] [1.] [0.18237087]\n",
"[0.84375] [1.] [0.7182543]\n",
"[0.8671875] [0.984375] [0.508474]\n",
"[0.8828125] [0.9921875] [0.367172]\n",
"[0.9453125] [1.] [0.2366665]\n",
"[0.9375] [1.] [0.12494276]\n",
"[0.8984375] [1.] [0.3395289]\n",
"[0.890625] [0.984375] [0.30877113]\n",
"[0.90625] [1.] [0.29763448]\n",
"[0.8828125] [0.984375] [0.4845504]\n",
"[0.8515625] [1.] [0.45548072]\n",
"[0.8828125] [1.] [0.33331633]\n",
"[0.90625] [1.] [0.4024018]\n",
"[0.890625] [0.984375] [0.73405886]\n",
"[0.9609375] [0.9921875] [0.15409982]\n",
"[0.9140625] [0.984375] [0.37103674]\n",
"[0.953125] [1.] [0.17628372]\n",
"[0.890625] [1.] [0.36522508]\n",
"[0.8828125] [1.] [0.407708]\n",
"[0.9375] [0.984375] [0.25090045]\n",
"[0.890625] [0.984375] [0.35742313]\n",
"[0.921875] [0.9921875] [0.2751101]\n",
"[0.890625] [0.984375] [0.43053097]\n",
"[0.875] [0.9921875] [0.34412643]\n",
"[0.90625] [1.] [0.35595697]\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-21-92f72657bddc>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtrain_reader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0macc1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0macc5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mexe\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpruned_program\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtrain_feeder\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfeed\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0macc1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0macc5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.5/dist-packages/paddle/fluid/executor.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, program, feed, fetch_list, feed_var_name, fetch_var_name, scope, return_numpy, use_program_cache)\u001b[0m\n\u001b[1;32m 776\u001b[0m \u001b[0mscope\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mscope\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 777\u001b[0m \u001b[0mreturn_numpy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreturn_numpy\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 778\u001b[0;31m use_program_cache=use_program_cache)\n\u001b[0m\u001b[1;32m 779\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 780\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mEOFException\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.5/dist-packages/paddle/fluid/executor.py\u001b[0m in \u001b[0;36m_run_impl\u001b[0;34m(self, program, feed, fetch_list, feed_var_name, fetch_var_name, scope, return_numpy, use_program_cache)\u001b[0m\n\u001b[1;32m 829\u001b[0m \u001b[0mscope\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mscope\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 830\u001b[0m \u001b[0mreturn_numpy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreturn_numpy\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 831\u001b[0;31m use_program_cache=use_program_cache)\n\u001b[0m\u001b[1;32m 832\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 833\u001b[0m \u001b[0mprogram\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mscope\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplace\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.5/dist-packages/paddle/fluid/executor.py\u001b[0m in \u001b[0;36m_run_program\u001b[0;34m(self, program, feed, fetch_list, feed_var_name, fetch_var_name, scope, return_numpy, use_program_cache)\u001b[0m\n\u001b[1;32m 903\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0muse_program_cache\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 904\u001b[0m self._default_executor.run(program.desc, scope, 0, True, True,\n\u001b[0;32m--> 905\u001b[0;31m fetch_var_name)\n\u001b[0m\u001b[1;32m 906\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 907\u001b[0m self._default_executor.run_prepared_ctx(ctx, scope, False, False,\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"for data in train_reader():\n",
" acc1, acc5, loss = exe.run(pruned_program, feed=train_feeder.feed(data), fetch_list=outputs)\n",
" print(acc1, acc5, loss)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
...@@ -13,6 +13,7 @@ from paddleslim.common import get_logger ...@@ -13,6 +13,7 @@ from paddleslim.common import get_logger
from paddleslim.analysis import flops from paddleslim.analysis import flops
import models import models
from utility import add_arguments, print_arguments from utility import add_arguments, print_arguments
import paddle.vision.transforms as T
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
...@@ -99,8 +100,11 @@ def compress(args): ...@@ -99,8 +100,11 @@ def compress(args):
train_reader = None train_reader = None
test_reader = None test_reader = None
if args.data == "mnist": if args.data == "mnist":
train_dataset = paddle.vision.datasets.MNIST(mode='train') transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
val_dataset = paddle.vision.datasets.MNIST(mode='test') train_dataset = paddle.vision.datasets.MNIST(
mode='train', backend="cv2", transform=transform)
val_dataset = paddle.vision.datasets.MNIST(
mode='test', backend="cv2", transform=transform)
class_dim = 10 class_dim = 10
image_shape = "1,28,28" image_shape = "1,28,28"
elif args.data == "imagenet": elif args.data == "imagenet":
...@@ -231,6 +235,7 @@ def compress(args): ...@@ -231,6 +235,7 @@ def compress(args):
ratios=[args.pruned_ratio] * len(params), ratios=[args.pruned_ratio] * len(params),
place=place) place=place)
_logger.info("FLOPs after pruning: {}".format(flops(pruned_program))) _logger.info("FLOPs after pruning: {}".format(flops(pruned_program)))
for i in range(args.num_epochs): for i in range(args.num_epochs):
train(i, pruned_program) train(i, pruned_program)
if i % args.test_period == 0: if i % args.test_period == 0:
......
...@@ -20,4 +20,7 @@ from paddleslim import analysis ...@@ -20,4 +20,7 @@ from paddleslim import analysis
from paddleslim import dist from paddleslim import dist
from paddleslim import quant from paddleslim import quant
from paddleslim import pantheon from paddleslim import pantheon
__all__ = ['models', 'prune', 'nas', 'analysis', 'dist', 'quant', 'pantheon'] from paddleslim import dygraph
__all__ = [
'models', 'prune', 'nas', 'analysis', 'dist', 'quant', 'pantheon', 'dygraph'
]
...@@ -84,6 +84,7 @@ def _graph_flops(graph, only_conv=True, detail=False): ...@@ -84,6 +84,7 @@ def _graph_flops(graph, only_conv=True, detail=False):
def dygraph_flops(model, input_shape, only_conv=False, detail=False): def dygraph_flops(model, input_shape, only_conv=False, detail=False):
data = np.ones(tuple(input_shape)).astype("float32") data = np.ones(tuple(input_shape)).astype("float32")
in_var = paddle.to_tensor(data) in_var = paddle.to_tensor(data)
_, traced = paddle.jit.TracedLayer.trace(model, [in_var]) _, traced = paddle.jit.TracedLayer.trace(model, [in_var])
......
...@@ -310,22 +310,32 @@ class FilterPruner(Pruner): ...@@ -310,22 +310,32 @@ class FilterPruner(Pruner):
group = self.var_group.find_group(var_name, pruned_dims) group = self.var_group.find_group(var_name, pruned_dims)
_logger.debug("found group with {}: {}".format(var_name, group)) _logger.debug("found group with {}: {}".format(var_name, group))
plan = PruningPlan(self.model.full_name) plan = PruningPlan(self.model.full_name)
group_dict = {}
for sub_layer in self.model.sublayers(): for sub_layer in self.model.sublayers():
for param in sub_layer.parameters(include_sublayers=False): for param in sub_layer.parameters(include_sublayers=False):
if param.name in group: if param.name in group:
group[param.name]['layer'] = sub_layer group_dict[param.name] = group[param.name]
group[param.name]['var'] = param group_dict[param.name].update({
group[param.name]['value'] = np.array(param.value() 'layer': sub_layer,
.get_tensor()) 'var': param,
'value': np.array(param.value().get_tensor())
})
_logger.debug(f"set value of {param.name} into group") _logger.debug(f"set value of {param.name} into group")
mask = self.cal_mask(var_name, pruned_ratio, group) mask = self.cal_mask(var_name, pruned_ratio, group_dict)
for _name in group: for _name in group_dict:
dims = group[_name]['pruned_dims'] dims = group_dict[_name]['pruned_dims']
stride = group_dict[_name]['stride']
var_shape = group_dict[_name]['var'].shape
if isinstance(dims, int): if isinstance(dims, int):
dims = [dims] dims = [dims]
plan.add(_name, PruningMask(dims, mask, pruned_ratio))
current_mask = mask.repeat(stride[0]) if stride[0] > 1 else mask
assert len(current_mask) == var_shape[dims[
0]], "The length of current_mask must be equal to the size of dimension to be pruned on."
plan.add(_name, PruningMask(dims, current_mask, pruned_ratio))
if apply == "lazy": if apply == "lazy":
plan.apply(self.model, lazy=True) plan.apply(self.model, lazy=True)
elif apply == "impretive": elif apply == "impretive":
......
...@@ -26,10 +26,9 @@ class PruningMask(): ...@@ -26,10 +26,9 @@ class PruningMask():
"The dims of PruningMask must be instance of collections.Iterator." "The dims of PruningMask must be instance of collections.Iterator."
) )
if self._mask is not None: if self._mask is not None:
assert ( assert len(self._mask.shape) == len(
len(self._mask.shape) == len(value), value
"The length of value must be same with shape of mask in current PruningMask instance." ), "The length of value must be same with shape of mask in current PruningMask instance."
)
self._dims = list(value) self._dims = list(value)
@property @property
...@@ -40,10 +39,9 @@ class PruningMask(): ...@@ -40,10 +39,9 @@ class PruningMask():
def mask(self, value): def mask(self, value):
assert (isinstance(value, PruningMask)) assert (isinstance(value, PruningMask))
if self._dims is not None: if self._dims is not None:
assert ( assert len(self._mask.shape) == len(
len(self._mask.shape) == len(value), value
"The length of value must be same with shape of mask in current PruningMask instance." ), "The length of value must be same with shape of mask in current PruningMask instance."
)
self._mask = value self._mask = value
def __str__(self): def __str__(self):
...@@ -158,11 +156,10 @@ class PruningPlan(): ...@@ -158,11 +156,10 @@ class PruningPlan():
for _mask in self._masks[param.name]: for _mask in self._masks[param.name]:
dims = _mask.dims dims = _mask.dims
mask = _mask.mask mask = _mask.mask
assert ( assert len(
len(dims) == 1, dims
"Imperative mode only support for pruning" ) == 1, "Imperative mode only support for pruning on one dimension, but get dims {} when pruning parameter {}".format(
"on one dimension, but get dims {} when pruning parameter {}". dims, param.name)
format(dims, param.name))
t_value = param.value().get_tensor() t_value = param.value().get_tensor()
value = np.array(t_value).astype("float32") value = np.array(t_value).astype("float32")
# The name of buffer can not contains "." # The name of buffer can not contains "."
...@@ -188,9 +185,10 @@ class PruningPlan(): ...@@ -188,9 +185,10 @@ class PruningPlan():
t_value.set(pruned_value, place) t_value.set(pruned_value, place)
if isinstance(sub_layer, paddle.nn.layer.conv.Conv2D): if isinstance(sub_layer, paddle.nn.layer.conv.Conv2D):
if sub_layer._groups > 1: if sub_layer._groups > 1 and pruned_value.shape[
1] == 1: # depthwise conv2d
_logger.debug( _logger.debug(
"Update groups of conv form {} to {}". "Update groups of depthwise conv2d form {} to {}".
format(sub_layer._groups, format(sub_layer._groups,
pruned_value.shape[0])) pruned_value.shape[0]))
sub_layer._groups = pruned_value.shape[0] sub_layer._groups = pruned_value.shape[0]
......
...@@ -18,15 +18,15 @@ class VarGroup(): ...@@ -18,15 +18,15 @@ class VarGroup():
def _to_dict(self, group): def _to_dict(self, group):
ret = {} ret = {}
for _name, _axis, _idx in group: for _name, _axis, _stride in group:
if isinstance(_axis, int): if isinstance(_axis, int):
_axis = [_axis] # TODO: fix _axis = [_axis] # TODO: fix
ret[_name] = {'pruned_dims': _axis, 'pruned_idx': _idx} ret[_name] = {'pruned_dims': _axis, 'stride': _stride}
return ret return ret
def find_group(self, var_name, axis): def find_group(self, var_name, axis):
for group in self.groups: for group in self.groups:
for _name, _axis, _ in group: for _name, _axis, _stride in group:
if isinstance(_axis, int): if isinstance(_axis, int):
_axis = [_axis] # TODO: fix _axis = [_axis] # TODO: fix
if _name == var_name and _axis == axis: if _name == var_name and _axis == axis:
...@@ -36,6 +36,7 @@ class VarGroup(): ...@@ -36,6 +36,7 @@ class VarGroup():
_logger.debug("Parsing model with input: {}".format(input_shape)) _logger.debug("Parsing model with input: {}".format(input_shape))
data = np.ones(tuple(input_shape)).astype("float32") data = np.ones(tuple(input_shape)).astype("float32")
in_var = paddle.to_tensor(data) in_var = paddle.to_tensor(data)
model.eval()
out_dygraph, static_layer = TracedLayer.trace(model, inputs=[in_var]) out_dygraph, static_layer = TracedLayer.trace(model, inputs=[in_var])
graph = GraphWrapper(static_layer.program) graph = GraphWrapper(static_layer.program)
...@@ -45,7 +46,7 @@ class VarGroup(): ...@@ -45,7 +46,7 @@ class VarGroup():
visited)[0] # [(name, axis, pruned_idx)] visited)[0] # [(name, axis, pruned_idx)]
if len(group) > 0: if len(group) > 0:
self.groups.append(group) self.groups.append(group)
_logger.debug("Found {} groups.".format(len(self.groups))) _logger.info("Found {} groups.".format(len(self.groups)))
def __str__(self): def __str__(self):
return "\n".join([str(group) for group in self.groups]) return "\n".join([str(group) for group in self.groups])
...@@ -587,6 +587,20 @@ class mul(PruneWorker): ...@@ -587,6 +587,20 @@ class mul(PruneWorker):
self._prune_op(op, param_var, 0, pruned_idx) self._prune_op(op, param_var, 0, pruned_idx)
@PRUNE_WORKER.register
class matmul(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(matmul, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.inputs("X") and pruned_axis == 1:
param_var = self.op.inputs("Y")[0]
self.pruned_params.append((param_var, 0, pruned_idx))
for op in param_var.outputs():
self._prune_op(op, param_var, 0, pruned_idx)
@PRUNE_WORKER.register @PRUNE_WORKER.register
class scale(PruneWorker): class scale(PruneWorker):
def __init__(self, op, pruned_params, visited={}): def __init__(self, op, pruned_params, visited={}):
...@@ -657,3 +671,33 @@ class affine_channel(PruneWorker): ...@@ -657,3 +671,33 @@ class affine_channel(PruneWorker):
next_ops = out_var.outputs() next_ops = out_var.outputs()
for op in next_ops: for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx) self._prune_op(op, out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
class flatten_contiguous_range(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(flatten_contiguous_range, self).__init__(op, pruned_params,
visited)
def _prune(self, var, pruned_axis, pruned_idx):
start_axis = self.op.attr("start_axis")
stop_axis = self.op.attr("stop_axis")
if var in self.op.inputs("X"):
out_var = self.op.outputs("Out")[0]
in_var = self.op.inputs("X")[0]
stride = 1
out_pruned_axis = pruned_axis
out_pruned_idx = pruned_idx
if pruned_axis >= start_axis and pruned_axis <= stop_axis:
out_pruned_axis = start_axis
for i in range(pruned_axis + 1, stop_axis + 1):
stride *= in_var.shape()[i]
elif pruned_axis > stop_axis:
out_pruned_axis = start_axis + pruned_axis - stop_axis
self._visit(in_var, pruned_axis)
self._visit(out_var, out_pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, out_pruned_axis, [stride])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册