未验证 提交 a00c8afb 编写于 作者: W wanghuancoder 提交者: GitHub

fix resnet50 usetime statistics (#4838)

上级 4257b82d
......@@ -234,6 +234,7 @@ def train(args):
train_batch_id = 0
train_batch_time_record = []
train_batch_metrics_record = []
train_batch_time_print_step = []
if not args.use_dali:
train_iter = train_data_loader()
......@@ -256,8 +257,19 @@ def train(args):
np.array(train_batch_metrics), axis=1)
train_batch_metrics_record.append(train_batch_metrics_avg)
if trainer_id == 0:
print_info("batch", train_batch_metrics_avg, train_batch_elapse,
pass_id, train_batch_id, args.print_step)
if train_batch_id % args.print_step == 0:
if len(train_batch_time_print_step) == 0:
train_batch_time_print_step_avg = train_batch_elapse
else:
train_batch_time_print_step_avg = np.mean(
train_batch_time_print_step)
train_batch_time_print_step = []
print_info("batch", train_batch_metrics_avg,
train_batch_time_print_step_avg, pass_id,
train_batch_id, args.print_step)
else:
train_batch_time_print_step.append(train_batch_elapse)
sys.stdout.flush()
train_batch_id += 1
t1 = time.time()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册