run.py 3.3 KB
Newer Older
X
xixiaoyao 已提交
1
# coding=utf-8
X
xixiaoyao 已提交
2
import paddlepalm as palm
X
xixiaoyao 已提交
3
import json
X
xixiaoyao 已提交
4 5

if __name__ == '__main__':
X
xixiaoyao 已提交
6

X
xixiaoyao 已提交
7
    max_seqlen = 512
X
xixiaoyao 已提交
8 9 10 11
    batch_size = 4
    num_epochs = 2
    lr = 1e-3
    vocab_path = './pretrain/ernie/vocab.txt'
X
xixiaoyao 已提交
12

X
xixiaoyao 已提交
13 14
    train_file = './data/cls4mrqa/train.tsv'
    predict_file = './data/cls4mrqa/dev.tsv'
X
xixiaoyao 已提交
15

X
xixiaoyao 已提交
16 17 18
    config = json.load(open('./pretrain/ernie/ernie_config.json'))
    # ernie = palm.backbone.ERNIE(...)
    ernie = palm.backbone.ERNIE.from_config(config)
X
xixiaoyao 已提交
19

X
xixiaoyao 已提交
20 21 22 23
    # cls_reader2 = palm.reader.cls(train_file_topic, vocab_path, batch_size, max_seqlen)
    # cls_reader3 = palm.reader.cls(train_file_subj, vocab_path, batch_size, max_seqlen)
    # topic_trainer = palm.Trainer('topic_cls', cls_reader2, cls)
    # subj_trainer = palm.Trainer('subj_cls', cls_reader3, cls)
X
xixiaoyao 已提交
24

X
xixiaoyao 已提交
25 26 27 28 29 30
    # 创建该分类任务的reader,由诸多参数控制数据集读入格式、文件数量、预处理规则等
    cls_reader = palm.reader.ClassifyReader(vocab_path, max_seqlen)
    cls_reader2 = palm.reader.ClassifyReader(vocab_path, max_seqlen)
    print(cls_reader.outputs_attr)
    # 不同的backbone会对任务reader有不同的特征要求,例如对于分类任务,基本的输入feature为token_ids和label_ids,但是对于BERT,还要求从输入中额外提取position、segment、input_mask等特征,因此经过register后,reader会自动补充backbone所要求的字段
    cls_reader.register_with(ernie)
X
xixiaoyao 已提交
31
    cls_reader2.register_with(ernie)
X
xixiaoyao 已提交
32
    print(cls_reader.outputs_attr)
X
xixiaoyao 已提交
33

X
xixiaoyao 已提交
34 35 36 37 38 39
    print("preparing data...")
    print(cls_reader.num_examples)
    cls_reader.load_data(train_file, batch_size)
    cls_reader2.load_data(train_file, batch_size)
    print(cls_reader.num_examples)
    print('done!')
X
xixiaoyao 已提交
40

X
xixiaoyao 已提交
41 42 43
    # 创建任务头(task head),如分类、匹配、机器阅读理解等。每个任务头有跟该任务相关的必选/可选参数。注意,任务头与reader是解耦合的,只要任务头依赖的数据集侧的字段能被reader提供,那么就是合法的
    cls_head = palm.head.Classify(4, 1024, 0.1)
    cls_head2 = palm.head.Classify(4, 1024, 0.1)
X
xixiaoyao 已提交
44

X
xixiaoyao 已提交
45 46 47 48
    # 根据reader和任务头来创建一个训练器trainer,trainer代表了一个训练任务,内部维护着训练进程、和任务的关键信息,并完成合法性校验,该任务的模型保存、载入等相关规则控制
    trainer = palm.Trainer('cls')
    trainer2 = palm.Trainer('senti_cls')
    mh_trainer = palm.MultiHeadTrainer([trainer, trainer2])
X
xixiaoyao 已提交
49

X
xixiaoyao 已提交
50
    # match4mrqa.reuse_head_with(mrc4mrqa)
X
xixiaoyao 已提交
51

X
xixiaoyao 已提交
52 53 54
    # data_vars = cls_reader.build()
    # output_vars = ernie.build(data_vars)
    # cls_head.build({'backbone': output_vars, 'reader': data_vars})
X
xixiaoyao 已提交
55

X
xixiaoyao 已提交
56 57
    loss_var = mh_trainer.build_forward(ernie, [cls_head, cls_head2])

X
xixiaoyao 已提交
58 59 60 61
    n_steps = cls_reader.num_examples * num_epochs // batch_size
    warmup_steps = int(0.1 * n_steps)
    print(warmup_steps)
    sched = palm.lr_sched.TriangularSchedualer(warmup_steps, n_steps)
X
xixiaoyao 已提交
62 63 64 65

    adam = palm.optimizer.Adam(loss_var, lr, sched)

    mh_trainer.build_backward(optimizer=adam, weight_decay=0.001)
X
xixiaoyao 已提交
66 67
    
    # mh_trainer.random_init_params()
X
xixiaoyao 已提交
68 69 70 71 72 73 74
    mh_trainer.load_pretrain('pretrain/ernie/params')

    # trainer.train(iterator_fn, print_steps=1, save_steps=5, save_path='outputs', save_type='ckpt,predict')
    mh_trainer.fit_readers_with_mixratio([cls_reader, cls_reader2], 'cls', 2)
    mh_trainer.train(print_steps=1)
    # trainer.save()