提交 1941c278 编写于 作者: W wangxiao

change name & path

上级 2f34f935
*.pyc *.pyc
__pycache__ __pycache__
pretrain_model pretrain
output_model output_model
build build
dist dist
......
...@@ -163,7 +163,7 @@ max_seq_len: 512 ...@@ -163,7 +163,7 @@ max_seq_len: 512
max_query_len: 64 max_query_len: 64
doc_stride: 128 # 在MRQA数据集中,存在较长的文档,因此我们这里使用滑动窗口处理样本,滑动步长设置为128 doc_stride: 128 # 在MRQA数据集中,存在较长的文档,因此我们这里使用滑动窗口处理样本,滑动步长设置为128
do_lower_case: True do_lower_case: True
vocab_path: "../../pretrain_model/bert/vocab.txt" vocab_path: "../../pretrain/bert-en-uncased-large/vocab.txt"
``` ```
更详细的任务实例配置方法(为任务实例选择合适的reader、paradigm和backbone)可参考[这里](#readerbackbone与paradigm的选择) 更详细的任务实例配置方法(为任务实例选择合适的reader、paradigm和backbone)可参考[这里](#readerbackbone与paradigm的选择)
...@@ -178,7 +178,7 @@ task_instance: "mrqa" ...@@ -178,7 +178,7 @@ task_instance: "mrqa"
save_path: "output_model/firstrun" save_path: "output_model/firstrun"
backbone: "bert" backbone: "bert"
backbone_config_path: "../../pretrain_model/bert/bert_config.json" backbone_config_path: "../../pretrain/bert-en-uncased-large/bert_config.json"
optimizer: "adam" optimizer: "adam"
learning_rate: 3e-5 learning_rate: 3e-5
...@@ -204,7 +204,7 @@ import paddlepalm as palm ...@@ -204,7 +204,7 @@ import paddlepalm as palm
if __name__ == '__main__': if __name__ == '__main__':
controller = palm.Controller('config.yaml') controller = palm.Controller('config.yaml')
controller.load_pretrain('../../pretrain_model/bert/params') controller.load_pretrain('../../pretrain/bert-en-uncased-large/params')
controller.train() controller.train()
``` ```
...@@ -271,9 +271,9 @@ target_tag: 1,0,0 ...@@ -271,9 +271,9 @@ target_tag: 1,0,0
save_path: "output_model/secondrun" save_path: "output_model/secondrun"
backbone: "ernie" backbone: "ernie"
backbone_config_path: "../../pretrain_model/ernie/ernie_config.json" backbone_config_path: "../../pretrain/ernie-en-uncased-large/ernie_config.json"
vocab_path: "../../pretrain_model/ernie/vocab.txt" vocab_path: "../../pretrain/ernie-en-uncased-large/vocab.txt"
do_lower_case: True do_lower_case: True
max_seq_len: 512 # 写入全局配置文件的参数会被自动广播到各个任务实例 max_seq_len: 512 # 写入全局配置文件的参数会被自动广播到各个任务实例
...@@ -308,7 +308,7 @@ import paddlepalm as palm ...@@ -308,7 +308,7 @@ import paddlepalm as palm
if __name__ == '__main__': if __name__ == '__main__':
controller = palm.Controller('config.yaml', task_dir='tasks') controller = palm.Controller('config.yaml', task_dir='tasks')
controller.load_pretrain('../../pretrain_model/ernie/params') controller.load_pretrain('../../pretrain/ernie-en-uncased-large/params')
controller.train() controller.train()
``` ```
...@@ -400,9 +400,9 @@ task_reuse_tag: 0, 0, 1, 1, 0, 2 ...@@ -400,9 +400,9 @@ task_reuse_tag: 0, 0, 1, 1, 0, 2
save_path: "output_model/secondrun" save_path: "output_model/secondrun"
backbone: "ernie" backbone: "ernie"
backbone_config_path: "../../pretrain_model/ernie/ernie_config.json" backbone_config_path: "../../pretrain/ernie-en-uncased-large/ernie_config.json"
vocab_path: "../../pretrain_model/ernie/vocab.txt" vocab_path: "../../pretrain/ernie-en-uncased-large/vocab.txt"
do_lower_case: True do_lower_case: True
max_seq_len: 512 # 写入全局配置文件的参数会被自动广播到各个任务实例 max_seq_len: 512 # 写入全局配置文件的参数会被自动广播到各个任务实例
...@@ -422,7 +422,7 @@ import paddlepalm as palm ...@@ -422,7 +422,7 @@ import paddlepalm as palm
if __name__ == '__main__': if __name__ == '__main__':
controller = palm.Controller('config.yaml', task_dir='tasks') controller = palm.Controller('config.yaml', task_dir='tasks')
controller.load_pretrain('../../pretrain_model/ernie/params') controller.load_pretrain('../../pretrain/ernie-en-uncased-large/params')
controller.train() controller.train()
``` ```
......
...@@ -2,8 +2,8 @@ task_instance: "mrqa" ...@@ -2,8 +2,8 @@ task_instance: "mrqa"
save_path: "output_model/firstrun" save_path: "output_model/firstrun"
backbone: "bert" backbone: "bert-en-uncased-large"
backbone_config_path: "../../pretrain_model/bert/bert_config.json" backbone_config_path: "../../pretrain/bert-en-uncased-large/bert_config.json"
batch_size: 4 batch_size: 4
num_epochs: 2 num_epochs: 2
......
...@@ -2,7 +2,7 @@ train_file: data/mrqa/train.json ...@@ -2,7 +2,7 @@ train_file: data/mrqa/train.json
reader: mrc reader: mrc
paradigm: mrc paradigm: mrc
vocab_path: "../../pretrain_model/bert/vocab.txt" vocab_path: "../../pretrain/bert-en-uncased-large/vocab.txt"
do_lower_case: True do_lower_case: True
max_seq_len: 512 max_seq_len: 512
doc_stride: 128 doc_stride: 128
......
...@@ -4,15 +4,15 @@ mix_ratio: 1.0, 0.5, 0.5 ...@@ -4,15 +4,15 @@ mix_ratio: 1.0, 0.5, 0.5
save_path: "output_model/secondrun" save_path: "output_model/secondrun"
backbone: "ernie" backbone: "ernie-en-uncased-large"
backbone_config_path: "../../pretrain_model/ernie/ernie_config.json" backbone_config_path: "../../pretrain/ernie-en-uncased-large/ernie_config.json"
vocab_path: "../../pretrain_model/ernie/vocab.txt" vocab_path: "../../pretrain/ernie-en-uncased-large/vocab.txt"
do_lower_case: True do_lower_case: True
max_seq_len: 512 max_seq_len: 512
batch_size: 4 batch_size: 4
num_epochs: 2 num_epochs: 0.1
optimizer: "adam" optimizer: "adam"
learning_rate: 3e-5 learning_rate: 3e-5
warmup_proportion: 0.1 warmup_proportion: 0.1
......
...@@ -4,10 +4,10 @@ task_reuse_tag: 0,0,1,1,0,2 ...@@ -4,10 +4,10 @@ task_reuse_tag: 0,0,1,1,0,2
save_path: "output_model/thirdrun" save_path: "output_model/thirdrun"
backbone: "ernie" backbone: "ernie-en-uncased-large"
backbone_config_path: "../../pretrain_model/ernie/ernie_config.json" backbone_config_path: "../../pretrain/ernie-en-uncased-large/ernie_config.json"
vocab_path: "../../pretrain_model/ernie/vocab.txt" vocab_path: "../../pretrain/ernie-en-uncased-large/vocab.txt"
do_lower_case: True do_lower_case: True
max_seq_len: 512 max_seq_len: 512
......
...@@ -522,15 +522,15 @@ class Controller(object): ...@@ -522,15 +522,15 @@ class Controller(object):
inst.reader['pred'] = pred_reader inst.reader['pred'] = pred_reader
return pred_prog return pred_prog
def load_pretrain(self, pretrain_model_path=None): def load_pretrain(self, pretrain_path=None):
# load pretrain model (or ckpt) # load pretrain model (or ckpt)
if pretrain_model_path is None: if pretrain_path is None:
assert 'pretrain_model_path' in self.main_conf, "pretrain_model_path NOT set." assert 'pretrain_path' in self.main_conf, "pretrain_path NOT set."
pretrain_model_path = self.main_conf['pretrain_model_path'] pretrain_path = self.main_conf['pretrain_path']
init_pretraining_params( init_pretraining_params(
self.exe, self.exe,
pretrain_model_path, pretrain_path,
main_program=fluid.default_startup_program()) main_program=fluid.default_startup_program())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册