未验证 提交 dae3c307 编写于 作者: C ceci3 提交者: GitHub

fix ofa search space when model has shortcut (#679)

* fix shortcut

* fix when kernel_size in ss

* fix export

* fix dp

* fix export model prune param twice

* add comment

* fix unittest

* fix

* update
上级 adc57c75
...@@ -124,10 +124,10 @@ def do_train(args): ...@@ -124,10 +124,10 @@ def do_train(args):
ofa_model.model.set_state_dict(sd) ofa_model.model.set_state_dict(sd)
best_config = utils.dynabert_config(ofa_model, args.width_mult) best_config = utils.dynabert_config(ofa_model, args.width_mult)
ofa_model.export( ofa_model.export(
origin_model,
best_config, best_config,
input_shapes=[[1, args.max_seq_length], [1, args.max_seq_length]], input_shapes=[[1, args.max_seq_length], [1, args.max_seq_length]],
input_dtypes=['int64', 'int64']) input_dtypes=['int64', 'int64'],
origin_model=origin_model)
for name, sublayer in origin_model.named_sublayers(): for name, sublayer in origin_model.named_sublayers():
if isinstance(sublayer, paddle.nn.MultiHeadAttention): if isinstance(sublayer, paddle.nn.MultiHeadAttention):
sublayer.num_heads = int(args.width_mult * sublayer.num_heads) sublayer.num_heads = int(args.width_mult * sublayer.num_heads)
......
...@@ -34,6 +34,7 @@ else: ...@@ -34,6 +34,7 @@ else:
from .layers import * from .layers import *
from . import layers from . import layers
Layer = paddle.nn.Layer Layer = paddle.nn.Layer
from .layers_base import Block
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
......
...@@ -15,11 +15,20 @@ ...@@ -15,11 +15,20 @@
import numpy as np import numpy as np
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
from .layers_base import BaseBlock
__all__ = ['get_prune_params_config', 'prune_params'] __all__ = ['get_prune_params_config', 'prune_params', 'check_search_space']
WEIGHT_OP = ['conv2d', 'conv3d', 'conv1d', 'linear', 'embedding']
CONV_TYPES = [
'conv2d', 'conv3d', 'conv1d', 'superconv2d', 'supergroupconv2d',
'superdepthwiseconv2d'
]
def get_prune_params_config(graph, origin_model_config): def get_prune_params_config(graph, origin_model_config):
""" Convert config of search space to parameters' prune config.
"""
param_config = {} param_config = {}
precedor = None precedor = None
for op in graph.ops(): for op in graph.ops():
...@@ -68,40 +77,124 @@ def get_prune_params_config(graph, origin_model_config): ...@@ -68,40 +77,124 @@ def get_prune_params_config(graph, origin_model_config):
def prune_params(model, param_config, super_model_sd=None): def prune_params(model, param_config, super_model_sd=None):
for name, param in model.named_parameters(): """ Prune parameters according to the config.
t_value = param.value().get_tensor()
value = np.array(t_value).astype("float32") Parameters:
model(paddle.nn.Layer): instance of model.
if super_model_sd != None: param_config(dict): prune config of each weight.
super_t_value = super_model_sd[name].value().get_tensor() super_model_sd(dict, optional): parameters come from supernet. If super_model_sd is not None, transfer parameters from this dict to model; otherwise, prune model from itself.
super_value = np.array(super_t_value).astype("float32") """
for l_name, sublayer in model.named_sublayers():
if param.name in param_config.keys(): if isinstance(sublayer, BaseBlock):
if len(param_config[param.name]) > 1: continue
in_exp = param_config[param.name][0] for p_name, param in sublayer.named_parameters(include_sublayers=False):
out_exp = param_config[param.name][1] t_value = param.value().get_tensor()
in_chn = int(value.shape[0]) if in_exp == None else int( value = np.array(t_value).astype("float32")
value.shape[0] * in_exp)
out_chn = int(value.shape[1]) if out_exp == None else int( if super_model_sd != None:
value.shape[1] * out_exp) name = l_name + '.' + p_name
prune_value = super_value[:in_chn, :out_chn, ...] \ super_t_value = super_model_sd[name].value().get_tensor()
if super_model_sd != None else value[:in_chn, :out_chn, ...] super_value = np.array(super_t_value).astype("float32")
if param.name in param_config.keys():
if len(param_config[param.name]) > 1:
in_exp = param_config[param.name][0]
out_exp = param_config[param.name][1]
if sublayer.__class__.__name__.lower() in CONV_TYPES:
in_chn = int(value.shape[1]) if in_exp == None else int(
value.shape[1] * in_exp)
out_chn = int(value.shape[
0]) if out_exp == None else int(value.shape[0] *
out_exp)
prune_value = super_value[:out_chn, :in_chn, ...] \
if super_model_sd != None else value[:out_chn, :in_chn, ...]
else:
in_chn = int(value.shape[0]) if in_exp == None else int(
value.shape[0] * in_exp)
out_chn = int(value.shape[
1]) if out_exp == None else int(value.shape[1] *
out_exp)
prune_value = super_value[:in_chn, :out_chn, ...] \
if super_model_sd != None else value[:in_chn, :out_chn, ...]
else:
out_chn = int(value.shape[0]) if param_config[param.name][
0] == None else int(value.shape[0] *
param_config[param.name][0])
prune_value = super_value[:out_chn, ...] \
if super_model_sd != None else value[:out_chn, ...]
else: else:
out_chn = int(value.shape[0]) if param_config[param.name][ prune_value = super_value if super_model_sd != None else value
0] == None else int(value.shape[0] *
param_config[param.name][0]) p = t_value._place()
prune_value = super_value[:out_chn, ...] \ if p.is_cpu_place():
if super_model_sd != None else value[:out_chn, ...] place = core.CPUPlace()
else: elif p.is_cuda_pinned_place():
prune_value = super_value if super_model_sd != None else value place = core.CUDAPinnedPlace()
else:
p = t_value._place() place = core.CUDAPlace(p.gpu_device_id())
if p.is_cpu_place(): t_value.set(prune_value, place)
place = core.CPUPlace() if param.trainable:
elif p.is_cuda_pinned_place(): param.clear_gradient()
place = core.CUDAPinnedPlace()
else:
place = core.CUDAPlace(p.gpu_device_id()) def _find_weight_ops(op, graph, weights):
t_value.set(prune_value, place) """ Find the vars come from operators with weight.
if param.trainable: """
param.clear_gradient() pre_ops = graph.pre_ops(op)
for pre_op in pre_ops:
if pre_op.type() in WEIGHT_OP:
for inp in pre_op.all_inputs():
if inp._var.persistable:
weights.append(inp._var.name)
return weights
return _find_weight_ops(pre_op, graph, weights)
def _find_pre_elementwise_add(op, graph):
""" Find precedors of the elementwise_add operator in the model.
"""
same_prune_before_elementwise_add = []
pre_ops = graph.pre_ops(op)
for pre_op in pre_ops:
if pre_op.type() in WEIGHT_OP:
return
same_prune_before_elementwise_add = _find_weight_ops(
pre_op, graph, same_prune_before_elementwise_add)
return same_prune_before_elementwise_add
def check_search_space(graph):
""" Find the shortcut in the model and set same config for this situation.
"""
same_search_space = []
for op in graph.ops():
if op.type() == 'elementwise_add':
inp1, inp2 = op.all_inputs()[0], op.all_inputs()[1]
if (not inp1._var.persistable) and (not inp2._var.persistable):
pre_ele_op = _find_pre_elementwise_add(op, graph)
if pre_ele_op != None:
same_search_space.append(pre_ele_op)
if len(same_search_space) == 0:
return None
same_search_space = sorted([sorted(x) for x in same_search_space])
final_search_space = []
if len(same_search_space) >= 1:
final_search_space = [same_search_space[0]]
if len(same_search_space) > 1:
for l in same_search_space[1:]:
listset = set(l)
merged = False
for idx in range(len(final_search_space)):
rset = set(final_search_space[idx])
if len(listset & rset) != 0:
final_search_space[idx] = list(listset | rset)
merged = True
break
if not merged:
final_search_space.append(l)
return final_search_space
...@@ -23,10 +23,11 @@ import paddle.fluid.core as core ...@@ -23,10 +23,11 @@ import paddle.fluid.core as core
from ...common import get_logger from ...common import get_logger
from .utils.utils import compute_start_end, get_same_padding, convert_to_list from .utils.utils import compute_start_end, get_same_padding, convert_to_list
from .layers_base import *
__all__ = [ __all__ = [
'SuperConv2D', 'SuperConv2DTranspose', 'SuperSeparableConv2D', 'SuperConv2D', 'SuperConv2DTranspose', 'SuperSeparableConv2D',
'SuperBatchNorm2D', 'SuperLinear', 'SuperInstanceNorm2D', 'Block', 'SuperBatchNorm2D', 'SuperLinear', 'SuperInstanceNorm2D',
'SuperGroupConv2D', 'SuperDepthwiseConv2D', 'SuperGroupConv2DTranspose', 'SuperGroupConv2D', 'SuperDepthwiseConv2D', 'SuperGroupConv2DTranspose',
'SuperDepthwiseConv2DTranspose', 'SuperLayerNorm', 'SuperEmbedding' 'SuperDepthwiseConv2DTranspose', 'SuperLayerNorm', 'SuperEmbedding'
] ]
...@@ -35,52 +36,6 @@ _logger = get_logger(__name__, level=logging.INFO) ...@@ -35,52 +36,6 @@ _logger = get_logger(__name__, level=logging.INFO)
### TODO: if task is elastic width, need to add re_organize_middle_weight in 1x1 conv in MBBlock ### TODO: if task is elastic width, need to add re_organize_middle_weight in 1x1 conv in MBBlock
_cnt = 0
def counter():
global _cnt
_cnt += 1
return _cnt
class BaseBlock(paddle.nn.Layer):
def __init__(self, key=None):
super(BaseBlock, self).__init__()
if key is not None:
self._key = str(key)
else:
self._key = self.__class__.__name__ + str(counter())
# set SuperNet class
def set_supernet(self, supernet):
self.__dict__['supernet'] = supernet
@property
def key(self):
return self._key
class Block(BaseBlock):
"""
Model is composed of nest blocks.
Parameters:
fn(paddle.nn.Layer): instance of super layers, such as: SuperConv2D(3, 5, 3).
fixed(bool, optional): whether to fix the shape of the weight in this layer. Default: False.
key(str, optional): key of this layer, one-to-one correspondence between key and candidate config. Default: None.
"""
def __init__(self, fn, fixed=False, key=None):
super(Block, self).__init__(key)
self.fn = fn
self.fixed = fixed
self.candidate_config = self.fn.candidate_config
def forward(self, *inputs, **kwargs):
out = self.supernet.layers_forward(self, *inputs, **kwargs)
return out
class SuperConv2D(nn.Conv2D): class SuperConv2D(nn.Conv2D):
""" """
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .utils.utils import get_paddle_version
pd_ver = get_paddle_version()
import paddle
if pd_ver == 185:
Layer = paddle.fluid.dygraph.Layer
else:
Layer = paddle.nn.Layer
_cnt = 0
def counter():
global _cnt
_cnt += 1
return _cnt
class BaseBlock(Layer):
def __init__(self, key=None):
super(BaseBlock, self).__init__()
if key is not None:
self._key = str(key)
else:
self._key = self.__class__.__name__ + str(counter())
# set SuperNet class
def set_supernet(self, supernet):
self.__dict__['supernet'] = supernet
@property
def key(self):
return self._key
class Block(BaseBlock):
"""
Model is composed of nest blocks.
Parameters:
fn(paddle.nn.Layer): instance of super layers, such as: SuperConv2D(3, 5, 3).
fixed(bool, optional): whether to fix the shape of the weight in this layer. Default: False.
key(str, optional): key of this layer, one-to-one correspondence between key and candidate config. Default: None.
"""
def __init__(self, fn, fixed=False, key=None):
super(Block, self).__init__(key)
self.fn = fn
self.fixed = fixed
self.candidate_config = self.fn.candidate_config
def forward(self, *inputs, **kwargs):
out = self.supernet.layers_forward(self, *inputs, **kwargs)
return out
...@@ -25,11 +25,12 @@ from paddle.fluid.dygraph.nn import InstanceNorm, Conv2D, Conv2DTranspose, Batch ...@@ -25,11 +25,12 @@ from paddle.fluid.dygraph.nn import InstanceNorm, Conv2D, Conv2DTranspose, Batch
from ...common import get_logger from ...common import get_logger
from .utils.utils import compute_start_end, get_same_padding, convert_to_list from .utils.utils import compute_start_end, get_same_padding, convert_to_list
from .layers_base import *
__all__ = [ __all__ = [
'SuperConv2D', 'SuperConv2DTranspose', 'SuperSeparableConv2D', 'SuperConv2D', 'SuperConv2DTranspose', 'SuperSeparableConv2D',
'SuperBatchNorm', 'SuperLinear', 'SuperInstanceNorm', 'Block', 'SuperBatchNorm', 'SuperLinear', 'SuperInstanceNorm', 'SuperGroupConv2D',
'SuperGroupConv2D', 'SuperDepthwiseConv2D', 'SuperGroupConv2DTranspose', 'SuperDepthwiseConv2D', 'SuperGroupConv2DTranspose',
'SuperDepthwiseConv2DTranspose', 'SuperLayerNorm', 'SuperEmbedding' 'SuperDepthwiseConv2DTranspose', 'SuperLayerNorm', 'SuperEmbedding'
] ]
...@@ -37,51 +38,6 @@ _logger = get_logger(__name__, level=logging.INFO) ...@@ -37,51 +38,6 @@ _logger = get_logger(__name__, level=logging.INFO)
### TODO: if task is elastic width, need to add re_organize_middle_weight in 1x1 conv in MBBlock ### TODO: if task is elastic width, need to add re_organize_middle_weight in 1x1 conv in MBBlock
_cnt = 0
def counter():
global _cnt
_cnt += 1
return _cnt
class BaseBlock(fluid.dygraph.Layer):
def __init__(self, key=None):
super(BaseBlock, self).__init__()
if key is not None:
self._key = str(key)
else:
self._key = self.__class__.__name__ + str(counter())
# set SuperNet class
def set_supernet(self, supernet):
self.__dict__['supernet'] = supernet
@property
def key(self):
return self._key
class Block(BaseBlock):
"""
Model is composed of nest blocks.
Parameters:
fn(Layer): instance of super layers, such as: SuperConv2D(3, 5, 3).
key(str, optional): key of this layer, one-to-one correspondence between key and candidate config. Default: None.
"""
def __init__(self, fn, fixed=False, key=None):
super(Block, self).__init__(key)
self.fn = fn
self.fixed = fixed
self.candidate_config = self.fn.candidate_config
def forward(self, *inputs, **kwargs):
out = self.supernet.layers_forward(self, *inputs, **kwargs)
return out
class SuperConv2D(fluid.dygraph.Conv2D): class SuperConv2D(fluid.dygraph.Conv2D):
""" """
......
...@@ -20,15 +20,18 @@ import paddle.fluid as fluid ...@@ -20,15 +20,18 @@ import paddle.fluid as fluid
from .utils.utils import get_paddle_version, remove_model_fn from .utils.utils import get_paddle_version, remove_model_fn
pd_ver = get_paddle_version() pd_ver = get_paddle_version()
if pd_ver == 185: if pd_ver == 185:
from .layers_old import BaseBlock, SuperConv2D, SuperLinear from .layers_old import SuperConv2D, SuperLinear
Layer = paddle.fluid.dygraph.Layer Layer = paddle.fluid.dygraph.Layer
DataParallel = paddle.fluid.dygraph.DataParallel
else: else:
from .layers import BaseBlock, SuperConv2D, SuperLinear from .layers import SuperConv2D, SuperLinear
Layer = paddle.nn.Layer Layer = paddle.nn.Layer
DataParallel = paddle.DataParallel
from .layers_base import BaseBlock
from .utils.utils import search_idx from .utils.utils import search_idx
from ...common import get_logger from ...common import get_logger
from ...core import GraphWrapper, dygraph2program from ...core import GraphWrapper, dygraph2program
from .get_sub_model import get_prune_params_config, prune_params from .get_sub_model import get_prune_params_config, prune_params, check_search_space
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
...@@ -75,19 +78,26 @@ class OFABase(Layer): ...@@ -75,19 +78,26 @@ class OFABase(Layer):
def __init__(self, model): def __init__(self, model):
super(OFABase, self).__init__() super(OFABase, self).__init__()
self.model = model self.model = model
self._layers, self._elastic_task = self.get_layers() self._ofa_layers, self._elastic_task, self._key2name, self._layers = self.get_layers(
)
def get_layers(self): def get_layers(self):
ofa_layers = dict()
layers = dict() layers = dict()
key2name = dict()
elastic_task = set() elastic_task = set()
for name, sublayer in self.model.named_sublayers(): model_to_traverse = self.model._layers if isinstance(
self.model, DataParallel) else self.model
for name, sublayer in model_to_traverse.named_sublayers():
if isinstance(sublayer, BaseBlock): if isinstance(sublayer, BaseBlock):
sublayer.set_supernet(self) sublayer.set_supernet(self)
if not sublayer.fixed: if not sublayer.fixed:
ofa_layers[name] = sublayer.candidate_config
layers[sublayer.key] = sublayer.candidate_config layers[sublayer.key] = sublayer.candidate_config
key2name[sublayer.key] = name
for k in sublayer.candidate_config.keys(): for k in sublayer.candidate_config.keys():
elastic_task.add(k) elastic_task.add(k)
return layers, elastic_task return ofa_layers, elastic_task, key2name, layers
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
raise NotImplementedError raise NotImplementedError
...@@ -97,9 +107,11 @@ class OFABase(Layer): ...@@ -97,9 +107,11 @@ class OFABase(Layer):
### if block is fixed, donnot join key into candidate ### if block is fixed, donnot join key into candidate
### concrete config as parameter in kwargs ### concrete config as parameter in kwargs
if block.fixed == False: if block.fixed == False:
assert block.key in self.current_config, 'DONNT have {} layer in config.'.format( assert self._key2name[
block.key) block.
config = self.current_config[block.key] key] in self.current_config, 'DONNT have {} layer in config.'.format(
self._key2name[block.key])
config = self.current_config[self._key2name[block.key]]
else: else:
config = dict() config = dict()
config.update(kwargs) config.update(kwargs)
...@@ -109,6 +121,10 @@ class OFABase(Layer): ...@@ -109,6 +121,10 @@ class OFABase(Layer):
return block.fn(*inputs, **config) return block.fn(*inputs, **config)
@property
def ofa_layers(self):
return self._ofa_layers
@property @property
def layers(self): def layers(self):
return self._layers return self._layers
...@@ -156,6 +172,9 @@ class OFA(OFABase): ...@@ -156,6 +172,9 @@ class OFA(OFABase):
self.task_idx = 0 self.task_idx = 0
self._add_teacher = False self._add_teacher = False
self.netAs_param = [] self.netAs_param = []
self._mapping_layers = None
self._build_ss = False
self._broadcast = False
### if elastic_order is none, use default order ### if elastic_order is none, use default order
if self.elastic_order is not None: if self.elastic_order is not None:
...@@ -165,7 +184,7 @@ class OFA(OFABase): ...@@ -165,7 +184,7 @@ class OFA(OFABase):
if getattr(self.run_config, 'elastic_depth', None) != None: if getattr(self.run_config, 'elastic_depth', None) != None:
depth_list = list(set(self.run_config.elastic_depth)) depth_list = list(set(self.run_config.elastic_depth))
depth_list.sort() depth_list.sort()
self.layers['depth'] = depth_list self._ofa_layers['depth'] = depth_list
if self.elastic_order is None: if self.elastic_order is None:
self.elastic_order = [] self.elastic_order = []
...@@ -178,7 +197,7 @@ class OFA(OFABase): ...@@ -178,7 +197,7 @@ class OFA(OFABase):
if getattr(self.run_config, 'elastic_depth', None) != None: if getattr(self.run_config, 'elastic_depth', None) != None:
depth_list = list(set(self.run_config.elastic_depth)) depth_list = list(set(self.run_config.elastic_depth))
depth_list.sort() depth_list.sort()
self.layers['depth'] = depth_list self._ofa_layers['depth'] = depth_list
self.elastic_order.append('depth') self.elastic_order.append('depth')
# final, elastic width # final, elastic width
...@@ -236,9 +255,14 @@ class OFA(OFABase): ...@@ -236,9 +255,14 @@ class OFA(OFABase):
# if mapping layer is NOT None, add hook and compute distill loss about mapping layers. # if mapping layer is NOT None, add hook and compute distill loss about mapping layers.
mapping_layers = getattr(self.distill_config, 'mapping_layers', None) mapping_layers = getattr(self.distill_config, 'mapping_layers', None)
if mapping_layers != None: if mapping_layers != None:
if isinstance(self.model, DataParallel):
for idx, name in enumerate(mapping_layers):
if name[:7] != '_layers':
mapping_layers[idx] = '_layers.' + name
self._mapping_layers = mapping_layers
self.netAs = [] self.netAs = []
for name, sublayer in self.model.named_sublayers(): for name, sublayer in self.model.named_sublayers():
if name in mapping_layers: if name in self._mapping_layers:
if self.distill_config.mapping_op != None: if self.distill_config.mapping_op != None:
if self.distill_config.mapping_op.lower() == 'conv2d': if self.distill_config.mapping_op.lower() == 'conv2d':
netA = SuperConv2D( netA = SuperConv2D(
...@@ -265,8 +289,7 @@ class OFA(OFABase): ...@@ -265,8 +289,7 @@ class OFA(OFABase):
def _reset_hook_before_forward(self): def _reset_hook_before_forward(self):
self.Tacts, self.Sacts = {}, {} self.Tacts, self.Sacts = {}, {}
mapping_layers = getattr(self.distill_config, 'mapping_layers', None) if self._mapping_layers != None:
if mapping_layers != None:
def get_activation(mem, name): def get_activation(mem, name):
def get_output_hook(layer, input, output): def get_output_hook(layer, input, output):
...@@ -279,8 +302,9 @@ class OFA(OFABase): ...@@ -279,8 +302,9 @@ class OFA(OFABase):
if n in mapping_layers: if n in mapping_layers:
m.register_forward_post_hook(get_activation(mem, n)) m.register_forward_post_hook(get_activation(mem, n))
add_hook(self.model, self.Sacts, mapping_layers) add_hook(self.model, self.Sacts, self._mapping_layers)
add_hook(self.ofa_teacher_model.model, self.Tacts, mapping_layers) add_hook(self.ofa_teacher_model.model, self.Tacts,
self._mapping_layers)
def _compute_epochs(self): def _compute_epochs(self):
if getattr(self, 'epoch', None) == None: if getattr(self, 'epoch', None) == None:
...@@ -326,7 +350,7 @@ class OFA(OFABase): ...@@ -326,7 +350,7 @@ class OFA(OFABase):
def _sample_config(self, task, sample_type='random', phase=None): def _sample_config(self, task, sample_type='random', phase=None):
config = self._sample_from_nestdict( config = self._sample_from_nestdict(
self.layers, sample_type=sample_type, task=task, phase=phase) self._ofa_layers, sample_type=sample_type, task=task, phase=phase)
return config return config
def set_task(self, task, phase=None): def set_task(self, task, phase=None):
...@@ -356,7 +380,13 @@ class OFA(OFABase): ...@@ -356,7 +380,13 @@ class OFA(OFABase):
def _progressive_shrinking(self): def _progressive_shrinking(self):
epoch = self._compute_epochs() epoch = self._compute_epochs()
self.task_idx, phase_idx = search_idx(epoch, self.run_config.n_epochs) phase_idx = None
if len(self.elastic_order) != 1:
assert self.run_config.n_epochs is not None, \
"if not use set_task() to set current task, please set n_epochs in run_config " \
"for to compute which task in this epoch."
self.task_idx, phase_idx = search_idx(epoch,
self.run_config.n_epochs)
self.task = self.elastic_order[:self.task_idx + 1] self.task = self.elastic_order[:self.task_idx + 1]
if 'width' in self.task: if 'width' in self.task:
### change width in task to concrete config ### change width in task to concrete config
...@@ -365,8 +395,6 @@ class OFA(OFABase): ...@@ -365,8 +395,6 @@ class OFA(OFABase):
self.task.append('expand_ratio') self.task.append('expand_ratio')
if 'channel' in self._elastic_task: if 'channel' in self._elastic_task:
self.task.append('channel') self.task.append('channel')
if len(self.run_config.n_epochs[self.task_idx]) == 1:
phase_idx = None
return self._sample_config(task=self.task, phase=phase_idx) return self._sample_config(task=self.task, phase=phase_idx)
def calc_distill_loss(self): def calc_distill_loss(self):
...@@ -413,21 +441,13 @@ class OFA(OFABase): ...@@ -413,21 +441,13 @@ class OFA(OFABase):
def _export_sub_model_config(self, origin_model, config, input_shapes, def _export_sub_model_config(self, origin_model, config, input_shapes,
input_dtypes): input_dtypes):
super_model_config = {}
for name, sublayer in self.model.named_sublayers():
if isinstance(sublayer, BaseBlock):
for param in sublayer.parameters():
super_model_config[name] = sublayer.key
for name, value in super_model_config.items():
super_model_config[name] = config[value] if value in config.keys(
) else {}
origin_model_config = {} origin_model_config = {}
for name, sublayer in origin_model.named_sublayers(): for name, sublayer in origin_model.named_sublayers():
if isinstance(sublayer, BaseBlock):
sublayer = sublayer.fn
for param in sublayer.parameters(include_sublayers=False): for param in sublayer.parameters(include_sublayers=False):
if name in super_model_config.keys(): if name in config.keys():
origin_model_config[param.name] = super_model_config[name] origin_model_config[param.name] = config[name]
program = dygraph2program( program = dygraph2program(
origin_model, inputs=input_shapes, dtypes=input_dtypes) origin_model, inputs=input_shapes, dtypes=input_dtypes)
...@@ -436,10 +456,10 @@ class OFA(OFABase): ...@@ -436,10 +456,10 @@ class OFA(OFABase):
return param_prune_config return param_prune_config
def export(self, def export(self,
origin_model,
config, config,
input_shapes, input_shapes,
input_dtypes, input_dtypes,
origin_model=None,
load_weights_from_supernet=True): load_weights_from_supernet=True):
""" """
Export the weights according origin model and sub model config. Export the weights according origin model and sub model config.
...@@ -458,9 +478,14 @@ class OFA(OFABase): ...@@ -458,9 +478,14 @@ class OFA(OFABase):
origin_model = ofa_model.export(origin_model, config, input_shapes=[1, 3, 28, 28], input_dtypes=['float32']) origin_model = ofa_model.export(origin_model, config, input_shapes=[1, 3, 28, 28], input_dtypes=['float32'])
""" """
super_sd = None super_sd = None
if load_weights_from_supernet: if load_weights_from_supernet and origin_model != None:
super_sd = remove_model_fn(origin_model, self.model.state_dict()) super_sd = remove_model_fn(origin_model, self.model.state_dict())
if origin_model == None:
origin_model = self.model
origin_model = origin_model._layers if isinstance(
origin_model, DataParallel) else origin_model
param_config = self._export_sub_model_config(origin_model, config, param_config = self._export_sub_model_config(origin_model, config,
input_shapes, input_dtypes) input_shapes, input_dtypes)
prune_params(origin_model, param_config, super_sd) prune_params(origin_model, param_config, super_sd)
...@@ -482,6 +507,88 @@ class OFA(OFABase): ...@@ -482,6 +507,88 @@ class OFA(OFABase):
""" """
self.net_config = net_config self.net_config = net_config
def _find_ele(self, inp, targets):
def _roll_eles(target_list, types=(list, set, tuple)):
if isinstance(target_list, types):
for targ in target_list:
for v in _roll_eles(targ, types):
yield v
else:
yield target_list
if inp in list(_roll_eles(targets)):
return True
else:
return False
def _clear_search_space(self, *inputs, **kwargs):
""" find shortcut in model, and clear up the search space """
input_shapes = []
input_dtypes = []
for n in inputs:
input_shapes.append(n.shape)
input_dtypes.append(n.numpy().dtype)
for n, v in kwargs.items():
input_shapes.append(v.shape)
input_dtypes.append(v.numpy().dtype)
### find shortcut block using static model
_st_prog = dygraph2program(
self.model, inputs=input_shapes, dtypes=input_dtypes)
self._same_ss = check_search_space(GraphWrapper(_st_prog))
if self._same_ss != None:
self._same_ss = sorted(self._same_ss)
self._param2key = {}
self._broadcast = True
### the name of sublayer is the key in search space
### param.name is the name in self._same_ss
model_to_traverse = self.model._layers if isinstance(
self.model, DataParallel) else self.model
for name, sublayer in model_to_traverse.named_sublayers():
if isinstance(sublayer, BaseBlock):
for param in sublayer.parameters():
if self._find_ele(param.name, self._same_ss):
self._param2key[param.name] = name
for per_ss in self._same_ss:
for ss in per_ss[1:]:
if 'expand_ratio' in self._ofa_layers[self._param2key[ss]]:
self._ofa_layers[self._param2key[ss]].pop(
'expand_ratio')
elif 'channel' in self._ofa_layers[self._param2key[ss]]:
self._ofa_layers[self._param2key[ss]].pop('channel')
if len(self._ofa_layers[self._param2key[ss]]) == 0:
self._ofa_layers.pop(self._param2key[ss])
def _broadcast_ss(self):
""" broadcast search space after random sample."""
for per_ss in self._same_ss:
for ss in per_ss[1:]:
key = self._param2key[ss]
pre_key = self._param2key[per_ss[0]]
if key in self.current_config:
if 'expand_ratio' in self.current_config[pre_key]:
self.current_config[key].update({
'expand_ratio':
self.current_config[pre_key]['expand_ratio']
})
elif 'channel' in self.current_config[pre_key]:
self.current_config[key].update({
'channel': self.current_config[pre_key]['channel']
})
else:
if 'expand_ratio' in self.current_config[pre_key]:
self.current_config[key] = {
'expand_ratio':
self.current_config[pre_key]['expand_ratio']
}
elif 'channel' in self.current_config[pre_key]:
self.current_config[key] = {
'channel': self.current_config[pre_key]['channel']
}
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
# ===================== teacher process ===================== # ===================== teacher process =====================
teacher_output = None teacher_output = None
...@@ -492,6 +599,10 @@ class OFA(OFABase): ...@@ -492,6 +599,10 @@ class OFA(OFABase):
# ============================================================ # ============================================================
# ==================== student process ===================== # ==================== student process =====================
if not self._build_ss and self.net_config == None:
self._clear_search_space(*inputs, **kwargs)
self._build_ss = True
if getattr(self.run_config, 'dynamic_batch_size', None) != None: if getattr(self.run_config, 'dynamic_batch_size', None) != None:
self.dynamic_iter += 1 self.dynamic_iter += 1
if self.dynamic_iter == self.run_config.dynamic_batch_size[ if self.dynamic_iter == self.run_config.dynamic_batch_size[
...@@ -516,4 +627,7 @@ class OFA(OFABase): ...@@ -516,4 +627,7 @@ class OFA(OFABase):
if 'depth' in self.current_config: if 'depth' in self.current_config:
kwargs['depth'] = self.current_config['depth'] kwargs['depth'] = self.current_config['depth']
if self._broadcast:
self._broadcast_ss()
return self.model.forward(*inputs, **kwargs), teacher_output return self.model.forward(*inputs, **kwargs), teacher_output
...@@ -272,6 +272,31 @@ def _encoder_forward(self, src, src_mask=[None, None]): ...@@ -272,6 +272,31 @@ def _encoder_forward(self, src, src_mask=[None, None]):
return output return output
def _encoder_layer_forward(self, src, src_mask=None, cache=None):
residual = src
if self.normalize_before:
src = self.norm1(src)
# Add cache for encoder for the usage like UniLM
if cache is None:
src = self.self_attn(src, src, src, src_mask)
else:
src, incremental_cache = self.self_attn(src, src, src, src_mask, cache)
src = residual + self.dropout1(src)
if not self.normalize_before:
src = self.norm1(src)
residual = src
if self.normalize_before:
src = self.norm2(src)
src = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = residual + self.dropout2(src)
if not self.normalize_before:
src = self.norm2(src)
return src if cache is None else (src, incremental_cache)
nn.MultiHeadAttention.forward = _mha_forward nn.MultiHeadAttention.forward = _mha_forward
nn.MultiHeadAttention._prepare_qkv = _prepare_qkv nn.MultiHeadAttention._prepare_qkv = _prepare_qkv
nn.TransformerEncoder.forward = _encoder_forward nn.TransformerEncoder.forward = _encoder_forward
nn.TransformerEncoderLayer.forward = _encoder_layer_forward
...@@ -45,5 +45,6 @@ def dynabert_config(model, width_mult, depth_mult=1.0): ...@@ -45,5 +45,6 @@ def dynabert_config(model, width_mult, depth_mult=1.0):
if block_k == 'depth': if block_k == 'depth':
block_v = depth_mult block_v = depth_mult
new_config[block_k] = block_v new_block_k = model._key2name[block_k]
new_config[new_block_k] = block_v
return new_config return new_config
...@@ -48,7 +48,7 @@ def set_state_dict(model, state_dict): ...@@ -48,7 +48,7 @@ def set_state_dict(model, state_dict):
""" """
assert isinstance(model, Layer) assert isinstance(model, Layer)
assert isinstance(state_dict, dict) assert isinstance(state_dict, dict)
for name, param in model.named_parameters(): for name, param in model.state_dict().items():
tmp_n = name.split('.')[:-2] + [name.split('.')[-1]] tmp_n = name.split('.')[:-2] + [name.split('.')[-1]]
tmp_n = '.'.join(tmp_n) tmp_n = '.'.join(tmp_n)
if name in state_dict: if name in state_dict:
...@@ -59,16 +59,15 @@ def set_state_dict(model, state_dict): ...@@ -59,16 +59,15 @@ def set_state_dict(model, state_dict):
_logger.info('{} is not in state_dict'.format(tmp_n)) _logger.info('{} is not in state_dict'.format(tmp_n))
def remove_model_fn(model, sd): def remove_model_fn(model, state_dict):
new_dict = {} new_dict = {}
keys = [] keys = []
for name, param in model.named_parameters(): for name, param in model.state_dict().items():
keys.append(name) keys.append(name)
for name, param in sd.items(): for name, param in state_dict.items():
if name.split('.')[-2] == 'fn': if name.split('.')[-2] == 'fn':
tmp_n = name.split('.')[:-2] + [name.split('.')[-1]] tmp_n = name.split('.')[:-2] + [name.split('.')[-1]]
tmp_n = '.'.join(tmp_n) tmp_n = '.'.join(tmp_n)
#print(name, tmp_n)
if name in keys: if name in keys:
new_dict[name] = param new_dict[name] = param
elif tmp_n in keys: elif tmp_n in keys:
......
...@@ -31,7 +31,7 @@ class TestConvertSuper(unittest.TestCase): ...@@ -31,7 +31,7 @@ class TestConvertSuper(unittest.TestCase):
assert len(sp_model.sublayers()) == 151 assert len(sp_model.sublayers()) == 151
class TestConvertSuper(unittest.TestCase): class TestConvertSuperCase1(unittest.TestCase):
def setUp(self): def setUp(self):
class Model(nn.Layer): class Model(nn.Layer):
def __init__(self): def __init__(self):
......
...@@ -19,10 +19,12 @@ import unittest ...@@ -19,10 +19,12 @@ import unittest
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
from paddle.nn import ReLU from paddle.nn import ReLU
from paddle.vision.models import resnet50
from paddleslim.nas import ofa from paddleslim.nas import ofa
from paddleslim.nas.ofa import OFA, RunConfig, DistillConfig from paddleslim.nas.ofa import OFA, RunConfig, DistillConfig
from paddleslim.nas.ofa.convert_super import supernet from paddleslim.nas.ofa.convert_super import supernet
from paddleslim.nas.ofa.layers import Block, SuperSeparableConv2D from paddleslim.nas.ofa.layers import Block, SuperSeparableConv2D
from paddleslim.nas.ofa.convert_super import Convert, supernet
class ModelConv(nn.Layer): class ModelConv(nn.Layer):
...@@ -413,10 +415,10 @@ class TestExport(unittest.TestCase): ...@@ -413,10 +415,10 @@ class TestExport(unittest.TestCase):
for name, param in self.origin_model.named_parameters(): for name, param in self.origin_model.named_parameters():
origin_dict[name] = param.shape origin_dict[name] = param.shape
self.ofa_model.export( self.ofa_model.export(
self.origin_model,
config, config,
input_shapes=[[1, 64]], input_shapes=[[1, 64]],
input_dtypes=['int64']) input_dtypes=['int64'],
origin_model=self.origin_model)
for name, param in self.origin_model.named_parameters(): for name, param in self.origin_model.named_parameters():
if name in config.keys(): if name in config.keys():
if 'expand_ratio' in config[name]: if 'expand_ratio' in config[name]:
...@@ -424,5 +426,27 @@ class TestExport(unittest.TestCase): ...@@ -424,5 +426,27 @@ class TestExport(unittest.TestCase):
name]['expand_ratio'] name]['expand_ratio']
class TestShortCut(unittest.TestCase):
def setUp(self):
model = resnet50()
sp_net_config = supernet(expand_ratio=[0.25, 0.5, 1.0])
self.model = Convert(sp_net_config).convert(model)
self.images = paddle.randn(shape=[2, 3, 224, 224], dtype='float32')
self._test_clear_search_space()
def _test_clear_search_space(self):
self.ofa_model = OFA(self.model)
self.ofa_model.set_epoch(0)
outs, _ = self.ofa_model(self.images)
self.config = self.ofa_model.current_config
def test_export_model(self):
self.ofa_model.export(
self.config,
input_shapes=[[2, 3, 224, 224]],
input_dtypes=['float32'])
assert len(self.ofa_model.ofa_layers) == 38
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册