未验证 提交 8cc4e44c 编写于 作者: C ceci3 提交者: GitHub

fix some bug of ofa (#712)

* fix dist

* fix

* remove hook

* fix skip layers

* fix depthwise

* fix depthwise

* fix depthwise
上级 58ac4d43
...@@ -105,6 +105,8 @@ class Convert: ...@@ -105,6 +105,8 @@ class Convert:
cur_channel = None cur_channel = None
for idx, layer in enumerate(model): for idx, layer in enumerate(model):
cls_name = layer.__class__.__name__.lower() cls_name = layer.__class__.__name__.lower()
### basic api in paddle
if len(layer.sublayers()) == 0:
if 'conv' in cls_name or 'linear' in cls_name or 'embedding' in cls_name: if 'conv' in cls_name or 'linear' in cls_name or 'embedding' in cls_name:
weight_layer_count += 1 weight_layer_count += 1
last_weight_layer_idx = idx last_weight_layer_idx = idx
......
...@@ -146,12 +146,34 @@ def prune_params(model, param_config, super_model_sd=None): ...@@ -146,12 +146,34 @@ def prune_params(model, param_config, super_model_sd=None):
setattr(model, k, v) setattr(model, k, v)
def _is_depthwise(op):
"""Check if this op is depthwise conv.
The shape of input and the shape of output in depthwise conv must be same in superlayer,
so depthwise op cannot be consider as weight op
"""
if op.type() == 'depthwise_conv':
return True
elif 'conv' in op.type():
for inp in op.all_inputs():
if not inp._var.persistable and op.attr('groups') == inp._var.shape[
1]:
return True
return False
def _find_weight_ops(op, graph, weights): def _find_weight_ops(op, graph, weights):
""" Find the vars come from operators with weight. """ Find the vars come from operators with weight.
""" """
pre_ops = graph.pre_ops(op) pre_ops = graph.pre_ops(op)
for pre_op in pre_ops: for pre_op in pre_ops:
if pre_op.type() in WEIGHT_OP: ### if depthwise conv is one of elementwise's input,
### add it into this same search space
if _is_depthwise(pre_op):
for inp in pre_op.all_inputs():
if inp._var.persistable:
weights.append(inp._var.name)
if pre_op.type() in WEIGHT_OP and not _is_depthwise(pre_op):
for inp in pre_op.all_inputs(): for inp in pre_op.all_inputs():
if inp._var.persistable: if inp._var.persistable:
weights.append(inp._var.name) weights.append(inp._var.name)
...@@ -176,16 +198,22 @@ def check_search_space(graph): ...@@ -176,16 +198,22 @@ def check_search_space(graph):
""" Find the shortcut in the model and set same config for this situation. """ Find the shortcut in the model and set same config for this situation.
""" """
same_search_space = [] same_search_space = []
depthwise_conv = []
for op in graph.ops(): for op in graph.ops():
if op.type() == 'elementwise_add': if op.type() == 'elementwise_add' or op.type() == 'elementwise_mul':
inp1, inp2 = op.all_inputs()[0], op.all_inputs()[1] inp1, inp2 = op.all_inputs()[0], op.all_inputs()[1]
if (not inp1._var.persistable) and (not inp2._var.persistable): if (not inp1._var.persistable) and (not inp2._var.persistable):
pre_ele_op = _find_pre_elementwise_add(op, graph) pre_ele_op = _find_pre_elementwise_add(op, graph)
if pre_ele_op != None: if pre_ele_op != None:
same_search_space.append(pre_ele_op) same_search_space.append(pre_ele_op)
if _is_depthwise(op):
for inp in op.all_inputs():
if inp._var.persistable:
depthwise_conv.append(inp._var.name)
if len(same_search_space) == 0: if len(same_search_space) == 0:
return None return None, None
same_search_space = sorted([sorted(x) for x in same_search_space]) same_search_space = sorted([sorted(x) for x in same_search_space])
final_search_space = [] final_search_space = []
...@@ -204,5 +232,7 @@ def check_search_space(graph): ...@@ -204,5 +232,7 @@ def check_search_space(graph):
break break
if not merged: if not merged:
final_search_space.append(l) final_search_space.append(l)
final_search_space = sorted([sorted(x) for x in final_search_space])
depthwise_conv = sorted(depthwise_conv)
return final_search_space return (final_search_space, depthwise_conv)
...@@ -302,7 +302,16 @@ class SuperConv2D(nn.Conv2D): ...@@ -302,7 +302,16 @@ class SuperConv2D(nn.Conv2D):
padding = self._padding padding = self._padding
if self.bias is not None: if self.bias is not None:
bias = self.bias[:out_nc] ### if conv is depthwise conv, expand_ratio=0, but conv' expand
### ratio before depthwise conv is not equal to 1.0, the shape of the weight
### about this depthwise conv is changed, but out_nc is not change,
### so need to change bias shape according to the weight_out_nc.
### if in_nc > groups > 1, the actual output of conv is weight_out_nc * groups,
### so slice the shape of bias by weight_out_nc and groups.
### if in_nc = groups, slice the shape of bias by weight_out_nc.
if groups != in_nc:
weight_out_nc = weight_out_nc * groups
bias = self.bias[:weight_out_nc]
else: else:
bias = self.bias bias = self.bias
...@@ -313,7 +322,7 @@ class SuperConv2D(nn.Conv2D): ...@@ -313,7 +322,7 @@ class SuperConv2D(nn.Conv2D):
stride=self._stride, stride=self._stride,
padding=padding, padding=padding,
dilation=self._dilation, dilation=self._dilation,
groups=self._groups, groups=groups,
data_format=self._data_format) data_format=self._data_format)
return out return out
...@@ -606,7 +615,9 @@ class SuperConv2DTranspose(nn.Conv2DTranspose): ...@@ -606,7 +615,9 @@ class SuperConv2DTranspose(nn.Conv2DTranspose):
output_padding = 0 output_padding = 0
if self.bias is not None: if self.bias is not None:
bias = self.bias[:out_nc] if groups != in_nc:
weight_out_nc = weight_out_nc * groups
bias = self.bias[:weight_out_nc]
else: else:
bias = self.bias bias = self.bias
...@@ -618,7 +629,7 @@ class SuperConv2DTranspose(nn.Conv2DTranspose): ...@@ -618,7 +629,7 @@ class SuperConv2DTranspose(nn.Conv2DTranspose):
output_padding=output_padding, output_padding=output_padding,
stride=self._stride, stride=self._stride,
dilation=self._dilation, dilation=self._dilation,
groups=self._groups, groups=groups,
output_size=output_size, output_size=output_size,
data_format=self._data_format) data_format=self._data_format)
return out return out
...@@ -929,10 +940,11 @@ class SuperBatchNorm2D(nn.BatchNorm2D): ...@@ -929,10 +940,11 @@ class SuperBatchNorm2D(nn.BatchNorm2D):
weight_attr=None, weight_attr=None,
bias_attr=None, bias_attr=None,
data_format='NCHW', data_format='NCHW',
use_global_stats=None,
name=None): name=None):
super(SuperBatchNorm2D, self).__init__(num_features, momentum, epsilon, super(SuperBatchNorm2D, self).__init__(
weight_attr, bias_attr, num_features, momentum, epsilon, weight_attr, bias_attr,
data_format, name) data_format, use_global_stats, name)
def forward(self, input): def forward(self, input):
self._check_data_format(self._data_format) self._check_data_format(self._data_format)
...@@ -954,7 +966,8 @@ class SuperBatchNorm2D(nn.BatchNorm2D): ...@@ -954,7 +966,8 @@ class SuperBatchNorm2D(nn.BatchNorm2D):
training=self.training, training=self.training,
momentum=self._momentum, momentum=self._momentum,
epsilon=self._epsilon, epsilon=self._epsilon,
data_format=self._data_format) data_format=self._data_format,
use_global_stats=self._use_global_stats)
class SuperSyncBatchNorm(nn.SyncBatchNorm): class SuperSyncBatchNorm(nn.SyncBatchNorm):
......
...@@ -51,7 +51,9 @@ RunConfig = namedtuple( ...@@ -51,7 +51,9 @@ RunConfig = namedtuple(
# list, elactic depth of the model in training, default: None # list, elactic depth of the model in training, default: None
'elastic_depth', 'elastic_depth',
# list, the number of sub-network to train per mini-batch data, used to get current epoch, default: None # list, the number of sub-network to train per mini-batch data, used to get current epoch, default: None
'dynamic_batch_size' 'dynamic_batch_size',
# the shape of weights in the skip_layers will not change in the training, default: None
'skip_layers'
]) ])
RunConfig.__new__.__defaults__ = (None, ) * len(RunConfig._fields) RunConfig.__new__.__defaults__ = (None, ) * len(RunConfig._fields)
...@@ -106,7 +108,10 @@ class OFABase(Layer): ...@@ -106,7 +108,10 @@ class OFABase(Layer):
if getattr(self, 'current_config', None) != None: if getattr(self, 'current_config', None) != None:
### if block is fixed, donnot join key into candidate ### if block is fixed, donnot join key into candidate
### concrete config as parameter in kwargs ### concrete config as parameter in kwargs
if block.fixed == False: if block.fixed == False and (
self._skip_layers != None and
self._key2name[block.key] not in self._skip_layers) and \
(block.fn.weight.name not in self._depthwise_conv):
assert self._key2name[ assert self._key2name[
block. block.
key] in self.current_config, 'DONNT have {} layer in config.'.format( key] in self.current_config, 'DONNT have {} layer in config.'.format(
...@@ -117,7 +122,7 @@ class OFABase(Layer): ...@@ -117,7 +122,7 @@ class OFABase(Layer):
config.update(kwargs) config.update(kwargs)
else: else:
config = dict() config = dict()
logging.debug(self.model, config) _logger.debug(self.model, config)
return block.fn(*inputs, **config) return block.fn(*inputs, **config)
...@@ -175,6 +180,7 @@ class OFA(OFABase): ...@@ -175,6 +180,7 @@ class OFA(OFABase):
self._mapping_layers = None self._mapping_layers = None
self._build_ss = False self._build_ss = False
self._broadcast = False self._broadcast = False
self._skip_layers = None
### if elastic_order is none, use default order ### if elastic_order is none, use default order
if self.elastic_order is not None: if self.elastic_order is not None:
...@@ -185,6 +191,7 @@ class OFA(OFABase): ...@@ -185,6 +191,7 @@ class OFA(OFABase):
depth_list = list(set(self.run_config.elastic_depth)) depth_list = list(set(self.run_config.elastic_depth))
depth_list.sort() depth_list.sort()
self._ofa_layers['depth'] = depth_list self._ofa_layers['depth'] = depth_list
self._layers['depth'] = depth_list
if self.elastic_order is None: if self.elastic_order is None:
self.elastic_order = [] self.elastic_order = []
...@@ -198,6 +205,7 @@ class OFA(OFABase): ...@@ -198,6 +205,7 @@ class OFA(OFABase):
depth_list = list(set(self.run_config.elastic_depth)) depth_list = list(set(self.run_config.elastic_depth))
depth_list.sort() depth_list.sort()
self._ofa_layers['depth'] = depth_list self._ofa_layers['depth'] = depth_list
self._layers['depth'] = depth_list
self.elastic_order.append('depth') self.elastic_order.append('depth')
# final, elastic width # final, elastic width
...@@ -225,6 +233,11 @@ class OFA(OFABase): ...@@ -225,6 +233,11 @@ class OFA(OFABase):
run_config.init_learning_rate[idx], list run_config.init_learning_rate[idx], list
), "each candidate in init_learning_rate must be list" ), "each candidate in init_learning_rate must be list"
### remove skip layers in search space
if self.run_config != None and getattr(self.run_config, 'skip_layers',
None) != None:
self._skip_layers = self.run_config.skip_layers
### ================= add distill prepare ====================== ### ================= add distill prepare ======================
if self.distill_config != None: if self.distill_config != None:
self._add_teacher = True self._add_teacher = True
...@@ -234,7 +247,7 @@ class OFA(OFABase): ...@@ -234,7 +247,7 @@ class OFA(OFABase):
def _prepare_distill(self): def _prepare_distill(self):
if self.distill_config.teacher_model == None: if self.distill_config.teacher_model == None:
logging.error( _logger.error(
'If you want to add distill, please input instance of teacher model' 'If you want to add distill, please input instance of teacher model'
) )
...@@ -289,6 +302,7 @@ class OFA(OFABase): ...@@ -289,6 +302,7 @@ class OFA(OFABase):
def _reset_hook_before_forward(self): def _reset_hook_before_forward(self):
self.Tacts, self.Sacts = {}, {} self.Tacts, self.Sacts = {}, {}
self.hooks = []
if self._mapping_layers != None: if self._mapping_layers != None:
def get_activation(mem, name): def get_activation(mem, name):
...@@ -300,12 +314,18 @@ class OFA(OFABase): ...@@ -300,12 +314,18 @@ class OFA(OFABase):
def add_hook(net, mem, mapping_layers): def add_hook(net, mem, mapping_layers):
for idx, (n, m) in enumerate(net.named_sublayers()): for idx, (n, m) in enumerate(net.named_sublayers()):
if n in mapping_layers: if n in mapping_layers:
m.register_forward_post_hook(get_activation(mem, n)) self.hooks.append(
m.register_forward_post_hook(
get_activation(mem, n)))
add_hook(self.model, self.Sacts, self._mapping_layers) add_hook(self.model, self.Sacts, self._mapping_layers)
add_hook(self.ofa_teacher_model.model, self.Tacts, add_hook(self.ofa_teacher_model.model, self.Tacts,
self._mapping_layers) self._mapping_layers)
def _remove_hook_after_forward(self):
for hook in self.hooks:
hook.remove()
def _compute_epochs(self): def _compute_epochs(self):
if getattr(self, 'epoch', None) == None: if getattr(self, 'epoch', None) == None:
assert self.run_config.total_images is not None, \ assert self.run_config.total_images is not None, \
...@@ -521,6 +541,14 @@ class OFA(OFABase): ...@@ -521,6 +541,14 @@ class OFA(OFABase):
else: else:
return False return False
def _clear_width(self, key):
if 'expand_ratio' in self._ofa_layers[key]:
self._ofa_layers[key].pop('expand_ratio')
elif 'channel' in self._ofa_layers[key]:
self._ofa_layers[key].pop('channel')
if len(self._ofa_layers[key]) == 0:
self._ofa_layers.pop(key)
def _clear_search_space(self, *inputs, **kwargs): def _clear_search_space(self, *inputs, **kwargs):
""" find shortcut in model, and clear up the search space """ """ find shortcut in model, and clear up the search space """
input_shapes = [] input_shapes = []
...@@ -533,34 +561,70 @@ class OFA(OFABase): ...@@ -533,34 +561,70 @@ class OFA(OFABase):
input_dtypes.append(v.numpy().dtype) input_dtypes.append(v.numpy().dtype)
### find shortcut block using static model ### find shortcut block using static model
model_to_traverse = self.model._layers if isinstance(
self.model, DataParallel) else self.model
_st_prog = dygraph2program( _st_prog = dygraph2program(
self.model, inputs=input_shapes, dtypes=input_dtypes) model_to_traverse, inputs=input_shapes, dtypes=input_dtypes)
self._same_ss = check_search_space(GraphWrapper(_st_prog)) self._same_ss, self._depthwise_conv = check_search_space(
GraphWrapper(_st_prog))
if self._same_ss != None: if self._same_ss != None:
self._same_ss = sorted(self._same_ss)
self._param2key = {} self._param2key = {}
self._broadcast = True self._broadcast = True
### the name of sublayer is the key in search space ### the name of sublayer is the key in search space
### param.name is the name in self._same_ss ### param.name is the name in self._same_ss
model_to_traverse = self.model._layers if isinstance(
self.model, DataParallel) else self.model
for name, sublayer in model_to_traverse.named_sublayers(): for name, sublayer in model_to_traverse.named_sublayers():
if isinstance(sublayer, BaseBlock): if isinstance(sublayer, BaseBlock):
for param in sublayer.parameters(): for param in sublayer.parameters():
if self._find_ele(param.name, self._same_ss): if self._find_ele(param.name, self._same_ss):
self._param2key[param.name] = name self._param2key[param.name] = name
### double clear same search space to avoid outputs weights in same ss.
tmp_same_ss = []
for ss in self._same_ss:
per_ss = []
for key in ss:
if key not in self._param2key.keys():
continue
### if skip_layers and same ss both have same layer,
### the layers related to this layer need to add to skip_layers
if self._skip_layers != None and self._param2key[
key] in self._skip_layers:
self._skip_layers += [self._param2key[sk] for sk in ss]
per_ss = []
break
if self._param2key[key] in self._ofa_layers.keys() and \
('expand_ratio' in self._ofa_layers[self._param2key[key]] or \
'channel' in self._ofa_layers[self._param2key[key]]):
per_ss.append(key)
else:
_logger.info("{} not in ss".format(key))
if len(per_ss) != 0:
tmp_same_ss.append(per_ss)
self._same_ss = tmp_same_ss
### clear layer in ofa_layers set by skip layers
if self._skip_layers != None:
for skip_layer in self._skip_layers:
if skip_layer in self._ofa_layers.keys():
self._ofa_layers.pop(skip_layer)
for per_ss in self._same_ss: for per_ss in self._same_ss:
for ss in per_ss[1:]: for ss in per_ss[1:]:
if 'expand_ratio' in self._ofa_layers[self._param2key[ss]]: self._clear_width(self._param2key[ss])
self._ofa_layers[self._param2key[ss]].pop(
'expand_ratio') ### clear depthwise conv from search space because of its output channel cannot change
elif 'channel' in self._ofa_layers[self._param2key[ss]]: for name, sublayer in model_to_traverse.named_sublayers():
self._ofa_layers[self._param2key[ss]].pop('channel') if isinstance(sublayer, BaseBlock):
if len(self._ofa_layers[self._param2key[ss]]) == 0: for param in sublayer.parameters():
self._ofa_layers.pop(self._param2key[ss]) if param.name in self._depthwise_conv and name in self._ofa_layers.keys(
):
self._clear_width(name)
def _broadcast_ss(self): def _broadcast_ss(self):
""" broadcast search space after random sample.""" """ broadcast search space after random sample."""
...@@ -630,4 +694,9 @@ class OFA(OFABase): ...@@ -630,4 +694,9 @@ class OFA(OFABase):
if self._broadcast: if self._broadcast:
self._broadcast_ss() self._broadcast_ss()
return self.model.forward(*inputs, **kwargs), teacher_output student_output = self.model.forward(*inputs, **kwargs)
if self._add_teacher:
self._remove_hook_after_forward()
return student_output, teacher_output
...@@ -45,6 +45,9 @@ def dynabert_config(model, width_mult, depth_mult=1.0): ...@@ -45,6 +45,9 @@ def dynabert_config(model, width_mult, depth_mult=1.0):
if block_k == 'depth': if block_k == 'depth':
block_v = depth_mult block_v = depth_mult
if block_k != 'depth':
new_block_k = model._key2name[block_k] new_block_k = model._key2name[block_k]
else:
new_block_k = 'depth'
new_config[new_block_k] = block_v new_config[new_block_k] = block_v
return new_config return new_config
...@@ -40,6 +40,46 @@ class ModelV1(nn.Layer): ...@@ -40,6 +40,46 @@ class ModelV1(nn.Layer):
return self.cls + self.model(inputs) return self.cls + self.model(inputs)
class ModelShortcut(nn.Layer):
def __init__(self):
super(ModelShortcut, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2D(3, 12, 1), nn.BatchNorm2D(12), nn.ReLU())
self.branch1 = nn.Sequential(
nn.Conv2D(12, 12, 1),
nn.BatchNorm2D(12),
nn.ReLU(),
nn.Conv2D(
12, 12, 1, groups=12),
nn.BatchNorm2D(12),
nn.ReLU(),
nn.Conv2D(
12, 12, 1, groups=12),
nn.BatchNorm2D(12),
nn.ReLU())
self.branch2 = nn.Sequential(
nn.Conv2D(12, 12, 1),
nn.BatchNorm2D(12),
nn.ReLU(),
nn.Conv2D(
12, 12, 1, groups=12),
nn.BatchNorm2D(12),
nn.ReLU(),
nn.Conv2D(12, 12, 1),
nn.BatchNorm2D(12),
nn.ReLU())
self.out = nn.Sequential(
nn.Conv2D(12, 12, 1), nn.BatchNorm2D(12), nn.ReLU())
def forward(self, x):
x = self.conv1(x)
y = self.branch1(x)
y = x + y
z = self.branch2(y)
z = z + y
return z
class TestOFAV2(unittest.TestCase): class TestOFAV2(unittest.TestCase):
def setUp(self): def setUp(self):
model = ModelV1() model = ModelV1()
...@@ -73,5 +113,42 @@ class TestOFAV2Export(unittest.TestCase): ...@@ -73,5 +113,42 @@ class TestOFAV2Export(unittest.TestCase):
origin_model=origin_model) origin_model=origin_model)
class TestShortcutSkiplayers(unittest.TestCase):
def setUp(self):
model = ModelShortcut()
sp_net_config = supernet(expand_ratio=[0.5, 1.0])
self.model = Convert(sp_net_config).convert(model)
self.images = paddle.randn(shape=[2, 3, 32, 32], dtype='float32')
self.init_config()
self.ofa_model = OFA(self.model, run_config=self.run_config)
self.ofa_model._clear_search_space(self.images)
def init_config(self):
default_run_config = {'skip_layers': ['branch1.6']}
self.run_config = RunConfig(**default_run_config)
def test_shortcut(self):
self.ofa_model.set_epoch(0)
self.ofa_model.set_task('expand_ratio')
for i in range(5):
self.ofa_model(self.images)
assert list(self.ofa_model._ofa_layers.keys()) == ['branch2.0', 'out.0']
class TestShortcutSkiplayersCase1(TestShortcutSkiplayers):
def init_config(self):
default_run_config = {'skip_layers': ['conv1.0']}
self.run_config = RunConfig(**default_run_config)
class TestShortcutSkiplayersCase2(TestShortcutSkiplayers):
def init_config(self):
default_run_config = {'skip_layers': ['branch2.0']}
self.run_config = RunConfig(**default_run_config)
def test_shortcut(self):
assert list(self.ofa_model._ofa_layers.keys()) == ['conv1.0', 'out.0']
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册