未验证 提交 ea7d592a 编写于 作者: X Xiaoyao Xi 提交者: GitHub

Update run.py

上级 746ce70f
......@@ -22,28 +22,25 @@ if __name__ == '__main__':
train_slot = './data/atis/atis_slot/train.tsv'
train_intent = './data/atis/atis_intent/train.tsv'
predict_file = './data/atis/atis_slot/test.tsv'
save_path = './outputs/'
pred_output = './outputs/predict/'
save_type = 'ckpt'
pre_params = './pretrain/ERNIE-v2-en-base/params'
config = json.load(open('./pretrain/ERNIE-v2-en-base/ernie_config.json'))
input_dim = config['hidden_size']
# ----------------------- for training -----------------------
# step 1-1: create readers for training
# step 1-1: create readers
seq_label_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, seed=random_seed)
cls_reader = palm.reader.ClassifyReader(vocab_path, max_seqlen, seed=random_seed)
# step 1-2: load the training data
# step 1-2: load train data
seq_label_reader.load_data(train_slot, file_format='tsv', num_epochs=None, batch_size=batch_size)
cls_reader.load_data(train_intent, batch_size=batch_size, num_epochs=None)
# step 2: create a backbone of the model to extract text features
ernie = palm.backbone.ERNIE.from_config(config)
# step 3: register the backbone in readers
# step 3: register readers with ernie backbone
seq_label_reader.register_with(ernie)
cls_reader.register_with(ernie)
......@@ -51,7 +48,7 @@ if __name__ == '__main__':
seq_label_head = palm.head.SequenceLabel(num_classes, input_dim, dropout_prob)
cls_head = palm.head.Classify(num_classes_intent, input_dim, dropout_prob)
# step 5-1: create a task trainer
# step 5-1: create task trainers and multiHeadTrainer
trainer_seq_label = palm.Trainer("slot", mix_ratio=1.0)
trainer_cls = palm.Trainer("intent", mix_ratio=1.0)
trainer = palm.MultiHeadTrainer([trainer_seq_label, trainer_cls])
......@@ -60,23 +57,21 @@ if __name__ == '__main__':
loss2 = trainer_seq_label.build_forward(ernie, seq_label_head)
loss_var = trainer.build_forward()
# step 6-1*: use warmup
# step 6-1*: enable warmup for better fine-tuning
n_steps = seq_label_reader.num_examples * 1.5 * num_epochs // batch_size
warmup_steps = int(0.1 * n_steps)
sched = palm.lr_sched.TriangularSchedualer(warmup_steps, n_steps)
# step 6-2: create a optimizer
# step 6-2: build a optimizer
adam = palm.optimizer.Adam(loss_var, lr, sched)
# step 6-3: build backward
# step 6-3: build backward graph
trainer.build_backward(optimizer=adam, weight_decay=weight_decay)
# step 7: fit prepared reader and data
# step 7: fit readers to trainer
trainer.fit_readers_with_mixratio([seq_label_reader, cls_reader], "slot", num_epochs)
# step 8-1*: load pretrained parameters
trainer.load_pretrain(pre_params)
# step 8-2*: set saver to save model
save_steps = int(n_steps-batch_size) // 2
# save_steps = 10
trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type)
# step 8-1*: load pretrained model
trainer.load_pretrain('./pretrain/ERNIE-v2-en-base')
# step 8-2*: set saver to save models during training
trainer.set_saver(save_path='./outputs/', save_steps=300)
# step 8-3: start training
trainer.train(print_steps=print_steps)
trainer.train(print_steps=10)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册