未验证 提交 1727ebff 编写于 作者: W wawltor 提交者: GitHub

change the format for the bert benchmark (#5160)

change the format for the bert benchmark (#5160)
上级 73cb3a0f
......@@ -172,6 +172,7 @@ def reset_program_state_dict(model, state_dict):
loc=0.0, scale=scale, size=p.shape).astype(dtype_str)
return new_state_dict
def create_strategy():
"""
Create build strategy and exec strategy.
......@@ -361,21 +362,34 @@ def do_train(args):
data_holders, worker_init,
paddle.static.cuda_places())
train_reader_cost = 0.0
train_run_cost = 0.0
total_samples = 0
reader_start = time.time()
for step, batch in enumerate(train_data_loader):
train_reader_cost += time.time() - reader_start
global_step += 1
train_start = time.time()
loss_return = exe.run(main_program,
feed=batch,
fetch_list=[loss])
train_run_cost += time.time() - train_start
total_samples += args.batch_size
# In the new 2.0 api, must call this function to change the learning_rate
lr_scheduler.step()
if global_step % args.logging_steps == 0:
time_cost = time.time() - tic_train
print(
"global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s, ips: %.2f sequences/s"
% (global_step, epoch, step, loss_return[0],
args.logging_steps / time_cost,
args.logging_steps * args.batch_size / time_cost))
tic_train = time.time()
"tobal step: %d, epoch: %d, batch: %d, loss: %f, "
"avg_reader_cost: %.5f sec, avg_batch_cost: %.5f sec, avg_samples: %.5f, ips: %.5f sequences/sec"
%
(global_step, epoch, step, loss_return[0],
train_reader_cost / args.logging_steps,
(train_reader_cost + train_run_cost) /
args.logging_steps, total_samples / args.logging_steps,
total_samples / (train_reader_cost + train_run_cost)))
train_reader_cost = 0.0
train_run_cost = 0.0
total_samples = 0
if global_step % args.save_steps == 0:
if worker_index == 0:
output_dir = os.path.join(args.output_dir,
......@@ -386,8 +400,10 @@ def do_train(args):
paddle.fluid.io.save_params(exe, output_dir)
tokenizer.save_pretrained(output_dir)
if global_step >= args.max_steps:
reader_start = time.time()
del train_data_loader
return
reader_start = time.time()
del train_data_loader
train_data_loader, data_file = dataset_future.result(timeout=None)
epoch += 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册