diff --git a/demo/one_shot/ofa_train.py b/demo/one_shot/ofa_train.py index 4a47a219c1096d750757f407cfde4ff37691efb7..bc7f864cef4e0543de0556d827e4f20b161d1e0b 100644 --- a/demo/one_shot/ofa_train.py +++ b/demo/one_shot/ofa_train.py @@ -14,14 +14,14 @@ import numpy as np import paddle -import paddle.fluid as fluid -import paddle.fluid.dygraph.nn as nn +import paddle.nn as nn +import paddle.nn.functional as F from paddle.nn import ReLU from paddleslim.nas.ofa import OFA, RunConfig, DistillConfig from paddleslim.nas.ofa import supernet -class Model(fluid.dygraph.Layer): +class Model(nn.Layer): def __init__(self): super(Model, self).__init__() with supernet( @@ -50,18 +50,20 @@ class Model(fluid.dygraph.Layer): for idx, layer in enumerate(models): if idx == 6: - inputs = fluid.layers.flatten(inputs, 1) + inputs = paddle.flatten(inputs, 1) inputs = layer(inputs) - inputs = fluid.layers.softmax(inputs) + inputs = F.softmax(inputs) return inputs def test_ofa(): + model = Model() + teacher_model = Model() + default_run_config = { 'train_batch_size': 256, - 'eval_batch_size': 64, 'n_epochs': [[1], [2, 3], [4, 5]], 'init_learning_rate': [[0.001], [0.003, 0.001], [0.003, 0.001]], 'dynamic_batch_size': [1, 1, 1], @@ -72,42 +74,46 @@ def test_ofa(): default_distill_config = { 'lambda_distill': 0.01, - 'teacher_model': Model, + 'teacher_model': teacher_model, 'mapping_layers': ['models.0.fn'] } distill_config = DistillConfig(**default_distill_config) - fluid.enable_dygraph() - model = Model() ofa_model = OFA(model, run_config, distill_config=distill_config) - train_reader = paddle.fluid.io.batch( - paddle.dataset.mnist.train(), batch_size=256, drop_last=True) + train_dataset = paddle.vision.datasets.MNIST( + mode='train', backend='cv2', transform=transform) + train_loader = paddle.io.DataLoader( + train_dataset, + places=place, + feed_list=[image, label], + drop_last=True, + batch_size=64) start_epoch = 0 for idx in range(len(run_config.n_epochs)): cur_idx = run_config.n_epochs[idx] for ph_idx in range(len(cur_idx)): cur_lr = run_config.init_learning_rate[idx][ph_idx] - adam = fluid.optimizer.Adam( + adam = paddle.optimizer.Adam( learning_rate=cur_lr, parameter_list=(ofa_model.parameters() + ofa_model.netAs_param)) for epoch_id in range(start_epoch, run_config.n_epochs[idx][ph_idx]): - for batch_id, data in enumerate(train_reader()): + for batch_id, data in enumerate(train_loader()): dy_x_data = np.array( [x[0].reshape(1, 28, 28) for x in data]).astype('float32') y_data = np.array( [x[1] for x in data]).astype('int64').reshape(-1, 1) - img = fluid.dygraph.to_variable(dy_x_data) - label = fluid.dygraph.to_variable(y_data) + img = paddle.dygraph.to_variable(dy_x_data) + label = paddle.dygraph.to_variable(y_data) label.stop_gradient = True for model_no in range(run_config.dynamic_batch_size[idx]): output, _ = ofa_model(img, label) - loss = fluid.layers.reduce_mean(output) + loss = F.mean(output) dis_loss = ofa_model.calc_distill_loss() loss += dis_loss loss.backward() diff --git a/paddleslim/nas/__init__.py b/paddleslim/nas/__init__.py index d438e54c572efb25fdf9f33db0d4c4d10a2487e7..6932afa1fd8d97dfaf51d7d8cc15995067fabe65 100644 --- a/paddleslim/nas/__init__.py +++ b/paddleslim/nas/__init__.py @@ -19,6 +19,7 @@ from .sa_nas import * from .rl_nas import * from ..nas import darts from .darts import * +from .ofa import * __all__ = [] __all__ += sa_nas.__all__ diff --git a/paddleslim/nas/ofa/convert_super.py b/paddleslim/nas/ofa/convert_super.py index f7ff8a1e530cef850415049c1d8a1b42dfcc0345..a95d1fd03ebc2ab2855b9301d175c7bb2371858b 100644 --- a/paddleslim/nas/ofa/convert_super.py +++ b/paddleslim/nas/ofa/convert_super.py @@ -16,9 +16,8 @@ import inspect import decorator import logging import paddle -import paddle.fluid as fluid -from paddle.fluid import framework -from paddle.fluid.dygraph.nn import Conv2D, Conv2DTranspose, Linear, BatchNorm, InstanceNorm +import numbers +from paddle.fluid.dygraph.nn import Conv2D, Conv2DTranspose, Linear, BatchNorm, InstanceNorm, LayerNorm, Embedding from .layers import * from ...common import get_logger @@ -26,7 +25,7 @@ _logger = get_logger(__name__, level=logging.INFO) __all__ = ['supernet'] -WEIGHT_LAYER = ['conv', 'linear'] +WEIGHT_LAYER = ['conv', 'linear', 'embedding'] ### TODO: add decorator @@ -45,7 +44,7 @@ class Convert: cur_channel = None for idx, layer in enumerate(model): cls_name = layer.__class__.__name__.lower() - if 'conv' in cls_name or 'linear' in cls_name: + if 'conv' in cls_name or 'linear' in cls_name or 'embedding' in cls_name: weight_layer_count += 1 last_weight_layer_idx = idx if first_weight_layer_idx == -1: @@ -63,7 +62,7 @@ class Convert: new_attr_name = [ '_stride', '_dilation', '_groups', '_param_attr', - '_bias_attr', '_use_cudnn', '_act', '_dtype' + '_bias_attr', '_use_cudnn', '_act', '_dtype', '_padding' ] new_attr_dict = dict() @@ -179,6 +178,8 @@ class Convert: layer._parameters['weight'].shape[0]) elif self.context.channel: new_attr_dict['num_channels'] = max(cur_channel) + else: + new_attr_dict['num_channels'] = attr_dict['_num_channels'] for attr in new_attr_name: new_attr_dict[attr[1:]] = attr_dict[attr] @@ -196,7 +197,8 @@ class Convert: new_attr_name = [ '_stride', '_dilation', '_groups', '_param_attr', - '_bias_attr', '_use_cudnn', '_act', '_dtype', '_output_size' + '_padding', '_bias_attr', '_use_cudnn', '_act', '_dtype', + '_output_size' ] assert attr_dict[ '_filter_size'] != None, "Conv2DTranspose only support filter size != None now" @@ -371,6 +373,8 @@ class Convert: layer._parameters['scale'].shape[0]) elif self.context.channel: new_attr_dict['num_channels'] = max(cur_channel) + else: + new_attr_dict['num_channels'] = attr_dict['_num_channels'] for attr in new_attr_name: new_attr_dict[attr[1:]] = attr_dict[attr] @@ -380,6 +384,76 @@ class Convert: layer = SuperInstanceNorm(**new_attr_dict) model[idx] = layer + elif isinstance(layer, LayerNorm) and ( + getattr(self.context, 'expand', None) != None or + getattr(self.context, 'channel', None) != None): + ### TODO(ceci3): fix when normalized_shape != last_dim_of_input + if idx > last_weight_layer_idx: + continue + + attr_dict = layer.__dict__ + new_attr_name = [ + '_scale', '_shift', '_param_attr', '_bias_attr', '_act', + '_dtype', '_epsilon' + ] + new_attr_dict = dict() + if self.context.expand: + new_attr_dict[ + 'normalized_shape'] = self.context.expand * int( + attr_dict['_normalized_shape'][0]) + elif self.context.channel: + new_attr_dict['normalized_shape'] = max(cur_channel) + else: + new_attr_dict['normalized_shape'] = attr_dict[ + '_normalized_shape'] + + for attr in new_attr_name: + new_attr_dict[attr[1:]] = attr_dict[attr] + + del layer, attr_dict + layer = SuperLayerNorm(**new_attr_dict) + model[idx] = layer + + elif isinstance(layer, Embedding) and ( + getattr(self.context, 'expand', None) != None or + getattr(self.context, 'channel', None) != None): + attr_dict = layer.__dict__ + key = attr_dict['_full_name'] + new_attr_name = [ + '_is_sparse', '_is_distributed', '_padding_idx', + '_param_attr', '_dtype' + ] + + new_attr_dict = dict() + new_attr_dict['candidate_config'] = dict() + bef_size = attr_dict['_size'] + if self.context.expand: + new_attr_dict['size'] = [ + bef_size[0], self.context.expand * bef_size[1] + ] + new_attr_dict['candidate_config'].update({ + 'expand_ratio': self.context.expand_ratio + }) + + elif self.context.channel: + cur_channel = self.context.channel[0] + self.context.channel = self.context.channel[1:] + new_attr_dict['size'] = [bef_size[0], max(cur_channel)] + new_attr_dict['candidate_config'].update({ + 'channel': cur_channel + }) + pre_channel = cur_channel + else: + new_attr_dict['size'] = bef_size + + for attr in new_attr_name: + new_attr_dict[attr[1:]] = attr_dict[attr] + + del layer, attr_dict + + layer = Block(SuperEmbedding(**new_attr_dict), key=key) + model[idx] = layer + return model diff --git a/paddleslim/nas/ofa/layers.py b/paddleslim/nas/ofa/layers.py index 4d91f5338a8a1f9ee67cc1d7dab2657d85348454..4b9f05bb2a8471f460a38843dbe64030842035c2 100644 --- a/paddleslim/nas/ofa/layers.py +++ b/paddleslim/nas/ofa/layers.py @@ -28,7 +28,7 @@ __all__ = [ 'SuperConv2D', 'SuperConv2DTranspose', 'SuperSeparableConv2D', 'SuperBatchNorm', 'SuperLinear', 'SuperInstanceNorm', 'Block', 'SuperGroupConv2D', 'SuperDepthwiseConv2D', 'SuperGroupConv2DTranspose', - 'SuperDepthwiseConv2DTranspose' + 'SuperDepthwiseConv2DTranspose', 'SuperLayerNorm', 'SuperEmbedding' ] _logger = get_logger(__name__, level=logging.INFO) @@ -70,9 +70,10 @@ class Block(BaseBlock): key(str, optional): key of this layer, one-to-one correspondence between key and candidate config. Default: None. """ - def __init__(self, fn, key=None): + def __init__(self, fn, fixed=False, key=None): super(Block, self).__init__(key) self.fn = fn + self.fixed = fixed self.candidate_config = self.fn.candidate_config def forward(self, *inputs, **kwargs): @@ -208,7 +209,6 @@ class SuperConv2D(fluid.dygraph.Conv2D): act=None, dtype='float32'): ### NOTE: padding always is 0, add padding in forward because of kernel size is uncertain - ### TODO: change padding to any padding super(SuperConv2D, self).__init__( num_channels, num_filters, filter_size, stride, padding, dilation, groups, param_attr, bias_attr, use_cudnn, act, dtype) @@ -228,7 +228,7 @@ class SuperConv2D(fluid.dygraph.Conv2D): 'expand_ratio'] if 'expand_ratio' in candidate_config else None self.channel = candidate_config[ 'channel'] if 'channel' in candidate_config else None - self.base_channel = None + self.base_channel = self._num_filters if self.expand_ratio != None: self.base_channel = int(self._num_filters / max(self.expand_ratio)) @@ -296,6 +296,11 @@ class SuperConv2D(fluid.dygraph.Conv2D): if not in_dygraph_mode(): _logger.error("NOT support static graph") + self.cur_config = { + 'kernel_size': kernel_size, + 'expand_ratio': expand_ratio, + 'channel': channel + } in_nc = int(input.shape[1]) assert ( expand_ratio == None or channel == None @@ -313,7 +318,11 @@ class SuperConv2D(fluid.dygraph.Conv2D): out_nc) weight = self.get_active_filter(weight_in_nc, weight_out_nc, ks) - padding = convert_to_list(get_same_padding(ks), 2) + + if kernel_size != None or 'kernel_size' in self.candidate_config.keys(): + padding = convert_to_list(get_same_padding(ks), 2) + else: + padding = self._padding if self._l_type == 'conv2d': attrs = ('strides', self._stride, 'paddings', padding, 'dilations', @@ -488,7 +497,6 @@ class SuperConv2DTranspose(fluid.dygraph.Conv2DTranspose): use_cudnn=True, act=None, dtype='float32'): - ### NOTE: padding always is 0, add padding in forward because of kernel size is uncertain super(SuperConv2DTranspose, self).__init__( num_channels, num_filters, filter_size, output_size, padding, stride, dilation, groups, param_attr, bias_attr, use_cudnn, act, @@ -507,7 +515,7 @@ class SuperConv2DTranspose(fluid.dygraph.Conv2DTranspose): 'expand_ratio'] if 'expand_ratio' in candidate_config else None self.channel = candidate_config[ 'channel'] if 'channel' in candidate_config else None - self.base_channel = None + self.base_channel = self._num_filters if self.expand_ratio: self.base_channel = int(self._num_filters / max(self.expand_ratio)) @@ -572,6 +580,11 @@ class SuperConv2DTranspose(fluid.dygraph.Conv2DTranspose): if not in_dygraph_mode(): _logger.error("NOT support static graph") + self.cur_config = { + 'kernel_size': kernel_size, + 'expand_ratio': expand_ratio, + 'channel': channel + } in_nc = int(input.shape[1]) assert ( expand_ratio == None or channel == None @@ -590,7 +603,10 @@ class SuperConv2DTranspose(fluid.dygraph.Conv2DTranspose): out_nc) weight = self.get_active_filter(weight_in_nc, weight_out_nc, ks) - padding = convert_to_list(get_same_padding(ks), 2) + if kernel_size != None or 'kernel_size' in self.candidate_config.keys(): + padding = convert_to_list(get_same_padding(ks), 2) + else: + padding = self._padding op = getattr(core.ops, self._op_type) out = op(input, weight, 'output_size', self._output_size, 'strides', @@ -701,7 +717,7 @@ class SuperSeparableConv2D(fluid.dygraph.Layer): self.conv.extend([norm_layer(num_channels * scale_factor)]) self.conv.extend([ - Conv2D( + fluid.dygraph.nn.Conv2D( num_channels=num_channels * scale_factor, num_filters=num_filters, filter_size=1, @@ -713,14 +729,16 @@ class SuperSeparableConv2D(fluid.dygraph.Layer): self.candidate_config = candidate_config self.expand_ratio = candidate_config[ 'expand_ratio'] if 'expand_ratio' in candidate_config else None - self.base_output_dim = None + self.base_output_dim = self.conv[0]._num_filters if self.expand_ratio != None: - self.base_output_dim = int(self.output_dim / max(self.expand_ratio)) + self.base_output_dim = int(self.conv[0]._num_filters / + max(self.expand_ratio)) def forward(self, input, expand_ratio=None, channel=None): if not in_dygraph_mode(): _logger.error("NOT support static graph") + self.cur_config = {'expand_ratio': expand_ratio, 'channel': channel} in_nc = int(input.shape[1]) assert ( expand_ratio == None or channel == None @@ -809,7 +827,7 @@ class SuperLinear(fluid.dygraph.Linear): self.candidate_config = candidate_config self.expand_ratio = candidate_config[ 'expand_ratio'] if 'expand_ratio' in candidate_config else None - self.base_output_dim = None + self.base_output_dim = self.output_dim if self.expand_ratio != None: self.base_output_dim = int(self.output_dim / max(self.expand_ratio)) @@ -817,8 +835,9 @@ class SuperLinear(fluid.dygraph.Linear): if not in_dygraph_mode(): _logger.error("NOT support static graph") + self.cur_config = {'expand_ratio': expand_ratio, 'channel': channel} ### weight: (Cin, Cout) - in_nc = int(input.shape[1]) + in_nc = int(input.shape[-1]) assert ( expand_ratio == None or channel == None ), "expand_ratio and channel CANNOT be NOT None at the same time." @@ -927,3 +946,77 @@ class SuperInstanceNorm(fluid.dygraph.InstanceNorm): out, _, _ = core.ops.instance_norm(input, scale, bias, 'epsilon', self._epsilon) return out + + +class SuperLayerNorm(fluid.dygraph.LayerNorm): + def __init__(self, + normalized_shape, + candidate_config={}, + scale=True, + shift=True, + epsilon=1e-05, + param_attr=None, + bias_attr=None, + act=None, + dtype='float32'): + super(SuperLayerNorm, + self).__init__(normalized_shape, scale, shift, epsilon, + param_attr, bias_attr, act, dtype) + + def forward(self, input): + if not in_dygraph_mode(): + _logger.error("NOT support static graph") + + input_shape = list(input.shape) + input_ndim = len(input_shape) + normalized_ndim = len(self._normalized_shape) + self._begin_norm_axis = input_ndim - normalized_ndim + + ### TODO(ceci3): fix if normalized_shape is not a single number + feature_dim = int(input.shape[-1]) + weight = self.weight[:feature_dim] + bias = self.bias[:feature_dim] + pre_act, _, _ = core.ops.layer_norm(input, weight, bias, 'epsilon', + self._epsilon, 'begin_norm_axis', + self._begin_norm_axis) + return dygraph_utils._append_activation_in_dygraph( + pre_act, act=self._act) + + +class SuperEmbedding(fluid.dygraph.Embedding): + def __init__(self, + size, + candidate_config={}, + is_sparse=False, + is_distributed=False, + padding_idx=None, + param_attr=None, + dtype='float32'): + super(SuperEmbedding, self).__init__(size, is_sparse, is_distributed, + padding_idx, param_attr, dtype) + self.candidate_config = candidate_config + self.expand_ratio = candidate_config[ + 'expand_ratio'] if 'expand_ratio' in candidate_config else None + self.base_output_dim = self._size[-1] + if self.expand_ratio != None: + self.base_output_dim = int(self._size[-1] / max(self.expand_ratio)) + + def forward(self, input, expand_ratio=None, channel=None): + if not in_dygraph_mode(): + _logger.error("NOT support static graph") + + assert ( + expand_ratio == None or channel == None + ), "expand_ratio and channel CANNOT be NOT None at the same time." + if expand_ratio != None: + out_nc = int(expand_ratio * self.base_output_dim) + elif channel != None: + out_nc = int(channel) + else: + out_nc = self._size[-1] + + weight = self.weight[:, :out_nc] + return core.ops.lookup_table_v2( + weight, input, 'is_sparse', self._is_sparse, 'is_distributed', + self._is_distributed, 'remote_prefetch', self._remote_prefetch, + 'padding_idx', self._padding_idx) diff --git a/paddleslim/nas/ofa/ofa.py b/paddleslim/nas/ofa/ofa.py index 9fd7f5ada5d0f59eabd9ac580b9453f183bd78f1..682cd1c0b45234fa5618aaa492d254c635ac38a3 100644 --- a/paddleslim/nas/ofa/ofa.py +++ b/paddleslim/nas/ofa/ofa.py @@ -16,7 +16,7 @@ import logging import numpy as np from collections import namedtuple import paddle -import paddle.nn as nn +#import paddle.nn as nn import paddle.fluid as fluid from paddle.fluid.dygraph import Conv2D from .layers import BaseBlock, Block, SuperConv2D, SuperBatchNorm @@ -28,9 +28,8 @@ _logger = get_logger(__name__, level=logging.INFO) __all__ = ['OFA', 'RunConfig', 'DistillConfig'] RunConfig = namedtuple('RunConfig', [ - 'train_batch_size', 'eval_batch_size', 'n_epochs', 'save_frequency', - 'eval_frequency', 'init_learning_rate', 'total_images', 'elastic_depth', - 'dynamic_batch_size' + 'train_batch_size', 'n_epochs', 'save_frequency', 'eval_frequency', + 'init_learning_rate', 'total_images', 'elastic_depth', 'dynamic_batch_size' ]) RunConfig.__new__.__defaults__ = (None, ) * len(RunConfig._fields) @@ -53,20 +52,26 @@ class OFABase(fluid.dygraph.Layer): for name, sublayer in self.model.named_sublayers(): if isinstance(sublayer, BaseBlock): sublayer.set_supernet(self) - layers[sublayer.key] = sublayer.candidate_config - for k in sublayer.candidate_config.keys(): - elastic_task.add(k) + if not sublayer.fixed: + layers[sublayer.key] = sublayer.candidate_config + for k in sublayer.candidate_config.keys(): + elastic_task.add(k) return layers, elastic_task def forward(self, *inputs, **kwargs): raise NotImplementedError - # NOTE: config means set forward config for layers, used in distill. def layers_forward(self, block, *inputs, **kwargs): if getattr(self, 'current_config', None) != None: - assert block.key in self.current_config, 'DONNT have {} layer in config.'.format( - block.key) - config = self.current_config[block.key] + ### if block is fixed, donnot join key into candidate + ### concrete config as parameter in kwargs + if block.fixed == False: + assert block.key in self.current_config, 'DONNT have {} layer in config.'.format( + block.key) + config = self.current_config[block.key] + else: + config = dict() + config.update(kwargs) else: config = dict() logging.debug(self.model, config) @@ -81,7 +86,7 @@ class OFABase(fluid.dygraph.Layer): class OFA(OFABase): def __init__(self, model, - run_config, + run_config=None, net_config=None, distill_config=None, elastic_order=None, @@ -92,7 +97,6 @@ class OFA(OFABase): self.distill_config = distill_config self.elastic_order = elastic_order self.train_full = train_full - self.iter_per_epochs = self.run_config.total_images // self.run_config.train_batch_size self.iter = 0 self.dynamic_iter = 0 self.manual_set_task = False @@ -100,18 +104,16 @@ class OFA(OFABase): self._add_teacher = False self.netAs_param = [] - for idx in range(len(run_config.n_epochs)): - assert isinstance( - run_config.init_learning_rate[idx], - list), "each candidate in init_learning_rate must be list" - assert isinstance(run_config.n_epochs[idx], - list), "each candidate in n_epochs must be list" - ### if elastic_order is none, use default order if self.elastic_order is not None: assert isinstance(self.elastic_order, list), 'elastic_order must be a list' + if getattr(self.run_config, 'elastic_depth', None) != None: + depth_list = list(set(self.run_config.elastic_depth)) + depth_list.sort() + self.layers['depth'] = depth_list + if self.elastic_order is None: self.elastic_order = [] # zero, elastic resulotion, write in demo @@ -133,16 +135,26 @@ class OFA(OFABase): if 'channel' in self._elastic_task and 'width' not in self.elastic_order: self.elastic_order.append('width') - assert len(self.run_config.n_epochs) == len(self.elastic_order) - assert len(self.run_config.n_epochs) == len( - self.run_config.dynamic_batch_size) - assert len(self.run_config.n_epochs) == len( - self.run_config.init_learning_rate) + if getattr(self.run_config, 'n_epochs', None) != None: + assert len(self.run_config.n_epochs) == len(self.elastic_order) + for idx in range(len(run_config.n_epochs)): + assert isinstance( + run_config.n_epochs[idx], + list), "each candidate in n_epochs must be list" + + if self.run_config.dynamic_batch_size != None: + assert len(self.run_config.n_epochs) == len( + self.run_config.dynamic_batch_size) + if self.run_config.init_learning_rate != None: + assert len(self.run_config.n_epochs) == len( + self.run_config.init_learning_rate) + for idx in range(len(run_config.n_epochs)): + assert isinstance( + run_config.init_learning_rate[idx], list + ), "each candidate in init_learning_rate must be list" ### ================= add distill prepare ====================== - if self.distill_config != None and ( - self.distill_config.lambda_distill != None and - self.distill_config.lambda_distill > 0): + if self.distill_config != None: self._add_teacher = True self._prepare_distill() @@ -153,9 +165,10 @@ class OFA(OFABase): if self.distill_config.teacher_model == None: logging.error( - 'If you want to add distill, please input class of teacher model' + 'If you want to add distill, please input instance of teacher model' ) + ### instance model by user can input super-param easily. assert isinstance(self.distill_config.teacher_model, paddle.fluid.dygraph.Layer) @@ -171,7 +184,7 @@ class OFA(OFABase): # add hook if mapping layers is not None # if mapping layer is None, return the output of the teacher model, # if mapping layer is NOT None, add hook and compute distill loss about mapping layers. - mapping_layers = self.distill_config.mapping_layers + mapping_layers = getattr(self.distill_config, 'mapping_layers', None) if mapping_layers != None: self.netAs = [] for name, sublayer in self.model.named_sublayers(): @@ -199,9 +212,16 @@ class OFA(OFABase): def _compute_epochs(self): if getattr(self, 'epoch', None) == None: + assert self.run_config.total_images is not None, \ + "if not use set_epoch() to set epoch, please set total_images in run_config." + assert self.run_config.train_batch_size is not None, \ + "if not use set_epoch() to set epoch, please set train_batch_size in run_config." + assert self.run_config.n_epochs is not None, \ + "if not use set_epoch() to set epoch, please set n_epochs in run_config." + self.iter_per_epochs = self.run_config.total_images // self.run_config.train_batch_size epoch = self.iter // self.iter_per_epochs else: - epoch = self.epochs + epoch = self.epoch return epoch def _sample_from_nestdict(self, cands, sample_type, task, phase): @@ -284,6 +304,9 @@ class OFA(OFABase): def export(self, config): pass + def set_net_config(self, net_config): + self.net_config = net_config + def forward(self, *inputs, **kwargs): # ===================== teacher process ===================== teacher_output = None @@ -293,11 +316,12 @@ class OFA(OFABase): # ============================================================ # ==================== student process ===================== - self.dynamic_iter += 1 - if self.dynamic_iter == self.run_config.dynamic_batch_size[ - self.task_idx]: - self.iter += 1 - self.dynamic_iter = 0 + if getattr(self.run_config, 'dynamic_batch_size', None) != None: + self.dynamic_iter += 1 + if self.dynamic_iter == self.run_config.dynamic_batch_size[ + self.task_idx]: + self.iter += 1 + self.dynamic_iter = 0 if self.net_config == None: if self.train_full == True: @@ -314,6 +338,6 @@ class OFA(OFABase): _logger.debug("Current config is {}".format(self.current_config)) if 'depth' in self.current_config: - kwargs['depth'] = int(self.current_config['depth']) + kwargs['depth'] = self.current_config['depth'] return self.model.forward(*inputs, **kwargs), teacher_output diff --git a/tests/test_ofa.py b/tests/test_ofa.py index 7c7575d021eda9686265beae18d1a79edbd19d34..e589928c3657796f18037f576c4ed456d23a6162 100644 --- a/tests/test_ofa.py +++ b/tests/test_ofa.py @@ -17,7 +17,6 @@ sys.path.append("../") import numpy as np import unittest import paddle -from static_case import StaticCase import paddle.fluid as fluid import paddle.fluid.dygraph.nn as nn from paddle.nn import ReLU @@ -35,13 +34,16 @@ class ModelConv(fluid.dygraph.Layer): channel=((4, 8, 12), (8, 12, 16), (8, 12, 16), (8, 12, 16))) as ofa_super: models = [] - models += [nn.Conv2D(3, 4, 3)] + models += [nn.Conv2D(3, 4, 3, padding=1)] models += [nn.InstanceNorm(4)] models += [ReLU()] models += [nn.Conv2D(4, 4, 3, groups=4)] models += [nn.InstanceNorm(4)] models += [ReLU()] - models += [nn.Conv2DTranspose(4, 4, 3, groups=4, use_cudnn=True)] + models += [ + nn.Conv2DTranspose( + 4, 4, 3, groups=4, padding=1, use_cudnn=True) + ] models += [nn.BatchNorm(4)] models += [ReLU()] models += [nn.Conv2D(4, 3, 3)] @@ -51,7 +53,8 @@ class ModelConv(fluid.dygraph.Layer): models += [ Block( SuperSeparableConv2D( - 3, 6, 1, candidate_config={'channel': (3, 6)})) + 3, 6, 1, padding=1, candidate_config={'channel': (3, 6)}), + fixed=True) ] with supernet( kernel_size=(3, 5, 7), expand_ratio=(1, 2, 4)) as ofa_super: @@ -92,15 +95,37 @@ class ModelLinear(fluid.dygraph.Layer): models = [] with supernet(expand_ratio=(1, 2, 4)) as ofa_super: models1 = [] + models1 += [nn.Embedding(size=(64, 64))] models1 += [nn.Linear(64, 128)] + models1 += [nn.LayerNorm(128)] models1 += [nn.Linear(128, 256)] models1 = ofa_super.convert(models1) models += models1 + self.models = paddle.nn.Sequential(*models) + + def forward(self, inputs, depth=None): + if depth != None: + assert isinstance(depth, int) + assert depth < len(self.models) + else: + depth = len(self.models) + for idx in range(depth): + layer = self.models[idx] + inputs = layer(inputs) + return inputs - with supernet(channel=((64, 128, 256), (64, 128, 256))) as ofa_super: + +class ModelLinear1(fluid.dygraph.Layer): + def __init__(self): + super(ModelLinear1, self).__init__() + models = [] + with supernet(channel=((64, 128, 256), (64, 128, 256), + (64, 128, 256))) as ofa_super: models1 = [] - models1 += [nn.Linear(256, 128)] + models1 += [nn.Embedding(size=(64, 64))] + models1 += [nn.Linear(64, 128)] + models1 += [nn.LayerNorm(128)] models1 += [nn.Linear(128, 256)] models1 = ofa_super.convert(models1) @@ -120,7 +145,35 @@ class ModelLinear(fluid.dygraph.Layer): return inputs -class TestOFA(StaticCase): +class ModelLinear2(fluid.dygraph.Layer): + def __init__(self): + super(ModelLinear2, self).__init__() + models = [] + with supernet(expand_ratio=None) as ofa_super: + models1 = [] + models1 += [nn.Embedding(size=(64, 64))] + models1 += [nn.Linear(64, 128)] + models1 += [nn.LayerNorm(128)] + models1 += [nn.Linear(128, 256)] + models1 = ofa_super.convert(models1) + + models += models1 + + self.models = paddle.nn.Sequential(*models) + + def forward(self, inputs, depth=None): + if depth != None: + assert isinstance(depth, int) + assert depth < len(self.models) + else: + depth = len(self.models) + for idx in range(depth): + layer = self.models[idx] + inputs = layer(inputs) + return inputs + + +class TestOFA(unittest.TestCase): def setUp(self): fluid.enable_dygraph() self.init_model_and_data() @@ -137,7 +190,6 @@ class TestOFA(StaticCase): def init_config(self): default_run_config = { 'train_batch_size': 1, - 'eval_batch_size': 1, 'n_epochs': [[1], [2, 3], [4, 5]], 'init_learning_rate': [[0.001], [0.003, 0.001], [0.003, 0.001]], 'dynamic_batch_size': [1, 1, 1], @@ -152,11 +204,13 @@ class TestOFA(StaticCase): 'mapping_layers': ['models.0.fn'] } self.distill_config = DistillConfig(**default_distill_config) + self.elastic_order = ['kernel_size', 'width', 'depth'] def test_ofa(self): ofa_model = OFA(self.model, self.run_config, - distill_config=self.distill_config) + distill_config=self.distill_config, + elastic_order=self.elastic_order) start_epoch = 0 for idx in range(len(self.run_config.n_epochs)): @@ -169,6 +223,8 @@ class TestOFA(StaticCase): ofa_model.parameters() + ofa_model.netAs_param)) for epoch_id in range(start_epoch, self.run_config.n_epochs[idx][ph_idx]): + if epoch_id == 0: + ofa_model.set_epoch(epoch_id) for model_no in range(self.run_config.dynamic_batch_size[ idx]): output, _ = ofa_model(self.data) @@ -191,14 +247,13 @@ class TestOFACase1(TestOFA): def init_model_and_data(self): self.model = ModelLinear() self.teacher_model = ModelLinear() - data_np = np.random.random((3, 64)).astype(np.float32) + data_np = np.random.random((3, 64)).astype(np.int64) self.data = fluid.dygraph.to_variable(data_np) def init_config(self): default_run_config = { 'train_batch_size': 1, - 'eval_batch_size': 1, 'n_epochs': [[2, 5]], 'init_learning_rate': [[0.003, 0.001]], 'dynamic_batch_size': [1], @@ -211,6 +266,23 @@ class TestOFACase1(TestOFA): 'teacher_model': self.teacher_model, } self.distill_config = DistillConfig(**default_distill_config) + self.elastic_order = None + + +class TestOFACase2(TestOFACase1): + def init_model_and_data(self): + self.model = ModelLinear1() + self.teacher_model = ModelLinear1() + data_np = np.random.random((3, 64)).astype(np.int64) + + self.data = fluid.dygraph.to_variable(data_np) + + +class TestOFACase3(unittest.TestCase): + def test_ofa(self): + self.model = ModelLinear2() + ofa_model = OFA(self.model) + ofa_model.set_net_config({'expand_ratio': None}) if __name__ == '__main__':