diff --git a/demo/one_shot/ofa_train.py b/demo/one_shot/ofa_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a47a219c1096d750757f407cfde4ff37691efb7
--- /dev/null
+++ b/demo/one_shot/ofa_train.py
@@ -0,0 +1,127 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import paddle
+import paddle.fluid as fluid
+import paddle.fluid.dygraph.nn as nn
+from paddle.nn import ReLU
+from paddleslim.nas.ofa import OFA, RunConfig, DistillConfig
+from paddleslim.nas.ofa import supernet
+
+
+class Model(fluid.dygraph.Layer):
+ def __init__(self):
+ super(Model, self).__init__()
+ with supernet(
+ kernel_size=(3, 5, 7), expand_ratio=[1, 2, 4]) as ofa_super:
+ models = []
+ models += [nn.Conv2D(1, 6, 3)]
+ models += [ReLU()]
+ models += [nn.Pool2D(2, 'max', 2)]
+ models += [nn.Conv2D(6, 16, 5, padding=0)]
+ models += [ReLU()]
+ models += [nn.Pool2D(2, 'max', 2)]
+ models += [
+ nn.Linear(784, 120), nn.Linear(120, 84), nn.Linear(84, 10)
+ ]
+ models = ofa_super.convert(models)
+ self.models = paddle.nn.Sequential(*models)
+
+ def forward(self, inputs, label, depth=None):
+ if depth != None:
+ assert isinstance(depth, int)
+ assert depth < len(self.models)
+ models = self.models[:depth]
+ else:
+ depth = len(self.models)
+ models = self.models[:]
+
+ for idx, layer in enumerate(models):
+ if idx == 6:
+ inputs = fluid.layers.flatten(inputs, 1)
+ inputs = layer(inputs)
+
+ inputs = fluid.layers.softmax(inputs)
+ return inputs
+
+
+def test_ofa():
+
+ default_run_config = {
+ 'train_batch_size': 256,
+ 'eval_batch_size': 64,
+ 'n_epochs': [[1], [2, 3], [4, 5]],
+ 'init_learning_rate': [[0.001], [0.003, 0.001], [0.003, 0.001]],
+ 'dynamic_batch_size': [1, 1, 1],
+ 'total_images': 50000, #1281167,
+ 'elastic_depth': (2, 5, 8)
+ }
+ run_config = RunConfig(**default_run_config)
+
+ default_distill_config = {
+ 'lambda_distill': 0.01,
+ 'teacher_model': Model,
+ 'mapping_layers': ['models.0.fn']
+ }
+ distill_config = DistillConfig(**default_distill_config)
+
+ fluid.enable_dygraph()
+ model = Model()
+ ofa_model = OFA(model, run_config, distill_config=distill_config)
+
+ train_reader = paddle.fluid.io.batch(
+ paddle.dataset.mnist.train(), batch_size=256, drop_last=True)
+
+ start_epoch = 0
+ for idx in range(len(run_config.n_epochs)):
+ cur_idx = run_config.n_epochs[idx]
+ for ph_idx in range(len(cur_idx)):
+ cur_lr = run_config.init_learning_rate[idx][ph_idx]
+ adam = fluid.optimizer.Adam(
+ learning_rate=cur_lr,
+ parameter_list=(ofa_model.parameters() + ofa_model.netAs_param))
+ for epoch_id in range(start_epoch,
+ run_config.n_epochs[idx][ph_idx]):
+ for batch_id, data in enumerate(train_reader()):
+ dy_x_data = np.array(
+ [x[0].reshape(1, 28, 28)
+ for x in data]).astype('float32')
+ y_data = np.array(
+ [x[1] for x in data]).astype('int64').reshape(-1, 1)
+
+ img = fluid.dygraph.to_variable(dy_x_data)
+ label = fluid.dygraph.to_variable(y_data)
+ label.stop_gradient = True
+
+ for model_no in range(run_config.dynamic_batch_size[idx]):
+ output, _ = ofa_model(img, label)
+ loss = fluid.layers.reduce_mean(output)
+ dis_loss = ofa_model.calc_distill_loss()
+ loss += dis_loss
+ loss.backward()
+
+ if batch_id % 10 == 0:
+ print(
+ 'epoch: {}, batch: {}, loss: {}, distill loss: {}'.
+ format(epoch_id, batch_id,
+ loss.numpy()[0], dis_loss.numpy()[0]))
+ ### accumurate dynamic_batch_size network of gradients for same batch of data
+ ### NOTE: need to fix gradients accumulate in PaddlePaddle
+ adam.minimize(loss)
+ adam.clear_gradients()
+ start_epoch = run_config.n_epochs[idx][ph_idx]
+
+
+test_ofa()
diff --git a/paddleslim/nas/ofa/__init__.py b/paddleslim/nas/ofa/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..db1394baf6dc59286b678f302b80fe2c5de404c1
--- /dev/null
+++ b/paddleslim/nas/ofa/__init__.py
@@ -0,0 +1,17 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .ofa import OFA, RunConfig, DistillConfig
+from .convert_super import supernet
+from .layers import *
diff --git a/paddleslim/nas/ofa/convert_super.py b/paddleslim/nas/ofa/convert_super.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7ff8a1e530cef850415049c1d8a1b42dfcc0345
--- /dev/null
+++ b/paddleslim/nas/ofa/convert_super.py
@@ -0,0 +1,417 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import decorator
+import logging
+import paddle
+import paddle.fluid as fluid
+from paddle.fluid import framework
+from paddle.fluid.dygraph.nn import Conv2D, Conv2DTranspose, Linear, BatchNorm, InstanceNorm
+from .layers import *
+from ...common import get_logger
+
+_logger = get_logger(__name__, level=logging.INFO)
+
+__all__ = ['supernet']
+
+WEIGHT_LAYER = ['conv', 'linear']
+
+
+### TODO: add decorator
+class Convert:
+ def __init__(self, context):
+ self.context = context
+
+ def convert(self, model):
+ # search the first and last weight layer, don't change out channel of the last weight layer
+ # don't change in channel of the first weight layer
+ first_weight_layer_idx = -1
+ last_weight_layer_idx = -1
+ weight_layer_count = 0
+ # NOTE: pre_channel store for shortcut module
+ pre_channel = 0
+ cur_channel = None
+ for idx, layer in enumerate(model):
+ cls_name = layer.__class__.__name__.lower()
+ if 'conv' in cls_name or 'linear' in cls_name:
+ weight_layer_count += 1
+ last_weight_layer_idx = idx
+ if first_weight_layer_idx == -1:
+ first_weight_layer_idx = idx
+
+ if getattr(self.context, 'channel', None) != None:
+ assert len(
+ self.context.channel
+ ) == weight_layer_count, "length of channel must same as weight layer."
+
+ for idx, layer in enumerate(model):
+ if isinstance(layer, Conv2D):
+ attr_dict = layer.__dict__
+ key = attr_dict['_full_name']
+
+ new_attr_name = [
+ '_stride', '_dilation', '_groups', '_param_attr',
+ '_bias_attr', '_use_cudnn', '_act', '_dtype'
+ ]
+
+ new_attr_dict = dict()
+ new_attr_dict['candidate_config'] = dict()
+ self.kernel_size = getattr(self.context, 'kernel_size', None)
+
+ if self.kernel_size != None:
+ new_attr_dict['transform_kernel'] = True
+
+ # if the kernel_size of conv is 1, don't change it.
+ #if self.kernel_size and int(attr_dict['_filter_size'][0]) != 1:
+ if self.kernel_size and int(attr_dict['_filter_size']) != 1:
+ new_attr_dict['filter_size'] = max(self.kernel_size)
+ new_attr_dict['candidate_config'].update({
+ 'kernel_size': self.kernel_size
+ })
+ else:
+ new_attr_dict['filter_size'] = attr_dict['_filter_size']
+
+ if self.context.expand:
+ ### first super convolution
+ if idx == first_weight_layer_idx:
+ new_attr_dict['num_channels'] = attr_dict[
+ '_num_channels']
+ else:
+ new_attr_dict[
+ 'num_channels'] = self.context.expand * attr_dict[
+ '_num_channels']
+ ### last super convolution
+ if idx == last_weight_layer_idx:
+ new_attr_dict['num_filters'] = attr_dict['_num_filters']
+ else:
+ new_attr_dict[
+ 'num_filters'] = self.context.expand * attr_dict[
+ '_num_filters']
+ new_attr_dict['candidate_config'].update({
+ 'expand_ratio': self.context.expand_ratio
+ })
+ elif self.context.channel:
+ if attr_dict['_groups'] != None and (
+ int(attr_dict['_groups']) ==
+ int(attr_dict['_num_channels'])):
+ ### depthwise conv, if conv is depthwise, use pre channel as cur_channel
+ _logger.warn(
+ "If convolution is a depthwise conv, output channel change" \
+ " to the same channel with input, output channel in search is not used."
+ )
+ cur_channel = pre_channel
+ else:
+ cur_channel = self.context.channel[0]
+ self.context.channel = self.context.channel[1:]
+ if idx == first_weight_layer_idx:
+ new_attr_dict['num_channels'] = attr_dict[
+ '_num_channels']
+ else:
+ new_attr_dict['num_channels'] = max(pre_channel)
+
+ if idx == last_weight_layer_idx:
+ new_attr_dict['num_filters'] = attr_dict['_num_filters']
+ else:
+ new_attr_dict['num_filters'] = max(cur_channel)
+ new_attr_dict['candidate_config'].update({
+ 'channel': cur_channel
+ })
+ pre_channel = cur_channel
+ else:
+ new_attr_dict['num_filters'] = attr_dict['_num_filters']
+ new_attr_dict['num_channels'] = attr_dict['_num_channels']
+
+ for attr in new_attr_name:
+ new_attr_dict[attr[1:]] = attr_dict[attr]
+
+ del layer
+
+ if attr_dict['_groups'] == None or int(attr_dict[
+ '_groups']) == 1:
+ ### standard conv
+ layer = Block(SuperConv2D(**new_attr_dict), key=key)
+ elif int(attr_dict['_groups']) == int(attr_dict[
+ '_num_channels']):
+ # if conv is depthwise conv, groups = in_channel, out_channel = in_channel,
+ # channel in candidate_config = in_channel_list
+ if 'channel' in new_attr_dict['candidate_config']:
+ new_attr_dict['num_channels'] = max(cur_channel)
+ new_attr_dict['num_filters'] = new_attr_dict[
+ 'num_channels']
+ new_attr_dict['candidate_config'][
+ 'channel'] = cur_channel
+ new_attr_dict['groups'] = new_attr_dict['num_channels']
+ layer = Block(
+ SuperDepthwiseConv2D(**new_attr_dict), key=key)
+ else:
+ ### group conv
+ layer = Block(SuperGroupConv2D(**new_attr_dict), key=key)
+ model[idx] = layer
+
+ elif isinstance(layer, BatchNorm) and (
+ getattr(self.context, 'expand', None) != None or
+ getattr(self.context, 'channel', None) != None):
+ # num_features in BatchNorm don't change after last weight operators
+ if idx > last_weight_layer_idx:
+ continue
+
+ attr_dict = layer.__dict__
+ new_attr_name = [
+ '_param_attr', '_bias_attr', '_act', '_dtype', '_in_place',
+ '_data_layout', '_momentum', '_epsilon', '_is_test',
+ '_use_global_stats', '_trainable_statistics'
+ ]
+ new_attr_dict = dict()
+ if self.context.expand:
+ new_attr_dict['num_channels'] = self.context.expand * int(
+ layer._parameters['weight'].shape[0])
+ elif self.context.channel:
+ new_attr_dict['num_channels'] = max(cur_channel)
+
+ for attr in new_attr_name:
+ new_attr_dict[attr[1:]] = attr_dict[attr]
+
+ del layer, attr_dict
+
+ layer = SuperBatchNorm(**new_attr_dict)
+ model[idx] = layer
+
+ ### assume output_size = None, filter_size != None
+ ### NOTE: output_size != None may raise error, solve when it happend.
+ elif isinstance(layer, Conv2DTranspose):
+ attr_dict = layer.__dict__
+ key = attr_dict['_full_name']
+
+ new_attr_name = [
+ '_stride', '_dilation', '_groups', '_param_attr',
+ '_bias_attr', '_use_cudnn', '_act', '_dtype', '_output_size'
+ ]
+ assert attr_dict[
+ '_filter_size'] != None, "Conv2DTranspose only support filter size != None now"
+
+ new_attr_dict = dict()
+ new_attr_dict['candidate_config'] = dict()
+ self.kernel_size = getattr(self.context, 'kernel_size', None)
+
+ if self.kernel_size != None:
+ new_attr_dict['transform_kernel'] = True
+
+ # if the kernel_size of conv transpose is 1, don't change it.
+ if self.kernel_size and int(attr_dict['_filter_size'][0]) != 1:
+ new_attr_dict['filter_size'] = max(self.kernel_size)
+ new_attr_dict['candidate_config'].update({
+ 'kernel_size': self.kernel_size
+ })
+ else:
+ new_attr_dict['filter_size'] = attr_dict['_filter_size']
+
+ if self.context.expand:
+ ### first super convolution transpose
+ if idx == first_weight_layer_idx:
+ new_attr_dict['num_channels'] = attr_dict[
+ '_num_channels']
+ else:
+ new_attr_dict[
+ 'num_channels'] = self.context.expand * attr_dict[
+ '_num_channels']
+ ### last super convolution transpose
+ if idx == last_weight_layer_idx:
+ new_attr_dict['num_filters'] = attr_dict['_num_filters']
+ else:
+ new_attr_dict[
+ 'num_filters'] = self.context.expand * attr_dict[
+ '_num_filters']
+ new_attr_dict['candidate_config'].update({
+ 'expand_ratio': self.context.expand_ratio
+ })
+ elif self.context.channel:
+ if attr_dict['_groups'] != None and (
+ int(attr_dict['_groups']) ==
+ int(attr_dict['_num_channels'])):
+ ### depthwise conv_transpose
+ _logger.warn(
+ "If convolution is a depthwise conv_transpose, output channel " \
+ "change to the same channel with input, output channel in search is not used."
+ )
+ cur_channel = pre_channel
+ else:
+ cur_channel = self.context.channel[0]
+ self.context.channel = self.context.channel[1:]
+ if idx == first_weight_layer_idx:
+ new_attr_dict['num_channels'] = attr_dict[
+ '_num_channels']
+ else:
+ new_attr_dict['num_channels'] = max(pre_channel)
+
+ if idx == last_weight_layer_idx:
+ new_attr_dict['num_filters'] = attr_dict['_num_filters']
+ else:
+ new_attr_dict['num_filters'] = max(cur_channel)
+ new_attr_dict['candidate_config'].update({
+ 'channel': cur_channel
+ })
+ pre_channel = cur_channel
+ else:
+ new_attr_dict['num_filters'] = attr_dict['_num_filters']
+ new_attr_dict['num_channels'] = attr_dict['_num_channels']
+
+ for attr in new_attr_name:
+ new_attr_dict[attr[1:]] = attr_dict[attr]
+
+ del layer
+
+ if new_attr_dict['output_size'] == []:
+ new_attr_dict['output_size'] = None
+
+ if attr_dict['_groups'] == None or int(attr_dict[
+ '_groups']) == 1:
+ ### standard conv_transpose
+ layer = Block(
+ SuperConv2DTranspose(**new_attr_dict), key=key)
+ elif int(attr_dict['_groups']) == int(attr_dict[
+ '_num_channels']):
+ # if conv is depthwise conv, groups = in_channel, out_channel = in_channel,
+ # channel in candidate_config = in_channel_list
+ if 'channel' in new_attr_dict['candidate_config']:
+ new_attr_dict['num_channels'] = max(cur_channel)
+ new_attr_dict['num_filters'] = new_attr_dict[
+ 'num_channels']
+ new_attr_dict['candidate_config'][
+ 'channel'] = cur_channel
+ new_attr_dict['groups'] = new_attr_dict['num_channels']
+ layer = Block(
+ SuperDepthwiseConv2DTranspose(**new_attr_dict), key=key)
+ else:
+ ### group conv_transpose
+ layer = Block(
+ SuperGroupConv2DTranspose(**new_attr_dict), key=key)
+ model[idx] = layer
+
+ elif isinstance(layer, Linear) and (
+ getattr(self.context, 'expand', None) != None or
+ getattr(self.context, 'channel', None) != None):
+ attr_dict = layer.__dict__
+ key = attr_dict['_full_name']
+ ### TODO(paddle): add _param_attr and _bias_attr as private variable of Linear
+ #new_attr_name = ['_act', '_dtype', '_param_attr', '_bias_attr']
+ new_attr_name = ['_act', '_dtype']
+ in_nc, out_nc = layer._parameters['weight'].shape
+
+ new_attr_dict = dict()
+ new_attr_dict['candidate_config'] = dict()
+ if self.context.expand:
+ if idx == first_weight_layer_idx:
+ new_attr_dict['input_dim'] = int(in_nc)
+ else:
+ new_attr_dict['input_dim'] = self.context.expand * int(
+ in_nc)
+
+ if idx == last_weight_layer_idx:
+ new_attr_dict['output_dim'] = int(out_nc)
+ else:
+ new_attr_dict['output_dim'] = self.context.expand * int(
+ out_nc)
+ new_attr_dict['candidate_config'].update({
+ 'expand_ratio': self.context.expand_ratio
+ })
+ elif self.context.channel:
+ cur_channel = self.context.channel[0]
+ self.context.channel = self.context.channel[1:]
+ if idx == first_weight_layer_idx:
+ new_attr_dict['input_dim'] = int(in_nc)
+ else:
+ new_attr_dict['input_dim'] = max(pre_channel)
+
+ if idx == last_weight_layer_idx:
+ new_attr_dict['output_dim'] = int(out_nc)
+ else:
+ new_attr_dict['output_dim'] = max(cur_channel)
+ new_attr_dict['candidate_config'].update({
+ 'channel': cur_channel
+ })
+ pre_channel = cur_channel
+ else:
+ new_attr_dict['input_dim'] = int(in_nc)
+ new_attr_dict['output_dim'] = int(out_nc)
+
+ for attr in new_attr_name:
+ new_attr_dict[attr[1:]] = attr_dict[attr]
+
+ del layer, attr_dict
+
+ layer = Block(SuperLinear(**new_attr_dict), key=key)
+ model[idx] = layer
+
+ elif isinstance(layer, InstanceNorm) and (
+ getattr(self.context, 'expand', None) != None or
+ getattr(self.context, 'channel', None) != None):
+ # num_features in InstanceNorm don't change after last weight operators
+ if idx > last_weight_layer_idx:
+ continue
+
+ attr_dict = layer.__dict__
+ new_attr_name = [
+ '_param_attr', '_bias_attr', '_dtype', '_epsilon'
+ ]
+ new_attr_dict = dict()
+ if self.context.expand:
+ new_attr_dict['num_channels'] = self.context.expand * int(
+ layer._parameters['scale'].shape[0])
+ elif self.context.channel:
+ new_attr_dict['num_channels'] = max(cur_channel)
+
+ for attr in new_attr_name:
+ new_attr_dict[attr[1:]] = attr_dict[attr]
+
+ del layer, attr_dict
+
+ layer = SuperInstanceNorm(**new_attr_dict)
+ model[idx] = layer
+
+ return model
+
+
+class supernet:
+ def __init__(self, **kwargs):
+ for key, value in kwargs.items():
+ setattr(self, key, value)
+
+ assert (
+ getattr(self, 'expand_ratio', None) == None or
+ getattr(self, 'channel', None) == None
+ ), "expand_ratio and channel CANNOT be NOT None at the same time."
+
+ self.expand = None
+ if 'expand_ratio' in kwargs.keys():
+ if isinstance(self.expand_ratio, list) or isinstance(
+ self.expand_ratio, tuple):
+ self.expand = max(self.expand_ratio)
+ elif isinstance(self.expand_ratio, int):
+ self.expand = self.expand_ratio
+
+ def __enter__(self):
+ return Convert(self)
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ pass
+
+
+#def ofa_supernet(kernel_size, expand_ratio):
+# def _ofa_supernet(func):
+# @functools.wraps(func)
+# def convert(*args, **kwargs):
+# supernet_convert(*args, **kwargs)
+# return convert
+# return _ofa_supernet
diff --git a/paddleslim/nas/ofa/layers.py b/paddleslim/nas/ofa/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d91f5338a8a1f9ee67cc1d7dab2657d85348454
--- /dev/null
+++ b/paddleslim/nas/ofa/layers.py
@@ -0,0 +1,929 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import logging
+import paddle.fluid as fluid
+import paddle.fluid.core as core
+import paddle.fluid.dygraph_utils as dygraph_utils
+from paddle.fluid.data_feeder import check_variable_and_dtype
+from paddle.fluid.framework import in_dygraph_mode, _varbase_creator
+from paddle.fluid.dygraph.nn import InstanceNorm, Conv2D, Conv2DTranspose, BatchNorm
+
+from ...common import get_logger
+from .utils.utils import compute_start_end, get_same_padding, convert_to_list
+
+__all__ = [
+ 'SuperConv2D', 'SuperConv2DTranspose', 'SuperSeparableConv2D',
+ 'SuperBatchNorm', 'SuperLinear', 'SuperInstanceNorm', 'Block',
+ 'SuperGroupConv2D', 'SuperDepthwiseConv2D', 'SuperGroupConv2DTranspose',
+ 'SuperDepthwiseConv2DTranspose'
+]
+
+_logger = get_logger(__name__, level=logging.INFO)
+
+### TODO: if task is elastic width, need to add re_organize_middle_weight in 1x1 conv in MBBlock
+
+_cnt = 0
+
+
+def counter():
+ global _cnt
+ _cnt += 1
+ return _cnt
+
+
+class BaseBlock(fluid.dygraph.Layer):
+ def __init__(self, key=None):
+ super(BaseBlock, self).__init__()
+ if key is not None:
+ self._key = str(key)
+ else:
+ self._key = self.__class__.__name__ + str(counter())
+
+ # set SuperNet class
+ def set_supernet(self, supernet):
+ self.__dict__['supernet'] = supernet
+
+ @property
+ def key(self):
+ return self._key
+
+
+class Block(BaseBlock):
+ """
+ Model is composed of nest blocks.
+
+ Parameters:
+ fn(Layer): instance of super layers, such as: SuperConv2D(3, 5, 3).
+ key(str, optional): key of this layer, one-to-one correspondence between key and candidate config. Default: None.
+ """
+
+ def __init__(self, fn, key=None):
+ super(Block, self).__init__(key)
+ self.fn = fn
+ self.candidate_config = self.fn.candidate_config
+
+ def forward(self, *inputs, **kwargs):
+ out = self.supernet.layers_forward(self, *inputs, **kwargs)
+ return out
+
+
+class SuperConv2D(fluid.dygraph.Conv2D):
+ """
+ This interface is used to construct a callable object of the ``SuperConv2D`` class.
+ The difference between ```SuperConv2D``` and ```Conv2D``` is: ```SuperConv2D``` need
+ to feed a config dictionary with the format of {'channel', num_of_channel} represents
+ the channels of the outputs, used to change the first dimension of weight and bias,
+ only train the first channels of the weight and bias.
+
+ Note: the channel in config need to less than first defined.
+
+ The super convolution2D layer calculates the output based on the input, filter
+ and strides, paddings, dilations, groups parameters. Input and
+ Output are in NCHW format, where N is batch size, C is the number of
+ the feature map, H is the height of the feature map, and W is the width of the feature map.
+ Filter's shape is [MCHW] , where M is the number of output feature map,
+ C is the number of input feature map, H is the height of the filter,
+ and W is the width of the filter. If the groups is greater than 1,
+ C will equal the number of input feature map divided by the groups.
+ Please refer to UFLDL's `convolution
+ `_
+ for more details.
+ If bias attribution and activation type are provided, bias is added to the
+ output of the convolution, and the corresponding activation function is
+ applied to the final result.
+ For each input :math:`X`, the equation is:
+ .. math::
+ Out = \\sigma (W \\ast X + b)
+ Where:
+ * :math:`X`: Input value, a ``Tensor`` with NCHW format.
+ * :math:`W`: Filter value, a ``Tensor`` with shape [MCHW] .
+ * :math:`\\ast`: Convolution operation.
+ * :math:`b`: Bias value, a 2-D ``Tensor`` with shape [M, 1].
+ * :math:`\\sigma`: Activation function.
+ * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different.
+
+ Example:
+ - Input:
+ Input shape: :math:`(N, C_{in}, H_{in}, W_{in})`
+ Filter shape: :math:`(C_{out}, C_{in}, H_f, W_f)`
+ - Output:
+ Output shape: :math:`(N, C_{out}, H_{out}, W_{out})`
+ Where
+ .. math::
+ H_{out}&= \\frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]} + 1 \\\\
+ W_{out}&= \\frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1
+ Parameters:
+ num_channels(int): The number of channels in the input image.
+ num_filters(int): The number of filter. It is as same as the output
+ feature map.
+ filter_size (int or tuple): The filter size. If filter_size is a tuple,
+ it must contain two integers, (filter_size_H, filter_size_W).
+ Otherwise, the filter will be a square.
+ candidate_config(dict, optional): Dictionary descripts candidate config of this layer,
+ such as {'kernel_size': (3, 5, 7), 'channel': (4, 6, 8)}, means the kernel size of
+ this layer can be choose from (3, 5, 7), the key of candidate_config
+ only can be 'kernel_size', 'channel' and 'expand_ratio', 'channel' and 'expand_ratio'
+ CANNOT be set at the same time. Default: None.
+ transform_kernel(bool, optional): Whether to use transform matrix to transform a large filter
+ to a small filter. Default: False.
+ stride (int or tuple, optional): The stride size. If stride is a tuple, it must
+ contain two integers, (stride_H, stride_W). Otherwise, the
+ stride_H = stride_W = stride. Default: 1.
+ padding (int or tuple, optional): The padding size. If padding is a tuple, it must
+ contain two integers, (padding_H, padding_W). Otherwise, the
+ padding_H = padding_W = padding. Default: 0.
+ dilation (int or tuple, optional): The dilation size. If dilation is a tuple, it must
+ contain two integers, (dilation_H, dilation_W). Otherwise, the
+ dilation_H = dilation_W = dilation. Default: 1.
+ groups (int, optional): The groups number of the Conv2d Layer. According to grouped
+ convolution in Alex Krizhevsky's Deep CNN paper: when group=2,
+ the first half of the filters is only connected to the first half
+ of the input channels, while the second half of the filters is only
+ connected to the second half of the input channels. Default: 1.
+ param_attr (ParamAttr, optional): The parameter attribute for learnable weights(Parameter)
+ of conv2d. If it is set to None or one attribute of ParamAttr, conv2d
+ will create ParamAttr as param_attr. If the Initializer of the param_attr
+ is not set, the parameter is initialized with :math:`Normal(0.0, std)`,
+ and the :math:`std` is :math:`(\\frac{2.0 }{filter\_elem\_num})^{0.5}`. Default: None.
+ bias_attr (ParamAttr or bool, optional): The attribute for the bias of conv2d.
+ If it is set to False, no bias will be added to the output units.
+ If it is set to None or one attribute of ParamAttr, conv2d
+ will create ParamAttr as bias_attr. If the Initializer of the bias_attr
+ is not set, the bias is initialized zero. Default: None.
+ use_cudnn (bool, optional): Use cudnn kernel or not, it is valid only when the cudnn
+ library is installed. Default: True.
+ act (str, optional): Activation type, if it is set to None, activation is not appended.
+ Default: None.
+ dtype (str, optional): Data type, it can be "float32" or "float64". Default: "float32".
+ Attribute:
+ **weight** (Parameter): the learnable weights of filter of this layer.
+ **bias** (Parameter or None): the learnable bias of this layer.
+ Returns:
+ None
+
+ Raises:
+ ValueError: if ``use_cudnn`` is not a bool value.
+ Examples:
+ .. code-block:: python
+ from paddle.fluid.dygraph.base import to_variable
+ import paddle.fluid as fluid
+ from paddleslim.core.layers import SuperConv2D
+ import numpy as np
+ data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
+ with fluid.dygraph.guard():
+ super_conv2d = SuperConv2D(3, 10, 3)
+ config = {'channel': 5}
+ data = to_variable(data)
+ conv = super_conv2d(data, config)
+
+ """
+
+ ### NOTE: filter_size, num_channels and num_filters must be the max of candidate to define a largest network.
+ def __init__(self,
+ num_channels,
+ num_filters,
+ filter_size,
+ candidate_config={},
+ transform_kernel=False,
+ stride=1,
+ dilation=1,
+ padding=0,
+ groups=None,
+ param_attr=None,
+ bias_attr=None,
+ use_cudnn=True,
+ act=None,
+ dtype='float32'):
+ ### NOTE: padding always is 0, add padding in forward because of kernel size is uncertain
+ ### TODO: change padding to any padding
+ super(SuperConv2D, self).__init__(
+ num_channels, num_filters, filter_size, stride, padding, dilation,
+ groups, param_attr, bias_attr, use_cudnn, act, dtype)
+
+ if isinstance(self._filter_size, int):
+ self._filter_size = convert_to_list(self._filter_size, 2)
+
+ self.candidate_config = candidate_config
+ if len(candidate_config.items()) != 0:
+ for k, v in candidate_config.items():
+ candidate_config[k] = list(set(v))
+
+ self.ks_set = candidate_config[
+ 'kernel_size'] if 'kernel_size' in candidate_config else None
+
+ self.expand_ratio = candidate_config[
+ 'expand_ratio'] if 'expand_ratio' in candidate_config else None
+ self.channel = candidate_config[
+ 'channel'] if 'channel' in candidate_config else None
+ self.base_channel = None
+ if self.expand_ratio != None:
+ self.base_channel = int(self._num_filters / max(self.expand_ratio))
+
+ self.transform_kernel = transform_kernel
+ if self.ks_set != None:
+ self.ks_set.sort()
+ if self.transform_kernel != False:
+ scale_param = dict()
+ ### create parameter to transform kernel
+ for i in range(len(self.ks_set) - 1):
+ ks_small = self.ks_set[i]
+ ks_large = self.ks_set[i + 1]
+ param_name = '%dto%d_matrix' % (ks_large, ks_small)
+ ks_t = ks_small**2
+ scale_param[param_name] = self.create_parameter(
+ attr=fluid.ParamAttr(
+ name=self._full_name + param_name,
+ initializer=fluid.initializer.NumpyArrayInitializer(
+ np.eye(ks_t))),
+ shape=(ks_t, ks_t),
+ dtype=self._dtype)
+
+ for name, param in scale_param.items():
+ setattr(self, name, param)
+
+ def get_active_filter(self, in_nc, out_nc, kernel_size):
+ start, end = compute_start_end(self._filter_size[0], kernel_size)
+ ### if NOT transform kernel, intercept a center filter with kernel_size from largest filter
+ filters = self.weight[:out_nc, :in_nc, start:end, start:end]
+ if self.transform_kernel != False and kernel_size < self._filter_size[
+ 0]:
+ ### if transform kernel, then use matrix to transform
+ start_filter = self.weight[:out_nc, :in_nc, :, :]
+ for i in range(len(self.ks_set) - 1, 0, -1):
+ src_ks = self.ks_set[i]
+ if src_ks <= kernel_size:
+ break
+ target_ks = self.ks_set[i - 1]
+ start, end = compute_start_end(src_ks, target_ks)
+ _input_filter = start_filter[:, :, start:end, start:end]
+ _input_filter = fluid.layers.reshape(
+ _input_filter,
+ shape=[(_input_filter.shape[0] * _input_filter.shape[1]),
+ -1])
+ core.ops.matmul(_input_filter,
+ self.__getattr__('%dto%d_matrix' %
+ (src_ks, target_ks)),
+ _input_filter, 'transpose_X', False,
+ 'transpose_Y', False, "alpha", 1)
+ _input_filter = fluid.layers.reshape(
+ _input_filter,
+ shape=[
+ filters.shape[0], filters.shape[1], target_ks, target_ks
+ ])
+ start_filter = _input_filter
+ filters = start_filter
+ return filters
+
+ def get_groups_in_out_nc(self, in_nc, out_nc):
+ ### standard conv
+ return self._groups, in_nc, out_nc
+
+ def forward(self, input, kernel_size=None, expand_ratio=None, channel=None):
+
+ if not in_dygraph_mode():
+ _logger.error("NOT support static graph")
+
+ in_nc = int(input.shape[1])
+ assert (
+ expand_ratio == None or channel == None
+ ), "expand_ratio and channel CANNOT be NOT None at the same time."
+ if expand_ratio != None:
+ out_nc = int(expand_ratio * self.base_channel)
+ elif channel != None:
+ out_nc = int(channel)
+ else:
+ out_nc = self._num_filters
+ ks = int(self._filter_size[0]) if kernel_size == None else int(
+ kernel_size)
+
+ groups, weight_in_nc, weight_out_nc = self.get_groups_in_out_nc(in_nc,
+ out_nc)
+
+ weight = self.get_active_filter(weight_in_nc, weight_out_nc, ks)
+ padding = convert_to_list(get_same_padding(ks), 2)
+
+ if self._l_type == 'conv2d':
+ attrs = ('strides', self._stride, 'paddings', padding, 'dilations',
+ self._dilation, 'groups', groups
+ if groups else 1, 'use_cudnn', self._use_cudnn)
+ out = core.ops.conv2d(input, weight, *attrs)
+ elif self._l_type == 'depthwise_conv2d':
+ attrs = ('strides', self._stride, 'paddings', padding, 'dilations',
+ self._dilation, 'groups', groups
+ if groups else self._groups, 'use_cudnn', self._use_cudnn)
+ out = core.ops.depthwise_conv2d(input, weight, *attrs)
+ else:
+ raise ValueError("conv type error")
+
+ pre_bias = out
+ out_nc = int(pre_bias.shape[1])
+ if self.bias is not None:
+ bias = self.bias[:out_nc]
+ pre_act = dygraph_utils._append_bias_in_dygraph(pre_bias, bias, 1)
+ else:
+ pre_act = pre_bias
+
+ return dygraph_utils._append_activation_in_dygraph(pre_act, self._act)
+
+
+class SuperGroupConv2D(SuperConv2D):
+ def get_groups_in_out_nc(self, in_nc, out_nc):
+ ### groups convolution
+ ### conv: weight: (Cout, Cin/G, Kh, Kw)
+ groups = self._groups
+ in_nc = int(in_nc // groups)
+ return groups, in_nc, out_nc
+
+
+class SuperDepthwiseConv2D(SuperConv2D):
+ ### depthwise convolution
+ def get_groups_in_out_nc(self, in_nc, out_nc):
+ if in_nc != out_nc:
+ _logger.debug(
+ "input channel and output channel in depthwise conv is different, change output channel to input channel! origin channel:(in_nc {}, out_nc {}): ".
+ format(in_nc, out_nc))
+ groups = in_nc
+ out_nc = in_nc
+ return groups, in_nc, out_nc
+
+
+class SuperConv2DTranspose(fluid.dygraph.Conv2DTranspose):
+ """
+ This interface is used to construct a callable object of the ``SuperConv2DTranspose``
+ class.
+ The difference between ```SuperConv2DTranspose``` and ```Conv2DTranspose``` is:
+ ```SuperConv2DTranspose``` need to feed a config dictionary with the format of
+ {'channel', num_of_channel} represents the channels of the outputs, used to change
+ the first dimension of weight and bias, only train the first channels of the weight
+ and bias.
+
+ Note: the channel in config need to less than first defined.
+
+ The super convolution2D transpose layer calculates the output based on the input,
+ filter, and dilations, strides, paddings. Input and output
+ are in NCHW format. Where N is batch size, C is the number of feature map,
+ H is the height of the feature map, and W is the width of the feature map.
+ Filter's shape is [MCHW] , where M is the number of input feature map,
+ C is the number of output feature map, H is the height of the filter,
+ and W is the width of the filter. If the groups is greater than 1,
+ C will equal the number of input feature map divided by the groups.
+ If bias attribution and activation type are provided, bias is added to
+ the output of the convolution, and the corresponding activation function
+ is applied to the final result.
+ The details of convolution transpose layer, please refer to the following explanation and references
+ `conv2dtranspose `_ .
+ For each input :math:`X`, the equation is:
+ .. math::
+ Out = \sigma (W \\ast X + b)
+ Where:
+ * :math:`X`: Input value, a ``Tensor`` with NCHW format.
+ * :math:`W`: Filter value, a ``Tensor`` with shape [MCHW] .
+ * :math:`\\ast`: Convolution operation.
+ * :math:`b`: Bias value, a 2-D ``Tensor`` with shape [M, 1].
+ * :math:`\\sigma`: Activation function.
+ * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different.
+ Example:
+ - Input:
+ Input shape: :math:`(N, C_{in}, H_{in}, W_{in})`
+ Filter shape: :math:`(C_{in}, C_{out}, H_f, W_f)`
+ - Output:
+ Output shape: :math:`(N, C_{out}, H_{out}, W_{out})`
+ Where
+ .. math::
+ H^\prime_{out} &= (H_{in} - 1) * strides[0] - 2 * paddings[0] + dilations[0] * (H_f - 1) + 1 \\\\
+ W^\prime_{out} &= (W_{in} - 1) * strides[1] - 2 * paddings[1] + dilations[1] * (W_f - 1) + 1 \\\\
+ H_{out} &\in [ H^\prime_{out}, H^\prime_{out} + strides[0] ) \\\\
+ W_{out} &\in [ W^\prime_{out}, W^\prime_{out} + strides[1] )
+ Parameters:
+ num_channels(int): The number of channels in the input image.
+ num_filters(int): The number of the filter. It is as same as the output
+ feature map.
+ filter_size(int or tuple): The filter size. If filter_size is a tuple,
+ it must contain two integers, (filter_size_H, filter_size_W).
+ Otherwise, the filter will be a square.
+ candidate_config(dict, optional): Dictionary descripts candidate config of this layer,
+ such as {'kernel_size': (3, 5, 7), 'channel': (4, 6, 8)}, means the kernel size of
+ this layer can be choose from (3, 5, 7), the key of candidate_config
+ only can be 'kernel_size', 'channel' and 'expand_ratio', 'channel' and 'expand_ratio'
+ CANNOT be set at the same time. Default: None.
+ transform_kernel(bool, optional): Whether to use transform matrix to transform a large filter
+ to a small filter. Default: False.
+ output_size(int or tuple, optional): The output image size. If output size is a
+ tuple, it must contain two integers, (image_H, image_W). None if use
+ filter_size, padding, and stride to calculate output_size.
+ if output_size and filter_size are specified at the same time, They
+ should follow the formula above. Default: None.
+ padding(int or tuple, optional): The padding size. If padding is a tuple, it must
+ contain two integers, (padding_H, padding_W). Otherwise, the
+ padding_H = padding_W = padding. Default: 0.
+ stride(int or tuple, optional): The stride size. If stride is a tuple, it must
+ contain two integers, (stride_H, stride_W). Otherwise, the
+ stride_H = stride_W = stride. Default: 1.
+ dilation(int or tuple, optional): The dilation size. If dilation is a tuple, it must
+ contain two integers, (dilation_H, dilation_W). Otherwise, the
+ dilation_H = dilation_W = dilation. Default: 1.
+ groups(int, optional): The groups number of the Conv2d transpose layer. Inspired by
+ grouped convolution in Alex Krizhevsky's Deep CNN paper, in which
+ when group=2, the first half of the filters is only connected to the
+ first half of the input channels, while the second half of the
+ filters is only connected to the second half of the input channels.
+ Default: 1.
+ param_attr (ParamAttr, optional): The parameter attribute for learnable weights(Parameter)
+ of conv2d_transpose. If it is set to None or one attribute of ParamAttr, conv2d_transpose
+ will create ParamAttr as param_attr. If the Initializer of the param_attr
+ is not set, the parameter is initialized with Xavier. Default: None.
+ bias_attr (ParamAttr or bool, optional): The attribute for the bias of conv2d_transpose.
+ If it is set to False, no bias will be added to the output units.
+ If it is set to None or one attribute of ParamAttr, conv2d_transpose
+ will create ParamAttr as bias_attr. If the Initializer of the bias_attr
+ is not set, the bias is initialized zero. Default: None.
+ use_cudnn(bool, optional): Use cudnn kernel or not, it is valid only when the cudnn
+ library is installed. Default: True.
+ act (str, optional): Activation type, if it is set to None, activation is not appended.
+ Default: None.
+ dtype (str, optional): Data type, it can be "float32" or "float64". Default: "float32".
+ Attribute:
+ **weight** (Parameter): the learnable weights of filters of this layer.
+ **bias** (Parameter or None): the learnable bias of this layer.
+ Returns:
+ None
+ Examples:
+ .. code-block:: python
+ import paddle.fluid as fluid
+ from paddleslim.core.layers import SuperConv2DTranspose
+ import numpy as np
+ with fluid.dygraph.guard():
+ data = np.random.random((3, 32, 32, 5)).astype('float32')
+ config = {'channel': 5
+ super_convtranspose = SuperConv2DTranspose(num_channels=32, num_filters=10, filter_size=3)
+ ret = super_convtranspose(fluid.dygraph.base.to_variable(data), config)
+ """
+
+ def __init__(self,
+ num_channels,
+ num_filters,
+ filter_size,
+ output_size=None,
+ candidate_config={},
+ transform_kernel=False,
+ stride=1,
+ dilation=1,
+ padding=0,
+ groups=None,
+ param_attr=None,
+ bias_attr=None,
+ use_cudnn=True,
+ act=None,
+ dtype='float32'):
+ ### NOTE: padding always is 0, add padding in forward because of kernel size is uncertain
+ super(SuperConv2DTranspose, self).__init__(
+ num_channels, num_filters, filter_size, output_size, padding,
+ stride, dilation, groups, param_attr, bias_attr, use_cudnn, act,
+ dtype)
+ self.candidate_config = candidate_config
+ if len(self.candidate_config.items()) != 0:
+ for k, v in candidate_config.items():
+ candidate_config[k] = list(set(v))
+ self.ks_set = candidate_config[
+ 'kernel_size'] if 'kernel_size' in candidate_config else None
+
+ if isinstance(self._filter_size, int):
+ self._filter_size = convert_to_list(self._filter_size, 2)
+
+ self.expand_ratio = candidate_config[
+ 'expand_ratio'] if 'expand_ratio' in candidate_config else None
+ self.channel = candidate_config[
+ 'channel'] if 'channel' in candidate_config else None
+ self.base_channel = None
+ if self.expand_ratio:
+ self.base_channel = int(self._num_filters / max(self.expand_ratio))
+
+ self.transform_kernel = transform_kernel
+ if self.ks_set != None:
+ self.ks_set.sort()
+ if self.transform_kernel != False:
+ scale_param = dict()
+ ### create parameter to transform kernel
+ for i in range(len(self.ks_set) - 1):
+ ks_small = self.ks_set[i]
+ ks_large = self.ks_set[i + 1]
+ param_name = '%dto%d_matrix' % (ks_large, ks_small)
+ ks_t = ks_small**2
+ scale_param[param_name] = self.create_parameter(
+ attr=fluid.ParamAttr(
+ name=self._full_name + param_name,
+ initializer=fluid.initializer.NumpyArrayInitializer(
+ np.eye(ks_t))),
+ shape=(ks_t, ks_t),
+ dtype=self._dtype)
+
+ for name, param in scale_param.items():
+ setattr(self, name, param)
+
+ def get_active_filter(self, in_nc, out_nc, kernel_size):
+ start, end = compute_start_end(self._filter_size[0], kernel_size)
+ filters = self.weight[:in_nc, :out_nc, start:end, start:end]
+ if self.transform_kernel != False and kernel_size < self._filter_size[
+ 0]:
+ start_filter = self.weight[:in_nc, :out_nc, :, :]
+ for i in range(len(self.ks_set) - 1, 0, -1):
+ src_ks = self.ks_set[i]
+ if src_ks <= kernel_size:
+ break
+ target_ks = self.ks_set[i - 1]
+ start, end = compute_start_end(src_ks, target_ks)
+ _input_filter = start_filter[:, :, start:end, start:end]
+ _input_filter = fluid.layers.reshape(
+ _input_filter,
+ shape=[(_input_filter.shape[0] * _input_filter.shape[1]),
+ -1])
+ core.ops.matmul(_input_filter,
+ self.__getattr__('%dto%d_matrix' %
+ (src_ks, target_ks)),
+ _input_filter, 'transpose_X', False,
+ 'transpose_Y', False, "alpha", 1)
+ _input_filter = fluid.layers.reshape(
+ _input_filter,
+ shape=[
+ filters.shape[0], filters.shape[1], target_ks, target_ks
+ ])
+ start_filter = _input_filter
+ filters = start_filter
+ return filters
+
+ def get_groups_in_out_nc(self, in_nc, out_nc):
+ ### standard conv
+ return self._groups, in_nc, out_nc
+
+ def forward(self, input, kernel_size=None, expand_ratio=None, channel=None):
+ if not in_dygraph_mode():
+ _logger.error("NOT support static graph")
+
+ in_nc = int(input.shape[1])
+ assert (
+ expand_ratio == None or channel == None
+ ), "expand_ratio and channel CANNOT be NOT None at the same time."
+ if expand_ratio != None:
+ out_nc = int(expand_ratio * self.base_channel)
+ elif channel != None:
+ out_nc = int(channel)
+ else:
+ out_nc = self._num_filters
+
+ ks = int(self._filter_size[0]) if kernel_size == None else int(
+ kernel_size)
+
+ groups, weight_in_nc, weight_out_nc = self.get_groups_in_out_nc(in_nc,
+ out_nc)
+
+ weight = self.get_active_filter(weight_in_nc, weight_out_nc, ks)
+ padding = convert_to_list(get_same_padding(ks), 2)
+
+ op = getattr(core.ops, self._op_type)
+ out = op(input, weight, 'output_size', self._output_size, 'strides',
+ self._stride, 'paddings', padding, 'dilations', self._dilation,
+ 'groups', groups, 'use_cudnn', self._use_cudnn)
+ pre_bias = out
+ out_nc = int(pre_bias.shape[1])
+ if self.bias is not None:
+ bias = self.bias[:out_nc]
+ pre_act = dygraph_utils._append_bias_in_dygraph(pre_bias, bias, 1)
+ else:
+ pre_act = pre_bias
+
+ return dygraph_utils._append_activation_in_dygraph(
+ pre_act, act=self._act)
+
+
+class SuperGroupConv2DTranspose(SuperConv2DTranspose):
+ def get_groups_in_out_nc(self, in_nc, out_nc):
+ ### groups convolution
+ ### groups conv transpose: weight: (Cin, Cout/G, Kh, Kw)
+ groups = self._groups
+ out_nc = int(out_nc // groups)
+ return groups, in_nc, out_nc
+
+
+class SuperDepthwiseConv2DTranspose(SuperConv2DTranspose):
+ def get_groups_in_out_nc(self, in_nc, out_nc):
+ if in_nc != out_nc:
+ _logger.debug(
+ "input channel and output channel in depthwise conv transpose is different, change output channel to input channel! origin channel:(in_nc {}, out_nc {}): ".
+ format(in_nc, out_nc))
+ groups = in_nc
+ out_nc = in_nc
+ return groups, in_nc, out_nc
+
+
+### NOTE: only search channel, write for GAN-compression, maybe change to SuperDepthwiseConv and SuperConv after.
+class SuperSeparableConv2D(fluid.dygraph.Layer):
+ """
+ This interface is used to construct a callable object of the ``SuperSeparableConv2D``
+ class.
+ The difference between ```SuperSeparableConv2D``` and ```SeparableConv2D``` is:
+ ```SuperSeparableConv2D``` need to feed a config dictionary with the format of
+ {'channel', num_of_channel} represents the channels of the first conv's outputs and
+ the second conv's inputs, used to change the first dimension of weight and bias,
+ only train the first channels of the weight and bias.
+
+ The architecture of super separable convolution2D op is [Conv2D, norm layer(may be BatchNorm
+ or InstanceNorm), Conv2D]. The first conv is depthwise conv, the filter number is input channel
+ multiply scale_factor, the group is equal to the number of input channel. The second conv
+ is standard conv, which filter size and stride size are 1.
+
+ Parameters:
+ num_channels(int): The number of channels in the input image.
+ num_filters(int): The number of the second conv's filter. It is as same as the output
+ feature map.
+ filter_size(int or tuple): The first conv's filter size. If filter_size is a tuple,
+ it must contain two integers, (filter_size_H, filter_size_W).
+ Otherwise, the filter will be a square.
+ padding(int or tuple, optional): The first conv's padding size. If padding is a tuple,
+ it must contain two integers, (padding_H, padding_W). Otherwise, the
+ padding_H = padding_W = padding. Default: 0.
+ stride(int or tuple, optional): The first conv's stride size. If stride is a tuple,
+ it must contain two integers, (stride_H, stride_W). Otherwise, the
+ stride_H = stride_W = stride. Default: 1.
+ dilation(int or tuple, optional): The first conv's dilation size. If dilation is a tuple,
+ it must contain two integers, (dilation_H, dilation_W). Otherwise, the
+ dilation_H = dilation_W = dilation. Default: 1.
+ norm_layer(class): The normalization layer between two convolution. Default: InstanceNorm.
+ bias_attr (ParamAttr or bool, optional): The attribute for the bias of convolution.
+ If it is set to False, no bias will be added to the output units.
+ If it is set to None or one attribute of ParamAttr, convolution
+ will create ParamAttr as bias_attr. If the Initializer of the bias_attr
+ is not set, the bias is initialized zero. Default: None.
+ scale_factor(float): The scale factor of the first conv's output channel. Default: 1.
+ use_cudnn(bool, optional): Use cudnn kernel or not, it is valid only when the cudnn
+ library is installed. Default: True.
+ Returns:
+ None
+ """
+
+ def __init__(self,
+ num_channels,
+ num_filters,
+ filter_size,
+ candidate_config={},
+ stride=1,
+ padding=0,
+ dilation=1,
+ norm_layer=InstanceNorm,
+ bias_attr=None,
+ scale_factor=1,
+ use_cudnn=False):
+ super(SuperSeparableConv2D, self).__init__()
+ self.conv = fluid.dygraph.LayerList([
+ fluid.dygraph.nn.Conv2D(
+ num_channels=num_channels,
+ num_filters=num_channels * scale_factor,
+ filter_size=filter_size,
+ stride=stride,
+ padding=padding,
+ use_cudnn=False,
+ groups=num_channels,
+ bias_attr=bias_attr)
+ ])
+
+ self.conv.extend([norm_layer(num_channels * scale_factor)])
+
+ self.conv.extend([
+ Conv2D(
+ num_channels=num_channels * scale_factor,
+ num_filters=num_filters,
+ filter_size=1,
+ stride=1,
+ use_cudnn=use_cudnn,
+ bias_attr=bias_attr)
+ ])
+
+ self.candidate_config = candidate_config
+ self.expand_ratio = candidate_config[
+ 'expand_ratio'] if 'expand_ratio' in candidate_config else None
+ self.base_output_dim = None
+ if self.expand_ratio != None:
+ self.base_output_dim = int(self.output_dim / max(self.expand_ratio))
+
+ def forward(self, input, expand_ratio=None, channel=None):
+ if not in_dygraph_mode():
+ _logger.error("NOT support static graph")
+
+ in_nc = int(input.shape[1])
+ assert (
+ expand_ratio == None or channel == None
+ ), "expand_ratio and channel CANNOT be NOT None at the same time."
+ if expand_ratio != None:
+ out_nc = int(expand_ratio * self.base_output_dim)
+ elif channel != None:
+ out_nc = int(channel)
+ else:
+ out_nc = self.conv[0]._num_filters
+
+ weight = self.conv[0].weight[:in_nc]
+ ### conv1
+ if self.conv[0]._l_type == 'conv2d':
+ attrs = ('strides', self.conv[0]._stride, 'paddings',
+ self.conv[0]._padding, 'dilations', self.conv[0]._dilation,
+ 'groups', in_nc, 'use_cudnn', self.conv[0]._use_cudnn)
+ out = core.ops.conv2d(input, weight, *attrs)
+ elif self.conv[0]._l_type == 'depthwise_conv2d':
+ attrs = ('strides', self.conv[0]._stride, 'paddings',
+ self.conv[0]._padding, 'dilations', self.conv[0]._dilation,
+ 'groups', in_nc, 'use_cudnn', self.conv[0]._use_cudnn)
+ out = core.ops.depthwise_conv2d(input, weight, *attrs)
+ else:
+ raise ValueError("conv type error")
+
+ pre_bias = out
+ if self.conv[0].bias is not None:
+ bias = self.conv[0].bias[:in_nc]
+ pre_act = dygraph_utils._append_bias_in_dygraph(pre_bias, bias, 1)
+ else:
+ pre_act = pre_bias
+
+ conv0_out = dygraph_utils._append_activation_in_dygraph(
+ pre_act, self.conv[0]._act)
+
+ norm_out = self.conv[1](conv0_out)
+
+ weight = self.conv[2].weight[:out_nc, :in_nc, :, :]
+
+ if self.conv[2]._l_type == 'conv2d':
+ attrs = ('strides', self.conv[2]._stride, 'paddings',
+ self.conv[2]._padding, 'dilations', self.conv[2]._dilation,
+ 'groups', self.conv[2]._groups if self.conv[2]._groups else
+ 1, 'use_cudnn', self.conv[2]._use_cudnn)
+ out = core.ops.conv2d(norm_out, weight, *attrs)
+ elif self.conv[2]._l_type == 'depthwise_conv2d':
+ attrs = ('strides', self.conv[2]._stride, 'paddings',
+ self.conv[2]._padding, 'dilations', self.conv[2]._dilation,
+ 'groups', self.conv[2]._groups, 'use_cudnn',
+ self.conv[2]._use_cudnn)
+ out = core.ops.depthwise_conv2d(norm_out, weight, *attrs)
+ else:
+ raise ValueError("conv type error")
+
+ pre_bias = out
+ if self.conv[2].bias is not None:
+ bias = self.conv[2].bias[:out_nc]
+ pre_act = dygraph_utils._append_bias_in_dygraph(pre_bias, bias, 1)
+ else:
+ pre_act = pre_bias
+
+ conv1_out = dygraph_utils._append_activation_in_dygraph(
+ pre_act, self.conv[2]._act)
+
+ return conv1_out
+
+
+class SuperLinear(fluid.dygraph.Linear):
+ """
+ """
+
+ def __init__(self,
+ input_dim,
+ output_dim,
+ candidate_config={},
+ param_attr=None,
+ bias_attr=None,
+ act=None,
+ dtype="float32"):
+ super(SuperLinear, self).__init__(input_dim, output_dim, param_attr,
+ bias_attr, act, dtype)
+ self._param_attr = param_attr
+ self._bias_attr = bias_attr
+ self.output_dim = output_dim
+ self.candidate_config = candidate_config
+ self.expand_ratio = candidate_config[
+ 'expand_ratio'] if 'expand_ratio' in candidate_config else None
+ self.base_output_dim = None
+ if self.expand_ratio != None:
+ self.base_output_dim = int(self.output_dim / max(self.expand_ratio))
+
+ def forward(self, input, expand_ratio=None, channel=None):
+ if not in_dygraph_mode():
+ _logger.error("NOT support static graph")
+
+ ### weight: (Cin, Cout)
+ in_nc = int(input.shape[1])
+ assert (
+ expand_ratio == None or channel == None
+ ), "expand_ratio and channel CANNOT be NOT None at the same time."
+ if expand_ratio != None:
+ out_nc = int(expand_ratio * self.base_output_dim)
+ elif channel != None:
+ out_nc = int(channel)
+ else:
+ out_nc = self.output_dim
+
+ weight = self.weight[:in_nc, :out_nc]
+ if self._bias_attr != False:
+ bias = self.bias[:out_nc]
+ use_bias = True
+
+ pre_bias = _varbase_creator(dtype=input.dtype)
+ core.ops.matmul(input, weight, pre_bias, 'transpose_X', False,
+ 'transpose_Y', False, "alpha", 1)
+ if self._bias_attr != False:
+ pre_act = dygraph_utils._append_bias_in_dygraph(
+ pre_bias, bias, axis=len(input.shape) - 1)
+ else:
+ pre_act = pre_bias
+
+ return dygraph_utils._append_activation_in_dygraph(pre_act, self._act)
+
+
+class SuperBatchNorm(fluid.dygraph.BatchNorm):
+ """
+ add comment
+ """
+
+ def __init__(self,
+ num_channels,
+ act=None,
+ is_test=False,
+ momentum=0.9,
+ epsilon=1e-05,
+ param_attr=None,
+ bias_attr=None,
+ dtype='float32',
+ data_layout='NCHW',
+ in_place=False,
+ moving_mean_name=None,
+ moving_variance_name=None,
+ do_model_average_for_mean_and_var=True,
+ use_global_stats=False,
+ trainable_statistics=False):
+ super(SuperBatchNorm, self).__init__(
+ num_channels, act, is_test, momentum, epsilon, param_attr,
+ bias_attr, dtype, data_layout, in_place, moving_mean_name,
+ moving_variance_name, do_model_average_for_mean_and_var,
+ use_global_stats, trainable_statistics)
+
+ def forward(self, input):
+ if not in_dygraph_mode():
+ _logger.error("NOT support static graph")
+
+ feature_dim = int(input.shape[1])
+
+ weight = self.weight[:feature_dim]
+ bias = self.bias[:feature_dim]
+ mean = self._mean[:feature_dim]
+ variance = self._variance[:feature_dim]
+
+ mean_out = mean
+ variance_out = variance
+
+ attrs = ("momentum", self._momentum, "epsilon", self._epsilon,
+ "is_test", not self.training, "data_layout", self._data_layout,
+ "use_mkldnn", False, "fuse_with_relu", self._fuse_with_relu,
+ "use_global_stats", self._use_global_stats,
+ 'trainable_statistics', self._trainable_statistics)
+ batch_norm_out, _, _, _, _, _ = core.ops.batch_norm(
+ input, weight, bias, mean, variance, mean_out, variance_out, *attrs)
+ return dygraph_utils._append_activation_in_dygraph(
+ batch_norm_out, act=self._act)
+
+
+class SuperInstanceNorm(fluid.dygraph.InstanceNorm):
+ """
+ """
+
+ def __init__(self,
+ num_channels,
+ epsilon=1e-05,
+ param_attr=None,
+ bias_attr=None,
+ dtype='float32'):
+ super(SuperInstanceNorm, self).__init__(num_channels, epsilon,
+ param_attr, bias_attr, dtype)
+
+ def forward(self, input):
+ if not in_dygraph_mode():
+ _logger.error("NOT support static graph")
+
+ feature_dim = int(input.shape[1])
+
+ if self._param_attr == False and self._bias_attr == False:
+ scale = None
+ bias = None
+ else:
+ scale = self.scale[:feature_dim]
+ bias = self.bias[:feature_dim]
+
+ out, _, _ = core.ops.instance_norm(input, scale, bias, 'epsilon',
+ self._epsilon)
+ return out
diff --git a/paddleslim/nas/ofa/ofa.py b/paddleslim/nas/ofa/ofa.py
new file mode 100644
index 0000000000000000000000000000000000000000..9fd7f5ada5d0f59eabd9ac580b9453f183bd78f1
--- /dev/null
+++ b/paddleslim/nas/ofa/ofa.py
@@ -0,0 +1,319 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import numpy as np
+from collections import namedtuple
+import paddle
+import paddle.nn as nn
+import paddle.fluid as fluid
+from paddle.fluid.dygraph import Conv2D
+from .layers import BaseBlock, Block, SuperConv2D, SuperBatchNorm
+from .utils.utils import search_idx
+from ...common import get_logger
+
+_logger = get_logger(__name__, level=logging.INFO)
+
+__all__ = ['OFA', 'RunConfig', 'DistillConfig']
+
+RunConfig = namedtuple('RunConfig', [
+ 'train_batch_size', 'eval_batch_size', 'n_epochs', 'save_frequency',
+ 'eval_frequency', 'init_learning_rate', 'total_images', 'elastic_depth',
+ 'dynamic_batch_size'
+])
+RunConfig.__new__.__defaults__ = (None, ) * len(RunConfig._fields)
+
+DistillConfig = namedtuple('DistillConfig', [
+ 'lambda_distill', 'teacher_model', 'mapping_layers', 'teacher_model_path',
+ 'distill_fn'
+])
+DistillConfig.__new__.__defaults__ = (None, ) * len(DistillConfig._fields)
+
+
+class OFABase(fluid.dygraph.Layer):
+ def __init__(self, model):
+ super(OFABase, self).__init__()
+ self.model = model
+ self._layers, self._elastic_task = self.get_layers()
+
+ def get_layers(self):
+ layers = dict()
+ elastic_task = set()
+ for name, sublayer in self.model.named_sublayers():
+ if isinstance(sublayer, BaseBlock):
+ sublayer.set_supernet(self)
+ layers[sublayer.key] = sublayer.candidate_config
+ for k in sublayer.candidate_config.keys():
+ elastic_task.add(k)
+ return layers, elastic_task
+
+ def forward(self, *inputs, **kwargs):
+ raise NotImplementedError
+
+ # NOTE: config means set forward config for layers, used in distill.
+ def layers_forward(self, block, *inputs, **kwargs):
+ if getattr(self, 'current_config', None) != None:
+ assert block.key in self.current_config, 'DONNT have {} layer in config.'.format(
+ block.key)
+ config = self.current_config[block.key]
+ else:
+ config = dict()
+ logging.debug(self.model, config)
+
+ return block.fn(*inputs, **config)
+
+ @property
+ def layers(self):
+ return self._layers
+
+
+class OFA(OFABase):
+ def __init__(self,
+ model,
+ run_config,
+ net_config=None,
+ distill_config=None,
+ elastic_order=None,
+ train_full=False):
+ super(OFA, self).__init__(model)
+ self.net_config = net_config
+ self.run_config = run_config
+ self.distill_config = distill_config
+ self.elastic_order = elastic_order
+ self.train_full = train_full
+ self.iter_per_epochs = self.run_config.total_images // self.run_config.train_batch_size
+ self.iter = 0
+ self.dynamic_iter = 0
+ self.manual_set_task = False
+ self.task_idx = 0
+ self._add_teacher = False
+ self.netAs_param = []
+
+ for idx in range(len(run_config.n_epochs)):
+ assert isinstance(
+ run_config.init_learning_rate[idx],
+ list), "each candidate in init_learning_rate must be list"
+ assert isinstance(run_config.n_epochs[idx],
+ list), "each candidate in n_epochs must be list"
+
+ ### if elastic_order is none, use default order
+ if self.elastic_order is not None:
+ assert isinstance(self.elastic_order,
+ list), 'elastic_order must be a list'
+
+ if self.elastic_order is None:
+ self.elastic_order = []
+ # zero, elastic resulotion, write in demo
+ # first, elastic kernel size
+ if 'kernel_size' in self._elastic_task:
+ self.elastic_order.append('kernel_size')
+
+ # second, elastic depth, such as: list(2, 3, 4)
+ if getattr(self.run_config, 'elastic_depth', None) != None:
+ depth_list = list(set(self.run_config.elastic_depth))
+ depth_list.sort()
+ self.layers['depth'] = depth_list
+ self.elastic_order.append('depth')
+
+ # final, elastic width
+ if 'expand_ratio' in self._elastic_task:
+ self.elastic_order.append('width')
+
+ if 'channel' in self._elastic_task and 'width' not in self.elastic_order:
+ self.elastic_order.append('width')
+
+ assert len(self.run_config.n_epochs) == len(self.elastic_order)
+ assert len(self.run_config.n_epochs) == len(
+ self.run_config.dynamic_batch_size)
+ assert len(self.run_config.n_epochs) == len(
+ self.run_config.init_learning_rate)
+
+ ### ================= add distill prepare ======================
+ if self.distill_config != None and (
+ self.distill_config.lambda_distill != None and
+ self.distill_config.lambda_distill > 0):
+ self._add_teacher = True
+ self._prepare_distill()
+
+ self.model.train()
+
+ def _prepare_distill(self):
+ self.Tacts, self.Sacts = {}, {}
+
+ if self.distill_config.teacher_model == None:
+ logging.error(
+ 'If you want to add distill, please input class of teacher model'
+ )
+
+ assert isinstance(self.distill_config.teacher_model,
+ paddle.fluid.dygraph.Layer)
+
+ # load teacher parameter
+ if self.distill_config.teacher_model_path != None:
+ param_state_dict, _ = paddle.load_dygraph(
+ self.distill_config.teacher_model_path)
+ self.distill_config.teacher_model.set_dict(param_state_dict)
+
+ self.ofa_teacher_model = OFABase(self.distill_config.teacher_model)
+ self.ofa_teacher_model.model.eval()
+
+ # add hook if mapping layers is not None
+ # if mapping layer is None, return the output of the teacher model,
+ # if mapping layer is NOT None, add hook and compute distill loss about mapping layers.
+ mapping_layers = self.distill_config.mapping_layers
+ if mapping_layers != None:
+ self.netAs = []
+ for name, sublayer in self.model.named_sublayers():
+ if name in mapping_layers:
+ netA = SuperConv2D(
+ sublayer._num_filters,
+ sublayer._num_filters,
+ filter_size=1)
+ self.netAs_param.extend(netA.parameters())
+ self.netAs.append(netA)
+
+ def get_activation(mem, name):
+ def get_output_hook(layer, input, output):
+ mem[name] = output
+
+ return get_output_hook
+
+ def add_hook(net, mem, mapping_layers):
+ for idx, (n, m) in enumerate(net.named_sublayers()):
+ if n in mapping_layers:
+ m.register_forward_post_hook(get_activation(mem, n))
+
+ add_hook(self.model, self.Sacts, mapping_layers)
+ add_hook(self.ofa_teacher_model.model, self.Tacts, mapping_layers)
+
+ def _compute_epochs(self):
+ if getattr(self, 'epoch', None) == None:
+ epoch = self.iter // self.iter_per_epochs
+ else:
+ epoch = self.epochs
+ return epoch
+
+ def _sample_from_nestdict(self, cands, sample_type, task, phase):
+ sample_cands = dict()
+ for k, v in cands.items():
+ if isinstance(v, dict):
+ sample_cands[k] = self._sample_from_nestdict(
+ v, sample_type=sample_type, task=task, phase=phase)
+ elif isinstance(v, list) or isinstance(v, set) or isinstance(v,
+ tuple):
+ if sample_type == 'largest':
+ sample_cands[k] = v[-1]
+ elif sample_type == 'smallest':
+ sample_cands[k] = v[0]
+ else:
+ if k not in task:
+ # sort and deduplication in candidate_config
+ # fixed candidate not in task_list
+ sample_cands[k] = v[-1]
+ else:
+ # phase == None -> all candidate; phase == number, append small candidate in each phase
+ # phase only affect last task in current task_list
+ if phase != None and k == task[-1]:
+ start = -(phase + 2)
+ else:
+ start = 0
+ sample_cands[k] = np.random.choice(v[start:])
+
+ return sample_cands
+
+ def _sample_config(self, task, sample_type='random', phase=None):
+ config = self._sample_from_nestdict(
+ self.layers, sample_type=sample_type, task=task, phase=phase)
+ return config
+
+ def set_task(self, task=None, phase=None):
+ self.manual_set_task = True
+ self.task = task
+ self.phase = phase
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
+
+ def _progressive_shrinking(self):
+ epoch = self._compute_epochs()
+ self.task_idx, phase_idx = search_idx(epoch, self.run_config.n_epochs)
+ self.task = self.elastic_order[:self.task_idx + 1]
+ if 'width' in self.task:
+ ### change width in task to concrete config
+ self.task.remove('width')
+ if 'expand_ratio' in self._elastic_task:
+ self.task.append('expand_ratio')
+ if 'channel' in self._elastic_task:
+ self.task.append('channel')
+ if len(self.run_config.n_epochs[self.task_idx]) == 1:
+ phase_idx = None
+ return self._sample_config(task=self.task, phase=phase_idx)
+
+ def calc_distill_loss(self):
+ losses = []
+ assert len(self.netAs) > 0
+ for i, netA in enumerate(self.netAs):
+ assert isinstance(netA, SuperConv2D)
+ n = self.distill_config.mapping_layers[i]
+ Tact = self.Tacts[n]
+ Sact = self.Sacts[n]
+ Sact = netA(Sact, channel=netA._num_filters)
+ if self.distill_config.distill_fn == None:
+ loss = fluid.layers.mse_loss(Sact, Tact)
+ else:
+ loss = distill_fn(Sact, Tact)
+ losses.append(loss)
+ return sum(losses) * self.distill_config.lambda_distill
+
+ ### TODO: complete it
+ def search(self, eval_func, condition):
+ pass
+
+ ### TODO: complete it
+ def export(self, config):
+ pass
+
+ def forward(self, *inputs, **kwargs):
+ # ===================== teacher process =====================
+ teacher_output = None
+ if self._add_teacher:
+ teacher_output = self.ofa_teacher_model.model.forward(*inputs,
+ **kwargs)
+ # ============================================================
+
+ # ==================== student process =====================
+ self.dynamic_iter += 1
+ if self.dynamic_iter == self.run_config.dynamic_batch_size[
+ self.task_idx]:
+ self.iter += 1
+ self.dynamic_iter = 0
+
+ if self.net_config == None:
+ if self.train_full == True:
+ self.current_config = self._sample_config(
+ task=None, sample_type='largest')
+ else:
+ if self.manual_set_task == False:
+ self.current_config = self._progressive_shrinking()
+ else:
+ self.current_config = self._sample_config(
+ self.task, phase=self.phase)
+ else:
+ self.current_config = self.net_config
+
+ _logger.debug("Current config is {}".format(self.current_config))
+ if 'depth' in self.current_config:
+ kwargs['depth'] = int(self.current_config['depth'])
+
+ return self.model.forward(*inputs, **kwargs), teacher_output
diff --git a/paddleslim/nas/ofa/utils/__init__.py b/paddleslim/nas/ofa/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..342ae0eddcff168fb62bb08708af868dbc808aa5
--- /dev/null
+++ b/paddleslim/nas/ofa/utils/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .utils import *
diff --git a/paddleslim/nas/ofa/utils/utils.py b/paddleslim/nas/ofa/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fad016a754f61df9c72c04956901d978db0b6df6
--- /dev/null
+++ b/paddleslim/nas/ofa/utils/utils.py
@@ -0,0 +1,46 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+def compute_start_end(kernel_size, sub_kernel_size):
+ center = kernel_size // 2
+ sub_center = sub_kernel_size // 2
+ start = center - sub_center
+ end = center + sub_center + 1
+ assert end - start == sub_kernel_size
+ return start, end
+
+
+def get_same_padding(kernel_size):
+ assert isinstance(kernel_size, int)
+ assert kernel_size % 2 > 0, "kernel size must be odd number"
+ return kernel_size // 2
+
+
+def convert_to_list(value, n):
+ return [value, ] * n
+
+
+def search_idx(num, sorted_nestlist):
+ max_num = -1
+ max_idx = -1
+ for idx in range(len(sorted_nestlist)):
+ task_ = sorted_nestlist[idx]
+ max_num = task_[-1]
+ max_idx = len(task_) - 1
+ for phase_idx in range(len(task_)):
+ if num <= task_[phase_idx]:
+ return idx, phase_idx
+ assert num > max_num
+ return len(sorted_nestlist) - 1, max_idx
diff --git a/tests/test_ofa.py b/tests/test_ofa.py
new file mode 100644
index 0000000000000000000000000000000000000000..b65d12e74a6f9ece7866db8468f7e8a1337e485c
--- /dev/null
+++ b/tests/test_ofa.py
@@ -0,0 +1,216 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+sys.path.append("../")
+import numpy as np
+import unittest
+import paddle
+import paddle.fluid as fluid
+import paddle.fluid.dygraph.nn as nn
+from paddle.nn import ReLU
+from paddleslim.nas import ofa
+from paddleslim.nas.ofa import OFA, RunConfig, DistillConfig
+from paddleslim.nas.ofa.convert_super import supernet
+from paddleslim.nas.ofa.layers import Block, SuperSeparableConv2D
+
+
+class ModelConv(fluid.dygraph.Layer):
+ def __init__(self):
+ super(ModelConv, self).__init__()
+ with supernet(
+ kernel_size=(3, 5, 7),
+ channel=((4, 8, 12), (8, 12, 16), (8, 12, 16),
+ (8, 12, 16))) as ofa_super:
+ models = []
+ models += [nn.Conv2D(3, 4, 3)]
+ models += [nn.InstanceNorm(4)]
+ models += [ReLU()]
+ models += [nn.Conv2D(4, 4, 3, groups=4)]
+ models += [nn.InstanceNorm(4)]
+ models += [ReLU()]
+ models += [nn.Conv2DTranspose(4, 4, 3, groups=4, use_cudnn=True)]
+ models += [nn.BatchNorm(4)]
+ models += [ReLU()]
+ models += [nn.Conv2D(4, 3, 3)]
+ models += [ReLU()]
+ models = ofa_super.convert(models)
+
+ models += [
+ Block(
+ SuperSeparableConv2D(
+ 3, 6, 1, candidate_config={'channel': (3, 6)}))
+ ]
+ with supernet(
+ kernel_size=(3, 5, 7), expand_ratio=(1, 2, 4)) as ofa_super:
+ models1 = []
+ models1 += [nn.Conv2D(6, 4, 3)]
+ models1 += [nn.BatchNorm(4)]
+ models1 += [ReLU()]
+ models1 += [nn.Conv2D(4, 4, 3, groups=2)]
+ models1 += [nn.InstanceNorm(4)]
+ models1 += [ReLU()]
+ models1 += [nn.Conv2DTranspose(4, 4, 3, groups=2)]
+ models1 += [nn.BatchNorm(4)]
+ models1 += [ReLU()]
+ models1 += [nn.Conv2DTranspose(4, 4, 3)]
+ models1 += [nn.BatchNorm(4)]
+ models1 += [ReLU()]
+ models1 = ofa_super.convert(models1)
+
+ models += models1
+
+ self.models = paddle.nn.Sequential(*models)
+
+ def forward(self, inputs, depth=None):
+ if depth != None:
+ assert isinstance(depth, int)
+ assert depth <= len(self.models)
+ else:
+ depth = len(self.models)
+ for idx in range(depth):
+ layer = self.models[idx]
+ inputs = layer(inputs)
+ return inputs
+
+
+class ModelLinear(fluid.dygraph.Layer):
+ def __init__(self):
+ super(ModelLinear, self).__init__()
+ models = []
+ with supernet(expand_ratio=(1, 2, 4)) as ofa_super:
+ models1 = []
+ models1 += [nn.Linear(64, 128)]
+ models1 += [nn.Linear(128, 256)]
+ models1 = ofa_super.convert(models1)
+
+ models += models1
+
+ with supernet(channel=((64, 128, 256), (64, 128, 256))) as ofa_super:
+ models1 = []
+ models1 += [nn.Linear(256, 128)]
+ models1 += [nn.Linear(128, 256)]
+ models1 = ofa_super.convert(models1)
+
+ models += models1
+
+ self.models = paddle.nn.Sequential(*models)
+
+ def forward(self, inputs, depth=None):
+ if depth != None:
+ assert isinstance(depth, int)
+ assert depth < len(self.models)
+ else:
+ depth = len(self.models)
+ for idx in range(depth):
+ layer = self.models[idx]
+ inputs = layer(inputs)
+ return inputs
+
+
+class TestOFA(unittest.TestCase):
+ def setUp(self):
+ fluid.enable_dygraph()
+ self.init_model_and_data()
+ self.init_config()
+
+ def init_model_and_data(self):
+ self.model = ModelConv()
+ self.teacher_model = ModelConv()
+ data_np = np.random.random((1, 3, 10, 10)).astype(np.float32)
+ label_np = np.random.random((1)).astype(np.float32)
+
+ self.data = fluid.dygraph.to_variable(data_np)
+
+ def init_config(self):
+ default_run_config = {
+ 'train_batch_size': 1,
+ 'eval_batch_size': 1,
+ 'n_epochs': [[1], [2, 3], [4, 5]],
+ 'init_learning_rate': [[0.001], [0.003, 0.001], [0.003, 0.001]],
+ 'dynamic_batch_size': [1, 1, 1],
+ 'total_images': 1,
+ 'elastic_depth': (5, 15, 24)
+ }
+ self.run_config = RunConfig(**default_run_config)
+
+ default_distill_config = {
+ 'lambda_distill': 0.01,
+ 'teacher_model': self.teacher_model,
+ 'mapping_layers': ['models.0.fn']
+ }
+ self.distill_config = DistillConfig(**default_distill_config)
+
+ def test_ofa(self):
+ ofa_model = OFA(self.model,
+ self.run_config,
+ distill_config=self.distill_config)
+
+ start_epoch = 0
+ for idx in range(len(self.run_config.n_epochs)):
+ cur_idx = self.run_config.n_epochs[idx]
+ for ph_idx in range(len(cur_idx)):
+ cur_lr = self.run_config.init_learning_rate[idx][ph_idx]
+ adam = fluid.optimizer.Adam(
+ learning_rate=cur_lr,
+ parameter_list=(
+ ofa_model.parameters() + ofa_model.netAs_param))
+ for epoch_id in range(start_epoch,
+ self.run_config.n_epochs[idx][ph_idx]):
+ for model_no in range(self.run_config.dynamic_batch_size[
+ idx]):
+ output, _ = ofa_model(self.data)
+ loss = fluid.layers.reduce_mean(output)
+ if self.distill_config.mapping_layers != None:
+ dis_loss = ofa_model.calc_distill_loss()
+ loss += dis_loss
+ dis_loss = dis_loss.numpy()[0]
+ else:
+ dis_loss = 0
+ print('epoch: {}, loss: {}, distill loss: {}'.format(
+ epoch_id, loss.numpy()[0], dis_loss))
+ loss.backward()
+ adam.minimize(loss)
+ adam.clear_gradients()
+ start_epoch = self.run_config.n_epochs[idx][ph_idx]
+
+
+class TestOFACase1(TestOFA):
+ def init_model_and_data(self):
+ self.model = ModelLinear()
+ self.teacher_model = ModelLinear()
+ data_np = np.random.random((3, 64)).astype(np.float32)
+
+ self.data = fluid.dygraph.to_variable(data_np)
+
+ def init_config(self):
+ default_run_config = {
+ 'train_batch_size': 1,
+ 'eval_batch_size': 1,
+ 'n_epochs': [[2, 5]],
+ 'init_learning_rate': [[0.003, 0.001]],
+ 'dynamic_batch_size': [1],
+ 'total_images': 1,
+ }
+ self.run_config = RunConfig(**default_run_config)
+
+ default_distill_config = {
+ 'lambda_distill': 0.01,
+ 'teacher_model': self.teacher_model,
+ }
+ self.distill_config = DistillConfig(**default_distill_config)
+
+
+if __name__ == '__main__':
+ unittest.main()