未验证 提交 dae3c307 编写于 作者: C ceci3 提交者: GitHub

fix ofa search space when model has shortcut (#679)

* fix shortcut

* fix when kernel_size in ss

* fix export

* fix dp

* fix export model prune param twice

* add comment

* fix unittest

* fix

* update
上级 adc57c75
......@@ -124,10 +124,10 @@ def do_train(args):
ofa_model.model.set_state_dict(sd)
best_config = utils.dynabert_config(ofa_model, args.width_mult)
ofa_model.export(
origin_model,
best_config,
input_shapes=[[1, args.max_seq_length], [1, args.max_seq_length]],
input_dtypes=['int64', 'int64'])
input_dtypes=['int64', 'int64'],
origin_model=origin_model)
for name, sublayer in origin_model.named_sublayers():
if isinstance(sublayer, paddle.nn.MultiHeadAttention):
sublayer.num_heads = int(args.width_mult * sublayer.num_heads)
......
......@@ -34,6 +34,7 @@ else:
from .layers import *
from . import layers
Layer = paddle.nn.Layer
from .layers_base import Block
_logger = get_logger(__name__, level=logging.INFO)
......
......@@ -15,11 +15,20 @@
import numpy as np
import paddle
from paddle.fluid import core
from .layers_base import BaseBlock
__all__ = ['get_prune_params_config', 'prune_params']
__all__ = ['get_prune_params_config', 'prune_params', 'check_search_space']
WEIGHT_OP = ['conv2d', 'conv3d', 'conv1d', 'linear', 'embedding']
CONV_TYPES = [
'conv2d', 'conv3d', 'conv1d', 'superconv2d', 'supergroupconv2d',
'superdepthwiseconv2d'
]
def get_prune_params_config(graph, origin_model_config):
""" Convert config of search space to parameters' prune config.
"""
param_config = {}
precedor = None
for op in graph.ops():
......@@ -68,40 +77,124 @@ def get_prune_params_config(graph, origin_model_config):
def prune_params(model, param_config, super_model_sd=None):
for name, param in model.named_parameters():
t_value = param.value().get_tensor()
value = np.array(t_value).astype("float32")
if super_model_sd != None:
super_t_value = super_model_sd[name].value().get_tensor()
super_value = np.array(super_t_value).astype("float32")
if param.name in param_config.keys():
if len(param_config[param.name]) > 1:
in_exp = param_config[param.name][0]
out_exp = param_config[param.name][1]
in_chn = int(value.shape[0]) if in_exp == None else int(
value.shape[0] * in_exp)
out_chn = int(value.shape[1]) if out_exp == None else int(
value.shape[1] * out_exp)
prune_value = super_value[:in_chn, :out_chn, ...] \
if super_model_sd != None else value[:in_chn, :out_chn, ...]
""" Prune parameters according to the config.
Parameters:
model(paddle.nn.Layer): instance of model.
param_config(dict): prune config of each weight.
super_model_sd(dict, optional): parameters come from supernet. If super_model_sd is not None, transfer parameters from this dict to model; otherwise, prune model from itself.
"""
for l_name, sublayer in model.named_sublayers():
if isinstance(sublayer, BaseBlock):
continue
for p_name, param in sublayer.named_parameters(include_sublayers=False):
t_value = param.value().get_tensor()
value = np.array(t_value).astype("float32")
if super_model_sd != None:
name = l_name + '.' + p_name
super_t_value = super_model_sd[name].value().get_tensor()
super_value = np.array(super_t_value).astype("float32")
if param.name in param_config.keys():
if len(param_config[param.name]) > 1:
in_exp = param_config[param.name][0]
out_exp = param_config[param.name][1]
if sublayer.__class__.__name__.lower() in CONV_TYPES:
in_chn = int(value.shape[1]) if in_exp == None else int(
value.shape[1] * in_exp)
out_chn = int(value.shape[
0]) if out_exp == None else int(value.shape[0] *
out_exp)
prune_value = super_value[:out_chn, :in_chn, ...] \
if super_model_sd != None else value[:out_chn, :in_chn, ...]
else:
in_chn = int(value.shape[0]) if in_exp == None else int(
value.shape[0] * in_exp)
out_chn = int(value.shape[
1]) if out_exp == None else int(value.shape[1] *
out_exp)
prune_value = super_value[:in_chn, :out_chn, ...] \
if super_model_sd != None else value[:in_chn, :out_chn, ...]
else:
out_chn = int(value.shape[0]) if param_config[param.name][
0] == None else int(value.shape[0] *
param_config[param.name][0])
prune_value = super_value[:out_chn, ...] \
if super_model_sd != None else value[:out_chn, ...]
else:
out_chn = int(value.shape[0]) if param_config[param.name][
0] == None else int(value.shape[0] *
param_config[param.name][0])
prune_value = super_value[:out_chn, ...] \
if super_model_sd != None else value[:out_chn, ...]
else:
prune_value = super_value if super_model_sd != None else value
p = t_value._place()
if p.is_cpu_place():
place = core.CPUPlace()
elif p.is_cuda_pinned_place():
place = core.CUDAPinnedPlace()
else:
place = core.CUDAPlace(p.gpu_device_id())
t_value.set(prune_value, place)
if param.trainable:
param.clear_gradient()
prune_value = super_value if super_model_sd != None else value
p = t_value._place()
if p.is_cpu_place():
place = core.CPUPlace()
elif p.is_cuda_pinned_place():
place = core.CUDAPinnedPlace()
else:
place = core.CUDAPlace(p.gpu_device_id())
t_value.set(prune_value, place)
if param.trainable:
param.clear_gradient()
def _find_weight_ops(op, graph, weights):
""" Find the vars come from operators with weight.
"""
pre_ops = graph.pre_ops(op)
for pre_op in pre_ops:
if pre_op.type() in WEIGHT_OP:
for inp in pre_op.all_inputs():
if inp._var.persistable:
weights.append(inp._var.name)
return weights
return _find_weight_ops(pre_op, graph, weights)
def _find_pre_elementwise_add(op, graph):
""" Find precedors of the elementwise_add operator in the model.
"""
same_prune_before_elementwise_add = []
pre_ops = graph.pre_ops(op)
for pre_op in pre_ops:
if pre_op.type() in WEIGHT_OP:
return
same_prune_before_elementwise_add = _find_weight_ops(
pre_op, graph, same_prune_before_elementwise_add)
return same_prune_before_elementwise_add
def check_search_space(graph):
""" Find the shortcut in the model and set same config for this situation.
"""
same_search_space = []
for op in graph.ops():
if op.type() == 'elementwise_add':
inp1, inp2 = op.all_inputs()[0], op.all_inputs()[1]
if (not inp1._var.persistable) and (not inp2._var.persistable):
pre_ele_op = _find_pre_elementwise_add(op, graph)
if pre_ele_op != None:
same_search_space.append(pre_ele_op)
if len(same_search_space) == 0:
return None
same_search_space = sorted([sorted(x) for x in same_search_space])
final_search_space = []
if len(same_search_space) >= 1:
final_search_space = [same_search_space[0]]
if len(same_search_space) > 1:
for l in same_search_space[1:]:
listset = set(l)
merged = False
for idx in range(len(final_search_space)):
rset = set(final_search_space[idx])
if len(listset & rset) != 0:
final_search_space[idx] = list(listset | rset)
merged = True
break
if not merged:
final_search_space.append(l)
return final_search_space
......@@ -23,10 +23,11 @@ import paddle.fluid.core as core
from ...common import get_logger
from .utils.utils import compute_start_end, get_same_padding, convert_to_list
from .layers_base import *
__all__ = [
'SuperConv2D', 'SuperConv2DTranspose', 'SuperSeparableConv2D',
'SuperBatchNorm2D', 'SuperLinear', 'SuperInstanceNorm2D', 'Block',
'SuperBatchNorm2D', 'SuperLinear', 'SuperInstanceNorm2D',
'SuperGroupConv2D', 'SuperDepthwiseConv2D', 'SuperGroupConv2DTranspose',
'SuperDepthwiseConv2DTranspose', 'SuperLayerNorm', 'SuperEmbedding'
]
......@@ -35,52 +36,6 @@ _logger = get_logger(__name__, level=logging.INFO)
### TODO: if task is elastic width, need to add re_organize_middle_weight in 1x1 conv in MBBlock
_cnt = 0
def counter():
global _cnt
_cnt += 1
return _cnt
class BaseBlock(paddle.nn.Layer):
def __init__(self, key=None):
super(BaseBlock, self).__init__()
if key is not None:
self._key = str(key)
else:
self._key = self.__class__.__name__ + str(counter())
# set SuperNet class
def set_supernet(self, supernet):
self.__dict__['supernet'] = supernet
@property
def key(self):
return self._key
class Block(BaseBlock):
"""
Model is composed of nest blocks.
Parameters:
fn(paddle.nn.Layer): instance of super layers, such as: SuperConv2D(3, 5, 3).
fixed(bool, optional): whether to fix the shape of the weight in this layer. Default: False.
key(str, optional): key of this layer, one-to-one correspondence between key and candidate config. Default: None.
"""
def __init__(self, fn, fixed=False, key=None):
super(Block, self).__init__(key)
self.fn = fn
self.fixed = fixed
self.candidate_config = self.fn.candidate_config
def forward(self, *inputs, **kwargs):
out = self.supernet.layers_forward(self, *inputs, **kwargs)
return out
class SuperConv2D(nn.Conv2D):
"""
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .utils.utils import get_paddle_version
pd_ver = get_paddle_version()
import paddle
if pd_ver == 185:
Layer = paddle.fluid.dygraph.Layer
else:
Layer = paddle.nn.Layer
_cnt = 0
def counter():
global _cnt
_cnt += 1
return _cnt
class BaseBlock(Layer):
def __init__(self, key=None):
super(BaseBlock, self).__init__()
if key is not None:
self._key = str(key)
else:
self._key = self.__class__.__name__ + str(counter())
# set SuperNet class
def set_supernet(self, supernet):
self.__dict__['supernet'] = supernet
@property
def key(self):
return self._key
class Block(BaseBlock):
"""
Model is composed of nest blocks.
Parameters:
fn(paddle.nn.Layer): instance of super layers, such as: SuperConv2D(3, 5, 3).
fixed(bool, optional): whether to fix the shape of the weight in this layer. Default: False.
key(str, optional): key of this layer, one-to-one correspondence between key and candidate config. Default: None.
"""
def __init__(self, fn, fixed=False, key=None):
super(Block, self).__init__(key)
self.fn = fn
self.fixed = fixed
self.candidate_config = self.fn.candidate_config
def forward(self, *inputs, **kwargs):
out = self.supernet.layers_forward(self, *inputs, **kwargs)
return out
......@@ -25,11 +25,12 @@ from paddle.fluid.dygraph.nn import InstanceNorm, Conv2D, Conv2DTranspose, Batch
from ...common import get_logger
from .utils.utils import compute_start_end, get_same_padding, convert_to_list
from .layers_base import *
__all__ = [
'SuperConv2D', 'SuperConv2DTranspose', 'SuperSeparableConv2D',
'SuperBatchNorm', 'SuperLinear', 'SuperInstanceNorm', 'Block',
'SuperGroupConv2D', 'SuperDepthwiseConv2D', 'SuperGroupConv2DTranspose',
'SuperBatchNorm', 'SuperLinear', 'SuperInstanceNorm', 'SuperGroupConv2D',
'SuperDepthwiseConv2D', 'SuperGroupConv2DTranspose',
'SuperDepthwiseConv2DTranspose', 'SuperLayerNorm', 'SuperEmbedding'
]
......@@ -37,51 +38,6 @@ _logger = get_logger(__name__, level=logging.INFO)
### TODO: if task is elastic width, need to add re_organize_middle_weight in 1x1 conv in MBBlock
_cnt = 0
def counter():
global _cnt
_cnt += 1
return _cnt
class BaseBlock(fluid.dygraph.Layer):
def __init__(self, key=None):
super(BaseBlock, self).__init__()
if key is not None:
self._key = str(key)
else:
self._key = self.__class__.__name__ + str(counter())
# set SuperNet class
def set_supernet(self, supernet):
self.__dict__['supernet'] = supernet
@property
def key(self):
return self._key
class Block(BaseBlock):
"""
Model is composed of nest blocks.
Parameters:
fn(Layer): instance of super layers, such as: SuperConv2D(3, 5, 3).
key(str, optional): key of this layer, one-to-one correspondence between key and candidate config. Default: None.
"""
def __init__(self, fn, fixed=False, key=None):
super(Block, self).__init__(key)
self.fn = fn
self.fixed = fixed
self.candidate_config = self.fn.candidate_config
def forward(self, *inputs, **kwargs):
out = self.supernet.layers_forward(self, *inputs, **kwargs)
return out
class SuperConv2D(fluid.dygraph.Conv2D):
"""
......
......@@ -20,15 +20,18 @@ import paddle.fluid as fluid
from .utils.utils import get_paddle_version, remove_model_fn
pd_ver = get_paddle_version()
if pd_ver == 185:
from .layers_old import BaseBlock, SuperConv2D, SuperLinear
from .layers_old import SuperConv2D, SuperLinear
Layer = paddle.fluid.dygraph.Layer
DataParallel = paddle.fluid.dygraph.DataParallel
else:
from .layers import BaseBlock, SuperConv2D, SuperLinear
from .layers import SuperConv2D, SuperLinear
Layer = paddle.nn.Layer
DataParallel = paddle.DataParallel
from .layers_base import BaseBlock
from .utils.utils import search_idx
from ...common import get_logger
from ...core import GraphWrapper, dygraph2program
from .get_sub_model import get_prune_params_config, prune_params
from .get_sub_model import get_prune_params_config, prune_params, check_search_space
_logger = get_logger(__name__, level=logging.INFO)
......@@ -75,19 +78,26 @@ class OFABase(Layer):
def __init__(self, model):
super(OFABase, self).__init__()
self.model = model
self._layers, self._elastic_task = self.get_layers()
self._ofa_layers, self._elastic_task, self._key2name, self._layers = self.get_layers(
)
def get_layers(self):
ofa_layers = dict()
layers = dict()
key2name = dict()
elastic_task = set()
for name, sublayer in self.model.named_sublayers():
model_to_traverse = self.model._layers if isinstance(
self.model, DataParallel) else self.model
for name, sublayer in model_to_traverse.named_sublayers():
if isinstance(sublayer, BaseBlock):
sublayer.set_supernet(self)
if not sublayer.fixed:
ofa_layers[name] = sublayer.candidate_config
layers[sublayer.key] = sublayer.candidate_config
key2name[sublayer.key] = name
for k in sublayer.candidate_config.keys():
elastic_task.add(k)
return layers, elastic_task
return ofa_layers, elastic_task, key2name, layers
def forward(self, *inputs, **kwargs):
raise NotImplementedError
......@@ -97,9 +107,11 @@ class OFABase(Layer):
### if block is fixed, donnot join key into candidate
### concrete config as parameter in kwargs
if block.fixed == False:
assert block.key in self.current_config, 'DONNT have {} layer in config.'.format(
block.key)
config = self.current_config[block.key]
assert self._key2name[
block.
key] in self.current_config, 'DONNT have {} layer in config.'.format(
self._key2name[block.key])
config = self.current_config[self._key2name[block.key]]
else:
config = dict()
config.update(kwargs)
......@@ -109,6 +121,10 @@ class OFABase(Layer):
return block.fn(*inputs, **config)
@property
def ofa_layers(self):
return self._ofa_layers
@property
def layers(self):
return self._layers
......@@ -156,6 +172,9 @@ class OFA(OFABase):
self.task_idx = 0
self._add_teacher = False
self.netAs_param = []
self._mapping_layers = None
self._build_ss = False
self._broadcast = False
### if elastic_order is none, use default order
if self.elastic_order is not None:
......@@ -165,7 +184,7 @@ class OFA(OFABase):
if getattr(self.run_config, 'elastic_depth', None) != None:
depth_list = list(set(self.run_config.elastic_depth))
depth_list.sort()
self.layers['depth'] = depth_list
self._ofa_layers['depth'] = depth_list
if self.elastic_order is None:
self.elastic_order = []
......@@ -178,7 +197,7 @@ class OFA(OFABase):
if getattr(self.run_config, 'elastic_depth', None) != None:
depth_list = list(set(self.run_config.elastic_depth))
depth_list.sort()
self.layers['depth'] = depth_list
self._ofa_layers['depth'] = depth_list
self.elastic_order.append('depth')
# final, elastic width
......@@ -236,9 +255,14 @@ class OFA(OFABase):
# if mapping layer is NOT None, add hook and compute distill loss about mapping layers.
mapping_layers = getattr(self.distill_config, 'mapping_layers', None)
if mapping_layers != None:
if isinstance(self.model, DataParallel):
for idx, name in enumerate(mapping_layers):
if name[:7] != '_layers':
mapping_layers[idx] = '_layers.' + name
self._mapping_layers = mapping_layers
self.netAs = []
for name, sublayer in self.model.named_sublayers():
if name in mapping_layers:
if name in self._mapping_layers:
if self.distill_config.mapping_op != None:
if self.distill_config.mapping_op.lower() == 'conv2d':
netA = SuperConv2D(
......@@ -265,8 +289,7 @@ class OFA(OFABase):
def _reset_hook_before_forward(self):
self.Tacts, self.Sacts = {}, {}
mapping_layers = getattr(self.distill_config, 'mapping_layers', None)
if mapping_layers != None:
if self._mapping_layers != None:
def get_activation(mem, name):
def get_output_hook(layer, input, output):
......@@ -279,8 +302,9 @@ class OFA(OFABase):
if n in mapping_layers:
m.register_forward_post_hook(get_activation(mem, n))
add_hook(self.model, self.Sacts, mapping_layers)
add_hook(self.ofa_teacher_model.model, self.Tacts, mapping_layers)
add_hook(self.model, self.Sacts, self._mapping_layers)
add_hook(self.ofa_teacher_model.model, self.Tacts,
self._mapping_layers)
def _compute_epochs(self):
if getattr(self, 'epoch', None) == None:
......@@ -326,7 +350,7 @@ class OFA(OFABase):
def _sample_config(self, task, sample_type='random', phase=None):
config = self._sample_from_nestdict(
self.layers, sample_type=sample_type, task=task, phase=phase)
self._ofa_layers, sample_type=sample_type, task=task, phase=phase)
return config
def set_task(self, task, phase=None):
......@@ -356,7 +380,13 @@ class OFA(OFABase):
def _progressive_shrinking(self):
epoch = self._compute_epochs()
self.task_idx, phase_idx = search_idx(epoch, self.run_config.n_epochs)
phase_idx = None
if len(self.elastic_order) != 1:
assert self.run_config.n_epochs is not None, \
"if not use set_task() to set current task, please set n_epochs in run_config " \
"for to compute which task in this epoch."
self.task_idx, phase_idx = search_idx(epoch,
self.run_config.n_epochs)
self.task = self.elastic_order[:self.task_idx + 1]
if 'width' in self.task:
### change width in task to concrete config
......@@ -365,8 +395,6 @@ class OFA(OFABase):
self.task.append('expand_ratio')
if 'channel' in self._elastic_task:
self.task.append('channel')
if len(self.run_config.n_epochs[self.task_idx]) == 1:
phase_idx = None
return self._sample_config(task=self.task, phase=phase_idx)
def calc_distill_loss(self):
......@@ -413,21 +441,13 @@ class OFA(OFABase):
def _export_sub_model_config(self, origin_model, config, input_shapes,
input_dtypes):
super_model_config = {}
for name, sublayer in self.model.named_sublayers():
if isinstance(sublayer, BaseBlock):
for param in sublayer.parameters():
super_model_config[name] = sublayer.key
for name, value in super_model_config.items():
super_model_config[name] = config[value] if value in config.keys(
) else {}
origin_model_config = {}
for name, sublayer in origin_model.named_sublayers():
if isinstance(sublayer, BaseBlock):
sublayer = sublayer.fn
for param in sublayer.parameters(include_sublayers=False):
if name in super_model_config.keys():
origin_model_config[param.name] = super_model_config[name]
if name in config.keys():
origin_model_config[param.name] = config[name]
program = dygraph2program(
origin_model, inputs=input_shapes, dtypes=input_dtypes)
......@@ -436,10 +456,10 @@ class OFA(OFABase):
return param_prune_config
def export(self,
origin_model,
config,
input_shapes,
input_dtypes,
origin_model=None,
load_weights_from_supernet=True):
"""
Export the weights according origin model and sub model config.
......@@ -458,9 +478,14 @@ class OFA(OFABase):
origin_model = ofa_model.export(origin_model, config, input_shapes=[1, 3, 28, 28], input_dtypes=['float32'])
"""
super_sd = None
if load_weights_from_supernet:
if load_weights_from_supernet and origin_model != None:
super_sd = remove_model_fn(origin_model, self.model.state_dict())
if origin_model == None:
origin_model = self.model
origin_model = origin_model._layers if isinstance(
origin_model, DataParallel) else origin_model
param_config = self._export_sub_model_config(origin_model, config,
input_shapes, input_dtypes)
prune_params(origin_model, param_config, super_sd)
......@@ -482,6 +507,88 @@ class OFA(OFABase):
"""
self.net_config = net_config
def _find_ele(self, inp, targets):
def _roll_eles(target_list, types=(list, set, tuple)):
if isinstance(target_list, types):
for targ in target_list:
for v in _roll_eles(targ, types):
yield v
else:
yield target_list
if inp in list(_roll_eles(targets)):
return True
else:
return False
def _clear_search_space(self, *inputs, **kwargs):
""" find shortcut in model, and clear up the search space """
input_shapes = []
input_dtypes = []
for n in inputs:
input_shapes.append(n.shape)
input_dtypes.append(n.numpy().dtype)
for n, v in kwargs.items():
input_shapes.append(v.shape)
input_dtypes.append(v.numpy().dtype)
### find shortcut block using static model
_st_prog = dygraph2program(
self.model, inputs=input_shapes, dtypes=input_dtypes)
self._same_ss = check_search_space(GraphWrapper(_st_prog))
if self._same_ss != None:
self._same_ss = sorted(self._same_ss)
self._param2key = {}
self._broadcast = True
### the name of sublayer is the key in search space
### param.name is the name in self._same_ss
model_to_traverse = self.model._layers if isinstance(
self.model, DataParallel) else self.model
for name, sublayer in model_to_traverse.named_sublayers():
if isinstance(sublayer, BaseBlock):
for param in sublayer.parameters():
if self._find_ele(param.name, self._same_ss):
self._param2key[param.name] = name
for per_ss in self._same_ss:
for ss in per_ss[1:]:
if 'expand_ratio' in self._ofa_layers[self._param2key[ss]]:
self._ofa_layers[self._param2key[ss]].pop(
'expand_ratio')
elif 'channel' in self._ofa_layers[self._param2key[ss]]:
self._ofa_layers[self._param2key[ss]].pop('channel')
if len(self._ofa_layers[self._param2key[ss]]) == 0:
self._ofa_layers.pop(self._param2key[ss])
def _broadcast_ss(self):
""" broadcast search space after random sample."""
for per_ss in self._same_ss:
for ss in per_ss[1:]:
key = self._param2key[ss]
pre_key = self._param2key[per_ss[0]]
if key in self.current_config:
if 'expand_ratio' in self.current_config[pre_key]:
self.current_config[key].update({
'expand_ratio':
self.current_config[pre_key]['expand_ratio']
})
elif 'channel' in self.current_config[pre_key]:
self.current_config[key].update({
'channel': self.current_config[pre_key]['channel']
})
else:
if 'expand_ratio' in self.current_config[pre_key]:
self.current_config[key] = {
'expand_ratio':
self.current_config[pre_key]['expand_ratio']
}
elif 'channel' in self.current_config[pre_key]:
self.current_config[key] = {
'channel': self.current_config[pre_key]['channel']
}
def forward(self, *inputs, **kwargs):
# ===================== teacher process =====================
teacher_output = None
......@@ -492,6 +599,10 @@ class OFA(OFABase):
# ============================================================
# ==================== student process =====================
if not self._build_ss and self.net_config == None:
self._clear_search_space(*inputs, **kwargs)
self._build_ss = True
if getattr(self.run_config, 'dynamic_batch_size', None) != None:
self.dynamic_iter += 1
if self.dynamic_iter == self.run_config.dynamic_batch_size[
......@@ -516,4 +627,7 @@ class OFA(OFABase):
if 'depth' in self.current_config:
kwargs['depth'] = self.current_config['depth']
if self._broadcast:
self._broadcast_ss()
return self.model.forward(*inputs, **kwargs), teacher_output
......@@ -272,6 +272,31 @@ def _encoder_forward(self, src, src_mask=[None, None]):
return output
def _encoder_layer_forward(self, src, src_mask=None, cache=None):
residual = src
if self.normalize_before:
src = self.norm1(src)
# Add cache for encoder for the usage like UniLM
if cache is None:
src = self.self_attn(src, src, src, src_mask)
else:
src, incremental_cache = self.self_attn(src, src, src, src_mask, cache)
src = residual + self.dropout1(src)
if not self.normalize_before:
src = self.norm1(src)
residual = src
if self.normalize_before:
src = self.norm2(src)
src = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = residual + self.dropout2(src)
if not self.normalize_before:
src = self.norm2(src)
return src if cache is None else (src, incremental_cache)
nn.MultiHeadAttention.forward = _mha_forward
nn.MultiHeadAttention._prepare_qkv = _prepare_qkv
nn.TransformerEncoder.forward = _encoder_forward
nn.TransformerEncoderLayer.forward = _encoder_layer_forward
......@@ -45,5 +45,6 @@ def dynabert_config(model, width_mult, depth_mult=1.0):
if block_k == 'depth':
block_v = depth_mult
new_config[block_k] = block_v
new_block_k = model._key2name[block_k]
new_config[new_block_k] = block_v
return new_config
......@@ -48,7 +48,7 @@ def set_state_dict(model, state_dict):
"""
assert isinstance(model, Layer)
assert isinstance(state_dict, dict)
for name, param in model.named_parameters():
for name, param in model.state_dict().items():
tmp_n = name.split('.')[:-2] + [name.split('.')[-1]]
tmp_n = '.'.join(tmp_n)
if name in state_dict:
......@@ -59,16 +59,15 @@ def set_state_dict(model, state_dict):
_logger.info('{} is not in state_dict'.format(tmp_n))
def remove_model_fn(model, sd):
def remove_model_fn(model, state_dict):
new_dict = {}
keys = []
for name, param in model.named_parameters():
for name, param in model.state_dict().items():
keys.append(name)
for name, param in sd.items():
for name, param in state_dict.items():
if name.split('.')[-2] == 'fn':
tmp_n = name.split('.')[:-2] + [name.split('.')[-1]]
tmp_n = '.'.join(tmp_n)
#print(name, tmp_n)
if name in keys:
new_dict[name] = param
elif tmp_n in keys:
......
......@@ -31,7 +31,7 @@ class TestConvertSuper(unittest.TestCase):
assert len(sp_model.sublayers()) == 151
class TestConvertSuper(unittest.TestCase):
class TestConvertSuperCase1(unittest.TestCase):
def setUp(self):
class Model(nn.Layer):
def __init__(self):
......
......@@ -19,10 +19,12 @@ import unittest
import paddle
import paddle.nn as nn
from paddle.nn import ReLU
from paddle.vision.models import resnet50
from paddleslim.nas import ofa
from paddleslim.nas.ofa import OFA, RunConfig, DistillConfig
from paddleslim.nas.ofa.convert_super import supernet
from paddleslim.nas.ofa.layers import Block, SuperSeparableConv2D
from paddleslim.nas.ofa.convert_super import Convert, supernet
class ModelConv(nn.Layer):
......@@ -413,10 +415,10 @@ class TestExport(unittest.TestCase):
for name, param in self.origin_model.named_parameters():
origin_dict[name] = param.shape
self.ofa_model.export(
self.origin_model,
config,
input_shapes=[[1, 64]],
input_dtypes=['int64'])
input_dtypes=['int64'],
origin_model=self.origin_model)
for name, param in self.origin_model.named_parameters():
if name in config.keys():
if 'expand_ratio' in config[name]:
......@@ -424,5 +426,27 @@ class TestExport(unittest.TestCase):
name]['expand_ratio']
class TestShortCut(unittest.TestCase):
def setUp(self):
model = resnet50()
sp_net_config = supernet(expand_ratio=[0.25, 0.5, 1.0])
self.model = Convert(sp_net_config).convert(model)
self.images = paddle.randn(shape=[2, 3, 224, 224], dtype='float32')
self._test_clear_search_space()
def _test_clear_search_space(self):
self.ofa_model = OFA(self.model)
self.ofa_model.set_epoch(0)
outs, _ = self.ofa_model(self.images)
self.config = self.ofa_model.current_config
def test_export_model(self):
self.ofa_model.export(
self.config,
input_shapes=[[2, 3, 224, 224]],
input_dtypes=['float32'])
assert len(self.ofa_model.ofa_layers) == 38
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册