未验证 提交 790a9ffb 编写于 作者: W whs 提交者: GitHub

Add pruning walker. (#5)

上级 cc1cd6f4
from .mobilenet import MobileNet
from .resnet import ResNet34, ResNet50
from .mobilenet_v2 import MobileNetV2
from .pvanet import PVANet
__all__ = ['MobileNet', 'ResNet34', 'ResNet50', 'MobileNetV2']
__all__ = ['MobileNet', 'ResNet34', 'ResNet50', 'MobileNetV2', 'PVANet']
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
import os, sys, time, math
import numpy as np
from collections import namedtuple
BLOCK_TYPE_MCRELU = 'BLOCK_TYPE_MCRELU'
BLOCK_TYPE_INCEP = 'BLOCK_TYPE_INCEP'
BlockConfig = namedtuple('BlockConfig',
'stride, num_outputs, preact_bn, block_type')
__all__ = ['PVANet']
class PVANet():
def __init__(self):
pass
def net(self, input, include_last_bn_relu=True, class_dim=1000):
conv1 = self._conv_bn_crelu(input, 16, 7, stride=2, name="conv1_1")
pool1 = fluid.layers.pool2d(
input=conv1,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max',
name='pool1')
end_points = {}
conv2 = self._conv_stage(
pool1,
block_configs=[
BlockConfig(1, (24, 24, 48), False, BLOCK_TYPE_MCRELU),
BlockConfig(1, (24, 24, 48), True, BLOCK_TYPE_MCRELU),
BlockConfig(1, (24, 24, 48), True, BLOCK_TYPE_MCRELU)
],
name='conv2',
end_points=end_points)
conv3 = self._conv_stage(
conv2,
block_configs=[
BlockConfig(2, (48, 48, 96), True, BLOCK_TYPE_MCRELU),
BlockConfig(1, (48, 48, 96), True, BLOCK_TYPE_MCRELU),
BlockConfig(1, (48, 48, 96), True, BLOCK_TYPE_MCRELU),
BlockConfig(1, (48, 48, 96), True, BLOCK_TYPE_MCRELU)
],
name='conv3',
end_points=end_points)
conv4 = self._conv_stage(
conv3,
block_configs=[
BlockConfig(2, '64 48-96 24-48-48 96 128', True,
BLOCK_TYPE_INCEP),
BlockConfig(1, '64 64-96 24-48-48 128', True,
BLOCK_TYPE_INCEP),
BlockConfig(1, '64 64-96 24-48-48 128', True,
BLOCK_TYPE_INCEP),
BlockConfig(1, '64 64-96 24-48-48 128', True, BLOCK_TYPE_INCEP)
],
name='conv4',
end_points=end_points)
conv5 = self._conv_stage(
conv4,
block_configs=[
BlockConfig(2, '64 96-128 32-64-64 128 196', True,
BLOCK_TYPE_INCEP),
BlockConfig(1, '64 96-128 32-64-64 196', True,
BLOCK_TYPE_INCEP),
BlockConfig(1, '64 96-128 32-64-64 196', True,
BLOCK_TYPE_INCEP), BlockConfig(
1, '64 96-128 32-64-64 196', True,
BLOCK_TYPE_INCEP)
],
name='conv5',
end_points=end_points)
if include_last_bn_relu:
conv5 = self._bn(conv5, 'relu', 'conv5_4_last_bn')
end_points['conv5'] = conv5
output = fluid.layers.fc(input=input,
size=class_dim,
act='softmax',
param_attr=ParamAttr(
initializer=MSRA(), name="fc_weights"),
bias_attr=ParamAttr(name="fc_offset"))
return output
def _conv_stage(self, input, block_configs, name, end_points):
net = input
for idx, bc in enumerate(block_configs):
if bc.block_type == BLOCK_TYPE_MCRELU:
block_scope = '{}_{}'.format(name, idx + 1)
fn = self._mCReLU
elif bc.block_type == BLOCK_TYPE_INCEP:
block_scope = '{}_{}_incep'.format(name, idx + 1)
fn = self._inception_block
net = fn(net, bc, block_scope)
end_points[block_scope] = net
end_points[name] = net
return net
def _mCReLU(self, input, mc_config, name):
"""
every cReLU has at least three conv steps:
conv_bn_relu, conv_bn_crelu, conv_bn_relu
if the inputs has a different number of channels as crelu output,
an extra 1x1 conv is added before sum.
"""
if mc_config.preact_bn:
conv1_fn = self._bn_relu_conv
conv1_scope = name + '_1'
else:
conv1_fn = self._conv
conv1_scope = name + '_1_conv'
sub_conv1 = conv1_fn(input, mc_config.num_outputs[0], 1, conv1_scope,
mc_config.stride)
sub_conv2 = self._bn_relu_conv(sub_conv1, mc_config.num_outputs[1], 3,
name + '_2')
sub_conv3 = self._bn_crelu_conv(sub_conv2, mc_config.num_outputs[2], 1,
name + '_3')
if int(input.shape[1]) == mc_config.num_outputs[2]:
conv_proj = input
else:
conv_proj = self._conv(input, mc_config.num_outputs[2], 1,
name + '_proj', mc_config.stride)
conv = sub_conv3 + conv_proj
return conv
def _inception_block(self, input, block_config, name):
num_outputs = block_config.num_outputs.split() # e.g. 64 24-48-48 128
num_outputs = [map(int, s.split('-')) for s in num_outputs]
inception_outputs = num_outputs[-1][0]
num_outputs = num_outputs[:-1]
stride = block_config.stride
pool_path_outputs = None
if stride > 1:
pool_path_outputs = num_outputs[-1][0]
num_outputs = num_outputs[:-1]
scopes = [['_0']] # follow the name style of caffe pva
kernel_sizes = [[1]]
for path_idx, path_outputs in enumerate(num_outputs[1:]):
path_idx += 1
path_scopes = ['_{}_reduce'.format(path_idx)]
path_scopes.extend([
'_{}_{}'.format(path_idx, i - 1)
for i in range(1, len(path_outputs))
])
scopes.append(path_scopes)
path_kernel_sizes = [1, 3, 3][:len(path_outputs)]
kernel_sizes.append(path_kernel_sizes)
paths = []
if block_config.preact_bn:
preact = self._bn(input, 'relu', name + '_bn')
else:
preact = input
path_params = zip(num_outputs, scopes, kernel_sizes)
for path_idx, path_param in enumerate(path_params):
path_net = preact
for conv_idx, (num_output, scope,
kernel_size) in enumerate(zip(*path_param)):
if conv_idx == 0:
conv_stride = stride
else:
conv_stride = 1
path_net = self._conv_bn_relu(path_net, num_output,
kernel_size, name + scope,
conv_stride)
paths.append(path_net)
if stride > 1:
path_net = fluid.layers.pool2d(
input,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max',
name=name + '_pool')
path_net = self._conv_bn_relu(path_net, pool_path_outputs, 1,
name + '_poolproj')
paths.append(path_net)
block_net = fluid.layers.concat(paths, axis=1)
block_net = self._conv(block_net, inception_outputs, 1,
name + '_out_conv')
if int(input.shape[1]) == inception_outputs:
proj = input
else:
proj = self._conv(input, inception_outputs, 1, name + '_proj',
stride)
return block_net + proj
def _scale(self, input, name, axis=1, num_axes=1):
assert num_axes == 1, "layer scale not support this num_axes[%d] now" % (
num_axes)
prefix = name + '_'
scale_shape = input.shape[axis:axis + num_axes]
param_attr = fluid.ParamAttr(name=prefix + 'gamma')
scale_param = fluid.layers.create_parameter(
shape=scale_shape,
dtype=input.dtype,
name=name,
attr=param_attr,
is_bias=True,
default_initializer=fluid.initializer.Constant(value=1.0))
offset_attr = fluid.ParamAttr(name=prefix + 'beta')
offset_param = fluid.layers.create_parameter(
shape=scale_shape,
dtype=input.dtype,
name=name,
attr=offset_attr,
is_bias=True,
default_initializer=fluid.initializer.Constant(value=0.0))
output = fluid.layers.elementwise_mul(
input, scale_param, axis=axis, name=prefix + 'mul')
output = fluid.layers.elementwise_add(
output, offset_param, axis=axis, name=prefix + 'add')
return output
def _conv(self,
input,
num_filters,
filter_size,
name,
stride=1,
groups=1,
act=None):
net = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=act,
use_cudnn=True,
param_attr=ParamAttr(name=name + '_weights'),
bias_attr=ParamAttr(name=name + '_bias'),
name=name)
return net
def _bn(self, input, act, name):
net = fluid.layers.batch_norm(
input=input,
act=act,
name=name,
moving_mean_name=name + '_mean',
moving_variance_name=name + '_variance',
param_attr=ParamAttr(name=name + '_scale'),
bias_attr=ParamAttr(name=name + '_offset'))
return net
def _bn_relu_conv(self,
input,
num_filters,
filter_size,
name,
stride=1,
groups=1):
net = self._bn(input, 'relu', name + '_bn')
net = self._conv(net, num_filters, filter_size, name + '_conv', stride,
groups)
return net
def _conv_bn_relu(self,
input,
num_filters,
filter_size,
name,
stride=1,
groups=1):
net = self._conv(input, num_filters, filter_size, name + '_conv',
stride, groups)
net = self._bn(net, 'relu', name + '_bn')
return net
def _bn_crelu(self, input, name):
net = self._bn(input, None, name + '_bn_1')
neg_net = fluid.layers.scale(net, scale=-1.0, name=name + '_neg')
net = fluid.layers.concat([net, neg_net], axis=1)
net = self._scale(net, name + '_scale')
net = fluid.layers.relu(net, name=name + '_relu')
return net
def _conv_bn_crelu(self,
input,
num_filters,
filter_size,
name,
stride=1,
groups=1,
act=None):
net = self._conv(input, num_filters, filter_size, name + '_conv',
stride, groups)
net = self._bn_crelu(net, name)
return net
def _bn_crelu_conv(self,
input,
num_filters,
filter_size,
name,
stride=1,
groups=1,
act=None):
net = self._bn_crelu(input, name)
net = self._conv(net, num_filters, filter_size, name + '_conv', stride,
groups)
return net
def deconv_bn_layer(self,
input,
num_filters,
filter_size=4,
stride=2,
padding=1,
act='relu',
name=None):
"""Deconv bn layer."""
deconv = fluid.layers.conv2d_transpose(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
act=None,
param_attr=ParamAttr(name=name + '_weights'),
bias_attr=ParamAttr(name=name + '_bias'),
name=name + 'deconv')
return self._bn(deconv, act, name + '_bn')
def conv_bn_layer(self,
input,
num_filters,
filter_size,
name,
stride=1,
groups=1):
return self._conv_bn_relu(input, num_filters, filter_size, name,
stride, groups)
def Fpn_Fusion(blocks, net):
f = [blocks['conv5'], blocks['conv4'], blocks['conv3'], blocks['conv2']]
num_outputs = [64] * len(f)
g = [None] * len(f)
h = [None] * len(f)
for i in range(len(f)):
h[i] = net.conv_bn_layer(f[i], num_outputs[i], 1, 'fpn_pre_' + str(i))
for i in range(len(f) - 1):
if i == 0:
g[i] = net.deconv_bn_layer(h[i], num_outputs[i], name='fpn_0')
else:
out = fluid.layers.elementwise_add(x=g[i - 1], y=h[i])
out = net.conv_bn_layer(out, num_outputs[i], 1,
'fpn_trans_' + str(i))
g[i] = net.deconv_bn_layer(
out, num_outputs[i], name='fpn_' + str(i))
out = fluid.layers.elementwise_add(x=g[-2], y=h[-1])
out = net.conv_bn_layer(out, num_outputs[-1], 1, 'fpn_post_0')
out = net.conv_bn_layer(out, num_outputs[-1], 3, 'fpn_post_1')
return out
def Detector_Header(f_common, net, class_num):
"""Detector header."""
f_geo = net.conv_bn_layer(f_common, 64, 1, name='geo_1')
f_geo = net.conv_bn_layer(f_geo, 64, 3, name='geo_2')
f_geo = net.conv_bn_layer(f_geo, 64, 1, name='geo_3')
f_geo = fluid.layers.conv2d(
f_geo,
8,
1,
use_cudnn=True,
param_attr=ParamAttr(name='geo_4_conv_weights'),
bias_attr=ParamAttr(name='geo_4_conv_bias'),
name='geo_4_conv')
name = 'score_class_num' + str(class_num + 1)
f_score = net.conv_bn_layer(f_common, 64, 1, 'score_1')
f_score = net.conv_bn_layer(f_score, 64, 3, 'score_2')
f_score = net.conv_bn_layer(f_score, 64, 1, 'score_3')
f_score = fluid.layers.conv2d(
f_score,
class_num + 1,
1,
use_cudnn=True,
param_attr=ParamAttr(name=name + '_conv_weights'),
bias_attr=ParamAttr(name=name + '_conv_bias'),
name=name + '_conv')
f_score = fluid.layers.transpose(f_score, perm=[0, 2, 3, 1])
f_score = fluid.layers.reshape(f_score, shape=[-1, class_num + 1])
f_score = fluid.layers.softmax(input=f_score)
return f_score, f_geo
def east(input, class_num=31):
net = PVANet()
out = net.net(input)
blocks = []
for i, j, k in zip(['conv2', 'conv3', 'conv4', 'conv5'], [1, 2, 4, 8],
[64, 64, 64, 64]):
if j == 1:
conv = net.conv_bn_layer(
out[i], k, 1, name='fusion_' + str(len(blocks)))
elif j <= 4:
conv = net.deconv_bn_layer(
out[i], k, 2 * j, j, j // 2,
name='fusion_' + str(len(blocks)))
else:
conv = net.deconv_bn_layer(
out[i], 32, 8, 4, 2, name='fusion_' + str(len(blocks)) + '_1')
conv = net.deconv_bn_layer(
conv,
k,
j // 2,
j // 4,
j // 8,
name='fusion_' + str(len(blocks)) + '_2')
blocks.append(conv)
conv = fluid.layers.concat(blocks, axis=1)
f_score, f_geo = Detector_Header(conv, net, class_num)
return f_score, f_geo
def inference(input, class_num=1, nms_thresh=0.2, score_thresh=0.5):
f_score, f_geo = east(input, class_num)
print("f_geo shape={}".format(f_geo.shape))
print("f_score shape={}".format(f_score.shape))
f_score = fluid.layers.transpose(f_score, perm=[1, 0])
return f_score, f_geo
def loss(f_score, f_geo, l_score, l_geo, l_mask, class_num=1):
'''
predictions: f_score: -1 x 1 x H x W; f_geo: -1 x 8 x H x W
targets: l_score: -1 x 1 x H x W; l_geo: -1 x 1 x H x W; l_mask: -1 x 1 x H x W
return: dice_loss + smooth_l1_loss
'''
#smooth_l1_loss
channels = 8
l_geo_split, l_short_edge = fluid.layers.split(
l_geo, num_or_sections=[channels, 1],
dim=1) #last channel is short_edge_norm
f_geo_split = fluid.layers.split(f_geo, num_or_sections=[channels], dim=1)
f_geo_split = f_geo_split[0]
geo_diff = l_geo_split - f_geo_split
abs_geo_diff = fluid.layers.abs(geo_diff)
l_flag = l_score >= 1
l_flag = fluid.layers.cast(x=l_flag, dtype="float32")
l_flag = fluid.layers.expand(x=l_flag, expand_times=[1, channels, 1, 1])
smooth_l1_sign = abs_geo_diff < l_flag
smooth_l1_sign = fluid.layers.cast(x=smooth_l1_sign, dtype="float32")
in_loss = abs_geo_diff * abs_geo_diff * smooth_l1_sign + (
abs_geo_diff - 0.5) * (1.0 - smooth_l1_sign)
l_short_edge = fluid.layers.expand(
x=l_short_edge, expand_times=[1, channels, 1, 1])
out_loss = l_short_edge * in_loss * l_flag
out_loss = out_loss * l_flag
smooth_l1_loss = fluid.layers.reduce_mean(out_loss)
##softmax_loss
l_score.stop_gradient = True
l_score = fluid.layers.transpose(l_score, perm=[0, 2, 3, 1])
l_score.stop_gradient = True
l_score = fluid.layers.reshape(l_score, shape=[-1, 1])
l_score.stop_gradient = True
l_score = fluid.layers.cast(x=l_score, dtype="int64")
l_score.stop_gradient = True
softmax_loss = fluid.layers.cross_entropy(input=f_score, label=l_score)
softmax_loss = fluid.layers.reduce_mean(softmax_loss)
return softmax_loss, smooth_l1_loss
......@@ -40,11 +40,33 @@ add_arg('test_period', int, 10, "Test period in epoches.")
model_list = [m for m in dir(models) if "__" not in m]
def get_pruned_params(args, program):
params = []
if args.model == "MobileNet":
for param in program.global_block().all_parameters():
if "_sep_weights" in param.name:
params.append(param.name)
elif args.model == "MobileNetV2":
for param in program.global_block().all_parameters():
if "linear_weights" in param.name or "expand_weights" in param.name:
params.append(param.name)
elif args.model == "ResNet34":
for param in program.global_block().all_parameters():
if "weights" in param.name and "branch" in param.name:
params.append(param.name)
elif args.model == "PVANet":
for param in program.global_block().all_parameters():
if "conv_weights" in param.name:
params.append(param.name)
return params
def piecewise_decay(args):
step = int(math.ceil(float(args.total_images) / args.batch_size))
bd = [step * e for e in args.step_epochs]
lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
learning_rate = fluid.layers.piecewise_decay(boundaries=bd, values=lr)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=args.momentum_rate,
......@@ -176,14 +198,11 @@ def compress(args):
end_time - start_time))
batch_id += 1
params = []
for param in fluid.default_main_program().global_block().all_parameters():
if "_sep_weights" in param.name:
params.append(param.name)
_logger.info("fops before pruning: {}".format(
params = get_pruned_params(args, fluid.default_main_program())
_logger.info("FLOPs before pruning: {}".format(
flops(fluid.default_main_program())))
pruner = Pruner()
pruned_val_program = pruner.prune(
pruned_val_program, _, _ = pruner.prune(
val_program,
fluid.global_scope(),
params=params,
......@@ -191,19 +210,13 @@ def compress(args):
place=place,
only_graph=True)
pruned_program = pruner.prune(
pruned_program, _, _ = pruner.prune(
fluid.default_main_program(),
fluid.global_scope(),
params=params,
ratios=[0.33] * len(params),
place=place)
for param in pruned_program[0].global_block().all_parameters():
if "weights" in param.name:
print param.name, param.shape
return
_logger.info("fops after pruning: {}".format(flops(pruned_program)))
_logger.info("FLOPs after pruning: {}".format(flops(pruned_program)))
for i in range(args.num_epochs):
train(i, pruned_program)
if i % args.test_period == 0:
......
# 卷积通道剪裁示例
本示例将演示如何按指定的剪裁率对每个卷积层的通道数进行剪裁。该示例默认会自动下载并使用mnist数据。
当前示例支持以下分类模型:
- MobileNetV1
- MobileNetV2
- ResNet50
- PVANet
## 接口介绍
该示例使用了`paddleslim.Pruner`工具类,用户接口使用介绍请参考:[API文档](https://paddlepaddle.github.io/PaddleSlim/api/prune_api/)
## 确定待裁参数
不同模型的参数命名不同,在剪裁前需要确定待裁卷积层的参数名称。可通过以下方法列出所有参数名:
```
for param in program.global_block().all_parameters():
print("param name: {}; shape: {}".format(param.name, param.shape))
```
`train.py`脚本中,提供了`get_pruned_params`方法,根据用户设置的选项`--model`确定要裁剪的参数。
## 启动裁剪任务
通过以下命令启动裁剪任务:
```
export CUDA_VISIBLE_DEVICES=0
python train.py
```
执行`python train.py --help`查看更多选项。
## 注意
1. 在接口`paddle.Pruner.prune`的参数中,`params``ratios`的长度需要一样。
......@@ -36,7 +36,7 @@ def flops(program, only_conv=True, detail=False):
return _graph_flops(graph, only_conv=only_conv, detail=detail)
def _graph_flops(graph, only_conv=False, detail=False):
def _graph_flops(graph, only_conv=True, detail=False):
assert isinstance(graph, GraphWrapper)
flops = 0
params2flops = {}
......@@ -66,12 +66,14 @@ def _graph_flops(graph, only_conv=False, detail=False):
y_shape = op.inputs("Y")[0].shape()
if x_shape[0] == -1:
x_shape[0] = 1
flops += x_shape[0] * x_shape[1] * y_shape[1]
op_flops = x_shape[0] * x_shape[1] * y_shape[1]
flops += op_flops
params2flops[op.inputs("Y")[0].name()] = op_flops
elif op.type() in ['relu', 'sigmoid', 'batch_norm', 'relu6'] and not only_conv:
elif op.type() in ['relu', 'sigmoid', 'batch_norm', 'relu6'
] and not only_conv:
input_shape = list(op.inputs("X")[0].shape())
if input_shape[0] == -1:
input_shape[0] = 1
......
......@@ -93,6 +93,8 @@ class VarWrapper(object):
ops.append(op)
return ops
def is_parameter(self):
return isinstance(self._var, Parameter)
class OpWrapper(object):
def __init__(self, op, graph):
......
......@@ -23,6 +23,8 @@ from .sensitive_pruner import *
import sensitive_pruner
from .sensitive import *
import sensitive
from prune_walker import *
import prune_walker
__all__ = []
......@@ -32,3 +34,4 @@ __all__ += controller_server.__all__
__all__ += controller_client.__all__
__all__ += sensitive_pruner.__all__
__all__ += sensitive.__all__
__all__ += prune_walker.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
from ..core import Registry
from ..common import get_logger
__all__ = ["PRUNE_WORKER", "conv2d"]
_logger = get_logger(__name__, level=logging.INFO)
PRUNE_WORKER = Registry('prune_worker')
class PruneWorker(object):
def __init__(self, op, pruned_params=[], visited={}):
"""
A wrapper of operator used to infer the information of all the related variables.
Args:
op(Operator): The operator to be pruned.
pruned_params(list): The list to store the information of pruning that infered by walker.
visited(dict): The auxiliary dict to record the visited operators and variables. The key is a encoded string of operator id and variable name.
Return: A instance of PruneWalker.
"""
self.op = op
self.pruned_params = pruned_params
self.visited = visited
def prune(self, var, pruned_axis, pruned_idx):
"""
Infer the shape of variables related with current operator, predecessor and successor.
It will search the graph to find all varibles related with `var` and record the information of pruning.
Args:
var(Variable): The root variable of searching. It can be the input or output of current operator.
pruned_axis(int): The axis to be pruned of root variable.
pruned_idx(int): The indexes to be pruned in `pruned_axis` of root variable.
"""
key = "_".join([str(self.op.idx()), var.name()])
if pruned_axis not in self.visited:
self.visited[pruned_axis] = {}
if key in self.visited[pruned_axis]:
return
else:
self.visited[pruned_axis][key] = True
self._prune(var, pruned_axis, pruned_idx)
def _prune(self, var, pruned_axis, pruned_idx):
raise NotImplementedError('Abstract method.')
def _prune_op(self, op, var, pruned_axis, pruned_idx, visited=None):
if op.type().endswith("_grad"):
return
if visited is not None:
self.visited = visited
cls = PRUNE_WORKER.get(op.type())
assert cls is not None, "The walker of {} is not registered.".format(
op.type())
_logger.debug("\nfrom: {}\nto: {}\npruned_axis: {}; var: {}".format(
self.op, op, pruned_axis, var.name()))
walker = cls(op,
pruned_params=self.pruned_params,
visited=self.visited)
walker.prune(var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
class conv2d(PruneWorker):
def __init__(self, op, pruned_params, visited={}):
super(conv2d, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
data_format = sef.op.attr("data_format")
channel_axis = 1
if data_format == "NHWC":
channel_axis = 3
if var in self.op.inputs("Input"):
assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}; var: {}".format(
pruned_axis, var.name())
filter_var = self.op.inputs("Filter")[0]
key = "_".join([str(self.op.idx()), filter_var.name()])
self.visited[1][key] = True
self.pruned_params.append((filter_var, 1, pruned_idx))
for op in filter_var.outputs():
self._prune_op(op, filter_var, 1, pruned_idx)
elif var in self.op.inputs("Filter"):
assert pruned_axis in [0, 1]
self.pruned_params.append((var, pruned_axis, pruned_idx))
for op in var.outputs():
self._prune_op(op, var, pruned_axis, pruned_idx)
if pruned_axis == 0:
if len(self.op.inputs("Bias")) > 0:
self.pruned_params.append(
(self.op.inputs("Bias"), channel_axis, pruned_idx))
output_var = self.op.outputs("Output")[0]
key = "_".join([str(self.op.idx()), output_var.name()])
self.visited[channel_axis][key] = True
next_ops = output_var.outputs()
for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx)
elif pruned_axis == 1:
input_var = self.op.inputs("Input")[0]
key = "_".join([str(self.op.idx()), input_var.name()])
self.visited[channel_axis][key] = True
pre_ops = input_var.inputs()
for op in pre_ops:
self._prune_op(op, input_var, channel_axis, pruned_idx)
elif var in self.op.outputs("Output"):
assert pruned_axis == channel_axis, "pruned_axis: {}; var: {}".format(
pruned_axis, var.name())
filter_var = self.op.inputs("Filter")[0]
key = "_".join([str(self.op.idx()), filter_var.name()])
self.visited[0][key] = True
self.pruned_params.append((filter_var, 0, pruned_idx))
for op in filter_var.outputs():
self._prune_op(op, filter_var, 0, pruned_idx)
if len(self.op.inputs("Bias")) > 0:
self.pruned_params.append(
(self.op.inputs("Bias")[0], channel_axis, pruned_idx))
output_var = self.op.outputs("Output")[0]
next_ops = output_var.outputs()
for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx)
@PRUNE_WORKER.register
class batch_norm(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(batch_norm, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
if (var not in self.op.outputs("Y")) and (
var not in self.op.inputs("X")):
return
if var in self.op.outputs("Y"):
in_var = self.op.inputs("X")[0]
key = "_".join([str(self.op.idx()), in_var.name()])
self.visited[pruned_axis][key] = True
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
for param in ["Scale", "Bias", "Mean", "Variance"]:
param_var = self.op.inputs(param)[0]
for op in param_var.outputs():
self._prune_op(op, param_var, 0, pruned_idx)
self.pruned_params.append((param_var, 0, pruned_idx))
out_var = self.op.outputs("Y")[0]
key = "_".join([str(self.op.idx()), out_var.name()])
self.visited[pruned_axis][key] = True
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
class elementwise_op(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(elementwise_op, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
axis = self.op.attr("axis")
if axis == -1: # TODO
axis = 0
if var in self.op.outputs("Out"):
for name in ["X", "Y"]:
actual_axis = pruned_axis
if name == "Y":
actual_axis = pruned_axis - axis
in_var = self.op.inputs(name)[0]
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, actual_axis, pruned_idx)
else:
if var in self.op.inputs("X"):
in_var = self.op.inputs("Y")[0]
if in_var.is_parameter():
self.pruned_params.append(
(in_var, pruned_axis - axis, pruned_idx))
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis - axis, pruned_idx)
elif var in self.op.inputs("Y"):
in_var = self.op.inputs("X")[0]
pre_ops = in_var.inputs()
pruned_axis = pruned_axis + axis
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
out_var = self.op.outputs("Out")[0]
key = "_".join([str(self.op.idx()), out_var.name()])
self.visited[pruned_axis][key] = True
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
class elementwise_add(elementwise_op):
def __init__(self, op, pruned_params, visited):
super(elementwise_add, self).__init__(op, pruned_params, visited)
@PRUNE_WORKER.register
class elementwise_sub(elementwise_op):
def __init__(self, op, pruned_params, visited):
super(elementwise_sub, self).__init__(op, pruned_params, visited)
@PRUNE_WORKER.register
class elementwise_mul(elementwise_op):
def __init__(self, op, pruned_params, visited):
super(elementwise_mul, self).__init__(op, pruned_params, visited)
class activation(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(activation, self).__init__(op, pruned_params, visited)
self.input_name = "X"
self.output_name = "Out"
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.outputs(self.output_name):
in_var = self.op.inputs(self.input_name)[0]
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
out_var = self.op.outputs(self.output_name)[0]
key = "_".join([str(self.op.idx()), out_var.name()])
self.visited[pruned_axis][key] = True
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
class uniform_random_batch_size_like(activation):
def __init__(self, op, pruned_params, visited):
super(uniform_random_batch_size_like, self).__init__(op, pruned_params,
visited)
self.input_name = "Input"
self.output_name = "Out"
@PRUNE_WORKER.register
class bilinear_interp(activation):
def __init__(self, op, pruned_params, visited):
super(bilinear_interp, self).__init__(op, pruned_params, visited)
@PRUNE_WORKER.register
class relu(activation):
def __init__(self, op, pruned_params, visited):
super(relu, self).__init__(op, pruned_params, visited)
@PRUNE_WORKER.register
class floor(activation):
def __init__(self, op, pruned_params, visited):
super(floor, self).__init__(op, pruned_params, visited)
@PRUNE_WORKER.register
class relu6(activation):
def __init__(self, op, pruned_params, visited):
super(relu6, self).__init__(op, pruned_params, visited)
@PRUNE_WORKER.register
class pool2d(activation):
def __init__(self, op, pruned_params, visited):
super(pool2d, self).__init__(op, pruned_params, visited)
@PRUNE_WORKER.register
class sum(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(sum, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.outputs("Out"):
for in_var in self.op.inputs("X"):
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
elif var in self.op.inputs("X"):
for in_var in self.op.inputs("X"):
if in_var != var:
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
out_var = self.op.outputs("Out")[0]
key = "_".join([str(self.op.idx()), out_var.name()])
self.visited[pruned_axis][key] = True
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
class concat(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(concat, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
idx = []
axis = self.op.attr("axis")
if var in self.op.outputs("Out"):
start = 0
if axis == pruned_axis:
for _, in_var in enumerate(self.op.inputs("X")):
idx = []
for i in pruned_idx:
r_idx = i - start
if r_idx < in_var.shape()[pruned_axis] and r_idx >= 0:
idx.append(r_idx)
start += in_var.shape()[pruned_axis]
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, idx)
idx = pruned_idx[:]
else:
for _, in_var in enumerate(self.op.inputs("X")):
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
elif var in self.op.inputs("X"):
if axis == pruned_axis:
idx = []
start = 0
for v in self.op.inputs("X"):
if v.name() == var.name():
idx = [i + start for i in pruned_idx]
else:
start += v.shape()[pruned_axis]
out_var = self.op.outputs("Out")[0]
key = "_".join([str(self.op.idx()), out_var.name()])
self.visited[pruned_axis][key] = True
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, idx, visited={})
else:
for v in self.op.inputs("X"):
for op in v.inputs():
self._prune_op(op, v, pruned_axis, pruned_idx)
out_var = self.op.outputs("Out")[0]
key = "_".join([str(self.op.idx()), out_var.name()])
self.visited[pruned_axis][key] = True
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
class depthwise_conv2d(PruneWorker):
def __init__(self, op, pruned_params, visited={}):
super(depthwise_conv2d, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
data_format = sef.op.attr("data_format")
channel_axis = 1
if data_format == "NHWC":
channel_axis = 3
if var in self.op.inputs("Input"):
assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}".format(
pruned_axis)
filter_var = self.op.inputs("Filter")[0]
self.pruned_params.append((filter_var, 0, pruned_idx))
key = "_".join([str(self.op.idx()), filter_var.name()])
self.visited[0][key] = True
new_groups = filter_var.shape()[0] - len(pruned_idx)
self.op.set_attr("groups", new_groups)
for op in filter_var.outputs():
self._prune_op(op, filter_var, 0, pruned_idx)
output_var = self.op.outputs("Output")[0]
next_ops = output_var.outputs()
for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx)
elif var in self.op.inputs("Filter"):
assert pruned_axis in [0]
if pruned_axis == 0:
if len(self.op.inputs("Bias")) > 0:
self.pruned_params.append(
(self.op.inputs("Bias"), channel_axis, pruned_idx))
self.pruned_params.append((var, 0, pruned_idx))
new_groups = var.shape()[0] - len(pruned_idx)
self.op.set_attr("groups", new_groups)
for op in var.outputs():
self._prune_op(op, var, 0, pruned_idx)
output_var = self.op.outputs("Output")[0]
key = "_".join([str(self.op.idx()), output_var.name()])
self.visited[channel_axis][key] = True
next_ops = output_var.outputs()
for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx)
for op in var.outputs():
self._prune_op(op, var, pruned_axis, pruned_idx)
elif var in self.op.outputs("Output"):
assert pruned_axis == channel_axis
filter_var = self.op.inputs("Filter")[0]
self.pruned_params.append((filter_var, 0, pruned_idx))
key = "_".join([str(self.op.idx()), filter_var.name()])
self.visited[0][key] = True
new_groups = filter_var.shape()[0] - len(pruned_idx)
op.set_attr("groups", new_groups)
for op in filter_var.outputs():
self._prune_op(op, filter_var, 0, pruned_idx)
if len(self.op.inputs("Bias")) > 0:
self.pruned_params.append(
(self.op.inputs("Bias")[0], channel_axis, pruned_idx))
in_var = self.op.inputs("Input")[0]
key = "_".join([str(self.op.idx()), in_var.name()])
self.visited[channel_axis][key] = True
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, channel_axis, pruned_idx)
output_var = self.op.outputs("Output")[0]
next_ops = output_var.outputs()
for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx)
@PRUNE_WORKER.register
class mul(PruneWorker):
def __init__(self, op, pruned_params, visited={}):
super(mul, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.inputs("X"):
assert pruned_axis == 1, "The Input of conv2d can only be pruned at axis 1, but got {}".format(
pruned_axis)
idx = []
feature_map_size = var.shape()[2] * var.shape()[3]
range_idx = np.array(range(feature_map_size))
for i in pruned_idx:
idx += list(range_idx + i * feature_map_size)
param_var = self.op.inputs("Y")[0]
self.pruned_params.append((param_var, 0, idx))
for op in param_var.outputs():
self._prune_op(op, param_var, 0, pruned_idx)
@PRUNE_WORKER.register
class scale(PruneWorker):
def __init__(self, op, pruned_params, visited={}):
super(scale, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.inputs("X"):
out_var = self.op.outputs("Out")[0]
for op in out_var.outputs():
self._prune_op(op, out_var, pruned_axis, pruned_idx)
elif var in self.op.outputs("Out"):
in_var = self.op.inputs("X")[0]
for op in in_var.inputs():
self._prune_op(op, in_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
class momentum(PruneWorker):
def __init__(self, op, pruned_params, visited={}):
super(momentum, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.inputs("Param"):
_logger.debug("pruning momentum, var:{}".format(var.name()))
velocity_var = self.op.inputs("Velocity")[0]
self.pruned_params.append((velocity_var, pruned_axis, pruned_idx))
@PRUNE_WORKER.register
class adam(PruneWorker):
def __init__(self, op, pruned_params, visited={}):
super(adam, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.inputs("Param"):
_logger.debug("pruning momentum, var:{}".format(var.name()))
moment1_var = self.op.inputs("Moment1")[0]
self.pruned_params.append((moment1_var, pruned_axis, pruned_idx))
moment2_var = self.op.inputs("Moment2")[0]
self.pruned_params.append((moment2_var, pruned_axis, pruned_idx))
......@@ -17,6 +17,7 @@ import numpy as np
import paddle.fluid as fluid
import copy
from ..core import VarWrapper, OpWrapper, GraphWrapper
from .prune_walker import conv2d as conv2d_walker
from ..common import get_logger
__all__ = ["Pruner"]
......@@ -67,561 +68,60 @@ class Pruner():
graph = GraphWrapper(program.clone())
param_backup = {} if param_backup else None
param_shape_backup = {} if param_shape_backup else None
self._prune_parameters(
graph,
scope,
params,
ratios,
place,
lazy=lazy,
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
for op in graph.ops():
if op.type() == 'depthwise_conv2d' or op.type(
) == 'depthwise_conv2d_grad':
op.set_attr('groups', op.inputs('Filter')[0].shape()[0])
return graph.program, param_backup, param_shape_backup
def _prune_filters_by_ratio(self,
scope,
params,
ratio,
place,
lazy=False,
only_graph=False,
param_shape_backup=None,
param_backup=None):
"""
Pruning filters by given ratio.
Args:
scope(fluid.core.Scope): The scope used to pruning filters.
params(list<VarWrapper>): A list of filter parameters.
ratio(float): The ratio to be pruned.
place(fluid.Place): The device place of filter parameters.
lazy(bool): True means setting the pruned elements to zero.
False means cutting down the pruned elements.
only_graph(bool): True means only modifying the graph.
False means modifying graph and variables in scope.
"""
if params[0].name() in self.pruned_list[0]:
return
if only_graph:
pruned_num = int(round(params[0].shape()[0] * ratio))
for param in params:
ori_shape = param.shape()
if param_backup is not None and (
param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy(ori_shape)
new_shape = list(ori_shape)
new_shape[0] -= pruned_num
param.set_shape(new_shape)
_logger.debug("prune [{}] from {} to {}".format(param.name(
), ori_shape, new_shape))
self.pruned_list[0].append(param.name())
return range(pruned_num)
else:
param_t = scope.find_var(params[0].name()).get_tensor()
pruned_idx = self._cal_pruned_idx(
params[0].name(), np.array(param_t), ratio, axis=0)
for param in params:
assert isinstance(param, VarWrapper)
param_t = scope.find_var(param.name()).get_tensor()
if param_backup is not None and (
param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy(
np.array(param_t))
try:
pruned_param = self._prune_tensor(
np.array(param_t),
pruned_idx,
pruned_axis=0,
lazy=lazy)
except IndexError as e:
_logger.error("Pruning {}, but get [{}]".format(param.name(
), e))
param_t.set(pruned_param, place)
ori_shape = param.shape()
if param_shape_backup is not None and (
param.name() not in param_shape_backup):
param_shape_backup[param.name()] = copy.deepcopy(
param.shape())
new_shape = list(param.shape())
new_shape[0] = pruned_param.shape[0]
param.set_shape(new_shape)
_logger.debug("prune [{}] from {} to {}".format(param.name(
), ori_shape, new_shape))
self.pruned_list[0].append(param.name())
return pruned_idx
def _prune_parameter_by_idx(self,
scope,
params,
pruned_idx,
pruned_axis,
place,
lazy=False,
only_graph=False,
param_shape_backup=None,
param_backup=None):
"""
Pruning parameters in given axis.
Args:
scope(fluid.core.Scope): The scope storing paramaters to be pruned.
params(VarWrapper): The parameter to be pruned.
pruned_idx(list): The index of elements to be pruned.
pruned_axis(int): The pruning axis.
place(fluid.Place): The device place of filter parameters.
lazy(bool): True means setting the pruned elements to zero.
False means cutting down the pruned elements.
only_graph(bool): True means only modifying the graph.
False means modifying graph and variables in scope.
"""
if params[0].name() in self.pruned_list[pruned_axis]:
return
if only_graph:
pruned_num = len(pruned_idx)
for param in params:
ori_shape = param.shape()
if param_backup is not None and (
param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy(ori_shape)
new_shape = list(ori_shape)
new_shape[pruned_axis] -= pruned_num
param.set_shape(new_shape)
_logger.debug("prune [{}] from {} to {}".format(param.name(
), ori_shape, new_shape))
self.pruned_list[pruned_axis].append(param.name())
else:
for param in params:
assert isinstance(param, VarWrapper)
param_t = scope.find_var(param.name()).get_tensor()
if param_backup is not None and (
param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy(
np.array(param_t))
pruned_param = self._prune_tensor(
np.array(param_t), pruned_idx, pruned_axis, lazy=lazy)
param_t.set(pruned_param, place)
ori_shape = param.shape()
if param_shape_backup is not None and (
param.name() not in param_shape_backup):
param_shape_backup[param.name()] = copy.deepcopy(
param.shape())
new_shape = list(param.shape())
new_shape[pruned_axis] = pruned_param.shape[pruned_axis]
param.set_shape(new_shape)
_logger.debug("prune [{}] from {} to {}".format(param.name(
), ori_shape, new_shape))
self.pruned_list[pruned_axis].append(param.name())
def _forward_search_related_op(self, graph, node):
"""
Forward search operators that will be affected by pruning of param.
Args:
graph(GraphWrapper): The graph to be searched.
node(VarWrapper|OpWrapper): The current pruned parameter or operator.
Returns:
list<OpWrapper>: A list of operators.
"""
visited = {}
for op in graph.ops():
visited[op.idx()] = False
stack = []
visit_path = []
if isinstance(node, VarWrapper):
for op in graph.ops():
if (not op.is_bwd_op()) and (node in op.all_inputs()):
next_ops = self._get_next_unvisited_op(graph, visited, op)
# visit_path.append(op)
visited[op.idx()] = True
for next_op in next_ops:
if visited[next_op.idx()] == False:
stack.append(next_op)
visit_path.append(next_op)
visited[next_op.idx()] = True
elif isinstance(node, OpWrapper):
next_ops = self._get_next_unvisited_op(graph, visited, node)
for next_op in next_ops:
if visited[next_op.idx()] == False:
stack.append(next_op)
visit_path.append(next_op)
visited[next_op.idx()] = True
while len(stack) > 0:
#top_op = stack[len(stack) - 1]
top_op = stack.pop(0)
next_ops = None
if top_op.type() in ["conv2d", "deformable_conv"]:
next_ops = None
elif top_op.type() in ["mul", "concat"]:
next_ops = None
else:
next_ops = self._get_next_unvisited_op(graph, visited, top_op)
if next_ops != None:
for op in next_ops:
if visited[op.idx()] == False:
stack.append(op)
visit_path.append(op)
visited[op.idx()] = True
return visit_path
def _get_next_unvisited_op(self, graph, visited, top_op):
"""
Get next unvisited adjacent operators of given operators.
Args:
graph(GraphWrapper): The graph used to search.
visited(list): The ids of operators that has been visited.
top_op: The given operator.
Returns:
list<OpWrapper>: A list of operators.
"""
assert isinstance(top_op, OpWrapper)
next_ops = []
for op in graph.next_ops(top_op):
if (visited[op.idx()] == False) and (not op.is_bwd_op()):
next_ops.append(op)
return next_ops
def _get_accumulator(self, graph, param):
"""
Get accumulators of given parameter. The accumulator was created by optimizer.
Args:
graph(GraphWrapper): The graph used to search.
param(VarWrapper): The given parameter.
Returns:
list<VarWrapper>: A list of accumulators which are variables.
"""
assert isinstance(param, VarWrapper)
params = []
for op in param.outputs():
if op.is_opt_op():
for out_var in op.all_outputs():
if graph.is_persistable(out_var) and out_var.name(
) != param.name():
params.append(out_var)
return params
def _forward_pruning_ralated_params(self,
graph,
scope,
param,
place,
ratio=None,
pruned_idxs=None,
lazy=False,
only_graph=False,
param_backup=None,
param_shape_backup=None):
"""
Pruning all the parameters affected by the pruning of given parameter.
Args:
graph(GraphWrapper): The graph to be searched.
scope(fluid.core.Scope): The scope storing paramaters to be pruned.
param(VarWrapper): The given parameter.
place(fluid.Place): The device place of filter parameters.
ratio(float): The target ratio to be pruned.
pruned_idx(list): The index of elements to be pruned.
lazy(bool): True means setting the pruned elements to zero.
False means cutting down the pruned elements.
only_graph(bool): True means only modifying the graph.
False means modifying graph and variables in scope.
"""
assert isinstance(
graph,
GraphWrapper), "graph must be instance of slim.core.GraphWrapper"
assert isinstance(
param,
VarWrapper), "param must be instance of slim.core.VarWrapper"
if param.name() in self.pruned_list[0]:
return
related_ops = self._forward_search_related_op(graph, param)
for op in related_ops:
_logger.debug("relate op: {};".format(op))
if ratio is None:
assert pruned_idxs is not None
self._prune_parameter_by_idx(
scope, [param] + self._get_accumulator(graph, param),
pruned_idxs,
pruned_axis=0,
place=place,
lazy=lazy,
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
else:
pruned_idxs = self._prune_filters_by_ratio(
scope, [param] + self._get_accumulator(graph, param),
ratio,
place,
lazy=lazy,
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
self._prune_ops(related_ops, pruned_idxs, graph, scope, place, lazy,
only_graph, param_backup, param_shape_backup)
def _prune_ops(self, ops, pruned_idxs, graph, scope, place, lazy,
only_graph, param_backup, param_shape_backup):
for idx, op in enumerate(ops):
if op.type() in ["conv2d", "deformable_conv"]:
for in_var in op.all_inputs():
if graph.is_parameter(in_var):
conv_param = in_var
self._prune_parameter_by_idx(
scope, [conv_param] + self._get_accumulator(
graph, conv_param),
pruned_idxs,
pruned_axis=1,
place=place,
lazy=lazy,
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
if op.type() == "depthwise_conv2d":
for in_var in op.all_inputs():
if graph.is_parameter(in_var):
conv_param = in_var
self._prune_parameter_by_idx(
scope, [conv_param] + self._get_accumulator(
graph, conv_param),
pruned_idxs,
pruned_axis=0,
place=place,
lazy=lazy,
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
elif op.type() == "elementwise_add":
# pruning bias
for in_var in op.all_inputs():
if graph.is_parameter(in_var):
bias_param = in_var
self._prune_parameter_by_idx(
scope, [bias_param] + self._get_accumulator(
graph, bias_param),
pruned_idxs,
pruned_axis=0,
place=place,
lazy=lazy,
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
elif op.type() == "mul": # pruning fc layer
fc_input = None
fc_param = None
for in_var in op.all_inputs():
if graph.is_parameter(in_var):
fc_param = in_var
else:
fc_input = in_var
idx = []
feature_map_size = fc_input.shape()[2] * fc_input.shape()[3]
range_idx = np.array(range(feature_map_size))
for i in pruned_idxs:
idx += list(range_idx + i * feature_map_size)
corrected_idxs = idx
self._prune_parameter_by_idx(
scope, [fc_param] + self._get_accumulator(graph, fc_param),
corrected_idxs,
pruned_axis=0,
place=place,
lazy=lazy,
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
elif op.type() == "concat":
concat_inputs = op.all_inputs()
last_op = ops[idx - 1]
concat_idx = None
for last_op in reversed(ops):
for out_var in last_op.all_outputs():
if out_var in concat_inputs:
concat_idx = concat_inputs.index(out_var)
break
if concat_idx is not None:
break
offset = 0
for ci in range(concat_idx):
offset += concat_inputs[ci].shape()[1]
corrected_idxs = [x + offset for x in pruned_idxs]
related_ops = self._forward_search_related_op(graph, op)
for op in related_ops:
_logger.debug("concat relate op: {};".format(op))
self._prune_ops(related_ops, corrected_idxs, graph, scope,
place, lazy, only_graph, param_backup,
param_shape_backup)
elif op.type() == "batch_norm":
bn_inputs = op.all_inputs()
in_num = len(bn_inputs)
beta = bn_inputs[0]
mean = bn_inputs[1]
alpha = bn_inputs[2]
variance = bn_inputs[3]
self._prune_parameter_by_idx(
scope, [mean] + self._get_accumulator(graph, mean),
pruned_idxs,
pruned_axis=0,
place=place,
lazy=lazy,
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
self._prune_parameter_by_idx(
scope, [variance] + self._get_accumulator(graph, variance),
pruned_idxs,
pruned_axis=0,
place=place,
lazy=lazy,
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
self._prune_parameter_by_idx(
scope, [alpha] + self._get_accumulator(graph, alpha),
pruned_idxs,
pruned_axis=0,
place=place,
lazy=lazy,
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
self._prune_parameter_by_idx(
scope, [beta] + self._get_accumulator(graph, beta),
pruned_idxs,
pruned_axis=0,
place=place,
lazy=lazy,
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
def _prune_parameters(self,
graph,
scope,
params,
ratios,
place,
lazy=False,
only_graph=False,
param_backup=None,
param_shape_backup=None):
"""
Pruning the given parameters.
Args:
graph(GraphWrapper): The graph to be searched.
scope(fluid.core.Scope): The scope storing paramaters to be pruned.
params(list<str>): A list of parameter names to be pruned.
ratios(list<float>): A list of ratios to be used to pruning parameters.
place(fluid.Place): The device place of filter parameters.
pruned_idx(list): The index of elements to be pruned.
lazy(bool): True means setting the pruned elements to zero.
False means cutting down the pruned elements.
only_graph(bool): True means only modifying the graph.
False means modifying graph and variables in scope.
"""
assert len(params) == len(ratios)
self.pruned_list = [[], []]
pruned_params = []
for param, ratio in zip(params, ratios):
assert isinstance(param, str) or isinstance(param, unicode)
if param in self.pruned_list[0]:
_logger.info("Skip {}".format(param))
continue
_logger.info("pruning param: {}".format(param))
if only_graph:
param_v = graph.var(param)
pruned_num = int(round(param_v.shape()[0] * ratio))
pruned_idx = [0] * pruned_num
else:
param_t = np.array(scope.find_var(param).get_tensor())
pruned_idx = self._cal_pruned_idx(param_t, ratio, axis=0)
param = graph.var(param)
self._forward_pruning_ralated_params(
graph,
scope,
param,
place,
ratio=ratio,
lazy=lazy,
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
ops = param.outputs()
for op in ops:
if op.type() in ['conv2d', 'deformable_conv']:
brother_ops = self._search_brother_ops(graph, op)
for broher in brother_ops:
_logger.debug("pruning brother: {}".format(broher))
for p in graph.get_param_by_op(broher):
self._forward_pruning_ralated_params(
graph,
scope,
p,
place,
ratio=ratio,
lazy=lazy,
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
def _search_brother_ops(self, graph, op_node):
"""
Search brother operators that was affected by pruning of given operator.
Args:
graph(GraphWrapper): The graph to be searched.
op_node(OpWrapper): The start node for searching.
Returns:
list<VarWrapper>: A list of operators.
"""
_logger.debug("######################search: {}######################".
format(op_node))
visited = [op_node.idx()]
stack = []
brothers = []
for op in graph.next_ops(op_node):
if ("conv2d" not in op.type()) and (
"concat" not in op.type()) and (
"deformable_conv" not in op.type()) and (
op.type() != 'fc') and (
not op.is_bwd_op()) and (not op.is_opt_op()):
stack.append(op)
visited.append(op.idx())
while len(stack) > 0:
top_op = stack.pop()
for parent in graph.pre_ops(top_op):
if parent.idx() not in visited and (
not parent.is_bwd_op()) and (not parent.is_opt_op()):
_logger.debug("----------go back from {} to {}----------".
format(top_op, parent))
if (('conv2d' in parent.type()) or
("deformable_conv" in parent.type()) or
(parent.type() == 'fc')):
brothers.append(parent)
else:
stack.append(parent)
visited.append(parent.idx())
for child in graph.next_ops(top_op):
if ('conv2d' not in child.type()) and (
"concat" not in child.type()) and (
'deformable_conv' not in child.type()) and (
child.type() != 'fc') and (
child.idx() not in visited) and (
not child.is_bwd_op()) and (
not child.is_opt_op()):
stack.append(child)
visited.append(child.idx())
_logger.debug("brothers: {}".format(brothers))
_logger.debug(
"######################Finish search######################".format(
op_node))
return brothers
conv_op = param.outputs()[0]
walker = conv2d_walker(conv_op,pruned_params=pruned_params, visited=visited)
walker.prune(param, pruned_axis=0, pruned_idx=pruned_idx)
merge_pruned_params = {}
for param, pruned_axis, pruned_idx in pruned_params:
if param.name() not in merge_pruned_params:
merge_pruned_params[param.name()] = {}
if pruned_axis not in merge_pruned_params[param.name()]:
merge_pruned_params[param.name()][pruned_axis] = []
merge_pruned_params[param.name()][pruned_axis].append(pruned_idx)
for param_name in merge_pruned_params:
for pruned_axis in merge_pruned_params[param_name]:
pruned_idx = np.concatenate(merge_pruned_params[param_name][pruned_axis])
param = graph.var(param_name)
_logger.debug("{}\t{}\t{}".format(param.name(), pruned_axis, len(pruned_idx)))
if param_shape_backup is not None:
origin_shape = copy.deepcopy(param.shape())
param_shape_backup[param.name()] = origin_shape
new_shape = list(param.shape())
new_shape[pruned_axis] -= len(pruned_idx)
param.set_shape(new_shape)
if not only_graph:
param_t = scope.find_var(param.name()).get_tensor()
if param_backup is not None and (param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy(np.array(param_t))
try:
pruned_param = self._prune_tensor(
np.array(param_t),
pruned_idx,
pruned_axis=pruned_axis,
lazy=lazy)
except IndexError as e:
_logger.error("Pruning {}, but get [{}]".format(param.name(
), e))
param_t.set(pruned_param, place)
return graph.program, param_backup, param_shape_backup
def _cal_pruned_idx(self, name, param, ratio, axis):
def _cal_pruned_idx(self, param, ratio, axis):
"""
Calculate the index to be pruned on axis by given pruning ratio.
Args:
......
......@@ -15,7 +15,7 @@ import sys
sys.path.append("../")
import unittest
import paddle.fluid as fluid
from paddleslim.prune import Pruner
from paddleslim.prune.walk_pruner import Pruner
from layers import conv_bn_layer
......@@ -72,6 +72,7 @@ class TestPrune(unittest.TestCase):
for param in main_program.global_block().all_parameters():
if "weights" in param.name:
print("param: {}; param shape: {}".format(param.name, param.shape))
self.assertTrue(param.shape == shapes[param.name])
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import unittest
import paddle.fluid as fluid
from paddleslim.prune import Pruner
from paddleslim.core import GraphWrapper
from paddleslim.prune import conv2d as conv2d_walker
from layers import conv_bn_layer
class TestPrune(unittest.TestCase):
def test_prune(self):
main_program = fluid.Program()
startup_program = fluid.Program()
# X X O X O
# conv1-->conv2-->sum1-->conv3-->conv4-->sum2-->conv5-->conv6
# | ^ | ^
# |____________| |____________________|
#
# X: prune output channels
# O: prune input channels
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[None, 3, 16, 16])
conv1 = conv_bn_layer(input, 8, 3, "conv1")
conv2 = conv_bn_layer(conv1, 8, 3, "conv2")
sum1 = conv1 + conv2
conv3 = conv_bn_layer(sum1, 8, 3, "conv3")
conv4 = conv_bn_layer(conv3, 8, 3, "conv4")
sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
shapes = {}
for param in main_program.global_block().all_parameters():
shapes[param.name] = param.shape
place = fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.Scope()
exe.run(startup_program, scope=scope)
graph = GraphWrapper(main_program)
conv_op = graph.var("conv4_weights").outputs()[0]
walker = conv2d_walker(conv_op, [])
walker.prune(graph.var("conv4_weights"), pruned_axis=0, pruned_idx=[])
print walker.pruned_params
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册