未验证 提交 3d1bd8ae 编写于 作者: B Bai Yifan 提交者: GitHub

Merge branch 'develop' into fix_conflicts

...@@ -53,7 +53,10 @@ class AdaBERTClassifier(Layer): ...@@ -53,7 +53,10 @@ class AdaBERTClassifier(Layer):
search_layer=False, search_layer=False,
teacher_model=None, teacher_model=None,
data_dir=None, data_dir=None,
use_fixed_gumbel=False): use_fixed_gumbel=False,
gumbel_alphas=None,
fix_emb=False,
t=5.0):
super(AdaBERTClassifier, self).__init__() super(AdaBERTClassifier, self).__init__()
self._n_layer = n_layer self._n_layer = n_layer
self._num_labels = num_labels self._num_labels = num_labels
...@@ -66,7 +69,8 @@ class AdaBERTClassifier(Layer): ...@@ -66,7 +69,8 @@ class AdaBERTClassifier(Layer):
self._teacher_model = teacher_model self._teacher_model = teacher_model
self._data_dir = data_dir self._data_dir = data_dir
self.use_fixed_gumbel = use_fixed_gumbel self.use_fixed_gumbel = use_fixed_gumbel
self.T = 1.0
self.T = t
print( print(
"----------------------load teacher model and test----------------------------------------" "----------------------load teacher model and test----------------------------------------"
) )
...@@ -74,7 +78,7 @@ class AdaBERTClassifier(Layer): ...@@ -74,7 +78,7 @@ class AdaBERTClassifier(Layer):
num_labels, model_path=self._teacher_model) num_labels, model_path=self._teacher_model)
# global setting, will be overwritten when training(about 1% acc loss) # global setting, will be overwritten when training(about 1% acc loss)
self.teacher.eval() self.teacher.eval()
#self.teacher.test(self._data_dir) self.teacher.test(self._data_dir)
print( print(
"----------------------finish load teacher model and test----------------------------------------" "----------------------finish load teacher model and test----------------------------------------"
) )
...@@ -84,7 +88,21 @@ class AdaBERTClassifier(Layer): ...@@ -84,7 +88,21 @@ class AdaBERTClassifier(Layer):
hidden_size=self._hidden_size, hidden_size=self._hidden_size,
conv_type=self._conv_type, conv_type=self._conv_type,
search_layer=self._search_layer, search_layer=self._search_layer,
use_fixed_gumbel=self.use_fixed_gumbel) use_fixed_gumbel=self.use_fixed_gumbel,
gumbel_alphas=gumbel_alphas)
for s_emb, t_emb in zip(self.student.emb_names(),
self.teacher.emb_names()):
t_emb.stop_gradient = True
if fix_emb:
s_emb.stop_gradient = True
print(
"Assigning embedding[{}] from teacher to embedding[{}] in student.".
format(t_emb.name, s_emb.name))
fluid.layers.assign(input=t_emb, output=s_emb)
print(
"Assigned embedding[{}] from teacher to embedding[{}] in student.".
format(t_emb.name, s_emb.name))
fix_emb = False fix_emb = False
for s_emb, t_emb in zip(self.student.emb_names(), for s_emb, t_emb in zip(self.student.emb_names(),
...@@ -155,3 +173,4 @@ class AdaBERTClassifier(Layer): ...@@ -155,3 +173,4 @@ class AdaBERTClassifier(Layer):
total_loss = (1 - self._gamma) * ce_loss + self._gamma * kd_loss total_loss = (1 - self._gamma) * ce_loss + self._gamma * kd_loss
return total_loss, accuracy, ce_loss, kd_loss, s_logits return total_loss, accuracy, ce_loss, kd_loss, s_logits
...@@ -43,7 +43,8 @@ class BertModelLayer(Layer): ...@@ -43,7 +43,8 @@ class BertModelLayer(Layer):
conv_type="conv_bn", conv_type="conv_bn",
search_layer=False, search_layer=False,
use_fp16=False, use_fp16=False,
use_fixed_gumbel=False): use_fixed_gumbel=False,
gumbel_alphas=None):
super(BertModelLayer, self).__init__() super(BertModelLayer, self).__init__()
self._emb_size = emb_size self._emb_size = emb_size
...@@ -93,7 +94,12 @@ class BertModelLayer(Layer): ...@@ -93,7 +94,12 @@ class BertModelLayer(Layer):
n_layer=self._n_layer, n_layer=self._n_layer,
hidden_size=self._hidden_size, hidden_size=self._hidden_size,
search_layer=self._search_layer, search_layer=self._search_layer,
use_fixed_gumbel=self.use_fixed_gumbel) use_fixed_gumbel=self.use_fixed_gumbel,
gumbel_alphas=gumbel_alphas)
def emb_names(self):
return self._src_emb.parameters() + self._pos_emb.parameters(
) + self._sent_emb.parameters()
def emb_names(self): def emb_names(self):
return self._src_emb.parameters() + self._pos_emb.parameters( return self._src_emb.parameters() + self._pos_emb.parameters(
......
...@@ -18,6 +18,7 @@ from __future__ import division ...@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
from collections import Iterable
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -203,7 +204,8 @@ class EncoderLayer(Layer): ...@@ -203,7 +204,8 @@ class EncoderLayer(Layer):
hidden_size=768, hidden_size=768,
name="encoder", name="encoder",
search_layer=True, search_layer=True,
use_fixed_gumbel=False): use_fixed_gumbel=False,
gumbel_alphas=None):
super(EncoderLayer, self).__init__() super(EncoderLayer, self).__init__()
self._n_layer = n_layer self._n_layer = n_layer
self._hidden_size = hidden_size self._hidden_size = hidden_size
...@@ -260,8 +262,8 @@ class EncoderLayer(Layer): ...@@ -260,8 +262,8 @@ class EncoderLayer(Layer):
default_initializer=NormalInitializer( default_initializer=NormalInitializer(
loc=0.0, scale=1e-3)) loc=0.0, scale=1e-3))
self.pool2d_avg = Pool2D(pool_type='avg', global_pooling=True)
self.pool2d_avg = Pool2D(pool_type='avg', global_pooling=True)
self.bns = [] self.bns = []
self.outs = [] self.outs = []
for i in range(self._n_layer): for i in range(self._n_layer):
...@@ -303,6 +305,21 @@ class EncoderLayer(Layer): ...@@ -303,6 +305,21 @@ class EncoderLayer(Layer):
def forward(self, enc_input_0, enc_input_1, epoch, flops=[], def forward(self, enc_input_0, enc_input_1, epoch, flops=[],
model_size=[]): model_size=[]):
=======
self.outs.append(out)
self.use_fixed_gumbel = use_fixed_gumbel
self.gumbel_alphas = gumbel_softmax(self.alphas)
if gumbel_alphas is not None:
self.gumbel_alphas = np.array(gumbel_alphas).reshape(
self.alphas.shape)
else:
self.gumbel_alphas = gumbel_softmax(self.alphas)
self.gumbel_alphas.stop_gradient = True
print("gumbel_alphas: {}".format(self.gumbel_alphas))
def forward(self, enc_input_0, enc_input_1, flops=[], model_size=[]):
alphas = self.gumbel_alphas if self.use_fixed_gumbel else gumbel_softmax( alphas = self.gumbel_alphas if self.use_fixed_gumbel else gumbel_softmax(
self.alphas, epoch) self.alphas, epoch)
......
...@@ -100,17 +100,32 @@ class DARTSearch(object): ...@@ -100,17 +100,32 @@ class DARTSearch(object):
def train_one_epoch(self, train_loader, valid_loader, architect, optimizer, def train_one_epoch(self, train_loader, valid_loader, architect, optimizer,
epoch): epoch):
objs = AvgrageMeter() objs = AvgrageMeter()
ce_losses = AvgrageMeter() top1 = AvgrageMeter()
kd_losses = AvgrageMeter() top5 = AvgrageMeter()
e_losses = AvgrageMeter()
self.model.train() self.model.train()
step_id = 0 for step_id, (
for train_data, valid_data in zip(train_loader(), valid_loader()): train_data,
valid_data) in enumerate(zip(train_loader(), valid_loader())):
train_image, train_label = train_data
valid_image, valid_label = valid_data
train_image = to_variable(train_image)
train_label = to_variable(train_label)
train_label.stop_gradient = True
valid_image = to_variable(valid_image)
valid_label = to_variable(valid_label)
valid_label.stop_gradient = True
n = train_image.shape[0]
if epoch >= self.epochs_no_archopt: if epoch >= self.epochs_no_archopt:
architect.step(train_data, valid_data) architect.step(train_image, train_label, valid_image,
valid_label)
loss, ce_loss, kd_loss, e_loss = self.model.loss(train_data) logits = self.model(train_image)
prec1 = fluid.layers.accuracy(input=logits, label=train_label, k=1)
prec5 = fluid.layers.accuracy(input=logits, label=train_label, k=5)
loss = fluid.layers.reduce_mean(
fluid.layers.softmax_with_cross_entropy(logits, train_label))
if self.use_data_parallel: if self.use_data_parallel:
loss = self.model.scale_loss(loss) loss = self.model.scale_loss(loss)
...@@ -122,22 +137,18 @@ class DARTSearch(object): ...@@ -122,22 +137,18 @@ class DARTSearch(object):
optimizer.minimize(loss) optimizer.minimize(loss)
self.model.clear_gradients() self.model.clear_gradients()
batch_size = train_data[0].shape[0] objs.update(loss.numpy(), n)
objs.update(loss.numpy(), batch_size) top1.update(prec1.numpy(), n)
ce_losses.update(ce_loss.numpy(), batch_size) top5.update(prec5.numpy(), n)
kd_losses.update(kd_loss.numpy(), batch_size)
e_losses.update(e_loss.numpy(), batch_size)
if step_id % self.log_freq == 0: if step_id % self.log_freq == 0:
#logger.info("Train Epoch {}, Step {}, loss {:.6f}; ce: {:.6f}; kd: {:.6f}; e: {:.6f}".format( #logger.info("Train Epoch {}, Step {}, loss {:.6f}; ce: {:.6f}; kd: {:.6f}; e: {:.6f}".format(
# epoch, step_id, objs.avg[0], ce_losses.avg[0], kd_losses.avg[0], e_losses.avg[0])) # epoch, step_id, objs.avg[0], ce_losses.avg[0], kd_losses.avg[0], e_losses.avg[0]))
logger.info( logger.info(
"Train Epoch {}, Step {}, loss {}; ce: {}; kd: {}; e: {}". "Train Epoch {}, Step {}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}".
format(epoch, step_id, format(epoch, step_id, objs.avg[0], top1.avg[0], top5.avg[
loss.numpy(), 0]))
ce_loss.numpy(), kd_loss.numpy(), e_loss.numpy())) return top1.avg[0]
step_id += 1
return objs.avg[0]
def valid_one_epoch(self, valid_loader, epoch): def valid_one_epoch(self, valid_loader, epoch):
objs = AvgrageMeter() objs = AvgrageMeter()
...@@ -145,7 +156,7 @@ class DARTSearch(object): ...@@ -145,7 +156,7 @@ class DARTSearch(object):
top5 = AvgrageMeter() top5 = AvgrageMeter()
self.model.eval() self.model.eval()
for step_id, valid_data in enumerate(valid_loader): for step_id, (image, label) in enumerate(valid_loader):
image = to_variable(image) image = to_variable(image)
label = to_variable(label) label = to_variable(label)
n = image.shape[0] n = image.shape[0]
...@@ -235,12 +246,14 @@ class DARTSearch(object): ...@@ -235,12 +246,14 @@ class DARTSearch(object):
genotype = get_genotype(base_model) genotype = get_genotype(base_model)
logger.info('genotype = %s', genotype) logger.info('genotype = %s', genotype)
self.train_one_epoch(train_loader, valid_loader, architect, train_top1 = self.train_one_epoch(train_loader, valid_loader,
optimizer, epoch) architect, optimizer, epoch)
logger.info("Epoch {}, train_acc {:.6f}".format(epoch, train_top1))
if epoch == self.num_epochs - 1: if epoch == self.num_epochs - 1:
# valid_top1 = self.valid_one_epoch(valid_loader, epoch) valid_top1 = self.valid_one_epoch(valid_loader, epoch)
logger.info("Epoch {}, valid_acc {:.6f}".format(epoch, 1)) logger.info("Epoch {}, valid_acc {:.6f}".format(epoch,
valid_top1))
if save_parameters: if save_parameters:
fluid.save_dygraph( fluid.save_dygraph(
self.model.state_dict(), self.model.state_dict(),
......
...@@ -542,7 +542,7 @@ class depthwise_conv2d(PruneWorker): ...@@ -542,7 +542,7 @@ class depthwise_conv2d(PruneWorker):
self._visit(filter_var, 0) self._visit(filter_var, 0)
new_groups = filter_var.shape()[0] - len(pruned_idx) new_groups = filter_var.shape()[0] - len(pruned_idx)
op.set_attr("groups", new_groups) self.op.set_attr("groups", new_groups)
for op in filter_var.outputs(): for op in filter_var.outputs():
self._prune_op(op, filter_var, 0, pruned_idx) self._prune_op(op, filter_var, 0, pruned_idx)
......
...@@ -45,6 +45,10 @@ class DataProcessor(object): ...@@ -45,6 +45,10 @@ class DataProcessor(object):
self.num_examples = {'train': -1, 'dev': -1, 'test': -1} self.num_examples = {'train': -1, 'dev': -1, 'test': -1}
self.current_train_epoch = -1 self.current_train_epoch = -1
def get_train_aug_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set.""" """Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError() raise NotImplementedError()
...@@ -111,9 +115,9 @@ class DataProcessor(object): ...@@ -111,9 +115,9 @@ class DataProcessor(object):
def get_num_examples(self, phase): def get_num_examples(self, phase):
"""Get number of examples for train, dev or test.""" """Get number of examples for train, dev or test."""
#if phase not in ['train', 'dev', 'test']: if phase not in ['train', 'dev', 'test', 'train_aug']:
# raise ValueError( raise ValueError(
# "Unknown phase, which should be in ['train', 'dev', 'test'].") "Unknown phase, which should be in ['train', 'dev', 'test'].")
return self.num_examples[phase] return self.num_examples[phase]
def get_train_progress(self): def get_train_progress(self):
...@@ -141,6 +145,9 @@ class DataProcessor(object): ...@@ -141,6 +145,9 @@ class DataProcessor(object):
if phase == 'train': if phase == 'train':
examples = self.get_train_examples(self.data_dir) examples = self.get_train_examples(self.data_dir)
self.num_examples['train'] = len(examples) self.num_examples['train'] = len(examples)
elif phase == 'train_aug':
examples = self.get_train_aug_examples(self.data_dir)
self.num_examples['train'] = len(examples)
elif phase == 'dev': elif phase == 'dev':
examples = self.get_dev_examples(self.data_dir) examples = self.get_dev_examples(self.data_dir)
self.num_examples['dev'] = len(examples) self.num_examples['dev'] = len(examples)
...@@ -377,6 +384,11 @@ class XnliProcessor(DataProcessor): ...@@ -377,6 +384,11 @@ class XnliProcessor(DataProcessor):
class MnliProcessor(DataProcessor): class MnliProcessor(DataProcessor):
"""Processor for the MultiNLI data set (GLUE version).""" """Processor for the MultiNLI data set (GLUE version)."""
def get_train_aug_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train_aug.tsv")), "train")
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(
......
...@@ -33,7 +33,7 @@ class TestPrune(unittest.TestCase): ...@@ -33,7 +33,7 @@ class TestPrune(unittest.TestCase):
sum2 = conv4 + sum1 sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5") conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
conv6 = conv_bn_layer(conv5, 8, 3, "conv6") conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
self.assertTrue(1597440 == flops(main_program)) self.assertTrue(792576 == flops(main_program))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -57,7 +57,7 @@ class TestPrune(unittest.TestCase): ...@@ -57,7 +57,7 @@ class TestPrune(unittest.TestCase):
conv_op = graph.var("conv4_weights").outputs()[0] conv_op = graph.var("conv4_weights").outputs()[0]
walker = conv2d_walker(conv_op, []) walker = conv2d_walker(conv_op, [])
walker.prune(graph.var("conv4_weights"), pruned_axis=0, pruned_idx=[]) walker.prune(graph.var("conv4_weights"), pruned_axis=0, pruned_idx=[])
print walker.pruned_params print(walker.pruned_params)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册