未验证 提交 84d34761 编写于 作者: C ceci3 提交者: GitHub

fix get sub model (#733)

* fix get sub model

* fix

* update

* update

* update include_self
上级 022c5fbd
......@@ -37,12 +37,15 @@ def get_prune_params_config(graph, origin_model_config):
### TODO(ceci3):
### 1. fix config when this op is concat by graph.pre_ops(op)
### 2. add kernel_size in config
### 3. add channel in config
for inp in op.all_inputs():
n_ops = graph.next_ops(op)
if inp._var.name in origin_model_config.keys():
if 'expand_ratio' in origin_model_config[inp._var.name].keys():
tmp = origin_model_config[inp._var.name]['expand_ratio']
if 'expand_ratio' in origin_model_config[
inp._var.name] or 'channel' in origin_model_config[
inp._var.name]:
key = 'channel' if 'channel' in origin_model_config[
inp._var.name] else 'expand_ratio'
tmp = origin_model_config[inp._var.name][key]
if len(inp._var.shape) > 1:
if inp._var.name in param_config.keys():
param_config[inp._var.name].append(tmp)
......@@ -59,9 +62,13 @@ def get_prune_params_config(graph, origin_model_config):
if next_inp._var.persistable == True:
if next_inp._var.name in origin_model_config.keys():
if 'expand_ratio' in origin_model_config[
next_inp._var.name].keys():
next_inp._var.
name] or 'channel' in origin_model_config[
next_inp._var.name]:
key = 'channel' if 'channel' in origin_model_config[
next_inp._var.name] else 'expand_ratio'
tmp = origin_model_config[next_inp._var.name][
'expand_ratio']
key]
pre = tmp if precedor is None else precedor
if len(next_inp._var.shape) > 1:
param_config[next_inp._var.name] = [pre]
......@@ -78,9 +85,19 @@ def get_prune_params_config(graph, origin_model_config):
return param_config
def get_actual_shape(transform, channel):
if transform == None:
channel = int(channel)
else:
if isinstance(transform, float):
channel = int(channel * transform)
else:
channel = int(transform)
return channel
def prune_params(model, param_config, super_model_sd=None):
""" Prune parameters according to the config.
Parameters:
model(paddle.nn.Layer): instance of model.
param_config(dict): prune config of each weight.
......@@ -104,25 +121,18 @@ def prune_params(model, param_config, super_model_sd=None):
in_exp = param_config[param.name][0]
out_exp = param_config[param.name][1]
if sublayer.__class__.__name__.lower() in CONV_TYPES:
in_chn = int(value.shape[1]) if in_exp == None else int(
value.shape[1] * in_exp)
out_chn = int(value.shape[
0]) if out_exp == None else int(value.shape[0] *
out_exp)
in_chn = get_actual_shape(in_exp, value.shape[1])
out_chn = get_actual_shape(out_exp, value.shape[0])
prune_value = super_value[:out_chn, :in_chn, ...] \
if super_model_sd != None else value[:out_chn, :in_chn, ...]
else:
in_chn = int(value.shape[0]) if in_exp == None else int(
value.shape[0] * in_exp)
out_chn = int(value.shape[
1]) if out_exp == None else int(value.shape[1] *
out_exp)
in_chn = get_actual_shape(in_exp, value.shape[0])
out_chn = get_actual_shape(out_exp, value.shape[1])
prune_value = super_value[:in_chn, :out_chn, ...] \
if super_model_sd != None else value[:in_chn, :out_chn, ...]
else:
out_chn = int(value.shape[0]) if param_config[param.name][
0] == None else int(value.shape[0] *
param_config[param.name][0])
out_chn = get_actual_shape(param_config[param.name][0],
value.shape[0])
prune_value = super_value[:out_chn, ...] \
if super_model_sd != None else value[:out_chn, ...]
......@@ -140,23 +150,24 @@ def prune_params(model, param_config, super_model_sd=None):
if param.trainable:
param.clear_gradient()
### initialize param which not in sublayers, such as create persistable inputs by create_parameters
### initialize param which not in sublayers, such as create persistable inputs by create_parameters
if super_model_sd != None and len(super_model_sd) != 0:
for k, v in super_model_sd.items():
setattr(model, k, v)
def _is_depthwise(op):
"""Check if this op is depthwise conv.
"""Check if this op is depthwise conv. Only Cin == Cout == groups be consider as depthwise conv.
The shape of input and the shape of output in depthwise conv must be same in superlayer,
so depthwise op cannot be consider as weight op
"""
if op.type() == 'depthwise_conv':
return True
elif 'conv' in op.type():
#if op.type() == 'depthwise_conv2d': ### depthwise_conv2d in paddle is Cout % Cin =0
# return True
if 'conv' in op.type():
for inp in op.all_inputs():
if not inp._var.persistable and op.attr('groups') == inp._var.shape[
1]:
if inp._var.persistable and (
op.attr('groups') == inp._var.shape[0] and
op.attr('groups') * inp._var.shape[1] == inp._var.shape[0]):
return True
return False
......@@ -179,6 +190,7 @@ def _find_weight_ops(op, graph, weights):
weights.append(inp._var.name)
return weights
return _find_weight_ops(pre_op, graph, weights)
return weights
def _find_pre_elementwise_add(op, graph):
......@@ -236,3 +248,36 @@ def check_search_space(graph):
depthwise_conv = sorted(depthwise_conv)
return (final_search_space, depthwise_conv)
def broadcast_search_space(same_search_space, param2key, origin_config):
"""
Inplace broadcast the origin_config according to the same search space. Such as: same_search_space = [['conv1_weight', 'conv3_weight']], param2key = {'conv1_weight': 'conv1.conv', 'conv3_weight': 'conv3.weight'}, origin_config= {'conv1.weight': {'channel': 10}, 'conv2.weight': {'channel': 20}}, the result after this function is origin_config={'conv1.weight': {'channel': 10}, 'conv2.weight': {'channel': 20}, 'conv3.weight': {'channel': 10}}
Args:
same_search_space(list<list>): broadcast according this list, each list in same_search_space means the channel must be consistent.
param2key(dict): the name of layers corresponds to the name of parameter.
origin_config(dict): the search space which can be searched.
"""
for per_ss in same_search_space:
for ss in per_ss[1:]:
key = param2key[ss]
pre_key = param2key[per_ss[0]]
if key in origin_config:
if 'expand_ratio' in origin_config[pre_key]:
origin_config[key].update({
'expand_ratio': origin_config[pre_key]['expand_ratio']
})
elif 'channel' in origin_config[pre_key]:
origin_config[key].update({
'channel': origin_config[pre_key]['channel']
})
else:
if 'expand_ratio' in origin_config[pre_key]:
origin_config[key] = {
'expand_ratio': origin_config[pre_key]['expand_ratio']
}
elif 'channel' in origin_config[pre_key]:
origin_config[key] = {
'channel': origin_config[pre_key]['channel']
}
......@@ -31,7 +31,7 @@ from .layers_base import BaseBlock
from .utils.utils import search_idx
from ...common import get_logger
from ...core import GraphWrapper, dygraph2program
from .get_sub_model import get_prune_params_config, prune_params, check_search_space
from .get_sub_model import get_prune_params_config, prune_params, check_search_space, broadcast_search_space
_logger = get_logger(__name__, level=logging.INFO)
......@@ -156,7 +156,6 @@ class OFA(OFABase):
sp_net_config = supernet(kernel_size=(3, 5, 7), expand_ratio=[1, 2, 4])
sp_model = Convert(sp_net_config).convert(model)
ofa_model = OFA(sp_model)
"""
def __init__(self,
......@@ -461,6 +460,23 @@ class OFA(OFABase):
def _export_sub_model_config(self, origin_model, config, input_shapes,
input_dtypes):
param2name = {}
for name, sublayer in origin_model.named_sublayers():
for param in sublayer.parameters(include_sublayers=False):
if name.split('.')[-1] == 'fn':
### if sublayer is Block, the name of the param.name has 'fn', the config always donnot have 'fn'
param2name[param.name] = name[:-3]
else:
param2name[param.name] = name
program = dygraph2program(
origin_model, inputs=input_shapes, dtypes=input_dtypes)
graph = GraphWrapper(program)
same_config, _ = check_search_space(graph)
if same_config != None:
broadcast_search_space(same_config, param2name, config)
origin_model_config = {}
for name, sublayer in origin_model.named_sublayers():
if isinstance(sublayer, BaseBlock):
......@@ -469,9 +485,6 @@ class OFA(OFABase):
if name in config.keys():
origin_model_config[param.name] = config[name]
program = dygraph2program(
origin_model, inputs=input_shapes, dtypes=input_dtypes)
graph = GraphWrapper(program)
param_prune_config = get_prune_params_config(graph, origin_model_config)
return param_prune_config
......@@ -493,7 +506,6 @@ class OFA(OFABase):
.. code-block:: python
from paddle.vision.models import mobilenet_v1
origin_model = mobilenet_v1()
config = {'conv2d_0': {'expand_ratio': 2}, 'conv2d_1': {'expand_ratio': 2}}
origin_model = ofa_model.export(origin_model, config, input_shapes=[1, 3, 28, 28], input_dtypes=['float32'])
"""
......@@ -505,7 +517,6 @@ class OFA(OFABase):
origin_model = self.model
origin_model = origin_model._layers if isinstance(
origin_model, DataParallel) else origin_model
param_config = self._export_sub_model_config(origin_model, config,
input_shapes, input_dtypes)
prune_params(origin_model, param_config, super_sd)
......@@ -602,7 +613,6 @@ class OFA(OFABase):
per_ss.append(key)
else:
_logger.info("{} not in ss".format(key))
if len(per_ss) != 0:
tmp_same_ss.append(per_ss)
......@@ -626,33 +636,6 @@ class OFA(OFABase):
):
self._clear_width(name)
def _broadcast_ss(self):
""" broadcast search space after random sample."""
for per_ss in self._same_ss:
for ss in per_ss[1:]:
key = self._param2key[ss]
pre_key = self._param2key[per_ss[0]]
if key in self.current_config:
if 'expand_ratio' in self.current_config[pre_key]:
self.current_config[key].update({
'expand_ratio':
self.current_config[pre_key]['expand_ratio']
})
elif 'channel' in self.current_config[pre_key]:
self.current_config[key].update({
'channel': self.current_config[pre_key]['channel']
})
else:
if 'expand_ratio' in self.current_config[pre_key]:
self.current_config[key] = {
'expand_ratio':
self.current_config[pre_key]['expand_ratio']
}
elif 'channel' in self.current_config[pre_key]:
self.current_config[key] = {
'channel': self.current_config[pre_key]['channel']
}
def forward(self, *inputs, **kwargs):
# ===================== teacher process =====================
teacher_output = None
......@@ -692,7 +675,8 @@ class OFA(OFABase):
kwargs['depth'] = self.current_config['depth']
if self._broadcast:
self._broadcast_ss()
broadcast_search_space(self._same_ss, self._param2key,
self.current_config)
student_output = self.model.forward(*inputs, **kwargs)
......
......@@ -34,7 +34,7 @@ def OneShotSearch(model, eval_func, strategy='sa', search_steps=100):
list<int>: The best tokens searched.
"""
super_net = None
for layer in model.sublayers(include_sublayers=False):
for layer in model.sublayers(include_self=True):
print("layer: {}".format(layer))
if isinstance(layer, OneShotSuperNet):
super_net = layer
......
......@@ -37,14 +37,13 @@ class PrePostProcessLayer(Layer):
for cmd in self.process_cmd:
if cmd == "a": # add residual connection
self.functors.append(
lambda x, y: x + y if y is not None else x)
self.functors.append(lambda x, y: x + y if y is not None else x)
self.exec_order += "a"
elif cmd == "n": # add layer normalization
self.functors.append(
self.add_sublayer(
"layer_norm_%d" % len(
self.sublayers(include_sublayers=False)),
self.sublayers(include_self=True)),
LayerNorm(
normalized_shape=d_model,
param_attr=fluid.ParamAttr(
......
......@@ -449,5 +449,21 @@ class TestShortCut(unittest.TestCase):
assert len(self.ofa_model.ofa_layers) == 38
class TestExportCase1(unittest.TestCase):
def setUp(self):
model = ModelLinear1()
data_np = np.random.random((3, 64)).astype(np.int64)
self.data = paddle.to_tensor(data_np)
self.ofa_model = OFA(model)
self.ofa_model.set_epoch(0)
outs, _ = self.ofa_model(self.data)
self.config = self.ofa_model.current_config
def test_export_model(self):
self.ofa_model.export(
self.config, input_shapes=[[3, 64]], input_dtypes=['int64'])
assert len(self.ofa_model.ofa_layers) == 4
if __name__ == '__main__':
unittest.main()
......@@ -77,6 +77,7 @@ class ModelShortcut(nn.Layer):
y = x + y
z = self.branch2(y)
z = z + y
z = self.out(z)
return z
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册