diff --git a/PaddleNLP/benchmark/transformer/static/train.py b/PaddleNLP/benchmark/transformer/static/train.py index dcb44554e884421b45844f91ed2291ad7fe81319..1b40bc017ffbea2f735f0ef6abbd9aa80745982b 100644 --- a/PaddleNLP/benchmark/transformer/static/train.py +++ b/PaddleNLP/benchmark/transformer/static/train.py @@ -119,10 +119,12 @@ def do_train(args): batch_id = 0 batch_start = time.time() pass_start_time = batch_start - for data in train_loader(): + for data in train_loader: # NOTE: used for benchmark and use None as default. if args.max_iter and step_idx == args.max_iter: return + if trainer_count == 1: + data = [data] train_reader_cost = time.time() - batch_start outs = exe.run(compiled_train_program,