diff --git a/demo/ofa/bert/export_model.py b/demo/ofa/bert/export_model.py new file mode 100644 index 0000000000000000000000000000000000000000..9763ef14fc75bb53f87737e13eb00956fa1d559c --- /dev/null +++ b/demo/ofa/bert/export_model.py @@ -0,0 +1,150 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os +import random +import time +import json +from functools import partial + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from paddlenlp.transformers import BertModel, BertForSequenceClassification, BertTokenizer +from paddlenlp.utils.log import logger +from paddleslim.nas.ofa import OFA, utils +from paddleslim.nas.ofa.convert_super import Convert, supernet +from paddleslim.nas.ofa.layers import BaseBlock + +MODEL_CLASSES = {"bert": (BertForSequenceClassification, BertTokenizer), } + + +def parse_args(): + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument( + "--model_type", + default=None, + type=str, + required=True, + help="Model type selected in the list: " + + ", ".join(MODEL_CLASSES.keys()), ) + parser.add_argument( + "--model_name_or_path", + default=None, + type=str, + required=True, + help="Path to pre-trained model or shortcut name selected in the list: " + + ", ".join( + sum([ + list(classes[-1].pretrained_init_configuration.keys()) + for classes in MODEL_CLASSES.values() + ], [])), ) + parser.add_argument( + "--sub_model_output_dir", + default=None, + type=str, + help="The output directory where the sub model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--static_sub_model", + default=None, + type=str, + help="The output directory where the sub static model will be written. If set to None, not export static model", + ) + parser.add_argument( + "--max_seq_length", + default=128, + type=int, + help="The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.", ) + parser.add_argument( + "--n_gpu", + type=int, + default=1, + help="number of gpus to use, 0 for cpu.") + parser.add_argument( + '--width_mult', + type=float, + default=1.0, + help="width mult you want to export") + args = parser.parse_args() + return args + + +def export_static_model(model, model_path, max_seq_length): + input_shape = [ + paddle.static.InputSpec( + shape=[None, max_seq_length], dtype='int64'), + paddle.static.InputSpec( + shape=[None, max_seq_length], dtype='int64') + ] + net = paddle.jit.to_static(model, input_spec=input_shape) + paddle.jit.save(net, model_path) + + +def do_train(args): + paddle.set_device("gpu" if args.n_gpu else "cpu") + args.model_type = args.model_type.lower() + model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + config_path = os.path.join(args.model_name_or_path, 'model_config.json') + cfg_dict = dict(json.loads(open(config_path).read())) + num_labels = cfg_dict['num_classes'] + + model = model_class.from_pretrained( + args.model_name_or_path, num_classes=num_labels) + + origin_model = model_class.from_pretrained( + args.model_name_or_path, num_classes=num_labels) + + sp_config = supernet(expand_ratio=[1.0, args.width_mult]) + model = Convert(sp_config).convert(model) + + ofa_model = OFA(model) + + sd = paddle.load( + os.path.join(args.model_name_or_path, 'model_state.pdparams')) + ofa_model.model.set_state_dict(sd) + best_config = utils.dynabert_config(ofa_model, args.width_mult) + ofa_model.export( + origin_model, + best_config, + input_shapes=[[1, args.max_seq_length], [1, args.max_seq_length]], + input_dtypes=['int64', 'int64']) + for name, sublayer in origin_model.named_sublayers(): + if isinstance(sublayer, paddle.nn.MultiHeadAttention): + sublayer.num_heads = int(args.width_mult * sublayer.num_heads) + + if args.static_sub_model != None: + export_static_model(origin_model, args.static_sub_model, + args.max_seq_length) + + +def print_arguments(args): + """print arguments""" + print('----------- Configuration Arguments -----------') + for arg, value in sorted(vars(args).items()): + print('%s: %s' % (arg, value)) + print('------------------------------------------------') + + +if __name__ == "__main__": + args = parse_args() + print_arguments(args) + do_train(args) diff --git a/demo/ofa/bert/run_glue_ofa.py b/demo/ofa/bert/run_glue_ofa.py index ef35e041b10276afa059d993d708feff2eb03dc3..1618ee586984625925e55c601fb5e066586ab81c 100644 --- a/demo/ofa/bert/run_glue_ofa.py +++ b/demo/ofa/bert/run_glue_ofa.py @@ -31,6 +31,7 @@ from paddlenlp.utils.log import logger from paddlenlp.metrics import AccuracyAndF1, Mcc, PearsonAndSpearman import paddlenlp.datasets as datasets from paddleslim.nas.ofa import OFA, DistillConfig, utils +from paddleslim.nas.ofa.utils import nlp_utils from paddleslim.nas.ofa.convert_super import Convert, supernet TASK_CLASSES = { @@ -215,13 +216,13 @@ def reorder_neuron_head(model, head_importance, neuron_importance): for layer, current_importance in enumerate(neuron_importance): # reorder heads idx = paddle.argsort(head_importance[layer], descending=True) - utils.reorder_head(model.bert.encoder.layers[layer].self_attn, idx) + nlp_utils.reorder_head(model.bert.encoder.layers[layer].self_attn, idx) # reorder neurons idx = paddle.argsort( paddle.to_tensor(current_importance), descending=True) - utils.reorder_neuron( + nlp_utils.reorder_neuron( model.bert.encoder.layers[layer].linear1.fn, idx, dim=1) - utils.reorder_neuron( + nlp_utils.reorder_neuron( model.bert.encoder.layers[layer].linear2.fn, idx, dim=0) @@ -422,7 +423,7 @@ def do_train(args): # Step6: Calculate the importance of neurons and head, # and then reorder them according to the importance. - head_importance, neuron_importance = utils.compute_neuron_head_importance( + head_importance, neuron_importance = nlp_utils.compute_neuron_head_importance( args.task_name, ofa_model.model, dev_data_loader, @@ -512,7 +513,7 @@ def do_train(args): dev_data_loader, width_mult=100) for idx, width_mult in enumerate(args.width_mult_list): - net_config = apply_config(ofa_model, width_mult) + net_config = utils.dynabert_config(ofa_model, width_mult) ofa_model.set_net_config(net_config) tic_eval = time.time() if args.task_name == "mnli": diff --git a/demo/ofa/bert/run_glue_ofa_depth.py b/demo/ofa/bert/run_glue_ofa_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..b37e234cdb9598cf43817b925d63147d70667d3a --- /dev/null +++ b/demo/ofa/bert/run_glue_ofa_depth.py @@ -0,0 +1,606 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os +import random +import time +import math +from functools import partial + +import numpy as np +import paddle +import paddle.nn.functional as F +from paddle.io import DataLoader +from paddle.metric import Metric, Accuracy, Precision, Recall + +from paddlenlp.data import Stack, Tuple, Pad +from paddlenlp.data.sampler import SamplerHelper +from paddlenlp.transformers import BertModel, BertForSequenceClassification, BertTokenizer +from paddlenlp.utils.log import logger +from paddlenlp.metrics import AccuracyAndF1, Mcc, PearsonAndSpearman +import paddlenlp.datasets as datasets +from paddleslim.nas.ofa import OFA, RunConfig, DistillConfig, utils +from paddleslim.nas.ofa.utils import nlp_utils +from paddleslim.nas.ofa.convert_super import Convert, supernet + +TASK_CLASSES = { + "cola": (datasets.GlueCoLA, Mcc), + "sst-2": (datasets.GlueSST2, Accuracy), + "mrpc": (datasets.GlueMRPC, AccuracyAndF1), + "sts-b": (datasets.GlueSTSB, PearsonAndSpearman), + "qqp": (datasets.GlueQQP, AccuracyAndF1), + "mnli": (datasets.GlueMNLI, Accuracy), + "qnli": (datasets.GlueQNLI, Accuracy), + "rte": (datasets.GlueRTE, Accuracy), +} + +MODEL_CLASSES = {"bert": (BertForSequenceClassification, BertTokenizer), } + + +def parse_args(): + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument( + "--task_name", + default=None, + type=str, + required=True, + help="The name of the task to train selected in the list: " + + ", ".join(TASK_CLASSES.keys()), ) + parser.add_argument( + "--model_type", + default=None, + type=str, + required=True, + help="Model type selected in the list: " + + ", ".join(MODEL_CLASSES.keys()), ) + parser.add_argument( + "--model_name_or_path", + default=None, + type=str, + required=True, + help="Path to pre-trained model or shortcut name selected in the list: " + + ", ".join( + sum([ + list(classes[-1].pretrained_init_configuration.keys()) + for classes in MODEL_CLASSES.values() + ], [])), ) + parser.add_argument( + "--output_dir", + default=None, + type=str, + required=True, + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--max_seq_length", + default=128, + type=int, + help="The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.", ) + parser.add_argument( + "--batch_size", + default=8, + type=int, + help="Batch size per GPU/CPU for training.", ) + parser.add_argument( + "--learning_rate", + default=5e-5, + type=float, + help="The initial learning rate for Adam.") + parser.add_argument( + "--weight_decay", + default=0.0, + type=float, + help="Weight decay if we apply some.") + parser.add_argument( + "--adam_epsilon", + default=1e-8, + type=float, + help="Epsilon for Adam optimizer.") + parser.add_argument( + "--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument( + "--lambda_logit", + default=1.0, + type=float, + help="lambda for logit loss.") + parser.add_argument( + "--lambda_rep", + default=0.1, + type=float, + help="lambda for hidden state distillation loss.") + parser.add_argument( + "--num_train_epochs", + default=3, + type=int, + help="Total number of training epochs to perform.", ) + parser.add_argument( + "--max_steps", + default=-1, + type=int, + help="If > 0: set total number of training steps to perform. Override num_train_epochs.", + ) + parser.add_argument( + "--warmup_steps", + default=0, + type=int, + help="Linear warmup over warmup_steps.") + parser.add_argument( + "--logging_steps", + type=int, + default=500, + help="Log every X updates steps.") + parser.add_argument( + "--save_steps", + type=int, + default=500, + help="Save checkpoint every X updates steps.") + parser.add_argument( + "--seed", type=int, default=42, help="random seed for initialization") + parser.add_argument( + "--n_gpu", + type=int, + default=1, + help="number of gpus to use, 0 for cpu.") + parser.add_argument( + '--width_mult_list', + nargs='+', + type=float, + default=[1.0, 5 / 6, 2 / 3, 0.5], + help="width mult in compress") + parser.add_argument( + '--depth_mult_list', + nargs='+', + type=float, + default=[1.0, 0.75, 0.5], + help="width mult in compress") + args = parser.parse_args() + return args + + +def set_seed(args): + random.seed(args.seed + paddle.distributed.get_rank()) + np.random.seed(args.seed + paddle.distributed.get_rank()) + paddle.seed(args.seed + paddle.distributed.get_rank()) + + +def evaluate(model, + criterion, + metric, + data_loader, + width_mult=1.0, + depth_mult=1.0): + with paddle.no_grad(): + model.eval() + metric.reset() + for batch in data_loader: + input_ids, segment_ids, labels = batch + logits = model(input_ids, segment_ids, attention_mask=[None, None]) + if isinstance(logits, tuple): + logits = logits[0] + loss = criterion(logits, labels) + correct = metric.compute(logits, labels) + metric.update(correct) + results = metric.accumulate() + print( + "depth_mult: %f, width_mult: %f, eval loss: %f, %s: %s\n" % + (depth_mult, width_mult, loss.numpy(), metric.name(), results), + end='') + model.train() + + +### monkey patch for bert forward to accept [attention_mask, head_mask] as attention_mask +def bert_forward(self, + input_ids, + token_type_ids=None, + position_ids=None, + attention_mask=[None, None], + depth_mult=1.0): + wtype = self.pooler.dense.fn.weight.dtype if hasattr( + self.pooler.dense, 'fn') else self.pooler.dense.weight.dtype + if attention_mask[0] is None: + attention_mask[0] = paddle.unsqueeze( + (input_ids == self.pad_token_id).astype(wtype) * -1e9, axis=[1, 2]) + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids) + encoder_outputs = self.encoder( + embedding_output, attention_mask, depth_mult=depth_mult) + sequence_output = encoder_outputs + pooled_output = self.pooler(sequence_output) + return sequence_output, pooled_output + + +BertModel.forward = bert_forward + + +def transformer_encoder_forward(self, src, src_mask=None, depth_mult=1.): + output = src + + depth = round(self.num_layers * depth_mult) + kept_layers_index = [] + for i in range(1, depth + 1): + kept_layers_index.append(math.floor(i / depth_mult) - 1) + + for i in kept_layers_index: + output = self.layers[i](output, src_mask=src_mask) + + if self.norm is not None: + output = self.norm(output) + + return output + + +paddle.nn.TransformerEncoder.forward = transformer_encoder_forward + + +def sequence_forward(self, + input_ids, + token_type_ids=None, + position_ids=None, + attention_mask=[None, None], + depth=1.0): + _, pooled_output = self.bert( + input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + attention_mask=attention_mask, + depth_mult=depth) + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + return logits + + +BertForSequenceClassification.forward = sequence_forward + + +def soft_cross_entropy(inp, target): + inp_likelihood = F.log_softmax(inp, axis=-1) + target_prob = F.softmax(target, axis=-1) + return -1. * paddle.mean(paddle.sum(inp_likelihood * target_prob, axis=-1)) + + +def convert_example(example, + tokenizer, + label_list, + max_seq_length=512, + is_test=False): + """convert a glue example into necessary features""" + + def _truncate_seqs(seqs, max_seq_length): + if len(seqs) == 1: # single sentence + # Account for [CLS] and [SEP] with "- 2" + seqs[0] = seqs[0][0:(max_seq_length - 2)] + else: # sentence pair + # Account for [CLS], [SEP], [SEP] with "- 3" + tokens_a, tokens_b = seqs + max_seq_length -= 3 + while True: # truncate with longest_first strategy + total_length = len(tokens_a) + len(tokens_b) + if total_length <= max_seq_length: + break + if len(tokens_a) > len(tokens_b): + tokens_a.pop() + else: + tokens_b.pop() + return seqs + + def _concat_seqs(seqs, separators, seq_mask=0, separator_mask=1): + concat = sum((seq + sep for sep, seq in zip(separators, seqs)), []) + segment_ids = sum(([i] * (len(seq) + len(sep)) for i, (sep, seq) in + enumerate(zip(separators, seqs))), []) + if isinstance(seq_mask, int): + seq_mask = [[seq_mask] * len(seq) for seq in seqs] + if isinstance(separator_mask, int): + separator_mask = [[separator_mask] * len(sep) for sep in separators] + p_mask = sum((s_mask + mask for sep, seq, s_mask, mask in + zip(separators, seqs, seq_mask, separator_mask)), []) + return concat, segment_ids, p_mask + + if not is_test: + # `label_list == None` is for regression task + label_dtype = "int64" if label_list else "float32" + # get the label + label = example[-1] + example = example[:-1] + #create label maps if classification task + if label_list: + label_map = {} + for (i, l) in enumerate(label_list): + label_map[l] = i + label = label_map[label] + label = np.array([label], dtype=label_dtype) + + # tokenize raw text + tokens_raw = [tokenizer(l) for l in example] + # truncate to the truncate_length, + tokens_trun = _truncate_seqs(tokens_raw, max_seq_length) + # concate the sequences with special tokens + tokens_trun[0] = [tokenizer.cls_token] + tokens_trun[0] + tokens, segment_ids, _ = _concat_seqs(tokens_trun, [[tokenizer.sep_token]] * + len(tokens_trun)) + # convert the token to ids + input_ids = tokenizer.convert_tokens_to_ids(tokens) + valid_length = len(input_ids) + # The mask has 1 for real tokens and 0 for padding tokens. Only real + # tokens are attended to. + # input_mask = [1] * len(input_ids) + if not is_test: + return input_ids, segment_ids, valid_length, label + else: + return input_ids, segment_ids, valid_length + + +def do_train(args): + paddle.set_device("gpu" if args.n_gpu else "cpu") + if paddle.distributed.get_world_size() > 1: + paddle.distributed.init_parallel_env() + + set_seed(args) + + args.task_name = args.task_name.lower() + dataset_class, metric_class = TASK_CLASSES[args.task_name] + args.model_type = args.model_type.lower() + model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + + train_ds = dataset_class.get_datasets(['train']) + + tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) + trans_func = partial( + convert_example, + tokenizer=tokenizer, + label_list=train_ds.get_labels(), + max_seq_length=args.max_seq_length) + train_ds = train_ds.apply(trans_func, lazy=True) + train_batch_sampler = paddle.io.DistributedBatchSampler( + train_ds, batch_size=args.batch_size, shuffle=True) + batchify_fn = lambda samples, fn=Tuple( + Pad(axis=0, pad_val=tokenizer.pad_token_id), # input + Pad(axis=0, pad_val=tokenizer.pad_token_id), # segment + Stack(), # length + Stack(dtype="int64" if train_ds.get_labels() else "float32") # label + ): [data for i, data in enumerate(fn(samples)) if i != 2] + train_data_loader = DataLoader( + dataset=train_ds, + batch_sampler=train_batch_sampler, + collate_fn=batchify_fn, + num_workers=0, + return_list=True) + if args.task_name == "mnli": + dev_dataset_matched, dev_dataset_mismatched = dataset_class.get_datasets( + ["dev_matched", "dev_mismatched"]) + dev_dataset_matched = dev_dataset_matched.apply(trans_func, lazy=True) + dev_dataset_mismatched = dev_dataset_mismatched.apply( + trans_func, lazy=True) + dev_batch_sampler_matched = paddle.io.BatchSampler( + dev_dataset_matched, batch_size=args.batch_size, shuffle=False) + dev_data_loader_matched = DataLoader( + dataset=dev_dataset_matched, + batch_sampler=dev_batch_sampler_matched, + collate_fn=batchify_fn, + num_workers=0, + return_list=True) + dev_batch_sampler_mismatched = paddle.io.BatchSampler( + dev_dataset_mismatched, batch_size=args.batch_size, shuffle=False) + dev_data_loader_mismatched = DataLoader( + dataset=dev_dataset_mismatched, + batch_sampler=dev_batch_sampler_mismatched, + collate_fn=batchify_fn, + num_workers=0, + return_list=True) + else: + dev_dataset = dataset_class.get_datasets(["dev"]) + dev_dataset = dev_dataset.apply(trans_func, lazy=True) + dev_batch_sampler = paddle.io.BatchSampler( + dev_dataset, batch_size=args.batch_size, shuffle=False) + dev_data_loader = DataLoader( + dataset=dev_dataset, + batch_sampler=dev_batch_sampler, + collate_fn=batchify_fn, + num_workers=0, + return_list=True) + + num_labels = 1 if train_ds.get_labels() == None else len( + train_ds.get_labels()) + + # Step1: Initialize the origin BERT model. + model = model_class.from_pretrained( + args.model_name_or_path, num_classes=num_labels) + if paddle.distributed.get_world_size() > 1: + model = paddle.DataParallel(model) + + # Step2: Convert origin model to supernet. + sp_config = supernet(expand_ratio=args.width_mult_list) + model = Convert(sp_config).convert(model) + + # Use weights saved in the dictionary to initialize supernet. + weights_path = os.path.join(args.model_name_or_path, 'model_state.pdparams') + origin_weights = paddle.load(weights_path) + model.set_state_dict(origin_weights) + + # Step3: Define teacher model. + teacher_model = model_class.from_pretrained( + args.model_name_or_path, num_classes=num_labels) + new_dict = utils.utils.remove_model_fn(teacher_model, origin_weights) + teacher_model.set_state_dict(new_dict) + del origin_weights, new_dict + + default_run_config = {'elastic_depth': args.depth_mult_list} + run_config = RunConfig(**default_run_config) + + # Step4: Config about distillation. + mapping_layers = ['bert.embeddings'] + for idx in range(model.bert.config['num_hidden_layers']): + mapping_layers.append('bert.encoder.layers.{}'.format(idx)) + + default_distill_config = { + 'lambda_distill': args.lambda_rep, + 'teacher_model': teacher_model, + 'mapping_layers': mapping_layers, + } + distill_config = DistillConfig(**default_distill_config) + + # Step5: Config in supernet training. + ofa_model = OFA(model, + run_config=run_config, + distill_config=distill_config, + elastic_order=['depth']) + #elastic_order=['width']) + + criterion = paddle.nn.loss.CrossEntropyLoss() if train_ds.get_labels( + ) else paddle.nn.loss.MSELoss() + + metric = metric_class() + + if args.task_name == "mnli": + dev_data_loader = (dev_data_loader_matched, dev_data_loader_mismatched) + + lr_scheduler = paddle.optimizer.lr.LambdaDecay( + args.learning_rate, + lambda current_step, num_warmup_steps=args.warmup_steps, + num_training_steps=args.max_steps if args.max_steps > 0 else + (len(train_data_loader) * args.num_train_epochs): float( + current_step) / float(max(1, num_warmup_steps)) + if current_step < num_warmup_steps else max( + 0.0, + float(num_training_steps - current_step) / float( + max(1, num_training_steps - num_warmup_steps)))) + + optimizer = paddle.optimizer.AdamW( + learning_rate=lr_scheduler, + epsilon=args.adam_epsilon, + parameters=ofa_model.model.parameters(), + weight_decay=args.weight_decay, + apply_decay_param_fun=lambda x: x in [ + p.name for n, p in ofa_model.model.named_parameters() + if not any(nd in n for nd in ["bias", "norm"]) + ]) + + global_step = 0 + tic_train = time.time() + for epoch in range(args.num_train_epochs): + # Step6: Set current epoch and task. + ofa_model.set_epoch(epoch) + ofa_model.set_task('depth') + + for step, batch in enumerate(train_data_loader): + global_step += 1 + input_ids, segment_ids, labels = batch + + for depth_mult in args.depth_mult_list: + for width_mult in args.width_mult_list: + # Step7: Broadcast supernet config from width_mult, + # and use this config in supernet training. + net_config = utils.dynabert_config(ofa_model, width_mult, + depth_mult) + ofa_model.set_net_config(net_config) + logits, teacher_logits = ofa_model( + input_ids, segment_ids, attention_mask=[None, None]) + rep_loss = ofa_model.calc_distill_loss() + if args.task_name == 'sts-b': + logit_loss = 0.0 + else: + logit_loss = soft_cross_entropy(logits, + teacher_logits.detach()) + loss = rep_loss + args.lambda_logit * logit_loss + loss.backward() + optimizer.step() + lr_scheduler.step() + ofa_model.model.clear_gradients() + + if global_step % args.logging_steps == 0: + if (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0: + logger.info( + "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s" + % (global_step, epoch, step, loss, + args.logging_steps / (time.time() - tic_train))) + tic_train = time.time() + + if global_step % args.save_steps == 0: + if args.task_name == "mnli": + evaluate( + teacher_model, + criterion, + metric, + dev_data_loader_matched, + width_mult=100) + evaluate( + teacher_model, + criterion, + metric, + dev_data_loader_mismatched, + width_mult=100) + else: + evaluate( + teacher_model, + criterion, + metric, + dev_data_loader, + width_mult=100) + for depth_mult in args.depth_mult_list: + for width_mult in args.width_mult_list: + net_config = utils.dynabert_config( + ofa_model, width_mult, depth_mult) + ofa_model.set_net_config(net_config) + tic_eval = time.time() + if args.task_name == "mnli": + acc = evaluate(ofa_model, criterion, metric, + dev_data_loader_matched, width_mult, + depth_mult) + evaluate(ofa_model, criterion, metric, + dev_data_loader_mismatched, width_mult, + depth_mult) + print("eval done total : %s s" % + (time.time() - tic_eval)) + else: + acc = evaluate(ofa_model, criterion, metric, + dev_data_loader, width_mult, + depth_mult) + print("eval done total : %s s" % + (time.time() - tic_eval)) + + if (not args.n_gpu > 1 + ) or paddle.distributed.get_rank() == 0: + output_dir = os.path.join(args.output_dir, + "model_%d" % global_step) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + # need better way to get inner model of DataParallel + model_to_save = model._layers if isinstance( + model, paddle.DataParallel) else model + model_to_save.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + + +def print_arguments(args): + """print arguments""" + print('----------- Configuration Arguments -----------') + for arg, value in sorted(vars(args).items()): + print('%s: %s' % (arg, value)) + print('------------------------------------------------') + + +if __name__ == "__main__": + args = parse_args() + print_arguments(args) + if args.n_gpu > 1: + paddle.distributed.spawn(do_train, args=(args, ), nprocs=args.n_gpu) + else: + do_train(args) diff --git a/demo/ofa/ernie/ernie_supernet/modeling_ernie_supernet.py b/demo/ofa/ernie/ernie_supernet/modeling_ernie_supernet.py index 3b2259a3caf5ce8be09df636a01b931347d66e07..a232bbe789828a38b35afb7ae5cd219d95c0a848 100644 --- a/demo/ofa/ernie/ernie_supernet/modeling_ernie_supernet.py +++ b/demo/ofa/ernie/ernie_supernet/modeling_ernie_supernet.py @@ -26,6 +26,10 @@ import logging import logging from functools import partial import six +if six.PY2: + from pathlib2 import Path +else: + from pathlib import Path import paddle.fluid.dygraph as D import paddle.fluid as F @@ -288,8 +292,16 @@ def get_config(pretrain_dir_or_url): 'ernie-2.0-large-en': bce + 'model-ernie2.0-large-en.1.tar.gz', 'ernie-tiny': bce + 'model-ernie_tiny.1.tar.gz', } - url = resource_map[pretrain_dir_or_url] - pretrain_dir = _fetch_from_remote(url, False) + + if not Path(pretrain_dir_or_url).exists() and str( + pretrain_dir_or_url) in resource_map: + url = resource_map[pretrain_dir_or_url] + pretrain_dir = _fetch_from_remote(url, False) + else: + log.info('pretrain dir %s not in %s, read from local' % + (pretrain_dir_or_url, repr(resource_map))) + pretrain_dir = Path(pretrain_dir_or_url) + config_path = os.path.join(pretrain_dir, 'ernie_config.json') if not os.path.exists(config_path): raise ValueError('config path not found: %s' % config_path) diff --git a/docs/zh_cn/api_cn/ofa_api.rst b/docs/zh_cn/api_cn/ofa_api.rst index ff0e58924a0eea074b3b477be3dbd19d04f82ced..304cbb040cc7f84a36946d92e4ca5ca7ba70c198 100644 --- a/docs/zh_cn/api_cn/ofa_api.rst +++ b/docs/zh_cn/api_cn/ofa_api.rst @@ -88,9 +88,14 @@ OFA实例 .. code-block:: python - from paddlslim.nas.ofa import OFA - - ofa_model = OFA(model) + from paddle.vision.models import mobilenet_v1 + from paddleslim.nas.ofa import OFA + from paddleslim.nas.ofa.convert_super import Convert, supernet + + model = mobilenet_v1() + sp_net_config = supernet(kernel_size=(3, 5, 7), expand_ratio=[1, 2, 4]) + sp_model = Convert(sp_net_config).convert(model) + ofa_model = OFA(sp_model) .. .. py:method:: set_epoch(epoch) @@ -140,7 +145,7 @@ OFA实例 .. code-block:: python - config = ofa_model.current_config + config = {'conv2d_0': {'expand_ratio': 2}, 'conv2d_1': {'expand_ratio': 2}} ofa_model.set_net_config(config) .. py:method:: calc_distill_loss() @@ -159,15 +164,25 @@ OFA实例 .. py:method:: search() ### TODO - .. py:method:: export(config) + .. py:method:: export(origin_model, config, input_shapes, input_dtypes, load_weights_from_supernet=True) - 根据传入的子网络配置导出当前子网络的参数。 + 根据传入的原始模型结构、子网络配置,模型输入的形状和类型导出当前子网络,导出的子网络可以进一步训练、预测或者调用框架动静转换功能转为静态图模型。 **参数:** - - **config(dict):** 某个子网络每层的配置。 + - **origin_model(paddle.nn.Layer):** 原始模型实例,子模型会直接在原始模型的基础上进行修改。 + - **config(dict):** 某个子网络每层的配置,可以用。 + - **input_shapes(list|list(list)):** 模型输入的形状。 + - **input_dtypes(list):** 模型输入的类型。 + - **load_weights_from_supernet(bool, optional):** 是否从超网络加载参数。若为False,则不从超网络加载参数,则只根据config裁剪原始模型的网络结构;若为True,则用超网络参数来初始化原始模型,并根据config裁剪原始模型的网络结构。默认:True。 **返回:** - TODO + 子模型实例。 **示例代码:** - TODO + + .. code-block:: python + from paddle.vision.models import mobilenet_v1 + origin_model = mobilenet_v1() + + config = {'conv2d_0': {'expand_ratio': 2}, 'conv2d_1': {'expand_ratio': 2}} + origin_model = ofa_model.export(origin_model, config, input_shapes=[1, 3, 28, 28], input_dtypes=['float32']) diff --git a/paddleslim/analysis/flops.py b/paddleslim/analysis/flops.py index 7e01d12c64d6bb731788f157c22cd35ea6d459f4..194efadec2ebc22016c8847e7c872b74a9544dc3 100644 --- a/paddleslim/analysis/flops.py +++ b/paddleslim/analysis/flops.py @@ -13,7 +13,6 @@ # limitations under the License. import paddle import numpy as np -import paddle.jit as jit from ..core import GraphWrapper, dygraph2program __all__ = ["flops", "dygraph_flops"] diff --git a/paddleslim/nas/ofa/__init__.py b/paddleslim/nas/ofa/__init__.py index 8f047bd6381aad07d38bd0367ba6e6f30d628c7d..21e19995aedd48d1a048aad0dca86d54b2275a38 100644 --- a/paddleslim/nas/ofa/__init__.py +++ b/paddleslim/nas/ofa/__init__.py @@ -14,6 +14,8 @@ from .ofa import OFA, RunConfig, DistillConfig from .convert_super import supernet +from .utils.special_config import * +from .get_sub_model import * from .utils.utils import get_paddle_version pd_ver = get_paddle_version() diff --git a/paddleslim/nas/ofa/convert_super.py b/paddleslim/nas/ofa/convert_super.py index 0d74283d19159823687df84c35c90b9a571d4d07..df460720889b9ec320b4573618dc25905e8aabc0 100644 --- a/paddleslim/nas/ofa/convert_super.py +++ b/paddleslim/nas/ofa/convert_super.py @@ -579,14 +579,10 @@ class Convert: new_attr_name = [] if pd_ver == 185: new_attr_name += [ - 'size', 'is_sparse', 'is_distributed', 'param_attr', - 'dtype' + 'is_sparse', 'is_distributed', 'param_attr', 'dtype' ] else: - new_attr_name += [ - 'num_embeddings', 'embedding_dim', 'sparse', - 'weight_attr', 'name' - ] + new_attr_name += ['sparse', 'weight_attr', 'name'] self._change_name(layer, pd_ver, has_bias=False) diff --git a/paddleslim/nas/ofa/get_sub_model.py b/paddleslim/nas/ofa/get_sub_model.py new file mode 100644 index 0000000000000000000000000000000000000000..1b93e694703e4a7c87775303397f7ce6c7f6b5e4 --- /dev/null +++ b/paddleslim/nas/ofa/get_sub_model.py @@ -0,0 +1,106 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle + +__all__ = ['get_prune_params_config', 'prune_params'] + + +def get_prune_params_config(graph, origin_model_config): + param_config = {} + precedor = None + for op in graph.ops(): + ### TODO(ceci3): + ### 1. fix config when this op is concat by graph.pre_ops(op) + ### 2. add kernel_size in config + ### 3. add channel in config + for inp in op.all_inputs(): + n_ops = graph.next_ops(op) + if inp._var.name in origin_model_config.keys(): + if 'expand_ratio' in origin_model_config[inp._var.name].keys(): + tmp = origin_model_config[inp._var.name]['expand_ratio'] + if len(inp._var.shape) > 1: + if inp._var.name in param_config.keys(): + param_config[inp._var.name].append(tmp) + ### first op + else: + param_config[inp._var.name] = [precedor, tmp] + else: + param_config[inp._var.name] = [tmp] + precedor = tmp + else: + precedor = None + for n_op in n_ops: + for next_inp in n_op.all_inputs(): + if next_inp._var.persistable == True: + if next_inp._var.name in origin_model_config.keys(): + if 'expand_ratio' in origin_model_config[ + next_inp._var.name].keys(): + tmp = origin_model_config[next_inp._var.name][ + 'expand_ratio'] + pre = tmp if precedor is None else precedor + if len(next_inp._var.shape) > 1: + param_config[next_inp._var.name] = [pre] + else: + param_config[next_inp._var.name] = [tmp] + else: + if len(next_inp._var. + shape) > 1 and precedor != None: + param_config[ + next_inp._var.name] = [precedor, None] + else: + param_config[next_inp._var.name] = [precedor] + + return param_config + + +def prune_params(model, param_config, super_model_sd=None): + for name, param in model.named_parameters(): + t_value = param.value().get_tensor() + value = np.array(t_value).astype("float32") + + if super_model_sd != None: + super_t_value = super_model_sd[name].value().get_tensor() + super_value = np.array(super_t_value).astype("float32") + + if param.name in param_config.keys(): + if len(param_config[param.name]) > 1: + in_exp = param_config[param.name][0] + out_exp = param_config[param.name][1] + in_chn = int(value.shape[0]) if in_exp == None else int( + value.shape[0] * in_exp) + out_chn = int(value.shape[1]) if out_exp == None else int( + value.shape[1] * out_exp) + prune_value = super_value[:in_chn, :out_chn, ...] \ + if super_model_sd != None else value[:in_chn, :out_chn, ...] + else: + out_chn = int(value.shape[0]) if param_config[param.name][ + 0] == None else int(value.shape[0] * + param_config[param.name][0]) + prune_value = super_value[:out_chn, ...] \ + if super_model_sd != None else value[:out_chn, ...] + else: + prune_value = super_value if super_model_sd != None else value + + p = t_value._place() + if p.is_cpu_place(): + place = paddle.CPUPlace() + elif p.is_cuda_pinned_place(): + place = paddle.CUDAPinnedPlace() + else: + place = paddle.CUDAPlace(p.gpu_device_id()) + t_value.set(prune_value, place) + if param.trainable: + param.clear_gradient() diff --git a/paddleslim/nas/ofa/ofa.py b/paddleslim/nas/ofa/ofa.py index f0d94321f3ddbdc387ccd3420bb87b35418a532c..e7c345e99fbe1179ec64c425c1dc4185cca5040b 100644 --- a/paddleslim/nas/ofa/ofa.py +++ b/paddleslim/nas/ofa/ofa.py @@ -17,7 +17,7 @@ import numpy as np from collections import namedtuple import paddle import paddle.fluid as fluid -from .utils.utils import get_paddle_version +from .utils.utils import get_paddle_version, remove_model_fn pd_ver = get_paddle_version() if pd_ver == 185: from .layers_old import BaseBlock, SuperConv2D, SuperLinear @@ -27,6 +27,8 @@ else: Layer = paddle.nn.Layer from .utils.utils import search_idx from ...common import get_logger +from ...core import GraphWrapper, dygraph2program +from .get_sub_model import get_prune_params_config, prune_params _logger = get_logger(__name__, level=logging.INFO) @@ -125,8 +127,14 @@ class OFA(OFABase): Examples: .. code-block:: python - from paddlslim.nas.ofa import OFA - ofa_model = OFA(model) + from paddle.vision.models import mobilenet_v1 + from paddleslim.nas.ofa import OFA + from paddleslim.nas.ofa.convert_super import Convert, supernet + + model = mobilenet_v1() + sp_net_config = supernet(kernel_size=(3, 5, 7), expand_ratio=[1, 2, 4]) + sp_model = Convert(sp_net_config).convert(model) + ofa_model = OFA(sp_model) """ @@ -206,8 +214,6 @@ class OFA(OFABase): self.model.train() def _prepare_distill(self): - self.Tacts, self.Sacts = {}, {} - if self.distill_config.teacher_model == None: logging.error( 'If you want to add distill, please input instance of teacher model' @@ -257,6 +263,11 @@ class OFA(OFABase): self.netAs_param.extend(netA.parameters()) self.netAs.append(netA) + def _reset_hook_before_forward(self): + self.Tacts, self.Sacts = {}, {} + mapping_layers = getattr(self.distill_config, 'mapping_layers', None) + if mapping_layers != None: + def get_activation(mem, name): def get_output_hook(layer, input, output): mem[name] = output @@ -369,6 +380,9 @@ class OFA(OFABase): assert len(self.netAs) > 0 for i, netA in enumerate(self.netAs): n = self.distill_config.mapping_layers[i] + ### add for elastic depth + if n not in self.Sacts.keys(): + continue Tact = self.Tacts[n] Sact = self.Sacts[n] if isinstance(netA, SuperConv2D): @@ -397,9 +411,64 @@ class OFA(OFABase): def search(self, eval_func, condition): pass - ### TODO: complete it - def export(self, config): - pass + def _export_sub_model_config(self, origin_model, config, input_shapes, + input_dtypes): + super_model_config = {} + for name, sublayer in self.model.named_sublayers(): + if isinstance(sublayer, BaseBlock): + for param in sublayer.parameters(): + super_model_config[name] = sublayer.key + + for name, value in super_model_config.items(): + super_model_config[name] = config[value] if value in config.keys( + ) else {} + + origin_model_config = {} + for name, sublayer in origin_model.named_sublayers(): + for param in sublayer.parameters(include_sublayers=False): + if name in super_model_config.keys(): + origin_model_config[param.name] = super_model_config[name] + + program = dygraph2program( + origin_model, inputs=input_shapes, dtypes=input_dtypes) + graph = GraphWrapper(program) + param_prune_config = get_prune_params_config(graph, origin_model_config) + return param_prune_config + + def export(self, + origin_model, + config, + input_shapes, + input_dtypes, + load_weights_from_supernet=True): + """ + Export the weights according origin model and sub model config. + Parameters: + origin_model(paddle.nn.Layer): the instance of original model. + config(dict): the config of sub model, can get by OFA.get_current_config() or some special config, such as paddleslim.nas.ofa.utils.dynabert_config(width_mult). + input_shapes(list|list(list)): the shape of all inputs. + input_dtypes(list): the dtype of all inputs. + load_weights_from_supernet(bool, optional): whether to load weights from SuperNet. Default: False. + Examples: + .. code-block:: python + from paddle.vision.models import mobilenet_v1 + origin_model = mobilenet_v1() + + config = {'conv2d_0': {'expand_ratio': 2}, 'conv2d_1': {'expand_ratio': 2}} + origin_model = ofa_model.export(origin_model, config, input_shapes=[1, 3, 28, 28], input_dtypes=['float32']) + """ + super_sd = None + if load_weights_from_supernet: + super_sd = remove_model_fn(origin_model, self.model.state_dict()) + + param_config = self._export_sub_model_config(origin_model, config, + input_shapes, input_dtypes) + prune_params(origin_model, param_config, super_sd) + return origin_model + + @property + def get_current_config(self): + return self.current_config def set_net_config(self, net_config): """ @@ -408,7 +477,7 @@ class OFA(OFABase): net_config(dict): special the config of sug-network. Examples: .. code-block:: python - config = ofa_model.current_config + config = {'conv2d_0': {'expand_ratio': 2}, 'conv2d_1': {'expand_ratio': 2}} ofa_model.set_net_config(config) """ self.net_config = net_config @@ -417,6 +486,7 @@ class OFA(OFABase): # ===================== teacher process ===================== teacher_output = None if self._add_teacher: + self._reset_hook_before_forward() teacher_output = self.ofa_teacher_model.model.forward(*inputs, **kwargs) # ============================================================ diff --git a/paddleslim/nas/ofa/utils/__init__.py b/paddleslim/nas/ofa/utils/__init__.py index 12386e5c93b8f3149dd67952fd7b04b6441b88b0..ed7b69e13b486152df886a5a4e5da81dfbf6e121 100644 --- a/paddleslim/nas/ofa/utils/__init__.py +++ b/paddleslim/nas/ofa/utils/__init__.py @@ -17,5 +17,3 @@ from .special_config import * from .utils import get_paddle_version pd_ver = get_paddle_version() -if pd_ver == 200: - from .nlp_utils import * diff --git a/paddleslim/nas/ofa/utils/special_config.py b/paddleslim/nas/ofa/utils/special_config.py index 26d4be42af3aec975944cbc387565090892482a9..ae2192af8affdb39a4c0aaf9201627ded5057124 100644 --- a/paddleslim/nas/ofa/utils/special_config.py +++ b/paddleslim/nas/ofa/utils/special_config.py @@ -27,10 +27,17 @@ def dynabert_config(model, width_mult, depth_mult=1.0): return True return False + start_idx = 0 + for idx, (block_k, block_v) in enumerate(model.layers.items()): + if 'linear' in block_k: + start_idx = int(block_k.split('_')[1]) + break + for idx, (block_k, block_v) in enumerate(model.layers.items()): if isinstance(block_v, dict) and len(block_v.keys()) != 0: name, name_idx = block_k.split('_'), int(block_k.split('_')[1]) - if fix_exp(name_idx) or 'emb' in block_k or idx >= block_name: + if fix_exp(name_idx - + start_idx) or 'emb' in block_k or idx >= block_name: block_v['expand_ratio'] = 1.0 else: block_v['expand_ratio'] = width_mult diff --git a/paddleslim/nas/ofa/utils/utils.py b/paddleslim/nas/ofa/utils/utils.py index a4ec2f88e73a0fead6c0edf482e11bb3e0684b03..b4d92e3f7aeee0d81d23196fc6961116c37e1867 100644 --- a/paddleslim/nas/ofa/utils/utils.py +++ b/paddleslim/nas/ofa/utils/utils.py @@ -59,6 +59,25 @@ def set_state_dict(model, state_dict): _logger.info('{} is not in state_dict'.format(tmp_n)) +def remove_model_fn(model, sd): + new_dict = {} + keys = [] + for name, param in model.named_parameters(): + keys.append(name) + for name, param in sd.items(): + if name.split('.')[-2] == 'fn': + tmp_n = name.split('.')[:-2] + [name.split('.')[-1]] + tmp_n = '.'.join(tmp_n) + #print(name, tmp_n) + if name in keys: + new_dict[name] = param + elif tmp_n in keys: + new_dict[tmp_n] = param + else: + _logger.debug('{} is not in state_dict'.format(tmp_n)) + return new_dict + + def compute_start_end(kernel_size, sub_kernel_size): center = kernel_size // 2 sub_center = sub_kernel_size // 2 diff --git a/tests/test_ofa.py b/tests/test_ofa.py index 5c86132f4b6893c94cb273239c9a3438786be193..fa125e01eced28ead3d47673f72590ddc7a8318b 100644 --- a/tests/test_ofa.py +++ b/tests/test_ofa.py @@ -139,7 +139,7 @@ class ModelConv2(nn.Layer): class ModelLinear(nn.Layer): def __init__(self): super(ModelLinear, self).__init__() - with supernet(expand_ratio=(1, 2, 4)) as ofa_super: + with supernet(expand_ratio=(1.0, 2.0, 4.0)) as ofa_super: models = [] models += [nn.Embedding(num_embeddings=64, embedding_dim=64)] models += [nn.Linear(64, 128)] @@ -167,6 +167,22 @@ class ModelLinear(nn.Layer): return inputs +class ModelOriginLinear(nn.Layer): + def __init__(self): + super(ModelOriginLinear, self).__init__() + models = [] + models += [nn.Embedding(num_embeddings=64, embedding_dim=64)] + models += [nn.Linear(64, 128)] + models += [nn.LayerNorm(128)] + models += [nn.Linear(128, 256)] + models += [nn.Linear(256, 256)] + + self.models = paddle.nn.Sequential(*models) + + def forward(self, inputs): + return self.models(inputs) + + class ModelLinear1(nn.Layer): def __init__(self): super(ModelLinear1, self).__init__() @@ -373,5 +389,40 @@ class TestOFACase4(unittest.TestCase): self.model = ModelConv2() +class TestExport(unittest.TestCase): + def setUp(self): + self._init_model() + + def _init_model(self): + self.origin_model = ModelOriginLinear() + model = ModelLinear() + self.ofa_model = OFA(model) + + def test_ofa(self): + config = { + 'embedding_1': { + 'expand_ratio': (2.0) + }, + 'linear_3': { + 'expand_ratio': (2.0) + }, + 'linear_4': {}, + 'linear_5': {} + } + origin_dict = {} + for name, param in self.origin_model.named_parameters(): + origin_dict[name] = param.shape + self.ofa_model.export( + self.origin_model, + config, + input_shapes=[[1, 64]], + input_dtypes=['int64']) + for name, param in self.origin_model.named_parameters(): + if name in config.keys(): + if 'expand_ratio' in config[name]: + assert origin_dict[name][-1] == param.shape[-1] * config[ + name]['expand_ratio'] + + if __name__ == '__main__': unittest.main() diff --git a/tests/test_ofa_utils.py b/tests/test_ofa_utils.py index eb0394b31e6486a2241f0dca88567baf32064887..6d7e70836eceaf50adb18cf022d761f6280c21df 100644 --- a/tests/test_ofa_utils.py +++ b/tests/test_ofa_utils.py @@ -20,7 +20,8 @@ import paddle import paddle.nn as nn from paddle.vision.models import mobilenet_v1 from paddleslim.nas.ofa.convert_super import Convert, supernet -from paddleslim.nas.ofa.utils import compute_neuron_head_importance, reorder_head, reorder_neuron, set_state_dict, dynabert_config +from paddleslim.nas.ofa.utils import set_state_dict, dynabert_config +from paddleslim.nas.ofa.utils.nlp_utils import compute_neuron_head_importance, reorder_head, reorder_neuron from paddleslim.nas.ofa import OFA