diff --git a/examples/multi-task/run.py b/examples/multi-task/run.py index fb76f3de32fe7bd704c7a6f9fac71eb6c74697bf..b2e86c3061f84a444b232e6454ab1132e19f4a41 100644 --- a/examples/multi-task/run.py +++ b/examples/multi-task/run.py @@ -22,28 +22,25 @@ if __name__ == '__main__': train_slot = './data/atis/atis_slot/train.tsv' train_intent = './data/atis/atis_intent/train.tsv' predict_file = './data/atis/atis_slot/test.tsv' - save_path = './outputs/' pred_output = './outputs/predict/' - save_type = 'ckpt' - pre_params = './pretrain/ERNIE-v2-en-base/params' config = json.load(open('./pretrain/ERNIE-v2-en-base/ernie_config.json')) input_dim = config['hidden_size'] # ----------------------- for training ----------------------- - # step 1-1: create readers for training + # step 1-1: create readers seq_label_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, seed=random_seed) cls_reader = palm.reader.ClassifyReader(vocab_path, max_seqlen, seed=random_seed) - # step 1-2: load the training data + # step 1-2: load train data seq_label_reader.load_data(train_slot, file_format='tsv', num_epochs=None, batch_size=batch_size) cls_reader.load_data(train_intent, batch_size=batch_size, num_epochs=None) # step 2: create a backbone of the model to extract text features ernie = palm.backbone.ERNIE.from_config(config) - # step 3: register the backbone in readers + # step 3: register readers with ernie backbone seq_label_reader.register_with(ernie) cls_reader.register_with(ernie) @@ -51,7 +48,7 @@ if __name__ == '__main__': seq_label_head = palm.head.SequenceLabel(num_classes, input_dim, dropout_prob) cls_head = palm.head.Classify(num_classes_intent, input_dim, dropout_prob) - # step 5-1: create a task trainer + # step 5-1: create task trainers and multiHeadTrainer trainer_seq_label = palm.Trainer("slot", mix_ratio=1.0) trainer_cls = palm.Trainer("intent", mix_ratio=1.0) trainer = palm.MultiHeadTrainer([trainer_seq_label, trainer_cls]) @@ -60,23 +57,21 @@ if __name__ == '__main__': loss2 = trainer_seq_label.build_forward(ernie, seq_label_head) loss_var = trainer.build_forward() - # step 6-1*: use warmup + # step 6-1*: enable warmup for better fine-tuning n_steps = seq_label_reader.num_examples * 1.5 * num_epochs // batch_size warmup_steps = int(0.1 * n_steps) sched = palm.lr_sched.TriangularSchedualer(warmup_steps, n_steps) - # step 6-2: create a optimizer + # step 6-2: build a optimizer adam = palm.optimizer.Adam(loss_var, lr, sched) - # step 6-3: build backward + # step 6-3: build backward graph trainer.build_backward(optimizer=adam, weight_decay=weight_decay) - # step 7: fit prepared reader and data + # step 7: fit readers to trainer trainer.fit_readers_with_mixratio([seq_label_reader, cls_reader], "slot", num_epochs) - # step 8-1*: load pretrained parameters - trainer.load_pretrain(pre_params) - # step 8-2*: set saver to save model - save_steps = int(n_steps-batch_size) // 2 - # save_steps = 10 - trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type) + # step 8-1*: load pretrained model + trainer.load_pretrain('./pretrain/ERNIE-v2-en-base') + # step 8-2*: set saver to save models during training + trainer.set_saver(save_path='./outputs/', save_steps=300) # step 8-3: start training - trainer.train(print_steps=print_steps) + trainer.train(print_steps=10)