提交 bb0f8fbb 编写于 作者: W wanghaoshuang

Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleSlim into base

......@@ -16,13 +16,6 @@ import imagenet_reader
_logger = get_logger(__name__, level=logging.INFO)
reduce_rate = 0.85
init_temperature = 10.24
max_flops = 321208544
server_address = ""
port = 8979
retain_epoch = 5
def create_data_loader(image_shape):
data_shape = [None] + image_shape
......@@ -71,17 +64,13 @@ def search_mobilenetv2_block(config, args, image_size):
if args.is_server:
sa_nas = SANAS(
config,
server_addr=("", port),
init_temperature=init_temperature,
reduce_rate=reduce_rate,
server_addr=(args.server_address, args.port),
search_steps=args.search_steps,
is_server=True)
else:
sa_nas = SANAS(
config,
server_addr=(server_address, port),
init_temperature=init_temperature,
reduce_rate=reduce_rate,
server_addr=(args.server_address, args.port),
search_steps=args.search_steps,
is_server=False)
......@@ -140,7 +129,7 @@ def search_mobilenetv2_block(config, args, image_size):
current_flops = flops(train_program)
print('step: {}, current_flops: {}'.format(step, current_flops))
if current_flops > max_flops:
if current_flops > int(321208544):
continue
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
......@@ -178,7 +167,7 @@ def search_mobilenetv2_block(config, args, image_size):
train_compiled_program = fluid.CompiledProgram(
train_program).with_data_parallel(
loss_name=avg_cost.name, build_strategy=build_strategy)
for epoch_id in range(retain_epoch):
for epoch_id in range(args.retain_epoch):
for batch_id, data in enumerate(train_loader()):
fetches = [avg_cost.name]
s_time = time.time()
......@@ -243,6 +232,11 @@ if __name__ == '__main__':
type=int,
default=100,
help='controller server number.')
parser.add_argument(
'--server_address', type=str, default="", help='server ip.')
parser.add_argument('--port', type=int, default=8881, help='server port')
parser.add_argument(
'--retain_epoch', type=int, default=5, help='epoch for each token.')
parser.add_argument('--lr', type=float, default=0.1, help='learning rate.')
args = parser.parse_args()
print(args)
......@@ -257,7 +251,7 @@ if __name__ == '__main__':
args.data))
# block mask means block number, 1 mean downsample, 0 means the size of feature map don't change after this block
config_info = {'block_mask': [0, 1, 1, 1, 1, 0, 1, 0]}
config_info = {'block_mask': [0, 1, 1, 1, 0]}
config = [('MobileNetV2BlockSpace', config_info)]
search_mobilenetv2_block(config, args, image_size)
......@@ -18,13 +18,6 @@ import imagenet_reader
_logger = get_logger(__name__, level=logging.INFO)
reduce_rate = 0.85
init_temperature = 10.24
max_flops = 321208544
server_address = ""
port = 8989
retain_epoch = 5
def create_data_loader(image_shape):
data_shape = [None] + image_shape
......@@ -66,18 +59,14 @@ def search_mobilenetv2(config, args, image_size, is_server=True):
### start a server and a client
sa_nas = SANAS(
config,
server_addr=("", port),
init_temperature=init_temperature,
reduce_rate=reduce_rate,
server_addr=(args.server_address, args.port),
search_steps=args.search_steps,
is_server=True)
else:
### start a client
sa_nas = SANAS(
config,
server_addr=(server_address, port),
init_temperature=init_temperature,
reduce_rate=reduce_rate,
server_addr=(args.server_address, args.port),
search_steps=args.search_steps,
is_server=False)
......@@ -93,7 +82,7 @@ def search_mobilenetv2(config, args, image_size, is_server=True):
current_flops = flops(train_program)
print('step: {}, current_flops: {}'.format(step, current_flops))
if current_flops > max_flops:
if current_flops > int(321208544):
continue
test_loader, test_avg_cost, test_acc_top1, test_acc_top5 = build_program(
......@@ -139,7 +128,7 @@ def search_mobilenetv2(config, args, image_size, is_server=True):
train_compiled_program = fluid.CompiledProgram(
train_program).with_data_parallel(
loss_name=avg_cost.name, build_strategy=build_strategy)
for epoch_id in range(retain_epoch):
for epoch_id in range(args.retain_epoch):
for batch_id, data in enumerate(train_loader()):
fetches = [avg_cost.name]
s_time = time.time()
......@@ -179,7 +168,7 @@ def search_mobilenetv2(config, args, image_size, is_server=True):
def test_search_result(tokens, image_size, args, config):
sa_nas = SANAS(
config,
server_addr=("", 8887),
server_addr=(args.server_address, args.port),
init_temperature=args.init_temperature,
reduce_rate=args.reduce_rate,
search_steps=args.search_steps,
......@@ -234,7 +223,7 @@ def test_search_result(tokens, image_size, args, config):
train_compiled_program = fluid.CompiledProgram(
train_program).with_data_parallel(
loss_name=avg_cost.name, build_strategy=build_strategy)
for epoch_id in range(retain_epoch):
for epoch_id in range(args.retain_epoch):
for batch_id, data in enumerate(train_loader()):
fetches = [avg_cost.name]
s_time = time.time()
......@@ -298,6 +287,11 @@ if __name__ == '__main__':
type=int,
default=100,
help='controller server number.')
parser.add_argument(
'--server_address', type=str, default="", help='server ip.')
parser.add_argument('--port', type=int, default=8881, help='server port')
parser.add_argument(
'--retain_epoch', type=int, default=5, help='epoch for each token.')
parser.add_argument('--lr', type=float, default=0.1, help='learning rate.')
args = parser.parse_args()
print(args)
......
......@@ -80,8 +80,8 @@ def run(args):
student.start()
if args.test_send_recv:
for t in xrange(2):
for i in xrange(3):
for t in range(2):
for i in range(3):
print(student.recv(t))
student.send("message from student!")
......
......@@ -17,7 +17,20 @@
1). 根据分类模型中[ImageNet数据准备文档](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification#%E6%95%B0%E6%8D%AE%E5%87%86%E5%A4%87)下载数据到`PaddleSlim/demo/data/ILSVRC2012`路径下。
2). 使用`train.py`脚本时,指定`--data`选项为`imagenet`.
## 2. 启动剪裁任务
## 2. 下载预训练模型
如果使用`ImageNet`数据,建议在预训练模型的基础上进行剪裁,请从[分类库](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification#%E5%B7%B2%E5%8F%91%E5%B8%83%E6%A8%A1%E5%9E%8B%E5%8F%8A%E5%85%B6%E6%80%A7%E8%83%BD)中下载合适的预训练模型。
这里以`MobileNetV1`为例,下载并解压预训练模型到当前路径:
```
wget http://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV1_pretrained.tar
tar -xf MobileNetV1_pretrained.tar
```
使用`train.py`脚本时,指定`--pretrained_model`加载预训练模型。
## 3. 启动剪裁任务
通过以下命令启动裁剪任务:
......@@ -25,8 +38,8 @@
export CUDA_VISIBLE_DEVICES=0
python train.py \
--model "MobileNet" \
--pruned_ratio 0.33 \
--data "imagenet"
--pruned_ratio 0.31 \
--data "mnist"
```
其中,`model`用于指定待裁剪的模型。`pruned_ratio`用于指定各个卷积层通道数被裁剪的比例。`data`选项用于指定使用的数据集。
......@@ -35,7 +48,7 @@ python train.py \
在本示例中,会在日志中输出剪裁前后的`FLOPs`,并且每训练一轮就会保存一个模型到文件系统。
## 3. 加载和评估模型
## 4. 加载和评估模型
本节介绍如何加载训练过程中保存的模型。
......@@ -43,14 +56,14 @@ python train.py \
```
python eval.py \
--model "mobilenet" \
--model "MobileNet" \
--data "mnist" \
--model_path "./models/0"
```
在脚本`eval.py`中,使用`paddleslim.prune.load_model`接口加载剪裁得到的模型。
## 4. 接口介绍
## 5. 接口介绍
该示例使用了`paddleslim.Pruner`工具类,用户接口使用介绍请参考:[API文档](https://paddlepaddle.github.io/PaddleSlim/api/prune_api/)
......
......@@ -68,7 +68,7 @@ def eval(args):
val_feeder = feeder = fluid.DataFeeder(
[image, label], place, program=val_program)
load_model(val_program, "./model/mobilenetv1_prune_50")
load_model(exe, val_program, args.model_path)
batch_id = 0
acc_top1_ns = []
......
......@@ -8,6 +8,7 @@ import math
import time
import numpy as np
import paddle.fluid as fluid
sys.path.append("../../")
from paddleslim.prune import Pruner, save_model
from paddleslim.common import get_logger
from paddleslim.analysis import flops
......@@ -37,6 +38,7 @@ add_arg('log_period', int, 10, "Log period in batches.")
add_arg('test_period', int, 10, "Test period in epoches.")
add_arg('model_path', str, "./models", "The path to save model.")
add_arg('pruned_ratio', float, None, "The ratios to be pruned.")
add_arg('criterion', str, "l1_norm", "The prune criterion to be used, support l1_norm and batch_norm_scale.")
# yapf: enable
model_list = models.__all__
......@@ -136,6 +138,8 @@ def compress(args):
return os.path.exists(
os.path.join(args.pretrained_model, var.name))
_logger.info("Load pretrained model from {}".format(
args.pretrained_model))
fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist)
val_reader = paddle.batch(val_reader, batch_size=args.batch_size)
......@@ -200,10 +204,12 @@ def compress(args):
end_time - start_time))
batch_id += 1
test(0, val_program)
params = get_pruned_params(args, fluid.default_main_program())
_logger.info("FLOPs before pruning: {}".format(
flops(fluid.default_main_program())))
pruner = Pruner()
pruner = Pruner(args.criterion)
pruned_val_program, _, _ = pruner.prune(
val_program,
fluid.global_scope(),
......
此差异已折叠。
......@@ -128,7 +128,7 @@ SANAS(Simulated Annealing Neural Architecture Search)是基于模拟退火
- **tokens(list):** - 一组tokens。tokens的长度和范围取决于搜索空间。
**返回:**
根据传入的token得到一个模型结构实例。
根据传入的token得到一个模型结构实例列表
**示例代码:**
......@@ -153,8 +153,10 @@ SANAS(Simulated Annealing Neural Architecture Search)是基于模拟退火
**示例代码:**
.. code-block:: python
import paddle.fluid as fluid
from paddleslim.nas import SANAS
config = [('MobileNetV2Space')]
sanas = SANAS(configs=config)
print(sanas.current_info())
# 多进程蒸馏
# 大规模可扩展知识蒸馏框架 Pantheon
## Teacher
......@@ -100,7 +100,8 @@ pantheon.Teacher.start\_knowledge\_service(feed\_list, schema, program, reader\_
- **times (int):** The maximum repeated serving times, default 1. Whenever
the public method **get\_knowledge\_generator()** in **Student**
object called once, the serving times will be added one,
until reaching the maximum and ending the service.
until reaching the maximum and ending the service. Only
valid in online mode, and will be ignored in offline mode.
**Return:** None
......
......@@ -378,7 +378,7 @@ load_sensitivities
}
}
sensitivities_file = "sensitive_api_demo.data"
with open(sensitivities_file, 'w') as f:
with open(sensitivities_file, 'wb') as f:
pickle.dump(sen, f)
sensitivities = load_sensitivities(sensitivities_file)
print(sensitivities)
......
# 模型库
## 1. 图分类
## 1. 图分类
数据集:ImageNet1000类
......@@ -16,7 +16,7 @@
| MobileNetV2 | quant_aware |72.05%/90.63% (-0.1%/-0.02%)| 4.0 | - | [下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/MobileNetV2_quant_aware.tar) |
|ResNet50|-|76.50%/93.00%| 99 | 2.71 | [下载链接](http://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar) |
|ResNet50|quant_post|76.33%/93.02% (-0.17%/+0.02%)| 25.1| 1.19 | [下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/ResNet50_quant_post.tar) |
|ResNet50|quant_aware| 76.48%/93.11% (-0.02%/+0.11%)| 25.1 | 1.17 | [下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/ResNet50_quant_awre.tar) |
|ResNet50|quant_aware| 76.48%/93.11% (-0.02%/+0.11%)| 25.1 | 1.17 | [下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/ResNet50_quant_awre.tar) |
分类模型Lite时延(ms)
......@@ -89,6 +89,12 @@
<a name="trans1">[1]</a>:带_vd后缀代表该预训练模型使用了Mixup,Mixup相关介绍参考[mixup: Beyond Empirical Risk Minimization](https://arxiv.org/abs/1710.09412)
### 1.4 搜索
| 模型 | 压缩方法 | Top-1/Top-5 Acc | 模型体积(MB) | GFLOPs | 下载 |
|:--:|:---:|:--:|:--:|:--:|:--:|
| MobileNetV2 | - | 72.15%/90.65% | 15 | 0.59 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_pretrained.tar) |
| MobileNetV2 | SANAS | 71.518%/90.208% (-0.632%/-0.442%) | 14 | 0.295 | [下载链接](https://paddlemodels.cdn.bcebos.com/PaddleSlim/MobileNetV2_sanas.tar) |
## 2. 目标检测
......@@ -99,8 +105,8 @@
| 模型 | 压缩方法 | 数据集 | Image/GPU | 输入608 Box AP | 输入416 Box AP | 输入320 Box AP | 模型体积(MB) | TensorRT时延(V100, ms) | 下载 |
| :----------------------------: | :---------: | :----: | :-------: | :------------: | :------------: | :------------: | :------------: | :----------: |:----------: |
| MobileNet-V1-YOLOv3 | - | COCO | 8 | 29.3 | 29.3 | 27.1 | 95 | - | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1.tar) |
| MobileNet-V1-YOLOv3 | quant_post | COCO | 8 | 27.9 (-1.4)| 28.0 (-1.3) | 26.0 (-1.0) | 25 | - | [下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_mobilenetv1_coco_quant_post.tar) |
| MobileNet-V1-YOLOv3 | quant_aware | COCO | 8 | 28.1 (-1.2)| 28.2 (-1.1) | 25.8 (-1.2) | 26.3 | - | [下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_mobilenet_coco_quant_aware.tar) |
| MobileNet-V1-YOLOv3 | quant_post | COCO | 8 | 27.9 (-1.4)| 28.0 (-1.3) | 26.0 (-1.0) | 25 | - | [下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_mobilenetv1_coco_quant_post.tar) |
| MobileNet-V1-YOLOv3 | quant_aware | COCO | 8 | 28.1 (-1.2)| 28.2 (-1.1) | 25.8 (-1.2) | 26.3 | - | [下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_mobilenet_coco_quant_aware.tar) |
| R34-YOLOv3 | - | COCO | 8 | 36.2 | 34.3 | 31.4 | 162 | - | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34.tar) |
| R34-YOLOv3 | quant_post | COCO | 8 | 35.7 (-0.5) | - | - | 42.7 | - | [下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_r34_coco_quant_post.tar) |
| R34-YOLOv3 | quant_aware | COCO | 8 | 35.2 (-1.0) | 33.3 (-1.0) | 30.3 (-1.1)| 44 | - | [下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_r34_coco_quant_aware.tar) |
......@@ -157,6 +163,20 @@
| MobileNet-V1-YOLOv3 | ResNet34-YOLOv3 distill | COCO | 8 | 31.4 (+2.1) | 30.0 (+0.7) | 27.1 (+0.1) | 95 | [下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_mobilenetv1_coco_distilled.tar) |
### 2.4 搜索
数据集:WIDER-FACE
| 模型 | 压缩方法 | Image/GPU | 输入尺寸 | Easy/Medium/Hard | 模型体积(KB) | 硬件延时(ms)| 下载 |
| :------------: | :---------: | :-------: | :------: | :-----------------------------: | :------------: | :------------: | :----------------------------------------------------------: |
| BlazeFace | - | 8 | 640 | 91.5/89.2/79.7 | 815 | 71.862 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/blazeface_original.tar) |
| BlazeFace-NAS | - | 8 | 640 | 83.7/80.7/65.8 | 244 | 21.117 |[下载链接](https://paddlemodels.bj.bcebos.com/object_detection/blazeface_nas.tar) |
| BlazeFace-NAS1 | SANAS | 8 | 640 | 87.0/83.7/68.5 | 389 | 22.558 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/blazeface_nas2.tar) |
!!! note "Note"
<a name="trans1">[1]</a>: 硬件延时时间是利用提供的硬件延时表得到的,硬件延时表是在855芯片上基于PaddleLite测试的结果。
## 3. 图像分割
数据集:Cityscapes
......
......@@ -19,4 +19,5 @@ from paddleslim import nas
from paddleslim import analysis
from paddleslim import dist
from paddleslim import quant
__all__ = ['models', 'prune', 'nas', 'analysis', 'dist', 'quant']
from paddleslim import pantheon
__all__ = ['models', 'prune', 'nas', 'analysis', 'dist', 'quant', 'pantheon']
......@@ -190,7 +190,10 @@ class SANAS(object):
self._iter = 0
def _get_host_ip(self):
return socket.gethostbyname(socket.gethostname())
if os.name == 'posix':
return socket.gethostbyname('localhost')
else:
return socket.gethostbyname(socket.gethostname())
def tokens2arch(self, tokens):
"""
......
......@@ -13,7 +13,7 @@ The illustration below shows an application of Pantheon, where the sudent model
## Prerequisites
- Python 2.7.x or 3.x
- PaddlePaddle >= 1.6.0
- PaddlePaddle >= 1.7.0
## APIs
......
......@@ -158,7 +158,7 @@ class Student(object):
if end_recved:
break
with open(in_path, 'r') as fin:
with open(in_path, 'rb') as fin:
# get knowledge desc
desc = pickle.load(fin)
out_queue.put(desc)
......@@ -222,7 +222,7 @@ class Student(object):
self._started = True
def _merge_knowledge(self, knowledge):
for k, tensors in knowledge.items():
for k, tensors in list(knowledge.items()):
if len(tensors) == 0:
del knowledge[k]
elif len(tensors) == 1:
......@@ -308,7 +308,7 @@ class Student(object):
print("Knowledge merging strategy: {}".format(
self._merge_strategy))
print("Knowledge description after merging:")
for schema, desc in knowledge_desc.items():
for schema, desc in list(knowledge_desc.items()):
print("{}: {}".format(schema, desc))
self._knowledge_desc = knowledge_desc
......@@ -426,13 +426,13 @@ class Student(object):
end_received = [0] * len(queues)
while True:
knowledge = OrderedDict(
[(k, []) for k, v in self._knowledge_desc.items()])
[(k, []) for k, v in list(self._knowledge_desc.items())])
for idx, receiver in enumerate(data_receivers):
if not end_received[idx]:
batch_samples = receiver.next(
) if six.PY2 else receiver.__next__()
if not isinstance(batch_samples, EndSignal):
for k, v in batch_samples.items():
for k, v in list(batch_samples.items()):
knowledge[k].append(v)
else:
end_received[idx] = 1
......
......@@ -151,7 +151,7 @@ class Teacher(object):
self._t2s_queue = None
self._cmd_queue = None
self._out_file = open(self._out_path, "w") if self._out_path else None
self._out_file = open(self._out_path, "wb") if self._out_path else None
if self._out_file:
return
......@@ -231,7 +231,7 @@ class Teacher(object):
"The knowledge data should be a dict or OrderedDict!")
knowledge_desc = {}
for name, value in knowledge.items():
for name, value in list(knowledge.items()):
knowledge_desc[name] = {
"shape": [-1] + list(value.shape[1:]),
"dtype": str(value.dtype),
......@@ -294,7 +294,8 @@ class Teacher(object):
times (int): The maximum repeated serving times. Default 1. Whenever
the public method 'get_knowledge_generator()' in Student
object called once, the serving times will be added one,
until reaching the maximum and ending the service.
until reaching the maximum and ending the service. Only
valid in online mode, and will be ignored in offline mode.
"""
if not self._started:
raise ValueError("The method start() should be called first!")
......@@ -339,9 +340,12 @@ class Teacher(object):
if not times > 0:
raise ValueError("Repeated serving times should be positive!")
self._times = times
if self._times > 1 and self._out_file:
self._times = 1
print("WARNING: args 'times' will be ignored in offline mode")
desc = {}
for name, var in schema.items():
for name, var in list(schema.items()):
if not isinstance(var, fluid.framework.Variable):
raise ValueError(
"The member of schema must be fluid Variable.")
......@@ -412,10 +416,14 @@ class Teacher(object):
else:
if self._knowledge_queue:
self._knowledge_queue.put(EndSignal())
# should close file in child thread to wait for all
# writing finished
if self._out_file:
self._out_file.close()
# Asynchronous output
out_buf_queue = Queue.Queue(self._buf_size)
schema_keys, schema_vars = zip(*self._schema.items())
schema_keys, schema_vars = zip(*list(self._schema.items()))
out_thread = Thread(target=writer, args=(out_buf_queue, schema_keys))
out_thread.daemon = True
out_thread.start()
......@@ -424,8 +432,9 @@ class Teacher(object):
self._program).with_data_parallel()
print("Knowledge description {}".format(self._knowledge_desc))
print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) +
" Teacher begins to serve ...")
print(
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) +
" Teacher begins to serve ...")
# For offline dump, write the knowledge description to the head of file
if self._out_file:
self._out_file.write(pickle.dumps(self._knowledge_desc))
......@@ -491,11 +500,10 @@ class Teacher(object):
if self._knowledge_queue:
self._knowledge_queue.join()
print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) +
" Teacher ends serving.")
print(
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) +
" Teacher ends serving.")
def __del__(self):
if self._manager:
self._manager.shutdown()
if self._out_file:
self._out_file.close()
......@@ -13,6 +13,7 @@
# limitations under the License.
import logging
import sys
import numpy as np
import paddle.fluid as fluid
import copy
......@@ -79,8 +80,8 @@ class Pruner():
pruned_num = int(round(param_v.shape()[0] * ratio))
pruned_idx = [0] * pruned_num
else:
param_t = np.array(scope.find_var(param).get_tensor())
pruned_idx = self._cal_pruned_idx(param_t, ratio, axis=0)
pruned_idx = self._cal_pruned_idx(
graph, scope, param, ratio, axis=0)
param = graph.var(param)
conv_op = param.outputs()[0]
walker = conv2d_walker(
......@@ -130,7 +131,7 @@ class Pruner():
graph.infer_shape()
return graph.program, param_backup, param_shape_backup
def _cal_pruned_idx(self, param, ratio, axis):
def _cal_pruned_idx(self, graph, scope, param, ratio, axis):
"""
Calculate the index to be pruned on axis by given pruning ratio.
......@@ -145,11 +146,26 @@ class Pruner():
Returns:
list<int>: The indexes to be pruned on axis.
"""
prune_num = int(round(param.shape[axis] * ratio))
reduce_dims = [i for i in range(len(param.shape)) if i != axis]
if self.criterion == 'l1_norm':
criterions = np.sum(np.abs(param), axis=tuple(reduce_dims))
pruned_idx = criterions.argsort()[:prune_num]
param_t = np.array(scope.find_var(param).get_tensor())
prune_num = int(round(param_t.shape[axis] * ratio))
reduce_dims = [i for i in range(len(param_t.shape)) if i != axis]
criterions = np.sum(np.abs(param_t), axis=tuple(reduce_dims))
pruned_idx = criterions.argsort()[:prune_num]
elif self.criterion == "batch_norm_scale":
param_var = graph.var(param)
conv_op = param_var.outputs()[0]
conv_output = conv_op.outputs("Output")[0]
bn_op = conv_output.outputs()[0]
if bn_op is not None:
bn_scale_param = bn_op.inputs("Scale")[0].name()
bn_scale_np = np.array(
scope.find_var(bn_scale_param).get_tensor())
prune_num = int(round(bn_scale_np.shape[axis] * ratio))
pruned_idx = np.abs(bn_scale_np).argsort()[:prune_num]
else:
raise SystemExit(
"Can't find BatchNorm op after Conv op in Network.")
return pruned_idx
def _prune_tensor(self, tensor, pruned_idx, pruned_axis, lazy=False):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import unittest
import paddle.fluid as fluid
from paddleslim.prune import Pruner
from layers import conv_bn_layer
class TestPrune(unittest.TestCase):
def test_prune(self):
main_program = fluid.Program()
startup_program = fluid.Program()
# X X O X O
# conv1-->conv2-->sum1-->conv3-->conv4-->sum2-->conv5-->conv6
# | ^ | ^
# |____________| |____________________|
#
# X: prune output channels
# O: prune input channels
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[None, 3, 16, 16])
conv1 = conv_bn_layer(input, 8, 3, "conv1")
conv2 = conv_bn_layer(conv1, 8, 3, "conv2")
sum1 = conv1 + conv2
conv3 = conv_bn_layer(sum1, 8, 3, "conv3")
conv4 = conv_bn_layer(conv3, 8, 3, "conv4")
sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
shapes = {}
for param in main_program.global_block().all_parameters():
shapes[param.name] = param.shape
place = fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.Scope()
exe.run(startup_program, scope=scope)
criterion = 'batch_norm_scale'
pruner = Pruner(criterion)
main_program, _, _ = pruner.prune(
main_program,
scope,
params=["conv4_weights"],
ratios=[0.5],
place=place,
lazy=False,
only_graph=False,
param_backup=None,
param_shape_backup=None)
shapes = {
"conv1_weights": (4L, 3L, 3L, 3L),
"conv2_weights": (4L, 4L, 3L, 3L),
"conv3_weights": (8L, 4L, 3L, 3L),
"conv4_weights": (4L, 8L, 3L, 3L),
"conv5_weights": (8L, 4L, 3L, 3L),
"conv6_weights": (8L, 8L, 3L, 3L)
}
for param in main_program.global_block().all_parameters():
if "weights" in param.name:
print("param: {}; param shape: {}".format(param.name,
param.shape))
self.assertTrue(param.shape == shapes[param.name])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册