未验证 提交 90b85001 编写于 作者: Y Yibing Liu 提交者: GitHub

Fix bert training (#3044)

上级 aab36287
......@@ -138,6 +138,8 @@ def predict_wrapper(args,
max_seq_len=args.max_seq_len,
is_test=True)
pyreader.decorate_batch_generator(data_reader.data_generator())
if args.do_test:
assert args.init_checkpoint is not None, "[FATAL] Please use --init_checkpoint '/path/to/checkpoints' \
to specify you pretrained model checkpoints"
......@@ -145,8 +147,6 @@ def predict_wrapper(args,
init_pretraining_params(exe, args.init_checkpoint, test_prog)
def predict(exe=exe, pyreader=pyreader):
pyreader.decorate_batch_generator(data_reader.data_generator())
pyreader.start()
cost = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册