diff --git a/paddleslim/nas/ofa/convert_super.py b/paddleslim/nas/ofa/convert_super.py index 41ce4980ea54d52d64779a9d0c9727139ee7ed71..b08116f1b3e884bbad10bb114d86b9cdf9e6eec5 100644 --- a/paddleslim/nas/ofa/convert_super.py +++ b/paddleslim/nas/ofa/convert_super.py @@ -105,11 +105,13 @@ class Convert: cur_channel = None for idx, layer in enumerate(model): cls_name = layer.__class__.__name__.lower() - if 'conv' in cls_name or 'linear' in cls_name or 'embedding' in cls_name: - weight_layer_count += 1 - last_weight_layer_idx = idx - if first_weight_layer_idx == -1: - first_weight_layer_idx = idx + ### basic api in paddle + if len(layer.sublayers()) == 0: + if 'conv' in cls_name or 'linear' in cls_name or 'embedding' in cls_name: + weight_layer_count += 1 + last_weight_layer_idx = idx + if first_weight_layer_idx == -1: + first_weight_layer_idx = idx if getattr(self.context, 'channel', None) != None: assert len( diff --git a/paddleslim/nas/ofa/get_sub_model.py b/paddleslim/nas/ofa/get_sub_model.py index 4bfb1a1aba1c47113a8ff95e3d2b348490d455f0..dd624e82764f634616fac3775bf3df1d81e0f543 100644 --- a/paddleslim/nas/ofa/get_sub_model.py +++ b/paddleslim/nas/ofa/get_sub_model.py @@ -146,12 +146,34 @@ def prune_params(model, param_config, super_model_sd=None): setattr(model, k, v) +def _is_depthwise(op): + """Check if this op is depthwise conv. + The shape of input and the shape of output in depthwise conv must be same in superlayer, + so depthwise op cannot be consider as weight op + """ + if op.type() == 'depthwise_conv': + return True + elif 'conv' in op.type(): + for inp in op.all_inputs(): + if not inp._var.persistable and op.attr('groups') == inp._var.shape[ + 1]: + return True + return False + + def _find_weight_ops(op, graph, weights): """ Find the vars come from operators with weight. """ pre_ops = graph.pre_ops(op) for pre_op in pre_ops: - if pre_op.type() in WEIGHT_OP: + ### if depthwise conv is one of elementwise's input, + ### add it into this same search space + if _is_depthwise(pre_op): + for inp in pre_op.all_inputs(): + if inp._var.persistable: + weights.append(inp._var.name) + + if pre_op.type() in WEIGHT_OP and not _is_depthwise(pre_op): for inp in pre_op.all_inputs(): if inp._var.persistable: weights.append(inp._var.name) @@ -176,16 +198,22 @@ def check_search_space(graph): """ Find the shortcut in the model and set same config for this situation. """ same_search_space = [] + depthwise_conv = [] for op in graph.ops(): - if op.type() == 'elementwise_add': + if op.type() == 'elementwise_add' or op.type() == 'elementwise_mul': inp1, inp2 = op.all_inputs()[0], op.all_inputs()[1] if (not inp1._var.persistable) and (not inp2._var.persistable): pre_ele_op = _find_pre_elementwise_add(op, graph) if pre_ele_op != None: same_search_space.append(pre_ele_op) + if _is_depthwise(op): + for inp in op.all_inputs(): + if inp._var.persistable: + depthwise_conv.append(inp._var.name) + if len(same_search_space) == 0: - return None + return None, None same_search_space = sorted([sorted(x) for x in same_search_space]) final_search_space = [] @@ -204,5 +232,7 @@ def check_search_space(graph): break if not merged: final_search_space.append(l) + final_search_space = sorted([sorted(x) for x in final_search_space]) + depthwise_conv = sorted(depthwise_conv) - return final_search_space + return (final_search_space, depthwise_conv) diff --git a/paddleslim/nas/ofa/layers.py b/paddleslim/nas/ofa/layers.py index 0f69647db1396607c1eb48c361b44ea07a8d57a9..a0b1fe0b0577c495b109b517d5b698cb2b1ea735 100644 --- a/paddleslim/nas/ofa/layers.py +++ b/paddleslim/nas/ofa/layers.py @@ -302,7 +302,16 @@ class SuperConv2D(nn.Conv2D): padding = self._padding if self.bias is not None: - bias = self.bias[:out_nc] + ### if conv is depthwise conv, expand_ratio=0, but conv' expand + ### ratio before depthwise conv is not equal to 1.0, the shape of the weight + ### about this depthwise conv is changed, but out_nc is not change, + ### so need to change bias shape according to the weight_out_nc. + ### if in_nc > groups > 1, the actual output of conv is weight_out_nc * groups, + ### so slice the shape of bias by weight_out_nc and groups. + ### if in_nc = groups, slice the shape of bias by weight_out_nc. + if groups != in_nc: + weight_out_nc = weight_out_nc * groups + bias = self.bias[:weight_out_nc] else: bias = self.bias @@ -313,7 +322,7 @@ class SuperConv2D(nn.Conv2D): stride=self._stride, padding=padding, dilation=self._dilation, - groups=self._groups, + groups=groups, data_format=self._data_format) return out @@ -606,7 +615,9 @@ class SuperConv2DTranspose(nn.Conv2DTranspose): output_padding = 0 if self.bias is not None: - bias = self.bias[:out_nc] + if groups != in_nc: + weight_out_nc = weight_out_nc * groups + bias = self.bias[:weight_out_nc] else: bias = self.bias @@ -618,7 +629,7 @@ class SuperConv2DTranspose(nn.Conv2DTranspose): output_padding=output_padding, stride=self._stride, dilation=self._dilation, - groups=self._groups, + groups=groups, output_size=output_size, data_format=self._data_format) return out @@ -929,10 +940,11 @@ class SuperBatchNorm2D(nn.BatchNorm2D): weight_attr=None, bias_attr=None, data_format='NCHW', + use_global_stats=None, name=None): - super(SuperBatchNorm2D, self).__init__(num_features, momentum, epsilon, - weight_attr, bias_attr, - data_format, name) + super(SuperBatchNorm2D, self).__init__( + num_features, momentum, epsilon, weight_attr, bias_attr, + data_format, use_global_stats, name) def forward(self, input): self._check_data_format(self._data_format) @@ -954,7 +966,8 @@ class SuperBatchNorm2D(nn.BatchNorm2D): training=self.training, momentum=self._momentum, epsilon=self._epsilon, - data_format=self._data_format) + data_format=self._data_format, + use_global_stats=self._use_global_stats) class SuperSyncBatchNorm(nn.SyncBatchNorm): diff --git a/paddleslim/nas/ofa/ofa.py b/paddleslim/nas/ofa/ofa.py index d5c4836b8be2cd4a2f9a3344c7510a80f7c110ae..6a95ff5169064e50a1c623c5da2c8d51f232bb3b 100644 --- a/paddleslim/nas/ofa/ofa.py +++ b/paddleslim/nas/ofa/ofa.py @@ -51,7 +51,9 @@ RunConfig = namedtuple( # list, elactic depth of the model in training, default: None 'elastic_depth', # list, the number of sub-network to train per mini-batch data, used to get current epoch, default: None - 'dynamic_batch_size' + 'dynamic_batch_size', + # the shape of weights in the skip_layers will not change in the training, default: None + 'skip_layers' ]) RunConfig.__new__.__defaults__ = (None, ) * len(RunConfig._fields) @@ -106,7 +108,10 @@ class OFABase(Layer): if getattr(self, 'current_config', None) != None: ### if block is fixed, donnot join key into candidate ### concrete config as parameter in kwargs - if block.fixed == False: + if block.fixed == False and ( + self._skip_layers != None and + self._key2name[block.key] not in self._skip_layers) and \ + (block.fn.weight.name not in self._depthwise_conv): assert self._key2name[ block. key] in self.current_config, 'DONNT have {} layer in config.'.format( @@ -117,7 +122,7 @@ class OFABase(Layer): config.update(kwargs) else: config = dict() - logging.debug(self.model, config) + _logger.debug(self.model, config) return block.fn(*inputs, **config) @@ -175,6 +180,7 @@ class OFA(OFABase): self._mapping_layers = None self._build_ss = False self._broadcast = False + self._skip_layers = None ### if elastic_order is none, use default order if self.elastic_order is not None: @@ -185,6 +191,7 @@ class OFA(OFABase): depth_list = list(set(self.run_config.elastic_depth)) depth_list.sort() self._ofa_layers['depth'] = depth_list + self._layers['depth'] = depth_list if self.elastic_order is None: self.elastic_order = [] @@ -198,6 +205,7 @@ class OFA(OFABase): depth_list = list(set(self.run_config.elastic_depth)) depth_list.sort() self._ofa_layers['depth'] = depth_list + self._layers['depth'] = depth_list self.elastic_order.append('depth') # final, elastic width @@ -225,6 +233,11 @@ class OFA(OFABase): run_config.init_learning_rate[idx], list ), "each candidate in init_learning_rate must be list" + ### remove skip layers in search space + if self.run_config != None and getattr(self.run_config, 'skip_layers', + None) != None: + self._skip_layers = self.run_config.skip_layers + ### ================= add distill prepare ====================== if self.distill_config != None: self._add_teacher = True @@ -234,7 +247,7 @@ class OFA(OFABase): def _prepare_distill(self): if self.distill_config.teacher_model == None: - logging.error( + _logger.error( 'If you want to add distill, please input instance of teacher model' ) @@ -289,6 +302,7 @@ class OFA(OFABase): def _reset_hook_before_forward(self): self.Tacts, self.Sacts = {}, {} + self.hooks = [] if self._mapping_layers != None: def get_activation(mem, name): @@ -300,12 +314,18 @@ class OFA(OFABase): def add_hook(net, mem, mapping_layers): for idx, (n, m) in enumerate(net.named_sublayers()): if n in mapping_layers: - m.register_forward_post_hook(get_activation(mem, n)) + self.hooks.append( + m.register_forward_post_hook( + get_activation(mem, n))) add_hook(self.model, self.Sacts, self._mapping_layers) add_hook(self.ofa_teacher_model.model, self.Tacts, self._mapping_layers) + def _remove_hook_after_forward(self): + for hook in self.hooks: + hook.remove() + def _compute_epochs(self): if getattr(self, 'epoch', None) == None: assert self.run_config.total_images is not None, \ @@ -521,6 +541,14 @@ class OFA(OFABase): else: return False + def _clear_width(self, key): + if 'expand_ratio' in self._ofa_layers[key]: + self._ofa_layers[key].pop('expand_ratio') + elif 'channel' in self._ofa_layers[key]: + self._ofa_layers[key].pop('channel') + if len(self._ofa_layers[key]) == 0: + self._ofa_layers.pop(key) + def _clear_search_space(self, *inputs, **kwargs): """ find shortcut in model, and clear up the search space """ input_shapes = [] @@ -533,34 +561,70 @@ class OFA(OFABase): input_dtypes.append(v.numpy().dtype) ### find shortcut block using static model + model_to_traverse = self.model._layers if isinstance( + self.model, DataParallel) else self.model _st_prog = dygraph2program( - self.model, inputs=input_shapes, dtypes=input_dtypes) - self._same_ss = check_search_space(GraphWrapper(_st_prog)) + model_to_traverse, inputs=input_shapes, dtypes=input_dtypes) + self._same_ss, self._depthwise_conv = check_search_space( + GraphWrapper(_st_prog)) if self._same_ss != None: - self._same_ss = sorted(self._same_ss) self._param2key = {} self._broadcast = True ### the name of sublayer is the key in search space ### param.name is the name in self._same_ss - model_to_traverse = self.model._layers if isinstance( - self.model, DataParallel) else self.model for name, sublayer in model_to_traverse.named_sublayers(): if isinstance(sublayer, BaseBlock): for param in sublayer.parameters(): if self._find_ele(param.name, self._same_ss): self._param2key[param.name] = name + ### double clear same search space to avoid outputs weights in same ss. + tmp_same_ss = [] + for ss in self._same_ss: + per_ss = [] + for key in ss: + if key not in self._param2key.keys(): + continue + + ### if skip_layers and same ss both have same layer, + ### the layers related to this layer need to add to skip_layers + if self._skip_layers != None and self._param2key[ + key] in self._skip_layers: + self._skip_layers += [self._param2key[sk] for sk in ss] + per_ss = [] + break + + if self._param2key[key] in self._ofa_layers.keys() and \ + ('expand_ratio' in self._ofa_layers[self._param2key[key]] or \ + 'channel' in self._ofa_layers[self._param2key[key]]): + per_ss.append(key) + else: + _logger.info("{} not in ss".format(key)) + + if len(per_ss) != 0: + tmp_same_ss.append(per_ss) + + self._same_ss = tmp_same_ss + + ### clear layer in ofa_layers set by skip layers + if self._skip_layers != None: + for skip_layer in self._skip_layers: + if skip_layer in self._ofa_layers.keys(): + self._ofa_layers.pop(skip_layer) + for per_ss in self._same_ss: for ss in per_ss[1:]: - if 'expand_ratio' in self._ofa_layers[self._param2key[ss]]: - self._ofa_layers[self._param2key[ss]].pop( - 'expand_ratio') - elif 'channel' in self._ofa_layers[self._param2key[ss]]: - self._ofa_layers[self._param2key[ss]].pop('channel') - if len(self._ofa_layers[self._param2key[ss]]) == 0: - self._ofa_layers.pop(self._param2key[ss]) + self._clear_width(self._param2key[ss]) + + ### clear depthwise conv from search space because of its output channel cannot change + for name, sublayer in model_to_traverse.named_sublayers(): + if isinstance(sublayer, BaseBlock): + for param in sublayer.parameters(): + if param.name in self._depthwise_conv and name in self._ofa_layers.keys( + ): + self._clear_width(name) def _broadcast_ss(self): """ broadcast search space after random sample.""" @@ -630,4 +694,9 @@ class OFA(OFABase): if self._broadcast: self._broadcast_ss() - return self.model.forward(*inputs, **kwargs), teacher_output + student_output = self.model.forward(*inputs, **kwargs) + + if self._add_teacher: + self._remove_hook_after_forward() + + return student_output, teacher_output diff --git a/paddleslim/nas/ofa/utils/special_config.py b/paddleslim/nas/ofa/utils/special_config.py index 4cfcdb23c11e7d62d6eb325f3dc1653874b1e226..473b6cdd0fd8f85e128d138273045238d6d1bcc9 100644 --- a/paddleslim/nas/ofa/utils/special_config.py +++ b/paddleslim/nas/ofa/utils/special_config.py @@ -45,6 +45,9 @@ def dynabert_config(model, width_mult, depth_mult=1.0): if block_k == 'depth': block_v = depth_mult - new_block_k = model._key2name[block_k] + if block_k != 'depth': + new_block_k = model._key2name[block_k] + else: + new_block_k = 'depth' new_config[new_block_k] = block_v return new_config diff --git a/tests/test_ofa_v2.py b/tests/test_ofa_v2.py index c2bef224f3c1c0f91fc4801eb406439025385f47..ee8e515f4dd8e7c514eee4e780298d21f00a127f 100644 --- a/tests/test_ofa_v2.py +++ b/tests/test_ofa_v2.py @@ -40,6 +40,46 @@ class ModelV1(nn.Layer): return self.cls + self.model(inputs) +class ModelShortcut(nn.Layer): + def __init__(self): + super(ModelShortcut, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2D(3, 12, 1), nn.BatchNorm2D(12), nn.ReLU()) + self.branch1 = nn.Sequential( + nn.Conv2D(12, 12, 1), + nn.BatchNorm2D(12), + nn.ReLU(), + nn.Conv2D( + 12, 12, 1, groups=12), + nn.BatchNorm2D(12), + nn.ReLU(), + nn.Conv2D( + 12, 12, 1, groups=12), + nn.BatchNorm2D(12), + nn.ReLU()) + self.branch2 = nn.Sequential( + nn.Conv2D(12, 12, 1), + nn.BatchNorm2D(12), + nn.ReLU(), + nn.Conv2D( + 12, 12, 1, groups=12), + nn.BatchNorm2D(12), + nn.ReLU(), + nn.Conv2D(12, 12, 1), + nn.BatchNorm2D(12), + nn.ReLU()) + self.out = nn.Sequential( + nn.Conv2D(12, 12, 1), nn.BatchNorm2D(12), nn.ReLU()) + + def forward(self, x): + x = self.conv1(x) + y = self.branch1(x) + y = x + y + z = self.branch2(y) + z = z + y + return z + + class TestOFAV2(unittest.TestCase): def setUp(self): model = ModelV1() @@ -73,5 +113,42 @@ class TestOFAV2Export(unittest.TestCase): origin_model=origin_model) +class TestShortcutSkiplayers(unittest.TestCase): + def setUp(self): + model = ModelShortcut() + sp_net_config = supernet(expand_ratio=[0.5, 1.0]) + self.model = Convert(sp_net_config).convert(model) + self.images = paddle.randn(shape=[2, 3, 32, 32], dtype='float32') + self.init_config() + self.ofa_model = OFA(self.model, run_config=self.run_config) + self.ofa_model._clear_search_space(self.images) + + def init_config(self): + default_run_config = {'skip_layers': ['branch1.6']} + self.run_config = RunConfig(**default_run_config) + + def test_shortcut(self): + self.ofa_model.set_epoch(0) + self.ofa_model.set_task('expand_ratio') + for i in range(5): + self.ofa_model(self.images) + assert list(self.ofa_model._ofa_layers.keys()) == ['branch2.0', 'out.0'] + + +class TestShortcutSkiplayersCase1(TestShortcutSkiplayers): + def init_config(self): + default_run_config = {'skip_layers': ['conv1.0']} + self.run_config = RunConfig(**default_run_config) + + +class TestShortcutSkiplayersCase2(TestShortcutSkiplayers): + def init_config(self): + default_run_config = {'skip_layers': ['branch2.0']} + self.run_config = RunConfig(**default_run_config) + + def test_shortcut(self): + assert list(self.ofa_model._ofa_layers.keys()) == ['conv1.0', 'out.0'] + + if __name__ == '__main__': unittest.main()