未验证 提交 7411b0a5 编写于 作者: C Chang Xu 提交者: GitHub

Update check ss (#817)

* update_check_ss

* update_check_ss

* update_check_ss

* ss_multi_output

* remove_some_comments

* Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleSlim into update_check_ss

* merge_conflicts

* add_coverage
上级 e274d7fb
......@@ -57,6 +57,15 @@ class VarWrapper(object):
def __repr__(self):
return self._var.name
def __lt__(self, other):
return self._var.name < other._var.name
def __gt__(self, other):
return self._var.name > other._var.name
def __eq__(self, other):
return self._var.name == other._var.name
def shape(self):
"""
Get the shape of the varibale.
......@@ -144,6 +153,15 @@ class OpWrapper(object):
self.type(),
self.all_inputs())
def __lt__(self, other):
return self._op.idx < other._op.idx
def __gt__(self, other):
return self._op.idx > other._op.idx
def __eq__(self, other):
return self._op.idx == other._op.idx
def is_bwd_op(self):
"""
Whether this operator is backward op.
......
......@@ -19,14 +19,37 @@ from .layers_base import BaseBlock
__all__ = ['check_search_space']
WEIGHT_OP = [
'conv2d', 'linear', 'embedding', 'conv2d_transpose', 'depthwise_conv2d'
DYNAMIC_WEIGHT_OP = [
'conv2d', 'mul', 'matmul', 'embedding', 'conv2d_transpose',
'depthwise_conv2d'
]
CONV_TYPES = [
'conv2d', 'conv3d', 'conv1d', 'superconv2d', 'supergroupconv2d',
'superdepthwiseconv2d'
]
ALL_WEIGHT_OP = [
'conv2d', 'mul', 'matmul', 'elementwise_add', 'embedding',
'conv2d_transpose', 'depthwise_conv2d', 'batch_norm', 'layer_norm',
'instance_norm', 'sync_batch_norm'
]
def _is_dynamic_weight_op(op, all_weight_op=False):
if all_weight_op == True:
weight_ops = ALL_WEIGHT_OP
else:
weight_ops = DYNAMIC_WEIGHT_OP
if op.type() in weight_ops:
if op.type() in ['mul', 'matmul']:
for inp in sorted(op.all_inputs()):
if inp._var.persistable == True:
return True
return False
return True
return False
def get_actual_shape(transform, channel):
if transform == None:
......@@ -58,7 +81,7 @@ def _is_depthwise(op):
def _find_weight_ops(op, graph, weights):
""" Find the vars come from operators with weight.
"""
pre_ops = graph.pre_ops(op)
pre_ops = sorted(graph.pre_ops(op))
for pre_op in pre_ops:
### if depthwise conv is one of elementwise's input,
### add it into this same search space
......@@ -67,7 +90,7 @@ def _find_weight_ops(op, graph, weights):
if inp._var.persistable:
weights.append(inp._var.name)
if pre_op.type() in WEIGHT_OP and not _is_depthwise(pre_op):
if _is_dynamic_weight_op(pre_op) and not _is_depthwise(pre_op):
for inp in pre_op.all_inputs():
if inp._var.persistable:
weights.append(inp._var.name)
......@@ -76,29 +99,70 @@ def _find_weight_ops(op, graph, weights):
return weights
def _find_pre_elementwise_add(op, graph):
def _find_pre_elementwise_op(op, graph):
""" Find precedors of the elementwise_add operator in the model.
"""
same_prune_before_elementwise_add = []
pre_ops = graph.pre_ops(op)
pre_ops = sorted(graph.pre_ops(op))
for pre_op in pre_ops:
if pre_op.type() in WEIGHT_OP:
if _is_dynamic_weight_op(pre_op):
return
same_prune_before_elementwise_add = _find_weight_ops(
pre_op, graph, same_prune_before_elementwise_add)
return same_prune_before_elementwise_add
def _is_output_weight_ops(op, graph):
next_ops = sorted(graph.next_ops(op))
for next_op in next_ops:
if op == next_op:
continue
if _is_dynamic_weight_op(next_op):
return False
return _is_output_weight_ops(next_op, graph)
return True
def check_search_space(graph):
""" Find the shortcut in the model and set same config for this situation.
"""
output_conv = []
same_search_space = []
depthwise_conv = []
fixed_by_input = []
for op in graph.ops():
# if there is no weight ops after this op,
# this op can be seen as an output
if _is_output_weight_ops(op, graph) and _is_dynamic_weight_op(op):
for inp in op.all_inputs():
if inp._var.persistable:
output_conv.append(inp._var.name)
if op.type() == 'elementwise_add' or op.type() == 'elementwise_mul':
inp1, inp2 = op.all_inputs()[0], op.all_inputs()[1]
if (not inp1._var.persistable) and (not inp2._var.persistable):
pre_ele_op = _find_pre_elementwise_add(op, graph)
# if one of two vars comes from input,
# then the two vars in this elementwise op should be all fixed
if inp1.inputs() and inp2.inputs():
pre_fixed_op_1, pre_fixed_op_2 = [], []
pre_fixed_op_1 = _find_weight_ops(inp1.inputs()[0], graph,
pre_fixed_op_1)
pre_fixed_op_2 = _find_weight_ops(inp2.inputs()[0], graph,
pre_fixed_op_2)
if not pre_fixed_op_1:
fixed_by_input += pre_fixed_op_2
if not pre_fixed_op_2:
fixed_by_input += pre_fixed_op_1
elif (not inp1.inputs() and inp2.inputs()) or (
inp1.inputs() and not inp2.inputs()):
pre_fixed_op = []
inputs = inp1.inputs() if not inp2.inputs(
) else inp2.inputs()
pre_fixed_op = _find_weight_ops(inputs[0], graph,
pre_fixed_op)
fixed_by_input += pre_fixed_op
pre_ele_op = _find_pre_elementwise_op(op, graph)
if pre_ele_op != None:
same_search_space.append(pre_ele_op)
......@@ -108,7 +172,7 @@ def check_search_space(graph):
depthwise_conv.append(inp._var.name)
if len(same_search_space) == 0:
return None, []
return None, [], [], output_conv
same_search_space = sorted([sorted(x) for x in same_search_space])
final_search_space = []
......@@ -129,8 +193,9 @@ def check_search_space(graph):
final_search_space.append(l)
final_search_space = sorted([sorted(x) for x in final_search_space])
depthwise_conv = sorted(depthwise_conv)
fixed_by_input = sorted(fixed_by_input)
return (final_search_space, depthwise_conv)
return (final_search_space, depthwise_conv, fixed_by_input, output_conv)
def broadcast_search_space(same_search_space, param2key, origin_config):
......@@ -156,11 +221,15 @@ def broadcast_search_space(same_search_space, param2key, origin_config):
'channel': origin_config[pre_key]['channel']
})
else:
if 'expand_ratio' in origin_config[pre_key]:
origin_config[key] = {
'expand_ratio': origin_config[pre_key]['expand_ratio']
}
elif 'channel' in origin_config[pre_key]:
origin_config[key] = {
'channel': origin_config[pre_key]['channel']
}
# if the pre_key is removed from config for some reasons
# such as it is fixed by hand or by elementwise op
if pre_key in origin_config:
if 'expand_ratio' in origin_config[pre_key]:
origin_config[key] = {
'expand_ratio':
origin_config[pre_key]['expand_ratio']
}
elif 'channel' in origin_config[pre_key]:
origin_config[key] = {
'channel': origin_config[pre_key]['channel']
}
......@@ -114,7 +114,7 @@ class OFABase(Layer):
if block.fixed == False and (self._skip_layers == None or
(self._skip_layers != None and
self._key2name[block.key] not in self._skip_layers)) and \
(block.fn.weight.name not in self._depthwise_conv):
(block.fn.weight.name not in self._cannot_changed_layer):
assert self._key2name[
block.
key] in self.current_config, 'DONNT have {} layer in config.'.format(
......@@ -183,7 +183,7 @@ class OFA(OFABase):
self._build_ss = False
self._broadcast = False
self._skip_layers = None
self._depthwise_conv = []
self._cannot_changed_layer = []
### if elastic_order is none, use default order
if self.elastic_order is not None:
......@@ -346,6 +346,7 @@ class OFA(OFABase):
def _sample_from_nestdict(self, cands, sample_type, task, phase):
sample_cands = dict()
for k, v in cands.items():
if isinstance(v, dict):
sample_cands[k] = self._sample_from_nestdict(
v, sample_type=sample_type, task=task, phase=phase)
......@@ -656,8 +657,9 @@ class OFA(OFABase):
self.model, DataParallel) else self.model
_st_prog = dygraph2program(
model_to_traverse, inputs=input_shapes, dtypes=input_dtypes)
self._same_ss, self._depthwise_conv = check_search_space(
self._same_ss, depthwise_conv, fixed_by_input, output_conv = check_search_space(
GraphWrapper(_st_prog))
self._cannot_changed_layer = output_conv
if self._same_ss != None:
self._param2key = {}
......@@ -698,6 +700,15 @@ class OFA(OFABase):
self._same_ss = tmp_same_ss
### if fixed_by_input layer in a same ss,
### layers in this same ss should all be fixed
tmp_fixed_by_input = []
for ss in self._same_ss:
for key in fixed_by_input:
if key in ss:
tmp_fixed_by_input += ss
fixed_by_input += tmp_fixed_by_input
### clear layer in ofa_layers set by skip layers
if self._skip_layers != None:
for skip_layer in self._skip_layers:
......@@ -708,13 +719,17 @@ class OFA(OFABase):
for ss in per_ss[1:]:
self._clear_width(self._param2key[ss])
### clear depthwise conv from search space because of its output channel cannot change
for name, sublayer in model_to_traverse.named_sublayers():
if isinstance(sublayer, BaseBlock):
for param in sublayer.parameters():
if param.name in self._depthwise_conv and name in self._ofa_layers.keys(
):
self._clear_width(name)
self._cannot_changed_layer = sorted(
set(output_conv + fixed_by_input + depthwise_conv))
### clear depthwise convs from search space because of its output channel cannot change
### clear output convs from search space because of model output shape cannot change
### clear convs that operate with fixed input
for name, sublayer in model_to_traverse.named_sublayers():
if isinstance(sublayer, BaseBlock):
for param in sublayer.parameters():
if param.name in self._cannot_changed_layer and name in self._ofa_layers.keys(
):
self._clear_width(name)
def forward(self, *inputs, **kwargs):
# ===================== teacher process =====================
......@@ -751,6 +766,7 @@ class OFA(OFABase):
self.current_config = self.net_config
_logger.debug("Current config is {}".format(self.current_config))
if 'depth' in self.current_config:
kwargs['depth'] = self.current_config['depth']
if self._broadcast:
......
......@@ -446,7 +446,7 @@ class TestShortCut(unittest.TestCase):
self.config,
input_shapes=[[2, 3, 224, 224]],
input_dtypes=['float32'])
assert len(self.ofa_model.ofa_layers) == 38
assert len(self.ofa_model.ofa_layers) == 37
class TestExportCase1(unittest.TestCase):
......@@ -462,8 +462,8 @@ class TestExportCase1(unittest.TestCase):
def test_export_model_linear1(self):
ex_model = self.ofa_model.export(
self.config, input_shapes=[[3, 64]], input_dtypes=['int64'])
assert len(self.ofa_model.ofa_layers) == 3
ex_model(self.data)
assert len(self.ofa_model.ofa_layers) == 4
class TestExportCase2(unittest.TestCase):
......@@ -482,7 +482,7 @@ class TestExportCase2(unittest.TestCase):
ex_model = self.ofa_model.export(
config, input_shapes=[[3, 64]], input_dtypes=['int64'])
ex_model(self.data)
assert len(self.ofa_model.ofa_layers) == 4
assert len(self.ofa_model.ofa_layers) == 3
if __name__ == '__main__':
......
......@@ -81,6 +81,74 @@ class ModelShortcut(nn.Layer):
return z
class ModelElementwise(nn.Layer):
def __init__(self):
super(ModelElementwise, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2D(3, 12, 1), nn.BatchNorm2D(12), nn.ReLU())
self.conv2 = nn.Sequential(
nn.Conv2D(12, 24, 3), nn.BatchNorm2D(24), nn.ReLU())
self.conv3 = nn.Sequential(
nn.Conv2D(24, 12, 1), nn.BatchNorm2D(12), nn.ReLU())
self.out = nn.Sequential(
nn.Conv2D(12, 6, 1), nn.BatchNorm2D(6), nn.ReLU())
def forward(self, x):
d = paddle.randn(shape=[2, 12, x.shape[2], x.shape[3]], dtype='float32')
d = nn.functional.softmax(d)
x = self.conv1(x)
x = x + d
x = self.conv2(x)
x = self.conv3(x)
x = self.out(x)
return x
class ModelMultiExit(nn.Layer):
def __init__(self):
super(ModelMultiExit, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2D(3, 12, 3), nn.BatchNorm2D(12), nn.ReLU())
self.block1 = nn.Sequential(
nn.Conv2D(12, 24, 7),
nn.BatchNorm2D(24),
nn.ReLU(),
nn.MaxPool2D(
kernel_size=3, stride=2, padding=0),
nn.Conv2D(24, 24, 7),
nn.BatchNorm2D(24),
nn.ReLU(),
nn.MaxPool2D(
kernel_size=3, stride=2, padding=0))
self.block2 = nn.Sequential(
nn.Conv2D(24, 24, 1),
nn.BatchNorm2D(24),
nn.ReLU(),
nn.MaxPool2D(
kernel_size=3, stride=2, padding=1))
self.out1 = nn.Sequential(
nn.Conv2D(24, 24, 1), nn.BatchNorm2D(24), nn.ReLU())
self.out2 = nn.Sequential(
nn.Conv2D(48, 24, 7),
nn.BatchNorm2D(24),
nn.ReLU(), nn.Conv2D(24, 24, 3), nn.BatchNorm2D(24), nn.ReLU())
def forward(self, x):
x = self.conv1(x)
b1 = self.block1(x)
adapt = nn.UpsamplingBilinear2D(size=[b1.shape[2], b1.shape[2]])
b2 = self.block2(b1)
up = adapt(b2)
y1 = self.out1(b1)
y2 = paddle.concat([b1, up], axis=1)
y2 = self.out2(y2)
return [y1, y2]
class ModelInputDict(nn.Layer):
def __init__(self):
super(ModelInputDict, self).__init__()
......@@ -132,6 +200,37 @@ class TestOFAV2Export(unittest.TestCase):
origin_model=origin_model)
class Testelementwise(unittest.TestCase):
def setUp(self):
model = ModelElementwise()
sp_net_config = supernet(expand_ratio=[0.25, 0.5, 1.0])
self.model = Convert(sp_net_config).convert(model)
self.images = paddle.randn(shape=[2, 3, 32, 32], dtype='float32')
def test_elementwise(self):
self.ofa_model = OFA(self.model)
self.ofa_model.set_epoch(0)
self.ofa_model.set_task('expand_ratio')
out, _ = self.ofa_model(self.images)
assert list(self.ofa_model._ofa_layers.keys()) == ['conv2.0', 'conv3.0']
class TestMultiExit(unittest.TestCase):
def setUp(self):
self.images = paddle.randn(shape=[1, 3, 224, 224], dtype='float32')
model = ModelMultiExit()
sp_net_config = supernet(expand_ratio=[0.25, 0.5, 1.0])
self.model = Convert(sp_net_config).convert(model)
def test_multiexit(self):
self.ofa_model = OFA(self.model)
self.ofa_model.set_epoch(0)
self.ofa_model.set_task('expand_ratio')
out, _ = self.ofa_model(self.images)
assert list(self.ofa_model._ofa_layers.keys(
)) == ['conv1.0', 'block1.0', 'block1.4', 'block2.0', 'out2.0']
class TestShortcutSkiplayers(unittest.TestCase):
def setUp(self):
model = ModelShortcut()
......@@ -151,7 +250,7 @@ class TestShortcutSkiplayers(unittest.TestCase):
self.ofa_model.set_task('expand_ratio')
for i in range(5):
self.ofa_model(self.images)
assert list(self.ofa_model._ofa_layers.keys()) == ['branch2.0', 'out.0']
assert list(self.ofa_model._ofa_layers.keys()) == ['branch2.0']
class TestShortcutSkiplayersCase1(TestShortcutSkiplayers):
......@@ -166,7 +265,36 @@ class TestShortcutSkiplayersCase2(TestShortcutSkiplayers):
self.run_config = RunConfig(**default_run_config)
def test_shortcut(self):
assert list(self.ofa_model._ofa_layers.keys()) == ['conv1.0', 'out.0']
assert list(self.ofa_model._ofa_layers.keys()) == ['conv1.0']
class TestInputDict(unittest.TestCase):
def setUp(self):
model = ModelInputDict()
sp_net_config = supernet(expand_ratio=[0.5, 1.0])
self.model = Convert(sp_net_config).convert(model)
self.images = paddle.randn(shape=[2, 3, 32, 32], dtype='float32')
self.images2 = {
'data': paddle.randn(
shape=[2, 12, 32, 32], dtype='float32')
}
default_run_config = {'skip_layers': ['conv1.0', 'conv2.0']}
self.run_config = RunConfig(**default_run_config)
self.ofa_model = OFA(self.model, run_config=self.run_config)
self.ofa_model._clear_search_space(self.images, data=self.images2)
def test_export(self):
config = self.ofa_model._sample_config(
task="expand_ratio", sample_type="smallest")
self.ofa_model.export(
config,
input_shapes=[[1, 3, 32, 32], {
'data': [1, 12, 32, 32]
}],
input_dtypes=['float32', 'float32'])
class TestInputDict(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册