未验证 提交 7a529cf1 编写于 作者: W whs 提交者: GitHub

Add distillation for BERT (#330)

* Fix import in py3.

* Update bert distillation.

* Add option to fix embedding in bert distillation.

* Fix distillation in adabert searching

* Remove unused code
上级 491c3489
...@@ -52,7 +52,10 @@ class AdaBERTClassifier(Layer): ...@@ -52,7 +52,10 @@ class AdaBERTClassifier(Layer):
search_layer=False, search_layer=False,
teacher_model=None, teacher_model=None,
data_dir=None, data_dir=None,
use_fixed_gumbel=False): use_fixed_gumbel=False,
gumbel_alphas=None,
fix_emb=False,
t=5.0):
super(AdaBERTClassifier, self).__init__() super(AdaBERTClassifier, self).__init__()
self._n_layer = n_layer self._n_layer = n_layer
self._num_labels = num_labels self._num_labels = num_labels
...@@ -65,21 +68,37 @@ class AdaBERTClassifier(Layer): ...@@ -65,21 +68,37 @@ class AdaBERTClassifier(Layer):
self._teacher_model = teacher_model self._teacher_model = teacher_model
self._data_dir = data_dir self._data_dir = data_dir
self.use_fixed_gumbel = use_fixed_gumbel self.use_fixed_gumbel = use_fixed_gumbel
#print( self.T = t
# "----------------------load teacher model and test----------------------------------------" print(
#) "----------------------load teacher model and test----------------------------------------"
#self.teacher = BERTClassifier(num_labels, model_path=self._teacher_model) )
#self.teacher.test(self._data_dir) self.teacher = BERTClassifier(
#print( num_labels, model_path=self._teacher_model)
# "----------------------finish load teacher model and test----------------------------------------" self.teacher.test(self._data_dir)
#) print(
"----------------------finish load teacher model and test----------------------------------------"
)
self.student = BertModelLayer( self.student = BertModelLayer(
n_layer=self._n_layer, n_layer=self._n_layer,
emb_size=self._emb_size, emb_size=self._emb_size,
hidden_size=self._hidden_size, hidden_size=self._hidden_size,
conv_type=self._conv_type, conv_type=self._conv_type,
search_layer=self._search_layer, search_layer=self._search_layer,
use_fixed_gumbel=self.use_fixed_gumbel) use_fixed_gumbel=self.use_fixed_gumbel,
gumbel_alphas=gumbel_alphas)
for s_emb, t_emb in zip(self.student.emb_names(),
self.teacher.emb_names()):
t_emb.stop_gradient = True
if fix_emb:
s_emb.stop_gradient = True
print(
"Assigning embedding[{}] from teacher to embedding[{}] in student.".
format(t_emb.name, s_emb.name))
fluid.layers.assign(input=t_emb, output=s_emb)
print(
"Assigned embedding[{}] from teacher to embedding[{}] in student.".
format(t_emb.name, s_emb.name))
self.cls_fc = list() self.cls_fc = list()
for i in range(self._n_layer): for i in range(self._n_layer):
...@@ -107,6 +126,11 @@ class AdaBERTClassifier(Layer): ...@@ -107,6 +126,11 @@ class AdaBERTClassifier(Layer):
def genotype(self): def genotype(self):
return self.arch_parameters() return self.arch_parameters()
def ce(self, logits):
logits = np.exp(logits - np.max(logits))
logits = logits / logits.sum(axis=0)
return logits
def loss(self, data_ids): def loss(self, data_ids):
src_ids = data_ids[0] src_ids = data_ids[0]
position_ids = data_ids[1] position_ids = data_ids[1]
...@@ -114,13 +138,53 @@ class AdaBERTClassifier(Layer): ...@@ -114,13 +138,53 @@ class AdaBERTClassifier(Layer):
input_mask = data_ids[3] input_mask = data_ids[3]
labels = data_ids[4] labels = data_ids[4]
enc_output = self.student( s_logits = self.student(
src_ids, position_ids, sentence_ids, flops=[], model_size=[]) src_ids, position_ids, sentence_ids, flops=[], model_size=[])
self.teacher.eval()
total_loss, t_logits, t_losses, accuracys, num_seqs = self.teacher(
data_ids)
# define kd loss
kd_losses = []
kd_weights = []
for i in range(len(s_logits)):
j = int(np.ceil(i * (float(len(t_logits)) / len(s_logits))))
kd_weights.append(t_losses[j].numpy())
kd_weights = np.array(kd_weights)
kd_weights = self.ce(-kd_weights)
s_probs = None
for i in range(len(s_logits)):
j = int(np.ceil(i * (float(len(t_logits)) / len(s_logits))))
t_logit = t_logits[j]
s_logit = s_logits[i]
t_logit.stop_gradient = True
t_probs = fluid.layers.softmax(t_logit / self.T)
s_probs = fluid.layers.softmax(s_logit)
#kd_loss = -t_probs * fluid.layers.log(s_probs)
kd_loss = fluid.layers.cross_entropy(
input=s_probs, label=t_probs, soft_label=True)
kd_loss = fluid.layers.reduce_sum(kd_loss, dim=1)
kd_loss = fluid.layers.mean(kd_loss)
# print("kd_loss[{}] = {}; kd_weights[{}] = {}".format(i, kd_loss.numpy(), i, kd_weights[i]))
# tmp = kd_loss * kd_weights[i]
tmp = fluid.layers.scale(kd_loss, scale=kd_weights[i])
# print("kd_loss[{}] = {}".format(i, tmp.numpy()))
kd_losses.append(tmp)
kd_loss = fluid.layers.sum(kd_losses)
# print("kd_loss = {}".format(kd_loss.numpy()))
ce_loss, probs = fluid.layers.softmax_with_cross_entropy( ce_loss, probs = fluid.layers.softmax_with_cross_entropy(
logits=enc_output, label=labels, return_softmax=True) logits=s_logits[-1], label=labels, return_softmax=True)
loss = fluid.layers.mean(x=ce_loss) ce_loss = fluid.layers.mean(x=ce_loss)
num_seqs = fluid.layers.create_tensor(dtype='int64') num_seqs = fluid.layers.create_tensor(dtype='int64')
accuracy = fluid.layers.accuracy( accuracy = fluid.layers.accuracy(
input=probs, label=labels, total=num_seqs) input=probs, label=labels, total=num_seqs)
return loss, accuracy
loss = (1 - self._gamma) * ce_loss + self._gamma * kd_loss
# return ce_loss, accuracy, None, None
return loss, accuracy, ce_loss, kd_loss
...@@ -43,7 +43,8 @@ class BertModelLayer(Layer): ...@@ -43,7 +43,8 @@ class BertModelLayer(Layer):
conv_type="conv_bn", conv_type="conv_bn",
search_layer=False, search_layer=False,
use_fp16=False, use_fp16=False,
use_fixed_gumbel=False): use_fixed_gumbel=False,
gumbel_alphas=None):
super(BertModelLayer, self).__init__() super(BertModelLayer, self).__init__()
self._emb_size = emb_size self._emb_size = emb_size
...@@ -56,9 +57,9 @@ class BertModelLayer(Layer): ...@@ -56,9 +57,9 @@ class BertModelLayer(Layer):
self.use_fixed_gumbel = use_fixed_gumbel self.use_fixed_gumbel = use_fixed_gumbel
self._word_emb_name = "word_embedding" self._word_emb_name = "s_word_embedding"
self._pos_emb_name = "pos_embedding" self._pos_emb_name = "s_pos_embedding"
self._sent_emb_name = "sent_embedding" self._sent_emb_name = "s_sent_embedding"
self._dtype = "float16" if use_fp16 else "float32" self._dtype = "float16" if use_fp16 else "float32"
self._conv_type = conv_type self._conv_type = conv_type
...@@ -93,7 +94,12 @@ class BertModelLayer(Layer): ...@@ -93,7 +94,12 @@ class BertModelLayer(Layer):
n_layer=self._n_layer, n_layer=self._n_layer,
hidden_size=self._hidden_size, hidden_size=self._hidden_size,
search_layer=self._search_layer, search_layer=self._search_layer,
use_fixed_gumbel=self.use_fixed_gumbel) use_fixed_gumbel=self.use_fixed_gumbel,
gumbel_alphas=gumbel_alphas)
def emb_names(self):
return self._src_emb.parameters() + self._pos_emb.parameters(
) + self._sent_emb.parameters()
def max_flops(self): def max_flops(self):
return self._encoder.max_flops return self._encoder.max_flops
...@@ -152,7 +158,6 @@ class BertModelLayer(Layer): ...@@ -152,7 +158,6 @@ class BertModelLayer(Layer):
emb_out_1 = self._emb_fac(src_emb_1) emb_out_1 = self._emb_fac(src_emb_1)
# (bs, seq_len, 768) # (bs, seq_len, 768)
enc_output = self._encoder( enc_outputs = self._encoder(
emb_out_0, emb_out_1, flops=flops, model_size=model_size) emb_out, flops=flops, model_size=model_size)
return enc_outputs
return enc_output
...@@ -18,6 +18,7 @@ from __future__ import division ...@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
from collections import Iterable
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -81,7 +82,13 @@ class MixedOp(fluid.dygraph.Layer): ...@@ -81,7 +82,13 @@ class MixedOp(fluid.dygraph.Layer):
# return out # return out
for i in range(len(self._ops)): for i in range(len(self._ops)):
if weights[i].numpy() != 0:
if isinstance(weights, Iterable):
weights_i = weights[i]
else:
weights_i = weights[i].numpy()
if weights_i != 0:
return self._ops[i](x) * weights[i] return self._ops[i](x) * weights[i]
...@@ -212,7 +219,8 @@ class EncoderLayer(Layer): ...@@ -212,7 +219,8 @@ class EncoderLayer(Layer):
hidden_size=768, hidden_size=768,
name="encoder", name="encoder",
search_layer=True, search_layer=True,
use_fixed_gumbel=False): use_fixed_gumbel=False,
gumbel_alphas=None):
super(EncoderLayer, self).__init__() super(EncoderLayer, self).__init__()
self._n_layer = n_layer self._n_layer = n_layer
self._hidden_size = hidden_size self._hidden_size = hidden_size
...@@ -259,26 +267,38 @@ class EncoderLayer(Layer): ...@@ -259,26 +267,38 @@ class EncoderLayer(Layer):
# dtype="float32", # dtype="float32",
# default_initializer=NormalInitializer( # default_initializer=NormalInitializer(
# loc=0.0, scale=1e-3)) # loc=0.0, scale=1e-3))
self.BN = BatchNorm(
num_channels=self._n_channel,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=1),
trainable=False),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0),
trainable=False))
self.pool2d_avg = Pool2D(pool_type='avg', global_pooling=True) self.pool2d_avg = Pool2D(pool_type='avg', global_pooling=True)
self.bns = []
self.outs = []
for i in range(self._n_layer):
self.out = Linear( bn = BatchNorm(
self._n_channel, num_channels=self._n_channel,
3, param_attr=fluid.ParamAttr(
param_attr=ParamAttr(initializer=MSRA()), initializer=fluid.initializer.Constant(value=1),
bias_attr=ParamAttr(initializer=MSRA())) trainable=False),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0),
trainable=False))
self.bns.append(bn)
out = Linear(
self._n_channel,
3,
param_attr=ParamAttr(initializer=MSRA()),
bias_attr=ParamAttr(initializer=MSRA()))
self.outs.append(out)
self.use_fixed_gumbel = use_fixed_gumbel self.use_fixed_gumbel = use_fixed_gumbel
self.gumbel_alphas = gumbel_softmax(self.alphas).detach() self.gumbel_alphas = gumbel_softmax(self.alphas)
if gumbel_alphas is not None:
self.gumbel_alphas = np.array(gumbel_alphas).reshape(
self.alphas.shape)
else:
self.gumbel_alphas = gumbel_softmax(self.alphas)
self.gumbel_alphas.stop_gradient = True
print("gumbel_alphas: {}".format(self.gumbel_alphas))
def forward(self, enc_input_0, enc_input_1, flops=[], model_size=[]): def forward(self, enc_input_0, enc_input_1, flops=[], model_size=[]):
alphas = self.gumbel_alphas if self.use_fixed_gumbel else gumbel_softmax( alphas = self.gumbel_alphas if self.use_fixed_gumbel else gumbel_softmax(
...@@ -293,14 +313,19 @@ class EncoderLayer(Layer): ...@@ -293,14 +313,19 @@ class EncoderLayer(Layer):
s0 = self.stem(s0) s0 = self.stem(s0)
s1 = self.stem(s1) s1 = self.stem(s1)
# (bs, n_channel, seq_len, 1) # (bs, n_channel, seq_len, 1)
if self.use_fixed_gumbel:
alphas = self.gumbel_alphas
else:
alphas = gumbel_softmax(self.alphas)
s0 = s1 = tmp
outputs = []
for i in range(self._n_layer): for i in range(self._n_layer):
s0, s1 = s1, self._cells[i](s0, s1, alphas) s0, s1 = s1, self._cells[i](s0, s1, alphas)
# (bs, n_channel, seq_len, 1) tmp = self.bns[i](s1)
tmp = self.pool2d_avg(tmp)
s1 = self.BN(s1) # (bs, n_channel, seq_len, 1)
tmp = fluid.layers.reshape(tmp, shape=[-1, 0])
outputs = self.pool2d_avg(s1) tmp = self.outs[i](tmp)
outputs = fluid.layers.reshape(outputs, shape=[-1, 0]) outputs.append(tmp)
outputs = self.out(outputs)
return outputs return outputs
...@@ -101,6 +101,9 @@ class BERTClassifier(Layer): ...@@ -101,6 +101,9 @@ class BERTClassifier(Layer):
"You should load pretrained model for training this teacher model." "You should load pretrained model for training this teacher model."
) )
def emb_names(self):
return self.cls_model.emb_names()
def forward(self, input): def forward(self, input):
return self.cls_model(input) return self.cls_model(input)
......
...@@ -122,6 +122,10 @@ class BertModelLayer(Layer): ...@@ -122,6 +122,10 @@ class BertModelLayer(Layer):
postprocess_cmd="dan", postprocess_cmd="dan",
param_initializer=self._param_initializer) param_initializer=self._param_initializer)
def emb_names(self):
return self._src_emb.parameters() + self._pos_emb.parameters(
) + self._sent_emb.parameters()
def forward(self, src_ids, position_ids, sentence_ids, input_mask): def forward(self, src_ids, position_ids, sentence_ids, input_mask):
""" """
forward forward
......
...@@ -64,6 +64,9 @@ class ClsModelLayer(Layer): ...@@ -64,6 +64,9 @@ class ClsModelLayer(Layer):
fc = self.add_sublayer("cls_fc_%d" % i, fc) fc = self.add_sublayer("cls_fc_%d" % i, fc)
self.cls_fc.append(fc) self.cls_fc.append(fc)
def emb_names(self):
return self.bert_layer.emb_names()
def forward(self, data_ids): def forward(self, data_ids):
""" """
forward forward
......
...@@ -44,6 +44,10 @@ class DataProcessor(object): ...@@ -44,6 +44,10 @@ class DataProcessor(object):
self.num_examples = {'train': -1, 'dev': -1, 'test': -1} self.num_examples = {'train': -1, 'dev': -1, 'test': -1}
self.current_train_epoch = -1 self.current_train_epoch = -1
def get_train_aug_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set.""" """Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError() raise NotImplementedError()
...@@ -110,7 +114,7 @@ class DataProcessor(object): ...@@ -110,7 +114,7 @@ class DataProcessor(object):
def get_num_examples(self, phase): def get_num_examples(self, phase):
"""Get number of examples for train, dev or test.""" """Get number of examples for train, dev or test."""
if phase not in ['train', 'dev', 'test']: if phase not in ['train', 'dev', 'test', 'train_aug']:
raise ValueError( raise ValueError(
"Unknown phase, which should be in ['train', 'dev', 'test'].") "Unknown phase, which should be in ['train', 'dev', 'test'].")
return self.num_examples[phase] return self.num_examples[phase]
...@@ -138,6 +142,9 @@ class DataProcessor(object): ...@@ -138,6 +142,9 @@ class DataProcessor(object):
if phase == 'train': if phase == 'train':
examples = self.get_train_examples(self.data_dir) examples = self.get_train_examples(self.data_dir)
self.num_examples['train'] = len(examples) self.num_examples['train'] = len(examples)
elif phase == 'train_aug':
examples = self.get_train_aug_examples(self.data_dir)
self.num_examples['train'] = len(examples)
elif phase == 'dev': elif phase == 'dev':
examples = self.get_dev_examples(self.data_dir) examples = self.get_dev_examples(self.data_dir)
self.num_examples['dev'] = len(examples) self.num_examples['dev'] = len(examples)
...@@ -337,6 +344,11 @@ class XnliProcessor(DataProcessor): ...@@ -337,6 +344,11 @@ class XnliProcessor(DataProcessor):
class MnliProcessor(DataProcessor): class MnliProcessor(DataProcessor):
"""Processor for the MultiNLI data set (GLUE version).""" """Processor for the MultiNLI data set (GLUE version)."""
def get_train_aug_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train_aug.tsv")), "train")
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册