From 1e12c3266776db0067bf9277451889d5a3ac4b37 Mon Sep 17 00:00:00 2001 From: ceci3 Date: Wed, 9 Dec 2020 15:45:03 +0800 Subject: [PATCH] add for nlp (#529) * for bert --- paddleslim/nas/ofa/convert_super.py | 4 +- paddleslim/nas/ofa/ofa.py | 52 +++-- paddleslim/nas/ofa/utils/__init__.py | 5 + paddleslim/nas/ofa/utils/nlp_utils.py | 272 ++++++++++++++++++++++++++ paddleslim/nas/ofa/utils/utils.py | 56 +++++- tests/test_ofa.py | 27 ++- tests/test_ofa_utils.py | 129 ++++++++++++ 7 files changed, 515 insertions(+), 30 deletions(-) create mode 100644 paddleslim/nas/ofa/utils/nlp_utils.py create mode 100644 tests/test_ofa_utils.py diff --git a/paddleslim/nas/ofa/convert_super.py b/paddleslim/nas/ofa/convert_super.py index 4040b928..bb5eb1db 100644 --- a/paddleslim/nas/ofa/convert_super.py +++ b/paddleslim/nas/ofa/convert_super.py @@ -25,7 +25,7 @@ if pd_ver == 185: from paddle.fluid.dygraph.nn import Conv2D, Conv2DTranspose, Linear, LayerNorm, Embedding from .layers import * from . import layers - Layer = fluid.dygraph.nn.Layer + Layer = paddle.fluid.dygraph.Layer else: import paddle.nn as nn from paddle.nn import Conv2D, Conv2DTranspose, Linear, LayerNorm, Embedding @@ -35,7 +35,7 @@ else: _logger = get_logger(__name__, level=logging.INFO) -__all__ = ['supernet'] +__all__ = ['supernet', 'Convert'] WEIGHT_LAYER = ['conv', 'linear', 'embedding'] diff --git a/paddleslim/nas/ofa/ofa.py b/paddleslim/nas/ofa/ofa.py index af2ac29f..cbdda4b2 100644 --- a/paddleslim/nas/ofa/ofa.py +++ b/paddleslim/nas/ofa/ofa.py @@ -20,10 +20,10 @@ import paddle.fluid as fluid from .utils.utils import get_paddle_version pd_ver = get_paddle_version() if pd_ver == 185: - from .layers import BaseBlock, SuperConv2D + from .layers import BaseBlock, SuperConv2D, SuperLinear Layer = paddle.fluid.dygraph.Layer else: - from .layers_new import BaseBlock, SuperConv2D + from .layers_new import BaseBlock, SuperConv2D, SuperLinear Layer = paddle.nn.Layer from .utils.utils import search_idx from ...common import get_logger @@ -40,7 +40,7 @@ RunConfig.__new__.__defaults__ = (None, ) * len(RunConfig._fields) DistillConfig = namedtuple('DistillConfig', [ 'lambda_distill', 'teacher_model', 'mapping_layers', 'teacher_model_path', - 'distill_fn' + 'distill_fn', 'mapping_op' ]) DistillConfig.__new__.__defaults__ = (None, ) * len(DistillConfig._fields) @@ -193,12 +193,28 @@ class OFA(OFABase): self.netAs = [] for name, sublayer in self.model.named_sublayers(): if name in mapping_layers: - netA = SuperConv2D( - getattr(sublayer, '_num_filters', - sublayer._out_channels), - getattr(sublayer, '_num_filters', - sublayer._out_channels), 1) - self.netAs_param.extend(netA.parameters()) + if self.distill_config.mapping_op != None: + if self.distill_config.mapping_op.lower() == 'conv2d': + netA = SuperConv2D( + getattr(sublayer, '_num_filters', + sublayer._out_channels), + getattr(sublayer, '_num_filters', + sublayer._out_channels), 1) + elif self.distill_config.mapping_op.lower() == 'linear': + netA = SuperLinear( + getattr(sublayer, '_output_dim', + sublayer._out_features), + getattr(sublayer, '_output_dim', + sublayer._out_features)) + else: + raise NotImplementedError( + "Not Support Op: {}".format( + self.distill_config.mapping_op.lower())) + else: + netA = None + + if netA != None: + self.netAs_param.extend(netA.parameters()) self.netAs.append(netA) def get_activation(mem, name): @@ -289,16 +305,24 @@ class OFA(OFABase): losses = [] assert len(self.netAs) > 0 for i, netA in enumerate(self.netAs): - assert isinstance(netA, SuperConv2D) n = self.distill_config.mapping_layers[i] Tact = self.Tacts[n] Sact = self.Sacts[n] - Sact = netA( - Sact, channel=getattr(netA, '_num_filters', netA._out_channels)) + if isinstance(netA, SuperConv2D): + Sact = netA( + Sact, + channel=getattr(netA, '_num_filters', netA._out_channels)) + elif isinstance(netA, SuperLinear): + Sact = netA( + Sact, + channel=getattr(netA, '_output_dim', netA._out_features)) + else: + Sact = Sact + if self.distill_config.distill_fn == None: - loss = fluid.layers.mse_loss(Sact, Tact) + loss = fluid.layers.mse_loss(Sact, Tact.detach()) else: - loss = distill_fn(Sact, Tact) + loss = distill_fn(Sact, Tact.detach()) losses.append(loss) return sum(losses) * self.distill_config.lambda_distill diff --git a/paddleslim/nas/ofa/utils/__init__.py b/paddleslim/nas/ofa/utils/__init__.py index 342ae0ed..2a70169b 100644 --- a/paddleslim/nas/ofa/utils/__init__.py +++ b/paddleslim/nas/ofa/utils/__init__.py @@ -13,3 +13,8 @@ # limitations under the License. from .utils import * + +from .utils import get_paddle_version +pd_ver = get_paddle_version() +if pd_ver == 200: + from .nlp_utils import * diff --git a/paddleslim/nas/ofa/utils/nlp_utils.py b/paddleslim/nas/ofa/utils/nlp_utils.py new file mode 100644 index 00000000..598b1a4e --- /dev/null +++ b/paddleslim/nas/ofa/utils/nlp_utils.py @@ -0,0 +1,272 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +__all__ = ["compute_neuron_head_importance", "reorder_head", "reorder_neuron"] + + +def compute_neuron_head_importance(task_name, + model, + data_loader, + num_layers, + num_heads, + loss_fct=nn.loss.CrossEntropyLoss(), + intermediate_name='linear1', + output_name='linear2'): + """ + Compute the importance of multi-head attention and feed-forward neuron in each transformer layer. + + Args: + task_name(str): task name. + model(paddle.nn.Layer): the instance of transformer model. + data_loader(DataLoader): An iterable data loader is used for evaluate. An instance of `paddle.io.Dataloader`. + num_layers(int): number of transformer layers. + num_heads(int): number of heads in each multi-head attention. + loss_fct(Loss|optional): loss function can be a `paddle.nn.Layer` instance. Default: `nn.loss.CrossEntropyLoss()`. + intermediate_name(str|optional): the name of intermediate `Linear` layer in feed-forward. Default: `linear1`. + output_name(str|optional): the name of output `Linear` layer in feed-forward. Default: `linear2`. + """ + head_importance = paddle.zeros( + shape=[num_layers, num_heads], dtype='float32') + head_mask = paddle.ones(shape=[num_layers, num_heads], dtype='float32') + head_mask.stop_gradient = False + + intermediate_weight = [] + intermediate_bias = [] + output_weight = [] + + for name, w in model.named_parameters(): + if intermediate_name in name: + if len(w.shape) > 1: + intermediate_weight.append(w) + else: + intermediate_bias.append(w) + + if output_name in name: + if len(w.shape) > 1: + output_weight.append(w) + + neuron_importance = [] + for w in intermediate_weight: + neuron_importance.append(np.zeros(shape=[w.shape[1]], dtype='float32')) + + for batch in data_loader: + input_ids, segment_ids, labels = batch + logits = model(input_ids, segment_ids, attention_mask=[None, head_mask]) + loss = loss_fct(logits, labels) + loss.backward() + head_importance += paddle.abs(paddle.to_tensor(head_mask.gradient())) + + for w1, b1, w2, current_importance in zip( + intermediate_weight, intermediate_bias, output_weight, + neuron_importance): + current_importance += np.abs( + (np.sum(w1.numpy() * w1.gradient(), axis=0) + b1.numpy() * + b1.gradient())) + current_importance += np.abs( + np.sum(w2.numpy() * w2.gradient(), axis=1)) + + return head_importance, neuron_importance + + +def reorder_head(layer, index): + """ + Reorder head weights according index. + + Args: + layer(paddle.nn.Layer): the instance of `paddle.nn.MultiHeadAttention` layer. + index(list): the sort indices of multi-head. + """ + assert isinstance(layer, nn.MultiHeadAttention), \ + "layer in reorder_head must be the instance of `paddle.nn.MultiHeadAttention`." + n, a = layer.num_heads, layer.head_dim + idx = paddle.reshape( + paddle.index_select( + paddle.reshape( + paddle.arange( + 0, n * a, dtype='int64'), shape=[n, a]), + index=index, + axis=0), + shape=[-1]) + + def reorder_head_matrix(linearLayer, index, dim=1): + W = paddle.index_select(linearLayer.weight, index, axis=dim).detach() + if linearLayer.bias is not None: + if dim == 0: + b = paddle.assign(linearLayer.bias).detach() + else: + b = paddle.assign( + paddle.index_select( + linearLayer.bias, index, axis=0)).detach() + + linearLayer.weight.stop_gradient = True + linearLayer.weight.set_value(W) + linearLayer.weight.stop_gradient = False + if linearLayer.bias is not None: + linearLayer.bias.stop_gradient = True + linearLayer.bias.set_value(b) + linearLayer.bias.stop_gradient = False + + reorder_head_matrix( + layer.q_proj.fn if hasattr(layer.q_proj, 'fn') else layer.q_proj, idx) + reorder_head_matrix( + layer.k_proj.fn if hasattr(layer.k_proj, 'fn') else layer.k_proj, idx) + reorder_head_matrix( + layer.v_proj.fn if hasattr(layer.v_proj, 'fn') else layer.v_proj, idx) + reorder_head_matrix( + layer.out_proj.fn if hasattr(layer.out_proj, 'fn') else layer.out_proj, + idx, + dim=0) + + +def reorder_neuron(layer, index, dim=0): + """ + Reorder feed-forward weights according index. + + Args: + layer(paddle.nn.Layer): the instance of `paddle.nn.Linear` layer. + index(list): the sort indices of feed-forward. + dim(int): select weights according to the dim. + """ + linearLayer = layer.fn if hasattr(layer, 'fn') else layer + W = paddle.index_select(linearLayer.weight, index, axis=dim).detach() + if linearLayer.bias is not None: + if dim == 0: + b = paddle.assign(linearLayer.bias).detach() + else: + b = paddle.assign( + paddle.index_select( + linearLayer.bias, index, axis=0)).detach() + linearLayer.weight.stop_gradient = True + linearLayer.weight.set_value(W) + linearLayer.weight.stop_gradient = False + + if linearLayer.bias is not None: + linearLayer.bias.stop_gradient = True + linearLayer.bias.set_value(b) + linearLayer.bias.stop_gradient = False + + +### monkey patch for MultiHeadAttention _prepare_qkv to change num_heads. +def _prepare_qkv(self, query, key, value, cache=None): + q = self.q_proj(query) + if hasattr(self.q_proj, + 'fn') and self.q_proj.fn.cur_config['expand_ratio'] != None: + self.num_heads = int(self.num_heads * + self.q_proj.fn.cur_config['expand_ratio']) + q = paddle.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) + q = paddle.transpose(x=q, perm=[0, 2, 1, 3]) + + if isinstance(cache, self.StaticCache): + # for encoder-decoder attention in inference and has cached + k, v = cache.k, cache.v + else: + k, v = self.compute_kv(key, value) + + if isinstance(cache, self.Cache): + # for decoder self-attention in inference + k = paddle.concat([cache.k, k], axis=2) + v = paddle.concat([cache.v, v], axis=2) + cache = self.Cache(k, v) + + return (q, k, v) if cache is None else (q, k, v, cache) + + +### monkey patch for MultiHeadAttention forward to accept head_mask +### attn_mask[0] = attn_mask, attn_mask[1] = head_mask +def _mha_forward(self, query, key, value, attn_mask=None, cache=None): + key = query if key is None else key + value = query if value is None else value + # compute q ,k ,v + if cache is None: + q, k, v = self._prepare_qkv(query, key, value, cache) + else: + q, k, v, cache = self._prepare_qkv(query, key, value, cache) + + # scale dot product attention + # TODO: use paddle.matmul, however it doesn't support `alpha` + product = paddle.fluid.layers.matmul( + x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5) + if attn_mask[0] is not None: + # TODO(guosheng): support bool mask + product = product + attn_mask[0] + weights = F.softmax(product) + if self.dropout: + weights = F.dropout( + weights, + self.dropout, + training=self.training, + mode="upscale_in_train") + + if attn_mask[1] is not None: + weights = weights * attn_mask[1] + + out = paddle.matmul(weights, v) + + # combine heads + out = paddle.transpose(out, perm=[0, 2, 1, 3]) + out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) + + # project to output + out = self.out_proj(out) + + outs = [out] + if self.need_weights: + outs.append(weights) + if cache is not None: + outs.append(cache) + + if hasattr(self.q_proj, + 'fn') and self.q_proj.fn.cur_config['expand_ratio'] != None: + self.num_heads = int( + float(self.num_heads) / self.q_proj.fn.cur_config['expand_ratio']) + return out if len(outs) == 1 else tuple(outs) + + +### monkey patch for TransformerEncoder forward to accept head_mask +### attn_mask[0] = attn_mask, attn_mask[1] = head_mask +def _encoder_forward(self, src, src_mask=[None, None]): + output = src + if src_mask[1] is not None: + head_mask = src_mask[1] + if len(head_mask.shape) == 1: + head_mask = paddle.unsqueeze( + paddle.unsqueeze( + paddle.unsqueeze(paddle.unsqueeze(head_mask, 0), 0), -1), + -1) + head_mask = paddle.expand( + head_mask, shape=[self.num_layers] + head_mask.shape[1:]) + elif len(head_mask.shape) == 2: + head_mask = paddle.unsqueeze( + paddle.unsqueeze(paddle.unsqueeze(head_mask, 1), -1), -1) + else: + head_mask = [None] * self.num_layers + + for i, mod in enumerate(self.layers): + output = mod(output, src_mask=[src_mask[0], head_mask[i]]) + + if self.norm is not None: + output = self.norm(output) + + return output + + +nn.MultiHeadAttention.forward = _mha_forward +nn.MultiHeadAttention._prepare_qkv = _prepare_qkv +nn.TransformerEncoder.forward = _encoder_forward diff --git a/paddleslim/nas/ofa/utils/utils.py b/paddleslim/nas/ofa/utils/utils.py index 70425aeb..a4ec2f88 100644 --- a/paddleslim/nas/ofa/utils/utils.py +++ b/paddleslim/nas/ofa/utils/utils.py @@ -12,6 +12,52 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging +import paddle +from ....common import get_logger + + +def get_paddle_version(): + import paddle + pd_ver = 185 + if hasattr(paddle, 'nn'): + if hasattr(paddle.nn, 'Conv1D'): ### judge 2.0 alpha + pd_ver = 200 + + return pd_ver + + +pd_ver = get_paddle_version() +if pd_ver == 185: + Layer = paddle.fluid.dygraph.Layer +else: + Layer = paddle.nn.Layer + +_logger = get_logger(__name__, level=logging.INFO) + +__all__ = ['set_state_dict'] + + +def set_state_dict(model, state_dict): + """ + Set state dict from origin model to supernet model. + + Args: + model(paddle.nn.Layer): model after convert to supernet. + state_dict(dict): dict with the type of {name: param} in origin model. + """ + assert isinstance(model, Layer) + assert isinstance(state_dict, dict) + for name, param in model.named_parameters(): + tmp_n = name.split('.')[:-2] + [name.split('.')[-1]] + tmp_n = '.'.join(tmp_n) + if name in state_dict: + param.set_value(state_dict[name]) + elif tmp_n in state_dict: + param.set_value(state_dict[tmp_n]) + else: + _logger.info('{} is not in state_dict'.format(tmp_n)) + def compute_start_end(kernel_size, sub_kernel_size): center = kernel_size // 2 @@ -44,13 +90,3 @@ def search_idx(num, sorted_nestlist): return idx, phase_idx assert num > max_num return len(sorted_nestlist) - 1, max_idx - - -def get_paddle_version(): - import paddle - pd_ver = 185 - if hasattr(paddle, 'nn'): - if hasattr(paddle.nn, 'Conv1D'): ### judge 2.0 alpha - pd_ver = 200 - - return pd_ver diff --git a/tests/test_ofa.py b/tests/test_ofa.py index 7e95ae92..e9ae2559 100644 --- a/tests/test_ofa.py +++ b/tests/test_ofa.py @@ -243,7 +243,8 @@ class TestOFA(unittest.TestCase): default_distill_config = { 'lambda_distill': 0.01, 'teacher_model': self.teacher_model, - 'mapping_layers': ['models.0.fn'] + 'mapping_layers': ['models.0.fn'], + 'mapping_op': 'conv2d' } self.distill_config = DistillConfig(**default_distill_config) self.elastic_order = ['kernel_size', 'width', 'depth'] @@ -289,7 +290,6 @@ class TestOFACase1(TestOFA): self.model = ModelLinear() self.teacher_model = ModelLinear() data_np = np.random.random((3, 64)).astype(np.int64) - self.data = paddle.to_tensor(data_np) def init_config(self): @@ -305,12 +305,14 @@ class TestOFACase1(TestOFA): default_distill_config = { 'lambda_distill': 0.01, 'teacher_model': self.teacher_model, + 'mapping_op': 'linear', + 'mapping_layers': ['models.3.fn'], } self.distill_config = DistillConfig(**default_distill_config) self.elastic_order = None -class TestOFACase2(TestOFACase1): +class TestOFACase2(TestOFA): def init_model_and_data(self): self.model = ModelLinear1() self.teacher_model = ModelLinear1() @@ -318,6 +320,23 @@ class TestOFACase2(TestOFACase1): self.data = paddle.to_tensor(data_np) + def init_config(self): + default_run_config = { + 'train_batch_size': 1, + 'n_epochs': [[2, 5]], + 'init_learning_rate': [[0.003, 0.001]], + 'dynamic_batch_size': [1], + 'total_images': 1, + } + self.run_config = RunConfig(**default_run_config) + default_distill_config = { + 'lambda_distill': 0.01, + 'teacher_model': self.teacher_model, + 'mapping_layers': ['models.3.fn'], + } + self.distill_config = DistillConfig(**default_distill_config) + self.elastic_order = None + class TestOFACase3(unittest.TestCase): def test_ofa(self): @@ -326,7 +345,7 @@ class TestOFACase3(unittest.TestCase): ofa_model.set_net_config({'expand_ratio': None}) -class TestOFACase3(unittest.TestCase): +class TestOFACase4(unittest.TestCase): def test_ofa(self): self.model = ModelConv2() diff --git a/tests/test_ofa_utils.py b/tests/test_ofa_utils.py new file mode 100644 index 00000000..e368774a --- /dev/null +++ b/tests/test_ofa_utils.py @@ -0,0 +1,129 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +sys.path.append("../") +import unittest +import numpy as np +import paddle +import paddle.nn as nn +from paddle.vision.models import mobilenet_v1 +from paddleslim.nas.ofa.convert_super import Convert, supernet +from paddleslim.nas.ofa.utils import compute_neuron_head_importance, reorder_head, reorder_neuron, set_state_dict + + +class TestComputeImportance(unittest.TestCase): + def setUp(self): + self.model = self.init_model() + self.data_loader = self.init_data() + + def init_model(self): + class TestModel(nn.Layer): + def __init__(self): + super(TestModel, self).__init__() + encoder_layer = nn.TransformerEncoderLayer( + 312, + 12, + 1024, + dropout=0.1, + activation='gelu', + attn_dropout=0.1, + act_dropout=0) + self.encoder = nn.TransformerEncoder(encoder_layer, 3) + self.fc = nn.Linear(312, 3) + + def forward(self, + input_ids, + segment_ids, + attention_mask=[None, None]): + src = input_ids + segment_ids + out = self.encoder(src, attention_mask) + out = self.fc(out[:, 0]) + return out + + return TestModel() + + def init_data(self): + batch_size = 16 + hidden_size = 312 + d_model = 26 + input_ids = np.random.rand(batch_size, d_model, + hidden_size).astype("float32") + segment_ids = np.random.rand(batch_size, d_model, + hidden_size).astype("float32") + labels = np.random.randint(0, high=3, size=(batch_size, 1)) + data = ((paddle.to_tensor(input_ids), paddle.to_tensor(segment_ids), + paddle.to_tensor(labels)), ) + return data + + def reorder_reorder_neuron_head(self, model, head_importance, + neuron_importance): + # reorder heads and ffn neurons + for layer, current_importance in enumerate(neuron_importance): + # reorder heads + idx = paddle.argsort(head_importance[layer], descending=True) + reorder_head(model.encoder.layers[layer].self_attn, idx) + # reorder neurons + idx = paddle.argsort( + paddle.to_tensor(current_importance), descending=True) + reorder_neuron(model.encoder.layers[layer].linear1, idx, dim=1) + reorder_neuron(model.encoder.layers[layer].linear2, idx, dim=0) + + def test_compute(self): + head_importance, neuron_importance = compute_neuron_head_importance( + task_name='xnli', + model=self.model, + data_loader=self.data_loader, + num_layers=3, + num_heads=12) + assert (len(head_importance) == 3) + assert (len(neuron_importance) == 3) + self.reorder_reorder_neuron_head(self.model, head_importance, + neuron_importance) + + +class TestComputeImportanceCase1(TestComputeImportance): + def test_compute(self): + for batch in self.data_loader: + input_ids, segment_ids, labels = batch + logits = self.model( + input_ids, segment_ids, attention_mask=[None, None]) + assert logits.shape[1] == 3 + + +class TestComputeImportanceCase2(TestComputeImportance): + def test_compute(self): + head_mask = paddle.ones(shape=[12], dtype='float32') + for batch in self.data_loader: + input_ids, segment_ids, labels = batch + logits = self.model( + input_ids, segment_ids, attention_mask=[None, head_mask]) + assert logits.shape[1] == 3 + + +class TestSetStateDict(unittest.TestCase): + def setUp(self): + self.model = mobilenet_v1() + self.origin_weights = {} + for name, param in self.model.named_parameters(): + self.origin_weights[name] = param + + def test_set_state_dict(self): + sp_net_config = supernet(expand_ratio=[0.5, 1.0]) + sp_model = Convert(sp_net_config).convert(self.model) + set_state_dict(sp_model, self.origin_weights) + + +if __name__ == '__main__': + unittest.main() -- GitLab