diff --git a/demo/one_shot/ofa_train.py b/demo/one_shot/ofa_train.py new file mode 100644 index 0000000000000000000000000000000000000000..4a47a219c1096d750757f407cfde4ff37691efb7 --- /dev/null +++ b/demo/one_shot/ofa_train.py @@ -0,0 +1,127 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.fluid.dygraph.nn as nn +from paddle.nn import ReLU +from paddleslim.nas.ofa import OFA, RunConfig, DistillConfig +from paddleslim.nas.ofa import supernet + + +class Model(fluid.dygraph.Layer): + def __init__(self): + super(Model, self).__init__() + with supernet( + kernel_size=(3, 5, 7), expand_ratio=[1, 2, 4]) as ofa_super: + models = [] + models += [nn.Conv2D(1, 6, 3)] + models += [ReLU()] + models += [nn.Pool2D(2, 'max', 2)] + models += [nn.Conv2D(6, 16, 5, padding=0)] + models += [ReLU()] + models += [nn.Pool2D(2, 'max', 2)] + models += [ + nn.Linear(784, 120), nn.Linear(120, 84), nn.Linear(84, 10) + ] + models = ofa_super.convert(models) + self.models = paddle.nn.Sequential(*models) + + def forward(self, inputs, label, depth=None): + if depth != None: + assert isinstance(depth, int) + assert depth < len(self.models) + models = self.models[:depth] + else: + depth = len(self.models) + models = self.models[:] + + for idx, layer in enumerate(models): + if idx == 6: + inputs = fluid.layers.flatten(inputs, 1) + inputs = layer(inputs) + + inputs = fluid.layers.softmax(inputs) + return inputs + + +def test_ofa(): + + default_run_config = { + 'train_batch_size': 256, + 'eval_batch_size': 64, + 'n_epochs': [[1], [2, 3], [4, 5]], + 'init_learning_rate': [[0.001], [0.003, 0.001], [0.003, 0.001]], + 'dynamic_batch_size': [1, 1, 1], + 'total_images': 50000, #1281167, + 'elastic_depth': (2, 5, 8) + } + run_config = RunConfig(**default_run_config) + + default_distill_config = { + 'lambda_distill': 0.01, + 'teacher_model': Model, + 'mapping_layers': ['models.0.fn'] + } + distill_config = DistillConfig(**default_distill_config) + + fluid.enable_dygraph() + model = Model() + ofa_model = OFA(model, run_config, distill_config=distill_config) + + train_reader = paddle.fluid.io.batch( + paddle.dataset.mnist.train(), batch_size=256, drop_last=True) + + start_epoch = 0 + for idx in range(len(run_config.n_epochs)): + cur_idx = run_config.n_epochs[idx] + for ph_idx in range(len(cur_idx)): + cur_lr = run_config.init_learning_rate[idx][ph_idx] + adam = fluid.optimizer.Adam( + learning_rate=cur_lr, + parameter_list=(ofa_model.parameters() + ofa_model.netAs_param)) + for epoch_id in range(start_epoch, + run_config.n_epochs[idx][ph_idx]): + for batch_id, data in enumerate(train_reader()): + dy_x_data = np.array( + [x[0].reshape(1, 28, 28) + for x in data]).astype('float32') + y_data = np.array( + [x[1] for x in data]).astype('int64').reshape(-1, 1) + + img = fluid.dygraph.to_variable(dy_x_data) + label = fluid.dygraph.to_variable(y_data) + label.stop_gradient = True + + for model_no in range(run_config.dynamic_batch_size[idx]): + output, _ = ofa_model(img, label) + loss = fluid.layers.reduce_mean(output) + dis_loss = ofa_model.calc_distill_loss() + loss += dis_loss + loss.backward() + + if batch_id % 10 == 0: + print( + 'epoch: {}, batch: {}, loss: {}, distill loss: {}'. + format(epoch_id, batch_id, + loss.numpy()[0], dis_loss.numpy()[0])) + ### accumurate dynamic_batch_size network of gradients for same batch of data + ### NOTE: need to fix gradients accumulate in PaddlePaddle + adam.minimize(loss) + adam.clear_gradients() + start_epoch = run_config.n_epochs[idx][ph_idx] + + +test_ofa() diff --git a/paddleslim/nas/ofa/__init__.py b/paddleslim/nas/ofa/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..db1394baf6dc59286b678f302b80fe2c5de404c1 --- /dev/null +++ b/paddleslim/nas/ofa/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .ofa import OFA, RunConfig, DistillConfig +from .convert_super import supernet +from .layers import * diff --git a/paddleslim/nas/ofa/convert_super.py b/paddleslim/nas/ofa/convert_super.py new file mode 100644 index 0000000000000000000000000000000000000000..f7ff8a1e530cef850415049c1d8a1b42dfcc0345 --- /dev/null +++ b/paddleslim/nas/ofa/convert_super.py @@ -0,0 +1,417 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import decorator +import logging +import paddle +import paddle.fluid as fluid +from paddle.fluid import framework +from paddle.fluid.dygraph.nn import Conv2D, Conv2DTranspose, Linear, BatchNorm, InstanceNorm +from .layers import * +from ...common import get_logger + +_logger = get_logger(__name__, level=logging.INFO) + +__all__ = ['supernet'] + +WEIGHT_LAYER = ['conv', 'linear'] + + +### TODO: add decorator +class Convert: + def __init__(self, context): + self.context = context + + def convert(self, model): + # search the first and last weight layer, don't change out channel of the last weight layer + # don't change in channel of the first weight layer + first_weight_layer_idx = -1 + last_weight_layer_idx = -1 + weight_layer_count = 0 + # NOTE: pre_channel store for shortcut module + pre_channel = 0 + cur_channel = None + for idx, layer in enumerate(model): + cls_name = layer.__class__.__name__.lower() + if 'conv' in cls_name or 'linear' in cls_name: + weight_layer_count += 1 + last_weight_layer_idx = idx + if first_weight_layer_idx == -1: + first_weight_layer_idx = idx + + if getattr(self.context, 'channel', None) != None: + assert len( + self.context.channel + ) == weight_layer_count, "length of channel must same as weight layer." + + for idx, layer in enumerate(model): + if isinstance(layer, Conv2D): + attr_dict = layer.__dict__ + key = attr_dict['_full_name'] + + new_attr_name = [ + '_stride', '_dilation', '_groups', '_param_attr', + '_bias_attr', '_use_cudnn', '_act', '_dtype' + ] + + new_attr_dict = dict() + new_attr_dict['candidate_config'] = dict() + self.kernel_size = getattr(self.context, 'kernel_size', None) + + if self.kernel_size != None: + new_attr_dict['transform_kernel'] = True + + # if the kernel_size of conv is 1, don't change it. + #if self.kernel_size and int(attr_dict['_filter_size'][0]) != 1: + if self.kernel_size and int(attr_dict['_filter_size']) != 1: + new_attr_dict['filter_size'] = max(self.kernel_size) + new_attr_dict['candidate_config'].update({ + 'kernel_size': self.kernel_size + }) + else: + new_attr_dict['filter_size'] = attr_dict['_filter_size'] + + if self.context.expand: + ### first super convolution + if idx == first_weight_layer_idx: + new_attr_dict['num_channels'] = attr_dict[ + '_num_channels'] + else: + new_attr_dict[ + 'num_channels'] = self.context.expand * attr_dict[ + '_num_channels'] + ### last super convolution + if idx == last_weight_layer_idx: + new_attr_dict['num_filters'] = attr_dict['_num_filters'] + else: + new_attr_dict[ + 'num_filters'] = self.context.expand * attr_dict[ + '_num_filters'] + new_attr_dict['candidate_config'].update({ + 'expand_ratio': self.context.expand_ratio + }) + elif self.context.channel: + if attr_dict['_groups'] != None and ( + int(attr_dict['_groups']) == + int(attr_dict['_num_channels'])): + ### depthwise conv, if conv is depthwise, use pre channel as cur_channel + _logger.warn( + "If convolution is a depthwise conv, output channel change" \ + " to the same channel with input, output channel in search is not used." + ) + cur_channel = pre_channel + else: + cur_channel = self.context.channel[0] + self.context.channel = self.context.channel[1:] + if idx == first_weight_layer_idx: + new_attr_dict['num_channels'] = attr_dict[ + '_num_channels'] + else: + new_attr_dict['num_channels'] = max(pre_channel) + + if idx == last_weight_layer_idx: + new_attr_dict['num_filters'] = attr_dict['_num_filters'] + else: + new_attr_dict['num_filters'] = max(cur_channel) + new_attr_dict['candidate_config'].update({ + 'channel': cur_channel + }) + pre_channel = cur_channel + else: + new_attr_dict['num_filters'] = attr_dict['_num_filters'] + new_attr_dict['num_channels'] = attr_dict['_num_channels'] + + for attr in new_attr_name: + new_attr_dict[attr[1:]] = attr_dict[attr] + + del layer + + if attr_dict['_groups'] == None or int(attr_dict[ + '_groups']) == 1: + ### standard conv + layer = Block(SuperConv2D(**new_attr_dict), key=key) + elif int(attr_dict['_groups']) == int(attr_dict[ + '_num_channels']): + # if conv is depthwise conv, groups = in_channel, out_channel = in_channel, + # channel in candidate_config = in_channel_list + if 'channel' in new_attr_dict['candidate_config']: + new_attr_dict['num_channels'] = max(cur_channel) + new_attr_dict['num_filters'] = new_attr_dict[ + 'num_channels'] + new_attr_dict['candidate_config'][ + 'channel'] = cur_channel + new_attr_dict['groups'] = new_attr_dict['num_channels'] + layer = Block( + SuperDepthwiseConv2D(**new_attr_dict), key=key) + else: + ### group conv + layer = Block(SuperGroupConv2D(**new_attr_dict), key=key) + model[idx] = layer + + elif isinstance(layer, BatchNorm) and ( + getattr(self.context, 'expand', None) != None or + getattr(self.context, 'channel', None) != None): + # num_features in BatchNorm don't change after last weight operators + if idx > last_weight_layer_idx: + continue + + attr_dict = layer.__dict__ + new_attr_name = [ + '_param_attr', '_bias_attr', '_act', '_dtype', '_in_place', + '_data_layout', '_momentum', '_epsilon', '_is_test', + '_use_global_stats', '_trainable_statistics' + ] + new_attr_dict = dict() + if self.context.expand: + new_attr_dict['num_channels'] = self.context.expand * int( + layer._parameters['weight'].shape[0]) + elif self.context.channel: + new_attr_dict['num_channels'] = max(cur_channel) + + for attr in new_attr_name: + new_attr_dict[attr[1:]] = attr_dict[attr] + + del layer, attr_dict + + layer = SuperBatchNorm(**new_attr_dict) + model[idx] = layer + + ### assume output_size = None, filter_size != None + ### NOTE: output_size != None may raise error, solve when it happend. + elif isinstance(layer, Conv2DTranspose): + attr_dict = layer.__dict__ + key = attr_dict['_full_name'] + + new_attr_name = [ + '_stride', '_dilation', '_groups', '_param_attr', + '_bias_attr', '_use_cudnn', '_act', '_dtype', '_output_size' + ] + assert attr_dict[ + '_filter_size'] != None, "Conv2DTranspose only support filter size != None now" + + new_attr_dict = dict() + new_attr_dict['candidate_config'] = dict() + self.kernel_size = getattr(self.context, 'kernel_size', None) + + if self.kernel_size != None: + new_attr_dict['transform_kernel'] = True + + # if the kernel_size of conv transpose is 1, don't change it. + if self.kernel_size and int(attr_dict['_filter_size'][0]) != 1: + new_attr_dict['filter_size'] = max(self.kernel_size) + new_attr_dict['candidate_config'].update({ + 'kernel_size': self.kernel_size + }) + else: + new_attr_dict['filter_size'] = attr_dict['_filter_size'] + + if self.context.expand: + ### first super convolution transpose + if idx == first_weight_layer_idx: + new_attr_dict['num_channels'] = attr_dict[ + '_num_channels'] + else: + new_attr_dict[ + 'num_channels'] = self.context.expand * attr_dict[ + '_num_channels'] + ### last super convolution transpose + if idx == last_weight_layer_idx: + new_attr_dict['num_filters'] = attr_dict['_num_filters'] + else: + new_attr_dict[ + 'num_filters'] = self.context.expand * attr_dict[ + '_num_filters'] + new_attr_dict['candidate_config'].update({ + 'expand_ratio': self.context.expand_ratio + }) + elif self.context.channel: + if attr_dict['_groups'] != None and ( + int(attr_dict['_groups']) == + int(attr_dict['_num_channels'])): + ### depthwise conv_transpose + _logger.warn( + "If convolution is a depthwise conv_transpose, output channel " \ + "change to the same channel with input, output channel in search is not used." + ) + cur_channel = pre_channel + else: + cur_channel = self.context.channel[0] + self.context.channel = self.context.channel[1:] + if idx == first_weight_layer_idx: + new_attr_dict['num_channels'] = attr_dict[ + '_num_channels'] + else: + new_attr_dict['num_channels'] = max(pre_channel) + + if idx == last_weight_layer_idx: + new_attr_dict['num_filters'] = attr_dict['_num_filters'] + else: + new_attr_dict['num_filters'] = max(cur_channel) + new_attr_dict['candidate_config'].update({ + 'channel': cur_channel + }) + pre_channel = cur_channel + else: + new_attr_dict['num_filters'] = attr_dict['_num_filters'] + new_attr_dict['num_channels'] = attr_dict['_num_channels'] + + for attr in new_attr_name: + new_attr_dict[attr[1:]] = attr_dict[attr] + + del layer + + if new_attr_dict['output_size'] == []: + new_attr_dict['output_size'] = None + + if attr_dict['_groups'] == None or int(attr_dict[ + '_groups']) == 1: + ### standard conv_transpose + layer = Block( + SuperConv2DTranspose(**new_attr_dict), key=key) + elif int(attr_dict['_groups']) == int(attr_dict[ + '_num_channels']): + # if conv is depthwise conv, groups = in_channel, out_channel = in_channel, + # channel in candidate_config = in_channel_list + if 'channel' in new_attr_dict['candidate_config']: + new_attr_dict['num_channels'] = max(cur_channel) + new_attr_dict['num_filters'] = new_attr_dict[ + 'num_channels'] + new_attr_dict['candidate_config'][ + 'channel'] = cur_channel + new_attr_dict['groups'] = new_attr_dict['num_channels'] + layer = Block( + SuperDepthwiseConv2DTranspose(**new_attr_dict), key=key) + else: + ### group conv_transpose + layer = Block( + SuperGroupConv2DTranspose(**new_attr_dict), key=key) + model[idx] = layer + + elif isinstance(layer, Linear) and ( + getattr(self.context, 'expand', None) != None or + getattr(self.context, 'channel', None) != None): + attr_dict = layer.__dict__ + key = attr_dict['_full_name'] + ### TODO(paddle): add _param_attr and _bias_attr as private variable of Linear + #new_attr_name = ['_act', '_dtype', '_param_attr', '_bias_attr'] + new_attr_name = ['_act', '_dtype'] + in_nc, out_nc = layer._parameters['weight'].shape + + new_attr_dict = dict() + new_attr_dict['candidate_config'] = dict() + if self.context.expand: + if idx == first_weight_layer_idx: + new_attr_dict['input_dim'] = int(in_nc) + else: + new_attr_dict['input_dim'] = self.context.expand * int( + in_nc) + + if idx == last_weight_layer_idx: + new_attr_dict['output_dim'] = int(out_nc) + else: + new_attr_dict['output_dim'] = self.context.expand * int( + out_nc) + new_attr_dict['candidate_config'].update({ + 'expand_ratio': self.context.expand_ratio + }) + elif self.context.channel: + cur_channel = self.context.channel[0] + self.context.channel = self.context.channel[1:] + if idx == first_weight_layer_idx: + new_attr_dict['input_dim'] = int(in_nc) + else: + new_attr_dict['input_dim'] = max(pre_channel) + + if idx == last_weight_layer_idx: + new_attr_dict['output_dim'] = int(out_nc) + else: + new_attr_dict['output_dim'] = max(cur_channel) + new_attr_dict['candidate_config'].update({ + 'channel': cur_channel + }) + pre_channel = cur_channel + else: + new_attr_dict['input_dim'] = int(in_nc) + new_attr_dict['output_dim'] = int(out_nc) + + for attr in new_attr_name: + new_attr_dict[attr[1:]] = attr_dict[attr] + + del layer, attr_dict + + layer = Block(SuperLinear(**new_attr_dict), key=key) + model[idx] = layer + + elif isinstance(layer, InstanceNorm) and ( + getattr(self.context, 'expand', None) != None or + getattr(self.context, 'channel', None) != None): + # num_features in InstanceNorm don't change after last weight operators + if idx > last_weight_layer_idx: + continue + + attr_dict = layer.__dict__ + new_attr_name = [ + '_param_attr', '_bias_attr', '_dtype', '_epsilon' + ] + new_attr_dict = dict() + if self.context.expand: + new_attr_dict['num_channels'] = self.context.expand * int( + layer._parameters['scale'].shape[0]) + elif self.context.channel: + new_attr_dict['num_channels'] = max(cur_channel) + + for attr in new_attr_name: + new_attr_dict[attr[1:]] = attr_dict[attr] + + del layer, attr_dict + + layer = SuperInstanceNorm(**new_attr_dict) + model[idx] = layer + + return model + + +class supernet: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + assert ( + getattr(self, 'expand_ratio', None) == None or + getattr(self, 'channel', None) == None + ), "expand_ratio and channel CANNOT be NOT None at the same time." + + self.expand = None + if 'expand_ratio' in kwargs.keys(): + if isinstance(self.expand_ratio, list) or isinstance( + self.expand_ratio, tuple): + self.expand = max(self.expand_ratio) + elif isinstance(self.expand_ratio, int): + self.expand = self.expand_ratio + + def __enter__(self): + return Convert(self) + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + +#def ofa_supernet(kernel_size, expand_ratio): +# def _ofa_supernet(func): +# @functools.wraps(func) +# def convert(*args, **kwargs): +# supernet_convert(*args, **kwargs) +# return convert +# return _ofa_supernet diff --git a/paddleslim/nas/ofa/layers.py b/paddleslim/nas/ofa/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..4d91f5338a8a1f9ee67cc1d7dab2657d85348454 --- /dev/null +++ b/paddleslim/nas/ofa/layers.py @@ -0,0 +1,929 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import logging +import paddle.fluid as fluid +import paddle.fluid.core as core +import paddle.fluid.dygraph_utils as dygraph_utils +from paddle.fluid.data_feeder import check_variable_and_dtype +from paddle.fluid.framework import in_dygraph_mode, _varbase_creator +from paddle.fluid.dygraph.nn import InstanceNorm, Conv2D, Conv2DTranspose, BatchNorm + +from ...common import get_logger +from .utils.utils import compute_start_end, get_same_padding, convert_to_list + +__all__ = [ + 'SuperConv2D', 'SuperConv2DTranspose', 'SuperSeparableConv2D', + 'SuperBatchNorm', 'SuperLinear', 'SuperInstanceNorm', 'Block', + 'SuperGroupConv2D', 'SuperDepthwiseConv2D', 'SuperGroupConv2DTranspose', + 'SuperDepthwiseConv2DTranspose' +] + +_logger = get_logger(__name__, level=logging.INFO) + +### TODO: if task is elastic width, need to add re_organize_middle_weight in 1x1 conv in MBBlock + +_cnt = 0 + + +def counter(): + global _cnt + _cnt += 1 + return _cnt + + +class BaseBlock(fluid.dygraph.Layer): + def __init__(self, key=None): + super(BaseBlock, self).__init__() + if key is not None: + self._key = str(key) + else: + self._key = self.__class__.__name__ + str(counter()) + + # set SuperNet class + def set_supernet(self, supernet): + self.__dict__['supernet'] = supernet + + @property + def key(self): + return self._key + + +class Block(BaseBlock): + """ + Model is composed of nest blocks. + + Parameters: + fn(Layer): instance of super layers, such as: SuperConv2D(3, 5, 3). + key(str, optional): key of this layer, one-to-one correspondence between key and candidate config. Default: None. + """ + + def __init__(self, fn, key=None): + super(Block, self).__init__(key) + self.fn = fn + self.candidate_config = self.fn.candidate_config + + def forward(self, *inputs, **kwargs): + out = self.supernet.layers_forward(self, *inputs, **kwargs) + return out + + +class SuperConv2D(fluid.dygraph.Conv2D): + """ + This interface is used to construct a callable object of the ``SuperConv2D`` class. + The difference between ```SuperConv2D``` and ```Conv2D``` is: ```SuperConv2D``` need + to feed a config dictionary with the format of {'channel', num_of_channel} represents + the channels of the outputs, used to change the first dimension of weight and bias, + only train the first channels of the weight and bias. + + Note: the channel in config need to less than first defined. + + The super convolution2D layer calculates the output based on the input, filter + and strides, paddings, dilations, groups parameters. Input and + Output are in NCHW format, where N is batch size, C is the number of + the feature map, H is the height of the feature map, and W is the width of the feature map. + Filter's shape is [MCHW] , where M is the number of output feature map, + C is the number of input feature map, H is the height of the filter, + and W is the width of the filter. If the groups is greater than 1, + C will equal the number of input feature map divided by the groups. + Please refer to UFLDL's `convolution + `_ + for more details. + If bias attribution and activation type are provided, bias is added to the + output of the convolution, and the corresponding activation function is + applied to the final result. + For each input :math:`X`, the equation is: + .. math:: + Out = \\sigma (W \\ast X + b) + Where: + * :math:`X`: Input value, a ``Tensor`` with NCHW format. + * :math:`W`: Filter value, a ``Tensor`` with shape [MCHW] . + * :math:`\\ast`: Convolution operation. + * :math:`b`: Bias value, a 2-D ``Tensor`` with shape [M, 1]. + * :math:`\\sigma`: Activation function. + * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different. + + Example: + - Input: + Input shape: :math:`(N, C_{in}, H_{in}, W_{in})` + Filter shape: :math:`(C_{out}, C_{in}, H_f, W_f)` + - Output: + Output shape: :math:`(N, C_{out}, H_{out}, W_{out})` + Where + .. math:: + H_{out}&= \\frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]} + 1 \\\\ + W_{out}&= \\frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1 + Parameters: + num_channels(int): The number of channels in the input image. + num_filters(int): The number of filter. It is as same as the output + feature map. + filter_size (int or tuple): The filter size. If filter_size is a tuple, + it must contain two integers, (filter_size_H, filter_size_W). + Otherwise, the filter will be a square. + candidate_config(dict, optional): Dictionary descripts candidate config of this layer, + such as {'kernel_size': (3, 5, 7), 'channel': (4, 6, 8)}, means the kernel size of + this layer can be choose from (3, 5, 7), the key of candidate_config + only can be 'kernel_size', 'channel' and 'expand_ratio', 'channel' and 'expand_ratio' + CANNOT be set at the same time. Default: None. + transform_kernel(bool, optional): Whether to use transform matrix to transform a large filter + to a small filter. Default: False. + stride (int or tuple, optional): The stride size. If stride is a tuple, it must + contain two integers, (stride_H, stride_W). Otherwise, the + stride_H = stride_W = stride. Default: 1. + padding (int or tuple, optional): The padding size. If padding is a tuple, it must + contain two integers, (padding_H, padding_W). Otherwise, the + padding_H = padding_W = padding. Default: 0. + dilation (int or tuple, optional): The dilation size. If dilation is a tuple, it must + contain two integers, (dilation_H, dilation_W). Otherwise, the + dilation_H = dilation_W = dilation. Default: 1. + groups (int, optional): The groups number of the Conv2d Layer. According to grouped + convolution in Alex Krizhevsky's Deep CNN paper: when group=2, + the first half of the filters is only connected to the first half + of the input channels, while the second half of the filters is only + connected to the second half of the input channels. Default: 1. + param_attr (ParamAttr, optional): The parameter attribute for learnable weights(Parameter) + of conv2d. If it is set to None or one attribute of ParamAttr, conv2d + will create ParamAttr as param_attr. If the Initializer of the param_attr + is not set, the parameter is initialized with :math:`Normal(0.0, std)`, + and the :math:`std` is :math:`(\\frac{2.0 }{filter\_elem\_num})^{0.5}`. Default: None. + bias_attr (ParamAttr or bool, optional): The attribute for the bias of conv2d. + If it is set to False, no bias will be added to the output units. + If it is set to None or one attribute of ParamAttr, conv2d + will create ParamAttr as bias_attr. If the Initializer of the bias_attr + is not set, the bias is initialized zero. Default: None. + use_cudnn (bool, optional): Use cudnn kernel or not, it is valid only when the cudnn + library is installed. Default: True. + act (str, optional): Activation type, if it is set to None, activation is not appended. + Default: None. + dtype (str, optional): Data type, it can be "float32" or "float64". Default: "float32". + Attribute: + **weight** (Parameter): the learnable weights of filter of this layer. + **bias** (Parameter or None): the learnable bias of this layer. + Returns: + None + + Raises: + ValueError: if ``use_cudnn`` is not a bool value. + Examples: + .. code-block:: python + from paddle.fluid.dygraph.base import to_variable + import paddle.fluid as fluid + from paddleslim.core.layers import SuperConv2D + import numpy as np + data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32') + with fluid.dygraph.guard(): + super_conv2d = SuperConv2D(3, 10, 3) + config = {'channel': 5} + data = to_variable(data) + conv = super_conv2d(data, config) + + """ + + ### NOTE: filter_size, num_channels and num_filters must be the max of candidate to define a largest network. + def __init__(self, + num_channels, + num_filters, + filter_size, + candidate_config={}, + transform_kernel=False, + stride=1, + dilation=1, + padding=0, + groups=None, + param_attr=None, + bias_attr=None, + use_cudnn=True, + act=None, + dtype='float32'): + ### NOTE: padding always is 0, add padding in forward because of kernel size is uncertain + ### TODO: change padding to any padding + super(SuperConv2D, self).__init__( + num_channels, num_filters, filter_size, stride, padding, dilation, + groups, param_attr, bias_attr, use_cudnn, act, dtype) + + if isinstance(self._filter_size, int): + self._filter_size = convert_to_list(self._filter_size, 2) + + self.candidate_config = candidate_config + if len(candidate_config.items()) != 0: + for k, v in candidate_config.items(): + candidate_config[k] = list(set(v)) + + self.ks_set = candidate_config[ + 'kernel_size'] if 'kernel_size' in candidate_config else None + + self.expand_ratio = candidate_config[ + 'expand_ratio'] if 'expand_ratio' in candidate_config else None + self.channel = candidate_config[ + 'channel'] if 'channel' in candidate_config else None + self.base_channel = None + if self.expand_ratio != None: + self.base_channel = int(self._num_filters / max(self.expand_ratio)) + + self.transform_kernel = transform_kernel + if self.ks_set != None: + self.ks_set.sort() + if self.transform_kernel != False: + scale_param = dict() + ### create parameter to transform kernel + for i in range(len(self.ks_set) - 1): + ks_small = self.ks_set[i] + ks_large = self.ks_set[i + 1] + param_name = '%dto%d_matrix' % (ks_large, ks_small) + ks_t = ks_small**2 + scale_param[param_name] = self.create_parameter( + attr=fluid.ParamAttr( + name=self._full_name + param_name, + initializer=fluid.initializer.NumpyArrayInitializer( + np.eye(ks_t))), + shape=(ks_t, ks_t), + dtype=self._dtype) + + for name, param in scale_param.items(): + setattr(self, name, param) + + def get_active_filter(self, in_nc, out_nc, kernel_size): + start, end = compute_start_end(self._filter_size[0], kernel_size) + ### if NOT transform kernel, intercept a center filter with kernel_size from largest filter + filters = self.weight[:out_nc, :in_nc, start:end, start:end] + if self.transform_kernel != False and kernel_size < self._filter_size[ + 0]: + ### if transform kernel, then use matrix to transform + start_filter = self.weight[:out_nc, :in_nc, :, :] + for i in range(len(self.ks_set) - 1, 0, -1): + src_ks = self.ks_set[i] + if src_ks <= kernel_size: + break + target_ks = self.ks_set[i - 1] + start, end = compute_start_end(src_ks, target_ks) + _input_filter = start_filter[:, :, start:end, start:end] + _input_filter = fluid.layers.reshape( + _input_filter, + shape=[(_input_filter.shape[0] * _input_filter.shape[1]), + -1]) + core.ops.matmul(_input_filter, + self.__getattr__('%dto%d_matrix' % + (src_ks, target_ks)), + _input_filter, 'transpose_X', False, + 'transpose_Y', False, "alpha", 1) + _input_filter = fluid.layers.reshape( + _input_filter, + shape=[ + filters.shape[0], filters.shape[1], target_ks, target_ks + ]) + start_filter = _input_filter + filters = start_filter + return filters + + def get_groups_in_out_nc(self, in_nc, out_nc): + ### standard conv + return self._groups, in_nc, out_nc + + def forward(self, input, kernel_size=None, expand_ratio=None, channel=None): + + if not in_dygraph_mode(): + _logger.error("NOT support static graph") + + in_nc = int(input.shape[1]) + assert ( + expand_ratio == None or channel == None + ), "expand_ratio and channel CANNOT be NOT None at the same time." + if expand_ratio != None: + out_nc = int(expand_ratio * self.base_channel) + elif channel != None: + out_nc = int(channel) + else: + out_nc = self._num_filters + ks = int(self._filter_size[0]) if kernel_size == None else int( + kernel_size) + + groups, weight_in_nc, weight_out_nc = self.get_groups_in_out_nc(in_nc, + out_nc) + + weight = self.get_active_filter(weight_in_nc, weight_out_nc, ks) + padding = convert_to_list(get_same_padding(ks), 2) + + if self._l_type == 'conv2d': + attrs = ('strides', self._stride, 'paddings', padding, 'dilations', + self._dilation, 'groups', groups + if groups else 1, 'use_cudnn', self._use_cudnn) + out = core.ops.conv2d(input, weight, *attrs) + elif self._l_type == 'depthwise_conv2d': + attrs = ('strides', self._stride, 'paddings', padding, 'dilations', + self._dilation, 'groups', groups + if groups else self._groups, 'use_cudnn', self._use_cudnn) + out = core.ops.depthwise_conv2d(input, weight, *attrs) + else: + raise ValueError("conv type error") + + pre_bias = out + out_nc = int(pre_bias.shape[1]) + if self.bias is not None: + bias = self.bias[:out_nc] + pre_act = dygraph_utils._append_bias_in_dygraph(pre_bias, bias, 1) + else: + pre_act = pre_bias + + return dygraph_utils._append_activation_in_dygraph(pre_act, self._act) + + +class SuperGroupConv2D(SuperConv2D): + def get_groups_in_out_nc(self, in_nc, out_nc): + ### groups convolution + ### conv: weight: (Cout, Cin/G, Kh, Kw) + groups = self._groups + in_nc = int(in_nc // groups) + return groups, in_nc, out_nc + + +class SuperDepthwiseConv2D(SuperConv2D): + ### depthwise convolution + def get_groups_in_out_nc(self, in_nc, out_nc): + if in_nc != out_nc: + _logger.debug( + "input channel and output channel in depthwise conv is different, change output channel to input channel! origin channel:(in_nc {}, out_nc {}): ". + format(in_nc, out_nc)) + groups = in_nc + out_nc = in_nc + return groups, in_nc, out_nc + + +class SuperConv2DTranspose(fluid.dygraph.Conv2DTranspose): + """ + This interface is used to construct a callable object of the ``SuperConv2DTranspose`` + class. + The difference between ```SuperConv2DTranspose``` and ```Conv2DTranspose``` is: + ```SuperConv2DTranspose``` need to feed a config dictionary with the format of + {'channel', num_of_channel} represents the channels of the outputs, used to change + the first dimension of weight and bias, only train the first channels of the weight + and bias. + + Note: the channel in config need to less than first defined. + + The super convolution2D transpose layer calculates the output based on the input, + filter, and dilations, strides, paddings. Input and output + are in NCHW format. Where N is batch size, C is the number of feature map, + H is the height of the feature map, and W is the width of the feature map. + Filter's shape is [MCHW] , where M is the number of input feature map, + C is the number of output feature map, H is the height of the filter, + and W is the width of the filter. If the groups is greater than 1, + C will equal the number of input feature map divided by the groups. + If bias attribution and activation type are provided, bias is added to + the output of the convolution, and the corresponding activation function + is applied to the final result. + The details of convolution transpose layer, please refer to the following explanation and references + `conv2dtranspose `_ . + For each input :math:`X`, the equation is: + .. math:: + Out = \sigma (W \\ast X + b) + Where: + * :math:`X`: Input value, a ``Tensor`` with NCHW format. + * :math:`W`: Filter value, a ``Tensor`` with shape [MCHW] . + * :math:`\\ast`: Convolution operation. + * :math:`b`: Bias value, a 2-D ``Tensor`` with shape [M, 1]. + * :math:`\\sigma`: Activation function. + * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different. + Example: + - Input: + Input shape: :math:`(N, C_{in}, H_{in}, W_{in})` + Filter shape: :math:`(C_{in}, C_{out}, H_f, W_f)` + - Output: + Output shape: :math:`(N, C_{out}, H_{out}, W_{out})` + Where + .. math:: + H^\prime_{out} &= (H_{in} - 1) * strides[0] - 2 * paddings[0] + dilations[0] * (H_f - 1) + 1 \\\\ + W^\prime_{out} &= (W_{in} - 1) * strides[1] - 2 * paddings[1] + dilations[1] * (W_f - 1) + 1 \\\\ + H_{out} &\in [ H^\prime_{out}, H^\prime_{out} + strides[0] ) \\\\ + W_{out} &\in [ W^\prime_{out}, W^\prime_{out} + strides[1] ) + Parameters: + num_channels(int): The number of channels in the input image. + num_filters(int): The number of the filter. It is as same as the output + feature map. + filter_size(int or tuple): The filter size. If filter_size is a tuple, + it must contain two integers, (filter_size_H, filter_size_W). + Otherwise, the filter will be a square. + candidate_config(dict, optional): Dictionary descripts candidate config of this layer, + such as {'kernel_size': (3, 5, 7), 'channel': (4, 6, 8)}, means the kernel size of + this layer can be choose from (3, 5, 7), the key of candidate_config + only can be 'kernel_size', 'channel' and 'expand_ratio', 'channel' and 'expand_ratio' + CANNOT be set at the same time. Default: None. + transform_kernel(bool, optional): Whether to use transform matrix to transform a large filter + to a small filter. Default: False. + output_size(int or tuple, optional): The output image size. If output size is a + tuple, it must contain two integers, (image_H, image_W). None if use + filter_size, padding, and stride to calculate output_size. + if output_size and filter_size are specified at the same time, They + should follow the formula above. Default: None. + padding(int or tuple, optional): The padding size. If padding is a tuple, it must + contain two integers, (padding_H, padding_W). Otherwise, the + padding_H = padding_W = padding. Default: 0. + stride(int or tuple, optional): The stride size. If stride is a tuple, it must + contain two integers, (stride_H, stride_W). Otherwise, the + stride_H = stride_W = stride. Default: 1. + dilation(int or tuple, optional): The dilation size. If dilation is a tuple, it must + contain two integers, (dilation_H, dilation_W). Otherwise, the + dilation_H = dilation_W = dilation. Default: 1. + groups(int, optional): The groups number of the Conv2d transpose layer. Inspired by + grouped convolution in Alex Krizhevsky's Deep CNN paper, in which + when group=2, the first half of the filters is only connected to the + first half of the input channels, while the second half of the + filters is only connected to the second half of the input channels. + Default: 1. + param_attr (ParamAttr, optional): The parameter attribute for learnable weights(Parameter) + of conv2d_transpose. If it is set to None or one attribute of ParamAttr, conv2d_transpose + will create ParamAttr as param_attr. If the Initializer of the param_attr + is not set, the parameter is initialized with Xavier. Default: None. + bias_attr (ParamAttr or bool, optional): The attribute for the bias of conv2d_transpose. + If it is set to False, no bias will be added to the output units. + If it is set to None or one attribute of ParamAttr, conv2d_transpose + will create ParamAttr as bias_attr. If the Initializer of the bias_attr + is not set, the bias is initialized zero. Default: None. + use_cudnn(bool, optional): Use cudnn kernel or not, it is valid only when the cudnn + library is installed. Default: True. + act (str, optional): Activation type, if it is set to None, activation is not appended. + Default: None. + dtype (str, optional): Data type, it can be "float32" or "float64". Default: "float32". + Attribute: + **weight** (Parameter): the learnable weights of filters of this layer. + **bias** (Parameter or None): the learnable bias of this layer. + Returns: + None + Examples: + .. code-block:: python + import paddle.fluid as fluid + from paddleslim.core.layers import SuperConv2DTranspose + import numpy as np + with fluid.dygraph.guard(): + data = np.random.random((3, 32, 32, 5)).astype('float32') + config = {'channel': 5 + super_convtranspose = SuperConv2DTranspose(num_channels=32, num_filters=10, filter_size=3) + ret = super_convtranspose(fluid.dygraph.base.to_variable(data), config) + """ + + def __init__(self, + num_channels, + num_filters, + filter_size, + output_size=None, + candidate_config={}, + transform_kernel=False, + stride=1, + dilation=1, + padding=0, + groups=None, + param_attr=None, + bias_attr=None, + use_cudnn=True, + act=None, + dtype='float32'): + ### NOTE: padding always is 0, add padding in forward because of kernel size is uncertain + super(SuperConv2DTranspose, self).__init__( + num_channels, num_filters, filter_size, output_size, padding, + stride, dilation, groups, param_attr, bias_attr, use_cudnn, act, + dtype) + self.candidate_config = candidate_config + if len(self.candidate_config.items()) != 0: + for k, v in candidate_config.items(): + candidate_config[k] = list(set(v)) + self.ks_set = candidate_config[ + 'kernel_size'] if 'kernel_size' in candidate_config else None + + if isinstance(self._filter_size, int): + self._filter_size = convert_to_list(self._filter_size, 2) + + self.expand_ratio = candidate_config[ + 'expand_ratio'] if 'expand_ratio' in candidate_config else None + self.channel = candidate_config[ + 'channel'] if 'channel' in candidate_config else None + self.base_channel = None + if self.expand_ratio: + self.base_channel = int(self._num_filters / max(self.expand_ratio)) + + self.transform_kernel = transform_kernel + if self.ks_set != None: + self.ks_set.sort() + if self.transform_kernel != False: + scale_param = dict() + ### create parameter to transform kernel + for i in range(len(self.ks_set) - 1): + ks_small = self.ks_set[i] + ks_large = self.ks_set[i + 1] + param_name = '%dto%d_matrix' % (ks_large, ks_small) + ks_t = ks_small**2 + scale_param[param_name] = self.create_parameter( + attr=fluid.ParamAttr( + name=self._full_name + param_name, + initializer=fluid.initializer.NumpyArrayInitializer( + np.eye(ks_t))), + shape=(ks_t, ks_t), + dtype=self._dtype) + + for name, param in scale_param.items(): + setattr(self, name, param) + + def get_active_filter(self, in_nc, out_nc, kernel_size): + start, end = compute_start_end(self._filter_size[0], kernel_size) + filters = self.weight[:in_nc, :out_nc, start:end, start:end] + if self.transform_kernel != False and kernel_size < self._filter_size[ + 0]: + start_filter = self.weight[:in_nc, :out_nc, :, :] + for i in range(len(self.ks_set) - 1, 0, -1): + src_ks = self.ks_set[i] + if src_ks <= kernel_size: + break + target_ks = self.ks_set[i - 1] + start, end = compute_start_end(src_ks, target_ks) + _input_filter = start_filter[:, :, start:end, start:end] + _input_filter = fluid.layers.reshape( + _input_filter, + shape=[(_input_filter.shape[0] * _input_filter.shape[1]), + -1]) + core.ops.matmul(_input_filter, + self.__getattr__('%dto%d_matrix' % + (src_ks, target_ks)), + _input_filter, 'transpose_X', False, + 'transpose_Y', False, "alpha", 1) + _input_filter = fluid.layers.reshape( + _input_filter, + shape=[ + filters.shape[0], filters.shape[1], target_ks, target_ks + ]) + start_filter = _input_filter + filters = start_filter + return filters + + def get_groups_in_out_nc(self, in_nc, out_nc): + ### standard conv + return self._groups, in_nc, out_nc + + def forward(self, input, kernel_size=None, expand_ratio=None, channel=None): + if not in_dygraph_mode(): + _logger.error("NOT support static graph") + + in_nc = int(input.shape[1]) + assert ( + expand_ratio == None or channel == None + ), "expand_ratio and channel CANNOT be NOT None at the same time." + if expand_ratio != None: + out_nc = int(expand_ratio * self.base_channel) + elif channel != None: + out_nc = int(channel) + else: + out_nc = self._num_filters + + ks = int(self._filter_size[0]) if kernel_size == None else int( + kernel_size) + + groups, weight_in_nc, weight_out_nc = self.get_groups_in_out_nc(in_nc, + out_nc) + + weight = self.get_active_filter(weight_in_nc, weight_out_nc, ks) + padding = convert_to_list(get_same_padding(ks), 2) + + op = getattr(core.ops, self._op_type) + out = op(input, weight, 'output_size', self._output_size, 'strides', + self._stride, 'paddings', padding, 'dilations', self._dilation, + 'groups', groups, 'use_cudnn', self._use_cudnn) + pre_bias = out + out_nc = int(pre_bias.shape[1]) + if self.bias is not None: + bias = self.bias[:out_nc] + pre_act = dygraph_utils._append_bias_in_dygraph(pre_bias, bias, 1) + else: + pre_act = pre_bias + + return dygraph_utils._append_activation_in_dygraph( + pre_act, act=self._act) + + +class SuperGroupConv2DTranspose(SuperConv2DTranspose): + def get_groups_in_out_nc(self, in_nc, out_nc): + ### groups convolution + ### groups conv transpose: weight: (Cin, Cout/G, Kh, Kw) + groups = self._groups + out_nc = int(out_nc // groups) + return groups, in_nc, out_nc + + +class SuperDepthwiseConv2DTranspose(SuperConv2DTranspose): + def get_groups_in_out_nc(self, in_nc, out_nc): + if in_nc != out_nc: + _logger.debug( + "input channel and output channel in depthwise conv transpose is different, change output channel to input channel! origin channel:(in_nc {}, out_nc {}): ". + format(in_nc, out_nc)) + groups = in_nc + out_nc = in_nc + return groups, in_nc, out_nc + + +### NOTE: only search channel, write for GAN-compression, maybe change to SuperDepthwiseConv and SuperConv after. +class SuperSeparableConv2D(fluid.dygraph.Layer): + """ + This interface is used to construct a callable object of the ``SuperSeparableConv2D`` + class. + The difference between ```SuperSeparableConv2D``` and ```SeparableConv2D``` is: + ```SuperSeparableConv2D``` need to feed a config dictionary with the format of + {'channel', num_of_channel} represents the channels of the first conv's outputs and + the second conv's inputs, used to change the first dimension of weight and bias, + only train the first channels of the weight and bias. + + The architecture of super separable convolution2D op is [Conv2D, norm layer(may be BatchNorm + or InstanceNorm), Conv2D]. The first conv is depthwise conv, the filter number is input channel + multiply scale_factor, the group is equal to the number of input channel. The second conv + is standard conv, which filter size and stride size are 1. + + Parameters: + num_channels(int): The number of channels in the input image. + num_filters(int): The number of the second conv's filter. It is as same as the output + feature map. + filter_size(int or tuple): The first conv's filter size. If filter_size is a tuple, + it must contain two integers, (filter_size_H, filter_size_W). + Otherwise, the filter will be a square. + padding(int or tuple, optional): The first conv's padding size. If padding is a tuple, + it must contain two integers, (padding_H, padding_W). Otherwise, the + padding_H = padding_W = padding. Default: 0. + stride(int or tuple, optional): The first conv's stride size. If stride is a tuple, + it must contain two integers, (stride_H, stride_W). Otherwise, the + stride_H = stride_W = stride. Default: 1. + dilation(int or tuple, optional): The first conv's dilation size. If dilation is a tuple, + it must contain two integers, (dilation_H, dilation_W). Otherwise, the + dilation_H = dilation_W = dilation. Default: 1. + norm_layer(class): The normalization layer between two convolution. Default: InstanceNorm. + bias_attr (ParamAttr or bool, optional): The attribute for the bias of convolution. + If it is set to False, no bias will be added to the output units. + If it is set to None or one attribute of ParamAttr, convolution + will create ParamAttr as bias_attr. If the Initializer of the bias_attr + is not set, the bias is initialized zero. Default: None. + scale_factor(float): The scale factor of the first conv's output channel. Default: 1. + use_cudnn(bool, optional): Use cudnn kernel or not, it is valid only when the cudnn + library is installed. Default: True. + Returns: + None + """ + + def __init__(self, + num_channels, + num_filters, + filter_size, + candidate_config={}, + stride=1, + padding=0, + dilation=1, + norm_layer=InstanceNorm, + bias_attr=None, + scale_factor=1, + use_cudnn=False): + super(SuperSeparableConv2D, self).__init__() + self.conv = fluid.dygraph.LayerList([ + fluid.dygraph.nn.Conv2D( + num_channels=num_channels, + num_filters=num_channels * scale_factor, + filter_size=filter_size, + stride=stride, + padding=padding, + use_cudnn=False, + groups=num_channels, + bias_attr=bias_attr) + ]) + + self.conv.extend([norm_layer(num_channels * scale_factor)]) + + self.conv.extend([ + Conv2D( + num_channels=num_channels * scale_factor, + num_filters=num_filters, + filter_size=1, + stride=1, + use_cudnn=use_cudnn, + bias_attr=bias_attr) + ]) + + self.candidate_config = candidate_config + self.expand_ratio = candidate_config[ + 'expand_ratio'] if 'expand_ratio' in candidate_config else None + self.base_output_dim = None + if self.expand_ratio != None: + self.base_output_dim = int(self.output_dim / max(self.expand_ratio)) + + def forward(self, input, expand_ratio=None, channel=None): + if not in_dygraph_mode(): + _logger.error("NOT support static graph") + + in_nc = int(input.shape[1]) + assert ( + expand_ratio == None or channel == None + ), "expand_ratio and channel CANNOT be NOT None at the same time." + if expand_ratio != None: + out_nc = int(expand_ratio * self.base_output_dim) + elif channel != None: + out_nc = int(channel) + else: + out_nc = self.conv[0]._num_filters + + weight = self.conv[0].weight[:in_nc] + ### conv1 + if self.conv[0]._l_type == 'conv2d': + attrs = ('strides', self.conv[0]._stride, 'paddings', + self.conv[0]._padding, 'dilations', self.conv[0]._dilation, + 'groups', in_nc, 'use_cudnn', self.conv[0]._use_cudnn) + out = core.ops.conv2d(input, weight, *attrs) + elif self.conv[0]._l_type == 'depthwise_conv2d': + attrs = ('strides', self.conv[0]._stride, 'paddings', + self.conv[0]._padding, 'dilations', self.conv[0]._dilation, + 'groups', in_nc, 'use_cudnn', self.conv[0]._use_cudnn) + out = core.ops.depthwise_conv2d(input, weight, *attrs) + else: + raise ValueError("conv type error") + + pre_bias = out + if self.conv[0].bias is not None: + bias = self.conv[0].bias[:in_nc] + pre_act = dygraph_utils._append_bias_in_dygraph(pre_bias, bias, 1) + else: + pre_act = pre_bias + + conv0_out = dygraph_utils._append_activation_in_dygraph( + pre_act, self.conv[0]._act) + + norm_out = self.conv[1](conv0_out) + + weight = self.conv[2].weight[:out_nc, :in_nc, :, :] + + if self.conv[2]._l_type == 'conv2d': + attrs = ('strides', self.conv[2]._stride, 'paddings', + self.conv[2]._padding, 'dilations', self.conv[2]._dilation, + 'groups', self.conv[2]._groups if self.conv[2]._groups else + 1, 'use_cudnn', self.conv[2]._use_cudnn) + out = core.ops.conv2d(norm_out, weight, *attrs) + elif self.conv[2]._l_type == 'depthwise_conv2d': + attrs = ('strides', self.conv[2]._stride, 'paddings', + self.conv[2]._padding, 'dilations', self.conv[2]._dilation, + 'groups', self.conv[2]._groups, 'use_cudnn', + self.conv[2]._use_cudnn) + out = core.ops.depthwise_conv2d(norm_out, weight, *attrs) + else: + raise ValueError("conv type error") + + pre_bias = out + if self.conv[2].bias is not None: + bias = self.conv[2].bias[:out_nc] + pre_act = dygraph_utils._append_bias_in_dygraph(pre_bias, bias, 1) + else: + pre_act = pre_bias + + conv1_out = dygraph_utils._append_activation_in_dygraph( + pre_act, self.conv[2]._act) + + return conv1_out + + +class SuperLinear(fluid.dygraph.Linear): + """ + """ + + def __init__(self, + input_dim, + output_dim, + candidate_config={}, + param_attr=None, + bias_attr=None, + act=None, + dtype="float32"): + super(SuperLinear, self).__init__(input_dim, output_dim, param_attr, + bias_attr, act, dtype) + self._param_attr = param_attr + self._bias_attr = bias_attr + self.output_dim = output_dim + self.candidate_config = candidate_config + self.expand_ratio = candidate_config[ + 'expand_ratio'] if 'expand_ratio' in candidate_config else None + self.base_output_dim = None + if self.expand_ratio != None: + self.base_output_dim = int(self.output_dim / max(self.expand_ratio)) + + def forward(self, input, expand_ratio=None, channel=None): + if not in_dygraph_mode(): + _logger.error("NOT support static graph") + + ### weight: (Cin, Cout) + in_nc = int(input.shape[1]) + assert ( + expand_ratio == None or channel == None + ), "expand_ratio and channel CANNOT be NOT None at the same time." + if expand_ratio != None: + out_nc = int(expand_ratio * self.base_output_dim) + elif channel != None: + out_nc = int(channel) + else: + out_nc = self.output_dim + + weight = self.weight[:in_nc, :out_nc] + if self._bias_attr != False: + bias = self.bias[:out_nc] + use_bias = True + + pre_bias = _varbase_creator(dtype=input.dtype) + core.ops.matmul(input, weight, pre_bias, 'transpose_X', False, + 'transpose_Y', False, "alpha", 1) + if self._bias_attr != False: + pre_act = dygraph_utils._append_bias_in_dygraph( + pre_bias, bias, axis=len(input.shape) - 1) + else: + pre_act = pre_bias + + return dygraph_utils._append_activation_in_dygraph(pre_act, self._act) + + +class SuperBatchNorm(fluid.dygraph.BatchNorm): + """ + add comment + """ + + def __init__(self, + num_channels, + act=None, + is_test=False, + momentum=0.9, + epsilon=1e-05, + param_attr=None, + bias_attr=None, + dtype='float32', + data_layout='NCHW', + in_place=False, + moving_mean_name=None, + moving_variance_name=None, + do_model_average_for_mean_and_var=True, + use_global_stats=False, + trainable_statistics=False): + super(SuperBatchNorm, self).__init__( + num_channels, act, is_test, momentum, epsilon, param_attr, + bias_attr, dtype, data_layout, in_place, moving_mean_name, + moving_variance_name, do_model_average_for_mean_and_var, + use_global_stats, trainable_statistics) + + def forward(self, input): + if not in_dygraph_mode(): + _logger.error("NOT support static graph") + + feature_dim = int(input.shape[1]) + + weight = self.weight[:feature_dim] + bias = self.bias[:feature_dim] + mean = self._mean[:feature_dim] + variance = self._variance[:feature_dim] + + mean_out = mean + variance_out = variance + + attrs = ("momentum", self._momentum, "epsilon", self._epsilon, + "is_test", not self.training, "data_layout", self._data_layout, + "use_mkldnn", False, "fuse_with_relu", self._fuse_with_relu, + "use_global_stats", self._use_global_stats, + 'trainable_statistics', self._trainable_statistics) + batch_norm_out, _, _, _, _, _ = core.ops.batch_norm( + input, weight, bias, mean, variance, mean_out, variance_out, *attrs) + return dygraph_utils._append_activation_in_dygraph( + batch_norm_out, act=self._act) + + +class SuperInstanceNorm(fluid.dygraph.InstanceNorm): + """ + """ + + def __init__(self, + num_channels, + epsilon=1e-05, + param_attr=None, + bias_attr=None, + dtype='float32'): + super(SuperInstanceNorm, self).__init__(num_channels, epsilon, + param_attr, bias_attr, dtype) + + def forward(self, input): + if not in_dygraph_mode(): + _logger.error("NOT support static graph") + + feature_dim = int(input.shape[1]) + + if self._param_attr == False and self._bias_attr == False: + scale = None + bias = None + else: + scale = self.scale[:feature_dim] + bias = self.bias[:feature_dim] + + out, _, _ = core.ops.instance_norm(input, scale, bias, 'epsilon', + self._epsilon) + return out diff --git a/paddleslim/nas/ofa/ofa.py b/paddleslim/nas/ofa/ofa.py new file mode 100644 index 0000000000000000000000000000000000000000..9fd7f5ada5d0f59eabd9ac580b9453f183bd78f1 --- /dev/null +++ b/paddleslim/nas/ofa/ofa.py @@ -0,0 +1,319 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import numpy as np +from collections import namedtuple +import paddle +import paddle.nn as nn +import paddle.fluid as fluid +from paddle.fluid.dygraph import Conv2D +from .layers import BaseBlock, Block, SuperConv2D, SuperBatchNorm +from .utils.utils import search_idx +from ...common import get_logger + +_logger = get_logger(__name__, level=logging.INFO) + +__all__ = ['OFA', 'RunConfig', 'DistillConfig'] + +RunConfig = namedtuple('RunConfig', [ + 'train_batch_size', 'eval_batch_size', 'n_epochs', 'save_frequency', + 'eval_frequency', 'init_learning_rate', 'total_images', 'elastic_depth', + 'dynamic_batch_size' +]) +RunConfig.__new__.__defaults__ = (None, ) * len(RunConfig._fields) + +DistillConfig = namedtuple('DistillConfig', [ + 'lambda_distill', 'teacher_model', 'mapping_layers', 'teacher_model_path', + 'distill_fn' +]) +DistillConfig.__new__.__defaults__ = (None, ) * len(DistillConfig._fields) + + +class OFABase(fluid.dygraph.Layer): + def __init__(self, model): + super(OFABase, self).__init__() + self.model = model + self._layers, self._elastic_task = self.get_layers() + + def get_layers(self): + layers = dict() + elastic_task = set() + for name, sublayer in self.model.named_sublayers(): + if isinstance(sublayer, BaseBlock): + sublayer.set_supernet(self) + layers[sublayer.key] = sublayer.candidate_config + for k in sublayer.candidate_config.keys(): + elastic_task.add(k) + return layers, elastic_task + + def forward(self, *inputs, **kwargs): + raise NotImplementedError + + # NOTE: config means set forward config for layers, used in distill. + def layers_forward(self, block, *inputs, **kwargs): + if getattr(self, 'current_config', None) != None: + assert block.key in self.current_config, 'DONNT have {} layer in config.'.format( + block.key) + config = self.current_config[block.key] + else: + config = dict() + logging.debug(self.model, config) + + return block.fn(*inputs, **config) + + @property + def layers(self): + return self._layers + + +class OFA(OFABase): + def __init__(self, + model, + run_config, + net_config=None, + distill_config=None, + elastic_order=None, + train_full=False): + super(OFA, self).__init__(model) + self.net_config = net_config + self.run_config = run_config + self.distill_config = distill_config + self.elastic_order = elastic_order + self.train_full = train_full + self.iter_per_epochs = self.run_config.total_images // self.run_config.train_batch_size + self.iter = 0 + self.dynamic_iter = 0 + self.manual_set_task = False + self.task_idx = 0 + self._add_teacher = False + self.netAs_param = [] + + for idx in range(len(run_config.n_epochs)): + assert isinstance( + run_config.init_learning_rate[idx], + list), "each candidate in init_learning_rate must be list" + assert isinstance(run_config.n_epochs[idx], + list), "each candidate in n_epochs must be list" + + ### if elastic_order is none, use default order + if self.elastic_order is not None: + assert isinstance(self.elastic_order, + list), 'elastic_order must be a list' + + if self.elastic_order is None: + self.elastic_order = [] + # zero, elastic resulotion, write in demo + # first, elastic kernel size + if 'kernel_size' in self._elastic_task: + self.elastic_order.append('kernel_size') + + # second, elastic depth, such as: list(2, 3, 4) + if getattr(self.run_config, 'elastic_depth', None) != None: + depth_list = list(set(self.run_config.elastic_depth)) + depth_list.sort() + self.layers['depth'] = depth_list + self.elastic_order.append('depth') + + # final, elastic width + if 'expand_ratio' in self._elastic_task: + self.elastic_order.append('width') + + if 'channel' in self._elastic_task and 'width' not in self.elastic_order: + self.elastic_order.append('width') + + assert len(self.run_config.n_epochs) == len(self.elastic_order) + assert len(self.run_config.n_epochs) == len( + self.run_config.dynamic_batch_size) + assert len(self.run_config.n_epochs) == len( + self.run_config.init_learning_rate) + + ### ================= add distill prepare ====================== + if self.distill_config != None and ( + self.distill_config.lambda_distill != None and + self.distill_config.lambda_distill > 0): + self._add_teacher = True + self._prepare_distill() + + self.model.train() + + def _prepare_distill(self): + self.Tacts, self.Sacts = {}, {} + + if self.distill_config.teacher_model == None: + logging.error( + 'If you want to add distill, please input class of teacher model' + ) + + assert isinstance(self.distill_config.teacher_model, + paddle.fluid.dygraph.Layer) + + # load teacher parameter + if self.distill_config.teacher_model_path != None: + param_state_dict, _ = paddle.load_dygraph( + self.distill_config.teacher_model_path) + self.distill_config.teacher_model.set_dict(param_state_dict) + + self.ofa_teacher_model = OFABase(self.distill_config.teacher_model) + self.ofa_teacher_model.model.eval() + + # add hook if mapping layers is not None + # if mapping layer is None, return the output of the teacher model, + # if mapping layer is NOT None, add hook and compute distill loss about mapping layers. + mapping_layers = self.distill_config.mapping_layers + if mapping_layers != None: + self.netAs = [] + for name, sublayer in self.model.named_sublayers(): + if name in mapping_layers: + netA = SuperConv2D( + sublayer._num_filters, + sublayer._num_filters, + filter_size=1) + self.netAs_param.extend(netA.parameters()) + self.netAs.append(netA) + + def get_activation(mem, name): + def get_output_hook(layer, input, output): + mem[name] = output + + return get_output_hook + + def add_hook(net, mem, mapping_layers): + for idx, (n, m) in enumerate(net.named_sublayers()): + if n in mapping_layers: + m.register_forward_post_hook(get_activation(mem, n)) + + add_hook(self.model, self.Sacts, mapping_layers) + add_hook(self.ofa_teacher_model.model, self.Tacts, mapping_layers) + + def _compute_epochs(self): + if getattr(self, 'epoch', None) == None: + epoch = self.iter // self.iter_per_epochs + else: + epoch = self.epochs + return epoch + + def _sample_from_nestdict(self, cands, sample_type, task, phase): + sample_cands = dict() + for k, v in cands.items(): + if isinstance(v, dict): + sample_cands[k] = self._sample_from_nestdict( + v, sample_type=sample_type, task=task, phase=phase) + elif isinstance(v, list) or isinstance(v, set) or isinstance(v, + tuple): + if sample_type == 'largest': + sample_cands[k] = v[-1] + elif sample_type == 'smallest': + sample_cands[k] = v[0] + else: + if k not in task: + # sort and deduplication in candidate_config + # fixed candidate not in task_list + sample_cands[k] = v[-1] + else: + # phase == None -> all candidate; phase == number, append small candidate in each phase + # phase only affect last task in current task_list + if phase != None and k == task[-1]: + start = -(phase + 2) + else: + start = 0 + sample_cands[k] = np.random.choice(v[start:]) + + return sample_cands + + def _sample_config(self, task, sample_type='random', phase=None): + config = self._sample_from_nestdict( + self.layers, sample_type=sample_type, task=task, phase=phase) + return config + + def set_task(self, task=None, phase=None): + self.manual_set_task = True + self.task = task + self.phase = phase + + def set_epoch(self, epoch): + self.epoch = epoch + + def _progressive_shrinking(self): + epoch = self._compute_epochs() + self.task_idx, phase_idx = search_idx(epoch, self.run_config.n_epochs) + self.task = self.elastic_order[:self.task_idx + 1] + if 'width' in self.task: + ### change width in task to concrete config + self.task.remove('width') + if 'expand_ratio' in self._elastic_task: + self.task.append('expand_ratio') + if 'channel' in self._elastic_task: + self.task.append('channel') + if len(self.run_config.n_epochs[self.task_idx]) == 1: + phase_idx = None + return self._sample_config(task=self.task, phase=phase_idx) + + def calc_distill_loss(self): + losses = [] + assert len(self.netAs) > 0 + for i, netA in enumerate(self.netAs): + assert isinstance(netA, SuperConv2D) + n = self.distill_config.mapping_layers[i] + Tact = self.Tacts[n] + Sact = self.Sacts[n] + Sact = netA(Sact, channel=netA._num_filters) + if self.distill_config.distill_fn == None: + loss = fluid.layers.mse_loss(Sact, Tact) + else: + loss = distill_fn(Sact, Tact) + losses.append(loss) + return sum(losses) * self.distill_config.lambda_distill + + ### TODO: complete it + def search(self, eval_func, condition): + pass + + ### TODO: complete it + def export(self, config): + pass + + def forward(self, *inputs, **kwargs): + # ===================== teacher process ===================== + teacher_output = None + if self._add_teacher: + teacher_output = self.ofa_teacher_model.model.forward(*inputs, + **kwargs) + # ============================================================ + + # ==================== student process ===================== + self.dynamic_iter += 1 + if self.dynamic_iter == self.run_config.dynamic_batch_size[ + self.task_idx]: + self.iter += 1 + self.dynamic_iter = 0 + + if self.net_config == None: + if self.train_full == True: + self.current_config = self._sample_config( + task=None, sample_type='largest') + else: + if self.manual_set_task == False: + self.current_config = self._progressive_shrinking() + else: + self.current_config = self._sample_config( + self.task, phase=self.phase) + else: + self.current_config = self.net_config + + _logger.debug("Current config is {}".format(self.current_config)) + if 'depth' in self.current_config: + kwargs['depth'] = int(self.current_config['depth']) + + return self.model.forward(*inputs, **kwargs), teacher_output diff --git a/paddleslim/nas/ofa/utils/__init__.py b/paddleslim/nas/ofa/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..342ae0eddcff168fb62bb08708af868dbc808aa5 --- /dev/null +++ b/paddleslim/nas/ofa/utils/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .utils import * diff --git a/paddleslim/nas/ofa/utils/utils.py b/paddleslim/nas/ofa/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fad016a754f61df9c72c04956901d978db0b6df6 --- /dev/null +++ b/paddleslim/nas/ofa/utils/utils.py @@ -0,0 +1,46 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def compute_start_end(kernel_size, sub_kernel_size): + center = kernel_size // 2 + sub_center = sub_kernel_size // 2 + start = center - sub_center + end = center + sub_center + 1 + assert end - start == sub_kernel_size + return start, end + + +def get_same_padding(kernel_size): + assert isinstance(kernel_size, int) + assert kernel_size % 2 > 0, "kernel size must be odd number" + return kernel_size // 2 + + +def convert_to_list(value, n): + return [value, ] * n + + +def search_idx(num, sorted_nestlist): + max_num = -1 + max_idx = -1 + for idx in range(len(sorted_nestlist)): + task_ = sorted_nestlist[idx] + max_num = task_[-1] + max_idx = len(task_) - 1 + for phase_idx in range(len(task_)): + if num <= task_[phase_idx]: + return idx, phase_idx + assert num > max_num + return len(sorted_nestlist) - 1, max_idx diff --git a/tests/test_ofa.py b/tests/test_ofa.py new file mode 100644 index 0000000000000000000000000000000000000000..b65d12e74a6f9ece7866db8468f7e8a1337e485c --- /dev/null +++ b/tests/test_ofa.py @@ -0,0 +1,216 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +sys.path.append("../") +import numpy as np +import unittest +import paddle +import paddle.fluid as fluid +import paddle.fluid.dygraph.nn as nn +from paddle.nn import ReLU +from paddleslim.nas import ofa +from paddleslim.nas.ofa import OFA, RunConfig, DistillConfig +from paddleslim.nas.ofa.convert_super import supernet +from paddleslim.nas.ofa.layers import Block, SuperSeparableConv2D + + +class ModelConv(fluid.dygraph.Layer): + def __init__(self): + super(ModelConv, self).__init__() + with supernet( + kernel_size=(3, 5, 7), + channel=((4, 8, 12), (8, 12, 16), (8, 12, 16), + (8, 12, 16))) as ofa_super: + models = [] + models += [nn.Conv2D(3, 4, 3)] + models += [nn.InstanceNorm(4)] + models += [ReLU()] + models += [nn.Conv2D(4, 4, 3, groups=4)] + models += [nn.InstanceNorm(4)] + models += [ReLU()] + models += [nn.Conv2DTranspose(4, 4, 3, groups=4, use_cudnn=True)] + models += [nn.BatchNorm(4)] + models += [ReLU()] + models += [nn.Conv2D(4, 3, 3)] + models += [ReLU()] + models = ofa_super.convert(models) + + models += [ + Block( + SuperSeparableConv2D( + 3, 6, 1, candidate_config={'channel': (3, 6)})) + ] + with supernet( + kernel_size=(3, 5, 7), expand_ratio=(1, 2, 4)) as ofa_super: + models1 = [] + models1 += [nn.Conv2D(6, 4, 3)] + models1 += [nn.BatchNorm(4)] + models1 += [ReLU()] + models1 += [nn.Conv2D(4, 4, 3, groups=2)] + models1 += [nn.InstanceNorm(4)] + models1 += [ReLU()] + models1 += [nn.Conv2DTranspose(4, 4, 3, groups=2)] + models1 += [nn.BatchNorm(4)] + models1 += [ReLU()] + models1 += [nn.Conv2DTranspose(4, 4, 3)] + models1 += [nn.BatchNorm(4)] + models1 += [ReLU()] + models1 = ofa_super.convert(models1) + + models += models1 + + self.models = paddle.nn.Sequential(*models) + + def forward(self, inputs, depth=None): + if depth != None: + assert isinstance(depth, int) + assert depth <= len(self.models) + else: + depth = len(self.models) + for idx in range(depth): + layer = self.models[idx] + inputs = layer(inputs) + return inputs + + +class ModelLinear(fluid.dygraph.Layer): + def __init__(self): + super(ModelLinear, self).__init__() + models = [] + with supernet(expand_ratio=(1, 2, 4)) as ofa_super: + models1 = [] + models1 += [nn.Linear(64, 128)] + models1 += [nn.Linear(128, 256)] + models1 = ofa_super.convert(models1) + + models += models1 + + with supernet(channel=((64, 128, 256), (64, 128, 256))) as ofa_super: + models1 = [] + models1 += [nn.Linear(256, 128)] + models1 += [nn.Linear(128, 256)] + models1 = ofa_super.convert(models1) + + models += models1 + + self.models = paddle.nn.Sequential(*models) + + def forward(self, inputs, depth=None): + if depth != None: + assert isinstance(depth, int) + assert depth < len(self.models) + else: + depth = len(self.models) + for idx in range(depth): + layer = self.models[idx] + inputs = layer(inputs) + return inputs + + +class TestOFA(unittest.TestCase): + def setUp(self): + fluid.enable_dygraph() + self.init_model_and_data() + self.init_config() + + def init_model_and_data(self): + self.model = ModelConv() + self.teacher_model = ModelConv() + data_np = np.random.random((1, 3, 10, 10)).astype(np.float32) + label_np = np.random.random((1)).astype(np.float32) + + self.data = fluid.dygraph.to_variable(data_np) + + def init_config(self): + default_run_config = { + 'train_batch_size': 1, + 'eval_batch_size': 1, + 'n_epochs': [[1], [2, 3], [4, 5]], + 'init_learning_rate': [[0.001], [0.003, 0.001], [0.003, 0.001]], + 'dynamic_batch_size': [1, 1, 1], + 'total_images': 1, + 'elastic_depth': (5, 15, 24) + } + self.run_config = RunConfig(**default_run_config) + + default_distill_config = { + 'lambda_distill': 0.01, + 'teacher_model': self.teacher_model, + 'mapping_layers': ['models.0.fn'] + } + self.distill_config = DistillConfig(**default_distill_config) + + def test_ofa(self): + ofa_model = OFA(self.model, + self.run_config, + distill_config=self.distill_config) + + start_epoch = 0 + for idx in range(len(self.run_config.n_epochs)): + cur_idx = self.run_config.n_epochs[idx] + for ph_idx in range(len(cur_idx)): + cur_lr = self.run_config.init_learning_rate[idx][ph_idx] + adam = fluid.optimizer.Adam( + learning_rate=cur_lr, + parameter_list=( + ofa_model.parameters() + ofa_model.netAs_param)) + for epoch_id in range(start_epoch, + self.run_config.n_epochs[idx][ph_idx]): + for model_no in range(self.run_config.dynamic_batch_size[ + idx]): + output, _ = ofa_model(self.data) + loss = fluid.layers.reduce_mean(output) + if self.distill_config.mapping_layers != None: + dis_loss = ofa_model.calc_distill_loss() + loss += dis_loss + dis_loss = dis_loss.numpy()[0] + else: + dis_loss = 0 + print('epoch: {}, loss: {}, distill loss: {}'.format( + epoch_id, loss.numpy()[0], dis_loss)) + loss.backward() + adam.minimize(loss) + adam.clear_gradients() + start_epoch = self.run_config.n_epochs[idx][ph_idx] + + +class TestOFACase1(TestOFA): + def init_model_and_data(self): + self.model = ModelLinear() + self.teacher_model = ModelLinear() + data_np = np.random.random((3, 64)).astype(np.float32) + + self.data = fluid.dygraph.to_variable(data_np) + + def init_config(self): + default_run_config = { + 'train_batch_size': 1, + 'eval_batch_size': 1, + 'n_epochs': [[2, 5]], + 'init_learning_rate': [[0.003, 0.001]], + 'dynamic_batch_size': [1], + 'total_images': 1, + } + self.run_config = RunConfig(**default_run_config) + + default_distill_config = { + 'lambda_distill': 0.01, + 'teacher_model': self.teacher_model, + } + self.distill_config = DistillConfig(**default_distill_config) + + +if __name__ == '__main__': + unittest.main()