未验证 提交 8cc4e44c 编写于 作者: C ceci3 提交者: GitHub

fix some bug of ofa (#712)

* fix dist

* fix

* remove hook

* fix skip layers

* fix depthwise

* fix depthwise

* fix depthwise
上级 58ac4d43
......@@ -105,6 +105,8 @@ class Convert:
cur_channel = None
for idx, layer in enumerate(model):
cls_name = layer.__class__.__name__.lower()
### basic api in paddle
if len(layer.sublayers()) == 0:
if 'conv' in cls_name or 'linear' in cls_name or 'embedding' in cls_name:
weight_layer_count += 1
last_weight_layer_idx = idx
......
......@@ -146,12 +146,34 @@ def prune_params(model, param_config, super_model_sd=None):
setattr(model, k, v)
def _is_depthwise(op):
"""Check if this op is depthwise conv.
The shape of input and the shape of output in depthwise conv must be same in superlayer,
so depthwise op cannot be consider as weight op
"""
if op.type() == 'depthwise_conv':
return True
elif 'conv' in op.type():
for inp in op.all_inputs():
if not inp._var.persistable and op.attr('groups') == inp._var.shape[
1]:
return True
return False
def _find_weight_ops(op, graph, weights):
""" Find the vars come from operators with weight.
"""
pre_ops = graph.pre_ops(op)
for pre_op in pre_ops:
if pre_op.type() in WEIGHT_OP:
### if depthwise conv is one of elementwise's input,
### add it into this same search space
if _is_depthwise(pre_op):
for inp in pre_op.all_inputs():
if inp._var.persistable:
weights.append(inp._var.name)
if pre_op.type() in WEIGHT_OP and not _is_depthwise(pre_op):
for inp in pre_op.all_inputs():
if inp._var.persistable:
weights.append(inp._var.name)
......@@ -176,16 +198,22 @@ def check_search_space(graph):
""" Find the shortcut in the model and set same config for this situation.
"""
same_search_space = []
depthwise_conv = []
for op in graph.ops():
if op.type() == 'elementwise_add':
if op.type() == 'elementwise_add' or op.type() == 'elementwise_mul':
inp1, inp2 = op.all_inputs()[0], op.all_inputs()[1]
if (not inp1._var.persistable) and (not inp2._var.persistable):
pre_ele_op = _find_pre_elementwise_add(op, graph)
if pre_ele_op != None:
same_search_space.append(pre_ele_op)
if _is_depthwise(op):
for inp in op.all_inputs():
if inp._var.persistable:
depthwise_conv.append(inp._var.name)
if len(same_search_space) == 0:
return None
return None, None
same_search_space = sorted([sorted(x) for x in same_search_space])
final_search_space = []
......@@ -204,5 +232,7 @@ def check_search_space(graph):
break
if not merged:
final_search_space.append(l)
final_search_space = sorted([sorted(x) for x in final_search_space])
depthwise_conv = sorted(depthwise_conv)
return final_search_space
return (final_search_space, depthwise_conv)
......@@ -302,7 +302,16 @@ class SuperConv2D(nn.Conv2D):
padding = self._padding
if self.bias is not None:
bias = self.bias[:out_nc]
### if conv is depthwise conv, expand_ratio=0, but conv' expand
### ratio before depthwise conv is not equal to 1.0, the shape of the weight
### about this depthwise conv is changed, but out_nc is not change,
### so need to change bias shape according to the weight_out_nc.
### if in_nc > groups > 1, the actual output of conv is weight_out_nc * groups,
### so slice the shape of bias by weight_out_nc and groups.
### if in_nc = groups, slice the shape of bias by weight_out_nc.
if groups != in_nc:
weight_out_nc = weight_out_nc * groups
bias = self.bias[:weight_out_nc]
else:
bias = self.bias
......@@ -313,7 +322,7 @@ class SuperConv2D(nn.Conv2D):
stride=self._stride,
padding=padding,
dilation=self._dilation,
groups=self._groups,
groups=groups,
data_format=self._data_format)
return out
......@@ -606,7 +615,9 @@ class SuperConv2DTranspose(nn.Conv2DTranspose):
output_padding = 0
if self.bias is not None:
bias = self.bias[:out_nc]
if groups != in_nc:
weight_out_nc = weight_out_nc * groups
bias = self.bias[:weight_out_nc]
else:
bias = self.bias
......@@ -618,7 +629,7 @@ class SuperConv2DTranspose(nn.Conv2DTranspose):
output_padding=output_padding,
stride=self._stride,
dilation=self._dilation,
groups=self._groups,
groups=groups,
output_size=output_size,
data_format=self._data_format)
return out
......@@ -929,10 +940,11 @@ class SuperBatchNorm2D(nn.BatchNorm2D):
weight_attr=None,
bias_attr=None,
data_format='NCHW',
use_global_stats=None,
name=None):
super(SuperBatchNorm2D, self).__init__(num_features, momentum, epsilon,
weight_attr, bias_attr,
data_format, name)
super(SuperBatchNorm2D, self).__init__(
num_features, momentum, epsilon, weight_attr, bias_attr,
data_format, use_global_stats, name)
def forward(self, input):
self._check_data_format(self._data_format)
......@@ -954,7 +966,8 @@ class SuperBatchNorm2D(nn.BatchNorm2D):
training=self.training,
momentum=self._momentum,
epsilon=self._epsilon,
data_format=self._data_format)
data_format=self._data_format,
use_global_stats=self._use_global_stats)
class SuperSyncBatchNorm(nn.SyncBatchNorm):
......
......@@ -51,7 +51,9 @@ RunConfig = namedtuple(
# list, elactic depth of the model in training, default: None
'elastic_depth',
# list, the number of sub-network to train per mini-batch data, used to get current epoch, default: None
'dynamic_batch_size'
'dynamic_batch_size',
# the shape of weights in the skip_layers will not change in the training, default: None
'skip_layers'
])
RunConfig.__new__.__defaults__ = (None, ) * len(RunConfig._fields)
......@@ -106,7 +108,10 @@ class OFABase(Layer):
if getattr(self, 'current_config', None) != None:
### if block is fixed, donnot join key into candidate
### concrete config as parameter in kwargs
if block.fixed == False:
if block.fixed == False and (
self._skip_layers != None and
self._key2name[block.key] not in self._skip_layers) and \
(block.fn.weight.name not in self._depthwise_conv):
assert self._key2name[
block.
key] in self.current_config, 'DONNT have {} layer in config.'.format(
......@@ -117,7 +122,7 @@ class OFABase(Layer):
config.update(kwargs)
else:
config = dict()
logging.debug(self.model, config)
_logger.debug(self.model, config)
return block.fn(*inputs, **config)
......@@ -175,6 +180,7 @@ class OFA(OFABase):
self._mapping_layers = None
self._build_ss = False
self._broadcast = False
self._skip_layers = None
### if elastic_order is none, use default order
if self.elastic_order is not None:
......@@ -185,6 +191,7 @@ class OFA(OFABase):
depth_list = list(set(self.run_config.elastic_depth))
depth_list.sort()
self._ofa_layers['depth'] = depth_list
self._layers['depth'] = depth_list
if self.elastic_order is None:
self.elastic_order = []
......@@ -198,6 +205,7 @@ class OFA(OFABase):
depth_list = list(set(self.run_config.elastic_depth))
depth_list.sort()
self._ofa_layers['depth'] = depth_list
self._layers['depth'] = depth_list
self.elastic_order.append('depth')
# final, elastic width
......@@ -225,6 +233,11 @@ class OFA(OFABase):
run_config.init_learning_rate[idx], list
), "each candidate in init_learning_rate must be list"
### remove skip layers in search space
if self.run_config != None and getattr(self.run_config, 'skip_layers',
None) != None:
self._skip_layers = self.run_config.skip_layers
### ================= add distill prepare ======================
if self.distill_config != None:
self._add_teacher = True
......@@ -234,7 +247,7 @@ class OFA(OFABase):
def _prepare_distill(self):
if self.distill_config.teacher_model == None:
logging.error(
_logger.error(
'If you want to add distill, please input instance of teacher model'
)
......@@ -289,6 +302,7 @@ class OFA(OFABase):
def _reset_hook_before_forward(self):
self.Tacts, self.Sacts = {}, {}
self.hooks = []
if self._mapping_layers != None:
def get_activation(mem, name):
......@@ -300,12 +314,18 @@ class OFA(OFABase):
def add_hook(net, mem, mapping_layers):
for idx, (n, m) in enumerate(net.named_sublayers()):
if n in mapping_layers:
m.register_forward_post_hook(get_activation(mem, n))
self.hooks.append(
m.register_forward_post_hook(
get_activation(mem, n)))
add_hook(self.model, self.Sacts, self._mapping_layers)
add_hook(self.ofa_teacher_model.model, self.Tacts,
self._mapping_layers)
def _remove_hook_after_forward(self):
for hook in self.hooks:
hook.remove()
def _compute_epochs(self):
if getattr(self, 'epoch', None) == None:
assert self.run_config.total_images is not None, \
......@@ -521,6 +541,14 @@ class OFA(OFABase):
else:
return False
def _clear_width(self, key):
if 'expand_ratio' in self._ofa_layers[key]:
self._ofa_layers[key].pop('expand_ratio')
elif 'channel' in self._ofa_layers[key]:
self._ofa_layers[key].pop('channel')
if len(self._ofa_layers[key]) == 0:
self._ofa_layers.pop(key)
def _clear_search_space(self, *inputs, **kwargs):
""" find shortcut in model, and clear up the search space """
input_shapes = []
......@@ -533,34 +561,70 @@ class OFA(OFABase):
input_dtypes.append(v.numpy().dtype)
### find shortcut block using static model
model_to_traverse = self.model._layers if isinstance(
self.model, DataParallel) else self.model
_st_prog = dygraph2program(
self.model, inputs=input_shapes, dtypes=input_dtypes)
self._same_ss = check_search_space(GraphWrapper(_st_prog))
model_to_traverse, inputs=input_shapes, dtypes=input_dtypes)
self._same_ss, self._depthwise_conv = check_search_space(
GraphWrapper(_st_prog))
if self._same_ss != None:
self._same_ss = sorted(self._same_ss)
self._param2key = {}
self._broadcast = True
### the name of sublayer is the key in search space
### param.name is the name in self._same_ss
model_to_traverse = self.model._layers if isinstance(
self.model, DataParallel) else self.model
for name, sublayer in model_to_traverse.named_sublayers():
if isinstance(sublayer, BaseBlock):
for param in sublayer.parameters():
if self._find_ele(param.name, self._same_ss):
self._param2key[param.name] = name
### double clear same search space to avoid outputs weights in same ss.
tmp_same_ss = []
for ss in self._same_ss:
per_ss = []
for key in ss:
if key not in self._param2key.keys():
continue
### if skip_layers and same ss both have same layer,
### the layers related to this layer need to add to skip_layers
if self._skip_layers != None and self._param2key[
key] in self._skip_layers:
self._skip_layers += [self._param2key[sk] for sk in ss]
per_ss = []
break
if self._param2key[key] in self._ofa_layers.keys() and \
('expand_ratio' in self._ofa_layers[self._param2key[key]] or \
'channel' in self._ofa_layers[self._param2key[key]]):
per_ss.append(key)
else:
_logger.info("{} not in ss".format(key))
if len(per_ss) != 0:
tmp_same_ss.append(per_ss)
self._same_ss = tmp_same_ss
### clear layer in ofa_layers set by skip layers
if self._skip_layers != None:
for skip_layer in self._skip_layers:
if skip_layer in self._ofa_layers.keys():
self._ofa_layers.pop(skip_layer)
for per_ss in self._same_ss:
for ss in per_ss[1:]:
if 'expand_ratio' in self._ofa_layers[self._param2key[ss]]:
self._ofa_layers[self._param2key[ss]].pop(
'expand_ratio')
elif 'channel' in self._ofa_layers[self._param2key[ss]]:
self._ofa_layers[self._param2key[ss]].pop('channel')
if len(self._ofa_layers[self._param2key[ss]]) == 0:
self._ofa_layers.pop(self._param2key[ss])
self._clear_width(self._param2key[ss])
### clear depthwise conv from search space because of its output channel cannot change
for name, sublayer in model_to_traverse.named_sublayers():
if isinstance(sublayer, BaseBlock):
for param in sublayer.parameters():
if param.name in self._depthwise_conv and name in self._ofa_layers.keys(
):
self._clear_width(name)
def _broadcast_ss(self):
""" broadcast search space after random sample."""
......@@ -630,4 +694,9 @@ class OFA(OFABase):
if self._broadcast:
self._broadcast_ss()
return self.model.forward(*inputs, **kwargs), teacher_output
student_output = self.model.forward(*inputs, **kwargs)
if self._add_teacher:
self._remove_hook_after_forward()
return student_output, teacher_output
......@@ -45,6 +45,9 @@ def dynabert_config(model, width_mult, depth_mult=1.0):
if block_k == 'depth':
block_v = depth_mult
if block_k != 'depth':
new_block_k = model._key2name[block_k]
else:
new_block_k = 'depth'
new_config[new_block_k] = block_v
return new_config
......@@ -40,6 +40,46 @@ class ModelV1(nn.Layer):
return self.cls + self.model(inputs)
class ModelShortcut(nn.Layer):
def __init__(self):
super(ModelShortcut, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2D(3, 12, 1), nn.BatchNorm2D(12), nn.ReLU())
self.branch1 = nn.Sequential(
nn.Conv2D(12, 12, 1),
nn.BatchNorm2D(12),
nn.ReLU(),
nn.Conv2D(
12, 12, 1, groups=12),
nn.BatchNorm2D(12),
nn.ReLU(),
nn.Conv2D(
12, 12, 1, groups=12),
nn.BatchNorm2D(12),
nn.ReLU())
self.branch2 = nn.Sequential(
nn.Conv2D(12, 12, 1),
nn.BatchNorm2D(12),
nn.ReLU(),
nn.Conv2D(
12, 12, 1, groups=12),
nn.BatchNorm2D(12),
nn.ReLU(),
nn.Conv2D(12, 12, 1),
nn.BatchNorm2D(12),
nn.ReLU())
self.out = nn.Sequential(
nn.Conv2D(12, 12, 1), nn.BatchNorm2D(12), nn.ReLU())
def forward(self, x):
x = self.conv1(x)
y = self.branch1(x)
y = x + y
z = self.branch2(y)
z = z + y
return z
class TestOFAV2(unittest.TestCase):
def setUp(self):
model = ModelV1()
......@@ -73,5 +113,42 @@ class TestOFAV2Export(unittest.TestCase):
origin_model=origin_model)
class TestShortcutSkiplayers(unittest.TestCase):
def setUp(self):
model = ModelShortcut()
sp_net_config = supernet(expand_ratio=[0.5, 1.0])
self.model = Convert(sp_net_config).convert(model)
self.images = paddle.randn(shape=[2, 3, 32, 32], dtype='float32')
self.init_config()
self.ofa_model = OFA(self.model, run_config=self.run_config)
self.ofa_model._clear_search_space(self.images)
def init_config(self):
default_run_config = {'skip_layers': ['branch1.6']}
self.run_config = RunConfig(**default_run_config)
def test_shortcut(self):
self.ofa_model.set_epoch(0)
self.ofa_model.set_task('expand_ratio')
for i in range(5):
self.ofa_model(self.images)
assert list(self.ofa_model._ofa_layers.keys()) == ['branch2.0', 'out.0']
class TestShortcutSkiplayersCase1(TestShortcutSkiplayers):
def init_config(self):
default_run_config = {'skip_layers': ['conv1.0']}
self.run_config = RunConfig(**default_run_config)
class TestShortcutSkiplayersCase2(TestShortcutSkiplayers):
def init_config(self):
default_run_config = {'skip_layers': ['branch2.0']}
self.run_config = RunConfig(**default_run_config)
def test_shortcut(self):
assert list(self.ofa_model._ofa_layers.keys()) == ['conv1.0', 'out.0']
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册