diff --git a/paddleslim/core/__init__.py b/paddleslim/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f6ad753dfa11c695be32a2211e27dddfdaed7072 --- /dev/null +++ b/paddleslim/core/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import graph_wrapper +from .graph_wrapper import * +__all__ = graph_wrapper.__all__ diff --git a/paddleslim/core/graph_wrapper.py b/paddleslim/core/graph_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..72de894a2e4345c32e7a4eee2f35249b77c2f467 --- /dev/null +++ b/paddleslim/core/graph_wrapper.py @@ -0,0 +1,355 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import copy +import pickle +import numpy as np +from collections import OrderedDict +from collections import Iterable +from paddle.fluid.framework import Program, program_guard, Parameter, Variable + +__all__ = ['GraphWrapper', 'VarWrapper', 'OpWrapper'] + +OPTIMIZER_OPS = [ + 'momentum', + 'lars_momentum', + 'adagrad', + 'adam', + 'adamax', + 'dpsgd', + 'decayed_adagrad', + 'adadelta', + 'rmsprop', +] + + +class VarWrapper(object): + def __init__(self, var, graph): + assert isinstance(var, Variable) + assert isinstance(graph, GraphWrapper) + self._var = var + self._graph = graph + + def __eq__(self, v): + """ + Overwrite this function for ...in... syntax in python. + """ + return self._var.name == v._var.name + + def name(self): + """ + Get the name of the variable. + """ + return self._var.name + + def shape(self): + """ + Get the shape of the varibale. + """ + return self._var.shape + + def set_shape(self, shape): + """ + Set the shape of the variable. + """ + self._var.desc.set_shape(shape) + + def inputs(self): + """ + Get all the operators that use this variable as output. + Returns: + list: A list of operators. + """ + ops = [] + for op in self._graph.ops(): + if self in op.all_outputs(): + ops.append(op) + return ops + + def outputs(self): + """ + Get all the operators that use this variable as input. + Returns: + list: A list of operators. + """ + ops = [] + for op in self._graph.ops(): + if self in op.all_inputs(): + ops.append(op) + return ops + + +class OpWrapper(object): + def __init__(self, op, graph): + assert isinstance(graph, GraphWrapper) + self._op = op + self._graph = graph + + def __eq__(self, op): + """ + Overwrite this function for ...in... syntax in python. + """ + return self.idx() == op.idx() + + def all_inputs(self): + """ + Get all the input variables of this operator. + """ + return [ + self._graph.var(var_name) for var_name in self._op.input_arg_names + ] + + def all_outputs(self): + """ + Get all the output variables of this operator. + """ + return [ + self._graph.var(var_name) for var_name in self._op.output_arg_names + ] + + def idx(self): + """ + Get the id of this operator. + """ + return self._op.idx + + def type(self): + """ + Get the type of this operator. + """ + return self._op.type + + def is_bwd_op(self): + """ + Whether this operator is backward op. + """ + return self.type().endswith('_grad') + + def is_opt_op(self): + """ + Whether this operator is optimizer op. + """ + return self.type() in OPTIMIZER_OPS + + def inputs(self, name): + """ + Get all the varibales by the input name. + """ + return [self._graph.var(var_name) for var_name in self._op.input(name)] + + def outputs(self, name): + """ + Get all the varibales by the output name. + """ + return [ + self._graph.var(var_name) for var_name in self._op.output(name) + ] + + def set_attr(self, key, value): + """ + Set the value of attribute by attribute's name. + + Args: + key(str): the attribute name. + value(bool|int|str|float|list): the value of the attribute. + """ + self._op._set_attr(key, value) + + def attr(self, name): + """ + Get the attribute by name. + + Args: + name(str): the attribute name. + + Returns: + bool|int|str|float|list: The attribute value. The return value + can be any valid attribute type. + """ + return self._op.attr(name) + + +class GraphWrapper(object): + """ + It is a wrapper of paddle.fluid.framework.IrGraph with some special functions + for paddle slim framework. + """ + + def __init__(self, program=None, in_nodes=[], out_nodes=[]): + """ + Args: + program(framework.Program): A program with + in_nodes(dict): A dict to indicate the input nodes of the graph. + The key is user-defined and human-readable name. + The value is the name of Variable. + out_nodes(dict): A dict to indicate the input nodes of the graph. + The key is user-defined and human-readable name. + The value is the name of Variable. + """ + super(GraphWrapper, self).__init__() + self.program = Program() if program is None else program + self.persistables = {} + self.teacher_persistables = {} + for var in self.program.list_vars(): + if var.persistable: + self.persistables[var.name] = var + self.compiled_graph = None + in_nodes = [] if in_nodes is None else in_nodes + out_nodes = [] if out_nodes is None else out_nodes + self.in_nodes = OrderedDict(in_nodes) + self.out_nodes = OrderedDict(out_nodes) + self._attrs = OrderedDict() + + def all_parameters(self): + """ + Get all the parameters in this graph. + Returns: + list: A list of VarWrapper instances. + """ + params = [] + for block in self.program.blocks: + for param in block.all_parameters(): + params.append(VarWrapper(param, self)) + return params + + def is_parameter(self, var): + """ + Whether the given variable is parameter. + Args: + var(VarWrapper): The given varibale. + """ + return isinstance(var._var, Parameter) + + def is_persistable(self, var): + """ + Whether the given variable is persistable. + Args: + var(VarWrapper): The given varibale. + """ + return var._var.persistable + + def ops(self): + """ + Return all operator nodes included in the graph as a set. + """ + ops = [] + for block in self.program.blocks: + for op in block.ops: + ops.append(OpWrapper(op, self)) + return ops + + def vars(self): + """ + Get all the variables. + """ + return [VarWrapper(var, self) for var in self.program.list_vars()] + + def var(self, name): + """ + Get the variable by variable name. + """ + return VarWrapper(self.program.global_block().var(name), self) + + def clone(self, for_test=False): + """ + Clone a new graph from current graph. + Returns: + (GraphWrapper): The wrapper of a new graph. + """ + return GraphWrapper( + self.program.clone(for_test), + copy.deepcopy(self.in_nodes), copy.deepcopy(self.out_nodes)) + + def program(self): + """ + Get the program in current wrapper. + """ + return self.program + + def pre_ops(self, op): + """ + Get all the previous operators of target operator. + Args: + op(OpWrapper): Target operator.. + Returns: + list: A list of operators. + """ + ops = [] + for p in self.ops(): + for in_var in op.all_inputs(): + if in_var in p.all_outputs(): + ops.append(p) + return ops + + def next_ops(self, op): + """ + Get all the next operators of target operator. + Args: + op(OpWrapper): Target operator.. + Returns: + list: A list of operators. + """ + ops = [] + for p in self.ops(): + for out_var in op.all_outputs(): + if out_var in p.all_inputs(): + ops.append(p) + return ops + + def get_param_by_op(self, op): + """ + Get the parameters used by target operator. + """ + assert isinstance(op, OpWrapper) + params = [] + for var in op.all_inputs(): + if isinstance(var._var, Parameter): + params.append(var) + assert len(params) > 0 + return params + + def numel_params(self): + """ + Get the number of elements in all parameters. + """ + ret = 0 + for param in self.all_parameters(): + ret += np.product(param.shape()) + return ret + + def update_param_shape(self, scope): + """ + Update the shape of parameters in the graph according to tensors in scope. + It is used after loading pruned parameters from file. + """ + for param in self.all_parameters(): + tensor_shape = np.array( + scope.find_var(param.name()).get_tensor()).shape + param.set_shape(tensor_shape) + + def infer_shape(self): + """ + Update the groups of convolution layer according to current filters. + It is used after loading pruned parameters from file. + """ + for op in self.ops(): + if op.type() != 'conditional_block': + op._op.desc.infer_shape(op._op.block.desc) + + def update_groups_of_conv(self): + for op in self.ops(): + if op.type() == 'depthwise_conv2d' or op.type( + ) == 'depthwise_conv2d_grad': + op.set_attr('groups', op.inputs('Filter')[0].shape()[0]) diff --git a/paddleslim/dist/mp_distiller.py b/paddleslim/dist/mp_distiller.py new file mode 100755 index 0000000000000000000000000000000000000000..ff15f5f17dd130edfd6fc5bfa1d8c358da2a5ae2 --- /dev/null +++ b/paddleslim/dist/mp_distiller.py @@ -0,0 +1,223 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import logging +import numpy as np +from six.moves.queue import Queue + +import paddle.fluid as fluid +from paddle.fluid.framework import Variable +from paddle.fluid.reader import DataLoaderBase +from paddle.fluid.core import EOFException +from paddle.fluid.incubate.fleet.utils.hdfs import HDFSClient + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +__all__ = ['Knowledge'] + + +class Knowledge(object): + """ + The knowledge class describes how to extract and store the dark knowledge + of the teacher model, and how the student model learns these dark knowledge. + """ + + def __init__(self, + path, + items, + reduce_strategy={'type': 'sum', + 'key': 'image'}): + """Init a knowledge instance. + Args: + path(list, str, optional): Specifies the storage path of the knowledge, + supports AFS/HDFS, local file system, and memory. + items(list): Save the tensor of the specified name + reduce_strategy(dict, optional): The policy for performing the reduce + operation. If it is set to None, + the reduce operation is not performed. + reduce_strategy.type(str): Type of reduce operation. + reduce_strategy.key(str): The key of the reduce operation. + It is an element in the item. + """ + assert (isinstance(path, list) or isinstance(path, str) or + (path is None)), "path type should be list or str or None" + assert (isinstance(items, list)), "items should be a list" + assert (isinstance(reduce_strategy, + dict)), "reduce_strategy should be a dict" + self.path = path + if isinstance(self.path, list): + self.write_type = 'HDFS/AFS' + assert ( + len(self.path) == 4 and isinstance(self.path[0], str) and + isinstance(self.path[1], str) and + isinstance(self.path[2], str) and isinstance(self.path[3], str) + ), "path should contains four str, ['local hadoop home', 'fs.default.name', 'hadoop.job.ugi', 'FS path']" + + hadoop_home = self.path[0] + configs = { + "fs.default.name": self.path[1], + "hadoop.job.ugi": self.path[2] + } + self.client = HDFSClient(hadoop_home, configs) + assert ( + self.client.is_exist(self.path[3]) == True + ), "Plese make sure your hadoop confiuration is correct and FS path exists" + + self.hdfs_local_path = "./teacher_knowledge" + if not os.path.exists(self.hdfs_local_path): + os.mkdir(self.hdfs_local_path) + elif isinstance(self.path, str): + self.write_type = "LocalFS" + if not os.path.exists(path): + raise ValueError("The local path [%s] does not exist." % + (path)) + else: + self.write_type = "MEM" + self.knowledge_queue = Queue(64) + + self.items = items + self.reduce_strategy = reduce_strategy + + def _write(self, data): + if self.write_type == 'HDFS/AFS': + file_name = 'knowledge_' + str(self.file_cnt) + file_path = os.path.join(self.hdfs_local_path, file_name) + file_path += ".npy" + np.save(file_path, data) + self.file_cnt += 1 + self.client.upload(self.path[3], file_path) + logger.info('{}.npy pushed to HDFS/AFS: {}'.format(file_name, + self.path[3])) + + elif self.write_type == 'LocalFS': + file_name = 'knowledge_' + str(self.file_cnt) + file_path = os.path.join(self.path, file_name) + np.save(file_path, data) + logger.info('{}.npy saved'.format(file_name)) + self.file_cnt += 1 + + else: + self.knowledge_queue.put(data) + logger.info('{} pushed to Queue'.format(file_name)) + + def run(self, teacher_program, exe, place, scope, reader, inputs, outputs, + call_back): + """Start teacher model to do information. + Args: + teacher_program(Program): teacher program. + scope(Scope): The scope used to execute the teacher, + which contains the initialized variables. + reader(reader): The data reader used by the teacher. + inputs(list): The name of variables to feed the teacher program. + outputs(list): Need to write to the variable instance's names of + the Knowledge instance, which needs to correspond + to the Knowledge's items. + call_back(func, optional): The callback function that handles the + outputs of the teacher, which is none by default, + that is, the output of the teacher is concat directly. + Return: + (bool): Whether the teacher task was successfully registered and started + """ + assert (isinstance( + teacher_program, + fluid.Program)), "teacher_program should be a fluid.Program" + assert (isinstance(inputs, list)), "inputs should be a list" + assert (isinstance(outputs, list)), "outputs should be a list" + assert (len(self.items) == len(outputs) + ), "the length of outputs list should be equal with items list" + assert (callable(call_back) or (call_back is None) + ), "call_back should be a callable function or NoneType." + + for var in teacher_program.list_vars(): + var.stop_gradient = True + + compiled_teacher_program = fluid.compiler.CompiledProgram( + teacher_program) + self.file_cnt = 0 + if isinstance(reader, Variable) or ( + isinstance(reader, DataLoaderBase) and (not reader.iterable)): + reader.start() + try: + while True: + logits = exe.run(compiled_teacher_program, + scope=scope, + fetch_list=outputs, + feed=None) + knowledge = dict() + for index, array in enumerate(logits): + knowledge[self.items[index]] = array + self._write(knowledge) + except EOFException: + reader.reset() + + else: + if not isinstance(reader, DataLoaderBase): + feeder = fluid.DataFeeder( + feed_list=inputs, place=place, program=teacher_program) + for batch_id, data in enumerate(reader()): + if not isinstance(reader, DataLoaderBase): + data = feeder.feed(data) + logits = exe.run(compiled_teacher_program, + scope=scope, + fetch_list=outputs, + feed=data) + knowledge = dict() + for index, array in enumerate(logits): + knowledge[self.items[index]] = array + self._write(knowledge) + return True + + def dist(self, student_program, losses): + """Building the distillation network + Args: + student_program(Program): student program. + losses(list, optional): The losses need to add. If set to None + does not add any loss. + Return: + (Program): Program for distillation. + (startup_program): Program for initializing distillation network. + (reader): Data reader for distillation training. + (Variable): Loss of distillation training + """ + + def loss(self, loss_func, *variables): + """User-defined loss + Args: + loss_func(func): Function used to define loss. + *variables(list): Variable name list. + Return: + (Variable): Distillation loss. + """ + pass + + def fsp_loss(self): + """fsp loss + """ + pass + + def l2_loss(self): + """l2 loss + """ + pass + + def softlabel_loss(self): + """softlabel_loss + """ + pass diff --git a/paddleslim/nas/search_space_base.py b/paddleslim/nas/search_space_base.py new file mode 100644 index 0000000000000000000000000000000000000000..cc1d462aca3b81c231df53ec2b5995cbb1deb5d5 --- /dev/null +++ b/paddleslim/nas/search_space_base.py @@ -0,0 +1,44 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ['SearchSpaceBase'] + +class SearchSpaceBase(object): + """Controller for Neural Architecture Search. + """ + + def __init__(self, input_size, output_size, block_num, *argss): + self.input_size = input_size + self.output_size = output_size + self.block_num = block_num + + def init_tokens(self): + """Get init tokens in search space. + """ + raise NotImplementedError('Abstract method.') + + def range_table(self): + """Get range table of current search space. + """ + raise NotImplementedError('Abstract method.') + + def token2arch(self, tokens): + """Create networks for training and evaluation according to tokens. + Args: + tokens(list): The tokens which represent a network. + Return: + list + """ + raise NotImplementedError('Abstract method.') + diff --git a/paddleslim/nas/search_space_factory.py b/paddleslim/nas/search_space_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..10d076e8722a89cb46bac94360b1968e1de9e33a --- /dev/null +++ b/paddleslim/nas/search_space_factory.py @@ -0,0 +1,36 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from searchspace.registry import SEARCHSPACE + +class SearchSpaceFactory(object): + def __init__(self): + pass + + def get_search_space(self, key, config): + """ + get specific model space based on key and config. + + Args: + key(str): model space name. + config(dict): basic config information. + return: + model space(class) + """ + cls = SEARCHSPACE.get(key) + space = cls(config['input_size'], config['output_size'], config['block_num']) + + return space + + diff --git a/paddleslim/nas/searchspace/__init__.py b/paddleslim/nas/searchspace/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d1b5c527794b03967e6ce77f8bb16e883c06dbbf --- /dev/null +++ b/paddleslim/nas/searchspace/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .mobilenetv2_space import MobileNetV2Space diff --git a/paddleslim/nas/searchspace/layer.py b/paddleslim/nas/searchspace/layer.py new file mode 100644 index 0000000000000000000000000000000000000000..75ce180b4279e32601fe3fab6fea44d719a1a701 --- /dev/null +++ b/paddleslim/nas/searchspace/layer.py @@ -0,0 +1,42 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr + + +def conv_bn_layer(input, filter_size, num_filters, stride, padding, num_groups=1, act=None, name=None, use_cudnn=True): + """Build convolution and batch normalization layers. + Args: + input(Variable): input. + filter_size(int): filter size. + num_filters(int): number of filters. + stride(int): stride. + padding(int|list|str): padding. + num_groups(int): number of groups. + act(str): activation type. + name(str): name. + use_cudnn(bool): whether use cudnn. + Returns: + Variable, layers output. + """ + conv = fluid.layers.conv2d(input, num_filters=num_filters, filter_size=filter_size, stride=stride, padding=padding, + groups=num_groups, act=None, use_cudnn=use_cudnn, param_attr=ParamAttr(name=name+'_weights'), bias_attr=False) + bn_name = name + '_bn' + bn = fluid.layers.batch_norm(input=conv, param_attr=ParamAttr(name=bn_name+'_scale'), bias_attr=ParamAttr(name=bn_name+'_offset'), + moving_mean_name=bn_name+'_mean', moving_variance_name=bn_name+'_variance') + if act == 'relu6': + return fluid.layers.relu6(bn) + else: + return bn diff --git a/paddleslim/nas/searchspace/mobilenetv2_space.py b/paddleslim/nas/searchspace/mobilenetv2_space.py new file mode 100644 index 0000000000000000000000000000000000000000..bf224d6f6fb56b29fbee5e297f7a95650bd20dbd --- /dev/null +++ b/paddleslim/nas/searchspace/mobilenetv2_space.py @@ -0,0 +1,270 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +from ..search_space_base import SearchSpaceBase +from .layer import conv_bn_layer +from .registry import SEARCHSPACE + +@SEARCHSPACE.register_module +class MobileNetV2Space(SearchSpaceBase): + def __init__(self, input_size, output_size, block_num, scale=1.0, class_dim=1000): + super(MobileNetV2Space, self).__init__(input_size, output_size, block_num) + self.head_num = np.array([3,4,8,12,16,24,32]) #7 + self.filter_num1 = np.array([3,4,8,12,16,24,32,48]) #8 + self.filter_num2 = np.array([8,12,16,24,32,48,64,80]) #8 + self.filter_num3 = np.array([16,24,32,48,64,80,96,128]) #8 + self.filter_num4 = np.array([24,32,48,64,80,96,128,144,160,192]) #10 + self.filter_num5 = np.array([32,48,64,80,96,128,144,160,192,224]) #10 + self.filter_num6 = np.array([64,80,96,128,144,160,192,224,256,320,384,512]) #12 + self.k_size = np.array([3,5]) #2 + self.multiply = np.array([1,2,3,4,6]) #5 + self.repeat = np.array([1,2,3,4,5,6]) #6 + self.scale=scale + self.class_dim=class_dim + + def init_tokens(self): + """ + The initial token send to controller. + The first one is the index of the first layers' channel in self.head_num, + each line in the following represent the index of the [expansion_factor, filter_num, repeat_num, kernel_size] + """ + # original MobileNetV2 + return [4, # 1, 16, 1 + 4, 5, 1, 0, # 6, 24, 1 + 4, 5, 1, 0, # 6, 24, 2 + 4, 4, 2, 0, # 6, 32, 3 + 4, 4, 3, 0, # 6, 64, 4 + 4, 5, 2, 0, # 6, 96, 3 + 4, 7, 2, 0, # 6, 160, 3 + 4, 9, 0, 0] # 6, 320, 1 + + def range_table(self): + """ + get range table of current search space + """ + # head_num + 7 * [multiple(expansion_factor), filter_num, repeat, kernel_size] + return [7, + 5, 8, 6, 2, + 5, 8, 6, 2, + 5, 8, 6, 2, + 5, 8, 6, 2, + 5, 10, 6, 2, + 5, 10, 6, 2, + 5, 12, 6, 2] + + def token2arch(self, tokens=None): + """ + return net_arch function + """ + if tokens is None: + tokens = self.init_tokens() + + base_bottleneck_params_list = [ + (1, self.head_num[tokens[0]], 1, 1, 3), + (self.multiply[tokens[1]], self.filter_num1[tokens[2]], self.repeat[tokens[3]], 2, self.k_size[tokens[4]]), + (self.multiply[tokens[5]], self.filter_num1[tokens[6]], self.repeat[tokens[7]], 2, self.k_size[tokens[8]]), + (self.multiply[tokens[9]], self.filter_num2[tokens[10]], self.repeat[tokens[11]], 2, self.k_size[tokens[12]]), + (self.multiply[tokens[13]], self.filter_num3[tokens[14]], self.repeat[tokens[15]], 2, self.k_size[tokens[16]]), + (self.multiply[tokens[17]], self.filter_num3[tokens[18]], self.repeat[tokens[19]], 1, self.k_size[tokens[20]]), + (self.multiply[tokens[21]], self.filter_num5[tokens[22]], self.repeat[tokens[23]], 2, self.k_size[tokens[24]]), + (self.multiply[tokens[25]], self.filter_num6[tokens[26]], self.repeat[tokens[27]], 1, self.k_size[tokens[28]]), + ] + + assert self.block_num < 7, 'block number must less than 7, but receive block number is {}'.format(self.block_num) + + # the stride = 2 means downsample feature map in the convolution, so only when stride=2, block_num minus 1, + # otherwise, add layers to params_list directly. + bottleneck_params_list = [] + for param_list in base_bottleneck_params_list: + if param_list[3] == 1: + bottleneck_params_list.append(param_list) + else: + if self.block_num > 1: + bottleneck_params_list.append(param_list) + self.block_num -= 1 + else: + break + + def net_arch(input): + #conv1 + # all padding is 'SAME' in the conv2d, can compute the actual padding automatic. + input = conv_bn_layer( + input, + num_filters=int(32 * self.scale), + filter_size=3, + stride=2, + padding='SAME', + act='relu6', + name='conv1_1') + + # bottleneck sequences + i = 1 + in_c = int(32 * self.scale) + for layer_setting in bottleneck_params_list: + t, c, n, s, k = layer_setting + i += 1 + input = self.invresi_blocks( + input=input, + in_c=in_c, + t=t, + c=int(c * self.scale), + n=n, + s=s, + k=k, + name='conv' + str(i)) + in_c = int(c * self.scale) + + # if output_size is 1, add fc layer in the end + if self.output_size == 1: + input = fluid.layers.fc(input=input, + size=self.class_dim, + param_attr=ParamAttr(name='fc10_weights'), + bias_attr=ParamAttr(name='fc10_offset')) + else: + assert self.output_size == input.shape[2], \ + ("output_size must EQUAL to input_size / (2^block_num)." + "But receive input_size={}, output_size={}, block_num={}".format( + self.input_size, self.output_size, self.block_num)) + + return input + + return net_arch + + + def shortcut(self, input, data_residual): + """Build shortcut layer. + Args: + input(Variable): input. + data_residual(Variable): residual layer. + Returns: + Variable, layer output. + """ + return fluid.layers.elementwise_add(input, data_residual) + + + def inverted_residual_unit(self, + input, + num_in_filter, + num_filters, + ifshortcut, + stride, + filter_size, + expansion_factor, + reduction_ratio=4, + name=None): + """Build inverted residual unit. + Args: + input(Variable), input. + num_in_filter(int), number of in filters. + num_filters(int), number of filters. + ifshortcut(bool), whether using shortcut. + stride(int), stride. + filter_size(int), filter size. + padding(str|int|list), padding. + expansion_factor(float), expansion factor. + name(str), name. + Returns: + Variable, layers output. + """ + num_expfilter = int(round(num_in_filter * expansion_factor)) + channel_expand = conv_bn_layer( + input=input, + num_filters=num_expfilter, + filter_size=1, + stride=1, + padding='SAME', + num_groups=1, + act='relu6', + name=name + '_expand') + + bottleneck_conv = conv_bn_layer( + input=channel_expand, + num_filters=num_expfilter, + filter_size=filter_size, + stride=stride, + padding='SAME', + num_groups=num_expfilter, + act='relu6', + name=name + '_dwise', + use_cudnn=False) + + linear_out = conv_bn_layer( + input=bottleneck_conv, + num_filters=num_filters, + filter_size=1, + stride=1, + padding='SAME', + num_groups=1, + act=None, + name=name + '_linear') + out = linear_out + if ifshortcut: + out = self.shortcut(input=input, data_residual=out) + return out + + def invresi_blocks(self, + input, + in_c, + t, + c, + n, + s, + k, + name=None): + """Build inverted residual blocks. + Args: + input: Variable, input. + in_c: int, number of in filters. + t: float, expansion factor. + c: int, number of filters. + n: int, number of layers. + s: int, stride. + k: int, filter size. + name: str, name. + Returns: + Variable, layers output. + """ + first_block = self.inverted_residual_unit( + input=input, + num_in_filter=in_c, + num_filters=c, + ifshortcut=False, + stride=s, + filter_size=k, + expansion_factor=t, + name=name + '_1') + + last_residual_block = first_block + last_c = c + + for i in range(1, n): + last_residual_block = self.inverted_residual_unit( + input=last_residual_block, + num_in_filter=last_c, + num_filters=c, + ifshortcut=True, + stride=1, + filter_size=k, + expansion_factor=t, + name=name + '_' + str(i + 1)) + return last_residual_block + + diff --git a/paddleslim/nas/searchspace/registry.py b/paddleslim/nas/searchspace/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..69fd9edf6cc97b1b71935f8b82d3056544fced42 --- /dev/null +++ b/paddleslim/nas/searchspace/registry.py @@ -0,0 +1,3 @@ +from ..utils.registry import Registry + +SEARCHSPACE = Registry('searchspace') diff --git a/paddleslim/common/__init__.py b/paddleslim/nas/utils/__init__.py similarity index 100% rename from paddleslim/common/__init__.py rename to paddleslim/nas/utils/__init__.py diff --git a/paddleslim/nas/utils/registry.py b/paddleslim/nas/utils/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..5d055a9c3de98db19bc2f6ccd85f762b384e6ce3 --- /dev/null +++ b/paddleslim/nas/utils/registry.py @@ -0,0 +1,31 @@ +import inspect + +class Registry(object): + def __init__(self, name): + self._name = name + self._module_dict = dict() + def __repr__(self): + format_str = self.__class__.__name__ + '(name={}, items={})'.format(self._name, list(self._module_dict.keys())) + return format_str + + @property + def name(self): + return self._name + @property + def module_dict(self): + return self._module_dict + + def get(self, key): + return self._module_dict.get(key, None) + + def _register_module(self, module_class): + if not inspect.isclass(module_class): + raise TypeError('module must be a class, but receive {}.'.format(type(module_class))) + module_name = module_class.__name__ + if module_name in self._module_dict: + raise KeyError('{} is already registered in {}.'.format(module_name, self.name)) + self._module_dict[module_name] = module_class + + def register_module(self, cls): + self._register_module(cls) + return cls diff --git a/paddleslim/prune/__init__.py b/paddleslim/prune/__init__.py index 9d0531501ca43921438ee5b2fb58ac0ad2396d1b..926586c67d9e0b73ecd66f107ef897b389c5844f 100644 --- a/paddleslim/prune/__init__.py +++ b/paddleslim/prune/__init__.py @@ -11,3 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from pruner import Pruner diff --git a/paddleslim/prune/pruner.py b/paddleslim/prune/pruner.py new file mode 100644 index 0000000000000000000000000000000000000000..30341f63407aa1b0cc52ec5b43eadead27aec2ab --- /dev/null +++ b/paddleslim/prune/pruner.py @@ -0,0 +1,553 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle.fluid as fluid +from core import VarWrapper, OpWrapper, GraphWrapper + +__all__ = ["prune"] + + +class Pruner(): + def __init__(self, criterion="l1_norm"): + """ + Args: + criterion(str): the criterion used to sort channels for pruning. + It only supports 'l1_norm' currently. + """ + self.criterion = criterion + + def prune(self, + program, + scope, + params, + ratios, + place=None, + lazy=False, + only_graph=False, + param_backup=None, + param_shape_backup=None): + """ + Pruning the given parameters. + Args: + program(fluid.Program): The program to be pruned. + scope(fluid.Scope): The scope storing paramaters to be pruned. + params(list): A list of parameter names to be pruned. + ratios(list): A list of ratios to be used to pruning parameters. + place(fluid.Place): The device place of filter parameters. Defalut: None. + lazy(bool): True means setting the pruned elements to zero. + False means cutting down the pruned elements. Default: False. + only_graph(bool): True means only modifying the graph. + False means modifying graph and variables in scope. Default: False. + param_backup(dict): A dict to backup the values of parameters. Default: None. + param_shape_backup(dict): A dict to backup the shapes of parameters. Default: None. + Returns: + Program: The pruned program. + """ + + self.pruned_list = [] + graph = GraphWrapper(program.clone()) + self._prune_parameters( + graph, + scope, + params, + ratios, + place, + lazy=False, + only_graph=False, + param_backup=None, + param_shape_backup=None) + return graph.program + + def _prune_filters_by_ratio(self, + scope, + params, + ratio, + place, + lazy=False, + only_graph=False, + param_shape_backup=None, + param_backup=None): + """ + Pruning filters by given ratio. + Args: + scope(fluid.core.Scope): The scope used to pruning filters. + params(list): A list of filter parameters. + ratio(float): The ratio to be pruned. + place(fluid.Place): The device place of filter parameters. + lazy(bool): True means setting the pruned elements to zero. + False means cutting down the pruned elements. + only_graph(bool): True means only modifying the graph. + False means modifying graph and variables in scope. + """ + if params[0].name() in self.pruned_list[0]: + return + param_t = scope.find_var(params[0].name()).get_tensor() + pruned_idx = self._cal_pruned_idx( + params[0].name(), np.array(param_t), ratio, axis=0) + for param in params: + assert isinstance(param, VarWrapper) + param_t = scope.find_var(param.name()).get_tensor() + if param_backup is not None and (param.name() not in param_backup): + param_backup[param.name()] = copy.deepcopy(np.array(param_t)) + pruned_param = self._prune_tensor( + np.array(param_t), pruned_idx, pruned_axis=0, lazy=lazy) + if not only_graph: + param_t.set(pruned_param, place) + ori_shape = param.shape() + if param_shape_backup is not None and ( + param.name() not in param_shape_backup): + param_shape_backup[param.name()] = copy.deepcopy(param.shape()) + new_shape = list(param.shape()) + new_shape[0] = pruned_param.shape[0] + param.set_shape(new_shape) + self.pruned_list[0].append(param.name()) + return pruned_idx + + def _prune_parameter_by_idx(self, + scope, + params, + pruned_idx, + pruned_axis, + place, + lazy=False, + only_graph=False, + param_shape_backup=None, + param_backup=None): + """ + Pruning parameters in given axis. + Args: + scope(fluid.core.Scope): The scope storing paramaters to be pruned. + params(VarWrapper): The parameter to be pruned. + pruned_idx(list): The index of elements to be pruned. + pruned_axis(int): The pruning axis. + place(fluid.Place): The device place of filter parameters. + lazy(bool): True means setting the pruned elements to zero. + False means cutting down the pruned elements. + only_graph(bool): True means only modifying the graph. + False means modifying graph and variables in scope. + """ + if params[0].name() in self.pruned_list[pruned_axis]: + return + for param in params: + assert isinstance(param, VarWrapper) + param_t = scope.find_var(param.name()).get_tensor() + if param_backup is not None and (param.name() not in param_backup): + param_backup[param.name()] = copy.deepcopy(np.array(param_t)) + pruned_param = self._prune_tensor( + np.array(param_t), pruned_idx, pruned_axis, lazy=lazy) + if not only_graph: + param_t.set(pruned_param, place) + ori_shape = param.shape() + + if param_shape_backup is not None and ( + param.name() not in param_shape_backup): + param_shape_backup[param.name()] = copy.deepcopy(param.shape()) + new_shape = list(param.shape()) + new_shape[pruned_axis] = pruned_param.shape[pruned_axis] + param.set_shape(new_shape) + self.pruned_list[pruned_axis].append(param.name()) + + def _forward_search_related_op(self, graph, param): + """ + Forward search operators that will be affected by pruning of param. + Args: + graph(GraphWrapper): The graph to be searched. + param(VarWrapper): The current pruned parameter. + Returns: + list: A list of operators. + """ + assert isinstance(param, VarWrapper) + visited = {} + for op in graph.ops(): + visited[op.idx()] = False + stack = [] + for op in graph.ops(): + if (not op.is_bwd_op()) and (param in op.all_inputs()): + stack.append(op) + visit_path = [] + while len(stack) > 0: + top_op = stack[len(stack) - 1] + if visited[top_op.idx()] == False: + visit_path.append(top_op) + visited[top_op.idx()] = True + next_ops = None + if top_op.type() == "conv2d" and param not in top_op.all_inputs(): + next_ops = None + elif top_op.type() == "mul": + next_ops = None + else: + next_ops = self._get_next_unvisited_op(graph, visited, top_op) + if next_ops == None: + stack.pop() + else: + stack += next_ops + return visit_path + + def _get_next_unvisited_op(self, graph, visited, top_op): + """ + Get next unvisited adjacent operators of given operators. + Args: + graph(GraphWrapper): The graph used to search. + visited(list): The ids of operators that has been visited. + top_op: The given operator. + Returns: + list: A list of operators. + """ + assert isinstance(top_op, OpWrapper) + next_ops = [] + for op in graph.next_ops(top_op): + if (visited[op.idx()] == False) and (not op.is_bwd_op()): + next_ops.append(op) + return next_ops if len(next_ops) > 0 else None + + def _get_accumulator(self, graph, param): + """ + Get accumulators of given parameter. The accumulator was created by optimizer. + Args: + graph(GraphWrapper): The graph used to search. + param(VarWrapper): The given parameter. + Returns: + list: A list of accumulators which are variables. + """ + assert isinstance(param, VarWrapper) + params = [] + for op in param.outputs(): + if op.is_opt_op(): + for out_var in op.all_outputs(): + if graph.is_persistable(out_var) and out_var.name( + ) != param.name(): + params.append(out_var) + return params + + def _forward_pruning_ralated_params(self, + graph, + scope, + param, + place, + ratio=None, + pruned_idxs=None, + lazy=False, + only_graph=False, + param_backup=None, + param_shape_backup=None): + """ + Pruning all the parameters affected by the pruning of given parameter. + Args: + graph(GraphWrapper): The graph to be searched. + scope(fluid.core.Scope): The scope storing paramaters to be pruned. + param(VarWrapper): The given parameter. + place(fluid.Place): The device place of filter parameters. + ratio(float): The target ratio to be pruned. + pruned_idx(list): The index of elements to be pruned. + lazy(bool): True means setting the pruned elements to zero. + False means cutting down the pruned elements. + only_graph(bool): True means only modifying the graph. + False means modifying graph and variables in scope. + """ + assert isinstance( + graph, + GraphWrapper), "graph must be instance of slim.core.GraphWrapper" + assert isinstance( + param, + VarWrapper), "param must be instance of slim.core.VarWrapper" + + if param.name() in self.pruned_list[0]: + return + related_ops = self._forward_search_related_op(graph, param) + + if ratio is None: + assert pruned_idxs is not None + self._prune_parameter_by_idx( + scope, [param] + self._get_accumulator(graph, param), + pruned_idxs, + pruned_axis=0, + place=place, + lazy=lazy, + only_graph=only_graph, + param_backup=param_backup, + param_shape_backup=param_shape_backup) + + else: + pruned_idxs = self._prune_filters_by_ratio( + scope, [param] + self._get_accumulator(graph, param), + ratio, + place, + lazy=lazy, + only_graph=only_graph, + param_backup=param_backup, + param_shape_backup=param_shape_backup) + corrected_idxs = pruned_idxs[:] + + for idx, op in enumerate(related_ops): + if op.type() == "conv2d" and (param not in op.all_inputs()): + for in_var in op.all_inputs(): + if graph.is_parameter(in_var): + conv_param = in_var + self._prune_parameter_by_idx( + scope, [conv_param] + self._get_accumulator( + graph, conv_param), + corrected_idxs, + pruned_axis=1, + place=place, + lazy=lazy, + only_graph=only_graph, + param_backup=param_backup, + param_shape_backup=param_shape_backup) + if op.type() == "depthwise_conv2d": + for in_var in op.all_inputs(): + if graph.is_parameter(in_var): + conv_param = in_var + self._prune_parameter_by_idx( + scope, [conv_param] + self._get_accumulator( + graph, conv_param), + corrected_idxs, + pruned_axis=0, + place=place, + lazy=lazy, + only_graph=only_graph, + param_backup=param_backup, + param_shape_backup=param_shape_backup) + elif op.type() == "elementwise_add": + # pruning bias + for in_var in op.all_inputs(): + if graph.is_parameter(in_var): + bias_param = in_var + self._prune_parameter_by_idx( + scope, [bias_param] + self._get_accumulator( + graph, bias_param), + pruned_idxs, + pruned_axis=0, + place=place, + lazy=lazy, + only_graph=only_graph, + param_backup=param_backup, + param_shape_backup=param_shape_backup) + elif op.type() == "mul": # pruning fc layer + fc_input = None + fc_param = None + for in_var in op.all_inputs(): + if graph.is_parameter(in_var): + fc_param = in_var + else: + fc_input = in_var + + idx = [] + feature_map_size = fc_input.shape()[2] * fc_input.shape()[3] + range_idx = np.array(range(feature_map_size)) + for i in corrected_idxs: + idx += list(range_idx + i * feature_map_size) + corrected_idxs = idx + self._prune_parameter_by_idx( + scope, [fc_param] + self._get_accumulator(graph, fc_param), + corrected_idxs, + pruned_axis=0, + place=place, + lazy=lazy, + only_graph=only_graph, + param_backup=param_backup, + param_shape_backup=param_shape_backup) + + elif op.type() == "concat": + concat_inputs = op.all_inputs() + last_op = related_ops[idx - 1] + for out_var in last_op.all_outputs(): + if out_var in concat_inputs: + concat_idx = concat_inputs.index(out_var) + offset = 0 + for ci in range(concat_idx): + offset += concat_inputs[ci].shape()[1] + corrected_idxs = [x + offset for x in pruned_idxs] + elif op.type() == "batch_norm": + bn_inputs = op.all_inputs() + mean = bn_inputs[2] + variance = bn_inputs[3] + alpha = bn_inputs[0] + beta = bn_inputs[1] + self._prune_parameter_by_idx( + scope, [mean] + self._get_accumulator(graph, mean), + corrected_idxs, + pruned_axis=0, + place=place, + lazy=lazy, + only_graph=only_graph, + param_backup=param_backup, + param_shape_backup=param_shape_backup) + self._prune_parameter_by_idx( + scope, [variance] + self._get_accumulator(graph, variance), + corrected_idxs, + pruned_axis=0, + place=place, + lazy=lazy, + only_graph=only_graph, + param_backup=param_backup, + param_shape_backup=param_shape_backup) + self._prune_parameter_by_idx( + scope, [alpha] + self._get_accumulator(graph, alpha), + corrected_idxs, + pruned_axis=0, + place=place, + lazy=lazy, + only_graph=only_graph, + param_backup=param_backup, + param_shape_backup=param_shape_backup) + self._prune_parameter_by_idx( + scope, [beta] + self._get_accumulator(graph, beta), + corrected_idxs, + pruned_axis=0, + place=place, + lazy=lazy, + only_graph=only_graph, + param_backup=param_backup, + param_shape_backup=param_shape_backup) + + def _prune_parameters(self, + graph, + scope, + params, + ratios, + place, + lazy=False, + only_graph=False, + param_backup=None, + param_shape_backup=None): + """ + Pruning the given parameters. + Args: + graph(GraphWrapper): The graph to be searched. + scope(fluid.core.Scope): The scope storing paramaters to be pruned. + params(list): A list of parameter names to be pruned. + ratios(list): A list of ratios to be used to pruning parameters. + place(fluid.Place): The device place of filter parameters. + pruned_idx(list): The index of elements to be pruned. + lazy(bool): True means setting the pruned elements to zero. + False means cutting down the pruned elements. + only_graph(bool): True means only modifying the graph. + False means modifying graph and variables in scope. + """ + assert len(params) == len(ratios) + self.pruned_list = [[], []] + for param, ratio in zip(params, ratios): + assert isinstance(param, str) or isinstance(param, unicode) + param = graph.var(param) + self._forward_pruning_ralated_params( + graph, + scope, + param, + place, + ratio=ratio, + lazy=lazy, + only_graph=only_graph, + param_backup=param_backup, + param_shape_backup=param_shape_backup) + ops = param.outputs() + for op in ops: + if op.type() == 'conv2d': + brother_ops = self._search_brother_ops(graph, op) + for broher in brother_ops: + for p in graph.get_param_by_op(broher): + self._forward_pruning_ralated_params( + graph, + scope, + p, + place, + ratio=ratio, + lazy=lazy, + only_graph=only_graph, + param_backup=param_backup, + param_shape_backup=param_shape_backup) + + def _search_brother_ops(self, graph, op_node): + """ + Search brother operators that was affected by pruning of given operator. + Args: + graph(GraphWrapper): The graph to be searched. + op_node(OpWrapper): The start node for searching. + Returns: + list: A list of operators. + """ + visited = [op_node.idx()] + stack = [] + brothers = [] + for op in graph.next_ops(op_node): + if (op.type() != 'conv2d') and (op.type() != 'fc') and ( + not op.is_bwd_op()): + stack.append(op) + visited.append(op.idx()) + while len(stack) > 0: + top_op = stack.pop() + for parent in graph.pre_ops(top_op): + if parent.idx() not in visited and (not parent.is_bwd_op()): + if ((parent.type() == 'conv2d') or + (parent.type() == 'fc')): + brothers.append(parent) + else: + stack.append(parent) + visited.append(parent.idx()) + + for child in graph.next_ops(top_op): + if (child.type() != 'conv2d') and (child.type() != 'fc') and ( + child.idx() not in visited) and ( + not child.is_bwd_op()): + stack.append(child) + visited.append(child.idx()) + return brothers + + def _cal_pruned_idx(self, name, param, ratio, axis): + """ + Calculate the index to be pruned on axis by given pruning ratio. + Args: + name(str): The name of parameter to be pruned. + param(np.array): The data of parameter to be pruned. + ratio(float): The ratio to be pruned. + axis(int): The axis to be used for pruning given parameter. + If it is None, the value in self.pruning_axis will be used. + default: None. + Returns: + list: The indexes to be pruned on axis. + """ + prune_num = int(round(param.shape[axis] * ratio)) + reduce_dims = [i for i in range(len(param.shape)) if i != axis] + if self.criterion == 'l1_norm': + criterions = np.sum(np.abs(param), axis=tuple(reduce_dims)) + pruned_idx = criterions.argsort()[:prune_num] + return pruned_idx + + def _prune_tensor(self, tensor, pruned_idx, pruned_axis, lazy=False): + """ + Pruning a array by indexes on given axis. + Args: + tensor(numpy.array): The target array to be pruned. + pruned_idx(list): The indexes to be pruned. + pruned_axis(int): The axis of given array to be pruned on. + lazy(bool): True means setting the pruned elements to zero. + False means remove the pruned elements from memory. + default: False. + Returns: + numpy.array: The pruned array. + """ + mask = np.zeros(tensor.shape[pruned_axis], dtype=bool) + mask[pruned_idx] = True + + def func(data): + return data[~mask] + + def lazy_func(data): + data[mask] = 0 + return data + + if lazy: + return np.apply_along_axis(lazy_func, pruned_axis, tensor) + else: + return np.apply_along_axis(func, pruned_axis, tensor) diff --git a/paddleslim/quant/__init__.py b/paddleslim/quant/__init__.py index 183c65ea68000e711ae7a8c720d20b65966a01ee..aa42adbf2e751d1fa2bba50280164fe92a6b10ce 100644 --- a/paddleslim/quant/__init__.py +++ b/paddleslim/quant/__init__.py @@ -12,4 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from quanter import quant_aware, quant_post, convert \ No newline at end of file +from quanter import quant_aware, quant_post, convert +from .quant_embedding import quant_embedding diff --git a/paddleslim/quant/quant_embedding.py b/paddleslim/quant/quant_embedding.py new file mode 100755 index 0000000000000000000000000000000000000000..46a81db65c55f91fdf5525bf0da25414598a0b71 --- /dev/null +++ b/paddleslim/quant/quant_embedding.py @@ -0,0 +1,259 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +import copy +import numpy as np + +import paddle.fluid as fluid +from paddle.fluid.framework import IrGraph +from paddle.fluid import core + +#_logger = logging.basicConfig(level=logging.DEBUG) + +__all__ = ['quant_embedding'] + +default_config = { + "quantize_type": "abs_max", + "quantize_bits": 8, + "dtype": "int8" +} + +support_quantize_types = ['abs_max'] +support_quantize_bits = [8] +support_dtype = ['int8'] + + +def _merge_config(old_config, new_config): + """ + merge default config and user defined config + + Args: + old_config(dict): the copy of default_config + new_config(dict): the user defined config, 'params_name' must be set. + When 'threshold' is not set, quant embedding without clip . + """ + old_config.update(new_config) + keys = old_config.keys() + assert 'params_name' in keys, "params_name must be set" + + quantize_type = old_config['quantize_type'] + assert isinstance(quantize_type, str), "quantize_type must be \ + str" + + assert quantize_type in support_quantize_types, " \ + quantize_type {} is not supported, now supported quantize type \ + are {}.".format(quantize_type, support_quantize_types) + + quantize_bits = old_config['quantize_bits'] + assert isinstance(quantize_bits, int), "quantize_bits must be int" + assert quantize_bits in support_quantize_bits, " quantize_bits {} \ + is not supported, now supported quantize bits are \ + {}. ".format(quantize_bits, support_quantize_bits) + + dtype = old_config['dtype'] + assert isinstance(dtype, str), "dtype must be str" + assert dtype in support_dtype, " dtype {} is not \ + supported, now supported dtypes are {} \ + ".format(dtype, support_dtype) + if 'threshold' in keys: + assert isinstance(old_config['threshold'], (float, int)), "threshold \ + must be number." + + print("quant_embedding config {}".format(old_config)) + return old_config + + +def _get_var_tensor(scope, var_name): + """ + get tensor array by name. + Args: + scope(fluid.Scope): scope to get var + var_name(str): vatiable name + Return: + np.array + """ + return np.array(scope.find_var(var_name).get_tensor()) + + +def _clip_tensor(tensor_array, threshold): + """ + when 'threshold' is set, clip tensor by 'threshold' and '-threshold' + Args: + tensor_array(np.array): array to clip + config(dict): config dict + """ + tensor_array[tensor_array > threshold] = threshold + tensor_array[tensor_array < -threshold] = -threshold + return tensor_array + + +def _get_scale_var_name(var_name): + """ + get scale var name + """ + return var_name + '.scale' + + +def _get_quant_var_name(var_name): + """ + get quantized var name + """ + return var_name + '.int8' + + +def _get_dequant_var_name(var_name): + """ + get dequantized var name + """ + return var_name + '.dequantize' + + +def _restore_var(name, arr, scope, place): + """ + restore quantized array to quantized var + """ + tensor = scope.find_var(name).get_tensor() + tensor.set(arr, place) + + +def _clear_var(var_name, scope): + """ + free memory of var + """ + tensor = scope.find_var(var_name).get_tensor() + tensor._clear() + + +def _quant_embedding_abs_max(graph, scope, place, config): + """ + quantize embedding using abs_max + + Args: + graph(IrGraph): graph that includes lookup_table op + scope(fluid.Scope): scope + place(fluid.CPUPlace or flud.CUDAPlace): place + config(dict): config to quant + """ + + def _quant_abs_max(tensor_array, config): + """ + quant array using abs_max op + """ + bit_length = config['quantize_bits'] + scale = np.max(np.abs(tensor_array)).astype("float32") + quanted_tensor = np.round(tensor_array / scale * ( + (1 << (bit_length - 1)) - 1)) + return scale, quanted_tensor.astype(config['dtype']) + + def _insert_dequant_abs_max_op(graph, scope, var_node, scale_node, config): + """ + Insert dequantize_abs_max op in graph + """ + assert var_node.is_var(), "{} is not a var".format(var_node.name()) + + dequant_var_node = graph.create_var_node( + name=_get_dequant_var_name(var_node.name()), + var_type=var_node.type(), + shape=var_node.shape(), + var_dtype=core.VarDesc.VarType.FP32) + scope.var(dequant_var_node.name()) + + max_range = (1 << (config['quantize_bits'] - 1)) - 1 + output_ops = var_node.outputs + dequant_op = graph.create_op_node( + op_type='dequantize_abs_max', + attrs={ + 'max_range': float(max_range), + 'op_role': core.op_proto_and_checker_maker.OpRole.Forward + }, + inputs={'X': var_node, + 'Scale': scale_node}, + outputs={'Out': dequant_var_node}) + graph.link_to(var_node, dequant_op) + graph.link_to(scale_node, dequant_op) + graph.link_to(dequant_op, dequant_var_node) + for node in output_ops: + graph.update_input_link(var_node, dequant_var_node, node) + + all_var_nodes = graph.all_var_nodes() + var_name = config['params_name'] + # find embedding var node by 'params_name' + embedding_node = graph._find_node_by_name(all_var_nodes, var_name) + embedding_tensor = _get_var_tensor(scope, var_name) + if 'threshold' in config.keys(): + embedding_tensor = _clip_tensor(embedding_tensor, config['threshold']) + + # get scale and quanted tensor + scale, quanted_tensor = _quant_abs_max(embedding_tensor, config) + + #create params must to use create_persistable_node + scale_var = graph.create_persistable_node( + _get_scale_var_name(var_name), + var_type=embedding_node.type(), + shape=[1], + var_dtype=core.VarDesc.VarType.FP32) + quant_tensor_var = graph.create_persistable_node( + _get_quant_var_name(var_name), + var_type=embedding_node.type(), + shape=embedding_node.shape(), + var_dtype=core.VarDesc.VarType.INT8) + # create var in scope + scope.var(_get_quant_var_name(var_name)) + scope.var(_get_scale_var_name(var_name)) + #set var by tensor array or scale + _restore_var(_get_quant_var_name(var_name), quanted_tensor, scope, place) + _restore_var(_get_scale_var_name(var_name), np.array(scale), scope, place) + + # insert dequantize_abs_max op + for op_node in embedding_node.outputs: + if op_node.name() == 'lookup_table': + graph.update_input_link(embedding_node, quant_tensor_var, op_node) + var_node = op_node.outputs[0] + _insert_dequant_abs_max_op(graph, scope, var_node, scale_var, + config) + + # free float embedding params memory + _clear_var(embedding_node.name(), scope) + graph.safe_remove_nodes(embedding_node) + + +def quant_embedding(program, place, config, scope=None): + """ + quant lookup_table op parameters + Args: + program(fluid.Program): infer program + scope(fluid.Scope): the scope to store var, when is None will use fluid.global_scope() + place(fluid.CPUPlace or fluid.CUDAPlace): place + config(dict): config to quant. The keys are 'params_name', 'quantize_type', \ + 'quantize_bits', 'dtype', 'threshold'. \ + 'params_name': parameter name to quant, must be set. + 'quantize_type': quantize type, supported types are ['abs_max']. default is "abs_max". + 'quantize_bits': quantize bits, supported bits are [8]. default is 8. + 'dtype': quantize dtype, supported dtype are ['int8']. default is 'int8'. + 'threshold': threshold to clip tensor before quant. When threshold is not set, \ + tensor will not be clipped. + """ + assert isinstance(config, dict), "config must be dict" + config = _merge_config(copy.deepcopy(default_config), config) + scope = fluid.global_scope() if scope is None else scope + + graph = IrGraph(core.Graph(program.desc), for_test=True) + if config['quantize_type'] == 'abs_max': + _quant_embedding_abs_max(graph, scope, place, config) + + return graph.to_program() diff --git a/paddleslim/tests/layers.py b/paddleslim/tests/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..140ff5919b9d8c9821b371db5ca4896db28bf7f0 --- /dev/null +++ b/paddleslim/tests/layers.py @@ -0,0 +1,44 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr + + +def conv_bn_layer(input, + num_filters, + filter_size, + name, + stride=1, + groups=1, + act=None): + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + param_attr=ParamAttr(name=name + "_weights"), + bias_attr=False, + name=name + "_out") + bn_name = name + "_bn" + return fluid.layers.batch_norm( + input=conv, + act=act, + name=bn_name + '_output', + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance', ) diff --git a/paddleslim/tests/test_nas_search_space.py b/paddleslim/tests/test_nas_search_space.py new file mode 100644 index 0000000000000000000000000000000000000000..c2f2af5b8b9fac38a6d8f1273e853aefc6983bff --- /dev/null +++ b/paddleslim/tests/test_nas_search_space.py @@ -0,0 +1,43 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +sys.path.append('..') +import unittest +import paddle.fluid as fluid +from nas.search_space_factory import SearchSpaceFactory + +class TestSearchSpace(unittest.TestCase): + def test_searchspace(self): + # if output_size is 1, the model will add fc layer in the end. + config = {'input_size': 224, 'output_size': 7, 'block_num': 5} + space = SearchSpaceFactory() + + my_space = space.get_search_space('MobileNetV2Space', config) + model_arch = my_space.token2arch() + + train_prog = fluid.Program() + startup_prog = fluid.Program() + with fluid.program_guard(train_prog, startup_prog): + input_size= config['input_size'] + model_input = fluid.layers.data(name='model_in', shape=[1, 3, input_size, input_size], dtype='float32', append_batch_size=False) + predict = model_arch(model_input) + self.assertTrue(predict.shape[2] == config['output_size']) + + + #for op in train_prog.global_block().ops: + # print(op.type) + +if __name__ == '__main__': + unittest.main() diff --git a/paddleslim/tests/test_prune.py b/paddleslim/tests/test_prune.py new file mode 100644 index 0000000000000000000000000000000000000000..93609367351618ce375f164a1dca284e85369e4c --- /dev/null +++ b/paddleslim/tests/test_prune.py @@ -0,0 +1,79 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +sys.path.append("../") +import unittest +import paddle.fluid as fluid +from prune import Pruner +from layers import conv_bn_layer + + +class TestPrune(unittest.TestCase): + def test_prune(self): + main_program = fluid.Program() + startup_program = fluid.Program() + # X X O X O + # conv1-->conv2-->sum1-->conv3-->conv4-->sum2-->conv5-->conv6 + # | ^ | ^ + # |____________| |____________________| + # + # X: prune output channels + # O: prune input channels + with fluid.program_guard(main_program, startup_program): + input = fluid.data(name="image", shape=[None, 3, 16, 16]) + conv1 = conv_bn_layer(input, 8, 3, "conv1") + conv2 = conv_bn_layer(conv1, 8, 3, "conv2") + sum1 = conv1 + conv2 + conv3 = conv_bn_layer(sum1, 8, 3, "conv3") + conv4 = conv_bn_layer(conv3, 8, 3, "conv4") + sum2 = conv4 + sum1 + conv5 = conv_bn_layer(sum2, 8, 3, "conv5") + conv6 = conv_bn_layer(conv5, 8, 3, "conv6") + + shapes = {} + for param in main_program.global_block().all_parameters(): + shapes[param.name] = param.shape + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + scope = fluid.Scope() + exe.run(startup_program, scope=scope) + pruner = Pruner() + main_program = pruner.prune( + main_program, + scope, + params=["conv4_weights"], + ratios=[0.5], + place=place, + lazy=False, + only_graph=False, + param_backup=None, + param_shape_backup=None) + + shapes = { + "conv1_weights": (4L, 3L, 3L, 3L), + "conv2_weights": (4L, 4L, 3L, 3L), + "conv3_weights": (8L, 4L, 3L, 3L), + "conv4_weights": (4L, 8L, 3L, 3L), + "conv5_weights": (8L, 4L, 3L, 3L), + "conv6_weights": (8L, 8L, 3L, 3L) + } + + for param in main_program.global_block().all_parameters(): + if "weights" in param.name: + self.assertTrue(param.shape == shapes[param.name]) + + +if __name__ == '__main__': + unittest.main() diff --git a/setup.py b/setup.py index d79620c5791c3a0144ca3aaa9f1d5d7b979dff31..86421878ed5493c3ab5f8b446f6b62a3b0135975 100644 --- a/setup.py +++ b/setup.py @@ -33,8 +33,12 @@ with open('./requirements.txt') as f: setup_requires = f.read().splitlines() packages = [ - 'paddleslim', 'paddleslim.prune', 'paddleslim.dist', 'paddleslim.nas', - 'paddleslim.analysis', 'paddleslim.quant' + 'paddleslim', + 'paddleslim.prune', + 'paddleslim.dist', + 'paddleslim.nas', + 'paddleslim.analysis', + 'paddleslim.quant', ] setup(