未验证 提交 95f46d33 编写于 作者: R ruri 提交者: GitHub

fix fetch bug (#4076)

上级 171ec35f
...@@ -225,16 +225,14 @@ def train(args): ...@@ -225,16 +225,14 @@ def train(args):
return return
train_batch_metrics = exe.run(compiled_train_prog, train_batch_metrics = exe.run(compiled_train_prog,
feed=batch, feed=batch,
fetch_list=train_fetch_list fetch_list=train_fetch_list)
if pass_id % args.print_step == 0 else
[])
t2 = time.time() t2 = time.time()
train_batch_elapse = t2 - t1 train_batch_elapse = t2 - t1
train_batch_time_record.append(train_batch_elapse) train_batch_time_record.append(train_batch_elapse)
if pass_id % args.print_step == 0:
train_batch_metrics_avg = np.mean( train_batch_metrics_avg = np.mean(
np.array(train_batch_metrics), axis=1) np.array(train_batch_metrics), axis=1)
train_batch_metrics_record.append(train_batch_metrics_avg) train_batch_metrics_record.append(train_batch_metrics_avg)
if trainer_id == 0: if trainer_id == 0:
print_info("batch", train_batch_metrics_avg, train_batch_elapse, print_info("batch", train_batch_metrics_avg, train_batch_elapse,
pass_id, train_batch_id, args.print_step) pass_id, train_batch_id, args.print_step)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册