未验证 提交 e274d7fb 编写于 作者: C Chang Xu 提交者: GitHub

new_export_func (#824)

* new_export_func

* add_test

* remove_kernel_prune

* add_test_&_update_clear_ss
上级 da0f5a6d
...@@ -64,6 +64,7 @@ def extract_vars(inputs): ...@@ -64,6 +64,7 @@ def extract_vars(inputs):
f"Variable is excepted, but get an element with type({type(_value)}) from inputs whose type is dict. And the key of element is {_key}." f"Variable is excepted, but get an element with type({type(_value)}) from inputs whose type is dict. And the key of element is {_key}."
) )
elif isinstance(inputs, (tuple, list)): elif isinstance(inputs, (tuple, list)):
for _value in inputs: for _value in inputs:
vars.extend(extract_vars(_value)) vars.extend(extract_vars(_value))
if len(vars) == 0: if len(vars) == 0:
...@@ -99,7 +100,6 @@ def dygraph2program(layer, ...@@ -99,7 +100,6 @@ def dygraph2program(layer,
extract_outputs_fn=None, extract_outputs_fn=None,
dtypes=None): dtypes=None):
assert isinstance(layer, Layer) assert isinstance(layer, Layer)
extract_inputs_fn = extract_inputs_fn if extract_inputs_fn is not None else extract_vars extract_inputs_fn = extract_inputs_fn if extract_inputs_fn is not None else extract_vars
extract_outputs_fn = extract_outputs_fn if extract_outputs_fn is not None else extract_vars extract_outputs_fn = extract_outputs_fn if extract_outputs_fn is not None else extract_vars
tracer = _dygraph_tracer()._get_program_desc_tracer() tracer = _dygraph_tracer()._get_program_desc_tracer()
...@@ -116,6 +116,7 @@ def dygraph2program(layer, ...@@ -116,6 +116,7 @@ def dygraph2program(layer,
else: else:
inputs = to_variables(inputs) inputs = to_variables(inputs)
input_var_list = extract_inputs_fn(inputs) input_var_list = extract_inputs_fn(inputs)
original_outputs = layer(*inputs) original_outputs = layer(*inputs)
# 'original_outputs' may be dict, so we should convert it to list of varibles. # 'original_outputs' may be dict, so we should convert it to list of varibles.
# And should not create new varibles in 'extract_vars'. # And should not create new varibles in 'extract_vars'.
......
...@@ -17,7 +17,7 @@ import paddle ...@@ -17,7 +17,7 @@ import paddle
from paddle.fluid import core from paddle.fluid import core
from .layers_base import BaseBlock from .layers_base import BaseBlock
__all__ = ['get_prune_params_config', 'prune_params', 'check_search_space'] __all__ = ['check_search_space']
WEIGHT_OP = [ WEIGHT_OP = [
'conv2d', 'linear', 'embedding', 'conv2d_transpose', 'depthwise_conv2d' 'conv2d', 'linear', 'embedding', 'conv2d_transpose', 'depthwise_conv2d'
...@@ -28,63 +28,6 @@ CONV_TYPES = [ ...@@ -28,63 +28,6 @@ CONV_TYPES = [
] ]
def get_prune_params_config(graph, origin_model_config):
""" Convert config of search space to parameters' prune config.
"""
param_config = {}
precedor = None
for op in graph.ops():
### TODO(ceci3):
### 1. fix config when this op is concat by graph.pre_ops(op)
### 2. add kernel_size in config
for inp in op.all_inputs():
n_ops = graph.next_ops(op)
if inp._var.name in origin_model_config.keys():
if 'expand_ratio' in origin_model_config[
inp._var.name] or 'channel' in origin_model_config[
inp._var.name]:
key = 'channel' if 'channel' in origin_model_config[
inp._var.name] else 'expand_ratio'
tmp = origin_model_config[inp._var.name][key]
if len(inp._var.shape) > 1:
if inp._var.name in param_config.keys():
param_config[inp._var.name].append(tmp)
### first op
else:
param_config[inp._var.name] = [precedor, tmp]
else:
param_config[inp._var.name] = [tmp]
precedor = tmp
else:
precedor = None
for n_op in n_ops:
for next_inp in n_op.all_inputs():
if next_inp._var.persistable == True:
if next_inp._var.name in origin_model_config.keys():
if 'expand_ratio' in origin_model_config[
next_inp._var.
name] or 'channel' in origin_model_config[
next_inp._var.name]:
key = 'channel' if 'channel' in origin_model_config[
next_inp._var.name] else 'expand_ratio'
tmp = origin_model_config[next_inp._var.name][
key]
pre = tmp if precedor is None else precedor
if len(next_inp._var.shape) > 1:
param_config[next_inp._var.name] = [pre]
else:
param_config[next_inp._var.name] = [tmp]
else:
if len(next_inp._var.
shape) > 1 and precedor != None:
param_config[
next_inp._var.name] = [precedor, None]
else:
param_config[next_inp._var.name] = [precedor]
return param_config
def get_actual_shape(transform, channel): def get_actual_shape(transform, channel):
if transform == None: if transform == None:
channel = int(channel) channel = int(channel)
...@@ -96,66 +39,6 @@ def get_actual_shape(transform, channel): ...@@ -96,66 +39,6 @@ def get_actual_shape(transform, channel):
return channel return channel
def prune_params(model, param_config, super_model_sd=None):
""" Prune parameters according to the config.
Parameters:
model(paddle.nn.Layer): instance of model.
param_config(dict): prune config of each weight.
super_model_sd(dict, optional): parameters come from supernet. If super_model_sd is not None, transfer parameters from this dict to model; otherwise, prune model from itself.
"""
for l_name, sublayer in model.named_sublayers():
if isinstance(sublayer, BaseBlock):
continue
for p_name, param in sublayer.named_parameters(include_sublayers=False):
t_value = param.value().get_tensor()
value = np.array(t_value).astype("float32")
if super_model_sd != None:
name = l_name + '.' + p_name
super_t_value = super_model_sd[name].value().get_tensor()
super_value = np.array(super_t_value).astype("float32")
super_model_sd.pop(name)
if param.name in param_config.keys():
if len(param_config[param.name]) > 1:
in_exp = param_config[param.name][0]
out_exp = param_config[param.name][1]
if sublayer.__class__.__name__.lower() in CONV_TYPES:
in_chn = get_actual_shape(in_exp, value.shape[1])
out_chn = get_actual_shape(out_exp, value.shape[0])
prune_value = super_value[:out_chn, :in_chn, ...] \
if super_model_sd != None else value[:out_chn, :in_chn, ...]
else:
in_chn = get_actual_shape(in_exp, value.shape[0])
out_chn = get_actual_shape(out_exp, value.shape[1])
prune_value = super_value[:in_chn, :out_chn, ...] \
if super_model_sd != None else value[:in_chn, :out_chn, ...]
else:
out_chn = get_actual_shape(param_config[param.name][0],
value.shape[0])
prune_value = super_value[:out_chn, ...] \
if super_model_sd != None else value[:out_chn, ...]
else:
prune_value = super_value if super_model_sd != None else value
p = t_value._place()
if p.is_cpu_place():
place = core.CPUPlace()
elif p.is_cuda_pinned_place():
place = core.CUDAPinnedPlace()
else:
place = core.CUDAPlace(p.gpu_device_id())
t_value.set(prune_value, place)
if param.trainable:
param.clear_gradient()
### initialize param which not in sublayers, such as create persistable inputs by create_parameters
if super_model_sd != None and len(super_model_sd) != 0:
for k, v in super_model_sd.items():
setattr(model, k, v)
def _is_depthwise(op): def _is_depthwise(op):
"""Check if this op is depthwise conv. Only Cin == Cout == groups be consider as depthwise conv. """Check if this op is depthwise conv. Only Cin == Cout == groups be consider as depthwise conv.
The shape of input and the shape of output in depthwise conv must be same in superlayer, The shape of input and the shape of output in depthwise conv must be same in superlayer,
......
...@@ -177,6 +177,7 @@ class SuperConv2D(nn.Conv2D): ...@@ -177,6 +177,7 @@ class SuperConv2D(nn.Conv2D):
data_format=data_format) data_format=data_format)
self.candidate_config = candidate_config self.candidate_config = candidate_config
self.cur_config = None
if len(candidate_config.items()) != 0: if len(candidate_config.items()) != 0:
for k, v in candidate_config.items(): for k, v in candidate_config.items():
candidate_config[k] = list(set(v)) candidate_config[k] = list(set(v))
...@@ -314,7 +315,7 @@ class SuperConv2D(nn.Conv2D): ...@@ -314,7 +315,7 @@ class SuperConv2D(nn.Conv2D):
bias = self.bias[:weight_out_nc] bias = self.bias[:weight_out_nc]
else: else:
bias = self.bias bias = self.bias
self.cur_config['prune_dim'] = list(weight.shape)
out = F.conv2d( out = F.conv2d(
input, input,
weight, weight,
...@@ -482,6 +483,7 @@ class SuperConv2DTranspose(nn.Conv2DTranspose): ...@@ -482,6 +483,7 @@ class SuperConv2DTranspose(nn.Conv2DTranspose):
data_format=data_format) data_format=data_format)
self.candidate_config = candidate_config self.candidate_config = candidate_config
self.cur_config = None
if len(self.candidate_config.items()) != 0: if len(self.candidate_config.items()) != 0:
for k, v in candidate_config.items(): for k, v in candidate_config.items():
candidate_config[k] = list(set(v)) candidate_config[k] = list(set(v))
...@@ -620,7 +622,7 @@ class SuperConv2DTranspose(nn.Conv2DTranspose): ...@@ -620,7 +622,7 @@ class SuperConv2DTranspose(nn.Conv2DTranspose):
bias = self.bias[:weight_out_nc] bias = self.bias[:weight_out_nc]
else: else:
bias = self.bias bias = self.bias
self.cur_config['prune_dim'] = list(weight.shape)
out = F.conv2d_transpose( out = F.conv2d_transpose(
input, input,
weight, weight,
...@@ -733,6 +735,7 @@ class SuperSeparableConv2D(nn.Layer): ...@@ -733,6 +735,7 @@ class SuperSeparableConv2D(nn.Layer):
]) ])
self.candidate_config = candidate_config self.candidate_config = candidate_config
self.cur_config = None
self.expand_ratio = candidate_config[ self.expand_ratio = candidate_config[
'expand_ratio'] if 'expand_ratio' in candidate_config else None 'expand_ratio'] if 'expand_ratio' in candidate_config else None
self.base_output_dim = self.conv[0]._out_channels self.base_output_dim = self.conv[0]._out_channels
...@@ -784,7 +787,7 @@ class SuperSeparableConv2D(nn.Layer): ...@@ -784,7 +787,7 @@ class SuperSeparableConv2D(nn.Layer):
bias = self.conv[2].bias[:out_nc] bias = self.conv[2].bias[:out_nc]
else: else:
bias = self.conv[2].bias bias = self.conv[2].bias
self.cur_config['prune_dim'] = list(weight.shape)
conv1_out = F.conv2d( conv1_out = F.conv2d(
norm_out, norm_out,
weight, weight,
...@@ -864,6 +867,7 @@ class SuperLinear(nn.Linear): ...@@ -864,6 +867,7 @@ class SuperLinear(nn.Linear):
self._in_features = in_features self._in_features = in_features
self._out_features = out_features self._out_features = out_features
self.candidate_config = candidate_config self.candidate_config = candidate_config
self.cur_config = None
self.expand_ratio = candidate_config[ self.expand_ratio = candidate_config[
'expand_ratio'] if 'expand_ratio' in candidate_config else None 'expand_ratio'] if 'expand_ratio' in candidate_config else None
self.base_output_dim = self._out_features self.base_output_dim = self._out_features
...@@ -896,7 +900,7 @@ class SuperLinear(nn.Linear): ...@@ -896,7 +900,7 @@ class SuperLinear(nn.Linear):
bias = self.bias[:out_nc] bias = self.bias[:out_nc]
else: else:
bias = self.bias bias = self.bias
self.cur_config['prune_dim'] = list(weight.shape)
out = F.linear(x=input, weight=weight, bias=bias, name=self.name) out = F.linear(x=input, weight=weight, bias=bias, name=self.name)
return out return out
...@@ -945,6 +949,7 @@ class SuperBatchNorm2D(nn.BatchNorm2D): ...@@ -945,6 +949,7 @@ class SuperBatchNorm2D(nn.BatchNorm2D):
super(SuperBatchNorm2D, self).__init__( super(SuperBatchNorm2D, self).__init__(
num_features, momentum, epsilon, weight_attr, bias_attr, num_features, momentum, epsilon, weight_attr, bias_attr,
data_format, use_global_stats, name) data_format, use_global_stats, name)
self.cur_config = None
def forward(self, input): def forward(self, input):
self._check_data_format(self._data_format) self._check_data_format(self._data_format)
...@@ -956,7 +961,7 @@ class SuperBatchNorm2D(nn.BatchNorm2D): ...@@ -956,7 +961,7 @@ class SuperBatchNorm2D(nn.BatchNorm2D):
bias = self.bias[:feature_dim] bias = self.bias[:feature_dim]
mean = self._mean[:feature_dim] mean = self._mean[:feature_dim]
variance = self._variance[:feature_dim] variance = self._variance[:feature_dim]
self.cur_config = {'prune_dim': feature_dim}
return F.batch_norm( return F.batch_norm(
input, input,
mean, mean,
...@@ -982,6 +987,7 @@ class SuperSyncBatchNorm(nn.SyncBatchNorm): ...@@ -982,6 +987,7 @@ class SuperSyncBatchNorm(nn.SyncBatchNorm):
super(SuperSyncBatchNorm, super(SuperSyncBatchNorm,
self).__init__(num_features, momentum, epsilon, weight_attr, self).__init__(num_features, momentum, epsilon, weight_attr,
bias_attr, data_format, name) bias_attr, data_format, name)
self.cur_config = None
def forward(self, input): def forward(self, input):
...@@ -995,6 +1001,7 @@ class SuperSyncBatchNorm(nn.SyncBatchNorm): ...@@ -995,6 +1001,7 @@ class SuperSyncBatchNorm(nn.SyncBatchNorm):
mean_out = mean mean_out = mean
# variance and variance out share the same memory # variance and variance out share the same memory
variance_out = variance variance_out = variance
self.cur_config = {'prune_dim': feature_dim}
attrs = ("momentum", self._momentum, "epsilon", self._epsilon, attrs = ("momentum", self._momentum, "epsilon", self._epsilon,
"is_test", not self.training, "data_layout", self._data_format, "is_test", not self.training, "data_layout", self._data_format,
...@@ -1049,6 +1056,7 @@ class SuperInstanceNorm2D(nn.InstanceNorm2D): ...@@ -1049,6 +1056,7 @@ class SuperInstanceNorm2D(nn.InstanceNorm2D):
super(SuperInstanceNorm2D, self).__init__(num_features, epsilon, super(SuperInstanceNorm2D, self).__init__(num_features, epsilon,
momentum, weight_attr, momentum, weight_attr,
bias_attr, data_format, name) bias_attr, data_format, name)
self.cur_config = None
def forward(self, input): def forward(self, input):
self._check_input_dim(input) self._check_input_dim(input)
...@@ -1060,7 +1068,7 @@ class SuperInstanceNorm2D(nn.InstanceNorm2D): ...@@ -1060,7 +1068,7 @@ class SuperInstanceNorm2D(nn.InstanceNorm2D):
else: else:
scale = self.scale[:feature_dim] scale = self.scale[:feature_dim]
bias = self.bias[:feature_dim] bias = self.bias[:feature_dim]
self.cur_config = {'prune_dim': feature_dim}
return F.instance_norm(input, scale, bias, eps=self._epsilon) return F.instance_norm(input, scale, bias, eps=self._epsilon)
...@@ -1112,6 +1120,7 @@ class SuperLayerNorm(nn.LayerNorm): ...@@ -1112,6 +1120,7 @@ class SuperLayerNorm(nn.LayerNorm):
name=None): name=None):
super(SuperLayerNorm, self).__init__(normalized_shape, epsilon, super(SuperLayerNorm, self).__init__(normalized_shape, epsilon,
weight_attr, bias_attr, name) weight_attr, bias_attr, name)
self.cur_config = None
def forward(self, input): def forward(self, input):
### TODO(ceci3): fix if normalized_shape is not a single number ### TODO(ceci3): fix if normalized_shape is not a single number
...@@ -1127,6 +1136,8 @@ class SuperLayerNorm(nn.LayerNorm): ...@@ -1127,6 +1136,8 @@ class SuperLayerNorm(nn.LayerNorm):
bias = self.bias[:feature_dim] bias = self.bias[:feature_dim]
else: else:
bias = None bias = None
self.cur_config = {'prune_dim': feature_dim}
out, _, _ = core.ops.layer_norm(input, weight, bias, 'epsilon', out, _, _ = core.ops.layer_norm(input, weight, bias, 'epsilon',
self._epsilon, 'begin_norm_axis', self._epsilon, 'begin_norm_axis',
begin_norm_axis) begin_norm_axis)
...@@ -1191,6 +1202,7 @@ class SuperEmbedding(nn.Embedding): ...@@ -1191,6 +1202,7 @@ class SuperEmbedding(nn.Embedding):
padding_idx, sparse, weight_attr, padding_idx, sparse, weight_attr,
name) name)
self.candidate_config = candidate_config self.candidate_config = candidate_config
self.cur_config = None
self.expand_ratio = candidate_config[ self.expand_ratio = candidate_config[
'expand_ratio'] if 'expand_ratio' in candidate_config else None 'expand_ratio'] if 'expand_ratio' in candidate_config else None
self.base_output_dim = self._embedding_dim self.base_output_dim = self._embedding_dim
...@@ -1216,6 +1228,7 @@ class SuperEmbedding(nn.Embedding): ...@@ -1216,6 +1228,7 @@ class SuperEmbedding(nn.Embedding):
out_nc = self._embedding_dim out_nc = self._embedding_dim
weight = self.weight[:, :out_nc] weight = self.weight[:, :out_nc]
self.cur_config = {'prune_dim': list(weight.shape)}
return F.embedding( return F.embedding(
input, input,
weight=weight, weight=weight,
......
...@@ -27,11 +27,14 @@ else: ...@@ -27,11 +27,14 @@ else:
from .layers import SuperConv2D, SuperLinear from .layers import SuperConv2D, SuperLinear
Layer = paddle.nn.Layer Layer = paddle.nn.Layer
DataParallel = paddle.DataParallel DataParallel = paddle.DataParallel
from .layers_base import BaseBlock from .layers_base import BaseBlock, Block
from .utils.utils import search_idx from .utils.utils import search_idx
from ...common import get_logger from ...common import get_logger
from ...core import GraphWrapper, dygraph2program from ...core import GraphWrapper, dygraph2program
from .get_sub_model import get_prune_params_config, prune_params, check_search_space, broadcast_search_space from .get_sub_model import check_search_space, broadcast_search_space
from paddle.fluid import core
from paddle.fluid.framework import Variable
import numbers
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
...@@ -459,35 +462,41 @@ class OFA(OFABase): ...@@ -459,35 +462,41 @@ class OFA(OFABase):
def search(self, eval_func, condition): def search(self, eval_func, condition):
pass pass
def _export_sub_model_config(self, origin_model, config, input_shapes, def _get_model_pruned_weight(self):
input_dtypes):
param2name = {}
for name, sublayer in origin_model.named_sublayers():
for param in sublayer.parameters(include_sublayers=False):
if name.split('.')[-1] == 'fn':
### if sublayer is Block, the name of the param.name has 'fn', the config always donnot have 'fn'
param2name[param.name] = name[:-3]
else:
param2name[param.name] = name
program = dygraph2program( pruned_param = {}
origin_model, inputs=input_shapes, dtypes=input_dtypes) for l_name, sublayer in self.model.named_sublayers():
graph = GraphWrapper(program)
same_config, _ = check_search_space(graph) if getattr(sublayer, 'cur_config', None) == None:
if same_config != None: continue
broadcast_search_space(same_config, param2name, config)
origin_model_config = {} assert 'prune_dim' in sublayer.cur_config, 'The laycer {} do not have prune_dim in cur_config.'.format(
for name, sublayer in origin_model.named_sublayers(): l_name)
if isinstance(sublayer, BaseBlock): prune_shape = sublayer.cur_config['prune_dim']
sublayer = sublayer.fn
for param in sublayer.parameters(include_sublayers=False): for p_name, param in sublayer.named_parameters(
if name in config.keys(): include_sublayers=False):
origin_model_config[param.name] = config[name] origin_param = param.value().get_tensor()
param = np.array(origin_param).astype("float32")
name = l_name + '.' + p_name
if isinstance(prune_shape, list):
param_prune_config = get_prune_params_config(graph, origin_model_config) if len(param.shape) == 4:
return param_prune_config pruned_param[name] = param[:prune_shape[0], :
prune_shape[1], :, :]
elif len(param.shape) == 2:
pruned_param[name] = param[:prune_shape[0], :
prune_shape[1]]
else:
if isinstance(sublayer, SuperLinear):
pruned_param[name] = param[:prune_shape[1]]
else:
pruned_param[name] = param[:prune_shape[0]]
else:
pruned_param[name] = param[:prune_shape]
return pruned_param
def export(self, def export(self,
config, config,
...@@ -510,17 +519,72 @@ class OFA(OFABase): ...@@ -510,17 +519,72 @@ class OFA(OFABase):
config = {'conv2d_0': {'expand_ratio': 2}, 'conv2d_1': {'expand_ratio': 2}} config = {'conv2d_0': {'expand_ratio': 2}, 'conv2d_1': {'expand_ratio': 2}}
origin_model = ofa_model.export(origin_model, config, input_shapes=[1, 3, 28, 28], input_dtypes=['float32']) origin_model = ofa_model.export(origin_model, config, input_shapes=[1, 3, 28, 28], input_dtypes=['float32'])
""" """
super_sd = None self.set_net_config(config)
self.model.eval()
def build_input(input_size, dtypes):
if isinstance(input_size, list) and all(
isinstance(i, numbers.Number) for i in input_size):
if isinstance(dtypes, list):
dtype = dtypes[0]
else:
dtype = dtypes
return paddle.cast(paddle.rand(list(input_size)), dtype)
if isinstance(input_size, dict):
inputs = {}
if isinstance(dtypes, list):
dtype = dtypes[0]
else:
dtype = dtypes
for key, value in input_size.items():
inputs[key] = paddle.cast(paddle.rand(list(value)), dtype)
return inputs
if isinstance(input_size, list):
return [
build_input(i, dtype)
for i, dtype in zip(input_size, dtypes)
]
data = build_input(input_shapes, input_dtypes)
if isinstance(data, list):
self.forward(*data)
else:
self.forward(data)
super_model_state_dict = None
if load_weights_from_supernet and origin_model != None: if load_weights_from_supernet and origin_model != None:
super_sd = remove_model_fn(origin_model, self.model.state_dict()) super_model_state_dict = remove_model_fn(origin_model,
self.model.state_dict())
if origin_model == None: if origin_model == None:
origin_model = self.model origin_model = self.model
origin_model = origin_model._layers if isinstance( origin_model = origin_model._layers if isinstance(
origin_model, DataParallel) else origin_model origin_model, DataParallel) else origin_model
param_config = self._export_sub_model_config(origin_model, config,
input_shapes, input_dtypes) _logger.info("Start to get pruned params, please wait...")
prune_params(origin_model, param_config, super_sd) pruned_param = self._get_model_pruned_weight()
pruned_state_dict = remove_model_fn(origin_model, pruned_param)
_logger.info("Start to get pruned model, please wait...")
for l_name, sublayer in origin_model.named_sublayers():
for p_name, param in sublayer.named_parameters(
include_sublayers=False):
name = l_name + '.' + p_name
t_value = param.value().get_tensor()
if name in pruned_state_dict:
p = t_value._place()
if p.is_cpu_place():
place = core.CPUPlace()
elif p.is_cuda_pinned_place():
place = core.CUDAPinnedPlace()
else:
place = core.CUDAPlace(p.gpu_device_id())
t_value.set(pruned_state_dict[name], place)
if super_model_state_dict != None and len(super_model_state_dict) != 0:
for k, v in super_model_state_dict.items():
setattr(origin_model, k, v)
return origin_model return origin_model
@property @property
...@@ -566,11 +630,26 @@ class OFA(OFABase): ...@@ -566,11 +630,26 @@ class OFA(OFABase):
input_shapes = [] input_shapes = []
input_dtypes = [] input_dtypes = []
for n in inputs: for n in inputs:
input_shapes.append(n.shape) if isinstance(n, Variable):
input_dtypes.append(n.numpy().dtype) input_shapes.append(n)
for n, v in kwargs.items(): input_dtypes.append(n.numpy().dtype)
input_shapes.append(v.shape)
input_dtypes.append(v.numpy().dtype) for key, val in kwargs.items():
if isinstance(val, Variable):
input_shapes.append(val)
input_dtypes.append(val.numpy().dtype)
elif isinstance(val, dict):
input_shape = {}
input_dtype = {}
for k, v in val.items():
input_shape[k] = v
input_dtype[k] = v.numpy().dtype
input_shapes.append(input_shape)
input_dtypes.append(input_dtype)
else:
_logger.error(
"Cannot figure out the type of inputs! Right now, the type of inputs can be only Variable or dict."
)
### find shortcut block using static model ### find shortcut block using static model
model_to_traverse = self.model._layers if isinstance( model_to_traverse = self.model._layers if isinstance(
...@@ -674,11 +753,9 @@ class OFA(OFABase): ...@@ -674,11 +753,9 @@ class OFA(OFABase):
_logger.debug("Current config is {}".format(self.current_config)) _logger.debug("Current config is {}".format(self.current_config))
if 'depth' in self.current_config: if 'depth' in self.current_config:
kwargs['depth'] = self.current_config['depth'] kwargs['depth'] = self.current_config['depth']
if self._broadcast: if self._broadcast:
broadcast_search_space(self._same_ss, self._param2key, broadcast_search_space(self._same_ss, self._param2key,
self.current_config) self.current_config)
student_output = self.model.forward(*inputs, **kwargs) student_output = self.model.forward(*inputs, **kwargs)
if self._add_teacher: if self._add_teacher:
......
...@@ -142,10 +142,19 @@ class ModelConv2(nn.Layer): ...@@ -142,10 +142,19 @@ class ModelConv2(nn.Layer):
class ModelLinear(nn.Layer): class ModelLinear(nn.Layer):
def __init__(self): def __init__(self):
super(ModelLinear, self).__init__() super(ModelLinear, self).__init__()
with supernet(expand_ratio=(1.0, 2.0, 4.0)) as ofa_super: with supernet(expand_ratio=(1, 2, 4)) as ofa_super:
models = [] models = []
models += [nn.Embedding(num_embeddings=64, embedding_dim=64)] models += [nn.Embedding(num_embeddings=64, embedding_dim=64)]
models += [nn.Linear(64, 128)] weight_attr = paddle.ParamAttr(
learning_rate=0.5,
regularizer=paddle.regularizer.L2Decay(1.0),
trainable=True)
bias_attr = paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=1.0))
models += [
nn.Linear(
64, 128, weight_attr=weight_attr, bias_attr=bias_attr)
]
models += [nn.LayerNorm(128)] models += [nn.LayerNorm(128)]
models += [nn.Linear(128, 256)] models += [nn.Linear(128, 256)]
models = ofa_super.convert(models) models = ofa_super.convert(models)
...@@ -402,16 +411,7 @@ class TestExport(unittest.TestCase): ...@@ -402,16 +411,7 @@ class TestExport(unittest.TestCase):
self.ofa_model = OFA(model) self.ofa_model = OFA(model)
def test_ofa(self): def test_ofa(self):
config = { config = self.ofa_model._sample_config(task='expand_ratio', phase=None)
'embedding_1': {
'expand_ratio': (2.0)
},
'linear_3': {
'expand_ratio': (2.0)
},
'linear_4': {},
'linear_5': {}
}
origin_dict = {} origin_dict = {}
for name, param in self.origin_model.named_parameters(): for name, param in self.origin_model.named_parameters():
origin_dict[name] = param.shape origin_dict[name] = param.shape
...@@ -459,9 +459,29 @@ class TestExportCase1(unittest.TestCase): ...@@ -459,9 +459,29 @@ class TestExportCase1(unittest.TestCase):
outs, _ = self.ofa_model(self.data) outs, _ = self.ofa_model(self.data)
self.config = self.ofa_model.current_config self.config = self.ofa_model.current_config
def test_export_model(self): def test_export_model_linear1(self):
self.ofa_model.export( ex_model = self.ofa_model.export(
self.config, input_shapes=[[3, 64]], input_dtypes=['int64']) self.config, input_shapes=[[3, 64]], input_dtypes=['int64'])
ex_model(self.data)
assert len(self.ofa_model.ofa_layers) == 4
class TestExportCase2(unittest.TestCase):
def setUp(self):
model = ModelLinear()
data_np = np.random.random((3, 64)).astype(np.int64)
self.data = paddle.to_tensor(data_np)
self.ofa_model = OFA(model)
self.ofa_model.set_epoch(0)
outs, _ = self.ofa_model(self.data)
self.config = self.ofa_model.current_config
def test_export_model_linear2(self):
config = self.ofa_model._sample_config(
task='expand_ratio', phase=None, sample_type='smallest')
ex_model = self.ofa_model.export(
config, input_shapes=[[3, 64]], input_dtypes=['int64'])
ex_model(self.data)
assert len(self.ofa_model.ofa_layers) == 4 assert len(self.ofa_model.ofa_layers) == 4
......
...@@ -81,6 +81,25 @@ class ModelShortcut(nn.Layer): ...@@ -81,6 +81,25 @@ class ModelShortcut(nn.Layer):
return z return z
class ModelInputDict(nn.Layer):
def __init__(self):
super(ModelInputDict, self).__init__()
self.conv0 = nn.Sequential(
nn.Conv2D(3, 12, 1), nn.BatchNorm2D(12), nn.ReLU())
self.conv1 = nn.Sequential(
nn.Conv2D(12, 12, 1), nn.BatchNorm2D(12), nn.ReLU())
self.conv2 = nn.Sequential(
nn.Conv2D(12, 12, 1), nn.BatchNorm2D(12), nn.ReLU())
self.conv3 = nn.Sequential(
nn.Conv2D(12, 12, 1), nn.BatchNorm2D(12), nn.ReLU())
def forward(self, x, data):
x = self.conv1(self.conv0(x))
y = self.conv2(x)
y = y + data['data']
return self.conv3(y)
class TestOFAV2(unittest.TestCase): class TestOFAV2(unittest.TestCase):
def setUp(self): def setUp(self):
model = ModelV1() model = ModelV1()
...@@ -93,7 +112,6 @@ class TestOFAV2(unittest.TestCase): ...@@ -93,7 +112,6 @@ class TestOFAV2(unittest.TestCase):
self.ofa_model.set_epoch(0) self.ofa_model.set_epoch(0)
self.ofa_model.set_task('expand_ratio') self.ofa_model.set_task('expand_ratio')
out, _ = self.ofa_model(self.images) out, _ = self.ofa_model(self.images)
print(self.ofa_model.get_current_config)
class TestOFAV2Export(unittest.TestCase): class TestOFAV2Export(unittest.TestCase):
...@@ -151,5 +169,34 @@ class TestShortcutSkiplayersCase2(TestShortcutSkiplayers): ...@@ -151,5 +169,34 @@ class TestShortcutSkiplayersCase2(TestShortcutSkiplayers):
assert list(self.ofa_model._ofa_layers.keys()) == ['conv1.0', 'out.0'] assert list(self.ofa_model._ofa_layers.keys()) == ['conv1.0', 'out.0']
class TestInputDict(unittest.TestCase):
def setUp(self):
model = ModelInputDict()
sp_net_config = supernet(expand_ratio=[0.5, 1.0])
self.model = Convert(sp_net_config).convert(model)
self.images = paddle.randn(shape=[2, 3, 32, 32], dtype='float32')
self.images2 = {
'data': paddle.randn(
shape=[2, 12, 32, 32], dtype='float32')
}
default_run_config = {'skip_layers': ['conv1.0', 'conv2.0']}
self.run_config = RunConfig(**default_run_config)
self.ofa_model = OFA(self.model, run_config=self.run_config)
self.ofa_model._clear_search_space(self.images, data=self.images2)
def test_export(self):
config = self.ofa_model._sample_config(
task="expand_ratio", sample_type="smallest")
self.ofa_model.export(
config,
input_shapes=[[1, 3, 32, 32], {
'data': [1, 12, 32, 32]
}],
input_dtypes=['float32', 'float32'])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册