diff --git a/paddleslim/nas/ofa/convert_super.py b/paddleslim/nas/ofa/convert_super.py index b08116f1b3e884bbad10bb114d86b9cdf9e6eec5..580f6b656f18ebb519adcc1d9ef6790858d2d266 100644 --- a/paddleslim/nas/ofa/convert_super.py +++ b/paddleslim/nas/ofa/convert_super.py @@ -35,7 +35,7 @@ else: from . import layers Layer = paddle.nn.Layer from .layers_base import Block - +from . import layers_old _logger = get_logger(__name__, level=logging.INFO) __all__ = ['supernet', 'Convert'] @@ -58,11 +58,16 @@ class Convert: def __init__(self, context): self.context = context - def _change_name(self, layer, pd_ver, has_bias=True, conv=False): + def _change_name(self, + layer, + pd_ver, + has_bias=True, + conv=False, + use_bn_old=False): if conv: w_attr = layer._param_attr else: - w_attr = layer._param_attr if pd_ver == 185 else layer._weight_attr + w_attr = layer._param_attr if pd_ver == 185 or use_bn_old else layer._weight_attr if isinstance(w_attr, ParamAttr): if w_attr != None and not isinstance(w_attr, @@ -241,18 +246,22 @@ class Convert: layer = Block(SuperGroupConv2D(**new_attr_dict), key=key) model[idx] = layer - elif isinstance(layer, - getattr(nn, 'BatchNorm2D', nn.BatchNorm)) and ( - getattr(self.context, 'expand', None) != None or - getattr(self.context, 'channel', None) != None): + elif (isinstance(layer, nn.BatchNorm2D) or + isinstance(layer, nn.BatchNorm)) and ( + getattr(self.context, 'expand', None) != None or + getattr(self.context, 'channel', None) != None): # num_features in BatchNorm don't change after last weight operators if idx > last_weight_layer_idx: continue + use_bn_old = False + if isinstance(layer, nn.BatchNorm): + use_bn_old = True + attr_dict = layer.__dict__ new_attr_name = ['momentum', 'epsilon', 'bias_attr'] - if pd_ver == 185: + if pd_ver == 185 or use_bn_old: new_attr_name += [ 'param_attr', 'act', 'dtype', 'in_place', 'data_layout', 'is_test', 'use_global_stats', 'trainable_statistics' @@ -260,9 +269,9 @@ class Convert: else: new_attr_name += ['weight_attr', 'data_format', 'name'] - self._change_name(layer, pd_ver) + self._change_name(layer, pd_ver, use_bn_old=use_bn_old) new_attr_dict = dict.fromkeys(new_attr_name, None) - if pd_ver == 185: + if pd_ver == 185 or use_bn_old: new_attr_dict['num_channels'] = None else: new_attr_dict['num_features'] = None @@ -284,9 +293,10 @@ class Convert: del layer, attr_dict - layer = layers.SuperBatchNorm( + layer = layers_old.SuperBatchNorm( **new_attr_dict - ) if pd_ver == 185 else layers.SuperBatchNorm2D(**new_attr_dict) + ) if pd_ver == 185 or use_bn_old else layers.SuperBatchNorm2D( + **new_attr_dict) model[idx] = layer elif isinstance(layer, SyncBatchNorm) and ( @@ -755,4 +765,4 @@ class supernet: # def convert(*args, **kwargs): # supernet_convert(*args, **kwargs) # return convert -# return _ofa_supernet +# return _ofa_supernet \ No newline at end of file diff --git a/paddleslim/nas/ofa/layers.py b/paddleslim/nas/ofa/layers.py index b45ffd33014561a4f0bf488b47c6587b2e884806..aad221475ee3de79374dc085c0ca4c6155bc8740 100644 --- a/paddleslim/nas/ofa/layers.py +++ b/paddleslim/nas/ofa/layers.py @@ -40,9 +40,7 @@ _logger = get_logger(__name__, level=logging.INFO) class SuperConv2D(nn.Conv2D): """This interface is used to construct a callable object of the ``SuperConv2D`` class. - Note: the channel in config need to less than first defined. - The super convolution2D layer calculates the output based on the input, filter and strides, paddings, dilations, groups parameters. Input and Output are in NCHW format, where N is batch size, C is the number of @@ -59,9 +57,7 @@ class SuperConv2D(nn.Conv2D): applied to the final result. For each input :math:`X`, the equation is: .. math:: - Out = sigma (W \\ast X + b) - Where: * :math:`X`: Input value, a ``Tensor`` with NCHW format. * :math:`W`: Filter value, a ``Tensor`` with shape [MCHW] . @@ -69,7 +65,6 @@ class SuperConv2D(nn.Conv2D): * :math:`b`: Bias value, a 2-D ``Tensor`` with shape [M, 1]. * :math:`\\sigma`: Activation function. * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different. - Example: - Input: Input shape: :math:`(N, C_{in}, H_{in}, W_{in})` @@ -78,11 +73,8 @@ class SuperConv2D(nn.Conv2D): Output shape: :math:`(N, C_{out}, H_{out}, W_{out})` Where .. math:: - H_{out}&= \\frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]} + 1 - W_{out}&= \\frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1 - Parameters: num_channels(int): The number of channels in the input image. num_filters(int): The number of filter. It is as same as the output @@ -144,7 +136,6 @@ class SuperConv2D(nn.Conv2D): config = {'channel': 5} data = paddle.to_tensor(data) conv = super_conv2d(data, config) - """ ### NOTE: filter_size, num_channels and num_filters must be the max of candidate to define a largest network. @@ -214,10 +205,6 @@ class SuperConv2D(nn.Conv2D): setattr(self, name, param) def get_active_filter(self, in_nc, out_nc, kernel_size): - ### Unsupport for asymmetric kernels - if self._kernel_size[0] != self._kernel_size[1]: - return self.weight[:out_nc, :in_nc, :, :] - start, end = compute_start_end(self._kernel_size[0], kernel_size) ### if NOT transform kernel, intercept a center filter with kernel_size from largest filter filters = self.weight[:out_nc, :in_nc, start:end, start:end] @@ -292,14 +279,9 @@ class SuperConv2D(nn.Conv2D): out_nc = int(channel) else: out_nc = self._out_channels - ks = int(self._kernel_size[0]) if kernel_size == None else int( kernel_size) - if kernel_size is not None and self._kernel_size[ - 0] != self._kernel_size[1]: - _logger.error("Searching for asymmetric kernels is NOT supported") - groups, weight_in_nc, weight_out_nc = self.get_groups_in_out_nc(in_nc, out_nc) @@ -324,6 +306,7 @@ class SuperConv2D(nn.Conv2D): else: bias = self.bias self.cur_config['prune_dim'] = list(weight.shape) + self.cur_config['prune_group'] = groups out = F.conv2d( input, weight, @@ -361,9 +344,7 @@ class SuperConv2DTranspose(nn.Conv2DTranspose): """ This interface is used to construct a callable object of the ``SuperConv2DTranspose`` class. - Note: the channel in config need to less than first defined. - The super convolution2D transpose layer calculates the output based on the input, filter, and dilations, strides, paddings. Input and output are in NCHW format. Where N is batch size, C is the number of feature map, @@ -527,9 +508,6 @@ class SuperConv2DTranspose(nn.Conv2DTranspose): setattr(self, name, param) def get_active_filter(self, in_nc, out_nc, kernel_size): - ### Unsupport for asymmetric kernels - if self._kernel_size[0] != self._kernel_size[1]: - return self.weight[:out_nc, :in_nc, :, :] start, end = compute_start_end(self._kernel_size[0], kernel_size) filters = self.weight[:in_nc, :out_nc, start:end, start:end] if self.transform_kernel != False and kernel_size < self._kernel_size[ @@ -612,10 +590,6 @@ class SuperConv2DTranspose(nn.Conv2DTranspose): ks = int(self._kernel_size[0]) if kernel_size == None else int( kernel_size) - if kernel_size is not None and self._kernel_size[ - 0] != self._kernel_size[1]: - _logger.error("Searching for asymmetric kernels is NOT supported") - groups, weight_in_nc, weight_out_nc = self.get_groups_in_out_nc(in_nc, out_nc) @@ -638,6 +612,7 @@ class SuperConv2DTranspose(nn.Conv2DTranspose): else: bias = self.bias self.cur_config['prune_dim'] = list(weight.shape) + self.cur_config['prune_group'] = groups out = F.conv2d_transpose( input, weight, @@ -682,12 +657,10 @@ class SuperSeparableConv2D(nn.Layer): {'channel', num_of_channel} represents the channels of the first conv's outputs and the second conv's inputs, used to change the first dimension of weight and bias, only train the first channels of the weight and bias. - The architecture of super separable convolution2D op is [Conv2D, norm layer(may be BatchNorm2D or InstanceNorm2D), Conv2D]. The first conv is depthwise conv, the filter number is input channel multiply scale_factor, the group is equal to the number of input channel. The second conv is standard conv, which filter size and stride size are 1. - Parameters: num_channels(int): The number of channels in the input image. num_filters(int): The number of the second conv's filter. It is as same as the output @@ -923,7 +896,6 @@ class SuperLinear(nn.Linear): class SuperBatchNorm2D(nn.BatchNorm2D): """ This interface is used to construct a callable object of the ``SuperBatchNorm2D`` class. - Parameters: num_features(int): Indicate the number of channels of the input ``Tensor``. epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5. @@ -938,7 +910,6 @@ class SuperBatchNorm2D(nn.BatchNorm2D): If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None. data_format(str, optional): Specify the input data format, the data format can be "NCHW" or "NHWC". Default: NCHW. name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. - Examples: .. code-block:: python import paddle @@ -1062,7 +1033,6 @@ class SuperSyncBatchNorm(nn.SyncBatchNorm): class SuperInstanceNorm2D(nn.InstanceNorm2D): """ This interface is used to construct a callable object of the ``SuperInstanceNorm2D`` class. - Parameters: num_features(int): Indicate the number of channels of the input ``Tensor``. epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5. @@ -1077,7 +1047,6 @@ class SuperInstanceNorm2D(nn.InstanceNorm2D): If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None. data_format(str, optional): Specify the input data format, the data format can be "NCHW" or "NHWC". Default: NCHW. name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. - Examples: .. code-block:: python import paddle @@ -1121,11 +1090,9 @@ class SuperInstanceNorm2D(nn.InstanceNorm2D): class SuperLayerNorm(nn.LayerNorm): """ This interface is used to construct a callable object of the ``SuperLayerNorm`` class. - The difference between ```SuperLayerNorm``` and ```LayerNorm``` is: the trained weight and bias in ```SuperLayerNorm``` can be changed according to the shape of input, only train the first channels of the weight and bias. - Parameters: normalized_shape(int|list|tuple): Input shape from an expected input of size :math:`[*, normalized_shape[0], normalized_shape[1], ..., normalized_shape[-1]]`. @@ -1193,7 +1160,6 @@ class SuperLayerNorm(nn.LayerNorm): class SuperEmbedding(nn.Embedding): """ This interface is used to construct a callable object of the ``SuperEmbedding`` class. - Parameters: num_embeddings (int): Just one element which indicate the size of the dictionary of embeddings. @@ -1280,4 +1246,4 @@ class SuperEmbedding(nn.Embedding): weight=weight, padding_idx=self._padding_idx, sparse=self._sparse, - name=self._name) + name=self._name) \ No newline at end of file diff --git a/paddleslim/nas/ofa/ofa.py b/paddleslim/nas/ofa/ofa.py index ac788209407d53a996aec05564942821c6cfda0d..b0a02fbad850264ccd5a7686236b93c74ff190ce 100644 --- a/paddleslim/nas/ofa/ofa.py +++ b/paddleslim/nas/ofa/ofa.py @@ -17,7 +17,7 @@ import numpy as np from collections import namedtuple import paddle import paddle.fluid as fluid -from .utils.utils import get_paddle_version, remove_model_fn +from .utils.utils import get_paddle_version, remove_model_fn, build_input pd_ver = get_paddle_version() if pd_ver == 185: from .layers_old import SuperConv2D, SuperLinear @@ -56,7 +56,11 @@ RunConfig = namedtuple( # list, the number of sub-network to train per mini-batch data, used to get current epoch, default: None 'dynamic_batch_size', # the shape of weights in the skip_layers will not change in the training, default: None - 'skip_layers' + 'skip_layers', + # same search space designed by hand for some complicated models + 'same_search_space', + # ofa_layers designed by hand if different ratio or channel is needed for different layers + 'ofa_layers', ]) RunConfig.__new__.__defaults__ = (None, ) * len(RunConfig._fields) @@ -79,21 +83,6 @@ DistillConfig = namedtuple( DistillConfig.__new__.__defaults__ = (None, ) * len(DistillConfig._fields) -def to_tensor(string_values, name="text"): - """ - Create the tensor that the value holds the list of string. - NOTICE: The value will be holded in the cpu place. - - Parameters: - string_values(list[string]): The value will be setted to the tensor. - name(string): The name of the tensor. - """ - tensor = paddle.Tensor(core.VarDesc.VarType.STRING, [], name, - core.VarDesc.VarType.STRINGS, False) - tensor.value().set_string_list(string_values) - return tensor - - class OFABase(Layer): def __init__(self, model): super(OFABase, self).__init__() @@ -114,6 +103,11 @@ class OFABase(Layer): if isinstance(sublayer, BaseBlock): sublayer.set_supernet(self) if not sublayer.fixed: + config = sublayer.candidate_config + for k, v in config.items(): + if isinstance(v, list) or isinstance( + v, set) or isinstance(v, tuple): + sublayer.candidate_config[k] = sorted(list(v)) ofa_layers[name] = sublayer.candidate_config layers[sublayer.key] = sublayer.candidate_config key2name[sublayer.key] = name @@ -158,20 +152,17 @@ class OFABase(Layer): class OFA(OFABase): """ Convert the training progress to the Once-For-All training progress, a detailed description in the paper: `Once-for-All: Train One Network and Specialize it for Efficient Deployment`_ . This paper propose a training propgress named progressive shrinking (PS), which means we start with training the largest neural network with the maximum kernel size (i.e., 7), depth (i.e., 4), and width (i.e., 6). Next, we progressively fine-tune the network to support smaller sub-networks by gradually adding them into the sampling space (larger sub-networks may also be sampled). Specifically, after training the largest network, we first support elastic kernel size which can choose from {3, 5, 7} at each layer, while the depth and width remain the maximum values. Then, we support elastic depth and elastic width sequentially. - Parameters: model(paddle.nn.Layer): instance of model. run_config(paddleslim.ofa.RunConfig, optional): config in ofa training, can reference `<>`_ . Default: None. distill_config(paddleslim.ofa.DistillConfig, optional): config of distilltion in ofa training, can reference `<>`_. Default: None. elastic_order(list, optional): define the training order, if it set to None, use the default order in the paper. Default: None. train_full(bool, optional): whether to train the largest sub-network only. Default: False. - Examples: .. code-block:: python from paddle.vision.models import mobilenet_v1 from paddleslim.nas.ofa import OFA from paddleslim.nas.ofa.convert_super import Convert, supernet - model = mobilenet_v1() sp_net_config = supernet(kernel_size=(3, 5, 7), expand_ratio=[1, 2, 4]) sp_model = Convert(sp_net_config).convert(model) @@ -201,6 +192,8 @@ class OFA(OFABase): self._broadcast = False self._skip_layers = None self._cannot_changed_layer = [] + self.token_map = {} + self.search_cands = [] ### if elastic_order is none, use default order if self.elastic_order is not None: @@ -235,6 +228,12 @@ class OFA(OFABase): if 'channel' in self._elastic_task and 'width' not in self.elastic_order: self.elastic_order.append('width') + if getattr(self.run_config, 'ofa_layers', None) != None: + for key in self.run_config.ofa_layers: + assert key in self._ofa_layers, "layer {} is not in current _ofa_layers".format( + key) + self._ofa_layers[key] = self.run_config.ofa_layers[key] + if getattr(self.run_config, 'n_epochs', None) != None: assert len(self.run_config.n_epochs) == len(self.elastic_order) for idx in range(len(run_config.n_epochs)): @@ -258,6 +257,11 @@ class OFA(OFABase): None) != None: self._skip_layers = self.run_config.skip_layers + if self.run_config != None and getattr( + self.run_config, 'same_search_space', None) != None: + self._same_ss_by_hand = self.run_config.same_search_space + else: + self._same_ss_by_hand = None ### ================= add distill prepare ====================== if self.distill_config != None: self._add_teacher = True @@ -393,6 +397,47 @@ class OFA(OFABase): self._ofa_layers, sample_type=sample_type, task=task, phase=phase) return config + def tokenize(self): + ''' + Tokenize current search space. Task should be set before tokenize. + Example: token_map = { + 'expand_ratio': { + 'conv1': {0: 0.25, 1: 0.5, 2: 0.75} + 'conv2': {0: 0.25, 1: 0.5, 2: 0.75} + } + } + + ''' + all_tokens = [] + for name, cands in self._ofa_layers.items(): + if self.task in cands: + all_tokens += list(cands[self.task]) + + all_tokens = sorted(list(set(all_tokens))) + self.token_map[self.task] = {} + for name, cands in self._ofa_layers.items(): + if not cands: + continue + if self.task in cands: + self.token_map[self.task][name] = {} + for cand in cands[self.task]: + key = all_tokens.index(cand) + self.token_map[self.task][name][key] = cand + else: + raise NotImplementedError("Task {} not in ofa layers".format( + self.task)) + + self.search_cands = [] + for layer, t_map in self.token_map[self.task].items(): + self.search_cands.append(list(t_map.keys())) + + def decode_token(self, token): + config = {} + for i, name in enumerate(self.token_map[self.task].keys()): + config[name] = self.token_map[self.task][name][token[i]] + self.net_config = config + return config + def set_task(self, task, phase=None): """ set task in the ofa training progress. @@ -480,7 +525,7 @@ class OFA(OFABase): pass def _get_model_pruned_weight(self): - + prune_groups = {} pruned_param = {} for l_name, sublayer in self.model.named_sublayers(): @@ -490,6 +535,9 @@ class OFA(OFABase): assert 'prune_dim' in sublayer.cur_config, 'The laycer {} do not have prune_dim in cur_config.'.format( l_name) prune_shape = sublayer.cur_config['prune_dim'] + if 'prune_group' in sublayer.cur_config: + prune_group = sublayer.cur_config['prune_group'] + prune_groups[l_name] = prune_group for p_name, param in sublayer.named_parameters( include_sublayers=False): @@ -513,7 +561,7 @@ class OFA(OFABase): else: pruned_param[name] = param[:prune_shape] - return pruned_param + return pruned_param, prune_groups def export(self, config, @@ -539,31 +587,6 @@ class OFA(OFABase): self.set_net_config(config) self.model.eval() - def build_input(input_size, dtypes): - if isinstance(input_size, list) and all( - isinstance(i, numbers.Number) for i in input_size): - if isinstance(dtypes, list): - dtype = dtypes[0] - else: - dtype = dtypes - if dtype == core.VarDesc.VarType.STRINGS: - return to_tensor([""]) - return paddle.cast(paddle.rand(list(input_size)), dtype) - if isinstance(input_size, dict): - inputs = {} - if isinstance(dtypes, list): - dtype = dtypes[0] - else: - dtype = dtypes - for key, value in input_size.items(): - inputs[key] = paddle.cast(paddle.rand(list(value)), dtype) - return inputs - if isinstance(input_size, list): - return [ - build_input(i, dtype) - for i, dtype in zip(input_size, dtypes) - ] - data = build_input(input_shapes, input_dtypes) if isinstance(data, list): @@ -582,10 +605,12 @@ class OFA(OFABase): origin_model, DataParallel) else origin_model _logger.info("Start to get pruned params, please wait...") - pruned_param = self._get_model_pruned_weight() + pruned_param, pruned_groups = self._get_model_pruned_weight() pruned_state_dict = remove_model_fn(origin_model, pruned_param) _logger.info("Start to get pruned model, please wait...") for l_name, sublayer in origin_model.named_sublayers(): + if l_name in pruned_groups: + sublayer._groups = pruned_groups[l_name] for p_name, param in sublayer.named_parameters( include_sublayers=False): name = l_name + '.' + p_name @@ -643,40 +668,67 @@ class OFA(OFABase): if len(self._ofa_layers[key]) == 0: self._ofa_layers.pop(key) - def _clear_search_space(self, *inputs, **kwargs): + def _clear_search_space(self, *inputs, input_spec=None, **kwargs): """ find shortcut in model, and clear up the search space """ - input_shapes = [] - input_dtypes = [] - for n in inputs: - if isinstance(n, Variable): - input_shapes.append(n) - input_dtypes.append(n.numpy().dtype) - - for key, val in kwargs.items(): - if isinstance(val, Variable): - input_shapes.append(val) - input_dtypes.append(val.numpy().dtype) - elif isinstance(val, dict): - input_shape = {} - input_dtype = {} - for k, v in val.items(): - input_shape[k] = v - input_dtype[k] = v.numpy().dtype - input_shapes.append(input_shape) - input_dtypes.append(input_dtype) - else: - _logger.error( - "Cannot figure out the type of inputs! Right now, the type of inputs can be only Variable or dict." - ) + if input_spec is None: + input_shapes = [] + input_dtypes = [] + for n in inputs: + if isinstance(n, Variable): + input_shapes.append(n) + input_dtypes.append(n.numpy().dtype) + + for key, val in kwargs.items(): + if isinstance(val, Variable): + input_shapes.append(val) + input_dtypes.append(val.numpy().dtype) + elif isinstance(val, dict): + input_shape = {} + input_dtype = {} + for k, v in val.items(): + input_shape[k] = v + input_dtype[k] = v.numpy().dtype + input_shapes.append(input_shape) + input_dtypes.append(input_dtype) + else: + _logger.error( + "Cannot figure out the type of inputs! Right now, the type of inputs can be only Variable or dict." + ) - ### find shortcut block using static model - model_to_traverse = self.model._layers if isinstance( - self.model, DataParallel) else self.model - _st_prog = dygraph2program( - model_to_traverse, inputs=input_shapes, dtypes=input_dtypes) - self._same_ss, depthwise_conv, fixed_by_input, output_conv = check_search_space( - GraphWrapper(_st_prog)) - self._cannot_changed_layer = output_conv + ### find shortcut block using static model + model_to_traverse = self.model._layers if isinstance( + self.model, DataParallel) else self.model + _st_prog = dygraph2program( + model_to_traverse, inputs=input_shapes, dtypes=input_dtypes) + + else: + model_to_traverse = self.model._layers if isinstance( + self.model, DataParallel) else self.model + + model_to_traverse.eval() + _st_prog = dygraph2program(model_to_traverse, inputs=input_spec) + model_to_traverse.train() + + if self._same_ss_by_hand is None: + self._same_ss, depthwise_conv, fixed_by_input, output_conv = check_search_space( + GraphWrapper(_st_prog)) + self._cannot_changed_layer = output_conv + else: + output_conv = [] + fixed_by_input = [] + depthwise_conv = [] + self._cannot_changed_layer = output_conv + self._same_ss = [] + self._key2param = {} + for name, sublayer in model_to_traverse.named_sublayers(): + if isinstance(sublayer, BaseBlock): + for param in sublayer.parameters(): + self._key2param[name] = param.name + for ss in self._same_ss_by_hand: + param_ss = [] + for key in ss: + param_ss.append(self._key2param[key]) + self._same_ss.append(param_ss) if self._same_ss != None: self._param2key = {} @@ -793,5 +845,6 @@ class OFA(OFABase): if self._add_teacher: self._remove_hook_after_forward() + return student_output, teacher_output - return student_output, teacher_output + return student_output #, teacher_output \ No newline at end of file diff --git a/paddleslim/nas/ofa/utils/utils.py b/paddleslim/nas/ofa/utils/utils.py index 837159170d13b306c958b37752be553edc6da736..1887f4f04ad7c517735cdc11eb88919d987a0bb7 100644 --- a/paddleslim/nas/ofa/utils/utils.py +++ b/paddleslim/nas/ofa/utils/utils.py @@ -14,6 +14,9 @@ import logging import paddle +import numbers +import numpy as np +from paddle.fluid import core from ....common import get_logger @@ -41,7 +44,6 @@ __all__ = ['set_state_dict'] def set_state_dict(model, state_dict): """ Set state dict from origin model to supernet model. - Args: model(paddle.nn.Layer): model after convert to supernet. state_dict(dict): dict with the type of {name: param} in origin model. @@ -59,12 +61,50 @@ def set_state_dict(model, state_dict): _logger.info('{} is not in state_dict'.format(tmp_n)) +def to_tensor(string_values, name="text"): + """ + Create the tensor that the value holds the list of string. + NOTICE: The value will be holded in the cpu place. + Parameters: + string_values(list[string]): The value will be setted to the tensor. + name(string): The name of the tensor. + """ + tensor = paddle.Tensor(core.VarDesc.VarType.STRING, [], name, + core.VarDesc.VarType.STRINGS, False) + tensor.value().set_string_list(string_values) + return tensor + + +def build_input(input_size, dtypes): + if isinstance(input_size, list) and all( + isinstance(i, numbers.Number) for i in input_size): + if isinstance(dtypes, list): + dtype = dtypes[0] + else: + dtype = dtypes + if dtype == core.VarDesc.VarType.STRINGS: + return to_tensor([""]) + return paddle.cast(paddle.rand(list(input_size)), dtype) + if isinstance(input_size, dict): + inputs = {} + if isinstance(dtypes, list): + dtype = dtypes[0] + else: + dtype = dtypes + for key, value in input_size.items(): + inputs[key] = paddle.cast(paddle.rand(list(value)), dtype) + return inputs + if isinstance(input_size, list): + return [build_input(i, dtype) for i, dtype in zip(input_size, dtypes)] + + def remove_model_fn(model, state_dict): new_dict = {} keys = [] for name, param in model.state_dict().items(): keys.append(name) for name, param in state_dict.items(): + tmp_n = None if len(name.split('.')) <= 2: new_dict[name] = param continue @@ -113,4 +153,4 @@ def search_idx(num, sorted_nestlist): if num <= task_[phase_idx]: return idx, phase_idx assert num > max_num - return len(sorted_nestlist) - 1, max_idx + return len(sorted_nestlist) - 1, max_idx \ No newline at end of file diff --git a/tests/test_ofa.py b/tests/test_ofa.py index baa5f4c43c5678194eb86480f994a9f4360a4a0e..997793fc6f8872a6cacfa45000e85c4bf217cc23 100644 --- a/tests/test_ofa.py +++ b/tests/test_ofa.py @@ -25,6 +25,7 @@ from paddleslim.nas.ofa import OFA, RunConfig, DistillConfig from paddleslim.nas.ofa.convert_super import supernet from paddleslim.nas.ofa.layers import Block, SuperSeparableConv2D from paddleslim.nas.ofa.convert_super import Convert, supernet +from paddle.static import InputSpec class ModelConv(nn.Layer): @@ -321,7 +322,7 @@ class TestOFA(unittest.TestCase): ofa_model.set_epoch(epoch_id) for model_no in range(self.run_config.dynamic_batch_size[ idx]): - output, _ = ofa_model(self.data) + output = ofa_model(self.data) loss = paddle.mean(output) if self.distill_config.mapping_layers != None: dis_loss = ofa_model.calc_distill_loss() @@ -440,7 +441,7 @@ class TestShortCut(unittest.TestCase): def _test_clear_search_space(self): self.ofa_model = OFA(self.model) self.ofa_model.set_epoch(0) - outs, _ = self.ofa_model(self.images) + outs = self.ofa_model(self.images) self.config = self.ofa_model.current_config def test_export_model(self): @@ -458,7 +459,7 @@ class TestExportCase1(unittest.TestCase): self.data = paddle.to_tensor(data_np) self.ofa_model = OFA(model) self.ofa_model.set_epoch(0) - outs, _ = self.ofa_model(self.data) + outs = self.ofa_model(self.data) self.config = self.ofa_model.current_config def test_export_model_linear1(self): @@ -475,7 +476,7 @@ class TestExportCase2(unittest.TestCase): self.data = paddle.to_tensor(data_np) self.ofa_model = OFA(model) self.ofa_model.set_epoch(0) - outs, _ = self.ofa_model(self.data) + outs = self.ofa_model(self.data) self.config = self.ofa_model.current_config def test_export_model_linear2(self): @@ -487,5 +488,121 @@ class TestExportCase2(unittest.TestCase): assert len(self.ofa_model.ofa_layers) == 3 +class TestManualSetting(unittest.TestCase): + def setUp(self): + self._init_model() + + def _init_model(self): + model = ModelOriginLinear() + data_np = np.random.random((3, 64)).astype(np.int64) + self.data = paddle.to_tensor(data_np) + sp_net_config = supernet(expand_ratio=[0.25, 0.5, 1.0]) + self.model = Convert(sp_net_config).convert(model) + + def test_setting_byhand(self): + self.ofa_model1 = OFA(self.model) + for key, value in self.ofa_model1._ofa_layers.items(): + if 'expand_ratio' in value: + assert value['expand_ratio'] == [0.25, 0.5, 1.0] + self.ofa_model1._clear_search_space(self.data) + assert len(self.ofa_model1._ofa_layers) == 3 + + ofa_layers = { + 'models.0': { + 'expand_ratio': [0.5, 1.0] + }, + 'models.1': { + 'expand_ratio': [0.25, 1.0] + }, + 'models.3': { + 'expand_ratio': [0.25, 1.0] + }, + 'models.4': {} + } + same_search_space = [['models.1', 'models.3']] + skip_layers = ['models.0'] + cfg = { + 'ofa_layers': ofa_layers, + 'same_search_space': same_search_space, + 'skip_layers': skip_layers + } + run_config = RunConfig(**cfg) + self.ofa_model2 = OFA(self.model, run_config=run_config) + self.ofa_model2._clear_search_space(self.data) + #print(self.ofa_model2._ofa_layers) + assert self.ofa_model2._ofa_layers['models.1'][ + 'expand_ratio'] == [0.25, 1.0] + assert len(self.ofa_model2._ofa_layers) == 2 + #print(self.ofa_model_1._ofa_layers) + + def test_tokenize(self): + self.ofa_model = OFA(self.model) + self.ofa_model.set_task('expand_ratio') + self.ofa_model._clear_search_space(self.data) + self.ofa_model.tokenize() + assert self.ofa_model.token_map == { + 'expand_ratio': { + 'models.0': { + 0: 0.25, + 1: 0.5, + 2: 1.0 + }, + 'models.1': { + 0: 0.25, + 1: 0.5, + 2: 1.0 + }, + 'models.3': { + 0: 0.25, + 1: 0.5, + 2: 1.0 + } + } + } + assert self.ofa_model.search_cands == [[0, 1, 2], [0, 1, 2], [0, 1, 2]] + + ofa_layers = { + 'models.0': { + 'expand_ratio': [0.5, 1.0] + }, + 'models.1': { + 'expand_ratio': [0.25, 1.0] + }, + 'models.3': { + 'expand_ratio': [0.25, 1.0] + }, + 'models.4': {} + } + same_search_space = [['models.1', 'models.3']] + cfg = {'ofa_layers': ofa_layers, 'same_search_space': same_search_space} + run_config = RunConfig(**cfg) + self.ofa_model2 = OFA(self.model, run_config=run_config) + self.ofa_model2.set_task('expand_ratio') + self.ofa_model2._clear_search_space(self.data) + self.ofa_model2.tokenize() + assert self.ofa_model2.token_map == { + 'expand_ratio': { + 'models.0': { + 1: 0.5, + 2: 1.0 + }, + 'models.1': { + 0: 0.25, + 2: 1.0 + } + } + } + assert self.ofa_model2.search_cands == [[1, 2], [0, 2]] + + token = [1, 2] + config = self.ofa_model2.decode_token(token) + assert config == {'models.0': 0.5, 'models.1': 1.0} + + def test_input_spec(self): + self.ofa_model = OFA(self.model) + self.ofa_model.set_task('expand_ratio') + self.ofa_model._clear_search_space(input_spec=[self.data]) + + if __name__ == '__main__': - unittest.main() + unittest.main() \ No newline at end of file