未验证 提交 84d34761 编写于 作者: C ceci3 提交者: GitHub

fix get sub model (#733)

* fix get sub model

* fix

* update

* update

* update include_self
上级 022c5fbd
...@@ -37,12 +37,15 @@ def get_prune_params_config(graph, origin_model_config): ...@@ -37,12 +37,15 @@ def get_prune_params_config(graph, origin_model_config):
### TODO(ceci3): ### TODO(ceci3):
### 1. fix config when this op is concat by graph.pre_ops(op) ### 1. fix config when this op is concat by graph.pre_ops(op)
### 2. add kernel_size in config ### 2. add kernel_size in config
### 3. add channel in config
for inp in op.all_inputs(): for inp in op.all_inputs():
n_ops = graph.next_ops(op) n_ops = graph.next_ops(op)
if inp._var.name in origin_model_config.keys(): if inp._var.name in origin_model_config.keys():
if 'expand_ratio' in origin_model_config[inp._var.name].keys(): if 'expand_ratio' in origin_model_config[
tmp = origin_model_config[inp._var.name]['expand_ratio'] inp._var.name] or 'channel' in origin_model_config[
inp._var.name]:
key = 'channel' if 'channel' in origin_model_config[
inp._var.name] else 'expand_ratio'
tmp = origin_model_config[inp._var.name][key]
if len(inp._var.shape) > 1: if len(inp._var.shape) > 1:
if inp._var.name in param_config.keys(): if inp._var.name in param_config.keys():
param_config[inp._var.name].append(tmp) param_config[inp._var.name].append(tmp)
...@@ -59,9 +62,13 @@ def get_prune_params_config(graph, origin_model_config): ...@@ -59,9 +62,13 @@ def get_prune_params_config(graph, origin_model_config):
if next_inp._var.persistable == True: if next_inp._var.persistable == True:
if next_inp._var.name in origin_model_config.keys(): if next_inp._var.name in origin_model_config.keys():
if 'expand_ratio' in origin_model_config[ if 'expand_ratio' in origin_model_config[
next_inp._var.name].keys(): next_inp._var.
name] or 'channel' in origin_model_config[
next_inp._var.name]:
key = 'channel' if 'channel' in origin_model_config[
next_inp._var.name] else 'expand_ratio'
tmp = origin_model_config[next_inp._var.name][ tmp = origin_model_config[next_inp._var.name][
'expand_ratio'] key]
pre = tmp if precedor is None else precedor pre = tmp if precedor is None else precedor
if len(next_inp._var.shape) > 1: if len(next_inp._var.shape) > 1:
param_config[next_inp._var.name] = [pre] param_config[next_inp._var.name] = [pre]
...@@ -78,9 +85,19 @@ def get_prune_params_config(graph, origin_model_config): ...@@ -78,9 +85,19 @@ def get_prune_params_config(graph, origin_model_config):
return param_config return param_config
def get_actual_shape(transform, channel):
if transform == None:
channel = int(channel)
else:
if isinstance(transform, float):
channel = int(channel * transform)
else:
channel = int(transform)
return channel
def prune_params(model, param_config, super_model_sd=None): def prune_params(model, param_config, super_model_sd=None):
""" Prune parameters according to the config. """ Prune parameters according to the config.
Parameters: Parameters:
model(paddle.nn.Layer): instance of model. model(paddle.nn.Layer): instance of model.
param_config(dict): prune config of each weight. param_config(dict): prune config of each weight.
...@@ -104,25 +121,18 @@ def prune_params(model, param_config, super_model_sd=None): ...@@ -104,25 +121,18 @@ def prune_params(model, param_config, super_model_sd=None):
in_exp = param_config[param.name][0] in_exp = param_config[param.name][0]
out_exp = param_config[param.name][1] out_exp = param_config[param.name][1]
if sublayer.__class__.__name__.lower() in CONV_TYPES: if sublayer.__class__.__name__.lower() in CONV_TYPES:
in_chn = int(value.shape[1]) if in_exp == None else int( in_chn = get_actual_shape(in_exp, value.shape[1])
value.shape[1] * in_exp) out_chn = get_actual_shape(out_exp, value.shape[0])
out_chn = int(value.shape[
0]) if out_exp == None else int(value.shape[0] *
out_exp)
prune_value = super_value[:out_chn, :in_chn, ...] \ prune_value = super_value[:out_chn, :in_chn, ...] \
if super_model_sd != None else value[:out_chn, :in_chn, ...] if super_model_sd != None else value[:out_chn, :in_chn, ...]
else: else:
in_chn = int(value.shape[0]) if in_exp == None else int( in_chn = get_actual_shape(in_exp, value.shape[0])
value.shape[0] * in_exp) out_chn = get_actual_shape(out_exp, value.shape[1])
out_chn = int(value.shape[
1]) if out_exp == None else int(value.shape[1] *
out_exp)
prune_value = super_value[:in_chn, :out_chn, ...] \ prune_value = super_value[:in_chn, :out_chn, ...] \
if super_model_sd != None else value[:in_chn, :out_chn, ...] if super_model_sd != None else value[:in_chn, :out_chn, ...]
else: else:
out_chn = int(value.shape[0]) if param_config[param.name][ out_chn = get_actual_shape(param_config[param.name][0],
0] == None else int(value.shape[0] * value.shape[0])
param_config[param.name][0])
prune_value = super_value[:out_chn, ...] \ prune_value = super_value[:out_chn, ...] \
if super_model_sd != None else value[:out_chn, ...] if super_model_sd != None else value[:out_chn, ...]
...@@ -140,23 +150,24 @@ def prune_params(model, param_config, super_model_sd=None): ...@@ -140,23 +150,24 @@ def prune_params(model, param_config, super_model_sd=None):
if param.trainable: if param.trainable:
param.clear_gradient() param.clear_gradient()
### initialize param which not in sublayers, such as create persistable inputs by create_parameters ### initialize param which not in sublayers, such as create persistable inputs by create_parameters
if super_model_sd != None and len(super_model_sd) != 0: if super_model_sd != None and len(super_model_sd) != 0:
for k, v in super_model_sd.items(): for k, v in super_model_sd.items():
setattr(model, k, v) setattr(model, k, v)
def _is_depthwise(op): def _is_depthwise(op):
"""Check if this op is depthwise conv. """Check if this op is depthwise conv. Only Cin == Cout == groups be consider as depthwise conv.
The shape of input and the shape of output in depthwise conv must be same in superlayer, The shape of input and the shape of output in depthwise conv must be same in superlayer,
so depthwise op cannot be consider as weight op so depthwise op cannot be consider as weight op
""" """
if op.type() == 'depthwise_conv': #if op.type() == 'depthwise_conv2d': ### depthwise_conv2d in paddle is Cout % Cin =0
return True # return True
elif 'conv' in op.type(): if 'conv' in op.type():
for inp in op.all_inputs(): for inp in op.all_inputs():
if not inp._var.persistable and op.attr('groups') == inp._var.shape[ if inp._var.persistable and (
1]: op.attr('groups') == inp._var.shape[0] and
op.attr('groups') * inp._var.shape[1] == inp._var.shape[0]):
return True return True
return False return False
...@@ -179,6 +190,7 @@ def _find_weight_ops(op, graph, weights): ...@@ -179,6 +190,7 @@ def _find_weight_ops(op, graph, weights):
weights.append(inp._var.name) weights.append(inp._var.name)
return weights return weights
return _find_weight_ops(pre_op, graph, weights) return _find_weight_ops(pre_op, graph, weights)
return weights
def _find_pre_elementwise_add(op, graph): def _find_pre_elementwise_add(op, graph):
...@@ -236,3 +248,36 @@ def check_search_space(graph): ...@@ -236,3 +248,36 @@ def check_search_space(graph):
depthwise_conv = sorted(depthwise_conv) depthwise_conv = sorted(depthwise_conv)
return (final_search_space, depthwise_conv) return (final_search_space, depthwise_conv)
def broadcast_search_space(same_search_space, param2key, origin_config):
"""
Inplace broadcast the origin_config according to the same search space. Such as: same_search_space = [['conv1_weight', 'conv3_weight']], param2key = {'conv1_weight': 'conv1.conv', 'conv3_weight': 'conv3.weight'}, origin_config= {'conv1.weight': {'channel': 10}, 'conv2.weight': {'channel': 20}}, the result after this function is origin_config={'conv1.weight': {'channel': 10}, 'conv2.weight': {'channel': 20}, 'conv3.weight': {'channel': 10}}
Args:
same_search_space(list<list>): broadcast according this list, each list in same_search_space means the channel must be consistent.
param2key(dict): the name of layers corresponds to the name of parameter.
origin_config(dict): the search space which can be searched.
"""
for per_ss in same_search_space:
for ss in per_ss[1:]:
key = param2key[ss]
pre_key = param2key[per_ss[0]]
if key in origin_config:
if 'expand_ratio' in origin_config[pre_key]:
origin_config[key].update({
'expand_ratio': origin_config[pre_key]['expand_ratio']
})
elif 'channel' in origin_config[pre_key]:
origin_config[key].update({
'channel': origin_config[pre_key]['channel']
})
else:
if 'expand_ratio' in origin_config[pre_key]:
origin_config[key] = {
'expand_ratio': origin_config[pre_key]['expand_ratio']
}
elif 'channel' in origin_config[pre_key]:
origin_config[key] = {
'channel': origin_config[pre_key]['channel']
}
...@@ -31,7 +31,7 @@ from .layers_base import BaseBlock ...@@ -31,7 +31,7 @@ from .layers_base import BaseBlock
from .utils.utils import search_idx from .utils.utils import search_idx
from ...common import get_logger from ...common import get_logger
from ...core import GraphWrapper, dygraph2program from ...core import GraphWrapper, dygraph2program
from .get_sub_model import get_prune_params_config, prune_params, check_search_space from .get_sub_model import get_prune_params_config, prune_params, check_search_space, broadcast_search_space
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
...@@ -156,7 +156,6 @@ class OFA(OFABase): ...@@ -156,7 +156,6 @@ class OFA(OFABase):
sp_net_config = supernet(kernel_size=(3, 5, 7), expand_ratio=[1, 2, 4]) sp_net_config = supernet(kernel_size=(3, 5, 7), expand_ratio=[1, 2, 4])
sp_model = Convert(sp_net_config).convert(model) sp_model = Convert(sp_net_config).convert(model)
ofa_model = OFA(sp_model) ofa_model = OFA(sp_model)
""" """
def __init__(self, def __init__(self,
...@@ -461,6 +460,23 @@ class OFA(OFABase): ...@@ -461,6 +460,23 @@ class OFA(OFABase):
def _export_sub_model_config(self, origin_model, config, input_shapes, def _export_sub_model_config(self, origin_model, config, input_shapes,
input_dtypes): input_dtypes):
param2name = {}
for name, sublayer in origin_model.named_sublayers():
for param in sublayer.parameters(include_sublayers=False):
if name.split('.')[-1] == 'fn':
### if sublayer is Block, the name of the param.name has 'fn', the config always donnot have 'fn'
param2name[param.name] = name[:-3]
else:
param2name[param.name] = name
program = dygraph2program(
origin_model, inputs=input_shapes, dtypes=input_dtypes)
graph = GraphWrapper(program)
same_config, _ = check_search_space(graph)
if same_config != None:
broadcast_search_space(same_config, param2name, config)
origin_model_config = {} origin_model_config = {}
for name, sublayer in origin_model.named_sublayers(): for name, sublayer in origin_model.named_sublayers():
if isinstance(sublayer, BaseBlock): if isinstance(sublayer, BaseBlock):
...@@ -469,9 +485,6 @@ class OFA(OFABase): ...@@ -469,9 +485,6 @@ class OFA(OFABase):
if name in config.keys(): if name in config.keys():
origin_model_config[param.name] = config[name] origin_model_config[param.name] = config[name]
program = dygraph2program(
origin_model, inputs=input_shapes, dtypes=input_dtypes)
graph = GraphWrapper(program)
param_prune_config = get_prune_params_config(graph, origin_model_config) param_prune_config = get_prune_params_config(graph, origin_model_config)
return param_prune_config return param_prune_config
...@@ -493,7 +506,6 @@ class OFA(OFABase): ...@@ -493,7 +506,6 @@ class OFA(OFABase):
.. code-block:: python .. code-block:: python
from paddle.vision.models import mobilenet_v1 from paddle.vision.models import mobilenet_v1
origin_model = mobilenet_v1() origin_model = mobilenet_v1()
config = {'conv2d_0': {'expand_ratio': 2}, 'conv2d_1': {'expand_ratio': 2}} config = {'conv2d_0': {'expand_ratio': 2}, 'conv2d_1': {'expand_ratio': 2}}
origin_model = ofa_model.export(origin_model, config, input_shapes=[1, 3, 28, 28], input_dtypes=['float32']) origin_model = ofa_model.export(origin_model, config, input_shapes=[1, 3, 28, 28], input_dtypes=['float32'])
""" """
...@@ -505,7 +517,6 @@ class OFA(OFABase): ...@@ -505,7 +517,6 @@ class OFA(OFABase):
origin_model = self.model origin_model = self.model
origin_model = origin_model._layers if isinstance( origin_model = origin_model._layers if isinstance(
origin_model, DataParallel) else origin_model origin_model, DataParallel) else origin_model
param_config = self._export_sub_model_config(origin_model, config, param_config = self._export_sub_model_config(origin_model, config,
input_shapes, input_dtypes) input_shapes, input_dtypes)
prune_params(origin_model, param_config, super_sd) prune_params(origin_model, param_config, super_sd)
...@@ -602,7 +613,6 @@ class OFA(OFABase): ...@@ -602,7 +613,6 @@ class OFA(OFABase):
per_ss.append(key) per_ss.append(key)
else: else:
_logger.info("{} not in ss".format(key)) _logger.info("{} not in ss".format(key))
if len(per_ss) != 0: if len(per_ss) != 0:
tmp_same_ss.append(per_ss) tmp_same_ss.append(per_ss)
...@@ -626,33 +636,6 @@ class OFA(OFABase): ...@@ -626,33 +636,6 @@ class OFA(OFABase):
): ):
self._clear_width(name) self._clear_width(name)
def _broadcast_ss(self):
""" broadcast search space after random sample."""
for per_ss in self._same_ss:
for ss in per_ss[1:]:
key = self._param2key[ss]
pre_key = self._param2key[per_ss[0]]
if key in self.current_config:
if 'expand_ratio' in self.current_config[pre_key]:
self.current_config[key].update({
'expand_ratio':
self.current_config[pre_key]['expand_ratio']
})
elif 'channel' in self.current_config[pre_key]:
self.current_config[key].update({
'channel': self.current_config[pre_key]['channel']
})
else:
if 'expand_ratio' in self.current_config[pre_key]:
self.current_config[key] = {
'expand_ratio':
self.current_config[pre_key]['expand_ratio']
}
elif 'channel' in self.current_config[pre_key]:
self.current_config[key] = {
'channel': self.current_config[pre_key]['channel']
}
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
# ===================== teacher process ===================== # ===================== teacher process =====================
teacher_output = None teacher_output = None
...@@ -692,7 +675,8 @@ class OFA(OFABase): ...@@ -692,7 +675,8 @@ class OFA(OFABase):
kwargs['depth'] = self.current_config['depth'] kwargs['depth'] = self.current_config['depth']
if self._broadcast: if self._broadcast:
self._broadcast_ss() broadcast_search_space(self._same_ss, self._param2key,
self.current_config)
student_output = self.model.forward(*inputs, **kwargs) student_output = self.model.forward(*inputs, **kwargs)
......
...@@ -34,7 +34,7 @@ def OneShotSearch(model, eval_func, strategy='sa', search_steps=100): ...@@ -34,7 +34,7 @@ def OneShotSearch(model, eval_func, strategy='sa', search_steps=100):
list<int>: The best tokens searched. list<int>: The best tokens searched.
""" """
super_net = None super_net = None
for layer in model.sublayers(include_sublayers=False): for layer in model.sublayers(include_self=True):
print("layer: {}".format(layer)) print("layer: {}".format(layer))
if isinstance(layer, OneShotSuperNet): if isinstance(layer, OneShotSuperNet):
super_net = layer super_net = layer
......
...@@ -37,14 +37,13 @@ class PrePostProcessLayer(Layer): ...@@ -37,14 +37,13 @@ class PrePostProcessLayer(Layer):
for cmd in self.process_cmd: for cmd in self.process_cmd:
if cmd == "a": # add residual connection if cmd == "a": # add residual connection
self.functors.append( self.functors.append(lambda x, y: x + y if y is not None else x)
lambda x, y: x + y if y is not None else x)
self.exec_order += "a" self.exec_order += "a"
elif cmd == "n": # add layer normalization elif cmd == "n": # add layer normalization
self.functors.append( self.functors.append(
self.add_sublayer( self.add_sublayer(
"layer_norm_%d" % len( "layer_norm_%d" % len(
self.sublayers(include_sublayers=False)), self.sublayers(include_self=True)),
LayerNorm( LayerNorm(
normalized_shape=d_model, normalized_shape=d_model,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
......
...@@ -449,5 +449,21 @@ class TestShortCut(unittest.TestCase): ...@@ -449,5 +449,21 @@ class TestShortCut(unittest.TestCase):
assert len(self.ofa_model.ofa_layers) == 38 assert len(self.ofa_model.ofa_layers) == 38
class TestExportCase1(unittest.TestCase):
def setUp(self):
model = ModelLinear1()
data_np = np.random.random((3, 64)).astype(np.int64)
self.data = paddle.to_tensor(data_np)
self.ofa_model = OFA(model)
self.ofa_model.set_epoch(0)
outs, _ = self.ofa_model(self.data)
self.config = self.ofa_model.current_config
def test_export_model(self):
self.ofa_model.export(
self.config, input_shapes=[[3, 64]], input_dtypes=['int64'])
assert len(self.ofa_model.ofa_layers) == 4
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -77,6 +77,7 @@ class ModelShortcut(nn.Layer): ...@@ -77,6 +77,7 @@ class ModelShortcut(nn.Layer):
y = x + y y = x + y
z = self.branch2(y) z = self.branch2(y)
z = z + y z = z + y
z = self.out(z)
return z return z
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册