提交 5ed8185b 编写于 作者: X xixiaoyao

Revert "refactor demos"

This reverts commit cdee76f2.
上级 cdee76f2
......@@ -3,7 +3,7 @@ task_instance: "mrqa"
save_path: "output_model/firstrun"
backbone: "bert"
backbone_config_path: "../../pretrain_model/bert/bert_config.json"
backbone_config_path: "pretrain_model/bert/bert_config.json"
batch_size: 4
num_epochs: 2
......
......@@ -5,9 +5,9 @@ mix_ratio: 1.0, 0.5, 0.5
save_path: "output_model/secondrun"
backbone: "ernie"
backbone_config_path: "../../pretrain_model/ernie/ernie_config.json"
backbone_config_path: "pretrain_model/ernie/ernie_config.json"
vocab_path: "../../pretrain_model/ernie/vocab.txt"
vocab_path: "pretrain_model/ernie/vocab.txt"
do_lower_case: True
max_seq_len: 512
......
......@@ -5,9 +5,9 @@ task_reuse_tag: 0,0,1,1,0,2
save_path: "output_model/thirdrun"
backbone: "ernie"
backbone_config_path: "../../pretrain_model/ernie/ernie_config.json"
backbone_config_path: "pretrain_model/ernie/ernie_config.json"
vocab_path: "../../pretrain_model/ernie/vocab.txt"
vocab_path: "pretrain_model/ernie/vocab.txt"
do_lower_case: True
max_seq_len: 512
......
# coding: utf-8
f='mrqa-combined.train.raw.json'
import json
a=json.load(open(f))
a=a['data']
writer = open('train.json','w')
for s in a:
p = s['paragraphs']
assert len(p) == 1
p = p[0]
q = {}
q['context'] = p['context']
q['qa_list'] = p['qas']
writer.write(json.dumps(q)+'\n')
writer.close()
此差异已折叠。
此差异已折叠。
export CUDA_VISIBLE_DEVICES=0,1
python run.py
import paddlepalm as palm
if __name__ == '__main__':
controller = palm.Controller('config.yaml', task_dir='tasks')
controller.load_pretrain('../../pretrain_model/ernie/params')
controller = palm.Controller('config_demo1.yaml', task_dir='demo1_tasks')
controller.load_pretrain('pretrain_model/bert/params')
controller.train()
......@@ -2,7 +2,7 @@ train_file: data/mrqa/train.json
reader: mrc
paradigm: mrc
vocab_path: "../../pretrain_model/bert/vocab.txt"
vocab_path: "pretrain_model/bert/vocab.txt"
do_lower_case: True
max_seq_len: 512
doc_stride: 128
......
import paddlepalm as palm
if __name__ == '__main__':
controller = palm.Controller('config.yaml', task_dir='tasks')
controller.load_pretrain('../../pretrain_model/ernie/params')
controller = palm.Controller('config_demo2.yaml', task_dir='demo2_tasks')
controller.load_pretrain('pretrain_model/ernie/params')
controller.train()
controller = palm.Controller(config='config.yaml', task_dir='tasks', for_train=False)
controller = palm.Controller(config='config_demo2.yaml', task_dir='demo2_tasks', for_train=False)
controller.pred('mrqa', inference_model_dir='output_model/secondrun/mrqa/infer_model')
import paddlepalm as palm
if __name__ == '__main__':
controller = palm.Controller('config.yaml')
controller.load_pretrain('../../pretrain_model/bert/params')
controller = palm.Controller('config_demo3.yaml', task_dir='demo3_tasks')
controller.load_pretrain('pretrain_model/ernie/params')
controller.train()
export CUDA_VISIBLE_DEVICES=0
python run.py
python demo1.py
export CUDA_VISIBLE_DEVICES=0
python demo2.py
export CUDA_VISIBLE_DEVICES=0
python run.py
python demo3.py
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册