提交 f7e57cb6 编写于 作者: B baiyfbupt

add mrpc exp

上级 5e416847
......@@ -81,20 +81,33 @@ def main():
BERT_BASE_PATH = "./data/pretrained_models/uncased_L-12_H-768_A-12"
vocab_path = BERT_BASE_PATH + "/vocab.txt"
data_dir = "./data/glue_data/MNLI/"
teacher_model_dir = "./data/teacher_model/steps_23000"
do_lower_case = True
num_samples = 392702
# augmented dataset nums
# num_samples = 8016987
max_seq_len = 128
batch_size = 192
hidden_size = 768
emb_size = 768
max_layer = 8
epoch = 80
log_freq = 10
task_name = 'mnli'
if task_name == 'mrpc':
data_dir = "./data/glue_data/MRPC/"
teacher_model_dir = "./data/teacher_model/mrpc"
num_samples = 3668
max_layer = 4
processor_func = MrpcProcessor
elif task_name == 'mnli':
data_dir = "./data/glue_data/MNLI/"
teacher_model_dir = "./data/teacher_model/step_23000"
num_samples = 392702
max_layer = 8
processor_func = MnliProcessor
device_num = fluid.dygraph.parallel.Env().nranks
use_fixed_gumbel = True
train_phase = "train"
......@@ -107,9 +120,10 @@ def main():
np.random.seed(1)
fluid.default_main_program().random_seed = 1
model = AdaBERTClassifier(
3,
2,
n_layer=max_layer,
hidden_size=hidden_size,
task_name=task_name,
emb_size=emb_size,
teacher_model=teacher_model_dir,
data_dir=data_dir,
......@@ -130,7 +144,7 @@ def main():
regularization=fluid.regularizer.L2DecayRegularizer(3e-4),
parameter_list=model_parameters)
processor = MnliProcessor(
processor = processor_func(
data_dir=data_dir,
vocab_path=vocab_path,
max_seq_len=max_seq_len,
......@@ -172,12 +186,16 @@ def main():
strategy = fluid.dygraph.parallel.prepare_context()
model = fluid.dygraph.parallel.DataParallel(model, strategy)
best_valid_acc = 0
for epoch_id in range(epoch):
train_one_epoch(model, train_loader, optimizer, epoch_id,
use_data_parallel, log_freq)
loss, acc = valid_one_epoch(model, dev_loader, epoch_id, log_freq)
logger.info("dev set, ce_loss {:.6f}; acc: {:.6f};".format(loss,
acc))
if acc > best_valid_acc:
best_valid_acc = acc
logger.info(
"dev set, ce_loss {:.6f}; acc {:.6f}, best_acc {:.6f};".format(
loss, acc, best_valid_acc))
if __name__ == '__main__':
......
......@@ -222,9 +222,9 @@ def main():
acc))
if use_data_parallel:
print(model.student._encoder.alphas.numpy())
else:
print(model._layers.student._encoder.alphas.numpy())
else:
print(model.student._encoder.alphas.numpy())
print("=" * 100)
......
......@@ -49,6 +49,7 @@ class AdaBERTClassifier(Layer):
hidden_size=768,
gamma=0.8,
beta=4,
task_name='mnli',
conv_type="conv_bn",
search_layer=False,
teacher_model=None,
......@@ -75,7 +76,7 @@ class AdaBERTClassifier(Layer):
"----------------------load teacher model and test----------------------------------------"
)
self.teacher = BERTClassifier(
num_labels, model_path=self._teacher_model)
num_labels, task_name=task_name, model_path=self._teacher_model)
# global setting, will be overwritten when training(about 1% acc loss)
self.teacher.eval()
self.teacher.test(self._data_dir)
......@@ -83,6 +84,7 @@ class AdaBERTClassifier(Layer):
"----------------------finish load teacher model and test----------------------------------------"
)
self.student = BertModelLayer(
num_labels=num_labels,
n_layer=self._n_layer,
emb_size=self._emb_size,
hidden_size=self._hidden_size,
......
......@@ -32,6 +32,7 @@ from .transformer_encoder import EncoderLayer
class BertModelLayer(Layer):
def __init__(self,
num_labels,
emb_size=128,
hidden_size=768,
n_layer=12,
......@@ -91,6 +92,7 @@ class BertModelLayer(Layer):
param_attr=fluid.ParamAttr(name="s_emb_factorization"))
self._encoder = EncoderLayer(
num_labels=num_labels,
n_layer=self._n_layer,
hidden_size=self._hidden_size,
search_layer=self._search_layer,
......
......@@ -200,6 +200,7 @@ class EncoderLayer(Layer):
"""
def __init__(self,
num_labels,
n_layer,
hidden_size=768,
name="encoder",
......@@ -276,7 +277,7 @@ class EncoderLayer(Layer):
trainable=False))
out = Linear(
self._n_channel,
3,
num_labels,
param_attr=ParamAttr(initializer=MSRA()),
bias_attr=ParamAttr(initializer=MSRA()))
self.bns.append(bn)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册