未验证 提交 609efce5 编写于 作者: C ceci3 提交者: GitHub

add export model for ofa (#548)


* add export model and depth
上级 526b9ca5
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
import random
import time
import json
from functools import partial
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddlenlp.transformers import BertModel, BertForSequenceClassification, BertTokenizer
from paddlenlp.utils.log import logger
from paddleslim.nas.ofa import OFA, utils
from paddleslim.nas.ofa.convert_super import Convert, supernet
from paddleslim.nas.ofa.layers import BaseBlock
MODEL_CLASSES = {"bert": (BertForSequenceClassification, BertTokenizer), }
def parse_args():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model_type",
default=None,
type=str,
required=True,
help="Model type selected in the list: " +
", ".join(MODEL_CLASSES.keys()), )
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
required=True,
help="Path to pre-trained model or shortcut name selected in the list: "
+ ", ".join(
sum([
list(classes[-1].pretrained_init_configuration.keys())
for classes in MODEL_CLASSES.values()
], [])), )
parser.add_argument(
"--sub_model_output_dir",
default=None,
type=str,
help="The output directory where the sub model predictions and checkpoints will be written.",
)
parser.add_argument(
"--static_sub_model",
default=None,
type=str,
help="The output directory where the sub static model will be written. If set to None, not export static model",
)
parser.add_argument(
"--max_seq_length",
default=128,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.", )
parser.add_argument(
"--n_gpu",
type=int,
default=1,
help="number of gpus to use, 0 for cpu.")
parser.add_argument(
'--width_mult',
type=float,
default=1.0,
help="width mult you want to export")
args = parser.parse_args()
return args
def export_static_model(model, model_path, max_seq_length):
input_shape = [
paddle.static.InputSpec(
shape=[None, max_seq_length], dtype='int64'),
paddle.static.InputSpec(
shape=[None, max_seq_length], dtype='int64')
]
net = paddle.jit.to_static(model, input_spec=input_shape)
paddle.jit.save(net, model_path)
def do_train(args):
paddle.set_device("gpu" if args.n_gpu else "cpu")
args.model_type = args.model_type.lower()
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config_path = os.path.join(args.model_name_or_path, 'model_config.json')
cfg_dict = dict(json.loads(open(config_path).read()))
num_labels = cfg_dict['num_classes']
model = model_class.from_pretrained(
args.model_name_or_path, num_classes=num_labels)
origin_model = model_class.from_pretrained(
args.model_name_or_path, num_classes=num_labels)
sp_config = supernet(expand_ratio=[1.0, args.width_mult])
model = Convert(sp_config).convert(model)
ofa_model = OFA(model)
sd = paddle.load(
os.path.join(args.model_name_or_path, 'model_state.pdparams'))
ofa_model.model.set_state_dict(sd)
best_config = utils.dynabert_config(ofa_model, args.width_mult)
ofa_model.export(
origin_model,
best_config,
input_shapes=[[1, args.max_seq_length], [1, args.max_seq_length]],
input_dtypes=['int64', 'int64'])
for name, sublayer in origin_model.named_sublayers():
if isinstance(sublayer, paddle.nn.MultiHeadAttention):
sublayer.num_heads = int(args.width_mult * sublayer.num_heads)
if args.static_sub_model != None:
export_static_model(origin_model, args.static_sub_model,
args.max_seq_length)
def print_arguments(args):
"""print arguments"""
print('----------- Configuration Arguments -----------')
for arg, value in sorted(vars(args).items()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
if __name__ == "__main__":
args = parse_args()
print_arguments(args)
do_train(args)
......@@ -31,6 +31,7 @@ from paddlenlp.utils.log import logger
from paddlenlp.metrics import AccuracyAndF1, Mcc, PearsonAndSpearman
import paddlenlp.datasets as datasets
from paddleslim.nas.ofa import OFA, DistillConfig, utils
from paddleslim.nas.ofa.utils import nlp_utils
from paddleslim.nas.ofa.convert_super import Convert, supernet
TASK_CLASSES = {
......@@ -215,13 +216,13 @@ def reorder_neuron_head(model, head_importance, neuron_importance):
for layer, current_importance in enumerate(neuron_importance):
# reorder heads
idx = paddle.argsort(head_importance[layer], descending=True)
utils.reorder_head(model.bert.encoder.layers[layer].self_attn, idx)
nlp_utils.reorder_head(model.bert.encoder.layers[layer].self_attn, idx)
# reorder neurons
idx = paddle.argsort(
paddle.to_tensor(current_importance), descending=True)
utils.reorder_neuron(
nlp_utils.reorder_neuron(
model.bert.encoder.layers[layer].linear1.fn, idx, dim=1)
utils.reorder_neuron(
nlp_utils.reorder_neuron(
model.bert.encoder.layers[layer].linear2.fn, idx, dim=0)
......@@ -422,7 +423,7 @@ def do_train(args):
# Step6: Calculate the importance of neurons and head,
# and then reorder them according to the importance.
head_importance, neuron_importance = utils.compute_neuron_head_importance(
head_importance, neuron_importance = nlp_utils.compute_neuron_head_importance(
args.task_name,
ofa_model.model,
dev_data_loader,
......@@ -512,7 +513,7 @@ def do_train(args):
dev_data_loader,
width_mult=100)
for idx, width_mult in enumerate(args.width_mult_list):
net_config = apply_config(ofa_model, width_mult)
net_config = utils.dynabert_config(ofa_model, width_mult)
ofa_model.set_net_config(net_config)
tic_eval = time.time()
if args.task_name == "mnli":
......
此差异已折叠。
......@@ -26,6 +26,10 @@ import logging
import logging
from functools import partial
import six
if six.PY2:
from pathlib2 import Path
else:
from pathlib import Path
import paddle.fluid.dygraph as D
import paddle.fluid as F
......@@ -288,8 +292,16 @@ def get_config(pretrain_dir_or_url):
'ernie-2.0-large-en': bce + 'model-ernie2.0-large-en.1.tar.gz',
'ernie-tiny': bce + 'model-ernie_tiny.1.tar.gz',
}
url = resource_map[pretrain_dir_or_url]
pretrain_dir = _fetch_from_remote(url, False)
if not Path(pretrain_dir_or_url).exists() and str(
pretrain_dir_or_url) in resource_map:
url = resource_map[pretrain_dir_or_url]
pretrain_dir = _fetch_from_remote(url, False)
else:
log.info('pretrain dir %s not in %s, read from local' %
(pretrain_dir_or_url, repr(resource_map)))
pretrain_dir = Path(pretrain_dir_or_url)
config_path = os.path.join(pretrain_dir, 'ernie_config.json')
if not os.path.exists(config_path):
raise ValueError('config path not found: %s' % config_path)
......
......@@ -88,9 +88,14 @@ OFA实例
.. code-block:: python
from paddlslim.nas.ofa import OFA
ofa_model = OFA(model)
from paddle.vision.models import mobilenet_v1
from paddleslim.nas.ofa import OFA
from paddleslim.nas.ofa.convert_super import Convert, supernet
model = mobilenet_v1()
sp_net_config = supernet(kernel_size=(3, 5, 7), expand_ratio=[1, 2, 4])
sp_model = Convert(sp_net_config).convert(model)
ofa_model = OFA(sp_model)
..
.. py:method:: set_epoch(epoch)
......@@ -140,7 +145,7 @@ OFA实例
.. code-block:: python
config = ofa_model.current_config
config = {'conv2d_0': {'expand_ratio': 2}, 'conv2d_1': {'expand_ratio': 2}}
ofa_model.set_net_config(config)
.. py:method:: calc_distill_loss()
......@@ -159,15 +164,25 @@ OFA实例
.. py:method:: search()
### TODO
.. py:method:: export(config)
.. py:method:: export(origin_model, config, input_shapes, input_dtypes, load_weights_from_supernet=True)
根据传入的子网络配置导出当前子网络的参数
根据传入的原始模型结构、子网络配置,模型输入的形状和类型导出当前子网络,导出的子网络可以进一步训练、预测或者调用框架动静转换功能转为静态图模型
**参数:**
- **config(dict):** 某个子网络每层的配置。
- **origin_model(paddle.nn.Layer)** 原始模型实例,子模型会直接在原始模型的基础上进行修改。
- **config(dict)** 某个子网络每层的配置,可以用。
- **input_shapes(list|list(list))** 模型输入的形状。
- **input_dtypes(list)** 模型输入的类型。
- **load_weights_from_supernet(bool, optional)** 是否从超网络加载参数。若为False,则不从超网络加载参数,则只根据config裁剪原始模型的网络结构;若为True,则用超网络参数来初始化原始模型,并根据config裁剪原始模型的网络结构。默认:True
**返回:**
TODO
子模型实例。
**示例代码:**
TODO
.. code-block:: python
from paddle.vision.models import mobilenet_v1
origin_model = mobilenet_v1()
config = {'conv2d_0': {'expand_ratio': 2}, 'conv2d_1': {'expand_ratio': 2}}
origin_model = ofa_model.export(origin_model, config, input_shapes=[1, 3, 28, 28], input_dtypes=['float32'])
......@@ -13,7 +13,6 @@
# limitations under the License.
import paddle
import numpy as np
import paddle.jit as jit
from ..core import GraphWrapper, dygraph2program
__all__ = ["flops", "dygraph_flops"]
......
......@@ -14,6 +14,8 @@
from .ofa import OFA, RunConfig, DistillConfig
from .convert_super import supernet
from .utils.special_config import *
from .get_sub_model import *
from .utils.utils import get_paddle_version
pd_ver = get_paddle_version()
......
......@@ -579,14 +579,10 @@ class Convert:
new_attr_name = []
if pd_ver == 185:
new_attr_name += [
'size', 'is_sparse', 'is_distributed', 'param_attr',
'dtype'
'is_sparse', 'is_distributed', 'param_attr', 'dtype'
]
else:
new_attr_name += [
'num_embeddings', 'embedding_dim', 'sparse',
'weight_attr', 'name'
]
new_attr_name += ['sparse', 'weight_attr', 'name']
self._change_name(layer, pd_ver, has_bias=False)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
__all__ = ['get_prune_params_config', 'prune_params']
def get_prune_params_config(graph, origin_model_config):
param_config = {}
precedor = None
for op in graph.ops():
### TODO(ceci3):
### 1. fix config when this op is concat by graph.pre_ops(op)
### 2. add kernel_size in config
### 3. add channel in config
for inp in op.all_inputs():
n_ops = graph.next_ops(op)
if inp._var.name in origin_model_config.keys():
if 'expand_ratio' in origin_model_config[inp._var.name].keys():
tmp = origin_model_config[inp._var.name]['expand_ratio']
if len(inp._var.shape) > 1:
if inp._var.name in param_config.keys():
param_config[inp._var.name].append(tmp)
### first op
else:
param_config[inp._var.name] = [precedor, tmp]
else:
param_config[inp._var.name] = [tmp]
precedor = tmp
else:
precedor = None
for n_op in n_ops:
for next_inp in n_op.all_inputs():
if next_inp._var.persistable == True:
if next_inp._var.name in origin_model_config.keys():
if 'expand_ratio' in origin_model_config[
next_inp._var.name].keys():
tmp = origin_model_config[next_inp._var.name][
'expand_ratio']
pre = tmp if precedor is None else precedor
if len(next_inp._var.shape) > 1:
param_config[next_inp._var.name] = [pre]
else:
param_config[next_inp._var.name] = [tmp]
else:
if len(next_inp._var.
shape) > 1 and precedor != None:
param_config[
next_inp._var.name] = [precedor, None]
else:
param_config[next_inp._var.name] = [precedor]
return param_config
def prune_params(model, param_config, super_model_sd=None):
for name, param in model.named_parameters():
t_value = param.value().get_tensor()
value = np.array(t_value).astype("float32")
if super_model_sd != None:
super_t_value = super_model_sd[name].value().get_tensor()
super_value = np.array(super_t_value).astype("float32")
if param.name in param_config.keys():
if len(param_config[param.name]) > 1:
in_exp = param_config[param.name][0]
out_exp = param_config[param.name][1]
in_chn = int(value.shape[0]) if in_exp == None else int(
value.shape[0] * in_exp)
out_chn = int(value.shape[1]) if out_exp == None else int(
value.shape[1] * out_exp)
prune_value = super_value[:in_chn, :out_chn, ...] \
if super_model_sd != None else value[:in_chn, :out_chn, ...]
else:
out_chn = int(value.shape[0]) if param_config[param.name][
0] == None else int(value.shape[0] *
param_config[param.name][0])
prune_value = super_value[:out_chn, ...] \
if super_model_sd != None else value[:out_chn, ...]
else:
prune_value = super_value if super_model_sd != None else value
p = t_value._place()
if p.is_cpu_place():
place = paddle.CPUPlace()
elif p.is_cuda_pinned_place():
place = paddle.CUDAPinnedPlace()
else:
place = paddle.CUDAPlace(p.gpu_device_id())
t_value.set(prune_value, place)
if param.trainable:
param.clear_gradient()
......@@ -17,7 +17,7 @@ import numpy as np
from collections import namedtuple
import paddle
import paddle.fluid as fluid
from .utils.utils import get_paddle_version
from .utils.utils import get_paddle_version, remove_model_fn
pd_ver = get_paddle_version()
if pd_ver == 185:
from .layers_old import BaseBlock, SuperConv2D, SuperLinear
......@@ -27,6 +27,8 @@ else:
Layer = paddle.nn.Layer
from .utils.utils import search_idx
from ...common import get_logger
from ...core import GraphWrapper, dygraph2program
from .get_sub_model import get_prune_params_config, prune_params
_logger = get_logger(__name__, level=logging.INFO)
......@@ -125,8 +127,14 @@ class OFA(OFABase):
Examples:
.. code-block:: python
from paddlslim.nas.ofa import OFA
ofa_model = OFA(model)
from paddle.vision.models import mobilenet_v1
from paddleslim.nas.ofa import OFA
from paddleslim.nas.ofa.convert_super import Convert, supernet
model = mobilenet_v1()
sp_net_config = supernet(kernel_size=(3, 5, 7), expand_ratio=[1, 2, 4])
sp_model = Convert(sp_net_config).convert(model)
ofa_model = OFA(sp_model)
"""
......@@ -206,8 +214,6 @@ class OFA(OFABase):
self.model.train()
def _prepare_distill(self):
self.Tacts, self.Sacts = {}, {}
if self.distill_config.teacher_model == None:
logging.error(
'If you want to add distill, please input instance of teacher model'
......@@ -257,6 +263,11 @@ class OFA(OFABase):
self.netAs_param.extend(netA.parameters())
self.netAs.append(netA)
def _reset_hook_before_forward(self):
self.Tacts, self.Sacts = {}, {}
mapping_layers = getattr(self.distill_config, 'mapping_layers', None)
if mapping_layers != None:
def get_activation(mem, name):
def get_output_hook(layer, input, output):
mem[name] = output
......@@ -369,6 +380,9 @@ class OFA(OFABase):
assert len(self.netAs) > 0
for i, netA in enumerate(self.netAs):
n = self.distill_config.mapping_layers[i]
### add for elastic depth
if n not in self.Sacts.keys():
continue
Tact = self.Tacts[n]
Sact = self.Sacts[n]
if isinstance(netA, SuperConv2D):
......@@ -397,9 +411,64 @@ class OFA(OFABase):
def search(self, eval_func, condition):
pass
### TODO: complete it
def export(self, config):
pass
def _export_sub_model_config(self, origin_model, config, input_shapes,
input_dtypes):
super_model_config = {}
for name, sublayer in self.model.named_sublayers():
if isinstance(sublayer, BaseBlock):
for param in sublayer.parameters():
super_model_config[name] = sublayer.key
for name, value in super_model_config.items():
super_model_config[name] = config[value] if value in config.keys(
) else {}
origin_model_config = {}
for name, sublayer in origin_model.named_sublayers():
for param in sublayer.parameters(include_sublayers=False):
if name in super_model_config.keys():
origin_model_config[param.name] = super_model_config[name]
program = dygraph2program(
origin_model, inputs=input_shapes, dtypes=input_dtypes)
graph = GraphWrapper(program)
param_prune_config = get_prune_params_config(graph, origin_model_config)
return param_prune_config
def export(self,
origin_model,
config,
input_shapes,
input_dtypes,
load_weights_from_supernet=True):
"""
Export the weights according origin model and sub model config.
Parameters:
origin_model(paddle.nn.Layer): the instance of original model.
config(dict): the config of sub model, can get by OFA.get_current_config() or some special config, such as paddleslim.nas.ofa.utils.dynabert_config(width_mult).
input_shapes(list|list(list)): the shape of all inputs.
input_dtypes(list): the dtype of all inputs.
load_weights_from_supernet(bool, optional): whether to load weights from SuperNet. Default: False.
Examples:
.. code-block:: python
from paddle.vision.models import mobilenet_v1
origin_model = mobilenet_v1()
config = {'conv2d_0': {'expand_ratio': 2}, 'conv2d_1': {'expand_ratio': 2}}
origin_model = ofa_model.export(origin_model, config, input_shapes=[1, 3, 28, 28], input_dtypes=['float32'])
"""
super_sd = None
if load_weights_from_supernet:
super_sd = remove_model_fn(origin_model, self.model.state_dict())
param_config = self._export_sub_model_config(origin_model, config,
input_shapes, input_dtypes)
prune_params(origin_model, param_config, super_sd)
return origin_model
@property
def get_current_config(self):
return self.current_config
def set_net_config(self, net_config):
"""
......@@ -408,7 +477,7 @@ class OFA(OFABase):
net_config(dict): special the config of sug-network.
Examples:
.. code-block:: python
config = ofa_model.current_config
config = {'conv2d_0': {'expand_ratio': 2}, 'conv2d_1': {'expand_ratio': 2}}
ofa_model.set_net_config(config)
"""
self.net_config = net_config
......@@ -417,6 +486,7 @@ class OFA(OFABase):
# ===================== teacher process =====================
teacher_output = None
if self._add_teacher:
self._reset_hook_before_forward()
teacher_output = self.ofa_teacher_model.model.forward(*inputs,
**kwargs)
# ============================================================
......
......@@ -17,5 +17,3 @@ from .special_config import *
from .utils import get_paddle_version
pd_ver = get_paddle_version()
if pd_ver == 200:
from .nlp_utils import *
......@@ -27,10 +27,17 @@ def dynabert_config(model, width_mult, depth_mult=1.0):
return True
return False
start_idx = 0
for idx, (block_k, block_v) in enumerate(model.layers.items()):
if 'linear' in block_k:
start_idx = int(block_k.split('_')[1])
break
for idx, (block_k, block_v) in enumerate(model.layers.items()):
if isinstance(block_v, dict) and len(block_v.keys()) != 0:
name, name_idx = block_k.split('_'), int(block_k.split('_')[1])
if fix_exp(name_idx) or 'emb' in block_k or idx >= block_name:
if fix_exp(name_idx -
start_idx) or 'emb' in block_k or idx >= block_name:
block_v['expand_ratio'] = 1.0
else:
block_v['expand_ratio'] = width_mult
......
......@@ -59,6 +59,25 @@ def set_state_dict(model, state_dict):
_logger.info('{} is not in state_dict'.format(tmp_n))
def remove_model_fn(model, sd):
new_dict = {}
keys = []
for name, param in model.named_parameters():
keys.append(name)
for name, param in sd.items():
if name.split('.')[-2] == 'fn':
tmp_n = name.split('.')[:-2] + [name.split('.')[-1]]
tmp_n = '.'.join(tmp_n)
#print(name, tmp_n)
if name in keys:
new_dict[name] = param
elif tmp_n in keys:
new_dict[tmp_n] = param
else:
_logger.debug('{} is not in state_dict'.format(tmp_n))
return new_dict
def compute_start_end(kernel_size, sub_kernel_size):
center = kernel_size // 2
sub_center = sub_kernel_size // 2
......
......@@ -139,7 +139,7 @@ class ModelConv2(nn.Layer):
class ModelLinear(nn.Layer):
def __init__(self):
super(ModelLinear, self).__init__()
with supernet(expand_ratio=(1, 2, 4)) as ofa_super:
with supernet(expand_ratio=(1.0, 2.0, 4.0)) as ofa_super:
models = []
models += [nn.Embedding(num_embeddings=64, embedding_dim=64)]
models += [nn.Linear(64, 128)]
......@@ -167,6 +167,22 @@ class ModelLinear(nn.Layer):
return inputs
class ModelOriginLinear(nn.Layer):
def __init__(self):
super(ModelOriginLinear, self).__init__()
models = []
models += [nn.Embedding(num_embeddings=64, embedding_dim=64)]
models += [nn.Linear(64, 128)]
models += [nn.LayerNorm(128)]
models += [nn.Linear(128, 256)]
models += [nn.Linear(256, 256)]
self.models = paddle.nn.Sequential(*models)
def forward(self, inputs):
return self.models(inputs)
class ModelLinear1(nn.Layer):
def __init__(self):
super(ModelLinear1, self).__init__()
......@@ -373,5 +389,40 @@ class TestOFACase4(unittest.TestCase):
self.model = ModelConv2()
class TestExport(unittest.TestCase):
def setUp(self):
self._init_model()
def _init_model(self):
self.origin_model = ModelOriginLinear()
model = ModelLinear()
self.ofa_model = OFA(model)
def test_ofa(self):
config = {
'embedding_1': {
'expand_ratio': (2.0)
},
'linear_3': {
'expand_ratio': (2.0)
},
'linear_4': {},
'linear_5': {}
}
origin_dict = {}
for name, param in self.origin_model.named_parameters():
origin_dict[name] = param.shape
self.ofa_model.export(
self.origin_model,
config,
input_shapes=[[1, 64]],
input_dtypes=['int64'])
for name, param in self.origin_model.named_parameters():
if name in config.keys():
if 'expand_ratio' in config[name]:
assert origin_dict[name][-1] == param.shape[-1] * config[
name]['expand_ratio']
if __name__ == '__main__':
unittest.main()
......@@ -20,7 +20,8 @@ import paddle
import paddle.nn as nn
from paddle.vision.models import mobilenet_v1
from paddleslim.nas.ofa.convert_super import Convert, supernet
from paddleslim.nas.ofa.utils import compute_neuron_head_importance, reorder_head, reorder_neuron, set_state_dict, dynabert_config
from paddleslim.nas.ofa.utils import set_state_dict, dynabert_config
from paddleslim.nas.ofa.utils.nlp_utils import compute_neuron_head_importance, reorder_head, reorder_neuron
from paddleslim.nas.ofa import OFA
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册