提交 e720f617 编写于 作者: C ceci3

Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleSlim into develop

# PaddleSlim
As a submodule of PaddlePaddle framework, PaddleSlim is an open-source library for deep model compression and architecture search. PaddleSlim supports current popular deep compression techniques such as pruning, quantization, and knowledge distillation. Further, it also automates the search of hyperparameters and the design of lightweight deep architectures. In the future, we will develop more practically useful compression techniques for industrial-level applications and transfer these techniques to models in NLP.
## Outline
- Key Features
- Architecture of PaddleSlim
- Methods
- Experimental Results
## Key Features
The main key features of PaddleSlim are:
### Simple APIs
- It provides simple APIs for building and deploying lightweight and energy-efficient deep models on different platforms. Experimental hyperparameters can be set up by a simple configuration file.
- It requires just a little coding work for a model compression.
### Outstanding Performance
- For MobileNetV1 with limited redundancy, channel-based pruning can ensure lossless compression.
- Knowledge distillation can promote the performance of baseline models with a clear margin.
- Quantization after knowledge distillation can reduce model size and increase accuracy of models.
### Flexible APIs
- We automate the pruning process.
- Pruning strategy can be applied onto various deep architectures.
- We can distill multiple kinds of knowledge from teacher models to student models and self-defined losses for the corresponding knowledge distillation are supported.
- We support the deployment of multiple compression strategies.
## Architecture of PaddleSlim
To make the usage of PaddleSlim clearer and easier, we briefly introduce the background of how to implement the library.
The architecture of PaddleSlim is demonstrated in **Figure 1**. The high-level APIs often depend on several low-level APIs. We can see, knowledge distillation, quantization, and pruning are indirectly based on the Paddle framework. Currently, as a part of PaddlePaddle, user can use PaddleSlim for model compression and search after kindly download and install Paddle framework.
<p align="center">
<img src="docs/images/framework_0.png" height=452 width=900 hspace='10'/> <br />
<strong>Figure 1</strong>
</p>
As shown in **Figure 1**, the top-level module, marked as yellow, is the API exposed to users. When we deploy compression methods in Python, we only need to construct an instance of Compressor.
We encapsulate each compression and search method to a compression strategy class. When we train the deep model to be compressed, the strategy class will be instantiated by using the configuration information registered by users, as shown in **Figure 2**. The logic of training process is encapsulated in our compression method. The jobs that users should do by themself is to define the structure of deep models, to prepare the training data, and to choose optimization strategy. This would surely help users save much effort.
<p align="center">
<img src="docs/images/framework_1.png" height=255 width=646 hspace='10'/> <br />
<strong>Figure 2</strong>
</p>
## Methods
### Pruning
- Here, PaddleSlim supports uniform prunning, sensitivity-based prunning, and automated model pruning methods.
- PaddleSlim supports pruning of various deep architectures such as VGG, ResNet, and MobileNet.
- PaddleSlim supports self-defined range of pruning, i.e., layers to be pruned.
### Quantization
- PaddleSlim supports training-aware quantization with static and dynamic estimation of quantization hyperparameters such as scale.
- Dynamic strategy: During inference, we quantize models with hyperparameters dynamically estimated from small batches of samples.
- Static strategy: During inference, we quantize models with the same hyperparameters estimated from training data.
- PaddleSlim supports layer-wise and channel-wise quantization.
- PaddleSlim provides models compatible with Paddle Mobile for final inference.
### Knowledge Distillation
- PaddleSlim supports the following losses added on any paired layers between teacher and student models:
- Flow of the solution procedure (FSP) loss.
- L2 loss.
- Softmax with cross-entropy loss.
### Lightweight Network Architecture Search (Light-NAS)
- PaddleSlim provides Simulated Annealing (SA)-based lightweight network architecture search method.
- PaddleSlim supports distributed search.
- PaddleSlim supports FLOPs and latency constrained search.
- PaddleSlim supports the latency estimation on different hardware and platforms.
## Experimental Results
In this section, we will show some experimental results conducted on PaddleSlim.
### Quantization
We evaluate the quantized models on ImageNet2012. The top-5/top-1 accuracies are compared,
| Model | FP32| int8(X:abs_max, W:abs_max) | int8, (X:moving_average_abs_max, W:abs_max) |int8, (X:abs_max, W:channel_wise_abs_max) |
|:---|:---:|:---:|:---:|:---:|
|MobileNetV1|89.54%/70.91%|89.64%/71.01%|89.58%/70.86%|89.75%/71.13%|
|ResNet50|92.80%/76.35%|93.12%/76.77%|93.07%/76.65%|93.15%/76.80%|
Before and after quantization, the model sizes are,
| Model | FP32 | int8(A:abs_max, W:abs_max) | int8, (A:moving_average_abs_max, W:abs_max) | int8, (A:abs_max, W:channel_wise_abs_max) |
| :--- | :---: | :---: | :---: | :---: |
| MobileNetV1 | 17M | 4.8M(-71.76%) | 4.9M(-71.18%) | 4.9M(-71.18%) |
| ResNet50 | 99M | 26M(-73.74%) | 27M(-72.73%) | 27M(-72.73%) |
Note: abs_max refers to dynamic strategy; moving_average_abs_max refers to static strategy; channel_wise_abs_max refers channel-wise quantization for weights in convolutional layers.
### Pruning
Data: ImageNet2012
Baseline model: MobileNetV1
Model size: 17M
Top-5/top-1 accuracies: 89.54% / 70.91%
#### Uniform pruning
| FLOPS |model size| Decrease in accuracy (top5/top1)| Accuracy (top5/top1) |
|---|---|---|---|
| -50%|-47.0%(9.0M)|-0.41% / -1.08%|88.92% / 69.66%|
| -60%|-55.9%(7.5M)|-1.34% / -2.67%|88.22% / 68.24%|
| -70%|-65.3%(5.9M)|-2.55% / -4.34%|86.99% / 66.57%|
#### Sensitivity-based pruning
| FLOPS |精度(top5/top1)|
|---|---|
| -0% |89.54% / 70.91% |
| -20% |90.08% / 71.48% |
| -36% |89.62% / 70.83%|
| -50% |88.77% / 69.31%|
### Knowledge distillation
Data: ImageNet2012
Baseline model: MobileNetV1
|- |Accuracy (top5/top1) |Gain in accuracy (top5/top1)|
|---|---|---|
| Train from scratch | 89.54% / 70.91%| - |
| Distilled from ResNet50 | 90.92% / 71.97%| +1.28% / +1.06%|
### Hydrid methods
Data: ImageNet2012
Baseline model: MobileNetV1
|Methods |Accuracy (top5/top1) |Model Size|
|---|---|---|
| Baseline|89.54% / 70.91%|17.0M|
| Distilled from ResNet50|90.92% / 71.97%|17.0M|
| Distilled from ResNet50 + Quantization |90.94% / 72.01%|4.8M|
| Pruning -50% FLOPS|89.13% / 69.83%|9.0M|
| Pruning -50% FLOPS + Quantization|89.11% / 69.20%|2.3M|
### Light-NAS
Data: ImageNet2012
| - | FLOPS | Top1/Top5 accuracy | GPU cost |
|------------------|-------|--------------------|----------------------|
| MobileNetV2 | 0% | 71.90% / 90.55% | - |
| Light-NAS-model0 | -3% | 72.45% / 90.70% | 1.2K GPU hours(V100) |
| Light-NAS-model1 | -17% | 71.84% / 90.45% | 1.2K GPU hours(V100) |
Hardware-aware latency-constrained light-NAS
| - | Latency | Top1/Top5 accuracy | GPU cost |
|---------------|---------|--------------------|---------------------|
| MobileNetV2 | 0% | 71.90% / 90.55% | - |
| RK3288 | -23% | 71.97% / 90.35% | 1.2K GPU hours(V100) |
| Android cellphone | -20% | 72.06% / 90.36% | 1.2K GPU hours(V100) |
| iPhone 6s | -17% | 72.22% / 90.47% | 1.2K GPU hours(V100) |
# PaddleSlim # PaddleSlim
......
...@@ -150,7 +150,9 @@ def compress(args): ...@@ -150,7 +150,9 @@ def compress(args):
# print(v.name, v.shape) # print(v.name, v.shape)
exe.run(t_startup) exe.run(t_startup)
_download('http://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar', '.') _download(
'http://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar',
'.')
_decompress('./ResNet50_pretrained.tar') _decompress('./ResNet50_pretrained.tar')
assert args.teacher_pretrained_model and os.path.exists( assert args.teacher_pretrained_model and os.path.exists(
args.teacher_pretrained_model args.teacher_pretrained_model
...@@ -168,21 +170,17 @@ def compress(args): ...@@ -168,21 +170,17 @@ def compress(args):
predicate=if_exist) predicate=if_exist)
data_name_map = {'image': 'image'} data_name_map = {'image': 'image'}
main = merge( merge(teacher_program, student_program, data_name_map, place)
teacher_program,
student_program, with fluid.program_guard(student_program, s_startup):
data_name_map, l2_loss = l2_loss("teacher_fc_0.tmp_0", "fc_0.tmp_0", student_program)
place)
with fluid.program_guard(main, s_startup):
l2_loss = l2_loss("teacher_fc_0.tmp_0", "fc_0.tmp_0", main)
loss = avg_cost + l2_loss loss = avg_cost + l2_loss
opt = create_optimizer(args) opt = create_optimizer(args)
opt.minimize(loss) opt.minimize(loss)
exe.run(s_startup) exe.run(s_startup)
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.fuse_all_reduce_ops = False build_strategy.fuse_all_reduce_ops = False
parallel_main = fluid.CompiledProgram(main).with_data_parallel( parallel_main = fluid.CompiledProgram(student_program).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy) loss_name=loss.name, build_strategy=build_strategy)
for epoch_id in range(args.num_epochs): for epoch_id in range(args.num_epochs):
...@@ -190,9 +188,7 @@ def compress(args): ...@@ -190,9 +188,7 @@ def compress(args):
loss_1, loss_2, loss_3 = exe.run( loss_1, loss_2, loss_3 = exe.run(
parallel_main, parallel_main,
feed=data, feed=data,
fetch_list=[ fetch_list=[loss.name, avg_cost.name, l2_loss.name])
loss.name, avg_cost.name, l2_loss.name
])
if step_id % args.log_period == 0: if step_id % args.log_period == 0:
_logger.info( _logger.info(
"train_epoch {} step {} loss {:.6f}, class loss {:.6f}, l2 loss {:.6f}". "train_epoch {} step {} loss {:.6f}, class loss {:.6f}, l2 loss {:.6f}".
......
from .mobilenet import MobileNet from .mobilenet import MobileNet
from .resnet import ResNet34, ResNet50 from .resnet import ResNet34, ResNet50
from .mobilenet_v2 import MobileNetV2 from .mobilenet_v2 import MobileNetV2
from .pvanet import PVANet
__all__ = ['MobileNet', 'ResNet34', 'ResNet50', 'MobileNetV2'] __all__ = ['MobileNet', 'ResNet34', 'ResNet50', 'MobileNetV2', 'PVANet']
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
import os, sys, time, math
import numpy as np
from collections import namedtuple
BLOCK_TYPE_MCRELU = 'BLOCK_TYPE_MCRELU'
BLOCK_TYPE_INCEP = 'BLOCK_TYPE_INCEP'
BlockConfig = namedtuple('BlockConfig',
'stride, num_outputs, preact_bn, block_type')
__all__ = ['PVANet']
class PVANet():
def __init__(self):
pass
def net(self, input, include_last_bn_relu=True, class_dim=1000):
conv1 = self._conv_bn_crelu(input, 16, 7, stride=2, name="conv1_1")
pool1 = fluid.layers.pool2d(
input=conv1,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max',
name='pool1')
end_points = {}
conv2 = self._conv_stage(
pool1,
block_configs=[
BlockConfig(1, (24, 24, 48), False, BLOCK_TYPE_MCRELU),
BlockConfig(1, (24, 24, 48), True, BLOCK_TYPE_MCRELU),
BlockConfig(1, (24, 24, 48), True, BLOCK_TYPE_MCRELU)
],
name='conv2',
end_points=end_points)
conv3 = self._conv_stage(
conv2,
block_configs=[
BlockConfig(2, (48, 48, 96), True, BLOCK_TYPE_MCRELU),
BlockConfig(1, (48, 48, 96), True, BLOCK_TYPE_MCRELU),
BlockConfig(1, (48, 48, 96), True, BLOCK_TYPE_MCRELU),
BlockConfig(1, (48, 48, 96), True, BLOCK_TYPE_MCRELU)
],
name='conv3',
end_points=end_points)
conv4 = self._conv_stage(
conv3,
block_configs=[
BlockConfig(2, '64 48-96 24-48-48 96 128', True,
BLOCK_TYPE_INCEP),
BlockConfig(1, '64 64-96 24-48-48 128', True,
BLOCK_TYPE_INCEP),
BlockConfig(1, '64 64-96 24-48-48 128', True,
BLOCK_TYPE_INCEP),
BlockConfig(1, '64 64-96 24-48-48 128', True, BLOCK_TYPE_INCEP)
],
name='conv4',
end_points=end_points)
conv5 = self._conv_stage(
conv4,
block_configs=[
BlockConfig(2, '64 96-128 32-64-64 128 196', True,
BLOCK_TYPE_INCEP),
BlockConfig(1, '64 96-128 32-64-64 196', True,
BLOCK_TYPE_INCEP),
BlockConfig(1, '64 96-128 32-64-64 196', True,
BLOCK_TYPE_INCEP), BlockConfig(
1, '64 96-128 32-64-64 196', True,
BLOCK_TYPE_INCEP)
],
name='conv5',
end_points=end_points)
if include_last_bn_relu:
conv5 = self._bn(conv5, 'relu', 'conv5_4_last_bn')
end_points['conv5'] = conv5
output = fluid.layers.fc(input=input,
size=class_dim,
act='softmax',
param_attr=ParamAttr(
initializer=MSRA(), name="fc_weights"),
bias_attr=ParamAttr(name="fc_offset"))
return output
def _conv_stage(self, input, block_configs, name, end_points):
net = input
for idx, bc in enumerate(block_configs):
if bc.block_type == BLOCK_TYPE_MCRELU:
block_scope = '{}_{}'.format(name, idx + 1)
fn = self._mCReLU
elif bc.block_type == BLOCK_TYPE_INCEP:
block_scope = '{}_{}_incep'.format(name, idx + 1)
fn = self._inception_block
net = fn(net, bc, block_scope)
end_points[block_scope] = net
end_points[name] = net
return net
def _mCReLU(self, input, mc_config, name):
"""
every cReLU has at least three conv steps:
conv_bn_relu, conv_bn_crelu, conv_bn_relu
if the inputs has a different number of channels as crelu output,
an extra 1x1 conv is added before sum.
"""
if mc_config.preact_bn:
conv1_fn = self._bn_relu_conv
conv1_scope = name + '_1'
else:
conv1_fn = self._conv
conv1_scope = name + '_1_conv'
sub_conv1 = conv1_fn(input, mc_config.num_outputs[0], 1, conv1_scope,
mc_config.stride)
sub_conv2 = self._bn_relu_conv(sub_conv1, mc_config.num_outputs[1], 3,
name + '_2')
sub_conv3 = self._bn_crelu_conv(sub_conv2, mc_config.num_outputs[2], 1,
name + '_3')
if int(input.shape[1]) == mc_config.num_outputs[2]:
conv_proj = input
else:
conv_proj = self._conv(input, mc_config.num_outputs[2], 1,
name + '_proj', mc_config.stride)
conv = sub_conv3 + conv_proj
return conv
def _inception_block(self, input, block_config, name):
num_outputs = block_config.num_outputs.split() # e.g. 64 24-48-48 128
num_outputs = [map(int, s.split('-')) for s in num_outputs]
inception_outputs = num_outputs[-1][0]
num_outputs = num_outputs[:-1]
stride = block_config.stride
pool_path_outputs = None
if stride > 1:
pool_path_outputs = num_outputs[-1][0]
num_outputs = num_outputs[:-1]
scopes = [['_0']] # follow the name style of caffe pva
kernel_sizes = [[1]]
for path_idx, path_outputs in enumerate(num_outputs[1:]):
path_idx += 1
path_scopes = ['_{}_reduce'.format(path_idx)]
path_scopes.extend([
'_{}_{}'.format(path_idx, i - 1)
for i in range(1, len(path_outputs))
])
scopes.append(path_scopes)
path_kernel_sizes = [1, 3, 3][:len(path_outputs)]
kernel_sizes.append(path_kernel_sizes)
paths = []
if block_config.preact_bn:
preact = self._bn(input, 'relu', name + '_bn')
else:
preact = input
path_params = zip(num_outputs, scopes, kernel_sizes)
for path_idx, path_param in enumerate(path_params):
path_net = preact
for conv_idx, (num_output, scope,
kernel_size) in enumerate(zip(*path_param)):
if conv_idx == 0:
conv_stride = stride
else:
conv_stride = 1
path_net = self._conv_bn_relu(path_net, num_output,
kernel_size, name + scope,
conv_stride)
paths.append(path_net)
if stride > 1:
path_net = fluid.layers.pool2d(
input,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max',
name=name + '_pool')
path_net = self._conv_bn_relu(path_net, pool_path_outputs, 1,
name + '_poolproj')
paths.append(path_net)
block_net = fluid.layers.concat(paths, axis=1)
block_net = self._conv(block_net, inception_outputs, 1,
name + '_out_conv')
if int(input.shape[1]) == inception_outputs:
proj = input
else:
proj = self._conv(input, inception_outputs, 1, name + '_proj',
stride)
return block_net + proj
def _scale(self, input, name, axis=1, num_axes=1):
assert num_axes == 1, "layer scale not support this num_axes[%d] now" % (
num_axes)
prefix = name + '_'
scale_shape = input.shape[axis:axis + num_axes]
param_attr = fluid.ParamAttr(name=prefix + 'gamma')
scale_param = fluid.layers.create_parameter(
shape=scale_shape,
dtype=input.dtype,
name=name,
attr=param_attr,
is_bias=True,
default_initializer=fluid.initializer.Constant(value=1.0))
offset_attr = fluid.ParamAttr(name=prefix + 'beta')
offset_param = fluid.layers.create_parameter(
shape=scale_shape,
dtype=input.dtype,
name=name,
attr=offset_attr,
is_bias=True,
default_initializer=fluid.initializer.Constant(value=0.0))
output = fluid.layers.elementwise_mul(
input, scale_param, axis=axis, name=prefix + 'mul')
output = fluid.layers.elementwise_add(
output, offset_param, axis=axis, name=prefix + 'add')
return output
def _conv(self,
input,
num_filters,
filter_size,
name,
stride=1,
groups=1,
act=None):
net = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=act,
use_cudnn=True,
param_attr=ParamAttr(name=name + '_weights'),
bias_attr=ParamAttr(name=name + '_bias'),
name=name)
return net
def _bn(self, input, act, name):
net = fluid.layers.batch_norm(
input=input,
act=act,
name=name,
moving_mean_name=name + '_mean',
moving_variance_name=name + '_variance',
param_attr=ParamAttr(name=name + '_scale'),
bias_attr=ParamAttr(name=name + '_offset'))
return net
def _bn_relu_conv(self,
input,
num_filters,
filter_size,
name,
stride=1,
groups=1):
net = self._bn(input, 'relu', name + '_bn')
net = self._conv(net, num_filters, filter_size, name + '_conv', stride,
groups)
return net
def _conv_bn_relu(self,
input,
num_filters,
filter_size,
name,
stride=1,
groups=1):
net = self._conv(input, num_filters, filter_size, name + '_conv',
stride, groups)
net = self._bn(net, 'relu', name + '_bn')
return net
def _bn_crelu(self, input, name):
net = self._bn(input, None, name + '_bn_1')
neg_net = fluid.layers.scale(net, scale=-1.0, name=name + '_neg')
net = fluid.layers.concat([net, neg_net], axis=1)
net = self._scale(net, name + '_scale')
net = fluid.layers.relu(net, name=name + '_relu')
return net
def _conv_bn_crelu(self,
input,
num_filters,
filter_size,
name,
stride=1,
groups=1,
act=None):
net = self._conv(input, num_filters, filter_size, name + '_conv',
stride, groups)
net = self._bn_crelu(net, name)
return net
def _bn_crelu_conv(self,
input,
num_filters,
filter_size,
name,
stride=1,
groups=1,
act=None):
net = self._bn_crelu(input, name)
net = self._conv(net, num_filters, filter_size, name + '_conv', stride,
groups)
return net
def deconv_bn_layer(self,
input,
num_filters,
filter_size=4,
stride=2,
padding=1,
act='relu',
name=None):
"""Deconv bn layer."""
deconv = fluid.layers.conv2d_transpose(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
act=None,
param_attr=ParamAttr(name=name + '_weights'),
bias_attr=ParamAttr(name=name + '_bias'),
name=name + 'deconv')
return self._bn(deconv, act, name + '_bn')
def conv_bn_layer(self,
input,
num_filters,
filter_size,
name,
stride=1,
groups=1):
return self._conv_bn_relu(input, num_filters, filter_size, name,
stride, groups)
def Fpn_Fusion(blocks, net):
f = [blocks['conv5'], blocks['conv4'], blocks['conv3'], blocks['conv2']]
num_outputs = [64] * len(f)
g = [None] * len(f)
h = [None] * len(f)
for i in range(len(f)):
h[i] = net.conv_bn_layer(f[i], num_outputs[i], 1, 'fpn_pre_' + str(i))
for i in range(len(f) - 1):
if i == 0:
g[i] = net.deconv_bn_layer(h[i], num_outputs[i], name='fpn_0')
else:
out = fluid.layers.elementwise_add(x=g[i - 1], y=h[i])
out = net.conv_bn_layer(out, num_outputs[i], 1,
'fpn_trans_' + str(i))
g[i] = net.deconv_bn_layer(
out, num_outputs[i], name='fpn_' + str(i))
out = fluid.layers.elementwise_add(x=g[-2], y=h[-1])
out = net.conv_bn_layer(out, num_outputs[-1], 1, 'fpn_post_0')
out = net.conv_bn_layer(out, num_outputs[-1], 3, 'fpn_post_1')
return out
def Detector_Header(f_common, net, class_num):
"""Detector header."""
f_geo = net.conv_bn_layer(f_common, 64, 1, name='geo_1')
f_geo = net.conv_bn_layer(f_geo, 64, 3, name='geo_2')
f_geo = net.conv_bn_layer(f_geo, 64, 1, name='geo_3')
f_geo = fluid.layers.conv2d(
f_geo,
8,
1,
use_cudnn=True,
param_attr=ParamAttr(name='geo_4_conv_weights'),
bias_attr=ParamAttr(name='geo_4_conv_bias'),
name='geo_4_conv')
name = 'score_class_num' + str(class_num + 1)
f_score = net.conv_bn_layer(f_common, 64, 1, 'score_1')
f_score = net.conv_bn_layer(f_score, 64, 3, 'score_2')
f_score = net.conv_bn_layer(f_score, 64, 1, 'score_3')
f_score = fluid.layers.conv2d(
f_score,
class_num + 1,
1,
use_cudnn=True,
param_attr=ParamAttr(name=name + '_conv_weights'),
bias_attr=ParamAttr(name=name + '_conv_bias'),
name=name + '_conv')
f_score = fluid.layers.transpose(f_score, perm=[0, 2, 3, 1])
f_score = fluid.layers.reshape(f_score, shape=[-1, class_num + 1])
f_score = fluid.layers.softmax(input=f_score)
return f_score, f_geo
def east(input, class_num=31):
net = PVANet()
out = net.net(input)
blocks = []
for i, j, k in zip(['conv2', 'conv3', 'conv4', 'conv5'], [1, 2, 4, 8],
[64, 64, 64, 64]):
if j == 1:
conv = net.conv_bn_layer(
out[i], k, 1, name='fusion_' + str(len(blocks)))
elif j <= 4:
conv = net.deconv_bn_layer(
out[i], k, 2 * j, j, j // 2,
name='fusion_' + str(len(blocks)))
else:
conv = net.deconv_bn_layer(
out[i], 32, 8, 4, 2, name='fusion_' + str(len(blocks)) + '_1')
conv = net.deconv_bn_layer(
conv,
k,
j // 2,
j // 4,
j // 8,
name='fusion_' + str(len(blocks)) + '_2')
blocks.append(conv)
conv = fluid.layers.concat(blocks, axis=1)
f_score, f_geo = Detector_Header(conv, net, class_num)
return f_score, f_geo
def inference(input, class_num=1, nms_thresh=0.2, score_thresh=0.5):
f_score, f_geo = east(input, class_num)
print("f_geo shape={}".format(f_geo.shape))
print("f_score shape={}".format(f_score.shape))
f_score = fluid.layers.transpose(f_score, perm=[1, 0])
return f_score, f_geo
def loss(f_score, f_geo, l_score, l_geo, l_mask, class_num=1):
'''
predictions: f_score: -1 x 1 x H x W; f_geo: -1 x 8 x H x W
targets: l_score: -1 x 1 x H x W; l_geo: -1 x 1 x H x W; l_mask: -1 x 1 x H x W
return: dice_loss + smooth_l1_loss
'''
#smooth_l1_loss
channels = 8
l_geo_split, l_short_edge = fluid.layers.split(
l_geo, num_or_sections=[channels, 1],
dim=1) #last channel is short_edge_norm
f_geo_split = fluid.layers.split(f_geo, num_or_sections=[channels], dim=1)
f_geo_split = f_geo_split[0]
geo_diff = l_geo_split - f_geo_split
abs_geo_diff = fluid.layers.abs(geo_diff)
l_flag = l_score >= 1
l_flag = fluid.layers.cast(x=l_flag, dtype="float32")
l_flag = fluid.layers.expand(x=l_flag, expand_times=[1, channels, 1, 1])
smooth_l1_sign = abs_geo_diff < l_flag
smooth_l1_sign = fluid.layers.cast(x=smooth_l1_sign, dtype="float32")
in_loss = abs_geo_diff * abs_geo_diff * smooth_l1_sign + (
abs_geo_diff - 0.5) * (1.0 - smooth_l1_sign)
l_short_edge = fluid.layers.expand(
x=l_short_edge, expand_times=[1, channels, 1, 1])
out_loss = l_short_edge * in_loss * l_flag
out_loss = out_loss * l_flag
smooth_l1_loss = fluid.layers.reduce_mean(out_loss)
##softmax_loss
l_score.stop_gradient = True
l_score = fluid.layers.transpose(l_score, perm=[0, 2, 3, 1])
l_score.stop_gradient = True
l_score = fluid.layers.reshape(l_score, shape=[-1, 1])
l_score.stop_gradient = True
l_score = fluid.layers.cast(x=l_score, dtype="int64")
l_score.stop_gradient = True
softmax_loss = fluid.layers.cross_entropy(input=f_score, label=l_score)
softmax_loss = fluid.layers.reduce_mean(softmax_loss)
return softmax_loss, smooth_l1_loss
...@@ -20,7 +20,6 @@ import math ...@@ -20,7 +20,6 @@ import math
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers.ops as ops import paddle.fluid.layers.ops as ops
from paddle.fluid.initializer import init_on_cpu
from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter
lr_strategy = 'cosine_decay' lr_strategy = 'cosine_decay'
...@@ -40,10 +39,9 @@ def cosine_decay(learning_rate, step_each_epoch, epochs=120): ...@@ -40,10 +39,9 @@ def cosine_decay(learning_rate, step_each_epoch, epochs=120):
""" """
global_step = _decay_step_counter() global_step = _decay_step_counter()
with init_on_cpu(): epoch = ops.floor(global_step / step_each_epoch)
epoch = ops.floor(global_step / step_each_epoch) decayed_lr = learning_rate * \
decayed_lr = learning_rate * \ (ops.cos(epoch * (math.pi / epochs)) + 1)/2
(ops.cos(epoch * (math.pi / epochs)) + 1)/2
return decayed_lr return decayed_lr
...@@ -63,17 +61,16 @@ def cosine_decay_with_warmup(learning_rate, step_each_epoch, epochs=120): ...@@ -63,17 +61,16 @@ def cosine_decay_with_warmup(learning_rate, step_each_epoch, epochs=120):
warmup_epoch = fluid.layers.fill_constant( warmup_epoch = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=float(5), force_cpu=True) shape=[1], dtype='float32', value=float(5), force_cpu=True)
with init_on_cpu(): epoch = ops.floor(global_step / step_each_epoch)
epoch = ops.floor(global_step / step_each_epoch) with fluid.layers.control_flow.Switch() as switch:
with fluid.layers.control_flow.Switch() as switch: with switch.case(epoch < warmup_epoch):
with switch.case(epoch < warmup_epoch): decayed_lr = learning_rate * (global_step /
decayed_lr = learning_rate * (global_step / (step_each_epoch * warmup_epoch))
(step_each_epoch * warmup_epoch)) fluid.layers.tensor.assign(input=decayed_lr, output=lr)
fluid.layers.tensor.assign(input=decayed_lr, output=lr) with switch.default():
with switch.default(): decayed_lr = learning_rate * \
decayed_lr = learning_rate * \ (ops.cos((global_step - warmup_epoch * step_each_epoch) * (math.pi / (epochs * step_each_epoch))) + 1)/2
(ops.cos((global_step - warmup_epoch * step_each_epoch) * (math.pi / (epochs * step_each_epoch))) + 1)/2 fluid.layers.tensor.assign(input=decayed_lr, output=lr)
fluid.layers.tensor.assign(input=decayed_lr, output=lr)
return lr return lr
...@@ -95,19 +92,18 @@ def exponential_decay_with_warmup(learning_rate, ...@@ -95,19 +92,18 @@ def exponential_decay_with_warmup(learning_rate,
warmup_epoch = fluid.layers.fill_constant( warmup_epoch = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=float(warm_up_epoch), force_cpu=True) shape=[1], dtype='float32', value=float(warm_up_epoch), force_cpu=True)
with init_on_cpu(): epoch = ops.floor(global_step / step_each_epoch)
epoch = ops.floor(global_step / step_each_epoch) with fluid.layers.control_flow.Switch() as switch:
with fluid.layers.control_flow.Switch() as switch: with switch.case(epoch < warmup_epoch):
with switch.case(epoch < warmup_epoch): decayed_lr = learning_rate * (global_step /
decayed_lr = learning_rate * (global_step / (step_each_epoch * warmup_epoch))
(step_each_epoch * warmup_epoch)) fluid.layers.assign(input=decayed_lr, output=lr)
fluid.layers.assign(input=decayed_lr, output=lr) with switch.default():
with switch.default(): div_res = (global_step - warmup_epoch * step_each_epoch
div_res = (global_step - warmup_epoch * step_each_epoch ) / decay_epochs
) / decay_epochs div_res = ops.floor(div_res)
div_res = ops.floor(div_res) decayed_lr = learning_rate * (decay_rate**div_res)
decayed_lr = learning_rate * (decay_rate**div_res) fluid.layers.assign(input=decayed_lr, output=lr)
fluid.layers.assign(input=decayed_lr, output=lr)
return lr return lr
......
...@@ -40,11 +40,33 @@ add_arg('test_period', int, 10, "Test period in epoches.") ...@@ -40,11 +40,33 @@ add_arg('test_period', int, 10, "Test period in epoches.")
model_list = [m for m in dir(models) if "__" not in m] model_list = [m for m in dir(models) if "__" not in m]
def get_pruned_params(args, program):
params = []
if args.model == "MobileNet":
for param in program.global_block().all_parameters():
if "_sep_weights" in param.name:
params.append(param.name)
elif args.model == "MobileNetV2":
for param in program.global_block().all_parameters():
if "linear_weights" in param.name or "expand_weights" in param.name:
params.append(param.name)
elif args.model == "ResNet34":
for param in program.global_block().all_parameters():
if "weights" in param.name and "branch" in param.name:
params.append(param.name)
elif args.model == "PVANet":
for param in program.global_block().all_parameters():
if "conv_weights" in param.name:
params.append(param.name)
return params
def piecewise_decay(args): def piecewise_decay(args):
step = int(math.ceil(float(args.total_images) / args.batch_size)) step = int(math.ceil(float(args.total_images) / args.batch_size))
bd = [step * e for e in args.step_epochs] bd = [step * e for e in args.step_epochs]
lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)] lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
learning_rate = fluid.layers.piecewise_decay(boundaries=bd, values=lr) learning_rate = fluid.layers.piecewise_decay(boundaries=bd, values=lr)
optimizer = fluid.optimizer.Momentum( optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate, learning_rate=learning_rate,
momentum=args.momentum_rate, momentum=args.momentum_rate,
...@@ -176,14 +198,11 @@ def compress(args): ...@@ -176,14 +198,11 @@ def compress(args):
end_time - start_time)) end_time - start_time))
batch_id += 1 batch_id += 1
params = [] params = get_pruned_params(args, fluid.default_main_program())
for param in fluid.default_main_program().global_block().all_parameters(): _logger.info("FLOPs before pruning: {}".format(
if "_sep_weights" in param.name:
params.append(param.name)
_logger.info("fops before pruning: {}".format(
flops(fluid.default_main_program()))) flops(fluid.default_main_program())))
pruner = Pruner() pruner = Pruner()
pruned_val_program = pruner.prune( pruned_val_program, _, _ = pruner.prune(
val_program, val_program,
fluid.global_scope(), fluid.global_scope(),
params=params, params=params,
...@@ -191,19 +210,13 @@ def compress(args): ...@@ -191,19 +210,13 @@ def compress(args):
place=place, place=place,
only_graph=True) only_graph=True)
pruned_program = pruner.prune( pruned_program, _, _ = pruner.prune(
fluid.default_main_program(), fluid.default_main_program(),
fluid.global_scope(), fluid.global_scope(),
params=params, params=params,
ratios=[0.33] * len(params), ratios=[0.33] * len(params),
place=place) place=place)
_logger.info("FLOPs after pruning: {}".format(flops(pruned_program)))
for param in pruned_program[0].global_block().all_parameters():
if "weights" in param.name:
print param.name, param.shape
return
_logger.info("fops after pruning: {}".format(flops(pruned_program)))
for i in range(args.num_epochs): for i in range(args.num_epochs):
train(i, pruned_program) train(i, pruned_program)
if i % args.test_period == 0: if i % args.test_period == 0:
......
...@@ -20,8 +20,7 @@ quant_config = { ...@@ -20,8 +20,7 @@ quant_config = {
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'], 'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
'dtype': 'int8', 'dtype': 'int8',
'window_size': 10000, 'window_size': 10000,
'moving_rate': 0.9, 'moving_rate': 0.9
'quant_weight_only': False
} }
``` ```
...@@ -49,7 +48,7 @@ compiled_train_prog = compiled_train_prog.with_data_parallel( ...@@ -49,7 +48,7 @@ compiled_train_prog = compiled_train_prog.with_data_parallel(
### 4. freeze program ### 4. freeze program
``` ```
float_program, int8_program = convert(val_program, float_program, int8_program = convert(val_program,
place, place,
quant_config, quant_config,
scope=None, scope=None,
......
...@@ -78,27 +78,24 @@ def compress(args): ...@@ -78,27 +78,24 @@ def compress(args):
# 1. quantization configs # 1. quantization configs
############################################################################################################ ############################################################################################################
quant_config = { quant_config = {
# weight quantize type, default is 'abs_max' # weight quantize type, default is 'channel_wise_abs_max'
'weight_quantize_type': 'abs_max', 'weight_quantize_type': 'channel_wise_abs_max',
# activation quantize type, default is 'abs_max' # activation quantize type, default is 'moving_average_abs_max'
'activation_quantize_type': 'moving_average_abs_max', 'activation_quantize_type': 'moving_average_abs_max',
# weight quantize bit num, default is 8 # weight quantize bit num, default is 8
'weight_bits': 8, 'weight_bits': 8,
# activation quantize bit num, default is 8 # activation quantize bit num, default is 8
'activation_bits': 8, 'activation_bits': 8,
# op of name_scope in not_quant_pattern list, will not quantized # ops of name_scope in not_quant_pattern list, will not be quantized
'not_quant_pattern': ['skip_quant'], 'not_quant_pattern': ['skip_quant'],
# op of types in quantize_op_types, will quantized # ops of type in quantize_op_types, will be quantized
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'], 'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
# data type after quantization, default is 'int8' # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
'dtype': 'int8', 'dtype': 'int8',
# window size for 'range_abs_max' quantization. defaulf is 10000 # window size for 'range_abs_max' quantization. defaulf is 10000
'window_size': 10000, 'window_size': 10000,
# The decay coefficient of moving average, default is 0.9 # The decay coefficient of moving average, default is 0.9
'moving_rate': 0.9, 'moving_rate': 0.9,
# if set quant_weight_only True, then only quantize parameters of layers which need quantization,
# and insert anti-quantization op for parameters of these layers.
'quant_weight_only': False
} }
train_reader = None train_reader = None
...@@ -141,8 +138,10 @@ def compress(args): ...@@ -141,8 +138,10 @@ def compress(args):
# According to the weight and activation quantization type, the graph will be added # According to the weight and activation quantization type, the graph will be added
# some fake quantize operators and fake dequantize operators. # some fake quantize operators and fake dequantize operators.
############################################################################################################ ############################################################################################################
val_program = quant_aware(val_program, place, quant_config, scope=None, for_test=True) val_program = quant_aware(
compiled_train_prog = quant_aware(train_prog, place, quant_config, scope=None, for_test=False) val_program, place, quant_config, scope=None, for_test=True)
compiled_train_prog = quant_aware(
train_prog, place, quant_config, scope=None, for_test=False)
opt = create_optimizer(args) opt = create_optimizer(args)
opt.minimize(avg_cost) opt.minimize(avg_cost)
...@@ -152,7 +151,8 @@ def compress(args): ...@@ -152,7 +151,8 @@ def compress(args):
if args.pretrained_model: if args.pretrained_model:
def if_exist(var): def if_exist(var):
return os.path.exists(os.path.join(args.pretrained_model, var.name)) return os.path.exists(
os.path.join(args.pretrained_model, var.name))
fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist) fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist)
...@@ -199,9 +199,9 @@ def compress(args): ...@@ -199,9 +199,9 @@ def compress(args):
build_strategy.sync_batch_norm = False build_strategy.sync_batch_norm = False
exec_strategy = fluid.ExecutionStrategy() exec_strategy = fluid.ExecutionStrategy()
compiled_train_prog = compiled_train_prog.with_data_parallel( compiled_train_prog = compiled_train_prog.with_data_parallel(
loss_name=avg_cost.name, loss_name=avg_cost.name,
build_strategy=build_strategy, build_strategy=build_strategy,
exec_strategy=exec_strategy) exec_strategy=exec_strategy)
batch_id = 0 batch_id = 0
for data in train_reader(): for data in train_reader():
...@@ -242,8 +242,8 @@ def compress(args): ...@@ -242,8 +242,8 @@ def compress(args):
# 4. Save inference model # 4. Save inference model
############################################################################################################ ############################################################################################################
model_path = os.path.join(quantization_model_save_dir, args.model, model_path = os.path.join(quantization_model_save_dir, args.model,
'act_' + quant_config['activation_quantize_type'] + '_w_' + quant_config[ 'act_' + quant_config['activation_quantize_type']
'weight_quantize_type']) + '_w_' + quant_config['weight_quantize_type'])
float_path = os.path.join(model_path, 'float') float_path = os.path.join(model_path, 'float')
int8_path = os.path.join(model_path, 'int8') int8_path = os.path.join(model_path, 'int8')
if not os.path.isdir(model_path): if not os.path.isdir(model_path):
...@@ -252,7 +252,8 @@ def compress(args): ...@@ -252,7 +252,8 @@ def compress(args):
fluid.io.save_inference_model( fluid.io.save_inference_model(
dirname=float_path, dirname=float_path,
feeded_var_names=[image.name], feeded_var_names=[image.name],
target_vars=[out], executor=exe, target_vars=[out],
executor=exe,
main_program=float_program, main_program=float_program,
model_filename=float_path + '/model', model_filename=float_path + '/model',
params_filename=float_path + '/params') params_filename=float_path + '/params')
...@@ -260,7 +261,8 @@ def compress(args): ...@@ -260,7 +261,8 @@ def compress(args):
fluid.io.save_inference_model( fluid.io.save_inference_model(
dirname=int8_path, dirname=int8_path,
feeded_var_names=[image.name], feeded_var_names=[image.name],
target_vars=[out], executor=exe, target_vars=[out],
executor=exe,
main_program=int8_program, main_program=int8_program,
model_filename=int8_path + '/model', model_filename=int8_path + '/model',
params_filename=int8_path + '/params') params_filename=int8_path + '/params')
......
# paddleslim.nas API文档 ## 搜索空间参数的配置
通过参数配置搜索空间。更多搜索空间的使用可以参考[search_space](../search_space.md)
**参数:**
- **input_size(int|None)**:- `input_size`表示输入feature map的大小。
- **output_size(int|None)**:- `output_size`表示输出feature map的大小。
- **block_num(int|None)**:- `block_num`表示搜索空间中block的数量。
- **block_mask(list|None)**:- `block_mask`是一组由0、1组成的列表,0表示当前block是normal block,1表示当前block是reduction block。如果设置了`block_mask`,则主要以`block_mask`为主要配置,`input_size``output_size``block_num`三种配置是无效的。
## SANAS API文档
## class SANAS Note:<br>
SANAS(Simulated Annealing Neural Architecture Search)是基于模拟退火算法进行模型结构搜索的算法,一般用于离散搜索任务。 1. reduction block表示经过这个block之后的feature map大小下降为之前的一半,normal block表示经过这个block之后feature map大小不变。<br>
2. `input_size``output_size`用来计算整个模型结构中reduction block数量。
--- ## SANAS
>paddleslim.nas.SANAS(configs, server_addr, init_temperature, reduce_rate, search_steps, save_checkpoint, load_checkpoint, is_server) paddleslim.nas.SANAS(configs, server_addr=("", 8881), init_temperature=100, reduce_rate=0.85, search_steps=300, save_checkpoint='./nas_checkpoint', load_checkpoint=None, is_server=True)[源代码](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/nas/sa_nas.py#L36)
: SANAS(Simulated Annealing Neural Architecture Search)是基于模拟退火算法进行模型结构搜索的算法,一般用于离散搜索任务。
**参数:** **参数:**
- **configs(list<tuple>):** 搜索空间配置列表,格式是`[(key, {input_size, output_size, block_num, block_mask})]`或者`[(key)]`(MobileNetV2、MobilenetV1和ResNet的搜索空间使用和原本网络结构相同的搜索空间,所以仅需指定`key`即可), `input_size``output_size`表示输入和输出的特征图的大小,`block_num`是指搜索网络中的block数量,`block_mask`是一组由0和1组成的列表,0代表不进行下采样的block,1代表下采样的block。 更多paddleslim提供的搜索空间配置可以参考。
- **server_addr(tuple):** SANAS的地址,包括server的ip地址和端口号,如果ip地址为None或者为""的话则默认使用本机ip。默认:("", 8881)。 - **configs(list<tuple>)** - 搜索空间配置列表,格式是`[(key, {input_size, output_size, block_num, block_mask})]`或者`[(key)]`(MobileNetV2、MobilenetV1和ResNet的搜索空间使用和原本网络结构相同的搜索空间,所以仅需指定`key`即可), `input_size``output_size`表示输入和输出的特征图的大小,`block_num`是指搜索网络中的block数量,`block_mask`是一组由0和1组成的列表,0代表不进行下采样的block,1代表下采样的block。 更多paddleslim提供的搜索空间配置可以参考。
- **init_temperature(float):** 基于模拟退火进行搜索的初始温度。默认:100。 - **server_addr(tuple)** - SANAS的地址,包括server的ip地址和端口号,如果ip地址为None或者为""的话则默认使用本机ip。默认:("", 8881)。
- **reduce_rate(float):** 基于模拟退火进行搜索的衰减率。默认:0.85。 - **init_temperature(float)** - 基于模拟退火进行搜索的初始温度。默认:100。
- **search_steps(int):** 搜索过程迭代的次数。默认:300。 - **reduce_rate(float)** - 基于模拟退火进行搜索的衰减率。默认:0.85。
- **save_checkpoint(str|None):** 保存checkpoint的文件目录,如果设置为None的话则不保存checkpoint。默认:`./nas_checkpoint` - **search_steps(int)** - 搜索过程迭代的次数。默认:300。
- **load_checkpoint(str|None):** 加载checkpoint的文件目录,如果设置为None的话则不加载checkpoint。默认:None。 - **save_checkpoint(str|None)** - 保存checkpoint的文件目录,如果设置为None的话则不保存checkpoint。默认:`./nas_checkpoint`
- **is_server(bool):** 当前实例是否要启动一个server。默认:True。 - **load_checkpoint(str|None)** - 加载checkpoint的文件目录,如果设置为None的话则不加载checkpoint。默认:None。
- **is_server(bool)** - 当前实例是否要启动一个server。默认:True。
**返回:**
**返回:**
一个SANAS类的实例 一个SANAS类的实例
**示例代码:** **示例代码:**
...@@ -29,16 +39,19 @@ config = [('MobileNetV2Space')] ...@@ -29,16 +39,19 @@ config = [('MobileNetV2Space')]
sanas = SANAS(config=config) sanas = SANAS(config=config)
``` ```
---
>tokens2arch(tokens) paddlesim.nas.SANAS.tokens2arch(tokens)
通过一组token得到实际的模型结构,一般用来把搜索到最优的token转换为模型结构用来做最后的训练。 : 通过一组token得到实际的模型结构,一般用来把搜索到最优的token转换为模型结构用来做最后的训练。
Note:<br>
tokens是一个列表,token映射到搜索空间转换成相应的网络结构,一组token对应唯一的一个网络结构。
**参数:** **参数:**
- **tokens(list):** 一组token。
**返回** - **tokens(list):** - 一组token。
返回一个模型结构实例。
**返回:**
根据传入的token得到一个模型结构实例。
**示例代码:** **示例代码:**
``` ```
...@@ -49,12 +62,11 @@ for arch in archs: ...@@ -49,12 +62,11 @@ for arch in archs:
output = arch(input) output = arch(input)
input = output input = output
``` ```
---
>next_archs(): paddleslim.nas.SANAS.next_archs()
取下一组模型结构。 : 获取下一组模型结构。
**返回** **返回**
返回模型结构实例的列表,形式为list。 返回模型结构实例的列表,形式为list。
**示例代码:** **示例代码:**
...@@ -67,116 +79,19 @@ for arch in archs: ...@@ -67,116 +79,19 @@ for arch in archs:
input = output input = output
``` ```
---
>reward(score): paddleslim.nas.SANAS.reward(score)
把当前模型结构的得分情况回传。 : 把当前模型结构的得分情况回传。
**参数:** **参数:**
**score<float>:** 当前模型的得分,分数越大越好。
**返回** - **score<float>:** - 当前模型的得分,分数越大越好。
模型结构更新成功或者失败,成功则返回`True`,失败则返回`False`
**返回:**
模型结构更新成功或者失败,成功则返回`True`,失败则返回`False`
**代码示例** paddleslim.nas.SANAS.current_info()
```python : 返回当前token和搜索过程中最好的token和reward。
import numpy as np
import paddle
import paddle.fluid as fluid
from paddleslim.nas import SANAS
from paddleslim.analysis import flops
max_flops = 321208544
batch_size = 256
# 搜索空间配置
config=[('MobileNetV2Space')]
# 实例化SANAS
sa_nas = SANAS(config, server_addr=("", 8887), init_temperature=10.24, reduce_rate=0.85, search_steps=100, is_server=True)
for step in range(100):
archs = sa_nas.next_archs()
train_program = fluid.Program()
test_program = fluid.Program()
startup_program = fluid.Program()
### 构造训练program
with fluid.program_guard(train_program, startup_program):
image = fluid.data(name='image', shape=[None, 3, 32, 32], dtype='float32')
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
for arch in archs:
output = arch(image)
out = fluid.layers.fc(output, size=10, act="softmax")
softmax_out = fluid.layers.softmax(input=out, use_cudnn=False)
cost = fluid.layers.cross_entropy(input=softmax_out, label=label)
avg_cost = fluid.layers.mean(cost)
acc_top1 = fluid.layers.accuracy(input=softmax_out, label=label, k=1)
### 构造测试program
test_program = train_program.clone(for_test=True)
### 定义优化器
sgd = fluid.optimizer.SGD(learning_rate=1e-3)
sgd.minimize(avg_cost)
### 增加限制条件,如果没有则进行无限制搜索
if flops(train_program) > max_flops:
continue
### 定义代码是在cpu上运行
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_program)
### 定义训练输入数据
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.cifar.train10(cycle=False), buf_size=1024),
batch_size=batch_size,
drop_last=True)
### 定义预测输入数据
test_reader = paddle.batch(
paddle.dataset.cifar.test10(cycle=False),
batch_size=batch_size,
drop_last=False)
train_feeder = fluid.DataFeeder([image, label], place, program=train_program)
test_feeder = fluid.DataFeeder([image, label], place, program=test_program)
### 开始训练,每个搜索结果训练5个epoch
for epoch_id in range(5):
for batch_id, data in enumerate(train_reader()):
fetches = [avg_cost.name]
outs = exe.run(train_program,
feed=train_feeder.feed(data),
fetch_list=fetches)[0]
if batch_id % 10 == 0:
print('TRAIN: steps: {}, epoch: {}, batch: {}, cost: {}'.format(step, epoch_id, batch_id, outs[0]))
### 开始预测,得到最终的测试结果作为score回传给sa_nas
reward = []
for batch_id, data in enumerate(test_reader()):
test_fetches = [
avg_cost.name, acc_top1.name
]
batch_reward = exe.run(test_program,
feed=test_feeder.feed(data),
fetch_list=test_fetches)
reward_avg = np.mean(np.array(batch_reward), axis=1)
reward.append(reward_avg)
print('TEST: step: {}, batch: {}, avg_cost: {}, acc_top1: {}'.
format(step, batch_id, batch_reward[0],batch_reward[1]))
finally_reward = np.mean(np.array(reward), axis=0)
print(
'FINAL TEST: avg_cost: {}, acc_top1: {}'.format(
finally_reward[0], finally_reward[1]))
### 回传score
sa_nas.reward(float(finally_reward[1]))
``` **返回:**
搜索过程中最好的token,reward和当前训练的token,形式为dict。
...@@ -4,29 +4,50 @@ ...@@ -4,29 +4,50 @@
通过字典配置量化参数 通过字典配置量化参数
``` ```
quant_config_default = { TENSORRT_OP_TYPES = [
'weight_quantize_type': 'abs_max', 'mul', 'conv2d', 'pool2d', 'depthwise_conv2d', 'elementwise_add',
'activation_quantize_type': 'abs_max', 'leaky_relu'
]
TRANSFORM_PASS_OP_TYPES = ['conv2d', 'depthwise_conv2d', 'mul']
QUANT_DEQUANT_PASS_OP_TYPES = [
"pool2d", "elementwise_add", "concat", "softmax", "argmax", "transpose",
"equal", "gather", "greater_equal", "greater_than", "less_equal",
"less_than", "mean", "not_equal", "reshape", "reshape2",
"bilinear_interp", "nearest_interp", "trilinear_interp", "slice",
"squeeze", "elementwise_sub", "relu", "relu6", "leaky_relu", "tanh", "swish"
]
_quant_config_default = {
# weight quantize type, default is 'channel_wise_abs_max'
'weight_quantize_type': 'channel_wise_abs_max',
# activation quantize type, default is 'moving_average_abs_max'
'activation_quantize_type': 'moving_average_abs_max',
# weight quantize bit num, default is 8
'weight_bits': 8, 'weight_bits': 8,
# activation quantize bit num, default is 8
'activation_bits': 8, 'activation_bits': 8,
# ops of name_scope in not_quant_pattern list, will not be quantized # ops of name_scope in not_quant_pattern list, will not be quantized
'not_quant_pattern': ['skip_quant'], 'not_quant_pattern': ['skip_quant'],
# ops of type in quantize_op_types, will be quantized # ops of type in quantize_op_types, will be quantized
'quantize_op_types': 'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
['conv2d', 'depthwise_conv2d', 'mul', 'elementwise_add', 'pool2d'],
# data type after quantization, such as 'uint8', 'int8', etc. default is 'int8' # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
'dtype': 'int8', 'dtype': 'int8',
# window size for 'range_abs_max' quantization. defaulf is 10000 # window size for 'range_abs_max' quantization. defaulf is 10000
'window_size': 10000, 'window_size': 10000,
# The decay coefficient of moving average, default is 0.9 # The decay coefficient of moving average, default is 0.9
'moving_rate': 0.9, 'moving_rate': 0.9,
# if True, 'quantize_op_types' will be TENSORRT_OP_TYPES
'for_tensorrt': False,
# if True, 'quantoze_op_types' will be TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES
'is_full_quantize': False
} }
``` ```
**参数:** **参数:**
- **weight_quantize_type(str)** - 参数量化方式。可选``'abs_max'``, ``'channel_wise_abs_max'``, ``'range_abs_max'``, ``'moving_average_abs_max'`` 默认``'abs_max'`` - **weight_quantize_type(str)** - 参数量化方式。可选``'abs_max'``, ``'channel_wise_abs_max'``, ``'range_abs_max'``, ``'moving_average_abs_max'``如果使用``TensorRT``加载量化后的模型来预测,请使用``'channel_wise_abs_max'``。 默认``'channel_wise_abs_max'``
- **activation_quantize_type(str)** - 激活量化方式,可选``'abs_max'``, ``'range_abs_max'``, ``'moving_average_abs_max'``,默认``'abs_max'`` - **activation_quantize_type(str)** - 激活量化方式,可选``'abs_max'``, ``'range_abs_max'``, ``'moving_average_abs_max'``。如果使用``TensorRT``加载量化后的模型来预测,请使用``'range_abs_max', 'moving_average_abs_max'``。,默认``'moving_average_abs_max'``
- **weight_bits(int)** - 参数量化bit数,默认8, 推荐设为8。 - **weight_bits(int)** - 参数量化bit数,默认8, 推荐设为8。
- **activation_bits(int)** - 激活量化bit数,默认8, 推荐设为8。 - **activation_bits(int)** - 激活量化bit数,默认8, 推荐设为8。
- **not_quant_pattern(str | list[str])** - 所有``name_scope``包含``'not_quant_pattern'``字符串的``op``,都不量化, 设置方式请参考[*fluid.name_scope*](https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/fluid_cn/name_scope_cn.html#name-scope) - **not_quant_pattern(str | list[str])** - 所有``name_scope``包含``'not_quant_pattern'``字符串的``op``,都不量化, 设置方式请参考[*fluid.name_scope*](https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/fluid_cn/name_scope_cn.html#name-scope)
...@@ -34,7 +55,12 @@ quant_config_default = { ...@@ -34,7 +55,12 @@ quant_config_default = {
- **dtype(int8)** - 量化后的参数类型,默认 ``int8``, 目前仅支持``int8`` - **dtype(int8)** - 量化后的参数类型,默认 ``int8``, 目前仅支持``int8``
- **window_size(int)** - ``'range_abs_max'``量化方式的``window size``,默认10000。 - **window_size(int)** - ``'range_abs_max'``量化方式的``window size``,默认10000。
- **moving_rate(int)** - ``'moving_average_abs_max'``量化方式的衰减系数,默认 0.9。 - **moving_rate(int)** - ``'moving_average_abs_max'``量化方式的衰减系数,默认 0.9。
- **for_tensorrt(bool)** - 量化后的模型是否使用``TensorRT``进行预测。如果是的话,量化op类型为:``TENSORRT_OP_TYPES``。默认值为False.
- **is_full_quantize(bool)** - 是否量化所有可支持op类型。默认值为False.
!!! note "注意事项"
- 目前``Paddle-Lite``有int8 kernel来加速的op只有 ``['conv2d', 'depthwise_conv2d', 'mul']``, 其他op的int8 kernel将陆续支持。
## quant_aware ## quant_aware
paddleslim.quant.quant_aware(program, place, config, scope=None, for_test=False)[[源代码]](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/quant/quanter.py) paddleslim.quant.quant_aware(program, place, config, scope=None, for_test=False)[[源代码]](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/quant/quanter.py)
...@@ -67,7 +93,7 @@ paddleslim.quant.quant_aware(program, place, config, scope=None, for_test=False) ...@@ -67,7 +93,7 @@ paddleslim.quant.quant_aware(program, place, config, scope=None, for_test=False)
## convert ## convert
paddleslim.quant.convert(program, place, config, scope=None, save_int8=False)[[源代码]](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/quant/quanter.py) paddleslim.quant.convert(program, place, config, scope=None, save_int8=False)[[源代码]](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/quant/quanter.py)
...@@ -135,7 +161,7 @@ inference_prog = quant.convert(quant_eval_program, place, config) ...@@ -135,7 +161,7 @@ inference_prog = quant.convert(quant_eval_program, place, config)
更详细的用法请参考 <a href='https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/quant/quant_aware'>量化训练demo</a> 更详细的用法请参考 <a href='https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/quant/quant_aware'>量化训练demo</a>
## quant_post ## quant_post
paddleslim.quant.quant_post(executor, model_dir, quantize_model_path,sample_generator, model_filename=None, params_filename=None, batch_size=16,batch_nums=None, scope=None, algo='KL', quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"])[[源代码]](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/quant/quanter.py) paddleslim.quant.quant_post(executor, model_dir, quantize_model_path,sample_generator, model_filename=None, params_filename=None, batch_size=16,batch_nums=None, scope=None, algo='KL', quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], is_full_quantize=False, is_use_cache_file=False, cache_dir="./temp_post_training")[[源代码]](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/quant/quanter.py)
: 对保存在``${model_dir}``下的模型进行量化,使用``sample_generator``的数据进行参数校正。 : 对保存在``${model_dir}``下的模型进行量化,使用``sample_generator``的数据进行参数校正。
...@@ -152,6 +178,9 @@ paddleslim.quant.quant_post(executor, model_dir, quantize_model_path,sample_gene ...@@ -152,6 +178,9 @@ paddleslim.quant.quant_post(executor, model_dir, quantize_model_path,sample_gene
- **scope(fluid.Scope, optional)** - 用来获取和写入``Variable``, 如果设置为``None``,则使用[*fluid.global_scope()*](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api_cn/executor_cn/global_scope_cn.html). 默认值是``None``. - **scope(fluid.Scope, optional)** - 用来获取和写入``Variable``, 如果设置为``None``,则使用[*fluid.global_scope()*](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api_cn/executor_cn/global_scope_cn.html). 默认值是``None``.
- **algo(str)** - 量化时使用的算法名称,可为``'KL'``或者``'direct'``。该参数仅针对激活值的量化,因为参数值的量化使用的方式为``'channel_wise_abs_max'``. 当``algo`` 设置为``'direct'``时,使用校正数据的激活值的绝对值的最大值当作``Scale``值,当设置为``'KL'``时,则使用``KL``散度的方法来计算``Scale``值。默认值为``'KL'`` - **algo(str)** - 量化时使用的算法名称,可为``'KL'``或者``'direct'``。该参数仅针对激活值的量化,因为参数值的量化使用的方式为``'channel_wise_abs_max'``. 当``algo`` 设置为``'direct'``时,使用校正数据的激活值的绝对值的最大值当作``Scale``值,当设置为``'KL'``时,则使用``KL``散度的方法来计算``Scale``值。默认值为``'KL'``
- **quantizable_op_type(list[str])** - 需要量化的``op``类型列表。默认值为``["conv2d", "depthwise_conv2d", "mul"]`` - **quantizable_op_type(list[str])** - 需要量化的``op``类型列表。默认值为``["conv2d", "depthwise_conv2d", "mul"]``
- **is_full_quantize(bool)** - 是否量化所有可支持的op类型。如果设置为False, 则按照 ``'quantizable_op_type'`` 的设置进行量化。
- **is_use_cache_file(bool)** - 是否使用硬盘对中间结果进行存储。如果为False, 则将中间结果存储在内存中。
- **cache_dir(str)** - 如果 ``'is_use_cache_file'``为True, 则将中间结果存储在此参数设置的路径下。
**返回** **返回**
...@@ -159,7 +188,8 @@ paddleslim.quant.quant_post(executor, model_dir, quantize_model_path,sample_gene ...@@ -159,7 +188,8 @@ paddleslim.quant.quant_post(executor, model_dir, quantize_model_path,sample_gene
!!! note "注意事项" !!! note "注意事项"
因为该接口会收集校正数据的所有的激活值,所以使用的校正图片不能太多。``'KL'``散度的计算也比较耗时。 - 因为该接口会收集校正数据的所有的激活值,当校正图片比较多时,请设置``'is_use_cache_file'``为True, 将中间结果存储在硬盘中。另外,``'KL'``散度的计算比较耗时。
- 目前``Paddle-Lite``有int8 kernel来加速的op只有 ``['conv2d', 'depthwise_conv2d', 'mul']``, 其他op的int8 kernel将陆续支持。
**代码示例** **代码示例**
......
## merge ## merge
paddleslim.dist.merge(teacher_program, student_program, data_name_map, place, scope=fluid.global_scope(), name_prefix='teacher_') [[源代码]](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/dist/single_distiller.py#L19) paddleslim.dist.merge(teacher_program, student_program, data_name_map, place, scope=fluid.global_scope(), name_prefix='teacher_') [[源代码]](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/dist/single_distiller.py#L19)
: merge将两个paddle program(teacher_program, student_program)融合为一个program,并将融合得到的program返回。在融合的program中,可以为其中合适的teacher特征图和student特征图添加蒸馏损失函数,从而达到用teacher模型的暗知识(Dark Knowledge)指导student模型学习的目的。 : merge将teacher_program融合到student_program中。在融合的program中,可以为其中合适的teacher特征图和student特征图添加蒸馏损失函数,从而达到用teacher模型的暗知识(Dark Knowledge)指导student模型学习的目的。
**参数:** **参数:**
...@@ -12,7 +12,7 @@ paddleslim.dist.merge(teacher_program, student_program, data_name_map, place, sc ...@@ -12,7 +12,7 @@ paddleslim.dist.merge(teacher_program, student_program, data_name_map, place, sc
- **scope**(Scope)-该参数表示程序使用的变量作用域,如果不指定将使用默认的全局作用域。默认值:[*fluid.global_scope()*](https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/fluid_cn/global_scope_cn.html#global-scope) - **scope**(Scope)-该参数表示程序使用的变量作用域,如果不指定将使用默认的全局作用域。默认值:[*fluid.global_scope()*](https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/fluid_cn/global_scope_cn.html#global-scope)
- **name_prefix**(str)-merge操作将统一为teacher的[*Variables*](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.3/api_guides/low_level/program.html#variable)添加的名称前缀name_prefix。默认值:'teacher_' - **name_prefix**(str)-merge操作将统一为teacher的[*Variables*](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.3/api_guides/low_level/program.html#variable)添加的名称前缀name_prefix。默认值:'teacher_'
**返回:** 由student_program和teacher_program merge得到的program **返回:**
!!! note "Note" !!! note "Note"
*data_name_map***teacher_var name到student_var name的映射**,如果写反可能无法正确进行merge *data_name_map***teacher_var name到student_var name的映射**,如果写反可能无法正确进行merge
...@@ -37,8 +37,8 @@ with fluid.program_guard(teacher_program): ...@@ -37,8 +37,8 @@ with fluid.program_guard(teacher_program):
data_name_map = {'y':'x'} data_name_map = {'y':'x'}
USE_GPU = False USE_GPU = False
place = fluid.CUDAPlace(0) if USE_GPU else fluid.CPUPlace() place = fluid.CUDAPlace(0) if USE_GPU else fluid.CPUPlace()
main_program = dist.merge(teacher_program, student_program, dist.merge(teacher_program, student_program,
data_name_map, place) data_name_map, place)
``` ```
...@@ -76,10 +76,10 @@ with fluid.program_guard(teacher_program): ...@@ -76,10 +76,10 @@ with fluid.program_guard(teacher_program):
data_name_map = {'y':'x'} data_name_map = {'y':'x'}
USE_GPU = False USE_GPU = False
place = fluid.CUDAPlace(0) if USE_GPU else fluid.CPUPlace() place = fluid.CUDAPlace(0) if USE_GPU else fluid.CPUPlace()
main_program = merge(teacher_program, student_program, data_name_map, place) merge(teacher_program, student_program, data_name_map, place)
with fluid.program_guard(main_program): with fluid.program_guard(student_program):
distillation_loss = dist.fsp_loss('teacher_t1.tmp_1', 'teacher_t2.tmp_1', distillation_loss = dist.fsp_loss('teacher_t1.tmp_1', 'teacher_t2.tmp_1',
's1.tmp_1', 's2.tmp_1', main_program) 's1.tmp_1', 's2.tmp_1', main_program)
``` ```
...@@ -91,7 +91,7 @@ paddleslim.dist.l2_loss(teacher_var_name, student_var_name, program=fluid.defaul ...@@ -91,7 +91,7 @@ paddleslim.dist.l2_loss(teacher_var_name, student_var_name, program=fluid.defaul
**参数:** **参数:**
- **teacher_var_name**(str): teacher_var的名称. - **teacher_var_name**(str): teacher_var的名称.
- **student_var_name**(str): student_var的名称. - **student_var_name**(str): student_var的名称.
- **program**(Program): 用于蒸馏训练的fluid program。默认值:[*fluid.default_main_program()*](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.3/api_cn/fluid_cn.html#default-main-program) - **program**(Program): 用于蒸馏训练的fluid program。默认值:[*fluid.default_main_program()*](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.3/api_cn/fluid_cn.html#default-main-program)
...@@ -116,10 +116,10 @@ with fluid.program_guard(teacher_program): ...@@ -116,10 +116,10 @@ with fluid.program_guard(teacher_program):
data_name_map = {'y':'x'} data_name_map = {'y':'x'}
USE_GPU = False USE_GPU = False
place = fluid.CUDAPlace(0) if USE_GPU else fluid.CPUPlace() place = fluid.CUDAPlace(0) if USE_GPU else fluid.CPUPlace()
main_program = merge(teacher_program, student_program, data_name_map, place) merge(teacher_program, student_program, data_name_map, place)
with fluid.program_guard(main_program): with fluid.program_guard(student_program):
distillation_loss = dist.l2_loss('teacher_t2.tmp_1', 's2.tmp_1', distillation_loss = dist.l2_loss('teacher_t2.tmp_1', 's2.tmp_1',
main_program) main_program)
``` ```
...@@ -131,11 +131,11 @@ paddleslim.dist.soft_label_loss(teacher_var_name, student_var_name, program=flui ...@@ -131,11 +131,11 @@ paddleslim.dist.soft_label_loss(teacher_var_name, student_var_name, program=flui
**参数:** **参数:**
- **teacher_var_name**(str): teacher_var的名称. - **teacher_var_name**(str): teacher_var的名称.
- **student_var_name**(str): student_var的名称. - **student_var_name**(str): student_var的名称.
- **program**(Program): 用于蒸馏训练的fluid program。默认值:[*fluid.default_main_program()*](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.3/api_cn/fluid_cn.html#default-main-program) - **program**(Program): 用于蒸馏训练的fluid program。默认值:[*fluid.default_main_program()*](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.3/api_cn/fluid_cn.html#default-main-program)
- **teacher_temperature**(float): 对teacher_var进行soft操作的温度值,温度值越大得到的特征图越平滑 - **teacher_temperature**(float): 对teacher_var进行soft操作的温度值,温度值越大得到的特征图越平滑
- **student_temperature**(float): 对student_var进行soft操作的温度值,温度值越大得到的特征图越平滑 - **student_temperature**(float): 对student_var进行soft操作的温度值,温度值越大得到的特征图越平滑
**返回:** 由teacher_var, student_var组合得到的soft_label_loss **返回:** 由teacher_var, student_var组合得到的soft_label_loss
...@@ -158,10 +158,10 @@ with fluid.program_guard(teacher_program): ...@@ -158,10 +158,10 @@ with fluid.program_guard(teacher_program):
data_name_map = {'y':'x'} data_name_map = {'y':'x'}
USE_GPU = False USE_GPU = False
place = fluid.CUDAPlace(0) if USE_GPU else fluid.CPUPlace() place = fluid.CUDAPlace(0) if USE_GPU else fluid.CPUPlace()
main_program = merge(teacher_program, student_program, data_name_map, place) merge(teacher_program, student_program, data_name_map, place)
with fluid.program_guard(main_program): with fluid.program_guard(student_program):
distillation_loss = dist.soft_label_loss('teacher_t2.tmp_1', distillation_loss = dist.soft_label_loss('teacher_t2.tmp_1',
's2.tmp_1', main_program, 1., 1.) 's2.tmp_1', main_program, 1., 1.)
``` ```
...@@ -173,7 +173,7 @@ paddleslim.dist.loss(loss_func, program=fluid.default_main_program(), **kwargs) ...@@ -173,7 +173,7 @@ paddleslim.dist.loss(loss_func, program=fluid.default_main_program(), **kwargs)
**参数:** **参数:**
- **loss_func**(python function): 自定义的损失函数,输入为teacher var和student var,输出为自定义的loss - **loss_func**(python function): 自定义的损失函数,输入为teacher var和student var,输出为自定义的loss
- **program**(Program): 用于蒸馏训练的fluid program。默认值:[*fluid.default_main_program()*](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.3/api_cn/fluid_cn.html#default-main-program) - **program**(Program): 用于蒸馏训练的fluid program。默认值:[*fluid.default_main_program()*](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.3/api_cn/fluid_cn.html#default-main-program)
- **\**kwargs**: loss_func输入名与对应variable名称 - **\**kwargs**: loss_func输入名与对应variable名称
...@@ -198,15 +198,15 @@ with fluid.program_guard(teacher_program): ...@@ -198,15 +198,15 @@ with fluid.program_guard(teacher_program):
data_name_map = {'y':'x'} data_name_map = {'y':'x'}
USE_GPU = False USE_GPU = False
place = fluid.CUDAPlace(0) if USE_GPU else fluid.CPUPlace() place = fluid.CUDAPlace(0) if USE_GPU else fluid.CPUPlace()
main_program = merge(teacher_program, student_program, data_name_map, place) merge(teacher_program, student_program, data_name_map, place)
def adaptation_loss(t_var, s_var): def adaptation_loss(t_var, s_var):
teacher_channel = t_var.shape[1] teacher_channel = t_var.shape[1]
s_hint = fluid.layers.conv2d(s_var, teacher_channel, 1) s_hint = fluid.layers.conv2d(s_var, teacher_channel, 1)
hint_loss = fluid.layers.reduce_mean(fluid.layers.square(s_hint - t_var)) hint_loss = fluid.layers.reduce_mean(fluid.layers.square(s_hint - t_var))
return hint_loss return hint_loss
with fluid.program_guard(main_program): with fluid.program_guard(student_program):
distillation_loss = dist.loss(main_program, adaptation_loss, distillation_loss = dist.loss(main_program, adaptation_loss,
t_var='teacher_t2.tmp_1', s_var='s2.tmp_1') t_var='teacher_t2.tmp_1', s_var='s2.tmp_1')
``` ```
!!! note "注意事项" !!! note "注意事项"
......
## 1. 图象分类
数据集:ImageNet1000类
### 1.1 量化
| 模型 | 压缩方法 | Top-1/Top-5 Acc | 模型体积(MB) | 下载 |
|:--:|:---:|:--:|:--:|:--:|
|MobileNetV1|-|70.99%/89.68%| xx | [下载链接]() |
|MobileNetV1|quant_post|xx%/xx%| xx | [下载链接]() |
|MobileNetV1|quant_aware|xx%/xx%| xx | [下载链接]() |
| MobileNetV2 | - |72.15%/90.65%| xx | [下载链接]() |
| MobileNetV2 | quant_post |xx%/xx%| xx | [下载链接]() |
| MobileNetV2 | quant_aware |xx%/xx%| xx | [下载链接]() |
|ResNet50|-|76.50%/93.00%| xx | [下载链接]() |
|ResNet50|quant_post|xx%/xx%| xx | [下载链接]() |
|ResNet50|quant_aware|xx%/xx%| xx | [下载链接]() |
### 1.2 剪裁
| 模型 | 压缩方法 | Top-1/Top-5 Acc | 模型体积(MB) | GFLOPs | 下载 |
|:--:|:---:|:--:|:--:|:--:|:--:|
| MobileNetV1 | Baseline | 70.99%/89.68% | 17 | 1.11 | [下载链接](http://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV1_pretrained.tar) |
| MobileNetV1 | uniform -50% | 69.4%/88.66% (-1.59%/-1.02%) | 9 | 0.56 | [下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/MobileNetV1_uniform-50.tar) |
| MobileNetV1 | sensitive -30% | 70.4%/89.3% (-0.59%/-0.38%) | 12 | 0.74 | [下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/MobileNetV1_sensitive-30.tar) |
| MobileNetV1 | sensitive -50% | 69.8% / 88.9% (-1.19%/-0.78%) | 9 | 0.56 | [下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/MobileNetV1_sensitive-50.tar) |
| MobileNetV2 | - | 72.15%/90.65% | 15 | 0.59 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_pretrained.tar) |
| MobileNetV2 | uniform -50% | 65.79%/86.11% (-6.35%/-4.47%) | 11 | 0.296 | [下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/MobileNetV2_uniform-50.tar) |
| ResNet34 | - | 72.15%/90.65% | 84 | 7.36 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/ResNet34_pretrained.tar) |
| ResNet34 | uniform -50% | 70.99%/89.95% (-1.36%/-0.87%) | 41 | 3.67 | [下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/ResNet34_uniform-50.tar) |
| ResNet34 | auto -55.05% | 70.24%/89.63% (-2.04%/-1.06%) | 33 | 3.31 | [下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/ResNet34_auto-55.tar) |
### 1.3 蒸馏
| 模型 | 压缩方法 | Top-1/Top-5 Acc | 模型体积(MB) | 下载 |
|:--:|:---:|:--:|:--:|:--:|
| MobileNetV1 | student | 70.99%/89.68% | 17 | [下载链接](http://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV1_pretrained.tar) |
|ResNet50_vd|teacher|79.12%/94.44%| 99 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar) |
|MobileNetV1|ResNet50_vd<sup>[1](#trans1)</sup> distill|72.77%/90.68% (+1.78%/+1.00%)| 17 | [下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/MobileNetV1_distilled.tar) |
| MobileNetV2 | student | 72.15%/90.65% | 15 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_pretrained.tar) |
| MobileNetV2 | ResNet50_vd distill | 74.28%/91.53% (+2.13%/+0.88%) | 15 | [下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/MobileNetV2_distilled.tar) |
| ResNet50 | student | 76.50%/93.00% | 99 | [下载链接](http://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar) |
|ResNet101|teacher|77.56%/93.64%| 173 | [下载链接](http://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_pretrained.tar) |
| ResNet50 | ResNet101 distill | 77.29%/93.65% (+0.79%/+0.65%) | 99 | [下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/ResNet50_distilled.tar) |
!!! note "Note"
<a name="trans1">[1]</a>:带_vd后缀代表该预训练模型使用了Mixup,Mixup相关介绍参考[mixup: Beyond Empirical Risk Minimization](https://arxiv.org/abs/1710.09412)
## 2. 目标检测
### 2.1 量化
数据集: COCO 2017
| 模型 | 压缩方法 | 数据集 | Image/GPU | 输入608 Box AP | 输入416 Box AP | 输入320 Box AP | 模型体积(MB) | 下载 |
| :----------------------------: | :---------: | :----: | :-------: | :------------: | :------------: | :------------: | :------------: | :----------: |
| MobileNet-V1-YOLOv3 | - | COCO | 8 | 29.3 | 29.3 | 27.1 | xx | [下载链接]() |
| MobileNet-V1-YOLOv3 | quant_post | COCO | 8 | xx | xx | xx | xx | [下载链接]() |
| MobileNet-V1-YOLOv3 | quant_aware | COCO | 8 | xx | xx | xx | xx | [下载链接]() |
| R50-dcn-YOLOv3 obj365_pretrain | - | COCO | 8 | 41.4 | xx | xx | xx | [下载链接]() |
| R50-dcn-YOLOv3 obj365_pretrain | quant_post | COCO | 8 | xx | xx | xx | xx | [下载链接]() |
| R50-dcn-YOLOv3 obj365_pretrain | quant_aware | COCO | 8 | xx | xx | xx | xx | [下载链接]() |
数据集:WIDER-FACE
| 模型 | 压缩方法 | Image/GPU | 输入尺寸 | Easy/Medium/Hard | 模型体积(MB) | 下载 |
| :------------: | :---------: | :-------: | :------: | :---------------: | :------------: | :----------: |
| BlazeFace | - | 8 | 640 | 0.915/0.892/0.797 | xx | [下载链接]() |
| BlazeFace | quant_post | 8 | 640 | xx/xx/xx | xx | [下载链接]() |
| BlazeFace | quant_aware | 8 | 640 | xx/xx/xx | xx | [下载链接]() |
| BlazeFace-Lite | - | 8 | 640 | 0.909/0.885/0.781 | xx | [下载链接]() |
| BlazeFace-Lite | quant_post | 8 | 640 | xx/xx/xx | xx | [下载链接]() |
| BlazeFace-Lite | quant_aware | 8 | 640 | xx/xx/xx | xx | [下载链接]() |
| BlazeFace-NAS | - | 8 | 640 | 0.837/0.807/0.658 | xx | [下载链接]() |
| BlazeFace-NAS | quant_post | 8 | 640 | xx/xx/xx | xx | [下载链接]() |
| BlazeFace-NAS | quant_aware | 8 | 640 | xx/xx/xx | xx | [下载链接]() |
### 2.2 剪裁
数据集:Pasacl VOC & COCO 2017
| 模型 | 压缩方法 | 数据集 | Image/GPU | 输入608 Box AP | 输入416 Box AP | 输入320 Box AP | 模型体积(MB) | GFLOPs (608*608) | 下载 |
| :----------------------------: | :---------------: | :--------: | :-------: | :------------: | :------------: | :------------: | :----------: | :--------------: | :----------------------------------------------------------: |
| MobileNet-V1-YOLOv3 | Baseline | Pascal VOC | 8 | 76.2 | 76.7 | 75.3 | 94 | 40.49 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1_voc.tar) |
| MobileNet-V1-YOLOv3 | sensitive -52.88% | Pascal VOC | 8 | 77.6 (+1.4) | 77.7 (1.0) | 75.5 (+0.2) | 31 | 19.08 | [下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_mobilenet_v1_voc_prune.tar) |
| MobileNet-V1-YOLOv3 | - | COCO | 8 | 29.3 | 29.3 | 27.0 | 95 | 41.35 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1.tar) |
| MobileNet-V1-YOLOv3 | sensitive -51.77% | COCO | 8 | 26.0 (-3.3) | 25.1 (-4.2) | 22.6 (-4.4) | 32 | 19.94 | [下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_mobilenet_v1_prune.tar) |
| R50-dcn-YOLOv3 | - | COCO | 8 | 39.1 | - | - | 177 | 89.60 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r50vd_dcn.tar) |
| R50-dcn-YOLOv3 | sensitive -9.37% | COCO | 8 | 39.3 (+0.2) | - | - | 150 | 81.20 | [下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_r50vd_dcn_prune.tar) |
| R50-dcn-YOLOv3 | sensitive -24.68% | COCO | 8 | 37.3 (-1.8) | - | - | 113 | 67.48 | [下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_r50vd_dcn_prune578.tar) |
| R50-dcn-YOLOv3 obj365_pretrain | - | COCO | 8 | 41.4 | - | - | 177 | 89.60 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r50vd_dcn_obj365_pretrained_coco.tar) |
| R50-dcn-YOLOv3 obj365_pretrain | sensitive -9.37% | COCO | 8 | 40.5 (-0.9) | - | - | 150 | 81.20 | [下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_r50vd_dcn_obj365_pretrained_coco_prune.tar) |
| R50-dcn-YOLOv3 obj365_pretrain | sensitive -24.68% | COCO | 8 | 37.8 (-3.3) | - | - | 113 | 67.48 | [下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_r50vd_dcn_obj365_pretrained_coco_prune578.tar) |
### 2.3 蒸馏
数据集:Pasacl VOC & COCO 2017
| 模型 | 压缩方法 | 数据集 | Image/GPU | 输入608 Box AP | 输入416 Box AP | 输入320 Box AP | 模型体积(MB) | 下载 |
| :-----------------: | :---------------------: | :--------: | :-------: | :------------: | :------------: | :------------: | :------------: | :----------------------------------------------------------: |
| MobileNet-V1-YOLOv3 | - | Pascal VOC | 8 | 76.2 | 76.7 | 75.3 | 94 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1_voc.tar) |
| ResNet34-YOLOv3 | - | Pascal VOC | 8 | 82.6 | 81.9 | 80.1 | 162 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34_voc.tar) |
| MobileNet-V1-YOLOv3 | ResNet34-YOLOv3 distill | Pascal VOC | 8 | 79.0 (+2.8) | 78.2 (+1.5) | 75.5 (+0.2) | 94 | [下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_mobilenetv1_voc_distilled.tar) |
| MobileNet-V1-YOLOv3 | - | COCO | 8 | 29.3 | 29.3 | 27.0 | 95 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1.tar) |
| ResNet34-YOLOv3 | - | COCO | 8 | 36.2 | 34.3 | 31.4 | 163 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34.tar) |
| MobileNet-V1-YOLOv3 | ResNet34-YOLOv3 distill | COCO | 8 | 31.4 (+2.1) | 30.0 (+0.7) | 27.1 (+0.1) | 95 | [下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_mobilenetv1_coco_distilled.tar) |
## 3. 图像分割
数据集:Cityscapes
### 3.1 量化
| 模型 | 压缩方法 | mIoU | 模型体积(MB) | 下载 |
| :--------------------: | :---------: | :---: | :------------: | :----------: |
| DeepLabv3+/MobileNetv1 | - | 63.26 | xx | [下载链接]() |
| DeepLabv3+/MobileNetv1 | quant_post | xx | xx | [下载链接]() |
| DeepLabv3+/MobileNetv1 | quant_aware | xx | xx | [下载链接]() |
| DeepLabv3+/MobileNetv2 | - | 69.81 | xx | [下载链接]() |
| DeepLabv3+/MobileNetv2 | quant_post | xx | xx | [下载链接]() |
| DeepLabv3+/MobileNetv2 | quant_aware | xx | xx | [下载链接]() |
### 3.2 剪裁
| 模型 | 压缩方法 | mIoU | 模型体积(MB) | GFLOPs | 下载 |
| :-------: | :---------------: | :-----------: | :------------: | :----: | :----------------------------------------------------------: |
| fast-scnn | baseline | 69.64 | 11 | 14.41 | [下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/fast_scnn_cityscape.tar) |
| fast-scnn | uniform -17.07% | 69.58 (-0.06) | 8.5 | 11.95 | [下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/fast_scnn_cityscape_uniform-17.tar) |
| fast-scnn | sensitive -47.60% | 66.68 (-2.96) | 5.7 | 7.55 | [下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/fast_scnn_cityscape_sensitive-47.tar) |
# paddleslim.nas 提供的搜索空间: ## 搜索空间简介
: 搜索空间是神经网络搜索中的一个概念。搜索空间是一系列模型结构的汇集, SANAS主要是利用模拟退火的思想在搜索空间中搜索到一个比较小的模型结构或者一个精度比较高的模型结构。
1. 根据原本模型结构构造搜索空间: ## paddleslim.nas 提供的搜索空间
1.1 MobileNetV2Space ##### 根据初始模型结构构造搜索空间
1. MobileNetV2Space<br>
1.2 MobileNetV1Space &emsp; MobileNetV2的网络结构可以参考:[代码](https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/image_classification/models/mobilenet_v2.py#L29)[论文](https://arxiv.org/abs/1801.04381)
1.3 ResNetSpace
2. MobileNetV1Space<br>
&emsp; MobilNetV1的网络结构可以参考:[代码](https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/image_classification/models/mobilenet_v1.py#L29)[论文](https://arxiv.org/abs/1704.04861)
2. 根据相应模型的block构造搜索空间 3. ResNetSpace<br>
&emsp; ResNetSpace的网络结构可以参考:[代码](https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/image_classification/models/resnet.py#L30)[论文](https://arxiv.org/pdf/1512.03385.pdf)
2.1 MobileNetV1BlockSpace
2.2 MobileNetV2BlockSpace
2.3 ResNetBlockSpace
2.4 InceptionABlockSpace
2.5 InceptionCBlockSpace
##### 根据相应模型的block构造搜索空间
1. MobileNetV1BlockSpace<br>
&emsp; MobileNetV1Block的结构可以参考:[代码](https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/image_classification/models/mobilenet_v1.py#L173)
##搜索空间的配置介绍: 2. MobileNetV2BlockSpace<br>
&emsp; MobileNetV2Block的结构可以参考:[代码](https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/image_classification/models/mobilenet_v2.py#L174)
**input_size(int|None)**`input_size`表示输入feature map的大小。 3. ResNetBlockSpace<br>
**output_size(int|None)**`output_size`表示输出feature map的大小。 &emsp; ResNetBlock的结构可以参考:[代码](https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/image_classification/models/resnet.py#L148)
**block_num(int|None)**`block_num`表示搜索空间中block的数量。
**block_mask(list|None)**`block_mask`表示当前的block是一个reduction block还是一个normal block,是一组由0、1组成的列表,0表示当前block是normal block,1表示当前block是reduction block。如果设置了`block_mask`,则主要以`block_mask`为主要配置,`input_size``output_size``block_num`三种配置是无效的。
**Note:** 4. InceptionABlockSpace<br>
1. reduction block表示经过这个block之后的feature map大小下降为之前的一半,normal block表示经过这个block之后feature map大小不变。 &emsp; InceptionABlock的结构可以参考:[代码](https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/image_classification/models/inception_v4.py#L140)
2. `input_size``output_size`用来计算整个模型结构中reduction block数量。
5. InceptionCBlockSpace<br>
&emsp; InceptionCBlock结构可以参考:[代码](https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/image_classification/models/inception_v4.py#L291)
##搜索空间示例:
1. 使用paddleslim中提供用原本的模型结构来构造搜索空间的话,仅需要指定搜索空间名字即可。例如:如果使用原本的MobileNetV2的搜索空间进行搜索的话,传入SANAS中的config直接指定为[('MobileNetV2Space')]。 ## 搜索空间示例
2. 使用paddleslim中提供的block搜索空间构造搜索空间:
2.1 使用`input_size`, `output_size``block_num`来构造搜索空间。例如:传入SANAS的config可以指定为[('MobileNetV2BlockSpace', {'input_size': 224, 'output_size': 32, 'block_num': 10})]。 1. 使用paddleslim中提供用初始的模型结构来构造搜索空间的话,仅需要指定搜索空间名字即可。例如:如果使用原本的MobileNetV2的搜索空间进行搜索的话,传入SANAS中的config直接指定为[('MobileNetV2Space')]。
2. 使用paddleslim中提供的block搜索空间构造搜索空间:<br>
2.1 使用`input_size`, `output_size``block_num`来构造搜索空间。例如:传入SANAS的config可以指定为[('MobileNetV2BlockSpace', {'input_size': 224, 'output_size': 32, 'block_num': 10})]。<br>
2.2 使用`block_mask`构造搜索空间。例如:传入SANAS的config可以指定为[('MobileNetV2BlockSpace', {'block_mask': [0, 1, 1, 1, 1, 0, 1, 0]})]。 2.2 使用`block_mask`构造搜索空间。例如:传入SANAS的config可以指定为[('MobileNetV2BlockSpace', {'block_mask': [0, 1, 1, 1, 1, 0, 1, 0]})]。
# 自定义搜索空间(search space) ## 自定义搜索空间(search space)
自定义搜索空间类需要继承搜索空间基类并重写以下几部分: 自定义搜索空间类需要继承搜索空间基类并重写以下几部分:<br>
1. 初始化的tokens(`init_tokens`函数),可以设置为自己想要的tokens列表, tokens列表中的每个数字指的是当前数字在相应的搜索列表中的索引。例如本示例中若tokens=[0, 3, 5],则代表当前模型结构搜索到的通道数为[8, 40, 128]。 &emsp; 1. 初始化的tokens(`init_tokens`函数),可以设置为自己想要的tokens列表, tokens列表中的每个数字指的是当前数字在相应的搜索列表中的索引。例如本示例中若tokens=[0, 3, 5],则代表当前模型结构搜索到的通道数为[8, 40, 128]。<br>
2. token中每个数字的搜索列表长度(`range_table`函数),tokens中每个token的索引范围。 &emsp; 2. token中每个数字的搜索列表长度(`range_table`函数),tokens中每个token的索引范围。<br>
3. 根据token产生模型结构(`token2arch`函数),根据搜索到的tokens列表产生模型结构。 &emsp; 3. 根据token产生模型结构(`token2arch`函数),根据搜索到的tokens列表产生模型结构。 <br>
以新增reset block为例说明如何构造自己的search space。自定义的search space不能和已有的search space同名。 以新增reset block为例说明如何构造自己的search space。自定义的search space不能和已有的search space同名。
...@@ -70,17 +67,18 @@ class ResNetBlockSpace2(SearchSpaceBase): ...@@ -70,17 +67,18 @@ class ResNetBlockSpace2(SearchSpaceBase):
def init_tokens(self): def init_tokens(self):
return [0] * 3 * len(self.block_mask) return [0] * 3 * len(self.block_mask)
### 定义 ### 定义token的index的取值范围
def range_table(self): def range_table(self):
return [len(self.filter_num)] * 3 * len(self.block_mask) return [len(self.filter_num)] * 3 * len(self.block_mask)
### 把token转换成模型结构
def token2arch(self, tokens=None): def token2arch(self, tokens=None):
if tokens == None: if tokens == None:
tokens = self.init_tokens() tokens = self.init_tokens()
self.bottleneck_params_list = [] self.bottleneck_params_list = []
for i in range(len(self.block_mask)): for i in range(len(self.block_mask)):
self.bottleneck_params_list.append(self.filter_num[tokens[i * 3 + 0]], self.bottleneck_params_list.append(self.filter_num[tokens[i * 3 + 0]],
self.filter_num[tokens[i * 3 + 1]], self.filter_num[tokens[i * 3 + 1]],
self.filter_num[tokens[i * 3 + 2]], self.filter_num[tokens[i * 3 + 2]],
2 if self.block_mask[i] == 1 else 1) 2 if self.block_mask[i] == 1 else 1)
...@@ -113,4 +111,4 @@ class ResNetBlockSpace2(SearchSpaceBase): ...@@ -113,4 +111,4 @@ class ResNetBlockSpace2(SearchSpaceBase):
conv = fluid.layers.conv2d(input, num_filters, filter_size, stride, name=name+'_conv') conv = fluid.layers.conv2d(input, num_filters, filter_size, stride, name=name+'_conv')
bn = fluid.layers.batch_norm(conv, act=act, name=name+'_bn') bn = fluid.layers.batch_norm(conv, act=act, name=name+'_bn')
return bn return bn
``` ```
...@@ -86,7 +86,7 @@ merge过程操作较多,具体细节请参考[merge API文档](https://paddlep ...@@ -86,7 +86,7 @@ merge过程操作较多,具体细节请参考[merge API文档](https://paddlep
```python ```python
data_name_map = {'data': 'image'} data_name_map = {'data': 'image'}
student_program = merge(teacher_program, student_program, data_name_map, place) merge(teacher_program, student_program, data_name_map, place)
``` ```
### 5.添加蒸馏loss ### 5.添加蒸馏loss
......
# 卷积通道剪裁示例
本示例将演示如何按指定的剪裁率对每个卷积层的通道数进行剪裁。该示例默认会自动下载并使用mnist数据。
当前示例支持以下分类模型:
- MobileNetV1
- MobileNetV2
- ResNet50
- PVANet
## 接口介绍
该示例使用了`paddleslim.Pruner`工具类,用户接口使用介绍请参考:[API文档](https://paddlepaddle.github.io/PaddleSlim/api/prune_api/)
## 确定待裁参数
不同模型的参数命名不同,在剪裁前需要确定待裁卷积层的参数名称。可通过以下方法列出所有参数名:
```
for param in program.global_block().all_parameters():
print("param name: {}; shape: {}".format(param.name, param.shape))
```
`train.py`脚本中,提供了`get_pruned_params`方法,根据用户设置的选项`--model`确定要裁剪的参数。
## 启动裁剪任务
通过以下命令启动裁剪任务:
```
export CUDA_VISIBLE_DEVICES=0
python train.py
```
执行`python train.py --help`查看更多选项。
## 注意
1. 在接口`paddle.Pruner.prune`的参数中,`params``ratios`的长度需要一样。
...@@ -2,6 +2,7 @@ site_name: PaddleSlim Docs ...@@ -2,6 +2,7 @@ site_name: PaddleSlim Docs
repo_url: https://github.com/PaddlePaddle/PaddleSlim repo_url: https://github.com/PaddlePaddle/PaddleSlim
nav: nav:
- Home: index.md - Home: index.md
- 模型库: model_zoo.md
- 教程: - 教程:
- 离线量化: tutorials/quant_post_demo.md - 离线量化: tutorials/quant_post_demo.md
- 量化训练: tutorials/quant_aware_demo.md - 量化训练: tutorials/quant_aware_demo.md
...@@ -14,7 +15,7 @@ nav: ...@@ -14,7 +15,7 @@ nav:
- 模型分析: api/analysis_api.md - 模型分析: api/analysis_api.md
- 知识蒸馏: api/single_distiller_api.md - 知识蒸馏: api/single_distiller_api.md
- SA搜索: api/nas_api.md - SA搜索: api/nas_api.md
- 搜索空间: api/search_space.md - 搜索空间: search_space.md
- 硬件延时评估表: table_latency.md - 硬件延时评估表: table_latency.md
- 算法原理: algo/algo.md - 算法原理: algo/algo.md
......
...@@ -36,7 +36,7 @@ def flops(program, only_conv=True, detail=False): ...@@ -36,7 +36,7 @@ def flops(program, only_conv=True, detail=False):
return _graph_flops(graph, only_conv=only_conv, detail=detail) return _graph_flops(graph, only_conv=only_conv, detail=detail)
def _graph_flops(graph, only_conv=False, detail=False): def _graph_flops(graph, only_conv=True, detail=False):
assert isinstance(graph, GraphWrapper) assert isinstance(graph, GraphWrapper)
flops = 0 flops = 0
params2flops = {} params2flops = {}
...@@ -66,12 +66,14 @@ def _graph_flops(graph, only_conv=False, detail=False): ...@@ -66,12 +66,14 @@ def _graph_flops(graph, only_conv=False, detail=False):
y_shape = op.inputs("Y")[0].shape() y_shape = op.inputs("Y")[0].shape()
if x_shape[0] == -1: if x_shape[0] == -1:
x_shape[0] = 1 x_shape[0] = 1
flops += x_shape[0] * x_shape[1] * y_shape[1]
op_flops = x_shape[0] * x_shape[1] * y_shape[1] op_flops = x_shape[0] * x_shape[1] * y_shape[1]
flops += op_flops flops += op_flops
params2flops[op.inputs("Y")[0].name()] = op_flops params2flops[op.inputs("Y")[0].name()] = op_flops
elif op.type() in ['relu', 'sigmoid', 'batch_norm', 'relu6'] and not only_conv: elif op.type() in ['relu', 'sigmoid', 'batch_norm', 'relu6'
] and not only_conv:
input_shape = list(op.inputs("X")[0].shape()) input_shape = list(op.inputs("X")[0].shape())
if input_shape[0] == -1: if input_shape[0] == -1:
input_shape[0] = 1 input_shape[0] = 1
......
...@@ -26,17 +26,22 @@ class ControllerClient(object): ...@@ -26,17 +26,22 @@ class ControllerClient(object):
Controller client. Controller client.
""" """
def __init__(self, server_ip=None, server_port=None, key=None): def __init__(self,
server_ip=None,
server_port=None,
key=None,
client_name=None):
""" """
Args: Args:
server_ip(str): The ip that controller server listens on. None means getting the ip automatically. Default: None. server_ip(str): The ip that controller server listens on. None means getting the ip automatically. Default: None.
server_port(int): The port that controller server listens on. 0 means getting usable port automatically. Default: 0. server_port(int): The port that controller server listens on. 0 means getting usable port automatically. Default: 0.
key(str): The key used to identify legal agent for controller server. Default: "light-nas" key(str): The key used to identify legal agent for controller server. Default: "light-nas"
client_name(str): Current client name, random generate for counting client number. Default: None.
""" """
self.server_ip = server_ip self.server_ip = server_ip
self.server_port = server_port self.server_port = server_port
self.socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._key = key self._key = key
self._client_name = client_name
def update(self, tokens, reward, iter): def update(self, tokens, reward, iter):
""" """
...@@ -48,8 +53,8 @@ class ControllerClient(object): ...@@ -48,8 +53,8 @@ class ControllerClient(object):
socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
socket_client.connect((self.server_ip, self.server_port)) socket_client.connect((self.server_ip, self.server_port))
tokens = ",".join([str(token) for token in tokens]) tokens = ",".join([str(token) for token in tokens])
socket_client.send("{}\t{}\t{}\t{}".format(self._key, tokens, reward, socket_client.send("{}\t{}\t{}\t{}\t{}".format(
iter).encode()) self._key, tokens, reward, iter, self._client_name).encode())
response = socket_client.recv(1024).decode() response = socket_client.recv(1024).decode()
if response.strip('\n').split("\t") == "ok": if response.strip('\n').split("\t") == "ok":
return True return True
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import os import os
import logging import logging
import socket import socket
import time
from .log_helper import get_logger from .log_helper import get_logger
from threading import Thread from threading import Thread
from .lock_utils import lock, unlock from .lock_utils import lock, unlock
...@@ -41,7 +42,8 @@ class ControllerServer(object): ...@@ -41,7 +42,8 @@ class ControllerServer(object):
address(tuple): The address of current server binding with format (ip, port). Default: ('', 0). address(tuple): The address of current server binding with format (ip, port). Default: ('', 0).
which means setting ip automatically which means setting ip automatically
max_client_num(int): The maximum number of clients connecting to current server simultaneously. Default: 100. max_client_num(int): The maximum number of clients connecting to current server simultaneously. Default: 100.
search_steps(int): The total steps of searching. None means never stopping. Default: None search_steps(int|None): The total steps of searching. None means never stopping. Default: None
key(str|None): Config information. Default: None.
""" """
self._controller = controller self._controller = controller
self._address = address self._address = address
...@@ -51,6 +53,9 @@ class ControllerServer(object): ...@@ -51,6 +53,9 @@ class ControllerServer(object):
self._port = address[1] self._port = address[1]
self._ip = address[0] self._ip = address[0]
self._key = key self._key = key
self._client_num = 0
self._client = dict()
self._compare_time = 172800 ### 48 hours
def start(self): def start(self):
self._socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
...@@ -93,15 +98,43 @@ class ControllerServer(object): ...@@ -93,15 +98,43 @@ class ControllerServer(object):
_logger.debug("recv message from {}: [{}]".format(addr, _logger.debug("recv message from {}: [{}]".format(addr,
message)) message))
messages = message.strip('\n').split("\t") messages = message.strip('\n').split("\t")
if (len(messages) < 4) or (messages[0] != self._key): if (len(messages) < 5) or (messages[0] != self._key):
_logger.debug("recv noise from {}: [{}]".format( _logger.debug("recv noise from {}: [{}]".format(
addr, message)) addr, message))
continue continue
tokens = messages[1] tokens = messages[1]
reward = messages[2] reward = messages[2]
iter = messages[3] iter = messages[3]
client_name = messages[4]
one_step_time = -1
if client_name in self._client.keys():
current_time = time.time() - self._client[client_name]
if current_time > one_step_time:
one_step_time = current_time
self._compare_time = 2 * one_step_time
if client_name not in self._client.keys():
self._client[client_name] = time.time()
self._client_num += 1
self._client[client_name] = time.time()
for key_client in self._client.keys():
### if a client not request token in double train one tokens' time, we think this client was stoped.
if (time.time() - self._client[key_client]
) > self._compare_time and len(self._client.keys(
)) > 1:
self._client.pop(key_client)
self._client_num -= 1
_logger.info(
"client: {}, client_num: {}, compare_time: {}".format(
self._client, self._client_num,
self._compare_time))
tokens = [int(token) for token in tokens.split(",")] tokens = [int(token) for token in tokens.split(",")]
self._controller.update(tokens, float(reward), int(iter)) self._controller.update(tokens,
float(reward),
int(iter), int(self._client_num))
response = "ok" response = "ok"
conn.send(response.encode()) conn.send(response.encode())
_logger.debug("send message to {}: [{}]".format(addr, _logger.debug("send message to {}: [{}]".format(addr,
......
...@@ -34,7 +34,7 @@ class SAController(EvolutionaryController): ...@@ -34,7 +34,7 @@ class SAController(EvolutionaryController):
def __init__(self, def __init__(self,
range_table=None, range_table=None,
reduce_rate=0.85, reduce_rate=0.85,
init_temperature=1024, init_temperature=None,
max_try_times=300, max_try_times=300,
init_tokens=None, init_tokens=None,
reward=-1, reward=-1,
...@@ -68,12 +68,20 @@ class SAController(EvolutionaryController): ...@@ -68,12 +68,20 @@ class SAController(EvolutionaryController):
self._max_try_times = max_try_times self._max_try_times = max_try_times
self._reward = reward self._reward = reward
self._tokens = init_tokens self._tokens = init_tokens
if init_temperature == None:
if init_tokens == None:
self._init_temperature = 10.0
else:
self._init_temperature = 1.0
self._constrain_func = constrain_func self._constrain_func = constrain_func
self._max_reward = max_reward self._max_reward = max_reward
self._best_tokens = best_tokens self._best_tokens = best_tokens
self._iter = iters self._iter = iters
self._checkpoints = checkpoints self._checkpoints = checkpoints
self._searched = searched if searched != None else dict() self._searched = searched if searched != None else dict()
self._current_token = init_tokens
def __getstate__(self): def __getstate__(self):
d = {} d = {}
...@@ -92,9 +100,9 @@ class SAController(EvolutionaryController): ...@@ -92,9 +100,9 @@ class SAController(EvolutionaryController):
@property @property
def current_tokens(self): def current_tokens(self):
return self._tokens return self._current_tokens
def update(self, tokens, reward, iter): def update(self, tokens, reward, iter, client_num):
""" """
Update the controller according to latest tokens and reward. Update the controller according to latest tokens and reward.
Args: Args:
...@@ -105,7 +113,9 @@ class SAController(EvolutionaryController): ...@@ -105,7 +113,9 @@ class SAController(EvolutionaryController):
if iter > self._iter: if iter > self._iter:
self._iter = iter self._iter = iter
self._searched[str(tokens)] = reward self._searched[str(tokens)] = reward
temperature = self._init_temperature * self._reduce_rate**self._iter temperature = self._init_temperature * self._reduce_rate**(client_num *
self._iter)
self._current_tokens = tokens
if (reward > self._reward) or (np.random.random() <= math.exp( if (reward > self._reward) or (np.random.random() <= math.exp(
(reward - self._reward) / temperature)): (reward - self._reward) / temperature)):
self._reward = reward self._reward = reward
...@@ -117,6 +127,9 @@ class SAController(EvolutionaryController): ...@@ -117,6 +127,9 @@ class SAController(EvolutionaryController):
"Controller - iter: {}; best_reward: {}, best tokens: {}, current_reward: {}; current tokens: {}". "Controller - iter: {}; best_reward: {}, best tokens: {}, current_reward: {}; current tokens: {}".
format(self._iter, self._max_reward, self._best_tokens, reward, format(self._iter, self._max_reward, self._best_tokens, reward,
tokens)) tokens))
_logger.debug(
'Controller - iter: {}, controller current tokens: {}, controller current reward: {}'.
format(self._iter, self._tokens, self._reward))
if self._checkpoints != None: if self._checkpoints != None:
self._save_checkpoint(self._checkpoints) self._save_checkpoint(self._checkpoints)
...@@ -137,7 +150,7 @@ class SAController(EvolutionaryController): ...@@ -137,7 +150,7 @@ class SAController(EvolutionaryController):
_logger.debug("change index[{}] from {} to {}".format( _logger.debug("change index[{}] from {} to {}".format(
index, tokens[index], new_tokens[index])) index, tokens[index], new_tokens[index]))
if self._searched.has_key(str(new_tokens)): if str(new_tokens) in self._searched.keys():
_logger.debug('get next tokens including searched tokens: {}'. _logger.debug('get next tokens including searched tokens: {}'.
format(new_tokens)) format(new_tokens))
continue continue
......
...@@ -93,6 +93,8 @@ class VarWrapper(object): ...@@ -93,6 +93,8 @@ class VarWrapper(object):
ops.append(op) ops.append(op)
return ops return ops
def is_parameter(self):
return isinstance(self._var, Parameter)
class OpWrapper(object): class OpWrapper(object):
def __init__(self, op, graph): def __init__(self, op, graph):
......
...@@ -34,7 +34,6 @@ def merge(teacher_program, ...@@ -34,7 +34,6 @@ def merge(teacher_program,
paddle run on which device. paddle run on which device.
scope(Scope): The input scope scope(Scope): The input scope
name_prefix(str): Name prefix added for all vars of the teacher program. name_prefix(str): Name prefix added for all vars of the teacher program.
Return(Program): Merged program.
""" """
teacher_program = teacher_program.clone(for_test=True) teacher_program = teacher_program.clone(for_test=True)
for teacher_var in teacher_program.list_vars(): for teacher_var in teacher_program.list_vars():
...@@ -51,7 +50,7 @@ def merge(teacher_program, ...@@ -51,7 +50,7 @@ def merge(teacher_program,
old_var = scope.var(teacher_var.name).get_tensor() old_var = scope.var(teacher_var.name).get_tensor()
renamed_var = scope.var(new_name).get_tensor() renamed_var = scope.var(new_name).get_tensor()
renamed_var.set(np.array(old_var), place) renamed_var.set(np.array(old_var), place)
# program var rename # program var rename
renamed_var = teacher_program.global_block()._rename_var( renamed_var = teacher_program.global_block()._rename_var(
teacher_var.name, new_name) teacher_var.name, new_name)
...@@ -84,11 +83,13 @@ def merge(teacher_program, ...@@ -84,11 +83,13 @@ def merge(teacher_program,
attrs[attr_name] = op.attr(attr_name) attrs[attr_name] = op.attr(attr_name)
student_program.global_block().append_op( student_program.global_block().append_op(
type=op.type, inputs=inputs, outputs=outputs, attrs=attrs) type=op.type, inputs=inputs, outputs=outputs, attrs=attrs)
return student_program
def fsp_loss(teacher_var1_name, teacher_var2_name, student_var1_name, def fsp_loss(teacher_var1_name,
student_var2_name, program=fluid.default_main_program()): teacher_var2_name,
student_var1_name,
student_var2_name,
program=fluid.default_main_program()):
""" """
Combine variables from student model and teacher model by fsp-loss. Combine variables from student model and teacher model by fsp-loss.
Args: Args:
...@@ -115,7 +116,8 @@ def fsp_loss(teacher_var1_name, teacher_var2_name, student_var1_name, ...@@ -115,7 +116,8 @@ def fsp_loss(teacher_var1_name, teacher_var2_name, student_var1_name,
return fsp_loss return fsp_loss
def l2_loss(teacher_var_name, student_var_name, def l2_loss(teacher_var_name,
student_var_name,
program=fluid.default_main_program()): program=fluid.default_main_program()):
""" """
Combine variables from student model and teacher model by l2-loss. Combine variables from student model and teacher model by l2-loss.
......
...@@ -18,6 +18,7 @@ import logging ...@@ -18,6 +18,7 @@ import logging
import numpy as np import numpy as np
import json import json
import hashlib import hashlib
import time
import paddle.fluid as fluid import paddle.fluid as fluid
from ..core import VarWrapper, OpWrapper, GraphWrapper from ..core import VarWrapper, OpWrapper, GraphWrapper
from ..common import SAController from ..common import SAController
...@@ -37,12 +38,13 @@ class SANAS(object): ...@@ -37,12 +38,13 @@ class SANAS(object):
def __init__(self, def __init__(self,
configs, configs,
server_addr=("", 8881), server_addr=("", 8881),
init_temperature=100, init_temperature=None,
reduce_rate=0.85, reduce_rate=0.85,
search_steps=300, search_steps=300,
init_tokens=None,
save_checkpoint='nas_checkpoint', save_checkpoint='nas_checkpoint',
load_checkpoint=None, load_checkpoint=None,
is_server=False): is_server=True):
""" """
Search a group of ratios used to prune program. Search a group of ratios used to prune program.
Args: Args:
...@@ -50,9 +52,10 @@ class SANAS(object): ...@@ -50,9 +52,10 @@ class SANAS(object):
`key` is the name of search space with data type str. `input_size` and `output_size` are `key` is the name of search space with data type str. `input_size` and `output_size` are
input size and output size of searched sub-network. `block_num` is the number of blocks in searched network, `block_mask` is a list consists by 0 and 1, 0 means normal block, 1 means reduction block. input size and output size of searched sub-network. `block_num` is the number of blocks in searched network, `block_mask` is a list consists by 0 and 1, 0 means normal block, 1 means reduction block.
server_addr(tuple): A tuple of server ip and server port for controller server. server_addr(tuple): A tuple of server ip and server port for controller server.
init_temperature(float): The init temperature used in simulated annealing search strategy. init_temperature(float|None): The init temperature used in simulated annealing search strategy. Default: None.
reduce_rate(float): The decay rate used in simulated annealing search strategy. reduce_rate(float): The decay rate used in simulated annealing search strategy. Default: None.
search_steps(int): The steps of searching. search_steps(int): The steps of searching. Default: 300.
init_token(list): Init tokens user can set by yourself. Default: None.
save_checkpoint(string|None): The directory of checkpoint to save, if set to None, not save checkpoint. Default: 'nas_checkpoint'. save_checkpoint(string|None): The directory of checkpoint to save, if set to None, not save checkpoint. Default: 'nas_checkpoint'.
load_checkpoint(string|None): The directory of checkpoint to load, if set to None, not load checkpoint. Default: None. load_checkpoint(string|None): The directory of checkpoint to load, if set to None, not load checkpoint. Default: None.
is_server(bool): Whether current host is controller server. Default: True. is_server(bool): Whether current host is controller server. Default: True.
...@@ -64,7 +67,12 @@ class SANAS(object): ...@@ -64,7 +67,12 @@ class SANAS(object):
self._init_temperature = init_temperature self._init_temperature = init_temperature
self._is_server = is_server self._is_server = is_server
self._configs = configs self._configs = configs
self._key = hashlib.md5(str(self._configs).encode("utf-8")).hexdigest() self._init_tokens = init_tokens
self._client_name = hashlib.md5(
str(time.time() + np.random.randint(1, 10000)).encode(
"utf-8")).hexdigest()
self._key = str(self._configs)
self._current_tokens = init_tokens
server_ip, server_port = server_addr server_ip, server_port = server_addr
if server_ip == None or server_ip == "": if server_ip == None or server_ip == "":
...@@ -75,7 +83,7 @@ class SANAS(object): ...@@ -75,7 +83,7 @@ class SANAS(object):
# create controller server # create controller server
if self._is_server: if self._is_server:
init_tokens = self._search_space.init_tokens() init_tokens = self._search_space.init_tokens(self._init_tokens)
range_table = self._search_space.range_table() range_table = self._search_space.range_table()
range_table = (len(range_table) * [0], range_table) range_table = (len(range_table) * [0], range_table)
_logger.info("range table: {}".format(range_table)) _logger.info("range table: {}".format(range_table))
...@@ -127,7 +135,10 @@ class SANAS(object): ...@@ -127,7 +135,10 @@ class SANAS(object):
server_port = self._controller_server.port() server_port = self._controller_server.port()
self._controller_client = ControllerClient( self._controller_client = ControllerClient(
server_ip, server_port, key=self._key) server_ip,
server_port,
key=self._key,
client_name=self._client_name)
if is_server and load_checkpoint != None: if is_server and load_checkpoint != None:
self._iter = scene['_iter'] self._iter = scene['_iter']
...@@ -138,6 +149,11 @@ class SANAS(object): ...@@ -138,6 +149,11 @@ class SANAS(object):
return socket.gethostbyname(socket.gethostname()) return socket.gethostbyname(socket.gethostname())
def tokens2arch(self, tokens): def tokens2arch(self, tokens):
"""
Convert tokens to network architectures.
Returns:
list<function>: A list of functions that define networks.
"""
return self._search_space.token2arch(tokens) return self._search_space.token2arch(tokens)
def current_info(self): def current_info(self):
...@@ -159,6 +175,7 @@ class SANAS(object): ...@@ -159,6 +175,7 @@ class SANAS(object):
list<function>: A list of functions that define networks. list<function>: A list of functions that define networks.
""" """
self._current_tokens = self._controller_client.next_tokens() self._current_tokens = self._controller_client.next_tokens()
_logger.info("current tokens: {}".format(self._current_tokens))
archs = self._search_space.token2arch(self._current_tokens) archs = self._search_space.token2arch(self._current_tokens)
return archs return archs
......
...@@ -97,16 +97,19 @@ class CombineSearchSpace(object): ...@@ -97,16 +97,19 @@ class CombineSearchSpace(object):
space = cls(input_size, output_size, block_num, block_mask=block_mask) space = cls(input_size, output_size, block_num, block_mask=block_mask)
return space return space
def init_tokens(self): def init_tokens(self, tokens=None):
""" """
Combine init tokens. Combine init tokens.
""" """
tokens = [] if tokens is None:
self.single_token_num = [] tokens = []
for space in self.spaces: self.single_token_num = []
tokens.extend(space.init_tokens()) for space in self.spaces:
self.single_token_num.append(len(space.init_tokens())) tokens.extend(space.init_tokens())
return tokens self.single_token_num.append(len(space.init_tokens()))
return tokens
else:
return tokens
def range_table(self): def range_table(self):
""" """
......
...@@ -22,7 +22,7 @@ from paddle.fluid.param_attr import ParamAttr ...@@ -22,7 +22,7 @@ from paddle.fluid.param_attr import ParamAttr
from .search_space_base import SearchSpaceBase from .search_space_base import SearchSpaceBase
from .base_layer import conv_bn_layer from .base_layer import conv_bn_layer
from .search_space_registry import SEARCHSPACE from .search_space_registry import SEARCHSPACE
from .utils import compute_downsample_num, check_points from .utils import compute_downsample_num, check_points, get_random_tokens
__all__ = ["InceptionABlockSpace", "InceptionCBlockSpace"] __all__ = ["InceptionABlockSpace", "InceptionCBlockSpace"]
### TODO add asymmetric kernel of conv when paddle-lite support ### TODO add asymmetric kernel of conv when paddle-lite support
...@@ -58,10 +58,7 @@ class InceptionABlockSpace(SearchSpaceBase): ...@@ -58,10 +58,7 @@ class InceptionABlockSpace(SearchSpaceBase):
""" """
The initial token. The initial token.
""" """
if self.block_mask != None: return get_random_tokens(self.range_table)
return [0] * (len(self.block_mask) * 9)
else:
return [0] * (self.block_num * 9)
def range_table(self): def range_table(self):
""" """
...@@ -290,10 +287,7 @@ class InceptionCBlockSpace(SearchSpaceBase): ...@@ -290,10 +287,7 @@ class InceptionCBlockSpace(SearchSpaceBase):
""" """
The initial token. The initial token.
""" """
if self.block_mask != None: return get_random_tokens(self.range_table)
return [0] * (len(self.block_mask) * 11)
else:
return [0] * (self.block_num * 11)
def range_table(self): def range_table(self):
""" """
......
...@@ -22,7 +22,7 @@ from paddle.fluid.param_attr import ParamAttr ...@@ -22,7 +22,7 @@ from paddle.fluid.param_attr import ParamAttr
from .search_space_base import SearchSpaceBase from .search_space_base import SearchSpaceBase
from .base_layer import conv_bn_layer from .base_layer import conv_bn_layer
from .search_space_registry import SEARCHSPACE from .search_space_registry import SEARCHSPACE
from .utils import compute_downsample_num, check_points from .utils import compute_downsample_num, check_points, get_random_tokens
__all__ = ["MobileNetV1BlockSpace", "MobileNetV2BlockSpace"] __all__ = ["MobileNetV1BlockSpace", "MobileNetV2BlockSpace"]
...@@ -60,10 +60,7 @@ class MobileNetV2BlockSpace(SearchSpaceBase): ...@@ -60,10 +60,7 @@ class MobileNetV2BlockSpace(SearchSpaceBase):
self.scale = scale self.scale = scale
def init_tokens(self): def init_tokens(self):
if self.block_mask != None: return get_random_tokens(self.range_table)
return [0] * (len(self.block_mask) * 4)
else:
return [0] * (self.block_num * 4)
def range_table(self): def range_table(self):
range_table_base = [] range_table_base = []
...@@ -308,10 +305,7 @@ class MobileNetV1BlockSpace(SearchSpaceBase): ...@@ -308,10 +305,7 @@ class MobileNetV1BlockSpace(SearchSpaceBase):
self.scale = scale self.scale = scale
def init_tokens(self): def init_tokens(self):
if self.block_mask != None: return get_random_tokens(self.range_table)
return [0] * (len(self.block_mask) * 3)
else:
return [0] * (self.block_num * 3)
def range_table(self): def range_table(self):
range_table_base = [] range_table_base = []
......
...@@ -22,7 +22,7 @@ from paddle.fluid.param_attr import ParamAttr ...@@ -22,7 +22,7 @@ from paddle.fluid.param_attr import ParamAttr
from .search_space_base import SearchSpaceBase from .search_space_base import SearchSpaceBase
from .base_layer import conv_bn_layer from .base_layer import conv_bn_layer
from .search_space_registry import SEARCHSPACE from .search_space_registry import SEARCHSPACE
from .utils import check_points from .utils import check_points, get_random_tokens
__all__ = ["ResNetSpace"] __all__ = ["ResNetSpace"]
...@@ -47,8 +47,7 @@ class ResNetSpace(SearchSpaceBase): ...@@ -47,8 +47,7 @@ class ResNetSpace(SearchSpaceBase):
""" """
The initial token. The initial token.
""" """
init_token_base = [0, 0, 0, 0, 0, 0, 0, 0] return [1, 1, 2, 2, 3, 4, 3, 1]
return init_token_base
def range_table(self): def range_table(self):
""" """
......
...@@ -22,7 +22,7 @@ from paddle.fluid.param_attr import ParamAttr ...@@ -22,7 +22,7 @@ from paddle.fluid.param_attr import ParamAttr
from .search_space_base import SearchSpaceBase from .search_space_base import SearchSpaceBase
from .base_layer import conv_bn_layer from .base_layer import conv_bn_layer
from .search_space_registry import SEARCHSPACE from .search_space_registry import SEARCHSPACE
from .utils import compute_downsample_num, check_points from .utils import compute_downsample_num, check_points, get_random_tokens
__all__ = ["ResNetBlockSpace"] __all__ = ["ResNetBlockSpace"]
...@@ -40,14 +40,11 @@ class ResNetBlockSpace(SearchSpaceBase): ...@@ -40,14 +40,11 @@ class ResNetBlockSpace(SearchSpaceBase):
self.downsample_num, self.block_num) self.downsample_num, self.block_num)
self.filter_num = np.array( self.filter_num = np.array(
[48, 64, 96, 128, 160, 192, 224, 256, 320, 384, 512, 640]) [48, 64, 96, 128, 160, 192, 224, 256, 320, 384, 512, 640])
self.repeat = np.array([0, 1, 2]) self.repeat = np.array([0, 1, 2, 3, 4, 6, 7, 8, 10, 12, 14, 16])
self.k_size = np.array([3, 5]) self.k_size = np.array([3, 5])
def init_tokens(self): def init_tokens(self):
if self.block_mask != None: return get_random_tokens(self.range_table)
return [0] * (len(self.block_mask) * 6)
else:
return [0] * (self.block_num * 6)
def range_table(self): def range_table(self):
range_table_base = [] range_table_base = []
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import math import math
import numpy as np
def compute_downsample_num(input_size, output_size): def compute_downsample_num(input_size, output_size):
...@@ -36,3 +37,11 @@ def check_points(count, points): ...@@ -36,3 +37,11 @@ def check_points(count, points):
return (True if count in points else False) return (True if count in points else False)
else: else:
return (True if count == points else False) return (True if count == points else False)
def get_random_tokens(range_table):
tokens = []
for idx, max_value in enumerate(range_table):
tokens_idx = int(np.floor(range_table[idx] * np.random.rand(1)))
tokens.append(tokens_idx)
return tokens
...@@ -23,6 +23,8 @@ from .sensitive_pruner import * ...@@ -23,6 +23,8 @@ from .sensitive_pruner import *
import sensitive_pruner import sensitive_pruner
from .sensitive import * from .sensitive import *
import sensitive import sensitive
from prune_walker import *
import prune_walker
__all__ = [] __all__ = []
...@@ -32,3 +34,4 @@ __all__ += controller_server.__all__ ...@@ -32,3 +34,4 @@ __all__ += controller_server.__all__
__all__ += controller_client.__all__ __all__ += controller_client.__all__
__all__ += sensitive_pruner.__all__ __all__ += sensitive_pruner.__all__
__all__ += sensitive.__all__ __all__ += sensitive.__all__
__all__ += prune_walker.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
from ..core import Registry
from ..common import get_logger
__all__ = ["PRUNE_WORKER", "conv2d"]
_logger = get_logger(__name__, level=logging.INFO)
PRUNE_WORKER = Registry('prune_worker')
class PruneWorker(object):
def __init__(self, op, pruned_params=[], visited={}):
"""
A wrapper of operator used to infer the information of all the related variables.
Args:
op(Operator): The operator to be pruned.
pruned_params(list): The list to store the information of pruning that infered by walker.
visited(dict): The auxiliary dict to record the visited operators and variables. The key is a encoded string of operator id and variable name.
Return: A instance of PruneWalker.
"""
self.op = op
self.pruned_params = pruned_params
self.visited = visited
def prune(self, var, pruned_axis, pruned_idx):
"""
Infer the shape of variables related with current operator, predecessor and successor.
It will search the graph to find all varibles related with `var` and record the information of pruning.
Args:
var(Variable): The root variable of searching. It can be the input or output of current operator.
pruned_axis(int): The axis to be pruned of root variable.
pruned_idx(int): The indexes to be pruned in `pruned_axis` of root variable.
"""
if self._visit(var, pruned_axis):
self._prune(var, pruned_axis, pruned_idx)
def _visit(self, var, pruned_axis):
key = "_".join([str(self.op.idx()), var.name()])
if pruned_axis not in self.visited:
self.visited[pruned_axis] = {}
if key in self.visited[pruned_axis]:
return False
else:
self.visited[pruned_axis][key] = True
return True
def _prune(self, var, pruned_axis, pruned_idx):
raise NotImplementedError('Abstract method.')
def _prune_op(self, op, var, pruned_axis, pruned_idx, visited=None):
if op.type().endswith("_grad"):
return
if visited is not None:
self.visited = visited
cls = PRUNE_WORKER.get(op.type())
assert cls is not None, "The walker of {} is not registered.".format(
op.type())
_logger.debug("\nfrom: {}\nto: {}\npruned_axis: {}; var: {}".format(
self.op, op, pruned_axis, var.name()))
walker = cls(op,
pruned_params=self.pruned_params,
visited=self.visited)
walker.prune(var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
class conv2d(PruneWorker):
def __init__(self, op, pruned_params, visited={}):
super(conv2d, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
data_format = self.op.attr("data_format")
channel_axis = 1
if data_format == "NHWC":
channel_axis = 3
if var in self.op.inputs("Input"):
assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}; var: {}".format(
pruned_axis, var.name())
filter_var = self.op.inputs("Filter")[0]
self._visit(filter_var, 1)
self.pruned_params.append((filter_var, 1, pruned_idx))
for op in filter_var.outputs():
self._prune_op(op, filter_var, 1, pruned_idx)
elif var in self.op.inputs("Filter"):
assert pruned_axis in [0, 1]
self.pruned_params.append((var, pruned_axis, pruned_idx))
for op in var.outputs():
self._prune_op(op, var, pruned_axis, pruned_idx)
if pruned_axis == 0:
if len(self.op.inputs("Bias")) > 0:
self.pruned_params.append(
(self.op.inputs("Bias"), channel_axis, pruned_idx))
output_var = self.op.outputs("Output")[0]
self._visit(output_var, channel_axis)
next_ops = output_var.outputs()
for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx)
elif pruned_axis == 1:
input_var = self.op.inputs("Input")[0]
self._visit(input_var, channel_axis)
pre_ops = input_var.inputs()
for op in pre_ops:
self._prune_op(op, input_var, channel_axis, pruned_idx)
elif var in self.op.outputs("Output"):
assert pruned_axis == channel_axis, "pruned_axis: {}; var: {}".format(
pruned_axis, var.name())
filter_var = self.op.inputs("Filter")[0]
self._visit(filter_var, 0)
self.pruned_params.append((filter_var, 0, pruned_idx))
for op in filter_var.outputs():
self._prune_op(op, filter_var, 0, pruned_idx)
if len(self.op.inputs("Bias")) > 0:
self.pruned_params.append(
(self.op.inputs("Bias")[0], channel_axis, pruned_idx))
output_var = self.op.outputs("Output")[0]
next_ops = output_var.outputs()
for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx)
@PRUNE_WORKER.register
class batch_norm(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(batch_norm, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
if (var not in self.op.outputs("Y")) and (
var not in self.op.inputs("X")):
return
if var in self.op.outputs("Y"):
in_var = self.op.inputs("X")[0]
self._visit(in_var, pruned_axis)
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
for param in ["Scale", "Bias", "Mean", "Variance"]:
param_var = self.op.inputs(param)[0]
for op in param_var.outputs():
self._prune_op(op, param_var, 0, pruned_idx)
self.pruned_params.append((param_var, 0, pruned_idx))
out_var = self.op.outputs("Y")[0]
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
class elementwise_op(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(elementwise_op, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
axis = self.op.attr("axis")
if axis == -1: # TODO
axis = 0
if var in self.op.outputs("Out"):
for name in ["X", "Y"]:
actual_axis = pruned_axis
if name == "Y":
actual_axis = pruned_axis - axis
in_var = self.op.inputs(name)[0]
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, actual_axis, pruned_idx)
else:
if var in self.op.inputs("X"):
in_var = self.op.inputs("Y")[0]
if in_var.is_parameter():
self.pruned_params.append(
(in_var, pruned_axis - axis, pruned_idx))
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis - axis, pruned_idx)
elif var in self.op.inputs("Y"):
in_var = self.op.inputs("X")[0]
pre_ops = in_var.inputs()
pruned_axis = pruned_axis + axis
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
out_var = self.op.outputs("Out")[0]
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
class elementwise_add(elementwise_op):
def __init__(self, op, pruned_params, visited):
super(elementwise_add, self).__init__(op, pruned_params, visited)
@PRUNE_WORKER.register
class elementwise_sub(elementwise_op):
def __init__(self, op, pruned_params, visited):
super(elementwise_sub, self).__init__(op, pruned_params, visited)
@PRUNE_WORKER.register
class elementwise_mul(elementwise_op):
def __init__(self, op, pruned_params, visited):
super(elementwise_mul, self).__init__(op, pruned_params, visited)
class activation(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(activation, self).__init__(op, pruned_params, visited)
self.input_name = "X"
self.output_name = "Out"
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.outputs(self.output_name):
in_var = self.op.inputs(self.input_name)[0]
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
out_var = self.op.outputs(self.output_name)[0]
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
class uniform_random_batch_size_like(activation):
def __init__(self, op, pruned_params, visited):
super(uniform_random_batch_size_like, self).__init__(op, pruned_params,
visited)
self.input_name = "Input"
self.output_name = "Out"
@PRUNE_WORKER.register
class bilinear_interp(activation):
def __init__(self, op, pruned_params, visited):
super(bilinear_interp, self).__init__(op, pruned_params, visited)
@PRUNE_WORKER.register
class nearest_interp(activation):
def __init__(self, op, pruned_params, visited):
super(nearest_interp, self).__init__(op, pruned_params, visited)
@PRUNE_WORKER.register
class relu(activation):
def __init__(self, op, pruned_params, visited):
super(relu, self).__init__(op, pruned_params, visited)
@PRUNE_WORKER.register
class leaky_relu(activation):
def __init__(self, op, pruned_params, visited):
super(leaky_relu, self).__init__(op, pruned_params, visited)
@PRUNE_WORKER.register
class floor(activation):
def __init__(self, op, pruned_params, visited):
super(floor, self).__init__(op, pruned_params, visited)
@PRUNE_WORKER.register
class relu6(activation):
def __init__(self, op, pruned_params, visited):
super(relu6, self).__init__(op, pruned_params, visited)
@PRUNE_WORKER.register
class pool2d(activation):
def __init__(self, op, pruned_params, visited):
super(pool2d, self).__init__(op, pruned_params, visited)
@PRUNE_WORKER.register
class sum(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(sum, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.outputs("Out"):
for in_var in self.op.inputs("X"):
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
elif var in self.op.inputs("X"):
for in_var in self.op.inputs("X"):
if in_var != var:
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
out_var = self.op.outputs("Out")[0]
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
class concat(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(concat, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
idx = []
axis = self.op.attr("axis")
if var in self.op.outputs("Out"):
start = 0
if axis == pruned_axis:
for _, in_var in enumerate(self.op.inputs("X")):
idx = []
for i in pruned_idx:
r_idx = i - start
if r_idx < in_var.shape()[pruned_axis] and r_idx >= 0:
idx.append(r_idx)
start += in_var.shape()[pruned_axis]
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, idx)
idx = pruned_idx[:]
else:
for _, in_var in enumerate(self.op.inputs("X")):
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
elif var in self.op.inputs("X"):
if axis == pruned_axis:
idx = []
start = 0
for v in self.op.inputs("X"):
if v.name() == var.name():
idx = [i + start for i in pruned_idx]
else:
start += v.shape()[pruned_axis]
out_var = self.op.outputs("Out")[0]
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, idx, visited={})
else:
for v in self.op.inputs("X"):
for op in v.inputs():
self._prune_op(op, v, pruned_axis, pruned_idx)
out_var = self.op.outputs("Out")[0]
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
class depthwise_conv2d(PruneWorker):
def __init__(self, op, pruned_params, visited={}):
super(depthwise_conv2d, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
data_format = self.op.attr("data_format")
channel_axis = 1
if data_format == "NHWC":
channel_axis = 3
if var in self.op.inputs("Input"):
assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}".format(
pruned_axis)
filter_var = self.op.inputs("Filter")[0]
self.pruned_params.append((filter_var, 0, pruned_idx))
self._visit(filter_var, 0)
new_groups = filter_var.shape()[0] - len(pruned_idx)
self.op.set_attr("groups", new_groups)
for op in filter_var.outputs():
self._prune_op(op, filter_var, 0, pruned_idx)
output_var = self.op.outputs("Output")[0]
next_ops = output_var.outputs()
for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx)
elif var in self.op.inputs("Filter"):
assert pruned_axis in [0]
if pruned_axis == 0:
if len(self.op.inputs("Bias")) > 0:
self.pruned_params.append(
(self.op.inputs("Bias"), channel_axis, pruned_idx))
self.pruned_params.append((var, 0, pruned_idx))
new_groups = var.shape()[0] - len(pruned_idx)
self.op.set_attr("groups", new_groups)
for op in var.outputs():
self._prune_op(op, var, 0, pruned_idx)
output_var = self.op.outputs("Output")[0]
self._visit(output_var, channel_axis)
next_ops = output_var.outputs()
for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx)
for op in var.outputs():
self._prune_op(op, var, pruned_axis, pruned_idx)
elif var in self.op.outputs("Output"):
assert pruned_axis == channel_axis
filter_var = self.op.inputs("Filter")[0]
self.pruned_params.append((filter_var, 0, pruned_idx))
self._visit(filter_var, 0)
new_groups = filter_var.shape()[0] - len(pruned_idx)
op.set_attr("groups", new_groups)
for op in filter_var.outputs():
self._prune_op(op, filter_var, 0, pruned_idx)
if len(self.op.inputs("Bias")) > 0:
self.pruned_params.append(
(self.op.inputs("Bias")[0], channel_axis, pruned_idx))
in_var = self.op.inputs("Input")[0]
self._visit(in_var, channel_axis)
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, channel_axis, pruned_idx)
output_var = self.op.outputs("Output")[0]
next_ops = output_var.outputs()
for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx)
@PRUNE_WORKER.register
class mul(PruneWorker):
def __init__(self, op, pruned_params, visited={}):
super(mul, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.inputs("X"):
assert pruned_axis == 1, "The Input of conv2d can only be pruned at axis 1, but got {}".format(
pruned_axis)
idx = []
feature_map_size = var.shape()[2] * var.shape()[3]
range_idx = np.array(range(feature_map_size))
for i in pruned_idx:
idx += list(range_idx + i * feature_map_size)
param_var = self.op.inputs("Y")[0]
self.pruned_params.append((param_var, 0, idx))
for op in param_var.outputs():
self._prune_op(op, param_var, 0, pruned_idx)
@PRUNE_WORKER.register
class scale(PruneWorker):
def __init__(self, op, pruned_params, visited={}):
super(scale, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.inputs("X"):
out_var = self.op.outputs("Out")[0]
for op in out_var.outputs():
self._prune_op(op, out_var, pruned_axis, pruned_idx)
elif var in self.op.outputs("Out"):
in_var = self.op.inputs("X")[0]
for op in in_var.inputs():
self._prune_op(op, in_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
class momentum(PruneWorker):
def __init__(self, op, pruned_params, visited={}):
super(momentum, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.inputs("Param"):
_logger.debug("pruning momentum, var:{}".format(var.name()))
velocity_var = self.op.inputs("Velocity")[0]
self.pruned_params.append((velocity_var, pruned_axis, pruned_idx))
@PRUNE_WORKER.register
class adam(PruneWorker):
def __init__(self, op, pruned_params, visited={}):
super(adam, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.inputs("Param"):
_logger.debug("pruning momentum, var:{}".format(var.name()))
moment1_var = self.op.inputs("Moment1")[0]
self.pruned_params.append((moment1_var, pruned_axis, pruned_idx))
moment2_var = self.op.inputs("Moment2")[0]
self.pruned_params.append((moment2_var, pruned_axis, pruned_idx))
...@@ -17,6 +17,7 @@ import numpy as np ...@@ -17,6 +17,7 @@ import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import copy import copy
from ..core import VarWrapper, OpWrapper, GraphWrapper from ..core import VarWrapper, OpWrapper, GraphWrapper
from .prune_walker import conv2d as conv2d_walker
from ..common import get_logger from ..common import get_logger
__all__ = ["Pruner"] __all__ = ["Pruner"]
...@@ -67,561 +68,66 @@ class Pruner(): ...@@ -67,561 +68,66 @@ class Pruner():
graph = GraphWrapper(program.clone()) graph = GraphWrapper(program.clone())
param_backup = {} if param_backup else None param_backup = {} if param_backup else None
param_shape_backup = {} if param_shape_backup else None param_shape_backup = {} if param_shape_backup else None
self._prune_parameters(
graph,
scope,
params,
ratios,
place,
lazy=lazy,
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
for op in graph.ops():
if op.type() == 'depthwise_conv2d' or op.type(
) == 'depthwise_conv2d_grad':
op.set_attr('groups', op.inputs('Filter')[0].shape()[0])
return graph.program, param_backup, param_shape_backup
def _prune_filters_by_ratio(self,
scope,
params,
ratio,
place,
lazy=False,
only_graph=False,
param_shape_backup=None,
param_backup=None):
"""
Pruning filters by given ratio.
Args:
scope(fluid.core.Scope): The scope used to pruning filters.
params(list<VarWrapper>): A list of filter parameters.
ratio(float): The ratio to be pruned.
place(fluid.Place): The device place of filter parameters.
lazy(bool): True means setting the pruned elements to zero.
False means cutting down the pruned elements.
only_graph(bool): True means only modifying the graph.
False means modifying graph and variables in scope.
"""
if params[0].name() in self.pruned_list[0]:
return
if only_graph:
pruned_num = int(round(params[0].shape()[0] * ratio))
for param in params:
ori_shape = param.shape()
if param_backup is not None and (
param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy(ori_shape)
new_shape = list(ori_shape)
new_shape[0] -= pruned_num
param.set_shape(new_shape)
_logger.debug("prune [{}] from {} to {}".format(param.name(
), ori_shape, new_shape))
self.pruned_list[0].append(param.name())
return range(pruned_num)
else:
param_t = scope.find_var(params[0].name()).get_tensor()
pruned_idx = self._cal_pruned_idx(
params[0].name(), np.array(param_t), ratio, axis=0)
for param in params:
assert isinstance(param, VarWrapper)
param_t = scope.find_var(param.name()).get_tensor()
if param_backup is not None and (
param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy(
np.array(param_t))
try:
pruned_param = self._prune_tensor(
np.array(param_t),
pruned_idx,
pruned_axis=0,
lazy=lazy)
except IndexError as e:
_logger.error("Pruning {}, but get [{}]".format(param.name(
), e))
param_t.set(pruned_param, place)
ori_shape = param.shape()
if param_shape_backup is not None and (
param.name() not in param_shape_backup):
param_shape_backup[param.name()] = copy.deepcopy(
param.shape())
new_shape = list(param.shape())
new_shape[0] = pruned_param.shape[0]
param.set_shape(new_shape)
_logger.debug("prune [{}] from {} to {}".format(param.name(
), ori_shape, new_shape))
self.pruned_list[0].append(param.name())
return pruned_idx
def _prune_parameter_by_idx(self,
scope,
params,
pruned_idx,
pruned_axis,
place,
lazy=False,
only_graph=False,
param_shape_backup=None,
param_backup=None):
"""
Pruning parameters in given axis.
Args:
scope(fluid.core.Scope): The scope storing paramaters to be pruned.
params(VarWrapper): The parameter to be pruned.
pruned_idx(list): The index of elements to be pruned.
pruned_axis(int): The pruning axis.
place(fluid.Place): The device place of filter parameters.
lazy(bool): True means setting the pruned elements to zero.
False means cutting down the pruned elements.
only_graph(bool): True means only modifying the graph.
False means modifying graph and variables in scope.
"""
if params[0].name() in self.pruned_list[pruned_axis]:
return
if only_graph:
pruned_num = len(pruned_idx)
for param in params:
ori_shape = param.shape()
if param_backup is not None and (
param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy(ori_shape)
new_shape = list(ori_shape)
new_shape[pruned_axis] -= pruned_num
param.set_shape(new_shape)
_logger.debug("prune [{}] from {} to {}".format(param.name(
), ori_shape, new_shape))
self.pruned_list[pruned_axis].append(param.name())
else:
for param in params:
assert isinstance(param, VarWrapper)
param_t = scope.find_var(param.name()).get_tensor()
if param_backup is not None and (
param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy(
np.array(param_t))
pruned_param = self._prune_tensor(
np.array(param_t), pruned_idx, pruned_axis, lazy=lazy)
param_t.set(pruned_param, place)
ori_shape = param.shape()
if param_shape_backup is not None and (
param.name() not in param_shape_backup):
param_shape_backup[param.name()] = copy.deepcopy(
param.shape())
new_shape = list(param.shape())
new_shape[pruned_axis] = pruned_param.shape[pruned_axis]
param.set_shape(new_shape)
_logger.debug("prune [{}] from {} to {}".format(param.name(
), ori_shape, new_shape))
self.pruned_list[pruned_axis].append(param.name())
def _forward_search_related_op(self, graph, node):
"""
Forward search operators that will be affected by pruning of param.
Args:
graph(GraphWrapper): The graph to be searched.
node(VarWrapper|OpWrapper): The current pruned parameter or operator.
Returns:
list<OpWrapper>: A list of operators.
"""
visited = {} visited = {}
for op in graph.ops(): pruned_params = []
visited[op.idx()] = False
stack = []
visit_path = []
if isinstance(node, VarWrapper):
for op in graph.ops():
if (not op.is_bwd_op()) and (node in op.all_inputs()):
next_ops = self._get_next_unvisited_op(graph, visited, op)
# visit_path.append(op)
visited[op.idx()] = True
for next_op in next_ops:
if visited[next_op.idx()] == False:
stack.append(next_op)
visit_path.append(next_op)
visited[next_op.idx()] = True
elif isinstance(node, OpWrapper):
next_ops = self._get_next_unvisited_op(graph, visited, node)
for next_op in next_ops:
if visited[next_op.idx()] == False:
stack.append(next_op)
visit_path.append(next_op)
visited[next_op.idx()] = True
while len(stack) > 0:
#top_op = stack[len(stack) - 1]
top_op = stack.pop(0)
next_ops = None
if top_op.type() in ["conv2d", "deformable_conv"]:
next_ops = None
elif top_op.type() in ["mul", "concat"]:
next_ops = None
else:
next_ops = self._get_next_unvisited_op(graph, visited, top_op)
if next_ops != None:
for op in next_ops:
if visited[op.idx()] == False:
stack.append(op)
visit_path.append(op)
visited[op.idx()] = True
return visit_path
def _get_next_unvisited_op(self, graph, visited, top_op):
"""
Get next unvisited adjacent operators of given operators.
Args:
graph(GraphWrapper): The graph used to search.
visited(list): The ids of operators that has been visited.
top_op: The given operator.
Returns:
list<OpWrapper>: A list of operators.
"""
assert isinstance(top_op, OpWrapper)
next_ops = []
for op in graph.next_ops(top_op):
if (visited[op.idx()] == False) and (not op.is_bwd_op()):
next_ops.append(op)
return next_ops
def _get_accumulator(self, graph, param):
"""
Get accumulators of given parameter. The accumulator was created by optimizer.
Args:
graph(GraphWrapper): The graph used to search.
param(VarWrapper): The given parameter.
Returns:
list<VarWrapper>: A list of accumulators which are variables.
"""
assert isinstance(param, VarWrapper)
params = []
for op in param.outputs():
if op.is_opt_op():
for out_var in op.all_outputs():
if graph.is_persistable(out_var) and out_var.name(
) != param.name():
params.append(out_var)
return params
def _forward_pruning_ralated_params(self,
graph,
scope,
param,
place,
ratio=None,
pruned_idxs=None,
lazy=False,
only_graph=False,
param_backup=None,
param_shape_backup=None):
"""
Pruning all the parameters affected by the pruning of given parameter.
Args:
graph(GraphWrapper): The graph to be searched.
scope(fluid.core.Scope): The scope storing paramaters to be pruned.
param(VarWrapper): The given parameter.
place(fluid.Place): The device place of filter parameters.
ratio(float): The target ratio to be pruned.
pruned_idx(list): The index of elements to be pruned.
lazy(bool): True means setting the pruned elements to zero.
False means cutting down the pruned elements.
only_graph(bool): True means only modifying the graph.
False means modifying graph and variables in scope.
"""
assert isinstance(
graph,
GraphWrapper), "graph must be instance of slim.core.GraphWrapper"
assert isinstance(
param,
VarWrapper), "param must be instance of slim.core.VarWrapper"
if param.name() in self.pruned_list[0]:
return
related_ops = self._forward_search_related_op(graph, param)
for op in related_ops:
_logger.debug("relate op: {};".format(op))
if ratio is None:
assert pruned_idxs is not None
self._prune_parameter_by_idx(
scope, [param] + self._get_accumulator(graph, param),
pruned_idxs,
pruned_axis=0,
place=place,
lazy=lazy,
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
else:
pruned_idxs = self._prune_filters_by_ratio(
scope, [param] + self._get_accumulator(graph, param),
ratio,
place,
lazy=lazy,
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
self._prune_ops(related_ops, pruned_idxs, graph, scope, place, lazy,
only_graph, param_backup, param_shape_backup)
def _prune_ops(self, ops, pruned_idxs, graph, scope, place, lazy,
only_graph, param_backup, param_shape_backup):
for idx, op in enumerate(ops):
if op.type() in ["conv2d", "deformable_conv"]:
for in_var in op.all_inputs():
if graph.is_parameter(in_var):
conv_param = in_var
self._prune_parameter_by_idx(
scope, [conv_param] + self._get_accumulator(
graph, conv_param),
pruned_idxs,
pruned_axis=1,
place=place,
lazy=lazy,
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
if op.type() == "depthwise_conv2d":
for in_var in op.all_inputs():
if graph.is_parameter(in_var):
conv_param = in_var
self._prune_parameter_by_idx(
scope, [conv_param] + self._get_accumulator(
graph, conv_param),
pruned_idxs,
pruned_axis=0,
place=place,
lazy=lazy,
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
elif op.type() == "elementwise_add":
# pruning bias
for in_var in op.all_inputs():
if graph.is_parameter(in_var):
bias_param = in_var
self._prune_parameter_by_idx(
scope, [bias_param] + self._get_accumulator(
graph, bias_param),
pruned_idxs,
pruned_axis=0,
place=place,
lazy=lazy,
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
elif op.type() == "mul": # pruning fc layer
fc_input = None
fc_param = None
for in_var in op.all_inputs():
if graph.is_parameter(in_var):
fc_param = in_var
else:
fc_input = in_var
idx = []
feature_map_size = fc_input.shape()[2] * fc_input.shape()[3]
range_idx = np.array(range(feature_map_size))
for i in pruned_idxs:
idx += list(range_idx + i * feature_map_size)
corrected_idxs = idx
self._prune_parameter_by_idx(
scope, [fc_param] + self._get_accumulator(graph, fc_param),
corrected_idxs,
pruned_axis=0,
place=place,
lazy=lazy,
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
elif op.type() == "concat":
concat_inputs = op.all_inputs()
last_op = ops[idx - 1]
concat_idx = None
for last_op in reversed(ops):
for out_var in last_op.all_outputs():
if out_var in concat_inputs:
concat_idx = concat_inputs.index(out_var)
break
if concat_idx is not None:
break
offset = 0
for ci in range(concat_idx):
offset += concat_inputs[ci].shape()[1]
corrected_idxs = [x + offset for x in pruned_idxs]
related_ops = self._forward_search_related_op(graph, op)
for op in related_ops:
_logger.debug("concat relate op: {};".format(op))
self._prune_ops(related_ops, corrected_idxs, graph, scope,
place, lazy, only_graph, param_backup,
param_shape_backup)
elif op.type() == "batch_norm":
bn_inputs = op.all_inputs()
in_num = len(bn_inputs)
beta = bn_inputs[0]
mean = bn_inputs[1]
alpha = bn_inputs[2]
variance = bn_inputs[3]
self._prune_parameter_by_idx(
scope, [mean] + self._get_accumulator(graph, mean),
pruned_idxs,
pruned_axis=0,
place=place,
lazy=lazy,
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
self._prune_parameter_by_idx(
scope, [variance] + self._get_accumulator(graph, variance),
pruned_idxs,
pruned_axis=0,
place=place,
lazy=lazy,
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
self._prune_parameter_by_idx(
scope, [alpha] + self._get_accumulator(graph, alpha),
pruned_idxs,
pruned_axis=0,
place=place,
lazy=lazy,
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
self._prune_parameter_by_idx(
scope, [beta] + self._get_accumulator(graph, beta),
pruned_idxs,
pruned_axis=0,
place=place,
lazy=lazy,
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
def _prune_parameters(self,
graph,
scope,
params,
ratios,
place,
lazy=False,
only_graph=False,
param_backup=None,
param_shape_backup=None):
"""
Pruning the given parameters.
Args:
graph(GraphWrapper): The graph to be searched.
scope(fluid.core.Scope): The scope storing paramaters to be pruned.
params(list<str>): A list of parameter names to be pruned.
ratios(list<float>): A list of ratios to be used to pruning parameters.
place(fluid.Place): The device place of filter parameters.
pruned_idx(list): The index of elements to be pruned.
lazy(bool): True means setting the pruned elements to zero.
False means cutting down the pruned elements.
only_graph(bool): True means only modifying the graph.
False means modifying graph and variables in scope.
"""
assert len(params) == len(ratios)
self.pruned_list = [[], []]
for param, ratio in zip(params, ratios): for param, ratio in zip(params, ratios):
assert isinstance(param, str) or isinstance(param, unicode) if only_graph:
if param in self.pruned_list[0]: param_v = graph.var(param)
_logger.info("Skip {}".format(param)) pruned_num = int(round(param_v.shape()[0] * ratio))
continue pruned_idx = [0] * pruned_num
_logger.info("pruning param: {}".format(param)) else:
param_t = np.array(scope.find_var(param).get_tensor())
pruned_idx = self._cal_pruned_idx(param_t, ratio, axis=0)
param = graph.var(param) param = graph.var(param)
self._forward_pruning_ralated_params( conv_op = param.outputs()[0]
graph, walker = conv2d_walker(
scope, conv_op, pruned_params=pruned_params, visited=visited)
param, walker.prune(param, pruned_axis=0, pruned_idx=pruned_idx)
place,
ratio=ratio, merge_pruned_params = {}
lazy=lazy, for param, pruned_axis, pruned_idx in pruned_params:
only_graph=only_graph, if param.name() not in merge_pruned_params:
param_backup=param_backup, merge_pruned_params[param.name()] = {}
param_shape_backup=param_shape_backup) if pruned_axis not in merge_pruned_params[param.name()]:
ops = param.outputs() merge_pruned_params[param.name()][pruned_axis] = []
for op in ops: merge_pruned_params[param.name()][pruned_axis].append(pruned_idx)
if op.type() in ['conv2d', 'deformable_conv']:
brother_ops = self._search_brother_ops(graph, op) for param_name in merge_pruned_params:
for broher in brother_ops: for pruned_axis in merge_pruned_params[param_name]:
_logger.debug("pruning brother: {}".format(broher)) pruned_idx = np.concatenate(merge_pruned_params[param_name][
for p in graph.get_param_by_op(broher): pruned_axis])
self._forward_pruning_ralated_params( param = graph.var(param_name)
graph, if not lazy:
scope, _logger.debug("{}\t{}\t{}".format(param.name(
p, ), pruned_axis, len(pruned_idx)))
place, if param_shape_backup is not None:
ratio=ratio, origin_shape = copy.deepcopy(param.shape())
lazy=lazy, param_shape_backup[param.name()] = origin_shape
only_graph=only_graph, new_shape = list(param.shape())
param_backup=param_backup, new_shape[pruned_axis] -= len(pruned_idx)
param_shape_backup=param_shape_backup) param.set_shape(new_shape)
if not only_graph:
def _search_brother_ops(self, graph, op_node): param_t = scope.find_var(param.name()).get_tensor()
""" if param_backup is not None and (
Search brother operators that was affected by pruning of given operator. param.name() not in param_backup):
Args: param_backup[param.name()] = copy.deepcopy(
graph(GraphWrapper): The graph to be searched. np.array(param_t))
op_node(OpWrapper): The start node for searching. try:
Returns: pruned_param = self._prune_tensor(
list<VarWrapper>: A list of operators. np.array(param_t),
""" pruned_idx,
_logger.debug("######################search: {}######################". pruned_axis=pruned_axis,
format(op_node)) lazy=lazy)
visited = [op_node.idx()] except IndexError as e:
stack = [] _logger.error("Pruning {}, but get [{}]".format(
brothers = [] param.name(), e))
for op in graph.next_ops(op_node):
if ("conv2d" not in op.type()) and ( param_t.set(pruned_param, place)
"concat" not in op.type()) and ( graph.update_groups_of_conv()
"deformable_conv" not in op.type()) and ( return graph.program, param_backup, param_shape_backup
op.type() != 'fc') and (
not op.is_bwd_op()) and (not op.is_opt_op()):
stack.append(op)
visited.append(op.idx())
while len(stack) > 0:
top_op = stack.pop()
for parent in graph.pre_ops(top_op):
if parent.idx() not in visited and (
not parent.is_bwd_op()) and (not parent.is_opt_op()):
_logger.debug("----------go back from {} to {}----------".
format(top_op, parent))
if (('conv2d' in parent.type()) or
("deformable_conv" in parent.type()) or
(parent.type() == 'fc')):
brothers.append(parent)
else:
stack.append(parent)
visited.append(parent.idx())
for child in graph.next_ops(top_op):
if ('conv2d' not in child.type()) and (
"concat" not in child.type()) and (
'deformable_conv' not in child.type()) and (
child.type() != 'fc') and (
child.idx() not in visited) and (
not child.is_bwd_op()) and (
not child.is_opt_op()):
stack.append(child)
visited.append(child.idx())
_logger.debug("brothers: {}".format(brothers))
_logger.debug(
"######################Finish search######################".format(
op_node))
return brothers
def _cal_pruned_idx(self, name, param, ratio, axis): def _cal_pruned_idx(self, param, ratio, axis):
""" """
Calculate the index to be pruned on axis by given pruning ratio. Calculate the index to be pruned on axis by given pruning ratio.
Args: Args:
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
# limitations under the License. # limitations under the License.
import copy import copy
import logging
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
...@@ -24,22 +26,37 @@ from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization ...@@ -24,22 +26,37 @@ from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass
from paddle.fluid import core from paddle.fluid import core
from ..common import get_logger
_logger = get_logger(__name__, level=logging.INFO)
WEIGHT_QUANTIZATION_TYPES = [ WEIGHT_QUANTIZATION_TYPES = [
'abs_max', 'channel_wise_abs_max', 'range_abs_max', 'abs_max', 'channel_wise_abs_max', 'range_abs_max',
'moving_average_abs_max' 'moving_average_abs_max'
] ]
WEIGHT_QUANTIZATION_TYPES_TENSORRT = ['channel_wise_abs_max']
ACTIVATION_QUANTIZATION_TYPES = [ ACTIVATION_QUANTIZATION_TYPES = [
'abs_max', 'range_abs_max', 'moving_average_abs_max' 'abs_max', 'range_abs_max', 'moving_average_abs_max'
] ]
ACTIVATION_QUANTIZATION_TYPES_TENSORRT = [
'range_abs_max', 'moving_average_abs_max'
]
VALID_DTYPES = ['int8'] VALID_DTYPES = ['int8']
TRANSFORM_PASS_OP_TYPES = ['conv2d', 'depthwise_conv2d', 'mul'] TRANSFORM_PASS_OP_TYPES = QuantizationTransformPass._supported_quantizable_op_type
QUANT_DEQUANT_PASS_OP_TYPES = ['elementwise_add', 'pool2d'] QUANT_DEQUANT_PASS_OP_TYPES = AddQuantDequantPass._supported_quantizable_op_type + \
AddQuantDequantPass._activation_type
TENSORRT_OP_TYPES = [
'mul', 'conv2d', 'pool2d', 'depthwise_conv2d', 'elementwise_add',
'leaky_relu'
]
_quant_config_default = { _quant_config_default = {
# weight quantize type, default is 'abs_max' # weight quantize type, default is 'channel_wise_abs_max'
'weight_quantize_type': 'abs_max', 'weight_quantize_type': 'channel_wise_abs_max',
# activation quantize type, default is 'abs_max' # activation quantize type, default is 'moving_average_abs_max'
'activation_quantize_type': 'abs_max', 'activation_quantize_type': 'moving_average_abs_max',
# weight quantize bit num, default is 8 # weight quantize bit num, default is 8
'weight_bits': 8, 'weight_bits': 8,
# activation quantize bit num, default is 8 # activation quantize bit num, default is 8
...@@ -47,25 +64,25 @@ _quant_config_default = { ...@@ -47,25 +64,25 @@ _quant_config_default = {
# ops of name_scope in not_quant_pattern list, will not be quantized # ops of name_scope in not_quant_pattern list, will not be quantized
'not_quant_pattern': ['skip_quant'], 'not_quant_pattern': ['skip_quant'],
# ops of type in quantize_op_types, will be quantized # ops of type in quantize_op_types, will be quantized
'quantize_op_types': 'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
['conv2d', 'depthwise_conv2d', 'mul', 'elementwise_add', 'pool2d'],
# data type after quantization, such as 'uint8', 'int8', etc. default is 'int8' # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
'dtype': 'int8', 'dtype': 'int8',
# window size for 'range_abs_max' quantization. defaulf is 10000 # window size for 'range_abs_max' quantization. defaulf is 10000
'window_size': 10000, 'window_size': 10000,
# The decay coefficient of moving average, default is 0.9 # The decay coefficient of moving average, default is 0.9
'moving_rate': 0.9, 'moving_rate': 0.9,
# if set quant_weight_only True, then only quantize parameters of layers which need to be quantized, # if True, 'quantize_op_types' will be TENSORRT_OP_TYPES
# and activations will not be quantized. 'for_tensorrt': False,
'quant_weight_only': False # if True, 'quantoze_op_types' will be TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES
'is_full_quantize': False
} }
def _parse_configs(user_config): def _parse_configs(user_config):
""" """
check user configs is valid, and set default value if user not config. check if user's configs are valid.
Args: Args:
user_config(dict):the config of user. user_config(dict): user's config.
Return: Return:
configs(dict): final configs will be used. configs(dict): final configs will be used.
""" """
...@@ -73,12 +90,26 @@ def _parse_configs(user_config): ...@@ -73,12 +90,26 @@ def _parse_configs(user_config):
configs = copy.deepcopy(_quant_config_default) configs = copy.deepcopy(_quant_config_default)
configs.update(user_config) configs.update(user_config)
# check configs is valid assert isinstance(configs['for_tensorrt'], bool) and isinstance(
assert configs['weight_quantize_type'] in WEIGHT_QUANTIZATION_TYPES, \ configs['is_full_quantize'],
"Unknown weight_quantize_type: '%s'. It can only be " + " ".join(WEIGHT_QUANTIZATION_TYPES) bool), "'for_tensorrt' and 'is_full_quantize' must both be bool'"
# check if configs is valid
if configs['for_tensorrt']:
weight_types = WEIGHT_QUANTIZATION_TYPES_TENSORRT
activation_types = ACTIVATION_QUANTIZATION_TYPES_TENSORRT
platform = 'TensorRT'
else:
weight_types = WEIGHT_QUANTIZATION_TYPES
activation_types = WEIGHT_QUANTIZATION_TYPES
platform = 'PaddleLite'
assert configs['weight_quantize_type'] in weight_types, \
"Unknown weight_quantize_type: {}. {} only supports {} ".format(configs['weight_quantize_type'],
platform, weight_types)
assert configs['activation_quantize_type'] in ACTIVATION_QUANTIZATION_TYPES, \ assert configs['activation_quantize_type'] in activation_types, \
"Unknown activation_quantize_type: '%s'. It can only be " + " ".join(ACTIVATION_QUANTIZATION_TYPES) "Unknown activation_quantize_type: {}. {} only supports {}".format(configs['activation_quantize_type'],
platform, activation_types)
assert isinstance(configs['weight_bits'], int), \ assert isinstance(configs['weight_bits'], int), \
"weight_bits must be int value." "weight_bits must be int value."
...@@ -92,17 +123,24 @@ def _parse_configs(user_config): ...@@ -92,17 +123,24 @@ def _parse_configs(user_config):
assert (configs['activation_bits'] >= 1 and configs['activation_bits'] <= 16), \ assert (configs['activation_bits'] >= 1 and configs['activation_bits'] <= 16), \
"activation_bits should be between 1 and 16." "activation_bits should be between 1 and 16."
assert isinstance(configs['not_quant_pattern'], list), \ assert isinstance(configs['not_quant_pattern'], (list, str)), \
"not_quant_pattern must be a list" "not_quant_pattern must be list or str"
assert isinstance(configs['quantize_op_types'], list), \ assert isinstance(configs['quantize_op_types'], list), \
"quantize_op_types must be a list" "quantize_op_types must be a list"
for op_type in configs['quantize_op_types']: if configs['for_tensorrt']:
assert (op_type in QUANT_DEQUANT_PASS_OP_TYPES) or ( configs['quantize_op_types'] = TENSORRT_OP_TYPES
op_type in TRANSFORM_PASS_OP_TYPES), "{} is not support, \ elif configs['is_full_quantize']:
now support op types are {}".format( configs[
op_type, TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES) 'quantize_op_types'] = TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES
else:
for op_type in configs['quantize_op_types']:
assert (op_type in QUANT_DEQUANT_PASS_OP_TYPES) or (
op_type in TRANSFORM_PASS_OP_TYPES), "{} is not support, \
now support op types are {}".format(
op_type,
TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES)
assert isinstance(configs['dtype'], str), \ assert isinstance(configs['dtype'], str), \
"dtype must be a str." "dtype must be a str."
...@@ -116,36 +154,31 @@ def _parse_configs(user_config): ...@@ -116,36 +154,31 @@ def _parse_configs(user_config):
assert isinstance(configs['moving_rate'], float), \ assert isinstance(configs['moving_rate'], float), \
"moving_rate must be float value, The decay coefficient of moving average, default is 0.9." "moving_rate must be float value, The decay coefficient of moving average, default is 0.9."
assert isinstance(configs['quant_weight_only'], bool), \
"quant_weight_only must be bool value, if set quant_weight_only True, " \
"then only quantize parameters of layers which need to be quantized, " \
" and activations will not be quantized."
return configs return configs
def quant_aware(program, place, config, scope=None, for_test=False): def quant_aware(program, place, config=None, scope=None, for_test=False):
""" """
add trainable quantization ops in program. add trainable quantization ops in program.
Args: Args:
program(fluid.Program): program program(fluid.Program): program to quant
scope(fluid.Scope): the scope to store var, it's should be the value of program's scope, usually it's fluid.global_scope(). place(fluid.CPUPlace or fluid.CUDAPlace): CPU or CUDA device
place(fluid.CPUPlace or fluid.CUDAPlace): place config(dict, optional): configs for quantization. if None, will use default config. Default is None.
config(dict): configs for quantization, default values are in quant_config_default dict. scope(fluid.Scope): the scope to store var, it should be program's scope. if None, will use fluid.global_scope().
for_test: if program is test program, for_test should be set True, else False. default is None.
for_test(bool): if program is test program, set True when program is for test, False when program is for train. Default is False.
Return: Return:
fluid.Program: user can finetune this quantization program to enhance the accuracy. fluid.Program: user can finetune this quantization program to enhance the accuracy.
""" """
scope = fluid.global_scope() if not scope else scope scope = fluid.global_scope() if not scope else scope
assert isinstance(config, dict), "config must be dict" if config is None:
config = _quant_config_default
assert 'weight_quantize_type' in config.keys( else:
), 'weight_quantize_type must be configured' assert isinstance(config, dict), "config must be dict"
assert 'activation_quantize_type' in config.keys( config = _parse_configs(config)
), 'activation_quantize_type must be configured' _logger.info("quant_aware config {}".format(config))
config = _parse_configs(config)
main_graph = IrGraph(core.Graph(program.desc), for_test=for_test) main_graph = IrGraph(core.Graph(program.desc), for_test=for_test)
transform_pass_ops = [] transform_pass_ops = []
...@@ -197,7 +230,10 @@ def quant_post(executor, ...@@ -197,7 +230,10 @@ def quant_post(executor,
batch_nums=None, batch_nums=None,
scope=None, scope=None,
algo='KL', algo='KL',
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"]): quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False,
is_use_cache_file=False,
cache_dir="./temp_post_training"):
""" """
The function utilizes post training quantization method to quantize the The function utilizes post training quantization method to quantize the
fp32 model. It uses calibrate data to calculate the scale factor of fp32 model. It uses calibrate data to calculate the scale factor of
...@@ -232,6 +268,11 @@ def quant_post(executor, ...@@ -232,6 +268,11 @@ def quant_post(executor,
quantizable_op_type(list[str], optional): The list of op types quantizable_op_type(list[str], optional): The list of op types
that will be quantized. Default is ["conv2d", "depthwise_conv2d", that will be quantized. Default is ["conv2d", "depthwise_conv2d",
"mul"]. "mul"].
is_full_quantize(bool): if True, apply quantization to all supported quantizable op type.
If False, only apply quantization to the input quantizable_op_type. Default is False.
is_use_cache_file(bool): If False, all temp data will be saved in memory. If True,
all temp data will be saved to disk. Defalut is False.
cache_dir(str): When 'is_use_cache_file' is True, temp data will be save in 'cache_dir'. Default is './temp_post_training'.
Returns: Returns:
None None
""" """
...@@ -246,41 +287,64 @@ def quant_post(executor, ...@@ -246,41 +287,64 @@ def quant_post(executor,
scope=scope, scope=scope,
algo=algo, algo=algo,
quantizable_op_type=quantizable_op_type, quantizable_op_type=quantizable_op_type,
is_full_quantize=False) is_full_quantize=is_full_quantize,
is_use_cache_file=is_use_cache_file,
cache_dir=cache_dir)
post_training_quantization.quantize() post_training_quantization.quantize()
post_training_quantization.save_quantized_model(quantize_model_path) post_training_quantization.save_quantized_model(quantize_model_path)
def convert(program, place, config, scope=None, save_int8=False): def convert(program, place, config=None, scope=None, save_int8=False):
""" """
add quantization ops in program. the program returned is not trainable. change quantization ops order in program. return program that can used by Paddle-Lite.
Args: Args:
program(fluid.Program): program program(fluid.Program): program that returned by quant_aware
scope(fluid.Scope): the scope to store var, when is None will use fluid.global_scope() place(fluid.CPUPlace or fluid.CUDAPlace): CPU or CUDA device
place(fluid.CPUPlace or fluid.CUDAPlace): place scope(fluid.Scope, optional): the scope to store var, it should be program's scope. if None, will use fluid.global_scope().
config(dict): configs for quantization, default values are in quant_config_default dict. default is None.
save_int8: is export int8 freezed program. config(dict, optional): configs for convert. if set None, will use default config. Default is None.\
It must be same with config that used in 'quant_aware'.
save_int8: if return int8 freezed program. Int8 program can only be used to check size of model weights. \
It cannot be used in Fluid or Paddle-Lite.
Return: Return:
fluid.Program: freezed program which can be used for inference. freezed_program(fluid.Program): freezed program which can be used for inference.
parameters is float32 type, but it's value in int8 range. parameters is float32 type, but it's value in int8 range.
fluid.Program: freezed int8 program which can be used for inference. freezed_program_int8(fluid.Program): freezed int8 program.
if save_int8 is False, this value is None. when save_int8 is False, return freezed_program.
when save_int8 is True, return freezed_program and freezed_program_int8
""" """
scope = fluid.global_scope() if not scope else scope scope = fluid.global_scope() if not scope else scope
if config is None:
config = _quant_config_default
else:
assert isinstance(config, dict), "config must be dict"
config = _parse_configs(config)
_logger.info("convert config {}".format(config))
test_graph = IrGraph(core.Graph(program.desc), for_test=True) test_graph = IrGraph(core.Graph(program.desc), for_test=True)
support_op_types = []
for op in config['quantize_op_types']:
if op in QuantizationFreezePass._supported_quantizable_op_type:
support_op_types.append(op)
# Freeze the graph after training by adjusting the quantize # Freeze the graph after training by adjusting the quantize
# operators' order for the inference. # operators' order for the inference.
freeze_pass = QuantizationFreezePass( freeze_pass = QuantizationFreezePass(
scope=scope, scope=scope,
place=place, place=place,
weight_quantize_type=config['weight_quantize_type']) weight_bits=config['weight_bits'],
activation_bits=config['activation_bits'],
weight_quantize_type=config['weight_quantize_type'],
quantizable_op_type=support_op_types)
freeze_pass.apply(test_graph) freeze_pass.apply(test_graph)
freezed_program = test_graph.to_program() freezed_program = test_graph.to_program()
if save_int8: if save_int8:
convert_int8_pass = ConvertToInt8Pass( convert_int8_pass = ConvertToInt8Pass(
scope=fluid.global_scope(), place=place) scope=fluid.global_scope(),
place=place,
quantizable_op_type=support_op_types)
convert_int8_pass.apply(test_graph) convert_int8_pass.apply(test_graph)
freezed_program_int8 = test_graph.to_program() freezed_program_int8 = test_graph.to_program()
return freezed_program, freezed_program_int8 return freezed_program, freezed_program_int8
......
...@@ -32,17 +32,6 @@ max_version, mid_version, min_version = python_version() ...@@ -32,17 +32,6 @@ max_version, mid_version, min_version = python_version()
with open('./requirements.txt') as f: with open('./requirements.txt') as f:
setup_requires = f.read().splitlines() setup_requires = f.read().splitlines()
packages = [
'paddleslim',
'paddleslim.prune',
'paddleslim.dist',
'paddleslim.nas',
'paddleslim.analysis',
'paddleslim.quant',
'paddleslim.core',
'paddleslim.common',
]
setup( setup(
name='paddleslim', name='paddleslim',
version=slim_version, version=slim_version,
...@@ -52,7 +41,7 @@ setup( ...@@ -52,7 +41,7 @@ setup(
author='PaddlePaddle Author', author='PaddlePaddle Author',
author_email='dltp-all@baidu.com', author_email='dltp-all@baidu.com',
install_requires=setup_requires, install_requires=setup_requires,
packages=packages, packages=find_packages(),
# PyPI package information. # PyPI package information.
classifiers=[ classifiers=[
'Development Status :: 4 - Beta', 'Development Status :: 4 - Beta',
......
...@@ -15,7 +15,7 @@ import sys ...@@ -15,7 +15,7 @@ import sys
sys.path.append("../") sys.path.append("../")
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
from paddleslim.prune import Pruner from paddleslim.prune.walk_pruner import Pruner
from layers import conv_bn_layer from layers import conv_bn_layer
...@@ -72,6 +72,7 @@ class TestPrune(unittest.TestCase): ...@@ -72,6 +72,7 @@ class TestPrune(unittest.TestCase):
for param in main_program.global_block().all_parameters(): for param in main_program.global_block().all_parameters():
if "weights" in param.name: if "weights" in param.name:
print("param: {}; param shape: {}".format(param.name, param.shape))
self.assertTrue(param.shape == shapes[param.name]) self.assertTrue(param.shape == shapes[param.name])
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import unittest
import paddle.fluid as fluid
from paddleslim.prune import Pruner
from paddleslim.core import GraphWrapper
from paddleslim.prune import conv2d as conv2d_walker
from layers import conv_bn_layer
class TestPrune(unittest.TestCase):
def test_prune(self):
main_program = fluid.Program()
startup_program = fluid.Program()
# X X O X O
# conv1-->conv2-->sum1-->conv3-->conv4-->sum2-->conv5-->conv6
# | ^ | ^
# |____________| |____________________|
#
# X: prune output channels
# O: prune input channels
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[None, 3, 16, 16])
conv1 = conv_bn_layer(input, 8, 3, "conv1")
conv2 = conv_bn_layer(conv1, 8, 3, "conv2")
sum1 = conv1 + conv2
conv3 = conv_bn_layer(sum1, 8, 3, "conv3")
conv4 = conv_bn_layer(conv3, 8, 3, "conv4")
sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
shapes = {}
for param in main_program.global_block().all_parameters():
shapes[param.name] = param.shape
place = fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.Scope()
exe.run(startup_program, scope=scope)
graph = GraphWrapper(main_program)
conv_op = graph.var("conv4_weights").outputs()[0]
walker = conv2d_walker(conv_op, [])
walker.prune(graph.var("conv4_weights"), pruned_axis=0, pruned_idx=[])
print walker.pruned_params
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册