From 7411b0a54cb9d1e780236c45d4c23d71dad5e313 Mon Sep 17 00:00:00 2001 From: Chang Xu Date: Wed, 30 Jun 2021 17:18:56 +0800 Subject: [PATCH] Update check ss (#817) * update_check_ss * update_check_ss * update_check_ss * ss_multi_output * remove_some_comments * Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleSlim into update_check_ss * merge_conflicts * add_coverage --- paddleslim/core/graph_wrapper.py | 18 ++++ paddleslim/nas/ofa/get_sub_model.py | 105 ++++++++++++++++++---- paddleslim/nas/ofa/ofa.py | 36 +++++--- tests/test_ofa.py | 6 +- tests/test_ofa_v2.py | 132 +++++++++++++++++++++++++++- 5 files changed, 264 insertions(+), 33 deletions(-) diff --git a/paddleslim/core/graph_wrapper.py b/paddleslim/core/graph_wrapper.py index 86fb1736..34c6323b 100644 --- a/paddleslim/core/graph_wrapper.py +++ b/paddleslim/core/graph_wrapper.py @@ -57,6 +57,15 @@ class VarWrapper(object): def __repr__(self): return self._var.name + def __lt__(self, other): + return self._var.name < other._var.name + + def __gt__(self, other): + return self._var.name > other._var.name + + def __eq__(self, other): + return self._var.name == other._var.name + def shape(self): """ Get the shape of the varibale. @@ -144,6 +153,15 @@ class OpWrapper(object): self.type(), self.all_inputs()) + def __lt__(self, other): + return self._op.idx < other._op.idx + + def __gt__(self, other): + return self._op.idx > other._op.idx + + def __eq__(self, other): + return self._op.idx == other._op.idx + def is_bwd_op(self): """ Whether this operator is backward op. diff --git a/paddleslim/nas/ofa/get_sub_model.py b/paddleslim/nas/ofa/get_sub_model.py index b87d23bc..62414852 100644 --- a/paddleslim/nas/ofa/get_sub_model.py +++ b/paddleslim/nas/ofa/get_sub_model.py @@ -19,14 +19,37 @@ from .layers_base import BaseBlock __all__ = ['check_search_space'] -WEIGHT_OP = [ - 'conv2d', 'linear', 'embedding', 'conv2d_transpose', 'depthwise_conv2d' +DYNAMIC_WEIGHT_OP = [ + 'conv2d', 'mul', 'matmul', 'embedding', 'conv2d_transpose', + 'depthwise_conv2d' ] + CONV_TYPES = [ 'conv2d', 'conv3d', 'conv1d', 'superconv2d', 'supergroupconv2d', 'superdepthwiseconv2d' ] +ALL_WEIGHT_OP = [ + 'conv2d', 'mul', 'matmul', 'elementwise_add', 'embedding', + 'conv2d_transpose', 'depthwise_conv2d', 'batch_norm', 'layer_norm', + 'instance_norm', 'sync_batch_norm' +] + + +def _is_dynamic_weight_op(op, all_weight_op=False): + if all_weight_op == True: + weight_ops = ALL_WEIGHT_OP + else: + weight_ops = DYNAMIC_WEIGHT_OP + if op.type() in weight_ops: + if op.type() in ['mul', 'matmul']: + for inp in sorted(op.all_inputs()): + if inp._var.persistable == True: + return True + return False + return True + return False + def get_actual_shape(transform, channel): if transform == None: @@ -58,7 +81,7 @@ def _is_depthwise(op): def _find_weight_ops(op, graph, weights): """ Find the vars come from operators with weight. """ - pre_ops = graph.pre_ops(op) + pre_ops = sorted(graph.pre_ops(op)) for pre_op in pre_ops: ### if depthwise conv is one of elementwise's input, ### add it into this same search space @@ -67,7 +90,7 @@ def _find_weight_ops(op, graph, weights): if inp._var.persistable: weights.append(inp._var.name) - if pre_op.type() in WEIGHT_OP and not _is_depthwise(pre_op): + if _is_dynamic_weight_op(pre_op) and not _is_depthwise(pre_op): for inp in pre_op.all_inputs(): if inp._var.persistable: weights.append(inp._var.name) @@ -76,29 +99,70 @@ def _find_weight_ops(op, graph, weights): return weights -def _find_pre_elementwise_add(op, graph): +def _find_pre_elementwise_op(op, graph): """ Find precedors of the elementwise_add operator in the model. """ same_prune_before_elementwise_add = [] - pre_ops = graph.pre_ops(op) + pre_ops = sorted(graph.pre_ops(op)) for pre_op in pre_ops: - if pre_op.type() in WEIGHT_OP: + if _is_dynamic_weight_op(pre_op): return same_prune_before_elementwise_add = _find_weight_ops( pre_op, graph, same_prune_before_elementwise_add) return same_prune_before_elementwise_add +def _is_output_weight_ops(op, graph): + next_ops = sorted(graph.next_ops(op)) + for next_op in next_ops: + if op == next_op: + continue + if _is_dynamic_weight_op(next_op): + return False + return _is_output_weight_ops(next_op, graph) + return True + + def check_search_space(graph): """ Find the shortcut in the model and set same config for this situation. """ + output_conv = [] same_search_space = [] depthwise_conv = [] + fixed_by_input = [] for op in graph.ops(): + # if there is no weight ops after this op, + # this op can be seen as an output + if _is_output_weight_ops(op, graph) and _is_dynamic_weight_op(op): + for inp in op.all_inputs(): + if inp._var.persistable: + output_conv.append(inp._var.name) + if op.type() == 'elementwise_add' or op.type() == 'elementwise_mul': inp1, inp2 = op.all_inputs()[0], op.all_inputs()[1] if (not inp1._var.persistable) and (not inp2._var.persistable): - pre_ele_op = _find_pre_elementwise_add(op, graph) + # if one of two vars comes from input, + # then the two vars in this elementwise op should be all fixed + if inp1.inputs() and inp2.inputs(): + pre_fixed_op_1, pre_fixed_op_2 = [], [] + pre_fixed_op_1 = _find_weight_ops(inp1.inputs()[0], graph, + pre_fixed_op_1) + pre_fixed_op_2 = _find_weight_ops(inp2.inputs()[0], graph, + pre_fixed_op_2) + if not pre_fixed_op_1: + fixed_by_input += pre_fixed_op_2 + if not pre_fixed_op_2: + fixed_by_input += pre_fixed_op_1 + elif (not inp1.inputs() and inp2.inputs()) or ( + inp1.inputs() and not inp2.inputs()): + pre_fixed_op = [] + inputs = inp1.inputs() if not inp2.inputs( + ) else inp2.inputs() + pre_fixed_op = _find_weight_ops(inputs[0], graph, + pre_fixed_op) + fixed_by_input += pre_fixed_op + + pre_ele_op = _find_pre_elementwise_op(op, graph) if pre_ele_op != None: same_search_space.append(pre_ele_op) @@ -108,7 +172,7 @@ def check_search_space(graph): depthwise_conv.append(inp._var.name) if len(same_search_space) == 0: - return None, [] + return None, [], [], output_conv same_search_space = sorted([sorted(x) for x in same_search_space]) final_search_space = [] @@ -129,8 +193,9 @@ def check_search_space(graph): final_search_space.append(l) final_search_space = sorted([sorted(x) for x in final_search_space]) depthwise_conv = sorted(depthwise_conv) + fixed_by_input = sorted(fixed_by_input) - return (final_search_space, depthwise_conv) + return (final_search_space, depthwise_conv, fixed_by_input, output_conv) def broadcast_search_space(same_search_space, param2key, origin_config): @@ -156,11 +221,15 @@ def broadcast_search_space(same_search_space, param2key, origin_config): 'channel': origin_config[pre_key]['channel'] }) else: - if 'expand_ratio' in origin_config[pre_key]: - origin_config[key] = { - 'expand_ratio': origin_config[pre_key]['expand_ratio'] - } - elif 'channel' in origin_config[pre_key]: - origin_config[key] = { - 'channel': origin_config[pre_key]['channel'] - } + # if the pre_key is removed from config for some reasons + # such as it is fixed by hand or by elementwise op + if pre_key in origin_config: + if 'expand_ratio' in origin_config[pre_key]: + origin_config[key] = { + 'expand_ratio': + origin_config[pre_key]['expand_ratio'] + } + elif 'channel' in origin_config[pre_key]: + origin_config[key] = { + 'channel': origin_config[pre_key]['channel'] + } diff --git a/paddleslim/nas/ofa/ofa.py b/paddleslim/nas/ofa/ofa.py index c4ea47a3..89810533 100644 --- a/paddleslim/nas/ofa/ofa.py +++ b/paddleslim/nas/ofa/ofa.py @@ -114,7 +114,7 @@ class OFABase(Layer): if block.fixed == False and (self._skip_layers == None or (self._skip_layers != None and self._key2name[block.key] not in self._skip_layers)) and \ - (block.fn.weight.name not in self._depthwise_conv): + (block.fn.weight.name not in self._cannot_changed_layer): assert self._key2name[ block. key] in self.current_config, 'DONNT have {} layer in config.'.format( @@ -183,7 +183,7 @@ class OFA(OFABase): self._build_ss = False self._broadcast = False self._skip_layers = None - self._depthwise_conv = [] + self._cannot_changed_layer = [] ### if elastic_order is none, use default order if self.elastic_order is not None: @@ -346,6 +346,7 @@ class OFA(OFABase): def _sample_from_nestdict(self, cands, sample_type, task, phase): sample_cands = dict() for k, v in cands.items(): + if isinstance(v, dict): sample_cands[k] = self._sample_from_nestdict( v, sample_type=sample_type, task=task, phase=phase) @@ -656,8 +657,9 @@ class OFA(OFABase): self.model, DataParallel) else self.model _st_prog = dygraph2program( model_to_traverse, inputs=input_shapes, dtypes=input_dtypes) - self._same_ss, self._depthwise_conv = check_search_space( + self._same_ss, depthwise_conv, fixed_by_input, output_conv = check_search_space( GraphWrapper(_st_prog)) + self._cannot_changed_layer = output_conv if self._same_ss != None: self._param2key = {} @@ -698,6 +700,15 @@ class OFA(OFABase): self._same_ss = tmp_same_ss + ### if fixed_by_input layer in a same ss, + ### layers in this same ss should all be fixed + tmp_fixed_by_input = [] + for ss in self._same_ss: + for key in fixed_by_input: + if key in ss: + tmp_fixed_by_input += ss + fixed_by_input += tmp_fixed_by_input + ### clear layer in ofa_layers set by skip layers if self._skip_layers != None: for skip_layer in self._skip_layers: @@ -708,13 +719,17 @@ class OFA(OFABase): for ss in per_ss[1:]: self._clear_width(self._param2key[ss]) - ### clear depthwise conv from search space because of its output channel cannot change - for name, sublayer in model_to_traverse.named_sublayers(): - if isinstance(sublayer, BaseBlock): - for param in sublayer.parameters(): - if param.name in self._depthwise_conv and name in self._ofa_layers.keys( - ): - self._clear_width(name) + self._cannot_changed_layer = sorted( + set(output_conv + fixed_by_input + depthwise_conv)) + ### clear depthwise convs from search space because of its output channel cannot change + ### clear output convs from search space because of model output shape cannot change + ### clear convs that operate with fixed input + for name, sublayer in model_to_traverse.named_sublayers(): + if isinstance(sublayer, BaseBlock): + for param in sublayer.parameters(): + if param.name in self._cannot_changed_layer and name in self._ofa_layers.keys( + ): + self._clear_width(name) def forward(self, *inputs, **kwargs): # ===================== teacher process ===================== @@ -751,6 +766,7 @@ class OFA(OFABase): self.current_config = self.net_config _logger.debug("Current config is {}".format(self.current_config)) + if 'depth' in self.current_config: kwargs['depth'] = self.current_config['depth'] if self._broadcast: diff --git a/tests/test_ofa.py b/tests/test_ofa.py index 2c109f87..68e2810a 100644 --- a/tests/test_ofa.py +++ b/tests/test_ofa.py @@ -446,7 +446,7 @@ class TestShortCut(unittest.TestCase): self.config, input_shapes=[[2, 3, 224, 224]], input_dtypes=['float32']) - assert len(self.ofa_model.ofa_layers) == 38 + assert len(self.ofa_model.ofa_layers) == 37 class TestExportCase1(unittest.TestCase): @@ -462,8 +462,8 @@ class TestExportCase1(unittest.TestCase): def test_export_model_linear1(self): ex_model = self.ofa_model.export( self.config, input_shapes=[[3, 64]], input_dtypes=['int64']) + assert len(self.ofa_model.ofa_layers) == 3 ex_model(self.data) - assert len(self.ofa_model.ofa_layers) == 4 class TestExportCase2(unittest.TestCase): @@ -482,7 +482,7 @@ class TestExportCase2(unittest.TestCase): ex_model = self.ofa_model.export( config, input_shapes=[[3, 64]], input_dtypes=['int64']) ex_model(self.data) - assert len(self.ofa_model.ofa_layers) == 4 + assert len(self.ofa_model.ofa_layers) == 3 if __name__ == '__main__': diff --git a/tests/test_ofa_v2.py b/tests/test_ofa_v2.py index 3b7eee3b..ae337505 100644 --- a/tests/test_ofa_v2.py +++ b/tests/test_ofa_v2.py @@ -81,6 +81,74 @@ class ModelShortcut(nn.Layer): return z +class ModelElementwise(nn.Layer): + def __init__(self): + super(ModelElementwise, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2D(3, 12, 1), nn.BatchNorm2D(12), nn.ReLU()) + self.conv2 = nn.Sequential( + nn.Conv2D(12, 24, 3), nn.BatchNorm2D(24), nn.ReLU()) + self.conv3 = nn.Sequential( + nn.Conv2D(24, 12, 1), nn.BatchNorm2D(12), nn.ReLU()) + self.out = nn.Sequential( + nn.Conv2D(12, 6, 1), nn.BatchNorm2D(6), nn.ReLU()) + + def forward(self, x): + d = paddle.randn(shape=[2, 12, x.shape[2], x.shape[3]], dtype='float32') + d = nn.functional.softmax(d) + + x = self.conv1(x) + x = x + d + x = self.conv2(x) + x = self.conv3(x) + x = self.out(x) + return x + + +class ModelMultiExit(nn.Layer): + def __init__(self): + super(ModelMultiExit, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2D(3, 12, 3), nn.BatchNorm2D(12), nn.ReLU()) + self.block1 = nn.Sequential( + nn.Conv2D(12, 24, 7), + nn.BatchNorm2D(24), + nn.ReLU(), + nn.MaxPool2D( + kernel_size=3, stride=2, padding=0), + nn.Conv2D(24, 24, 7), + nn.BatchNorm2D(24), + nn.ReLU(), + nn.MaxPool2D( + kernel_size=3, stride=2, padding=0)) + self.block2 = nn.Sequential( + nn.Conv2D(24, 24, 1), + nn.BatchNorm2D(24), + nn.ReLU(), + nn.MaxPool2D( + kernel_size=3, stride=2, padding=1)) + + self.out1 = nn.Sequential( + nn.Conv2D(24, 24, 1), nn.BatchNorm2D(24), nn.ReLU()) + + self.out2 = nn.Sequential( + nn.Conv2D(48, 24, 7), + nn.BatchNorm2D(24), + nn.ReLU(), nn.Conv2D(24, 24, 3), nn.BatchNorm2D(24), nn.ReLU()) + + def forward(self, x): + x = self.conv1(x) + + b1 = self.block1(x) + adapt = nn.UpsamplingBilinear2D(size=[b1.shape[2], b1.shape[2]]) + b2 = self.block2(b1) + up = adapt(b2) + y1 = self.out1(b1) + y2 = paddle.concat([b1, up], axis=1) + y2 = self.out2(y2) + return [y1, y2] + + class ModelInputDict(nn.Layer): def __init__(self): super(ModelInputDict, self).__init__() @@ -132,6 +200,37 @@ class TestOFAV2Export(unittest.TestCase): origin_model=origin_model) +class Testelementwise(unittest.TestCase): + def setUp(self): + model = ModelElementwise() + sp_net_config = supernet(expand_ratio=[0.25, 0.5, 1.0]) + self.model = Convert(sp_net_config).convert(model) + self.images = paddle.randn(shape=[2, 3, 32, 32], dtype='float32') + + def test_elementwise(self): + self.ofa_model = OFA(self.model) + self.ofa_model.set_epoch(0) + self.ofa_model.set_task('expand_ratio') + out, _ = self.ofa_model(self.images) + assert list(self.ofa_model._ofa_layers.keys()) == ['conv2.0', 'conv3.0'] + + +class TestMultiExit(unittest.TestCase): + def setUp(self): + self.images = paddle.randn(shape=[1, 3, 224, 224], dtype='float32') + model = ModelMultiExit() + sp_net_config = supernet(expand_ratio=[0.25, 0.5, 1.0]) + self.model = Convert(sp_net_config).convert(model) + + def test_multiexit(self): + self.ofa_model = OFA(self.model) + self.ofa_model.set_epoch(0) + self.ofa_model.set_task('expand_ratio') + out, _ = self.ofa_model(self.images) + assert list(self.ofa_model._ofa_layers.keys( + )) == ['conv1.0', 'block1.0', 'block1.4', 'block2.0', 'out2.0'] + + class TestShortcutSkiplayers(unittest.TestCase): def setUp(self): model = ModelShortcut() @@ -151,7 +250,7 @@ class TestShortcutSkiplayers(unittest.TestCase): self.ofa_model.set_task('expand_ratio') for i in range(5): self.ofa_model(self.images) - assert list(self.ofa_model._ofa_layers.keys()) == ['branch2.0', 'out.0'] + assert list(self.ofa_model._ofa_layers.keys()) == ['branch2.0'] class TestShortcutSkiplayersCase1(TestShortcutSkiplayers): @@ -166,7 +265,36 @@ class TestShortcutSkiplayersCase2(TestShortcutSkiplayers): self.run_config = RunConfig(**default_run_config) def test_shortcut(self): - assert list(self.ofa_model._ofa_layers.keys()) == ['conv1.0', 'out.0'] + assert list(self.ofa_model._ofa_layers.keys()) == ['conv1.0'] + + +class TestInputDict(unittest.TestCase): + def setUp(self): + model = ModelInputDict() + + sp_net_config = supernet(expand_ratio=[0.5, 1.0]) + self.model = Convert(sp_net_config).convert(model) + self.images = paddle.randn(shape=[2, 3, 32, 32], dtype='float32') + self.images2 = { + 'data': paddle.randn( + shape=[2, 12, 32, 32], dtype='float32') + } + default_run_config = {'skip_layers': ['conv1.0', 'conv2.0']} + self.run_config = RunConfig(**default_run_config) + + self.ofa_model = OFA(self.model, run_config=self.run_config) + self.ofa_model._clear_search_space(self.images, data=self.images2) + + def test_export(self): + + config = self.ofa_model._sample_config( + task="expand_ratio", sample_type="smallest") + self.ofa_model.export( + config, + input_shapes=[[1, 3, 32, 32], { + 'data': [1, 12, 32, 32] + }], + input_dtypes=['float32', 'float32']) class TestInputDict(unittest.TestCase): -- GitLab