提交 eb45a9a1 编写于 作者: W wanghaoshuang

Add simple prune API and unitest.

上级 3082ddab
......@@ -11,3 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import graph_wrapper
from .graph_wrapper import *
__all__ = graph_wrapper.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import copy
import pickle
import numpy as np
from collections import OrderedDict
from collections import Iterable
from paddle.fluid.framework import Program, program_guard, Parameter, Variable
__all__ = ['GraphWrapper', 'VarWrapper', 'OpWrapper']
OPTIMIZER_OPS = [
'momentum',
'lars_momentum',
'adagrad',
'adam',
'adamax',
'dpsgd',
'decayed_adagrad',
'adadelta',
'rmsprop',
]
class VarWrapper(object):
def __init__(self, var, graph):
assert isinstance(var, Variable)
assert isinstance(graph, GraphWrapper)
self._var = var
self._graph = graph
def __eq__(self, v):
"""
Overwrite this function for ...in... syntax in python.
"""
return self._var.name == v._var.name
def name(self):
"""
Get the name of the variable.
"""
return self._var.name
def shape(self):
"""
Get the shape of the varibale.
"""
return self._var.shape
def set_shape(self, shape):
"""
Set the shape of the variable.
"""
self._var.desc.set_shape(shape)
def inputs(self):
"""
Get all the operators that use this variable as output.
Returns:
list<OpWrapper>: A list of operators.
"""
ops = []
for op in self._graph.ops():
if self in op.all_outputs():
ops.append(op)
return ops
def outputs(self):
"""
Get all the operators that use this variable as input.
Returns:
list<OpWrapper>: A list of operators.
"""
ops = []
for op in self._graph.ops():
if self in op.all_inputs():
ops.append(op)
return ops
class OpWrapper(object):
def __init__(self, op, graph):
assert isinstance(graph, GraphWrapper)
self._op = op
self._graph = graph
def __eq__(self, op):
"""
Overwrite this function for ...in... syntax in python.
"""
return self.idx() == op.idx()
def all_inputs(self):
"""
Get all the input variables of this operator.
"""
return [
self._graph.var(var_name) for var_name in self._op.input_arg_names
]
def all_outputs(self):
"""
Get all the output variables of this operator.
"""
return [
self._graph.var(var_name) for var_name in self._op.output_arg_names
]
def idx(self):
"""
Get the id of this operator.
"""
return self._op.idx
def type(self):
"""
Get the type of this operator.
"""
return self._op.type
def is_bwd_op(self):
"""
Whether this operator is backward op.
"""
return self.type().endswith('_grad')
def is_opt_op(self):
"""
Whether this operator is optimizer op.
"""
return self.type() in OPTIMIZER_OPS
def inputs(self, name):
"""
Get all the varibales by the input name.
"""
return [self._graph.var(var_name) for var_name in self._op.input(name)]
def outputs(self, name):
"""
Get all the varibales by the output name.
"""
return [
self._graph.var(var_name) for var_name in self._op.output(name)
]
def set_attr(self, key, value):
"""
Set the value of attribute by attribute's name.
Args:
key(str): the attribute name.
value(bool|int|str|float|list): the value of the attribute.
"""
self._op._set_attr(key, value)
def attr(self, name):
"""
Get the attribute by name.
Args:
name(str): the attribute name.
Returns:
bool|int|str|float|list: The attribute value. The return value
can be any valid attribute type.
"""
return self._op.attr(name)
class GraphWrapper(object):
"""
It is a wrapper of paddle.fluid.framework.IrGraph with some special functions
for paddle slim framework.
"""
def __init__(self, program=None, in_nodes=[], out_nodes=[]):
"""
Args:
program(framework.Program): A program with
in_nodes(dict): A dict to indicate the input nodes of the graph.
The key is user-defined and human-readable name.
The value is the name of Variable.
out_nodes(dict): A dict to indicate the input nodes of the graph.
The key is user-defined and human-readable name.
The value is the name of Variable.
"""
super(GraphWrapper, self).__init__()
self.program = Program() if program is None else program
self.persistables = {}
self.teacher_persistables = {}
for var in self.program.list_vars():
if var.persistable:
self.persistables[var.name] = var
self.compiled_graph = None
in_nodes = [] if in_nodes is None else in_nodes
out_nodes = [] if out_nodes is None else out_nodes
self.in_nodes = OrderedDict(in_nodes)
self.out_nodes = OrderedDict(out_nodes)
self._attrs = OrderedDict()
def all_parameters(self):
"""
Get all the parameters in this graph.
Returns:
list<VarWrapper>: A list of VarWrapper instances.
"""
params = []
for block in self.program.blocks:
for param in block.all_parameters():
params.append(VarWrapper(param, self))
return params
def is_parameter(self, var):
"""
Whether the given variable is parameter.
Args:
var(VarWrapper): The given varibale.
"""
return isinstance(var._var, Parameter)
def is_persistable(self, var):
"""
Whether the given variable is persistable.
Args:
var(VarWrapper): The given varibale.
"""
return var._var.persistable
def ops(self):
"""
Return all operator nodes included in the graph as a set.
"""
ops = []
for block in self.program.blocks:
for op in block.ops:
ops.append(OpWrapper(op, self))
return ops
def vars(self):
"""
Get all the variables.
"""
return [VarWrapper(var, self) for var in self.program.list_vars()]
def var(self, name):
"""
Get the variable by variable name.
"""
return VarWrapper(self.program.global_block().var(name), self)
def clone(self, for_test=False):
"""
Clone a new graph from current graph.
Returns:
(GraphWrapper): The wrapper of a new graph.
"""
return GraphWrapper(
self.program.clone(for_test),
copy.deepcopy(self.in_nodes), copy.deepcopy(self.out_nodes))
def program(self):
"""
Get the program in current wrapper.
"""
return self.program
def pre_ops(self, op):
"""
Get all the previous operators of target operator.
Args:
op(OpWrapper): Target operator..
Returns:
list<OpWrapper>: A list of operators.
"""
ops = []
for p in self.ops():
for in_var in op.all_inputs():
if in_var in p.all_outputs():
ops.append(p)
return ops
def next_ops(self, op):
"""
Get all the next operators of target operator.
Args:
op(OpWrapper): Target operator..
Returns:
list<OpWrapper>: A list of operators.
"""
ops = []
for p in self.ops():
for out_var in op.all_outputs():
if out_var in p.all_inputs():
ops.append(p)
return ops
def get_param_by_op(self, op):
"""
Get the parameters used by target operator.
"""
assert isinstance(op, OpWrapper)
params = []
for var in op.all_inputs():
if isinstance(var._var, Parameter):
params.append(var)
assert len(params) > 0
return params
def numel_params(self):
"""
Get the number of elements in all parameters.
"""
ret = 0
for param in self.all_parameters():
ret += np.product(param.shape())
return ret
def flops(self, only_conv=False):
"""
Get the flops of current graph.
Args:
only_conv: Only calculating the conv layers. default: False.
Returns:
int: The flops of current graph.
"""
flops = 0
for op in self.ops():
if op.type() in ['conv2d', 'depthwise_conv2d']:
filter_shape = op.inputs("Filter")[0].shape()
input_shape = op.inputs("Input")[0].shape()
output_shape = op.outputs("Output")[0].shape()
c_out, c_in, k_h, k_w = filter_shape
_, _, h_out, w_out = output_shape
groups = op.attr("groups")
kernel_ops = k_h * k_w * (c_in / groups)
if len(op.inputs("Bias")) > 0:
with_bias = 1
else:
with_bias = 0
flops += 2 * h_out * w_out * c_out * (kernel_ops + with_bias)
elif op.type() == 'pool2d' and not only_conv:
input_shape = op.inputs("X")[0].shape()
output_shape = op.outputs("Out")[0].shape()
_, c_out, h_out, w_out = output_shape
k_size = op.attr("ksize")
flops += h_out * w_out * c_out * (k_size[0]**2)
elif op.type() == 'mul' and not only_conv:
x_shape = list(op.inputs("X")[0].shape())
y_shape = op.inputs("Y")[0].shape()
if x_shape[0] == -1:
x_shape[0] = 1
flops += 2 * x_shape[0] * x_shape[1] * y_shape[1]
elif op.type() in ['relu', 'sigmoid', 'batch_norm'
] and not only_conv:
input_shape = list(op.inputs("X")[0].shape())
if input_shape[0] == -1:
input_shape[0] = 1
flops += np.product(input_shape)
return flops
def update_param_shape(self, scope):
"""
Update the shape of parameters in the graph according to tensors in scope.
It is used after loading pruned parameters from file.
"""
for param in self.all_parameters():
tensor_shape = np.array(
scope.find_var(param.name()).get_tensor()).shape
param.set_shape(tensor_shape)
def infer_shape(self):
"""
Update the groups of convolution layer according to current filters.
It is used after loading pruned parameters from file.
"""
for op in self.ops():
if op.type() != 'conditional_block':
op._op.desc.infer_shape(op._op.block.desc)
def update_groups_of_conv(self):
for op in self.ops():
if op.type() == 'depthwise_conv2d' or op.type(
) == 'depthwise_conv2d_grad':
op.set_attr('groups', op.inputs('Filter')[0].shape()[0])
......@@ -11,3 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pruner import Pruner
此差异已折叠。
......@@ -33,8 +33,12 @@ with open('./requirements.txt') as f:
setup_requires = f.read().splitlines()
packages = [
'paddleslim', 'paddleslim.prune', 'paddleslim.dist', 'paddleslim.nas',
'paddleslim.analysis', 'paddleslim.quant'
'paddleslim',
'paddleslim.prune',
'paddleslim.dist',
'paddleslim.nas',
'paddleslim.analysis',
'paddleslim.quant',
]
setup(
......
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
def conv_bn_layer(input,
num_filters,
filter_size,
name,
stride=1,
groups=1,
act=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False,
name=name + "_out")
bn_name = name + "_bn"
return fluid.layers.batch_norm(
input=conv,
act=act,
name=bn_name + '_output',
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance', )
import sys
sys.path.append("../")
import unittest
import paddle.fluid as fluid
from paddleslim.prune import Pruner
from layers import conv_bn_layer
class TestPrune(unittest.TestCase):
def test_prune(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[None, 3, 16, 16])
conv1 = conv_bn_layer(input, 8, 3, "conv1")
conv2 = conv_bn_layer(conv1, 8, 3, "conv2")
sum1 = conv1 + conv2
conv3 = conv_bn_layer(sum1, 8, 3, "conv3")
conv4 = conv_bn_layer(conv3, 8, 3, "conv4")
sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
shapes = {}
for param in main_program.global_block().all_parameters():
shapes[param.name] = param.shape
place = fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.Scope()
exe.run(startup_program, scope=scope)
pruner = Pruner()
main_program = pruner.prune(
main_program,
scope,
params=["conv4_weights"],
ratios=[0.5],
place=place,
lazy=False,
only_graph=False,
param_backup=None,
param_shape_backup=None)
shapes = {
"conv5_weights": (8L, 4L, 3L, 3L),
"conv1_weights": (4L, 3L, 3L, 3L),
"conv6_weights": (8L, 8L, 3L, 3L),
"conv3_weights": (8L, 4L, 3L, 3L),
"conv2_weights": (4L, 4L, 3L, 3L),
"conv4_weights": (4L, 8L, 3L, 3L)
}
for param in main_program.global_block().all_parameters():
if "weights" in param.name:
self.assertTrue(param.shape == shapes[param.name])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册