提交 a5701512 编写于 作者: I itminner

Merge remote-tracking branch 'upstream/develop' into develop

# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import graph_wrapper
from .graph_wrapper import *
__all__ = graph_wrapper.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import copy
import pickle
import numpy as np
from collections import OrderedDict
from collections import Iterable
from paddle.fluid.framework import Program, program_guard, Parameter, Variable
__all__ = ['GraphWrapper', 'VarWrapper', 'OpWrapper']
OPTIMIZER_OPS = [
'momentum',
'lars_momentum',
'adagrad',
'adam',
'adamax',
'dpsgd',
'decayed_adagrad',
'adadelta',
'rmsprop',
]
class VarWrapper(object):
def __init__(self, var, graph):
assert isinstance(var, Variable)
assert isinstance(graph, GraphWrapper)
self._var = var
self._graph = graph
def __eq__(self, v):
"""
Overwrite this function for ...in... syntax in python.
"""
return self._var.name == v._var.name
def name(self):
"""
Get the name of the variable.
"""
return self._var.name
def shape(self):
"""
Get the shape of the varibale.
"""
return self._var.shape
def set_shape(self, shape):
"""
Set the shape of the variable.
"""
self._var.desc.set_shape(shape)
def inputs(self):
"""
Get all the operators that use this variable as output.
Returns:
list<OpWrapper>: A list of operators.
"""
ops = []
for op in self._graph.ops():
if self in op.all_outputs():
ops.append(op)
return ops
def outputs(self):
"""
Get all the operators that use this variable as input.
Returns:
list<OpWrapper>: A list of operators.
"""
ops = []
for op in self._graph.ops():
if self in op.all_inputs():
ops.append(op)
return ops
class OpWrapper(object):
def __init__(self, op, graph):
assert isinstance(graph, GraphWrapper)
self._op = op
self._graph = graph
def __eq__(self, op):
"""
Overwrite this function for ...in... syntax in python.
"""
return self.idx() == op.idx()
def all_inputs(self):
"""
Get all the input variables of this operator.
"""
return [
self._graph.var(var_name) for var_name in self._op.input_arg_names
]
def all_outputs(self):
"""
Get all the output variables of this operator.
"""
return [
self._graph.var(var_name) for var_name in self._op.output_arg_names
]
def idx(self):
"""
Get the id of this operator.
"""
return self._op.idx
def type(self):
"""
Get the type of this operator.
"""
return self._op.type
def is_bwd_op(self):
"""
Whether this operator is backward op.
"""
return self.type().endswith('_grad')
def is_opt_op(self):
"""
Whether this operator is optimizer op.
"""
return self.type() in OPTIMIZER_OPS
def inputs(self, name):
"""
Get all the varibales by the input name.
"""
return [self._graph.var(var_name) for var_name in self._op.input(name)]
def outputs(self, name):
"""
Get all the varibales by the output name.
"""
return [
self._graph.var(var_name) for var_name in self._op.output(name)
]
def set_attr(self, key, value):
"""
Set the value of attribute by attribute's name.
Args:
key(str): the attribute name.
value(bool|int|str|float|list): the value of the attribute.
"""
self._op._set_attr(key, value)
def attr(self, name):
"""
Get the attribute by name.
Args:
name(str): the attribute name.
Returns:
bool|int|str|float|list: The attribute value. The return value
can be any valid attribute type.
"""
return self._op.attr(name)
class GraphWrapper(object):
"""
It is a wrapper of paddle.fluid.framework.IrGraph with some special functions
for paddle slim framework.
"""
def __init__(self, program=None, in_nodes=[], out_nodes=[]):
"""
Args:
program(framework.Program): A program with
in_nodes(dict): A dict to indicate the input nodes of the graph.
The key is user-defined and human-readable name.
The value is the name of Variable.
out_nodes(dict): A dict to indicate the input nodes of the graph.
The key is user-defined and human-readable name.
The value is the name of Variable.
"""
super(GraphWrapper, self).__init__()
self.program = Program() if program is None else program
self.persistables = {}
self.teacher_persistables = {}
for var in self.program.list_vars():
if var.persistable:
self.persistables[var.name] = var
self.compiled_graph = None
in_nodes = [] if in_nodes is None else in_nodes
out_nodes = [] if out_nodes is None else out_nodes
self.in_nodes = OrderedDict(in_nodes)
self.out_nodes = OrderedDict(out_nodes)
self._attrs = OrderedDict()
def all_parameters(self):
"""
Get all the parameters in this graph.
Returns:
list<VarWrapper>: A list of VarWrapper instances.
"""
params = []
for block in self.program.blocks:
for param in block.all_parameters():
params.append(VarWrapper(param, self))
return params
def is_parameter(self, var):
"""
Whether the given variable is parameter.
Args:
var(VarWrapper): The given varibale.
"""
return isinstance(var._var, Parameter)
def is_persistable(self, var):
"""
Whether the given variable is persistable.
Args:
var(VarWrapper): The given varibale.
"""
return var._var.persistable
def ops(self):
"""
Return all operator nodes included in the graph as a set.
"""
ops = []
for block in self.program.blocks:
for op in block.ops:
ops.append(OpWrapper(op, self))
return ops
def vars(self):
"""
Get all the variables.
"""
return [VarWrapper(var, self) for var in self.program.list_vars()]
def var(self, name):
"""
Get the variable by variable name.
"""
return VarWrapper(self.program.global_block().var(name), self)
def clone(self, for_test=False):
"""
Clone a new graph from current graph.
Returns:
(GraphWrapper): The wrapper of a new graph.
"""
return GraphWrapper(
self.program.clone(for_test),
copy.deepcopy(self.in_nodes), copy.deepcopy(self.out_nodes))
def program(self):
"""
Get the program in current wrapper.
"""
return self.program
def pre_ops(self, op):
"""
Get all the previous operators of target operator.
Args:
op(OpWrapper): Target operator..
Returns:
list<OpWrapper>: A list of operators.
"""
ops = []
for p in self.ops():
for in_var in op.all_inputs():
if in_var in p.all_outputs():
ops.append(p)
return ops
def next_ops(self, op):
"""
Get all the next operators of target operator.
Args:
op(OpWrapper): Target operator..
Returns:
list<OpWrapper>: A list of operators.
"""
ops = []
for p in self.ops():
for out_var in op.all_outputs():
if out_var in p.all_inputs():
ops.append(p)
return ops
def get_param_by_op(self, op):
"""
Get the parameters used by target operator.
"""
assert isinstance(op, OpWrapper)
params = []
for var in op.all_inputs():
if isinstance(var._var, Parameter):
params.append(var)
assert len(params) > 0
return params
def numel_params(self):
"""
Get the number of elements in all parameters.
"""
ret = 0
for param in self.all_parameters():
ret += np.product(param.shape())
return ret
def update_param_shape(self, scope):
"""
Update the shape of parameters in the graph according to tensors in scope.
It is used after loading pruned parameters from file.
"""
for param in self.all_parameters():
tensor_shape = np.array(
scope.find_var(param.name()).get_tensor()).shape
param.set_shape(tensor_shape)
def infer_shape(self):
"""
Update the groups of convolution layer according to current filters.
It is used after loading pruned parameters from file.
"""
for op in self.ops():
if op.type() != 'conditional_block':
op._op.desc.infer_shape(op._op.block.desc)
def update_groups_of_conv(self):
for op in self.ops():
if op.type() == 'depthwise_conv2d' or op.type(
) == 'depthwise_conv2d_grad':
op.set_attr('groups', op.inputs('Filter')[0].shape()[0])
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import logging
import numpy as np
from six.moves.queue import Queue
import paddle.fluid as fluid
from paddle.fluid.framework import Variable
from paddle.fluid.reader import DataLoaderBase
from paddle.fluid.core import EOFException
from paddle.fluid.incubate.fleet.utils.hdfs import HDFSClient
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
__all__ = ['Knowledge']
class Knowledge(object):
"""
The knowledge class describes how to extract and store the dark knowledge
of the teacher model, and how the student model learns these dark knowledge.
"""
def __init__(self,
path,
items,
reduce_strategy={'type': 'sum',
'key': 'image'}):
"""Init a knowledge instance.
Args:
path(list<str>, str, optional): Specifies the storage path of the knowledge,
supports AFS/HDFS, local file system, and memory.
items(list<str>): Save the tensor of the specified name
reduce_strategy(dict, optional): The policy for performing the reduce
operation. If it is set to None,
the reduce operation is not performed.
reduce_strategy.type(str): Type of reduce operation.
reduce_strategy.key(str): The key of the reduce operation.
It is an element in the item.
"""
assert (isinstance(path, list) or isinstance(path, str) or
(path is None)), "path type should be list or str or None"
assert (isinstance(items, list)), "items should be a list"
assert (isinstance(reduce_strategy,
dict)), "reduce_strategy should be a dict"
self.path = path
if isinstance(self.path, list):
self.write_type = 'HDFS/AFS'
assert (
len(self.path) == 4 and isinstance(self.path[0], str) and
isinstance(self.path[1], str) and
isinstance(self.path[2], str) and isinstance(self.path[3], str)
), "path should contains four str, ['local hadoop home', 'fs.default.name', 'hadoop.job.ugi', 'FS path']"
hadoop_home = self.path[0]
configs = {
"fs.default.name": self.path[1],
"hadoop.job.ugi": self.path[2]
}
self.client = HDFSClient(hadoop_home, configs)
assert (
self.client.is_exist(self.path[3]) == True
), "Plese make sure your hadoop confiuration is correct and FS path exists"
self.hdfs_local_path = "./teacher_knowledge"
if not os.path.exists(self.hdfs_local_path):
os.mkdir(self.hdfs_local_path)
elif isinstance(self.path, str):
self.write_type = "LocalFS"
if not os.path.exists(path):
raise ValueError("The local path [%s] does not exist." %
(path))
else:
self.write_type = "MEM"
self.knowledge_queue = Queue(64)
self.items = items
self.reduce_strategy = reduce_strategy
def _write(self, data):
if self.write_type == 'HDFS/AFS':
file_name = 'knowledge_' + str(self.file_cnt)
file_path = os.path.join(self.hdfs_local_path, file_name)
file_path += ".npy"
np.save(file_path, data)
self.file_cnt += 1
self.client.upload(self.path[3], file_path)
logger.info('{}.npy pushed to HDFS/AFS: {}'.format(file_name,
self.path[3]))
elif self.write_type == 'LocalFS':
file_name = 'knowledge_' + str(self.file_cnt)
file_path = os.path.join(self.path, file_name)
np.save(file_path, data)
logger.info('{}.npy saved'.format(file_name))
self.file_cnt += 1
else:
self.knowledge_queue.put(data)
logger.info('{} pushed to Queue'.format(file_name))
def run(self, teacher_program, exe, place, scope, reader, inputs, outputs,
call_back):
"""Start teacher model to do information.
Args:
teacher_program(Program): teacher program.
scope(Scope): The scope used to execute the teacher,
which contains the initialized variables.
reader(reader): The data reader used by the teacher.
inputs(list<str>): The name of variables to feed the teacher program.
outputs(list<str>): Need to write to the variable instance's names of
the Knowledge instance, which needs to correspond
to the Knowledge's items.
call_back(func, optional): The callback function that handles the
outputs of the teacher, which is none by default,
that is, the output of the teacher is concat directly.
Return:
(bool): Whether the teacher task was successfully registered and started
"""
assert (isinstance(
teacher_program,
fluid.Program)), "teacher_program should be a fluid.Program"
assert (isinstance(inputs, list)), "inputs should be a list"
assert (isinstance(outputs, list)), "outputs should be a list"
assert (len(self.items) == len(outputs)
), "the length of outputs list should be equal with items list"
assert (callable(call_back) or (call_back is None)
), "call_back should be a callable function or NoneType."
for var in teacher_program.list_vars():
var.stop_gradient = True
compiled_teacher_program = fluid.compiler.CompiledProgram(
teacher_program)
self.file_cnt = 0
if isinstance(reader, Variable) or (
isinstance(reader, DataLoaderBase) and (not reader.iterable)):
reader.start()
try:
while True:
logits = exe.run(compiled_teacher_program,
scope=scope,
fetch_list=outputs,
feed=None)
knowledge = dict()
for index, array in enumerate(logits):
knowledge[self.items[index]] = array
self._write(knowledge)
except EOFException:
reader.reset()
else:
if not isinstance(reader, DataLoaderBase):
feeder = fluid.DataFeeder(
feed_list=inputs, place=place, program=teacher_program)
for batch_id, data in enumerate(reader()):
if not isinstance(reader, DataLoaderBase):
data = feeder.feed(data)
logits = exe.run(compiled_teacher_program,
scope=scope,
fetch_list=outputs,
feed=data)
knowledge = dict()
for index, array in enumerate(logits):
knowledge[self.items[index]] = array
self._write(knowledge)
return True
def dist(self, student_program, losses):
"""Building the distillation network
Args:
student_program(Program): student program.
losses(list<Variable>, optional): The losses need to add. If set to None
does not add any loss.
Return:
(Program): Program for distillation.
(startup_program): Program for initializing distillation network.
(reader): Data reader for distillation training.
(Variable): Loss of distillation training
"""
def loss(self, loss_func, *variables):
"""User-defined loss
Args:
loss_func(func): Function used to define loss.
*variables(list<str>): Variable name list.
Return:
(Variable): Distillation loss.
"""
pass
def fsp_loss(self):
"""fsp loss
"""
pass
def l2_loss(self):
"""l2 loss
"""
pass
def softlabel_loss(self):
"""softlabel_loss
"""
pass
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ['SearchSpaceBase']
class SearchSpaceBase(object):
"""Controller for Neural Architecture Search.
"""
def __init__(self, input_size, output_size, block_num, *argss):
self.input_size = input_size
self.output_size = output_size
self.block_num = block_num
def init_tokens(self):
"""Get init tokens in search space.
"""
raise NotImplementedError('Abstract method.')
def range_table(self):
"""Get range table of current search space.
"""
raise NotImplementedError('Abstract method.')
def token2arch(self, tokens):
"""Create networks for training and evaluation according to tokens.
Args:
tokens(list<int>): The tokens which represent a network.
Return:
list<layers>
"""
raise NotImplementedError('Abstract method.')
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from searchspace.registry import SEARCHSPACE
class SearchSpaceFactory(object):
def __init__(self):
pass
def get_search_space(self, key, config):
"""
get specific model space based on key and config.
Args:
key(str): model space name.
config(dict): basic config information.
return:
model space(class)
"""
cls = SEARCHSPACE.get(key)
space = cls(config['input_size'], config['output_size'], config['block_num'])
return space
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .mobilenetv2_space import MobileNetV2Space
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
def conv_bn_layer(input, filter_size, num_filters, stride, padding, num_groups=1, act=None, name=None, use_cudnn=True):
"""Build convolution and batch normalization layers.
Args:
input(Variable): input.
filter_size(int): filter size.
num_filters(int): number of filters.
stride(int): stride.
padding(int|list|str): padding.
num_groups(int): number of groups.
act(str): activation type.
name(str): name.
use_cudnn(bool): whether use cudnn.
Returns:
Variable, layers output.
"""
conv = fluid.layers.conv2d(input, num_filters=num_filters, filter_size=filter_size, stride=stride, padding=padding,
groups=num_groups, act=None, use_cudnn=use_cudnn, param_attr=ParamAttr(name=name+'_weights'), bias_attr=False)
bn_name = name + '_bn'
bn = fluid.layers.batch_norm(input=conv, param_attr=ParamAttr(name=bn_name+'_scale'), bias_attr=ParamAttr(name=bn_name+'_offset'),
moving_mean_name=bn_name+'_mean', moving_variance_name=bn_name+'_variance')
if act == 'relu6':
return fluid.layers.relu6(bn)
else:
return bn
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from ..search_space_base import SearchSpaceBase
from .layer import conv_bn_layer
from .registry import SEARCHSPACE
@SEARCHSPACE.register_module
class MobileNetV2Space(SearchSpaceBase):
def __init__(self, input_size, output_size, block_num, scale=1.0, class_dim=1000):
super(MobileNetV2Space, self).__init__(input_size, output_size, block_num)
self.head_num = np.array([3,4,8,12,16,24,32]) #7
self.filter_num1 = np.array([3,4,8,12,16,24,32,48]) #8
self.filter_num2 = np.array([8,12,16,24,32,48,64,80]) #8
self.filter_num3 = np.array([16,24,32,48,64,80,96,128]) #8
self.filter_num4 = np.array([24,32,48,64,80,96,128,144,160,192]) #10
self.filter_num5 = np.array([32,48,64,80,96,128,144,160,192,224]) #10
self.filter_num6 = np.array([64,80,96,128,144,160,192,224,256,320,384,512]) #12
self.k_size = np.array([3,5]) #2
self.multiply = np.array([1,2,3,4,6]) #5
self.repeat = np.array([1,2,3,4,5,6]) #6
self.scale=scale
self.class_dim=class_dim
def init_tokens(self):
"""
The initial token send to controller.
The first one is the index of the first layers' channel in self.head_num,
each line in the following represent the index of the [expansion_factor, filter_num, repeat_num, kernel_size]
"""
# original MobileNetV2
return [4, # 1, 16, 1
4, 5, 1, 0, # 6, 24, 1
4, 5, 1, 0, # 6, 24, 2
4, 4, 2, 0, # 6, 32, 3
4, 4, 3, 0, # 6, 64, 4
4, 5, 2, 0, # 6, 96, 3
4, 7, 2, 0, # 6, 160, 3
4, 9, 0, 0] # 6, 320, 1
def range_table(self):
"""
get range table of current search space
"""
# head_num + 7 * [multiple(expansion_factor), filter_num, repeat, kernel_size]
return [7,
5, 8, 6, 2,
5, 8, 6, 2,
5, 8, 6, 2,
5, 8, 6, 2,
5, 10, 6, 2,
5, 10, 6, 2,
5, 12, 6, 2]
def token2arch(self, tokens=None):
"""
return net_arch function
"""
if tokens is None:
tokens = self.init_tokens()
base_bottleneck_params_list = [
(1, self.head_num[tokens[0]], 1, 1, 3),
(self.multiply[tokens[1]], self.filter_num1[tokens[2]], self.repeat[tokens[3]], 2, self.k_size[tokens[4]]),
(self.multiply[tokens[5]], self.filter_num1[tokens[6]], self.repeat[tokens[7]], 2, self.k_size[tokens[8]]),
(self.multiply[tokens[9]], self.filter_num2[tokens[10]], self.repeat[tokens[11]], 2, self.k_size[tokens[12]]),
(self.multiply[tokens[13]], self.filter_num3[tokens[14]], self.repeat[tokens[15]], 2, self.k_size[tokens[16]]),
(self.multiply[tokens[17]], self.filter_num3[tokens[18]], self.repeat[tokens[19]], 1, self.k_size[tokens[20]]),
(self.multiply[tokens[21]], self.filter_num5[tokens[22]], self.repeat[tokens[23]], 2, self.k_size[tokens[24]]),
(self.multiply[tokens[25]], self.filter_num6[tokens[26]], self.repeat[tokens[27]], 1, self.k_size[tokens[28]]),
]
assert self.block_num < 7, 'block number must less than 7, but receive block number is {}'.format(self.block_num)
# the stride = 2 means downsample feature map in the convolution, so only when stride=2, block_num minus 1,
# otherwise, add layers to params_list directly.
bottleneck_params_list = []
for param_list in base_bottleneck_params_list:
if param_list[3] == 1:
bottleneck_params_list.append(param_list)
else:
if self.block_num > 1:
bottleneck_params_list.append(param_list)
self.block_num -= 1
else:
break
def net_arch(input):
#conv1
# all padding is 'SAME' in the conv2d, can compute the actual padding automatic.
input = conv_bn_layer(
input,
num_filters=int(32 * self.scale),
filter_size=3,
stride=2,
padding='SAME',
act='relu6',
name='conv1_1')
# bottleneck sequences
i = 1
in_c = int(32 * self.scale)
for layer_setting in bottleneck_params_list:
t, c, n, s, k = layer_setting
i += 1
input = self.invresi_blocks(
input=input,
in_c=in_c,
t=t,
c=int(c * self.scale),
n=n,
s=s,
k=k,
name='conv' + str(i))
in_c = int(c * self.scale)
# if output_size is 1, add fc layer in the end
if self.output_size == 1:
input = fluid.layers.fc(input=input,
size=self.class_dim,
param_attr=ParamAttr(name='fc10_weights'),
bias_attr=ParamAttr(name='fc10_offset'))
else:
assert self.output_size == input.shape[2], \
("output_size must EQUAL to input_size / (2^block_num)."
"But receive input_size={}, output_size={}, block_num={}".format(
self.input_size, self.output_size, self.block_num))
return input
return net_arch
def shortcut(self, input, data_residual):
"""Build shortcut layer.
Args:
input(Variable): input.
data_residual(Variable): residual layer.
Returns:
Variable, layer output.
"""
return fluid.layers.elementwise_add(input, data_residual)
def inverted_residual_unit(self,
input,
num_in_filter,
num_filters,
ifshortcut,
stride,
filter_size,
expansion_factor,
reduction_ratio=4,
name=None):
"""Build inverted residual unit.
Args:
input(Variable), input.
num_in_filter(int), number of in filters.
num_filters(int), number of filters.
ifshortcut(bool), whether using shortcut.
stride(int), stride.
filter_size(int), filter size.
padding(str|int|list), padding.
expansion_factor(float), expansion factor.
name(str), name.
Returns:
Variable, layers output.
"""
num_expfilter = int(round(num_in_filter * expansion_factor))
channel_expand = conv_bn_layer(
input=input,
num_filters=num_expfilter,
filter_size=1,
stride=1,
padding='SAME',
num_groups=1,
act='relu6',
name=name + '_expand')
bottleneck_conv = conv_bn_layer(
input=channel_expand,
num_filters=num_expfilter,
filter_size=filter_size,
stride=stride,
padding='SAME',
num_groups=num_expfilter,
act='relu6',
name=name + '_dwise',
use_cudnn=False)
linear_out = conv_bn_layer(
input=bottleneck_conv,
num_filters=num_filters,
filter_size=1,
stride=1,
padding='SAME',
num_groups=1,
act=None,
name=name + '_linear')
out = linear_out
if ifshortcut:
out = self.shortcut(input=input, data_residual=out)
return out
def invresi_blocks(self,
input,
in_c,
t,
c,
n,
s,
k,
name=None):
"""Build inverted residual blocks.
Args:
input: Variable, input.
in_c: int, number of in filters.
t: float, expansion factor.
c: int, number of filters.
n: int, number of layers.
s: int, stride.
k: int, filter size.
name: str, name.
Returns:
Variable, layers output.
"""
first_block = self.inverted_residual_unit(
input=input,
num_in_filter=in_c,
num_filters=c,
ifshortcut=False,
stride=s,
filter_size=k,
expansion_factor=t,
name=name + '_1')
last_residual_block = first_block
last_c = c
for i in range(1, n):
last_residual_block = self.inverted_residual_unit(
input=last_residual_block,
num_in_filter=last_c,
num_filters=c,
ifshortcut=True,
stride=1,
filter_size=k,
expansion_factor=t,
name=name + '_' + str(i + 1))
return last_residual_block
from ..utils.registry import Registry
SEARCHSPACE = Registry('searchspace')
import inspect
class Registry(object):
def __init__(self, name):
self._name = name
self._module_dict = dict()
def __repr__(self):
format_str = self.__class__.__name__ + '(name={}, items={})'.format(self._name, list(self._module_dict.keys()))
return format_str
@property
def name(self):
return self._name
@property
def module_dict(self):
return self._module_dict
def get(self, key):
return self._module_dict.get(key, None)
def _register_module(self, module_class):
if not inspect.isclass(module_class):
raise TypeError('module must be a class, but receive {}.'.format(type(module_class)))
module_name = module_class.__name__
if module_name in self._module_dict:
raise KeyError('{} is already registered in {}.'.format(module_name, self.name))
self._module_dict[module_name] = module_class
def register_module(self, cls):
self._register_module(cls)
return cls
...@@ -11,3 +11,4 @@ ...@@ -11,3 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from pruner import Pruner
此差异已折叠。
...@@ -13,3 +13,4 @@ ...@@ -13,3 +13,4 @@
# limitations under the License. # limitations under the License.
from quanter import quant_aware, quant_post, convert from quanter import quant_aware, quant_post, convert
from .quant_embedding import quant_embedding
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import copy
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
#_logger = logging.basicConfig(level=logging.DEBUG)
__all__ = ['quant_embedding']
default_config = {
"quantize_type": "abs_max",
"quantize_bits": 8,
"dtype": "int8"
}
support_quantize_types = ['abs_max']
support_quantize_bits = [8]
support_dtype = ['int8']
def _merge_config(old_config, new_config):
"""
merge default config and user defined config
Args:
old_config(dict): the copy of default_config
new_config(dict): the user defined config, 'params_name' must be set.
When 'threshold' is not set, quant embedding without clip .
"""
old_config.update(new_config)
keys = old_config.keys()
assert 'params_name' in keys, "params_name must be set"
quantize_type = old_config['quantize_type']
assert isinstance(quantize_type, str), "quantize_type must be \
str"
assert quantize_type in support_quantize_types, " \
quantize_type {} is not supported, now supported quantize type \
are {}.".format(quantize_type, support_quantize_types)
quantize_bits = old_config['quantize_bits']
assert isinstance(quantize_bits, int), "quantize_bits must be int"
assert quantize_bits in support_quantize_bits, " quantize_bits {} \
is not supported, now supported quantize bits are \
{}. ".format(quantize_bits, support_quantize_bits)
dtype = old_config['dtype']
assert isinstance(dtype, str), "dtype must be str"
assert dtype in support_dtype, " dtype {} is not \
supported, now supported dtypes are {} \
".format(dtype, support_dtype)
if 'threshold' in keys:
assert isinstance(old_config['threshold'], (float, int)), "threshold \
must be number."
print("quant_embedding config {}".format(old_config))
return old_config
def _get_var_tensor(scope, var_name):
"""
get tensor array by name.
Args:
scope(fluid.Scope): scope to get var
var_name(str): vatiable name
Return:
np.array
"""
return np.array(scope.find_var(var_name).get_tensor())
def _clip_tensor(tensor_array, threshold):
"""
when 'threshold' is set, clip tensor by 'threshold' and '-threshold'
Args:
tensor_array(np.array): array to clip
config(dict): config dict
"""
tensor_array[tensor_array > threshold] = threshold
tensor_array[tensor_array < -threshold] = -threshold
return tensor_array
def _get_scale_var_name(var_name):
"""
get scale var name
"""
return var_name + '.scale'
def _get_quant_var_name(var_name):
"""
get quantized var name
"""
return var_name + '.int8'
def _get_dequant_var_name(var_name):
"""
get dequantized var name
"""
return var_name + '.dequantize'
def _restore_var(name, arr, scope, place):
"""
restore quantized array to quantized var
"""
tensor = scope.find_var(name).get_tensor()
tensor.set(arr, place)
def _clear_var(var_name, scope):
"""
free memory of var
"""
tensor = scope.find_var(var_name).get_tensor()
tensor._clear()
def _quant_embedding_abs_max(graph, scope, place, config):
"""
quantize embedding using abs_max
Args:
graph(IrGraph): graph that includes lookup_table op
scope(fluid.Scope): scope
place(fluid.CPUPlace or flud.CUDAPlace): place
config(dict): config to quant
"""
def _quant_abs_max(tensor_array, config):
"""
quant array using abs_max op
"""
bit_length = config['quantize_bits']
scale = np.max(np.abs(tensor_array)).astype("float32")
quanted_tensor = np.round(tensor_array / scale * (
(1 << (bit_length - 1)) - 1))
return scale, quanted_tensor.astype(config['dtype'])
def _insert_dequant_abs_max_op(graph, scope, var_node, scale_node, config):
"""
Insert dequantize_abs_max op in graph
"""
assert var_node.is_var(), "{} is not a var".format(var_node.name())
dequant_var_node = graph.create_var_node(
name=_get_dequant_var_name(var_node.name()),
var_type=var_node.type(),
shape=var_node.shape(),
var_dtype=core.VarDesc.VarType.FP32)
scope.var(dequant_var_node.name())
max_range = (1 << (config['quantize_bits'] - 1)) - 1
output_ops = var_node.outputs
dequant_op = graph.create_op_node(
op_type='dequantize_abs_max',
attrs={
'max_range': float(max_range),
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
},
inputs={'X': var_node,
'Scale': scale_node},
outputs={'Out': dequant_var_node})
graph.link_to(var_node, dequant_op)
graph.link_to(scale_node, dequant_op)
graph.link_to(dequant_op, dequant_var_node)
for node in output_ops:
graph.update_input_link(var_node, dequant_var_node, node)
all_var_nodes = graph.all_var_nodes()
var_name = config['params_name']
# find embedding var node by 'params_name'
embedding_node = graph._find_node_by_name(all_var_nodes, var_name)
embedding_tensor = _get_var_tensor(scope, var_name)
if 'threshold' in config.keys():
embedding_tensor = _clip_tensor(embedding_tensor, config['threshold'])
# get scale and quanted tensor
scale, quanted_tensor = _quant_abs_max(embedding_tensor, config)
#create params must to use create_persistable_node
scale_var = graph.create_persistable_node(
_get_scale_var_name(var_name),
var_type=embedding_node.type(),
shape=[1],
var_dtype=core.VarDesc.VarType.FP32)
quant_tensor_var = graph.create_persistable_node(
_get_quant_var_name(var_name),
var_type=embedding_node.type(),
shape=embedding_node.shape(),
var_dtype=core.VarDesc.VarType.INT8)
# create var in scope
scope.var(_get_quant_var_name(var_name))
scope.var(_get_scale_var_name(var_name))
#set var by tensor array or scale
_restore_var(_get_quant_var_name(var_name), quanted_tensor, scope, place)
_restore_var(_get_scale_var_name(var_name), np.array(scale), scope, place)
# insert dequantize_abs_max op
for op_node in embedding_node.outputs:
if op_node.name() == 'lookup_table':
graph.update_input_link(embedding_node, quant_tensor_var, op_node)
var_node = op_node.outputs[0]
_insert_dequant_abs_max_op(graph, scope, var_node, scale_var,
config)
# free float embedding params memory
_clear_var(embedding_node.name(), scope)
graph.safe_remove_nodes(embedding_node)
def quant_embedding(program, place, config, scope=None):
"""
quant lookup_table op parameters
Args:
program(fluid.Program): infer program
scope(fluid.Scope): the scope to store var, when is None will use fluid.global_scope()
place(fluid.CPUPlace or fluid.CUDAPlace): place
config(dict): config to quant. The keys are 'params_name', 'quantize_type', \
'quantize_bits', 'dtype', 'threshold'. \
'params_name': parameter name to quant, must be set.
'quantize_type': quantize type, supported types are ['abs_max']. default is "abs_max".
'quantize_bits': quantize bits, supported bits are [8]. default is 8.
'dtype': quantize dtype, supported dtype are ['int8']. default is 'int8'.
'threshold': threshold to clip tensor before quant. When threshold is not set, \
tensor will not be clipped.
"""
assert isinstance(config, dict), "config must be dict"
config = _merge_config(copy.deepcopy(default_config), config)
scope = fluid.global_scope() if scope is None else scope
graph = IrGraph(core.Graph(program.desc), for_test=True)
if config['quantize_type'] == 'abs_max':
_quant_embedding_abs_max(graph, scope, place, config)
return graph.to_program()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
def conv_bn_layer(input,
num_filters,
filter_size,
name,
stride=1,
groups=1,
act=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False,
name=name + "_out")
bn_name = name + "_bn"
return fluid.layers.batch_norm(
input=conv,
act=act,
name=bn_name + '_output',
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance', )
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append('..')
import unittest
import paddle.fluid as fluid
from nas.search_space_factory import SearchSpaceFactory
class TestSearchSpace(unittest.TestCase):
def test_searchspace(self):
# if output_size is 1, the model will add fc layer in the end.
config = {'input_size': 224, 'output_size': 7, 'block_num': 5}
space = SearchSpaceFactory()
my_space = space.get_search_space('MobileNetV2Space', config)
model_arch = my_space.token2arch()
train_prog = fluid.Program()
startup_prog = fluid.Program()
with fluid.program_guard(train_prog, startup_prog):
input_size= config['input_size']
model_input = fluid.layers.data(name='model_in', shape=[1, 3, input_size, input_size], dtype='float32', append_batch_size=False)
predict = model_arch(model_input)
self.assertTrue(predict.shape[2] == config['output_size'])
#for op in train_prog.global_block().ops:
# print(op.type)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import unittest
import paddle.fluid as fluid
from prune import Pruner
from layers import conv_bn_layer
class TestPrune(unittest.TestCase):
def test_prune(self):
main_program = fluid.Program()
startup_program = fluid.Program()
# X X O X O
# conv1-->conv2-->sum1-->conv3-->conv4-->sum2-->conv5-->conv6
# | ^ | ^
# |____________| |____________________|
#
# X: prune output channels
# O: prune input channels
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[None, 3, 16, 16])
conv1 = conv_bn_layer(input, 8, 3, "conv1")
conv2 = conv_bn_layer(conv1, 8, 3, "conv2")
sum1 = conv1 + conv2
conv3 = conv_bn_layer(sum1, 8, 3, "conv3")
conv4 = conv_bn_layer(conv3, 8, 3, "conv4")
sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
shapes = {}
for param in main_program.global_block().all_parameters():
shapes[param.name] = param.shape
place = fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.Scope()
exe.run(startup_program, scope=scope)
pruner = Pruner()
main_program = pruner.prune(
main_program,
scope,
params=["conv4_weights"],
ratios=[0.5],
place=place,
lazy=False,
only_graph=False,
param_backup=None,
param_shape_backup=None)
shapes = {
"conv1_weights": (4L, 3L, 3L, 3L),
"conv2_weights": (4L, 4L, 3L, 3L),
"conv3_weights": (8L, 4L, 3L, 3L),
"conv4_weights": (4L, 8L, 3L, 3L),
"conv5_weights": (8L, 4L, 3L, 3L),
"conv6_weights": (8L, 8L, 3L, 3L)
}
for param in main_program.global_block().all_parameters():
if "weights" in param.name:
self.assertTrue(param.shape == shapes[param.name])
if __name__ == '__main__':
unittest.main()
...@@ -33,8 +33,12 @@ with open('./requirements.txt') as f: ...@@ -33,8 +33,12 @@ with open('./requirements.txt') as f:
setup_requires = f.read().splitlines() setup_requires = f.read().splitlines()
packages = [ packages = [
'paddleslim', 'paddleslim.prune', 'paddleslim.dist', 'paddleslim.nas', 'paddleslim',
'paddleslim.analysis', 'paddleslim.quant' 'paddleslim.prune',
'paddleslim.dist',
'paddleslim.nas',
'paddleslim.analysis',
'paddleslim.quant',
] ]
setup( setup(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册