未验证 提交 90b85001 编写于 作者: Y Yibing Liu 提交者: GitHub

Fix bert training (#3044)

上级 aab36287
...@@ -138,6 +138,8 @@ def predict_wrapper(args, ...@@ -138,6 +138,8 @@ def predict_wrapper(args,
max_seq_len=args.max_seq_len, max_seq_len=args.max_seq_len,
is_test=True) is_test=True)
pyreader.decorate_batch_generator(data_reader.data_generator())
if args.do_test: if args.do_test:
assert args.init_checkpoint is not None, "[FATAL] Please use --init_checkpoint '/path/to/checkpoints' \ assert args.init_checkpoint is not None, "[FATAL] Please use --init_checkpoint '/path/to/checkpoints' \
to specify you pretrained model checkpoints" to specify you pretrained model checkpoints"
...@@ -145,8 +147,6 @@ def predict_wrapper(args, ...@@ -145,8 +147,6 @@ def predict_wrapper(args,
init_pretraining_params(exe, args.init_checkpoint, test_prog) init_pretraining_params(exe, args.init_checkpoint, test_prog)
def predict(exe=exe, pyreader=pyreader): def predict(exe=exe, pyreader=pyreader):
pyreader.decorate_batch_generator(data_reader.data_generator())
pyreader.start() pyreader.start()
cost = 0 cost = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册