diff --git a/.gitignore b/.gitignore index c49200d1dd623f2f0fd00f084d8ce67228426b9a..b6e8399fa221ed1e6fc9aad73c09c45dad370e2b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,6 @@ *.pyc __pycache__ -pretrain_model +pretrain output_model build dist diff --git a/README.md b/README.md index 6e888bcab9835897c0729d03d21b9e0c7213ae60..ce5a7aa128a55570d6d47dfb3c84ea8d9384aae5 100644 --- a/README.md +++ b/README.md @@ -163,7 +163,7 @@ max_seq_len: 512 max_query_len: 64 doc_stride: 128 # 在MRQA数据集中,存在较长的文档,因此我们这里使用滑动窗口处理样本,滑动步长设置为128 do_lower_case: True -vocab_path: "../../pretrain_model/bert/vocab.txt" +vocab_path: "../../pretrain/bert-en-uncased-large/vocab.txt" ``` 更详细的任务实例配置方法(为任务实例选择合适的reader、paradigm和backbone)可参考[这里](#readerbackbone与paradigm的选择) @@ -178,7 +178,7 @@ task_instance: "mrqa" save_path: "output_model/firstrun" backbone: "bert" -backbone_config_path: "../../pretrain_model/bert/bert_config.json" +backbone_config_path: "../../pretrain/bert-en-uncased-large/bert_config.json" optimizer: "adam" learning_rate: 3e-5 @@ -204,7 +204,7 @@ import paddlepalm as palm if __name__ == '__main__': controller = palm.Controller('config.yaml') - controller.load_pretrain('../../pretrain_model/bert/params') + controller.load_pretrain('../../pretrain/bert-en-uncased-large/params') controller.train() ``` @@ -271,9 +271,9 @@ target_tag: 1,0,0 save_path: "output_model/secondrun" backbone: "ernie" -backbone_config_path: "../../pretrain_model/ernie/ernie_config.json" +backbone_config_path: "../../pretrain/ernie-en-uncased-large/ernie_config.json" -vocab_path: "../../pretrain_model/ernie/vocab.txt" +vocab_path: "../../pretrain/ernie-en-uncased-large/vocab.txt" do_lower_case: True max_seq_len: 512 # 写入全局配置文件的参数会被自动广播到各个任务实例 @@ -308,7 +308,7 @@ import paddlepalm as palm if __name__ == '__main__': controller = palm.Controller('config.yaml', task_dir='tasks') - controller.load_pretrain('../../pretrain_model/ernie/params') + controller.load_pretrain('../../pretrain/ernie-en-uncased-large/params') controller.train() ``` @@ -400,9 +400,9 @@ task_reuse_tag: 0, 0, 1, 1, 0, 2 save_path: "output_model/secondrun" backbone: "ernie" -backbone_config_path: "../../pretrain_model/ernie/ernie_config.json" +backbone_config_path: "../../pretrain/ernie-en-uncased-large/ernie_config.json" -vocab_path: "../../pretrain_model/ernie/vocab.txt" +vocab_path: "../../pretrain/ernie-en-uncased-large/vocab.txt" do_lower_case: True max_seq_len: 512 # 写入全局配置文件的参数会被自动广播到各个任务实例 @@ -422,7 +422,7 @@ import paddlepalm as palm if __name__ == '__main__': controller = palm.Controller('config.yaml', task_dir='tasks') - controller.load_pretrain('../../pretrain_model/ernie/params') + controller.load_pretrain('../../pretrain/ernie-en-uncased-large/params') controller.train() ``` diff --git a/demo/demo1/config.yaml b/demo/demo1/config.yaml index 56cb9476124c378c7440171317063e2d7aabe688..0cca90ff519fc9606214dd3aad4c5763092e9bcc 100644 --- a/demo/demo1/config.yaml +++ b/demo/demo1/config.yaml @@ -2,8 +2,8 @@ task_instance: "mrqa" save_path: "output_model/firstrun" -backbone: "bert" -backbone_config_path: "../../pretrain_model/bert/bert_config.json" +backbone: "bert-en-uncased-large" +backbone_config_path: "../../pretrain/bert-en-uncased-large/bert_config.json" batch_size: 4 num_epochs: 2 diff --git a/demo/demo1/mrqa.yaml b/demo/demo1/mrqa.yaml index 36bb6e29793971db30370845d2cceb53ed5d0ba5..ce4044232cedccbbaecedb64ffa4a5f4320859be 100644 --- a/demo/demo1/mrqa.yaml +++ b/demo/demo1/mrqa.yaml @@ -2,7 +2,7 @@ train_file: data/mrqa/train.json reader: mrc paradigm: mrc -vocab_path: "../../pretrain_model/bert/vocab.txt" +vocab_path: "../../pretrain/bert-en-uncased-large/vocab.txt" do_lower_case: True max_seq_len: 512 doc_stride: 128 diff --git a/demo/demo2/config.yaml b/demo/demo2/config.yaml index 2ea38ff15ff53b59cc059561d01754a69d9b1caf..2bd539051b4dbcc7b65763599f69bf81402f6047 100644 --- a/demo/demo2/config.yaml +++ b/demo/demo2/config.yaml @@ -4,15 +4,15 @@ mix_ratio: 1.0, 0.5, 0.5 save_path: "output_model/secondrun" -backbone: "ernie" -backbone_config_path: "../../pretrain_model/ernie/ernie_config.json" +backbone: "ernie-en-uncased-large" +backbone_config_path: "../../pretrain/ernie-en-uncased-large/ernie_config.json" -vocab_path: "../../pretrain_model/ernie/vocab.txt" +vocab_path: "../../pretrain/ernie-en-uncased-large/vocab.txt" do_lower_case: True max_seq_len: 512 batch_size: 4 -num_epochs: 2 +num_epochs: 0.1 optimizer: "adam" learning_rate: 3e-5 warmup_proportion: 0.1 diff --git a/demo/demo3/config.yaml b/demo/demo3/config.yaml index 0cd1ba8a9cb331b5b6f61d2bb9f47817145671ea..0595e58eb2c099807f1a9f790428406f134efe89 100644 --- a/demo/demo3/config.yaml +++ b/demo/demo3/config.yaml @@ -4,10 +4,10 @@ task_reuse_tag: 0,0,1,1,0,2 save_path: "output_model/thirdrun" -backbone: "ernie" -backbone_config_path: "../../pretrain_model/ernie/ernie_config.json" +backbone: "ernie-en-uncased-large" +backbone_config_path: "../../pretrain/ernie-en-uncased-large/ernie_config.json" -vocab_path: "../../pretrain_model/ernie/vocab.txt" +vocab_path: "../../pretrain/ernie-en-uncased-large/vocab.txt" do_lower_case: True max_seq_len: 512 diff --git a/paddlepalm/mtl_controller.py b/paddlepalm/mtl_controller.py index 0cac9fddb60489791205b146bcd2049d19bc8f90..1d229c753a6815031e980a6025a03c671e88dd8c 100755 --- a/paddlepalm/mtl_controller.py +++ b/paddlepalm/mtl_controller.py @@ -522,15 +522,15 @@ class Controller(object): inst.reader['pred'] = pred_reader return pred_prog - def load_pretrain(self, pretrain_model_path=None): + def load_pretrain(self, pretrain_path=None): # load pretrain model (or ckpt) - if pretrain_model_path is None: - assert 'pretrain_model_path' in self.main_conf, "pretrain_model_path NOT set." - pretrain_model_path = self.main_conf['pretrain_model_path'] + if pretrain_path is None: + assert 'pretrain_path' in self.main_conf, "pretrain_path NOT set." + pretrain_path = self.main_conf['pretrain_path'] init_pretraining_params( self.exe, - pretrain_model_path, + pretrain_path, main_program=fluid.default_startup_program())