“20178e0e091964c3349e7174a050302759c91f87”上不存在“examples/librispeech/asr5/path.sh”
未验证 提交 7411b0a5 编写于 作者: C Chang Xu 提交者: GitHub

Update check ss (#817)

* update_check_ss

* update_check_ss

* update_check_ss

* ss_multi_output

* remove_some_comments

* Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleSlim into update_check_ss

* merge_conflicts

* add_coverage
上级 e274d7fb
...@@ -57,6 +57,15 @@ class VarWrapper(object): ...@@ -57,6 +57,15 @@ class VarWrapper(object):
def __repr__(self): def __repr__(self):
return self._var.name return self._var.name
def __lt__(self, other):
return self._var.name < other._var.name
def __gt__(self, other):
return self._var.name > other._var.name
def __eq__(self, other):
return self._var.name == other._var.name
def shape(self): def shape(self):
""" """
Get the shape of the varibale. Get the shape of the varibale.
...@@ -144,6 +153,15 @@ class OpWrapper(object): ...@@ -144,6 +153,15 @@ class OpWrapper(object):
self.type(), self.type(),
self.all_inputs()) self.all_inputs())
def __lt__(self, other):
return self._op.idx < other._op.idx
def __gt__(self, other):
return self._op.idx > other._op.idx
def __eq__(self, other):
return self._op.idx == other._op.idx
def is_bwd_op(self): def is_bwd_op(self):
""" """
Whether this operator is backward op. Whether this operator is backward op.
......
...@@ -19,14 +19,37 @@ from .layers_base import BaseBlock ...@@ -19,14 +19,37 @@ from .layers_base import BaseBlock
__all__ = ['check_search_space'] __all__ = ['check_search_space']
WEIGHT_OP = [ DYNAMIC_WEIGHT_OP = [
'conv2d', 'linear', 'embedding', 'conv2d_transpose', 'depthwise_conv2d' 'conv2d', 'mul', 'matmul', 'embedding', 'conv2d_transpose',
'depthwise_conv2d'
] ]
CONV_TYPES = [ CONV_TYPES = [
'conv2d', 'conv3d', 'conv1d', 'superconv2d', 'supergroupconv2d', 'conv2d', 'conv3d', 'conv1d', 'superconv2d', 'supergroupconv2d',
'superdepthwiseconv2d' 'superdepthwiseconv2d'
] ]
ALL_WEIGHT_OP = [
'conv2d', 'mul', 'matmul', 'elementwise_add', 'embedding',
'conv2d_transpose', 'depthwise_conv2d', 'batch_norm', 'layer_norm',
'instance_norm', 'sync_batch_norm'
]
def _is_dynamic_weight_op(op, all_weight_op=False):
if all_weight_op == True:
weight_ops = ALL_WEIGHT_OP
else:
weight_ops = DYNAMIC_WEIGHT_OP
if op.type() in weight_ops:
if op.type() in ['mul', 'matmul']:
for inp in sorted(op.all_inputs()):
if inp._var.persistable == True:
return True
return False
return True
return False
def get_actual_shape(transform, channel): def get_actual_shape(transform, channel):
if transform == None: if transform == None:
...@@ -58,7 +81,7 @@ def _is_depthwise(op): ...@@ -58,7 +81,7 @@ def _is_depthwise(op):
def _find_weight_ops(op, graph, weights): def _find_weight_ops(op, graph, weights):
""" Find the vars come from operators with weight. """ Find the vars come from operators with weight.
""" """
pre_ops = graph.pre_ops(op) pre_ops = sorted(graph.pre_ops(op))
for pre_op in pre_ops: for pre_op in pre_ops:
### if depthwise conv is one of elementwise's input, ### if depthwise conv is one of elementwise's input,
### add it into this same search space ### add it into this same search space
...@@ -67,7 +90,7 @@ def _find_weight_ops(op, graph, weights): ...@@ -67,7 +90,7 @@ def _find_weight_ops(op, graph, weights):
if inp._var.persistable: if inp._var.persistable:
weights.append(inp._var.name) weights.append(inp._var.name)
if pre_op.type() in WEIGHT_OP and not _is_depthwise(pre_op): if _is_dynamic_weight_op(pre_op) and not _is_depthwise(pre_op):
for inp in pre_op.all_inputs(): for inp in pre_op.all_inputs():
if inp._var.persistable: if inp._var.persistable:
weights.append(inp._var.name) weights.append(inp._var.name)
...@@ -76,29 +99,70 @@ def _find_weight_ops(op, graph, weights): ...@@ -76,29 +99,70 @@ def _find_weight_ops(op, graph, weights):
return weights return weights
def _find_pre_elementwise_add(op, graph): def _find_pre_elementwise_op(op, graph):
""" Find precedors of the elementwise_add operator in the model. """ Find precedors of the elementwise_add operator in the model.
""" """
same_prune_before_elementwise_add = [] same_prune_before_elementwise_add = []
pre_ops = graph.pre_ops(op) pre_ops = sorted(graph.pre_ops(op))
for pre_op in pre_ops: for pre_op in pre_ops:
if pre_op.type() in WEIGHT_OP: if _is_dynamic_weight_op(pre_op):
return return
same_prune_before_elementwise_add = _find_weight_ops( same_prune_before_elementwise_add = _find_weight_ops(
pre_op, graph, same_prune_before_elementwise_add) pre_op, graph, same_prune_before_elementwise_add)
return same_prune_before_elementwise_add return same_prune_before_elementwise_add
def _is_output_weight_ops(op, graph):
next_ops = sorted(graph.next_ops(op))
for next_op in next_ops:
if op == next_op:
continue
if _is_dynamic_weight_op(next_op):
return False
return _is_output_weight_ops(next_op, graph)
return True
def check_search_space(graph): def check_search_space(graph):
""" Find the shortcut in the model and set same config for this situation. """ Find the shortcut in the model and set same config for this situation.
""" """
output_conv = []
same_search_space = [] same_search_space = []
depthwise_conv = [] depthwise_conv = []
fixed_by_input = []
for op in graph.ops(): for op in graph.ops():
# if there is no weight ops after this op,
# this op can be seen as an output
if _is_output_weight_ops(op, graph) and _is_dynamic_weight_op(op):
for inp in op.all_inputs():
if inp._var.persistable:
output_conv.append(inp._var.name)
if op.type() == 'elementwise_add' or op.type() == 'elementwise_mul': if op.type() == 'elementwise_add' or op.type() == 'elementwise_mul':
inp1, inp2 = op.all_inputs()[0], op.all_inputs()[1] inp1, inp2 = op.all_inputs()[0], op.all_inputs()[1]
if (not inp1._var.persistable) and (not inp2._var.persistable): if (not inp1._var.persistable) and (not inp2._var.persistable):
pre_ele_op = _find_pre_elementwise_add(op, graph) # if one of two vars comes from input,
# then the two vars in this elementwise op should be all fixed
if inp1.inputs() and inp2.inputs():
pre_fixed_op_1, pre_fixed_op_2 = [], []
pre_fixed_op_1 = _find_weight_ops(inp1.inputs()[0], graph,
pre_fixed_op_1)
pre_fixed_op_2 = _find_weight_ops(inp2.inputs()[0], graph,
pre_fixed_op_2)
if not pre_fixed_op_1:
fixed_by_input += pre_fixed_op_2
if not pre_fixed_op_2:
fixed_by_input += pre_fixed_op_1
elif (not inp1.inputs() and inp2.inputs()) or (
inp1.inputs() and not inp2.inputs()):
pre_fixed_op = []
inputs = inp1.inputs() if not inp2.inputs(
) else inp2.inputs()
pre_fixed_op = _find_weight_ops(inputs[0], graph,
pre_fixed_op)
fixed_by_input += pre_fixed_op
pre_ele_op = _find_pre_elementwise_op(op, graph)
if pre_ele_op != None: if pre_ele_op != None:
same_search_space.append(pre_ele_op) same_search_space.append(pre_ele_op)
...@@ -108,7 +172,7 @@ def check_search_space(graph): ...@@ -108,7 +172,7 @@ def check_search_space(graph):
depthwise_conv.append(inp._var.name) depthwise_conv.append(inp._var.name)
if len(same_search_space) == 0: if len(same_search_space) == 0:
return None, [] return None, [], [], output_conv
same_search_space = sorted([sorted(x) for x in same_search_space]) same_search_space = sorted([sorted(x) for x in same_search_space])
final_search_space = [] final_search_space = []
...@@ -129,8 +193,9 @@ def check_search_space(graph): ...@@ -129,8 +193,9 @@ def check_search_space(graph):
final_search_space.append(l) final_search_space.append(l)
final_search_space = sorted([sorted(x) for x in final_search_space]) final_search_space = sorted([sorted(x) for x in final_search_space])
depthwise_conv = sorted(depthwise_conv) depthwise_conv = sorted(depthwise_conv)
fixed_by_input = sorted(fixed_by_input)
return (final_search_space, depthwise_conv) return (final_search_space, depthwise_conv, fixed_by_input, output_conv)
def broadcast_search_space(same_search_space, param2key, origin_config): def broadcast_search_space(same_search_space, param2key, origin_config):
...@@ -156,11 +221,15 @@ def broadcast_search_space(same_search_space, param2key, origin_config): ...@@ -156,11 +221,15 @@ def broadcast_search_space(same_search_space, param2key, origin_config):
'channel': origin_config[pre_key]['channel'] 'channel': origin_config[pre_key]['channel']
}) })
else: else:
if 'expand_ratio' in origin_config[pre_key]: # if the pre_key is removed from config for some reasons
origin_config[key] = { # such as it is fixed by hand or by elementwise op
'expand_ratio': origin_config[pre_key]['expand_ratio'] if pre_key in origin_config:
} if 'expand_ratio' in origin_config[pre_key]:
elif 'channel' in origin_config[pre_key]: origin_config[key] = {
origin_config[key] = { 'expand_ratio':
'channel': origin_config[pre_key]['channel'] origin_config[pre_key]['expand_ratio']
} }
elif 'channel' in origin_config[pre_key]:
origin_config[key] = {
'channel': origin_config[pre_key]['channel']
}
...@@ -114,7 +114,7 @@ class OFABase(Layer): ...@@ -114,7 +114,7 @@ class OFABase(Layer):
if block.fixed == False and (self._skip_layers == None or if block.fixed == False and (self._skip_layers == None or
(self._skip_layers != None and (self._skip_layers != None and
self._key2name[block.key] not in self._skip_layers)) and \ self._key2name[block.key] not in self._skip_layers)) and \
(block.fn.weight.name not in self._depthwise_conv): (block.fn.weight.name not in self._cannot_changed_layer):
assert self._key2name[ assert self._key2name[
block. block.
key] in self.current_config, 'DONNT have {} layer in config.'.format( key] in self.current_config, 'DONNT have {} layer in config.'.format(
...@@ -183,7 +183,7 @@ class OFA(OFABase): ...@@ -183,7 +183,7 @@ class OFA(OFABase):
self._build_ss = False self._build_ss = False
self._broadcast = False self._broadcast = False
self._skip_layers = None self._skip_layers = None
self._depthwise_conv = [] self._cannot_changed_layer = []
### if elastic_order is none, use default order ### if elastic_order is none, use default order
if self.elastic_order is not None: if self.elastic_order is not None:
...@@ -346,6 +346,7 @@ class OFA(OFABase): ...@@ -346,6 +346,7 @@ class OFA(OFABase):
def _sample_from_nestdict(self, cands, sample_type, task, phase): def _sample_from_nestdict(self, cands, sample_type, task, phase):
sample_cands = dict() sample_cands = dict()
for k, v in cands.items(): for k, v in cands.items():
if isinstance(v, dict): if isinstance(v, dict):
sample_cands[k] = self._sample_from_nestdict( sample_cands[k] = self._sample_from_nestdict(
v, sample_type=sample_type, task=task, phase=phase) v, sample_type=sample_type, task=task, phase=phase)
...@@ -656,8 +657,9 @@ class OFA(OFABase): ...@@ -656,8 +657,9 @@ class OFA(OFABase):
self.model, DataParallel) else self.model self.model, DataParallel) else self.model
_st_prog = dygraph2program( _st_prog = dygraph2program(
model_to_traverse, inputs=input_shapes, dtypes=input_dtypes) model_to_traverse, inputs=input_shapes, dtypes=input_dtypes)
self._same_ss, self._depthwise_conv = check_search_space( self._same_ss, depthwise_conv, fixed_by_input, output_conv = check_search_space(
GraphWrapper(_st_prog)) GraphWrapper(_st_prog))
self._cannot_changed_layer = output_conv
if self._same_ss != None: if self._same_ss != None:
self._param2key = {} self._param2key = {}
...@@ -698,6 +700,15 @@ class OFA(OFABase): ...@@ -698,6 +700,15 @@ class OFA(OFABase):
self._same_ss = tmp_same_ss self._same_ss = tmp_same_ss
### if fixed_by_input layer in a same ss,
### layers in this same ss should all be fixed
tmp_fixed_by_input = []
for ss in self._same_ss:
for key in fixed_by_input:
if key in ss:
tmp_fixed_by_input += ss
fixed_by_input += tmp_fixed_by_input
### clear layer in ofa_layers set by skip layers ### clear layer in ofa_layers set by skip layers
if self._skip_layers != None: if self._skip_layers != None:
for skip_layer in self._skip_layers: for skip_layer in self._skip_layers:
...@@ -708,13 +719,17 @@ class OFA(OFABase): ...@@ -708,13 +719,17 @@ class OFA(OFABase):
for ss in per_ss[1:]: for ss in per_ss[1:]:
self._clear_width(self._param2key[ss]) self._clear_width(self._param2key[ss])
### clear depthwise conv from search space because of its output channel cannot change self._cannot_changed_layer = sorted(
for name, sublayer in model_to_traverse.named_sublayers(): set(output_conv + fixed_by_input + depthwise_conv))
if isinstance(sublayer, BaseBlock): ### clear depthwise convs from search space because of its output channel cannot change
for param in sublayer.parameters(): ### clear output convs from search space because of model output shape cannot change
if param.name in self._depthwise_conv and name in self._ofa_layers.keys( ### clear convs that operate with fixed input
): for name, sublayer in model_to_traverse.named_sublayers():
self._clear_width(name) if isinstance(sublayer, BaseBlock):
for param in sublayer.parameters():
if param.name in self._cannot_changed_layer and name in self._ofa_layers.keys(
):
self._clear_width(name)
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
# ===================== teacher process ===================== # ===================== teacher process =====================
...@@ -751,6 +766,7 @@ class OFA(OFABase): ...@@ -751,6 +766,7 @@ class OFA(OFABase):
self.current_config = self.net_config self.current_config = self.net_config
_logger.debug("Current config is {}".format(self.current_config)) _logger.debug("Current config is {}".format(self.current_config))
if 'depth' in self.current_config: if 'depth' in self.current_config:
kwargs['depth'] = self.current_config['depth'] kwargs['depth'] = self.current_config['depth']
if self._broadcast: if self._broadcast:
......
...@@ -446,7 +446,7 @@ class TestShortCut(unittest.TestCase): ...@@ -446,7 +446,7 @@ class TestShortCut(unittest.TestCase):
self.config, self.config,
input_shapes=[[2, 3, 224, 224]], input_shapes=[[2, 3, 224, 224]],
input_dtypes=['float32']) input_dtypes=['float32'])
assert len(self.ofa_model.ofa_layers) == 38 assert len(self.ofa_model.ofa_layers) == 37
class TestExportCase1(unittest.TestCase): class TestExportCase1(unittest.TestCase):
...@@ -462,8 +462,8 @@ class TestExportCase1(unittest.TestCase): ...@@ -462,8 +462,8 @@ class TestExportCase1(unittest.TestCase):
def test_export_model_linear1(self): def test_export_model_linear1(self):
ex_model = self.ofa_model.export( ex_model = self.ofa_model.export(
self.config, input_shapes=[[3, 64]], input_dtypes=['int64']) self.config, input_shapes=[[3, 64]], input_dtypes=['int64'])
assert len(self.ofa_model.ofa_layers) == 3
ex_model(self.data) ex_model(self.data)
assert len(self.ofa_model.ofa_layers) == 4
class TestExportCase2(unittest.TestCase): class TestExportCase2(unittest.TestCase):
...@@ -482,7 +482,7 @@ class TestExportCase2(unittest.TestCase): ...@@ -482,7 +482,7 @@ class TestExportCase2(unittest.TestCase):
ex_model = self.ofa_model.export( ex_model = self.ofa_model.export(
config, input_shapes=[[3, 64]], input_dtypes=['int64']) config, input_shapes=[[3, 64]], input_dtypes=['int64'])
ex_model(self.data) ex_model(self.data)
assert len(self.ofa_model.ofa_layers) == 4 assert len(self.ofa_model.ofa_layers) == 3
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -81,6 +81,74 @@ class ModelShortcut(nn.Layer): ...@@ -81,6 +81,74 @@ class ModelShortcut(nn.Layer):
return z return z
class ModelElementwise(nn.Layer):
def __init__(self):
super(ModelElementwise, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2D(3, 12, 1), nn.BatchNorm2D(12), nn.ReLU())
self.conv2 = nn.Sequential(
nn.Conv2D(12, 24, 3), nn.BatchNorm2D(24), nn.ReLU())
self.conv3 = nn.Sequential(
nn.Conv2D(24, 12, 1), nn.BatchNorm2D(12), nn.ReLU())
self.out = nn.Sequential(
nn.Conv2D(12, 6, 1), nn.BatchNorm2D(6), nn.ReLU())
def forward(self, x):
d = paddle.randn(shape=[2, 12, x.shape[2], x.shape[3]], dtype='float32')
d = nn.functional.softmax(d)
x = self.conv1(x)
x = x + d
x = self.conv2(x)
x = self.conv3(x)
x = self.out(x)
return x
class ModelMultiExit(nn.Layer):
def __init__(self):
super(ModelMultiExit, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2D(3, 12, 3), nn.BatchNorm2D(12), nn.ReLU())
self.block1 = nn.Sequential(
nn.Conv2D(12, 24, 7),
nn.BatchNorm2D(24),
nn.ReLU(),
nn.MaxPool2D(
kernel_size=3, stride=2, padding=0),
nn.Conv2D(24, 24, 7),
nn.BatchNorm2D(24),
nn.ReLU(),
nn.MaxPool2D(
kernel_size=3, stride=2, padding=0))
self.block2 = nn.Sequential(
nn.Conv2D(24, 24, 1),
nn.BatchNorm2D(24),
nn.ReLU(),
nn.MaxPool2D(
kernel_size=3, stride=2, padding=1))
self.out1 = nn.Sequential(
nn.Conv2D(24, 24, 1), nn.BatchNorm2D(24), nn.ReLU())
self.out2 = nn.Sequential(
nn.Conv2D(48, 24, 7),
nn.BatchNorm2D(24),
nn.ReLU(), nn.Conv2D(24, 24, 3), nn.BatchNorm2D(24), nn.ReLU())
def forward(self, x):
x = self.conv1(x)
b1 = self.block1(x)
adapt = nn.UpsamplingBilinear2D(size=[b1.shape[2], b1.shape[2]])
b2 = self.block2(b1)
up = adapt(b2)
y1 = self.out1(b1)
y2 = paddle.concat([b1, up], axis=1)
y2 = self.out2(y2)
return [y1, y2]
class ModelInputDict(nn.Layer): class ModelInputDict(nn.Layer):
def __init__(self): def __init__(self):
super(ModelInputDict, self).__init__() super(ModelInputDict, self).__init__()
...@@ -132,6 +200,37 @@ class TestOFAV2Export(unittest.TestCase): ...@@ -132,6 +200,37 @@ class TestOFAV2Export(unittest.TestCase):
origin_model=origin_model) origin_model=origin_model)
class Testelementwise(unittest.TestCase):
def setUp(self):
model = ModelElementwise()
sp_net_config = supernet(expand_ratio=[0.25, 0.5, 1.0])
self.model = Convert(sp_net_config).convert(model)
self.images = paddle.randn(shape=[2, 3, 32, 32], dtype='float32')
def test_elementwise(self):
self.ofa_model = OFA(self.model)
self.ofa_model.set_epoch(0)
self.ofa_model.set_task('expand_ratio')
out, _ = self.ofa_model(self.images)
assert list(self.ofa_model._ofa_layers.keys()) == ['conv2.0', 'conv3.0']
class TestMultiExit(unittest.TestCase):
def setUp(self):
self.images = paddle.randn(shape=[1, 3, 224, 224], dtype='float32')
model = ModelMultiExit()
sp_net_config = supernet(expand_ratio=[0.25, 0.5, 1.0])
self.model = Convert(sp_net_config).convert(model)
def test_multiexit(self):
self.ofa_model = OFA(self.model)
self.ofa_model.set_epoch(0)
self.ofa_model.set_task('expand_ratio')
out, _ = self.ofa_model(self.images)
assert list(self.ofa_model._ofa_layers.keys(
)) == ['conv1.0', 'block1.0', 'block1.4', 'block2.0', 'out2.0']
class TestShortcutSkiplayers(unittest.TestCase): class TestShortcutSkiplayers(unittest.TestCase):
def setUp(self): def setUp(self):
model = ModelShortcut() model = ModelShortcut()
...@@ -151,7 +250,7 @@ class TestShortcutSkiplayers(unittest.TestCase): ...@@ -151,7 +250,7 @@ class TestShortcutSkiplayers(unittest.TestCase):
self.ofa_model.set_task('expand_ratio') self.ofa_model.set_task('expand_ratio')
for i in range(5): for i in range(5):
self.ofa_model(self.images) self.ofa_model(self.images)
assert list(self.ofa_model._ofa_layers.keys()) == ['branch2.0', 'out.0'] assert list(self.ofa_model._ofa_layers.keys()) == ['branch2.0']
class TestShortcutSkiplayersCase1(TestShortcutSkiplayers): class TestShortcutSkiplayersCase1(TestShortcutSkiplayers):
...@@ -166,7 +265,36 @@ class TestShortcutSkiplayersCase2(TestShortcutSkiplayers): ...@@ -166,7 +265,36 @@ class TestShortcutSkiplayersCase2(TestShortcutSkiplayers):
self.run_config = RunConfig(**default_run_config) self.run_config = RunConfig(**default_run_config)
def test_shortcut(self): def test_shortcut(self):
assert list(self.ofa_model._ofa_layers.keys()) == ['conv1.0', 'out.0'] assert list(self.ofa_model._ofa_layers.keys()) == ['conv1.0']
class TestInputDict(unittest.TestCase):
def setUp(self):
model = ModelInputDict()
sp_net_config = supernet(expand_ratio=[0.5, 1.0])
self.model = Convert(sp_net_config).convert(model)
self.images = paddle.randn(shape=[2, 3, 32, 32], dtype='float32')
self.images2 = {
'data': paddle.randn(
shape=[2, 12, 32, 32], dtype='float32')
}
default_run_config = {'skip_layers': ['conv1.0', 'conv2.0']}
self.run_config = RunConfig(**default_run_config)
self.ofa_model = OFA(self.model, run_config=self.run_config)
self.ofa_model._clear_search_space(self.images, data=self.images2)
def test_export(self):
config = self.ofa_model._sample_config(
task="expand_ratio", sample_type="smallest")
self.ofa_model.export(
config,
input_shapes=[[1, 3, 32, 32], {
'data': [1, 12, 32, 32]
}],
input_dtypes=['float32', 'float32'])
class TestInputDict(unittest.TestCase): class TestInputDict(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册