From 77d5d10662da22804e56ac111e56f407a9aefc8b Mon Sep 17 00:00:00 2001 From: Meiyim Date: Mon, 15 Jun 2020 10:30:33 +0800 Subject: [PATCH] update pretrain demo (#491) * update pretrain demo * modeul path fix for ernie-gen * dygraph distributed cls + init_checkpoint option --- demo/finetune_classifier_dygraph_distributed.py | 7 +++++++ demo/pretrain/README.md | 8 ++++---- demo/pretrain/make_pretrain_data.py | 2 +- demo/seq2seq/finetune_seq2seq_dygraph.py | 4 ++-- 4 files changed, 14 insertions(+), 7 deletions(-) diff --git a/demo/finetune_classifier_dygraph_distributed.py b/demo/finetune_classifier_dygraph_distributed.py index fc86f44..42e8eed 100644 --- a/demo/finetune_classifier_dygraph_distributed.py +++ b/demo/finetune_classifier_dygraph_distributed.py @@ -53,6 +53,7 @@ if __name__ == '__main__': parser.add_argument('--lr', type=float, default=5e-5, help='learning rate') parser.add_argument('--save_dir', type=str, default=None, help='model output directory') parser.add_argument('--wd', type=int, default=0.01, help='weight decay, aka L2 regularizer') + parser.add_argument('--init_checkpoint', type=str, default=None, help='checkpoint to warm start from') args = parser.parse_args() @@ -99,6 +100,12 @@ if __name__ == '__main__': with FD.guard(place): ctx = FD.parallel.prepare_context() model = ErnieModelForSequenceClassification.from_pretrained(args.from_pretrained, num_labels=3, name='') + + if args.init_checkpoint is not None: + log.info('loading checkpoint from %s' % args.init_checkpoint) + sd, _ = FD.load_dygraph(args.init_checkpoint) + model.set_dict(sd) + model = FD.parallel.DataParallel(model, ctx) g_clip = F.clip.GradientClipByGlobalNorm(1.0) #experimental diff --git a/demo/pretrain/README.md b/demo/pretrain/README.md index 6e04f8c..df69d93 100644 --- a/demo/pretrain/README.md +++ b/demo/pretrain/README.md @@ -28,7 +28,7 @@ example: make pretrain data with: ```script -python3 ernie/pretrain/make_pretrain_data.py input_file output_file.gz --vocab ./pretrained/vocab.txt +python3 ./demo/pretrain/make_pretrain_data.py input_file output_file.gz --vocab /path/to/ernie1.0/vocab.txt ``` 2. run distributed pretrain @@ -36,9 +36,9 @@ python3 ernie/pretrain/make_pretrain_data.py input_file output_file.gz --vocab ```sript python3 -m paddle.distributed.launch \ -./ernie/pretrain/pretrain_dygraph.py \ - --data_dir data/* \ - --from_pretrained ./ernie_1.0_pretrain_dir/ +./demo/pretrain/pretrain_dygraph.py \ + --data_dir "data/*.gz" \ + --from_pretrained /path/to/ernie1.0_pretrain_dir/ ``` diff --git a/demo/pretrain/make_pretrain_data.py b/demo/pretrain/make_pretrain_data.py index 3070be0..96be1d4 100644 --- a/demo/pretrain/make_pretrain_data.py +++ b/demo/pretrain/make_pretrain_data.py @@ -124,7 +124,7 @@ if __name__ == '__main__': log.setLevel(logging.DEBUG) - from tokenizing_ernie import _wordpiece + from ernie.tokenizing_ernie import _wordpiece pat = re.compile(r'([a-zA-Z0-9]+|\S)') vocab = {j.strip().split(b'\t')[0].decode('utf8'): i for i, j in enumerate(open(args.vocab, 'rb'))} diff --git a/demo/seq2seq/finetune_seq2seq_dygraph.py b/demo/seq2seq/finetune_seq2seq_dygraph.py index f9ddebc..77140b9 100644 --- a/demo/seq2seq/finetune_seq2seq_dygraph.py +++ b/demo/seq2seq/finetune_seq2seq_dygraph.py @@ -37,7 +37,7 @@ from ernie.modeling_ernie import _build_linear, _build_ln, append_name from ernie.tokenizing_ernie import ErnieTokenizer from ernie.optimization import AdamW, LinearDecay -from experimental.seq2seq.decode import beam_search_infilling, post_process +from demo.seq2seq.decode import beam_search_infilling, post_process from propeller import log import propeller.paddle as propeller @@ -295,7 +295,7 @@ if __name__ == '__main__': parser.add_argument('--predict_output_dir', type=str, default=None, help='predict file output directory') parser.add_argument('--attn_token', type=str, default='[ATTN]', help='if [ATTN] not in vocab, you can specified [MAKK] as attn-token') parser.add_argument('--inference_model_dir', type=str, default=None, help='inference model output directory') - parser.add_argument('--init_checkpoint', type=str, default=None) + parser.add_argument('--init_checkpoint', type=str, default=None, help='checkpoint to warm start from') parser.add_argument('--save_dir', type=str, default=None, help='model output directory') parser.add_argument('--wd', type=float, default=0.01, help='weight decay, aka L2 regularizer') -- GitLab