“9e98966c7bb94355689478bc84cc3e0c190f977e”上不存在“drivers/tty/amiserial.c”
未验证 提交 e274d7fb 编写于 作者: C Chang Xu 提交者: GitHub

new_export_func (#824)

* new_export_func

* add_test

* remove_kernel_prune

* add_test_&_update_clear_ss
上级 da0f5a6d
......@@ -64,6 +64,7 @@ def extract_vars(inputs):
f"Variable is excepted, but get an element with type({type(_value)}) from inputs whose type is dict. And the key of element is {_key}."
)
elif isinstance(inputs, (tuple, list)):
for _value in inputs:
vars.extend(extract_vars(_value))
if len(vars) == 0:
......@@ -99,7 +100,6 @@ def dygraph2program(layer,
extract_outputs_fn=None,
dtypes=None):
assert isinstance(layer, Layer)
extract_inputs_fn = extract_inputs_fn if extract_inputs_fn is not None else extract_vars
extract_outputs_fn = extract_outputs_fn if extract_outputs_fn is not None else extract_vars
tracer = _dygraph_tracer()._get_program_desc_tracer()
......@@ -116,6 +116,7 @@ def dygraph2program(layer,
else:
inputs = to_variables(inputs)
input_var_list = extract_inputs_fn(inputs)
original_outputs = layer(*inputs)
# 'original_outputs' may be dict, so we should convert it to list of varibles.
# And should not create new varibles in 'extract_vars'.
......
......@@ -17,7 +17,7 @@ import paddle
from paddle.fluid import core
from .layers_base import BaseBlock
__all__ = ['get_prune_params_config', 'prune_params', 'check_search_space']
__all__ = ['check_search_space']
WEIGHT_OP = [
'conv2d', 'linear', 'embedding', 'conv2d_transpose', 'depthwise_conv2d'
......@@ -28,63 +28,6 @@ CONV_TYPES = [
]
def get_prune_params_config(graph, origin_model_config):
""" Convert config of search space to parameters' prune config.
"""
param_config = {}
precedor = None
for op in graph.ops():
### TODO(ceci3):
### 1. fix config when this op is concat by graph.pre_ops(op)
### 2. add kernel_size in config
for inp in op.all_inputs():
n_ops = graph.next_ops(op)
if inp._var.name in origin_model_config.keys():
if 'expand_ratio' in origin_model_config[
inp._var.name] or 'channel' in origin_model_config[
inp._var.name]:
key = 'channel' if 'channel' in origin_model_config[
inp._var.name] else 'expand_ratio'
tmp = origin_model_config[inp._var.name][key]
if len(inp._var.shape) > 1:
if inp._var.name in param_config.keys():
param_config[inp._var.name].append(tmp)
### first op
else:
param_config[inp._var.name] = [precedor, tmp]
else:
param_config[inp._var.name] = [tmp]
precedor = tmp
else:
precedor = None
for n_op in n_ops:
for next_inp in n_op.all_inputs():
if next_inp._var.persistable == True:
if next_inp._var.name in origin_model_config.keys():
if 'expand_ratio' in origin_model_config[
next_inp._var.
name] or 'channel' in origin_model_config[
next_inp._var.name]:
key = 'channel' if 'channel' in origin_model_config[
next_inp._var.name] else 'expand_ratio'
tmp = origin_model_config[next_inp._var.name][
key]
pre = tmp if precedor is None else precedor
if len(next_inp._var.shape) > 1:
param_config[next_inp._var.name] = [pre]
else:
param_config[next_inp._var.name] = [tmp]
else:
if len(next_inp._var.
shape) > 1 and precedor != None:
param_config[
next_inp._var.name] = [precedor, None]
else:
param_config[next_inp._var.name] = [precedor]
return param_config
def get_actual_shape(transform, channel):
if transform == None:
channel = int(channel)
......@@ -96,66 +39,6 @@ def get_actual_shape(transform, channel):
return channel
def prune_params(model, param_config, super_model_sd=None):
""" Prune parameters according to the config.
Parameters:
model(paddle.nn.Layer): instance of model.
param_config(dict): prune config of each weight.
super_model_sd(dict, optional): parameters come from supernet. If super_model_sd is not None, transfer parameters from this dict to model; otherwise, prune model from itself.
"""
for l_name, sublayer in model.named_sublayers():
if isinstance(sublayer, BaseBlock):
continue
for p_name, param in sublayer.named_parameters(include_sublayers=False):
t_value = param.value().get_tensor()
value = np.array(t_value).astype("float32")
if super_model_sd != None:
name = l_name + '.' + p_name
super_t_value = super_model_sd[name].value().get_tensor()
super_value = np.array(super_t_value).astype("float32")
super_model_sd.pop(name)
if param.name in param_config.keys():
if len(param_config[param.name]) > 1:
in_exp = param_config[param.name][0]
out_exp = param_config[param.name][1]
if sublayer.__class__.__name__.lower() in CONV_TYPES:
in_chn = get_actual_shape(in_exp, value.shape[1])
out_chn = get_actual_shape(out_exp, value.shape[0])
prune_value = super_value[:out_chn, :in_chn, ...] \
if super_model_sd != None else value[:out_chn, :in_chn, ...]
else:
in_chn = get_actual_shape(in_exp, value.shape[0])
out_chn = get_actual_shape(out_exp, value.shape[1])
prune_value = super_value[:in_chn, :out_chn, ...] \
if super_model_sd != None else value[:in_chn, :out_chn, ...]
else:
out_chn = get_actual_shape(param_config[param.name][0],
value.shape[0])
prune_value = super_value[:out_chn, ...] \
if super_model_sd != None else value[:out_chn, ...]
else:
prune_value = super_value if super_model_sd != None else value
p = t_value._place()
if p.is_cpu_place():
place = core.CPUPlace()
elif p.is_cuda_pinned_place():
place = core.CUDAPinnedPlace()
else:
place = core.CUDAPlace(p.gpu_device_id())
t_value.set(prune_value, place)
if param.trainable:
param.clear_gradient()
### initialize param which not in sublayers, such as create persistable inputs by create_parameters
if super_model_sd != None and len(super_model_sd) != 0:
for k, v in super_model_sd.items():
setattr(model, k, v)
def _is_depthwise(op):
"""Check if this op is depthwise conv. Only Cin == Cout == groups be consider as depthwise conv.
The shape of input and the shape of output in depthwise conv must be same in superlayer,
......
......@@ -177,6 +177,7 @@ class SuperConv2D(nn.Conv2D):
data_format=data_format)
self.candidate_config = candidate_config
self.cur_config = None
if len(candidate_config.items()) != 0:
for k, v in candidate_config.items():
candidate_config[k] = list(set(v))
......@@ -314,7 +315,7 @@ class SuperConv2D(nn.Conv2D):
bias = self.bias[:weight_out_nc]
else:
bias = self.bias
self.cur_config['prune_dim'] = list(weight.shape)
out = F.conv2d(
input,
weight,
......@@ -482,6 +483,7 @@ class SuperConv2DTranspose(nn.Conv2DTranspose):
data_format=data_format)
self.candidate_config = candidate_config
self.cur_config = None
if len(self.candidate_config.items()) != 0:
for k, v in candidate_config.items():
candidate_config[k] = list(set(v))
......@@ -620,7 +622,7 @@ class SuperConv2DTranspose(nn.Conv2DTranspose):
bias = self.bias[:weight_out_nc]
else:
bias = self.bias
self.cur_config['prune_dim'] = list(weight.shape)
out = F.conv2d_transpose(
input,
weight,
......@@ -733,6 +735,7 @@ class SuperSeparableConv2D(nn.Layer):
])
self.candidate_config = candidate_config
self.cur_config = None
self.expand_ratio = candidate_config[
'expand_ratio'] if 'expand_ratio' in candidate_config else None
self.base_output_dim = self.conv[0]._out_channels
......@@ -784,7 +787,7 @@ class SuperSeparableConv2D(nn.Layer):
bias = self.conv[2].bias[:out_nc]
else:
bias = self.conv[2].bias
self.cur_config['prune_dim'] = list(weight.shape)
conv1_out = F.conv2d(
norm_out,
weight,
......@@ -864,6 +867,7 @@ class SuperLinear(nn.Linear):
self._in_features = in_features
self._out_features = out_features
self.candidate_config = candidate_config
self.cur_config = None
self.expand_ratio = candidate_config[
'expand_ratio'] if 'expand_ratio' in candidate_config else None
self.base_output_dim = self._out_features
......@@ -896,7 +900,7 @@ class SuperLinear(nn.Linear):
bias = self.bias[:out_nc]
else:
bias = self.bias
self.cur_config['prune_dim'] = list(weight.shape)
out = F.linear(x=input, weight=weight, bias=bias, name=self.name)
return out
......@@ -945,6 +949,7 @@ class SuperBatchNorm2D(nn.BatchNorm2D):
super(SuperBatchNorm2D, self).__init__(
num_features, momentum, epsilon, weight_attr, bias_attr,
data_format, use_global_stats, name)
self.cur_config = None
def forward(self, input):
self._check_data_format(self._data_format)
......@@ -956,7 +961,7 @@ class SuperBatchNorm2D(nn.BatchNorm2D):
bias = self.bias[:feature_dim]
mean = self._mean[:feature_dim]
variance = self._variance[:feature_dim]
self.cur_config = {'prune_dim': feature_dim}
return F.batch_norm(
input,
mean,
......@@ -982,6 +987,7 @@ class SuperSyncBatchNorm(nn.SyncBatchNorm):
super(SuperSyncBatchNorm,
self).__init__(num_features, momentum, epsilon, weight_attr,
bias_attr, data_format, name)
self.cur_config = None
def forward(self, input):
......@@ -995,6 +1001,7 @@ class SuperSyncBatchNorm(nn.SyncBatchNorm):
mean_out = mean
# variance and variance out share the same memory
variance_out = variance
self.cur_config = {'prune_dim': feature_dim}
attrs = ("momentum", self._momentum, "epsilon", self._epsilon,
"is_test", not self.training, "data_layout", self._data_format,
......@@ -1049,6 +1056,7 @@ class SuperInstanceNorm2D(nn.InstanceNorm2D):
super(SuperInstanceNorm2D, self).__init__(num_features, epsilon,
momentum, weight_attr,
bias_attr, data_format, name)
self.cur_config = None
def forward(self, input):
self._check_input_dim(input)
......@@ -1060,7 +1068,7 @@ class SuperInstanceNorm2D(nn.InstanceNorm2D):
else:
scale = self.scale[:feature_dim]
bias = self.bias[:feature_dim]
self.cur_config = {'prune_dim': feature_dim}
return F.instance_norm(input, scale, bias, eps=self._epsilon)
......@@ -1112,6 +1120,7 @@ class SuperLayerNorm(nn.LayerNorm):
name=None):
super(SuperLayerNorm, self).__init__(normalized_shape, epsilon,
weight_attr, bias_attr, name)
self.cur_config = None
def forward(self, input):
### TODO(ceci3): fix if normalized_shape is not a single number
......@@ -1127,6 +1136,8 @@ class SuperLayerNorm(nn.LayerNorm):
bias = self.bias[:feature_dim]
else:
bias = None
self.cur_config = {'prune_dim': feature_dim}
out, _, _ = core.ops.layer_norm(input, weight, bias, 'epsilon',
self._epsilon, 'begin_norm_axis',
begin_norm_axis)
......@@ -1191,6 +1202,7 @@ class SuperEmbedding(nn.Embedding):
padding_idx, sparse, weight_attr,
name)
self.candidate_config = candidate_config
self.cur_config = None
self.expand_ratio = candidate_config[
'expand_ratio'] if 'expand_ratio' in candidate_config else None
self.base_output_dim = self._embedding_dim
......@@ -1216,6 +1228,7 @@ class SuperEmbedding(nn.Embedding):
out_nc = self._embedding_dim
weight = self.weight[:, :out_nc]
self.cur_config = {'prune_dim': list(weight.shape)}
return F.embedding(
input,
weight=weight,
......
......@@ -27,11 +27,14 @@ else:
from .layers import SuperConv2D, SuperLinear
Layer = paddle.nn.Layer
DataParallel = paddle.DataParallel
from .layers_base import BaseBlock
from .layers_base import BaseBlock, Block
from .utils.utils import search_idx
from ...common import get_logger
from ...core import GraphWrapper, dygraph2program
from .get_sub_model import get_prune_params_config, prune_params, check_search_space, broadcast_search_space
from .get_sub_model import check_search_space, broadcast_search_space
from paddle.fluid import core
from paddle.fluid.framework import Variable
import numbers
_logger = get_logger(__name__, level=logging.INFO)
......@@ -459,35 +462,41 @@ class OFA(OFABase):
def search(self, eval_func, condition):
pass
def _export_sub_model_config(self, origin_model, config, input_shapes,
input_dtypes):
param2name = {}
for name, sublayer in origin_model.named_sublayers():
for param in sublayer.parameters(include_sublayers=False):
if name.split('.')[-1] == 'fn':
### if sublayer is Block, the name of the param.name has 'fn', the config always donnot have 'fn'
param2name[param.name] = name[:-3]
else:
param2name[param.name] = name
def _get_model_pruned_weight(self):
program = dygraph2program(
origin_model, inputs=input_shapes, dtypes=input_dtypes)
graph = GraphWrapper(program)
pruned_param = {}
for l_name, sublayer in self.model.named_sublayers():
same_config, _ = check_search_space(graph)
if same_config != None:
broadcast_search_space(same_config, param2name, config)
if getattr(sublayer, 'cur_config', None) == None:
continue
origin_model_config = {}
for name, sublayer in origin_model.named_sublayers():
if isinstance(sublayer, BaseBlock):
sublayer = sublayer.fn
for param in sublayer.parameters(include_sublayers=False):
if name in config.keys():
origin_model_config[param.name] = config[name]
assert 'prune_dim' in sublayer.cur_config, 'The laycer {} do not have prune_dim in cur_config.'.format(
l_name)
prune_shape = sublayer.cur_config['prune_dim']
for p_name, param in sublayer.named_parameters(
include_sublayers=False):
origin_param = param.value().get_tensor()
param = np.array(origin_param).astype("float32")
name = l_name + '.' + p_name
if isinstance(prune_shape, list):
param_prune_config = get_prune_params_config(graph, origin_model_config)
return param_prune_config
if len(param.shape) == 4:
pruned_param[name] = param[:prune_shape[0], :
prune_shape[1], :, :]
elif len(param.shape) == 2:
pruned_param[name] = param[:prune_shape[0], :
prune_shape[1]]
else:
if isinstance(sublayer, SuperLinear):
pruned_param[name] = param[:prune_shape[1]]
else:
pruned_param[name] = param[:prune_shape[0]]
else:
pruned_param[name] = param[:prune_shape]
return pruned_param
def export(self,
config,
......@@ -510,17 +519,72 @@ class OFA(OFABase):
config = {'conv2d_0': {'expand_ratio': 2}, 'conv2d_1': {'expand_ratio': 2}}
origin_model = ofa_model.export(origin_model, config, input_shapes=[1, 3, 28, 28], input_dtypes=['float32'])
"""
super_sd = None
self.set_net_config(config)
self.model.eval()
def build_input(input_size, dtypes):
if isinstance(input_size, list) and all(
isinstance(i, numbers.Number) for i in input_size):
if isinstance(dtypes, list):
dtype = dtypes[0]
else:
dtype = dtypes
return paddle.cast(paddle.rand(list(input_size)), dtype)
if isinstance(input_size, dict):
inputs = {}
if isinstance(dtypes, list):
dtype = dtypes[0]
else:
dtype = dtypes
for key, value in input_size.items():
inputs[key] = paddle.cast(paddle.rand(list(value)), dtype)
return inputs
if isinstance(input_size, list):
return [
build_input(i, dtype)
for i, dtype in zip(input_size, dtypes)
]
data = build_input(input_shapes, input_dtypes)
if isinstance(data, list):
self.forward(*data)
else:
self.forward(data)
super_model_state_dict = None
if load_weights_from_supernet and origin_model != None:
super_sd = remove_model_fn(origin_model, self.model.state_dict())
super_model_state_dict = remove_model_fn(origin_model,
self.model.state_dict())
if origin_model == None:
origin_model = self.model
origin_model = origin_model._layers if isinstance(
origin_model, DataParallel) else origin_model
param_config = self._export_sub_model_config(origin_model, config,
input_shapes, input_dtypes)
prune_params(origin_model, param_config, super_sd)
_logger.info("Start to get pruned params, please wait...")
pruned_param = self._get_model_pruned_weight()
pruned_state_dict = remove_model_fn(origin_model, pruned_param)
_logger.info("Start to get pruned model, please wait...")
for l_name, sublayer in origin_model.named_sublayers():
for p_name, param in sublayer.named_parameters(
include_sublayers=False):
name = l_name + '.' + p_name
t_value = param.value().get_tensor()
if name in pruned_state_dict:
p = t_value._place()
if p.is_cpu_place():
place = core.CPUPlace()
elif p.is_cuda_pinned_place():
place = core.CUDAPinnedPlace()
else:
place = core.CUDAPlace(p.gpu_device_id())
t_value.set(pruned_state_dict[name], place)
if super_model_state_dict != None and len(super_model_state_dict) != 0:
for k, v in super_model_state_dict.items():
setattr(origin_model, k, v)
return origin_model
@property
......@@ -566,11 +630,26 @@ class OFA(OFABase):
input_shapes = []
input_dtypes = []
for n in inputs:
input_shapes.append(n.shape)
input_dtypes.append(n.numpy().dtype)
for n, v in kwargs.items():
input_shapes.append(v.shape)
input_dtypes.append(v.numpy().dtype)
if isinstance(n, Variable):
input_shapes.append(n)
input_dtypes.append(n.numpy().dtype)
for key, val in kwargs.items():
if isinstance(val, Variable):
input_shapes.append(val)
input_dtypes.append(val.numpy().dtype)
elif isinstance(val, dict):
input_shape = {}
input_dtype = {}
for k, v in val.items():
input_shape[k] = v
input_dtype[k] = v.numpy().dtype
input_shapes.append(input_shape)
input_dtypes.append(input_dtype)
else:
_logger.error(
"Cannot figure out the type of inputs! Right now, the type of inputs can be only Variable or dict."
)
### find shortcut block using static model
model_to_traverse = self.model._layers if isinstance(
......@@ -674,11 +753,9 @@ class OFA(OFABase):
_logger.debug("Current config is {}".format(self.current_config))
if 'depth' in self.current_config:
kwargs['depth'] = self.current_config['depth']
if self._broadcast:
broadcast_search_space(self._same_ss, self._param2key,
self.current_config)
student_output = self.model.forward(*inputs, **kwargs)
if self._add_teacher:
......
......@@ -142,10 +142,19 @@ class ModelConv2(nn.Layer):
class ModelLinear(nn.Layer):
def __init__(self):
super(ModelLinear, self).__init__()
with supernet(expand_ratio=(1.0, 2.0, 4.0)) as ofa_super:
with supernet(expand_ratio=(1, 2, 4)) as ofa_super:
models = []
models += [nn.Embedding(num_embeddings=64, embedding_dim=64)]
models += [nn.Linear(64, 128)]
weight_attr = paddle.ParamAttr(
learning_rate=0.5,
regularizer=paddle.regularizer.L2Decay(1.0),
trainable=True)
bias_attr = paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=1.0))
models += [
nn.Linear(
64, 128, weight_attr=weight_attr, bias_attr=bias_attr)
]
models += [nn.LayerNorm(128)]
models += [nn.Linear(128, 256)]
models = ofa_super.convert(models)
......@@ -402,16 +411,7 @@ class TestExport(unittest.TestCase):
self.ofa_model = OFA(model)
def test_ofa(self):
config = {
'embedding_1': {
'expand_ratio': (2.0)
},
'linear_3': {
'expand_ratio': (2.0)
},
'linear_4': {},
'linear_5': {}
}
config = self.ofa_model._sample_config(task='expand_ratio', phase=None)
origin_dict = {}
for name, param in self.origin_model.named_parameters():
origin_dict[name] = param.shape
......@@ -459,9 +459,29 @@ class TestExportCase1(unittest.TestCase):
outs, _ = self.ofa_model(self.data)
self.config = self.ofa_model.current_config
def test_export_model(self):
self.ofa_model.export(
def test_export_model_linear1(self):
ex_model = self.ofa_model.export(
self.config, input_shapes=[[3, 64]], input_dtypes=['int64'])
ex_model(self.data)
assert len(self.ofa_model.ofa_layers) == 4
class TestExportCase2(unittest.TestCase):
def setUp(self):
model = ModelLinear()
data_np = np.random.random((3, 64)).astype(np.int64)
self.data = paddle.to_tensor(data_np)
self.ofa_model = OFA(model)
self.ofa_model.set_epoch(0)
outs, _ = self.ofa_model(self.data)
self.config = self.ofa_model.current_config
def test_export_model_linear2(self):
config = self.ofa_model._sample_config(
task='expand_ratio', phase=None, sample_type='smallest')
ex_model = self.ofa_model.export(
config, input_shapes=[[3, 64]], input_dtypes=['int64'])
ex_model(self.data)
assert len(self.ofa_model.ofa_layers) == 4
......
......@@ -81,6 +81,25 @@ class ModelShortcut(nn.Layer):
return z
class ModelInputDict(nn.Layer):
def __init__(self):
super(ModelInputDict, self).__init__()
self.conv0 = nn.Sequential(
nn.Conv2D(3, 12, 1), nn.BatchNorm2D(12), nn.ReLU())
self.conv1 = nn.Sequential(
nn.Conv2D(12, 12, 1), nn.BatchNorm2D(12), nn.ReLU())
self.conv2 = nn.Sequential(
nn.Conv2D(12, 12, 1), nn.BatchNorm2D(12), nn.ReLU())
self.conv3 = nn.Sequential(
nn.Conv2D(12, 12, 1), nn.BatchNorm2D(12), nn.ReLU())
def forward(self, x, data):
x = self.conv1(self.conv0(x))
y = self.conv2(x)
y = y + data['data']
return self.conv3(y)
class TestOFAV2(unittest.TestCase):
def setUp(self):
model = ModelV1()
......@@ -93,7 +112,6 @@ class TestOFAV2(unittest.TestCase):
self.ofa_model.set_epoch(0)
self.ofa_model.set_task('expand_ratio')
out, _ = self.ofa_model(self.images)
print(self.ofa_model.get_current_config)
class TestOFAV2Export(unittest.TestCase):
......@@ -151,5 +169,34 @@ class TestShortcutSkiplayersCase2(TestShortcutSkiplayers):
assert list(self.ofa_model._ofa_layers.keys()) == ['conv1.0', 'out.0']
class TestInputDict(unittest.TestCase):
def setUp(self):
model = ModelInputDict()
sp_net_config = supernet(expand_ratio=[0.5, 1.0])
self.model = Convert(sp_net_config).convert(model)
self.images = paddle.randn(shape=[2, 3, 32, 32], dtype='float32')
self.images2 = {
'data': paddle.randn(
shape=[2, 12, 32, 32], dtype='float32')
}
default_run_config = {'skip_layers': ['conv1.0', 'conv2.0']}
self.run_config = RunConfig(**default_run_config)
self.ofa_model = OFA(self.model, run_config=self.run_config)
self.ofa_model._clear_search_space(self.images, data=self.images2)
def test_export(self):
config = self.ofa_model._sample_config(
task="expand_ratio", sample_type="smallest")
self.ofa_model.export(
config,
input_shapes=[[1, 3, 32, 32], {
'data': [1, 12, 32, 32]
}],
input_dtypes=['float32', 'float32'])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册