diff --git a/demo/demo2/paddlepalm b/demo/demo2/paddlepalm new file mode 120000 index 0000000000000000000000000000000000000000..02071c94fa2adcef66be6763eaec487eb5b478e3 --- /dev/null +++ b/demo/demo2/paddlepalm @@ -0,0 +1 @@ +../../paddlepalm/ \ No newline at end of file diff --git a/demo/demo2/run.py b/demo/demo2/run.py index f1041bfa1bde0b76b822e0aca26418e185b9d68d..678a09cc6139c9b4174c696bea9e94b0cbab039d 100644 --- a/demo/demo2/run.py +++ b/demo/demo2/run.py @@ -2,7 +2,12 @@ import paddlepalm as palm if __name__ == '__main__': - match_reader = palm.reader.match(train_file, file_format='csv', tokenizer='wordpiece', lang='en') + max_seqlen = 512 + batch_size = 32 + + match_reader = palm.reader.match(train_file, vocab, \ + max_seqlen, file_format='csv', tokenizer='wordpiece', \ + lang='en', shuffle_train=True) mrc_reader = palm.reader.mrc(train_file, phase='train') mlm_reader = palm.reader.mlm(train_file, phase='train') palm.reader. diff --git a/demo/demo3/.run.py.swp b/demo/demo3/.run.py.swp new file mode 100644 index 0000000000000000000000000000000000000000..2370d0471d868b0af30f7542375a1231fa978d03 Binary files /dev/null and b/demo/demo3/.run.py.swp differ diff --git a/demo/demo3/run.py b/demo/demo3/run.py index 242f5d51c9c6aa57cf670bd83dd8fbb9ba1fc7ae..678a09cc6139c9b4174c696bea9e94b0cbab039d 100644 --- a/demo/demo3/run.py +++ b/demo/demo3/run.py @@ -1,7 +1,52 @@ import paddlepalm as palm if __name__ == '__main__': - controller = palm.Controller('config.yaml', task_dir='tasks') + + max_seqlen = 512 + batch_size = 32 + + match_reader = palm.reader.match(train_file, vocab, \ + max_seqlen, file_format='csv', tokenizer='wordpiece', \ + lang='en', shuffle_train=True) + mrc_reader = palm.reader.mrc(train_file, phase='train') + mlm_reader = palm.reader.mlm(train_file, phase='train') + palm.reader. + + match = palm.tasktype.cls(num_classes=4) + mrc = palm.tasktype.match(learning_strategy='pairwise') + mlm = palm.tasktype.mlm() + mlm.print() + + + bb_flags = palm.load_json('./pretrain/ernie/ernie_config.json') + bb = palm.backbone.ernie(bb_flags['xx'], xxx) + bb.print() + + match4mrqa = palm.Task('match4mrqa', match_reader, match_tt) + mrc4mrqa = palm.Task('match4mrqa', match_reader, match_tt) + + # match4mrqa.reuse_with(mrc4mrqa) + + + controller = palm.Controller([mrqa, match4mrqa, mlm4mrqa]) + + loss = controller.build_forward(bb, mask_task=[]) + + n_steps = controller.estimate_train_steps(basetask=mrqa, num_epochs=2, batch_size=8, dev_count=4) + adam = palm.optimizer.Adam(loss) + sched = palm.schedualer.LinearWarmup(learning_rate, max_train_steps=n_steps, warmup_steps=0.1*n_steps) + + controller.build_backward(optimizer=adam, schedualer=sched, weight_decay=0.001, use_ema=True, ema_decay=0.999) + + controller.random_init_params() controller.load_pretrain('../../pretrain_model/ernie/params') controller.train() + + + + + # controller = palm.Controller(config='config.yaml', task_dir='tasks', for_train=False) + # controller.pred('mrqa', inference_model_dir='output_model/secondrun/mrqa/infer_model') + +