diff --git a/PaddleNLP/language_representations_kit/BERT/train.py b/PaddleNLP/language_representations_kit/BERT/train.py index 85833f57fccce7e2bf84f7c8e067bd88244d2714..b3640f324ac05f428e9ca8d49151b1c856fa1116 100644 --- a/PaddleNLP/language_representations_kit/BERT/train.py +++ b/PaddleNLP/language_representations_kit/BERT/train.py @@ -138,6 +138,8 @@ def predict_wrapper(args, max_seq_len=args.max_seq_len, is_test=True) + pyreader.decorate_batch_generator(data_reader.data_generator()) + if args.do_test: assert args.init_checkpoint is not None, "[FATAL] Please use --init_checkpoint '/path/to/checkpoints' \ to specify you pretrained model checkpoints" @@ -145,8 +147,6 @@ def predict_wrapper(args, init_pretraining_params(exe, args.init_checkpoint, test_prog) def predict(exe=exe, pyreader=pyreader): - - pyreader.decorate_batch_generator(data_reader.data_generator()) pyreader.start() cost = 0