未验证 提交 7a529cf1 编写于 作者: W whs 提交者: GitHub

Add distillation for BERT (#330)

* Fix import in py3.

* Update bert distillation.

* Add option to fix embedding in bert distillation.

* Fix distillation in adabert searching

* Remove unused code
上级 491c3489
......@@ -52,7 +52,10 @@ class AdaBERTClassifier(Layer):
search_layer=False,
teacher_model=None,
data_dir=None,
use_fixed_gumbel=False):
use_fixed_gumbel=False,
gumbel_alphas=None,
fix_emb=False,
t=5.0):
super(AdaBERTClassifier, self).__init__()
self._n_layer = n_layer
self._num_labels = num_labels
......@@ -65,21 +68,37 @@ class AdaBERTClassifier(Layer):
self._teacher_model = teacher_model
self._data_dir = data_dir
self.use_fixed_gumbel = use_fixed_gumbel
#print(
# "----------------------load teacher model and test----------------------------------------"
#)
#self.teacher = BERTClassifier(num_labels, model_path=self._teacher_model)
#self.teacher.test(self._data_dir)
#print(
# "----------------------finish load teacher model and test----------------------------------------"
#)
self.T = t
print(
"----------------------load teacher model and test----------------------------------------"
)
self.teacher = BERTClassifier(
num_labels, model_path=self._teacher_model)
self.teacher.test(self._data_dir)
print(
"----------------------finish load teacher model and test----------------------------------------"
)
self.student = BertModelLayer(
n_layer=self._n_layer,
emb_size=self._emb_size,
hidden_size=self._hidden_size,
conv_type=self._conv_type,
search_layer=self._search_layer,
use_fixed_gumbel=self.use_fixed_gumbel)
use_fixed_gumbel=self.use_fixed_gumbel,
gumbel_alphas=gumbel_alphas)
for s_emb, t_emb in zip(self.student.emb_names(),
self.teacher.emb_names()):
t_emb.stop_gradient = True
if fix_emb:
s_emb.stop_gradient = True
print(
"Assigning embedding[{}] from teacher to embedding[{}] in student.".
format(t_emb.name, s_emb.name))
fluid.layers.assign(input=t_emb, output=s_emb)
print(
"Assigned embedding[{}] from teacher to embedding[{}] in student.".
format(t_emb.name, s_emb.name))
self.cls_fc = list()
for i in range(self._n_layer):
......@@ -107,6 +126,11 @@ class AdaBERTClassifier(Layer):
def genotype(self):
return self.arch_parameters()
def ce(self, logits):
logits = np.exp(logits - np.max(logits))
logits = logits / logits.sum(axis=0)
return logits
def loss(self, data_ids):
src_ids = data_ids[0]
position_ids = data_ids[1]
......@@ -114,13 +138,53 @@ class AdaBERTClassifier(Layer):
input_mask = data_ids[3]
labels = data_ids[4]
enc_output = self.student(
s_logits = self.student(
src_ids, position_ids, sentence_ids, flops=[], model_size=[])
self.teacher.eval()
total_loss, t_logits, t_losses, accuracys, num_seqs = self.teacher(
data_ids)
# define kd loss
kd_losses = []
kd_weights = []
for i in range(len(s_logits)):
j = int(np.ceil(i * (float(len(t_logits)) / len(s_logits))))
kd_weights.append(t_losses[j].numpy())
kd_weights = np.array(kd_weights)
kd_weights = self.ce(-kd_weights)
s_probs = None
for i in range(len(s_logits)):
j = int(np.ceil(i * (float(len(t_logits)) / len(s_logits))))
t_logit = t_logits[j]
s_logit = s_logits[i]
t_logit.stop_gradient = True
t_probs = fluid.layers.softmax(t_logit / self.T)
s_probs = fluid.layers.softmax(s_logit)
#kd_loss = -t_probs * fluid.layers.log(s_probs)
kd_loss = fluid.layers.cross_entropy(
input=s_probs, label=t_probs, soft_label=True)
kd_loss = fluid.layers.reduce_sum(kd_loss, dim=1)
kd_loss = fluid.layers.mean(kd_loss)
# print("kd_loss[{}] = {}; kd_weights[{}] = {}".format(i, kd_loss.numpy(), i, kd_weights[i]))
# tmp = kd_loss * kd_weights[i]
tmp = fluid.layers.scale(kd_loss, scale=kd_weights[i])
# print("kd_loss[{}] = {}".format(i, tmp.numpy()))
kd_losses.append(tmp)
kd_loss = fluid.layers.sum(kd_losses)
# print("kd_loss = {}".format(kd_loss.numpy()))
ce_loss, probs = fluid.layers.softmax_with_cross_entropy(
logits=enc_output, label=labels, return_softmax=True)
loss = fluid.layers.mean(x=ce_loss)
logits=s_logits[-1], label=labels, return_softmax=True)
ce_loss = fluid.layers.mean(x=ce_loss)
num_seqs = fluid.layers.create_tensor(dtype='int64')
accuracy = fluid.layers.accuracy(
input=probs, label=labels, total=num_seqs)
return loss, accuracy
loss = (1 - self._gamma) * ce_loss + self._gamma * kd_loss
# return ce_loss, accuracy, None, None
return loss, accuracy, ce_loss, kd_loss
......@@ -43,7 +43,8 @@ class BertModelLayer(Layer):
conv_type="conv_bn",
search_layer=False,
use_fp16=False,
use_fixed_gumbel=False):
use_fixed_gumbel=False,
gumbel_alphas=None):
super(BertModelLayer, self).__init__()
self._emb_size = emb_size
......@@ -56,9 +57,9 @@ class BertModelLayer(Layer):
self.use_fixed_gumbel = use_fixed_gumbel
self._word_emb_name = "word_embedding"
self._pos_emb_name = "pos_embedding"
self._sent_emb_name = "sent_embedding"
self._word_emb_name = "s_word_embedding"
self._pos_emb_name = "s_pos_embedding"
self._sent_emb_name = "s_sent_embedding"
self._dtype = "float16" if use_fp16 else "float32"
self._conv_type = conv_type
......@@ -93,7 +94,12 @@ class BertModelLayer(Layer):
n_layer=self._n_layer,
hidden_size=self._hidden_size,
search_layer=self._search_layer,
use_fixed_gumbel=self.use_fixed_gumbel)
use_fixed_gumbel=self.use_fixed_gumbel,
gumbel_alphas=gumbel_alphas)
def emb_names(self):
return self._src_emb.parameters() + self._pos_emb.parameters(
) + self._sent_emb.parameters()
def max_flops(self):
return self._encoder.max_flops
......@@ -152,7 +158,6 @@ class BertModelLayer(Layer):
emb_out_1 = self._emb_fac(src_emb_1)
# (bs, seq_len, 768)
enc_output = self._encoder(
emb_out_0, emb_out_1, flops=flops, model_size=model_size)
return enc_output
enc_outputs = self._encoder(
emb_out, flops=flops, model_size=model_size)
return enc_outputs
......@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import numpy as np
from collections import Iterable
import paddle
import paddle.fluid as fluid
......@@ -81,7 +82,13 @@ class MixedOp(fluid.dygraph.Layer):
# return out
for i in range(len(self._ops)):
if weights[i].numpy() != 0:
if isinstance(weights, Iterable):
weights_i = weights[i]
else:
weights_i = weights[i].numpy()
if weights_i != 0:
return self._ops[i](x) * weights[i]
......@@ -212,7 +219,8 @@ class EncoderLayer(Layer):
hidden_size=768,
name="encoder",
search_layer=True,
use_fixed_gumbel=False):
use_fixed_gumbel=False,
gumbel_alphas=None):
super(EncoderLayer, self).__init__()
self._n_layer = n_layer
self._hidden_size = hidden_size
......@@ -259,26 +267,38 @@ class EncoderLayer(Layer):
# dtype="float32",
# default_initializer=NormalInitializer(
# loc=0.0, scale=1e-3))
self.BN = BatchNorm(
num_channels=self._n_channel,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=1),
trainable=False),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0),
trainable=False))
self.pool2d_avg = Pool2D(pool_type='avg', global_pooling=True)
self.bns = []
self.outs = []
for i in range(self._n_layer):
self.out = Linear(
self._n_channel,
3,
param_attr=ParamAttr(initializer=MSRA()),
bias_attr=ParamAttr(initializer=MSRA()))
bn = BatchNorm(
num_channels=self._n_channel,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=1),
trainable=False),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0),
trainable=False))
self.bns.append(bn)
out = Linear(
self._n_channel,
3,
param_attr=ParamAttr(initializer=MSRA()),
bias_attr=ParamAttr(initializer=MSRA()))
self.outs.append(out)
self.use_fixed_gumbel = use_fixed_gumbel
self.gumbel_alphas = gumbel_softmax(self.alphas).detach()
self.gumbel_alphas = gumbel_softmax(self.alphas)
if gumbel_alphas is not None:
self.gumbel_alphas = np.array(gumbel_alphas).reshape(
self.alphas.shape)
else:
self.gumbel_alphas = gumbel_softmax(self.alphas)
self.gumbel_alphas.stop_gradient = True
print("gumbel_alphas: {}".format(self.gumbel_alphas))
def forward(self, enc_input_0, enc_input_1, flops=[], model_size=[]):
alphas = self.gumbel_alphas if self.use_fixed_gumbel else gumbel_softmax(
......@@ -293,14 +313,19 @@ class EncoderLayer(Layer):
s0 = self.stem(s0)
s1 = self.stem(s1)
# (bs, n_channel, seq_len, 1)
if self.use_fixed_gumbel:
alphas = self.gumbel_alphas
else:
alphas = gumbel_softmax(self.alphas)
s0 = s1 = tmp
outputs = []
for i in range(self._n_layer):
s0, s1 = s1, self._cells[i](s0, s1, alphas)
# (bs, n_channel, seq_len, 1)
s1 = self.BN(s1)
outputs = self.pool2d_avg(s1)
outputs = fluid.layers.reshape(outputs, shape=[-1, 0])
outputs = self.out(outputs)
tmp = self.bns[i](s1)
tmp = self.pool2d_avg(tmp)
# (bs, n_channel, seq_len, 1)
tmp = fluid.layers.reshape(tmp, shape=[-1, 0])
tmp = self.outs[i](tmp)
outputs.append(tmp)
return outputs
......@@ -101,6 +101,9 @@ class BERTClassifier(Layer):
"You should load pretrained model for training this teacher model."
)
def emb_names(self):
return self.cls_model.emb_names()
def forward(self, input):
return self.cls_model(input)
......
......@@ -122,6 +122,10 @@ class BertModelLayer(Layer):
postprocess_cmd="dan",
param_initializer=self._param_initializer)
def emb_names(self):
return self._src_emb.parameters() + self._pos_emb.parameters(
) + self._sent_emb.parameters()
def forward(self, src_ids, position_ids, sentence_ids, input_mask):
"""
forward
......
......@@ -64,6 +64,9 @@ class ClsModelLayer(Layer):
fc = self.add_sublayer("cls_fc_%d" % i, fc)
self.cls_fc.append(fc)
def emb_names(self):
return self.bert_layer.emb_names()
def forward(self, data_ids):
"""
forward
......
......@@ -44,6 +44,10 @@ class DataProcessor(object):
self.num_examples = {'train': -1, 'dev': -1, 'test': -1}
self.current_train_epoch = -1
def get_train_aug_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
......@@ -110,7 +114,7 @@ class DataProcessor(object):
def get_num_examples(self, phase):
"""Get number of examples for train, dev or test."""
if phase not in ['train', 'dev', 'test']:
if phase not in ['train', 'dev', 'test', 'train_aug']:
raise ValueError(
"Unknown phase, which should be in ['train', 'dev', 'test'].")
return self.num_examples[phase]
......@@ -138,6 +142,9 @@ class DataProcessor(object):
if phase == 'train':
examples = self.get_train_examples(self.data_dir)
self.num_examples['train'] = len(examples)
elif phase == 'train_aug':
examples = self.get_train_aug_examples(self.data_dir)
self.num_examples['train'] = len(examples)
elif phase == 'dev':
examples = self.get_dev_examples(self.data_dir)
self.num_examples['dev'] = len(examples)
......@@ -337,6 +344,11 @@ class XnliProcessor(DataProcessor):
class MnliProcessor(DataProcessor):
"""Processor for the MultiNLI data set (GLUE version)."""
def get_train_aug_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train_aug.tsv")), "train")
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册