未验证 提交 95f46d33 编写于 作者: R ruri 提交者: GitHub

fix fetch bug (#4076)

上级 171ec35f
......@@ -225,16 +225,14 @@ def train(args):
return
train_batch_metrics = exe.run(compiled_train_prog,
feed=batch,
fetch_list=train_fetch_list
if pass_id % args.print_step == 0 else
[])
fetch_list=train_fetch_list)
t2 = time.time()
train_batch_elapse = t2 - t1
train_batch_time_record.append(train_batch_elapse)
if pass_id % args.print_step == 0:
train_batch_metrics_avg = np.mean(
np.array(train_batch_metrics), axis=1)
train_batch_metrics_record.append(train_batch_metrics_avg)
train_batch_metrics_avg = np.mean(
np.array(train_batch_metrics), axis=1)
train_batch_metrics_record.append(train_batch_metrics_avg)
if trainer_id == 0:
print_info("batch", train_batch_metrics_avg, train_batch_elapse,
pass_id, train_batch_id, args.print_step)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册