提交 a9c5a794 编写于 作者: X xixiaoyao

add multihead

上级 aa7f8ed7
此差异已折叠。
此差异已折叠。
# coding=utf-8
import paddlepalm as palm
import json
if __name__ == '__main__':
max_seqlen = 512
batch_size = 32
batch_size = 4
num_epochs = 2
lr = 1e-3
vocab_path = './pretrain/ernie/vocab.txt'
match_reader = palm.reader.match(train_file, vocab, \
max_seqlen, file_format='csv', tokenizer='wordpiece', \
lang='en', shuffle_train=True)
mrc_reader = palm.reader.mrc(train_file, phase='train')
mlm_reader = palm.reader.mlm(train_file, phase='train')
palm.reader.
train_file = './data/cls4mrqa/train.tsv'
predict_file = './data/cls4mrqa/dev.tsv'
match = palm.tasktype.cls(num_classes=4)
mrc = palm.tasktype.match(learning_strategy='pairwise')
mlm = palm.tasktype.mlm()
mlm.print()
config = json.load(open('./pretrain/ernie/ernie_config.json'))
# ernie = palm.backbone.ERNIE(...)
ernie = palm.backbone.ERNIE.from_config(config)
bb_flags = palm.load_json('./pretrain/ernie/ernie_config.json')
bb = palm.backbone.ernie(bb_flags['xx'], xxx)
bb.print()
# cls_reader2 = palm.reader.cls(train_file_topic, vocab_path, batch_size, max_seqlen)
# cls_reader3 = palm.reader.cls(train_file_subj, vocab_path, batch_size, max_seqlen)
# topic_trainer = palm.Trainer('topic_cls', cls_reader2, cls)
# subj_trainer = palm.Trainer('subj_cls', cls_reader3, cls)
match4mrqa = palm.Task('match4mrqa', match_reader, match_tt)
mrc4mrqa = palm.Task('match4mrqa', match_reader, match_tt)
# 创建该分类任务的reader,由诸多参数控制数据集读入格式、文件数量、预处理规则等
cls_reader = palm.reader.ClassifyReader(vocab_path, max_seqlen)
cls_reader2 = palm.reader.ClassifyReader(vocab_path, max_seqlen)
predict_cls_reader = palm.reader.ClassifyReader(vocab_path, max_seqlen, phase='predict')
print(cls_reader.outputs_attr)
print(predict_cls_reader.outputs_attr)
# 不同的backbone会对任务reader有不同的特征要求,例如对于分类任务,基本的输入feature为token_ids和label_ids,但是对于BERT,还要求从输入中额外提取position、segment、input_mask等特征,因此经过register后,reader会自动补充backbone所要求的字段
cls_reader.register_with(ernie)
print(cls_reader.outputs_attr)
print(predict_cls_reader.outputs_attr)
# match4mrqa.reuse_with(mrc4mrqa)
print("preparing data...")
print(cls_reader.num_examples)
cls_reader.load_data(train_file, batch_size)
cls_reader2.load_data(train_file, batch_size)
print(cls_reader.num_examples)
print('done!')
# 创建任务头(task head),如分类、匹配、机器阅读理解等。每个任务头有跟该任务相关的必选/可选参数。注意,任务头与reader是解耦合的,只要任务头依赖的数据集侧的字段能被reader提供,那么就是合法的
cls_head = palm.head.Classify(4, 1024, 0.1)
cls_head2 = palm.head.Classify(4, 1024, 0.1)
controller = palm.Controller([mrqa, match4mrqa, mlm4mrqa])
# 根据reader和任务头来创建一个训练器trainer,trainer代表了一个训练任务,内部维护着训练进程、和任务的关键信息,并完成合法性校验,该任务的模型保存、载入等相关规则控制
trainer = palm.Trainer('cls')
trainer2 = palm.Trainer('senti_cls')
mh_trainer = palm.MultiHeadTrainer([trainer, trainer2])
loss = controller.build_forward(bb, mask_task=[])
# match4mrqa.reuse_head_with(mrc4mrqa)
n_steps = controller.estimate_train_steps(basetask=mrqa, num_epochs=2, batch_size=8, dev_count=4)
adam = palm.optimizer.Adam(loss)
sched = palm.schedualer.LinearWarmup(learning_rate, max_train_steps=n_steps, warmup_steps=0.1*n_steps)
controller.build_backward(optimizer=adam, schedualer=sched, weight_decay=0.001, use_ema=True, ema_decay=0.999)
# data_vars = cls_reader.build()
# output_vars = ernie.build(data_vars)
# cls_head.build({'backbone': output_vars, 'reader': data_vars})
controller.random_init_params()
controller.load_pretrain('../../pretrain_model/ernie/params')
controller.train()
loss_var = mh_trainer.build_forward(ernie, [cls_head, cls_head2])
# controller.build_forward()
# Error! a head/backbone can be only build once! Try NOT to call build_forward method for any Trainer!
# n_steps = cls_reader.num_examples * num_epochs // batch_size
# warmup_steps = int(0.1 * n_steps)
# print(warmup_steps)
# sched = palm.lr_sched.TriangularSchedualer(warmup_steps, n_steps)
sched = None
adam = palm.optimizer.Adam(loss_var, lr, sched)
mh_trainer.build_backward(optimizer=adam, weight_decay=0.001)
mh_trainer.random_init_params()
mh_trainer.load_pretrain('pretrain/ernie/params')
# trainer.train(iterator_fn, print_steps=1, save_steps=5, save_path='outputs', save_type='ckpt,predict')
mh_trainer.fit_readers_with_mixratio([cls_reader, cls_reader2], 'cls', 2)
mh_trainer.train(print_steps=1)
# trainer.save()
# print('prepare to predict...')
# pred_ernie = palm.backbone.ERNIE.from_config(config, phase='pred')
# cls_pred_head = palm.head.Classify(4, 1024, phase='pred')
# trainer.build_predict_forward(pred_ernie, cls_pred_head)
# predict_cls_reader.load_data(predict_file, 8)
# print(predict_cls_reader.num_examples)
# predict_cls_reader.register_with(pred_ernie)
# trainer.fit_reader(predict_cls_reader, phase='predict')
# print('predicting..')
# trainer.predict(print_steps=20)
# controller = palm.Controller([mrqa, match4mrqa, mlm4mrqa])
# loss = controller.build_forward(bb, mask_task=[])
# n_steps = controller.estimate_train_steps(basetask=mrqa, num_epochs=2, batch_size=8, dev_count=4)
# adam = palm.optimizer.Adam(loss)
# sched = palm.schedualer.LinearWarmup(learning_rate, max_train_steps=n_steps, warmup_steps=0.1*n_steps)
#
# controller.build_backward(optimizer=adam, schedualer=sched, weight_decay=0.001, use_ema=True, ema_decay=0.999)
# controller.random_init_params()
# controller.load_pretrain('../../pretrain_model/ernie/params')
# controller.train()
......
export CUDA_VISIBLE_DEVICES=0
export CUDA_VISIBLE_DEVICES=3
python run.py
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册