未验证 提交 cf1bbfc0 编写于 作者: X Xudong Ma 提交者: GitHub

add implementation of BiBERT (#1102)

上级 63140aed
## BiBERT: Accurate Fully Binarized BERT
Created by [Haotong Qin](https://htqin.github.io/), [Yifu Ding](https://yifu-ding.github.io/), [Mingyuan Zhang](https://scholar.google.com/citations?user=2QLD4fAAAAAJ&hl=en), Qinghua Yan, [Aishan Liu](https://liuaishan.github.io/), Qingqing Dang, [Ziwei Liu](https://liuziwei7.github.io/), and [Xianglong Liu](https://xlliu-beihang.github.io/) from Beihang University, Nanyang Technological University, and Baidu Inc.
![loading-ag-172](./overview.png)
## Introduction
This project is the official implementation of our accepted ICLR 2022 paper *BiBERT: Accurate Fully Binarized BERT* [[PDF](https://openreview.net/forum?id=5xEgrl_5FAJ)]. The large pre-trained BERT has achieved remarkable performance on Natural Language Processing (NLP) tasks but is also computation and memory expensive. As one of the powerful compression approaches, binarization extremely reduces the computation and memory consumption by utilizing 1-bit parameters and bitwise operations. Unfortunately, the full binarization of BERT (i.e., 1-bit weight, embedding, and activation) usually suffer a significant performance drop, and there is rare study addressing this problem. In this paper, with the theoretical justification and empirical analysis, we identify that the severe performance drop can be mainly attributed to the information degradation and optimization direction mismatch respectively in the forward and backward propagation, and propose BiBERT, an accurate fully binarized BERT, to eliminate the performance bottlenecks. Specifically, BiBERT introduces an efficient Bi-Attention structure for maximizing representation information statistically and a Direction-Matching Distillation (DMD) scheme to optimize the full binarized BERT accurately. Extensive experiments show that BiBERT outperforms both the straightforward baseline and existing state-of-the-art quantized BERTs with ultra-low bit activations by convincing margins on the NLP benchmark. As the first fully binarized BERT, our method yields impressive $59.2\times$ and $31.2\times$ saving on FLOPs and model size, demonstrating the vast advantages and potential of the fully binarized BERT model in real-world resource-constrained scenarios.
## Quick start
This tutorial uses the GLUE/SST-2 dataset as an example to perform 1-bit quantization on the BERT model in PaddleNLP.
### Install PaddleNLP 和 Paddle
```shell
pip install paddlenlp
pip install paddlepaddle_gpu
```
### Acquisition of data and training models
When running the experiments in this directory, the dataset and the pretrained model will be automatically downloaded to the path of `paddlenlp.utils.env.DATA_HOME`. For example, under the Linux system, for the SST-2 dataset in GLUE, the default storage path is `~/.paddlenlp/datasets/Glue/SST-2`.
### Fine-tune the pretrained model
Following [GLUE](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/benchmark/glue/run_glue.py), the pre-trained model can be fine-tuned to obtain a full-precision teacher model
```shell
export TASK_NAME=SST-2
export LOG_FILENAME=$(date "+%Y-%m-%d-%H-%M-%S")
python -u ./run_glue.py \
--model_type bert \
--model_name_or_path bert-base-uncased \
--task_name $TASK_NAME \
--max_seq_length 128 \
--batch_size 32 \
--learning_rate 2e-5 \
--num_train_epochs 3 \
--logging_steps 1 \
--save_steps 500 \
--output_dir ./tmp/$TASK_NAME/ \
--device gpu 2>&1 | tee ${LOG_FILENAME}.log
```
### Train a binarized model
```shell
export TASK_NAME=SST-2
export TEACHER_PATH=./tmp/SST-2/sst-2_ft_model_6315.pdparams.pdparams
export LOG_FILENAME=$(date "+%Y-%m-%d-%H-%M-%S")
python my_task_distill.py \
--model_type bert \
--student_model_name_or_path $TEACHER_PATH \
--seed 1000000007 \
--weight_decay 0.01 \
--task_name $TASK_NAME \
--max_seq_length 64 \
--batch_size 16 \
--teacher_model_type bert \
--teacher_path $TEACHER_PATH \
--learning_rate 1e-4 \
--num_train_epochs 10 \
--logging_steps 10 \
--save_steps 10 \
--output_dir ./tmp/$TASK_NAME/ \
--device gpu \
--pred_distill \
--query_distill \
--key_distill \
--value_distill \
--intermediate_distill \
--bi 2>&1 | tee ${LOG_FILENAME}.log
```
Besides, parameters `pred_distill`, `query_distill`, `key_distill`, `value_distill` `intermediate_distill` are used for distillation configuration, and parameter `bi` is used to binarize the model.
## Acknowledgement
The original code is borrowed from [TinyBERT](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/model_compression/tinybert).
## Citation
If you find our work useful in your research, please consider citing:
```shell
@inproceedings{Qin:iclr22,
author = {Haotong Qin and Yifu Ding and Mingyuan Zhang and Qinghua Yan and
Aishan Liu and Qingqing Dang and Ziwei Liu and Xianglong Liu},
title = {BiBERT: Accurate Fully Binarized BERT},
booktitle = {International Conference on Learning Representations (ICLR)},
year = {2022}
}
```
import collections
from webbrowser import get
import paddle
from paddle import tensor
from paddle.autograd import PyLayer
from paddle.fluid import layers
from paddle.nn import functional as F
from paddle.nn.layer.common import Linear, Embedding
from paddle.nn.layer.transformer import MultiHeadAttention, _convert_attention_mask
class BinaryQuantizer(PyLayer):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
out = paddle.sign(input)
return out
@staticmethod
def backward(ctx, grad_output):
input = ctx.saved_tensor()[0]
grad_input = grad_output
grad_input[input >= 1] = 0
grad_input[input <= -1] = 0
return grad_input.clone()
class ZMeanBinaryQuantizer(PyLayer):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
out = (paddle.sign(input) + 1) / 2
return out
@staticmethod
def backward(ctx, grad_output):
input = ctx.saved_tensor()[0]
grad_input = grad_output
grad_input[input >= 1] = 0
grad_input[input <= -1] = 0
return grad_input.clone()
class BiLinear(Linear):
def __init__(self, in_features, out_features, weight_attr=None, bias_attr=None, name=None):
super(BiLinear, self).__init__(in_features, out_features, weight_attr=weight_attr, bias_attr=bias_attr, name=name)
def forward(self, input):
scaling_factor = paddle.mean(self.weight.abs(), axis=1).unsqueeze(1).detach()
real_weights = self.weight - paddle.mean(self.weight, axis=-1).unsqueeze(-1)
binary_weights_no_grad = scaling_factor * paddle.sign(real_weights)
cliped_weights = paddle.clip(real_weights, -1.0, 1.0)
weight = binary_weights_no_grad.detach() - cliped_weights.detach() + cliped_weights
binary_input_no_grad = paddle.sign(input)
cliped_input = paddle.clip(input, -1.0, 1.0)
ba = binary_input_no_grad.detach() - cliped_input.detach() + cliped_input
out = F.linear(x=ba, weight=weight, bias=self.bias, name=self.name)
return out
class BiEmbedding(Embedding):
def __init__(self, num_embeddings, embedding_dim, padding_idx=None, sparse=False, weight_attr=None, name=None):
super(BiEmbedding, self).__init__(num_embeddings, embedding_dim, padding_idx, sparse, weight_attr, name)
def forward(self, x):
scaling_factor = paddle.mean(self.weight.abs(), axis=1, keepdim=True)
scaling_factor = scaling_factor.detach()
real_weights = self.weight - paddle.mean(self.weight, axis=-1, keepdim=True)
binary_weights_no_grad = scaling_factor * paddle.sign(real_weights)
cliped_weights = paddle.clip(real_weights, -1.0, 1.0)
weight = binary_weights_no_grad.detach() - cliped_weights.detach() + cliped_weights
return F.embedding(x, weight=weight, padding_idx=self._padding_idx, sparse=self._sparse, name=self._name)
class BiMultiHeadAttention(MultiHeadAttention):
# fork from paddle.nn.layer.transformer.MultiHeadAttention
Cache = collections.namedtuple("Cache", ["k", "v"])
StaticCache = collections.namedtuple("StaticCache", ["k", "v"])
def __init__(self, embed_dim, num_heads, dropout=0., kdim=None, vdim=None, need_weights=False, weight_attr=None, bias_attr=None):
super(BiMultiHeadAttention, self).__init__(embed_dim, num_heads, dropout, kdim, vdim, need_weights, weight_attr, bias_attr)
def forward(self, query, key=None, value=None, attn_mask=None, cache=None):
key = query if key is None else key
value = query if value is None else value
# compute q ,k ,v
if cache is None:
q, k, v = self._prepare_qkv(query, key, value, cache)
else:
q, k, v, cache = self._prepare_qkv(query, key, value, cache)
q = BinaryQuantizer.apply(q)
k = BinaryQuantizer.apply(k)
# scale dot product attention
# TODO(guosheng): use tensor.matmul, however it doesn't support `alpha`
product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5)
if attn_mask is not None:
# Support bool or int mask
attn_mask = _convert_attention_mask(attn_mask, product.dtype)
product = product + attn_mask
weights = F.softmax(product)
if self.dropout:
weights = F.dropout(
weights,
self.dropout,
training=self.training,
mode="upscale_in_train")
weights = ZMeanBinaryQuantizer.apply(weights)
v = BinaryQuantizer.apply(v)
out = tensor.matmul(weights, v)
# combine heads
out = tensor.transpose(out, perm=[0, 2, 1, 3])
out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
# project to output
out = self.out_proj(out)
outs = [out]
if self.need_weights:
outs.append(weights)
if cache is not None:
outs.append(cache)
return out if len(outs) == 1 else tuple(outs)
def _to_bi_function(model):
for name, layer in model.named_children():
if isinstance(layer, MultiHeadAttention):
new_layer = BiMultiHeadAttention(layer.embed_dim,
layer.num_heads,
layer.dropout,
layer.kdim,
layer.vdim,
layer.need_weights,
layer.q_proj._weight_attr,
layer.q_proj._bias_attr)
new_layer.q_proj = layer.q_proj
new_layer.k_proj = layer.k_proj
new_layer.v_proj = layer.v_proj
new_layer.out_proj = layer.out_proj
model._sub_layers[name] = new_layer
elif isinstance(layer, Embedding):
if name != "word_embeddings": continue
new_layer = BiEmbedding(layer._num_embeddings,
layer._embedding_dim,
layer._padding_idx,
layer._sparse,
layer._weight_attr,
layer._name)
new_layer.weight = layer.weight
model._sub_layers[name] = new_layer
elif isinstance(layer, Linear):
if name == "classifier": continue
new_layer = BiLinear(layer.weight.shape[0],
layer.weight.shape[1],
layer._weight_attr,
layer._bias_attr,
layer.name)
new_layer.weight = layer.weight
new_layer.bias = layer.bias
model._sub_layers[name] = new_layer
import math
def _MultiHeadAttention_forward(self, query, key=None, value=None, attn_mask=None, cache=None):
key = query if key is None else key
value = query if value is None else value
# compute q ,k ,v
if cache is None:
q, k, v = self._prepare_qkv(query, key, value, cache)
else:
q, k, v, cache = self._prepare_qkv(query, key, value, cache)
# distill qxq
query_scores = paddle.matmul(q, tensor.transpose(x=q, perm=[0, 1, 3, 2]))
query_scores = query_scores / math.sqrt(self.head_dim)
# distill kxk
key_scores = paddle.matmul(k, tensor.transpose(x=k, perm=[0, 1, 3, 2]))
key_scores = key_scores / math.sqrt(self.head_dim)
# scale dot product attention
# TODO(guosheng): use tensor.matmul, however it doesn't support `alpha`
product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5)
if attn_mask is not None:
# Support bool or int mask
attn_mask = _convert_attention_mask(attn_mask, product.dtype)
product = product + attn_mask
weights = F.softmax(product)
if self.dropout:
weights = F.dropout(
weights,
self.dropout,
training=self.training,
mode="upscale_in_train")
# distil vxv
value_scores = paddle.matmul(v, tensor.transpose(x=v, perm=[0, 1, 3, 2]))
value_scores = value_scores / math.sqrt(self.head_dim)
out = tensor.matmul(weights, v)
# combine heads
out = tensor.transpose(out, perm=[0, 2, 1, 3])
out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
# project to output
out = self.out_proj(out)
outs = [out]
if self.need_weights:
outs.append(weights)
if cache is not None:
outs.append(cache)
self.query_scores = query_scores
self.key_scores = key_scores
self.value_scores = value_scores
return out if len(outs) == 1 else tuple(outs)
def _Bi_MultiHeadAttention_forward(self, query, key=None, value=None, attn_mask=None, cache=None):
key = query if key is None else key
value = query if value is None else value
# compute q ,k ,v
if cache is None:
q, k, v = self._prepare_qkv(query, key, value, cache)
else:
q, k, v, cache = self._prepare_qkv(query, key, value, cache)
# distill qxq
query_scores = paddle.matmul(q, tensor.transpose(x=q, perm=[0, 1, 3, 2]))
query_scores = query_scores / math.sqrt(self.head_dim)
# distill kxk
key_scores = paddle.matmul(k, tensor.transpose(x=k, perm=[0, 1, 3, 2]))
key_scores = key_scores / math.sqrt(self.head_dim)
q = BinaryQuantizer.apply(q)
k = BinaryQuantizer.apply(k)
# scale dot product attention
# TODO(guosheng): use tensor.matmul, however it doesn't support `alpha`
product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5)
if attn_mask is not None:
# Support bool or int mask
attn_mask = _convert_attention_mask(attn_mask, product.dtype)
product = product + attn_mask
# weights = F.softmax(product)
weights = product
if self.dropout:
weights = F.dropout(
weights,
self.dropout,
training=self.training,
mode="upscale_in_train")
# distil vxv
value_scores = paddle.matmul(v, tensor.transpose(x=v, perm=[0, 1, 3, 2]))
value_scores = value_scores / math.sqrt(self.head_dim)
weights = ZMeanBinaryQuantizer.apply(weights)
v = BinaryQuantizer.apply(v)
out = tensor.matmul(weights, v)
# combine heads
out = tensor.transpose(out, perm=[0, 2, 1, 3])
out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
# project to output
out = self.out_proj(out)
outs = [out]
if self.need_weights:
outs.append(weights)
if cache is not None:
outs.append(cache)
self.query_scores = query_scores
self.key_scores = key_scores
self.value_scores = value_scores
return out if len(outs) == 1 else tuple(outs)
def _TransformerEncoderLayer_forward(self, src, src_mask=None, cache=None):
src_mask = _convert_attention_mask(src_mask, src.dtype)
residual = src
if self.normalize_before:
src = self.norm1(src)
# Add cache for encoder for the usage like UniLM
if cache is None:
src = self.self_attn(src, src, src, src_mask)
else:
src, incremental_cache = self.self_attn(src, src, src, src_mask,
cache)
src = residual + self.dropout1(src)
if not self.normalize_before:
src = self.norm1(src)
residual = src
if self.normalize_before:
src = self.norm2(src)
src = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = residual + self.dropout2(src)
if not self.normalize_before:
src = self.norm2(src)
self.rep = src
return src if cache is None else (src, incremental_cache)
def _get_attr(model, attr):
res = []
if hasattr(model, attr):
res.append(getattr(model, attr))
for layer in model.children():
res.extend(_get_attr(layer, attr))
return res
def _to_distill_function(model):
from types import MethodType
for layer in model.children():
if isinstance(layer, BiMultiHeadAttention):
layer.forward = MethodType(_Bi_MultiHeadAttention_forward, layer)
elif isinstance(layer, MultiHeadAttention):
layer.forward = MethodType(_MultiHeadAttention_forward, layer)
elif isinstance(layer, paddle.nn.layer.transformer.TransformerEncoderLayer):
layer.forward = MethodType(_TransformerEncoderLayer_forward, layer)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
import sys
import random
import time
import math
from functools import partial
import numpy as np
import paddle
from paddle.io import DataLoader
from paddle.metric import Metric, Accuracy, Precision, Recall
from paddlenlp.datasets import load_dataset
from paddlenlp.data import Stack, Tuple, Pad, Dict
from paddlenlp.data.sampler import SamplerHelper
from paddlenlp.transformers import AutoTokenizer, AutoModelForSequenceClassification
from paddlenlp.transformers import LinearDecayWithWarmup
from paddlenlp.metrics import AccuracyAndF1, Mcc, PearsonAndSpearman
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
METRIC_CLASSES = {
"cola": Mcc,
"sst-2": Accuracy,
"mrpc": AccuracyAndF1,
"sts-b": PearsonAndSpearman,
"qqp": AccuracyAndF1,
"mnli": Accuracy,
"qnli": Accuracy,
"rte": Accuracy,
}
def parse_args():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--task_name",
default=None,
type=str,
required=True,
help="The name of the task to train selected in the list: " +
", ".join(METRIC_CLASSES.keys()), )
parser.add_argument(
"--model_type",
default=None,
type=str,
required=False,
help="should be remove later")
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
required=True,
help="Path to pre-trained model")
parser.add_argument(
"--tokenizer_name_or_path",
default=None,
type=str,
required=False,
help="Path to tokenizer")
parser.add_argument(
"--output_dir",
default=None,
type=str,
required=True,
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--max_seq_length",
default=128,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.", )
parser.add_argument(
"--learning_rate",
default=1e-4,
type=float,
help="The initial learning rate for Adam.")
parser.add_argument(
"--num_train_epochs",
default=3,
type=int,
help="Total number of training epochs to perform.", )
parser.add_argument(
"--logging_steps",
type=int,
default=100,
help="Log every X updates steps.")
parser.add_argument(
"--save_steps",
type=int,
default=100,
help="Save checkpoint every X updates steps.")
parser.add_argument(
"--batch_size",
default=32,
type=int,
help="Batch size per GPU/CPU for training.", )
parser.add_argument(
"--weight_decay",
default=0.0,
type=float,
help="Weight decay if we apply some.")
parser.add_argument(
"--warmup_steps",
default=0,
type=int,
help="Linear warmup over warmup_steps. If > 0: Override warmup_proportion"
)
parser.add_argument(
"--warmup_proportion",
default=0.1,
type=float,
help="Linear warmup proportion over total steps.")
parser.add_argument(
"--adam_epsilon",
default=1e-6,
type=float,
help="Epsilon for Adam optimizer.")
parser.add_argument(
"--max_steps",
default=-1,
type=int,
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
)
parser.add_argument(
"--seed", default=42, type=int, help="random seed for initialization")
parser.add_argument(
"--device",
default="gpu",
type=str,
help="The device to select to train the model, is must be cpu/gpu/xpu.")
args = parser.parse_args()
return args
def set_seed(args):
# Use the same data seed(for data shuffle) for all procs to guarantee data
# consistency after sharding.
random.seed(args.seed)
np.random.seed(args.seed)
# Maybe different op seeds(for dropout) for different procs is better. By:
# `paddle.seed(args.seed + paddle.distributed.get_rank())`
paddle.seed(args.seed)
@paddle.no_grad()
def evaluate(model, loss_fct, metric, data_loader):
model.eval()
metric.reset()
for batch in data_loader:
input_ids, segment_ids, labels = batch
logits = model(input_ids, segment_ids)
loss = loss_fct(logits, labels)
correct = metric.compute(logits, labels)
metric.update(correct)
res = metric.accumulate()
if isinstance(metric, AccuracyAndF1):
print(
"eval loss: %f, acc: %s, precision: %s, recall: %s, f1: %s, acc and f1: %s, "
% (
loss.numpy(),
res[0],
res[1],
res[2],
res[3],
res[4], ),
end='')
elif isinstance(metric, Mcc):
print("eval loss: %f, mcc: %s, " % (loss.numpy(), res[0]), end='')
elif isinstance(metric, PearsonAndSpearman):
print(
"eval loss: %f, pearson: %s, spearman: %s, pearson and spearman: %s, "
% (loss.numpy(), res[0], res[1], res[2]),
end='')
else:
print("eval loss: %f, acc: %s, " % (loss.numpy(), res), end='')
model.train()
def convert_example(example,
tokenizer,
label_list,
max_seq_length=512,
is_test=False):
"""convert a glue example into necessary features"""
if not is_test:
# `label_list == None` is for regression task
label_dtype = "int64" if label_list else "float32"
# Get the label
label = example['labels']
label = np.array([label], dtype=label_dtype)
# Convert raw text to feature
if (int(is_test) + len(example)) == 2:
example = tokenizer(example['sentence'], max_seq_len=max_seq_length)
else:
example = tokenizer(
example['sentence1'],
text_pair=example['sentence2'],
max_seq_len=max_seq_length)
if not is_test:
return example['input_ids'], example['token_type_ids'], label
else:
return example['input_ids'], example['token_type_ids']
def do_train(args):
paddle.set_device(args.device)
if paddle.distributed.get_world_size() > 1:
paddle.distributed.init_parallel_env()
set_seed(args)
args.task_name = args.task_name.lower()
metric_class = METRIC_CLASSES[args.task_name]
train_ds = load_dataset('glue', args.task_name, splits="train")
if args.tokenizer_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path)
else:
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
trans_func = partial(
convert_example,
tokenizer=tokenizer,
label_list=train_ds.label_list,
max_seq_length=args.max_seq_length)
train_ds = train_ds.map(trans_func, lazy=True)
train_batch_sampler = paddle.io.DistributedBatchSampler(
train_ds, batch_size=args.batch_size, shuffle=True)
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id), # input
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment
Stack(dtype="int64" if train_ds.label_list else "float32") # label
): fn(samples)
train_data_loader = DataLoader(
dataset=train_ds,
batch_sampler=train_batch_sampler,
collate_fn=batchify_fn,
num_workers=0,
return_list=True)
if args.task_name == "mnli":
dev_ds_matched, dev_ds_mismatched = load_dataset(
'glue', args.task_name, splits=["dev_matched", "dev_mismatched"])
dev_ds_matched = dev_ds_matched.map(trans_func, lazy=True)
dev_ds_mismatched = dev_ds_mismatched.map(trans_func, lazy=True)
dev_batch_sampler_matched = paddle.io.BatchSampler(
dev_ds_matched, batch_size=args.batch_size, shuffle=False)
dev_data_loader_matched = DataLoader(
dataset=dev_ds_matched,
batch_sampler=dev_batch_sampler_matched,
collate_fn=batchify_fn,
num_workers=0,
return_list=True)
dev_batch_sampler_mismatched = paddle.io.BatchSampler(
dev_ds_mismatched, batch_size=args.batch_size, shuffle=False)
dev_data_loader_mismatched = DataLoader(
dataset=dev_ds_mismatched,
batch_sampler=dev_batch_sampler_mismatched,
collate_fn=batchify_fn,
num_workers=0,
return_list=True)
else:
dev_ds = load_dataset('glue', args.task_name, splits='dev')
dev_ds = dev_ds.map(trans_func, lazy=True)
dev_batch_sampler = paddle.io.BatchSampler(
dev_ds, batch_size=args.batch_size, shuffle=False)
dev_data_loader = DataLoader(
dataset=dev_ds,
batch_sampler=dev_batch_sampler,
collate_fn=batchify_fn,
num_workers=0,
return_list=True)
num_classes = 1 if train_ds.label_list == None else len(train_ds.label_list)
model = AutoModelForSequenceClassification.from_pretrained(
args.model_name_or_path, num_classes=num_classes)
if paddle.distributed.get_world_size() > 1:
model = paddle.DataParallel(model)
if args.max_steps > 0:
num_training_steps = args.max_steps
num_train_epochs = math.ceil(num_training_steps /
len(train_data_loader))
else:
num_training_steps = len(train_data_loader) * args.num_train_epochs
num_train_epochs = args.num_train_epochs
warmup = args.warmup_steps if args.warmup_steps > 0 else args.warmup_proportion
lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps,
warmup)
# Generate parameter names needed to perform weight decay.
# All bias and LayerNorm parameters are excluded.
decay_params = [
p.name for n, p in model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"])
]
optimizer = paddle.optimizer.AdamW(
learning_rate=lr_scheduler,
beta1=0.9,
beta2=0.999,
epsilon=args.adam_epsilon,
parameters=model.parameters(),
weight_decay=args.weight_decay,
apply_decay_param_fun=lambda x: x in decay_params)
loss_fct = paddle.nn.loss.CrossEntropyLoss(
) if train_ds.label_list else paddle.nn.loss.MSELoss()
metric = metric_class()
global_step = 0
tic_train = time.time()
for epoch in range(num_train_epochs):
for step, batch in enumerate(train_data_loader):
global_step += 1
input_ids, segment_ids, labels = batch
logits = model(input_ids, segment_ids)
loss = loss_fct(logits, labels)
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.clear_grad()
if global_step % args.logging_steps == 0:
print(
"global step %d/%d, epoch: %d, batch: %d, rank_id: %s, loss: %f, lr: %.10f, speed: %.4f step/s"
% (global_step, num_training_steps, epoch, step,
paddle.distributed.get_rank(), loss, optimizer.get_lr(),
args.logging_steps / (time.time() - tic_train)))
tic_train = time.time()
if global_step % args.save_steps == 0 or global_step == num_training_steps:
tic_eval = time.time()
if args.task_name == "mnli":
evaluate(model, loss_fct, metric, dev_data_loader_matched)
evaluate(model, loss_fct, metric,
dev_data_loader_mismatched)
print("eval done total : %s s" % (time.time() - tic_eval))
else:
evaluate(model, loss_fct, metric, dev_data_loader)
print("eval done total : %s s" % (time.time() - tic_eval))
if paddle.distributed.get_rank() == 0:
output_dir = os.path.join(args.output_dir,
"%s_ft_model_%d.pdparams" %
(args.task_name, global_step))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Need better way to get inner model of DataParallel
model_to_save = model._layers if isinstance(
model, paddle.DataParallel) else model
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
if global_step >= num_training_steps:
return
def print_arguments(args):
"""print arguments"""
print('----------- Configuration Arguments -----------')
for arg, value in sorted(vars(args).items()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
if __name__ == "__main__":
args = parse_args()
print_arguments(args)
do_train(args)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
import sys
import random
import time
import math
from functools import partial
import numpy as np
import paddle
from paddle.io import DataLoader
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.metric import Accuracy
from paddlenlp.datasets import load_dataset
from paddlenlp.data import Stack, Tuple, Pad, Dict
from paddlenlp.data.sampler import SamplerHelper
from paddlenlp.metrics import AccuracyAndF1, Mcc, PearsonAndSpearman
import paddlenlp.transformers as T
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
METRIC_CLASSES = {
"cola": Mcc,
"sst-2": Accuracy,
"mrpc": AccuracyAndF1,
"sts-b": PearsonAndSpearman,
"qqp": AccuracyAndF1,
"mnli": Accuracy,
"qnli": Accuracy,
"rte": Accuracy,
}
MODEL_CLASSES = {
"bert": (T.BertForSequenceClassification, T.BertTokenizer),
"tinybert": (T.TinyBertForSequenceClassification, T.TinyBertTokenizer),
}
def parse_args():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--task_name",
default=None,
type=str,
required=True,
help="The name of the task to train selected in the list: " +
", ".join(METRIC_CLASSES.keys()), )
parser.add_argument(
"--model_type",
default="tinybert",
type=str,
required=True,
help="Model type selected in the list: " +
", ".join(MODEL_CLASSES.keys()), )
parser.add_argument(
"--teacher_model_type",
default="bert",
type=str,
required=True,
help="Model type selected in the list: " +
", ".join(MODEL_CLASSES.keys()), )
parser.add_argument(
"--student_model_name_or_path",
default=None,
type=str,
required=True,
help="Path to pre-trained model or shortcut name selected in the list: "
+ ", ".join(
sum([
list(classes[-1].pretrained_init_configuration.keys())
for classes in MODEL_CLASSES.values()
], [])), )
parser.add_argument(
"--distill_config",
default=None,
type=str,
help="distill config file path")
parser.add_argument(
"--teacher_path",
default=None,
type=str,
required=True,
help="Path to pre-trained model.")
parser.add_argument(
"--output_dir",
default=None,
type=str,
required=True,
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--glue_dir",
default="/root/.paddlenlp/datasets/Glue/",
type=str,
required=False,
help="The Glue directory.", )
parser.add_argument(
"--max_seq_length",
default=128,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.", )
parser.add_argument(
"--learning_rate",
default=1e-4,
type=float,
help="The initial learning rate for Adam.")
parser.add_argument(
"--num_train_epochs",
default=3,
type=int,
help="Total number of training epochs to perform.", )
parser.add_argument(
"--logging_steps",
type=int,
default=100,
help="Log every X updates steps.")
parser.add_argument(
"--save_steps",
type=int,
default=100,
help="Save checkpoint every X updates steps.")
parser.add_argument(
"--batch_size",
default=32,
type=int,
help="Batch size per GPU/CPU for training.", )
parser.add_argument(
"--T",
default=1,
type=int,
help="Temperature for softmax", )
parser.add_argument(
"--use_aug",
action="store_true",
help="Whether to use augmentation data to train.", )
parser.add_argument(
"--weight_decay",
default=0.0,
type=float,
help="Weight decay if we apply some.")
parser.add_argument(
"--warmup_steps",
default=0,
type=int,
help="Linear warmup over warmup_steps. If > 0: Override warmup_proportion"
)
parser.add_argument(
"--warmup_proportion",
default=0.1,
type=float,
help="Linear warmup proportion over total steps.")
parser.add_argument(
"--adam_epsilon",
default=1e-6,
type=float,
help="Epsilon for Adam optimizer.")
parser.add_argument(
"--max_steps",
default=-1,
type=int,
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
)
parser.add_argument(
"--seed", default=42, type=int, help="random seed for initialization")
parser.add_argument(
"--device",
default="gpu",
type=str,
help="The device to select to train the model, is must be cpu/gpu/xpu.")
args = parser.parse_args()
return args
def set_seed(args):
# Use the same data seed(for data shuffle) for all procs to guarantee data
# consistency after sharding.
random.seed(args.seed)
np.random.seed(args.seed)
# Maybe different op seeds(for dropout) for different procs is better. By:
# `paddle.seed(args.seed + paddle.distributed.get_rank())`
paddle.seed(args.seed)
@paddle.no_grad()
def evaluate(model, metric, data_loader):
model.eval()
metric.reset()
for batch in data_loader:
input_ids, segment_ids, labels = batch
logits = model(input_ids, segment_ids)
correct = metric.compute(logits, labels)
metric.update(correct)
res = metric.accumulate()
if isinstance(metric, AccuracyAndF1):
print(
"acc: %s, precision: %s, recall: %s, f1: %s, acc and f1: %s, " % (
res[0],
res[1],
res[2],
res[3],
res[4], ),
end='')
elif isinstance(metric, Mcc):
print("mcc: %s, " % (res[0]), end='')
elif isinstance(metric, PearsonAndSpearman):
print(
"pearson: %s, spearman: %s, pearson and spearman: %s, " %
(res[0], res[1], res[2]),
end='')
else:
print("acc: %s, " % (res), end='')
model.train()
return res[0] if isinstance(metric, (AccuracyAndF1, Mcc,
PearsonAndSpearman)) else res
def convert_example(example,
tokenizer,
label_list,
max_seq_length=512,
is_test=False):
"""convert a glue example into necessary features"""
if not is_test:
# `label_list == None` is for regression task
label_dtype = "int64" if label_list else "float32"
# Get the label
label = example['labels']
label = np.array([label], dtype=label_dtype)
# Convert raw text to feature
if (int(is_test) + len(example)) == 2:
example = tokenizer(example['sentence'], max_seq_len=max_seq_length)
else:
example = tokenizer(
example['sentence1'],
text_pair=example['sentence2'],
max_seq_len=max_seq_length)
if not is_test:
return example['input_ids'], example['token_type_ids'], label
else:
return example['input_ids'], example['token_type_ids']
def do_train(args):
paddle.set_device(args.device)
if paddle.distributed.get_world_size() > 1:
paddle.distributed.init_parallel_env()
set_seed(args)
args.task_name = args.task_name.lower()
metric_class = METRIC_CLASSES[args.task_name]
args.model_type = args.model_type.lower()
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
if args.use_aug:
aug_data_file = os.path.join(
os.path.join(args.glue_dir, args.task_name), "train_aug.tsv"),
train_ds = load_dataset(
'glue', args.task_name, data_files=aug_data_file)
else:
train_ds = load_dataset('glue', args.task_name, splits='train')
tokenizer = tokenizer_class.from_pretrained(args.student_model_name_or_path)
trans_func = partial(
convert_example,
tokenizer=tokenizer,
label_list=train_ds.label_list,
max_seq_length=args.max_seq_length)
train_ds = train_ds.map(trans_func, lazy=True)
train_batch_sampler = paddle.io.DistributedBatchSampler(
train_ds, batch_size=args.batch_size, shuffle=True)
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id), # input
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment
Stack(dtype="int64" if train_ds.label_list else "float32") # label
): fn(samples)
train_data_loader = DataLoader(
dataset=train_ds,
batch_sampler=train_batch_sampler,
collate_fn=batchify_fn,
num_workers=0,
return_list=True)
if args.task_name == "mnli":
dev_ds_matched, dev_ds_mismatched = load_dataset(
'glue', args.task_name, splits=["dev_matched", "dev_mismatched"])
dev_ds_matched = dev_ds_matched.map(trans_func, lazy=True)
dev_ds_mismatched = dev_ds_mismatched.map(trans_func, lazy=True)
dev_batch_sampler_matched = paddle.io.BatchSampler(
dev_ds_matched, batch_size=args.batch_size, shuffle=False)
dev_data_loader_matched = DataLoader(
dataset=dev_ds_matched,
batch_sampler=dev_batch_sampler_matched,
collate_fn=batchify_fn,
num_workers=0,
return_list=True)
dev_batch_sampler_mismatched = paddle.io.BatchSampler(
dev_ds_mismatched, batch_size=args.batch_size, shuffle=False)
dev_data_loader_mismatched = DataLoader(
dataset=dev_ds_mismatched,
batch_sampler=dev_batch_sampler_mismatched,
collate_fn=batchify_fn,
num_workers=0,
return_list=True)
else:
dev_ds = load_dataset('glue', args.task_name, splits='dev')
dev_ds = dev_ds.map(trans_func, lazy=True)
dev_batch_sampler = paddle.io.BatchSampler(
dev_ds, batch_size=args.batch_size, shuffle=False)
dev_data_loader = DataLoader(
dataset=dev_ds,
batch_sampler=dev_batch_sampler,
collate_fn=batchify_fn,
num_workers=0,
return_list=True)
num_classes = 1 if train_ds.label_list == None else len(train_ds.label_list)
student = model_class.from_pretrained(
args.student_model_name_or_path, num_classes=num_classes)
teacher_model_class, _ = MODEL_CLASSES[args.teacher_model_type]
teacher = teacher_model_class.from_pretrained(
args.teacher_path, num_classes=num_classes)
if paddle.distributed.get_world_size() > 1:
student = paddle.DataParallel(student, find_unused_parameters=True)
teacher = paddle.DataParallel(teacher, find_unused_parameters=True)
if args.max_steps > 0:
num_training_steps = args.max_steps
num_train_epochs = math.ceil(num_training_steps /
len(train_data_loader))
else:
num_training_steps = len(train_data_loader) * args.num_train_epochs
num_train_epochs = args.num_train_epochs
warmup = args.warmup_steps if args.warmup_steps > 0 else args.warmup_proportion
lr_scheduler = T.LinearDecayWithWarmup(args.learning_rate,
num_training_steps, warmup)
### step1: load distill config
assert os.path.exists(
args.distill_config), "distill file {} not exist.".format(
args.distill_config)
### step2: wrap the student model and teacher model by paddleslim.dygraph.dist.Distill
### the distill config need to be passed into it.
distill_model = Distill(
args.distill_config, students=[student], teachers=[teacher])
### step3: add parameter created by align op to optimizer
# Generate parameter names needed to perform weight decay.
# All bias and LayerNorm parameters are excluded.
decay_params = [
p.name for n, p in distill_model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"])
]
optimizer = paddle.optimizer.AdamW(
learning_rate=lr_scheduler,
beta1=0.9,
beta2=0.999,
epsilon=args.adam_epsilon,
parameters=distill_model.parameters(),
weight_decay=args.weight_decay,
apply_decay_param_fun=lambda x: x in decay_params)
metric = metric_class()
pad_token_id = 0
global_step = 0
tic_train = time.time()
best_res = 0.0
for epoch in range(num_train_epochs):
for step, batch in enumerate(train_data_loader):
global_step += 1
input_ids, segment_ids, labels = batch
### step4: call distill_model instead of call student model and teacher model independently.
loss, _, _ = distill_model(input_ids, segment_ids)
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.clear_grad()
if global_step % args.logging_steps == 0:
print(
"global step %d/%d, epoch: %d, batch: %d, rank_id: %s, loss: %f, lr: %.10f, speed: %.4f step/s"
% (global_step, num_training_steps, epoch, step,
paddle.distributed.get_rank(), loss, optimizer.get_lr(),
args.logging_steps / (time.time() - tic_train)))
tic_train = time.time()
if global_step % args.save_steps == 0 or global_step == num_training_steps:
tic_eval = time.time()
if args.task_name == "mnli":
res = evaluate(student, metric, dev_data_loader_matched)
evaluate(student, metric, dev_data_loader_mismatched)
print("eval done total : %s s" % (time.time() - tic_eval))
else:
res = evaluate(student, metric, dev_data_loader)
print("eval done total : %s s" % (time.time() - tic_eval))
if (best_res < res and global_step < num_training_steps or
global_step == num_training_steps
) and paddle.distributed.get_rank() == 0:
if global_step < num_training_steps:
output_dir = os.path.join(args.output_dir,
"distill_model_%d.pdparams" %
(global_step))
else:
output_dir = os.path.join(
args.output_dir, "distill_model_final.pdparams")
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Need better way to get inner model of DataParallel
model_to_save = student._layers if isinstance(
student, paddle.DataParallel) else student
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
best_res = res
if global_step >= num_training_steps:
return
def print_arguments(args):
"""print arguments"""
print('----------- Configuration Arguments -----------')
for arg, value in sorted(vars(args).items()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
if __name__ == "__main__":
args = parse_args()
print_arguments(args)
do_train(args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册