未验证 提交 281ae11d 编写于 作者: C ccmeteorljh 提交者: GitHub

Merge pull request #1595 from ccmeteorljh/develop

add models time statistics for benchmark
......@@ -165,7 +165,8 @@ def train(args):
'conditions': conditions_data},
fetch_list={dg_loss})[0][0]
losses[1].append(dg_loss_n)
t_time += (time.time() - s_time)
batch_time = time.time() - s_time
t_time += batch_time
......@@ -180,8 +181,9 @@ def train(args):
fetch_list={g_img})[0]
total_images = np.concatenate([real_image, generated_images])
fig = plot(total_images)
msg = "Epoch ID={0}\n Batch ID={1}\n D-Loss={2}\n DG-Loss={3}\n gen={4}".format(
pass_id, batch_id, d_loss_n, dg_loss_n, check(generated_images))
msg = "Epoch ID={0}\n Batch ID={1}\n D-Loss={2}\n DG-Loss={3}\n gen={4}\n " \
"Batch_time_cost={5:.2f}".format(
pass_id, batch_id, d_loss_n, dg_loss_n, check(generated_images), batch_time)
print(msg)
plt.title(msg)
plt.savefig(
......
......@@ -187,10 +187,12 @@ def train(args):
fetch_list=[d_A_trainer.d_loss_A],
feed={"input_A": tensor_A,
"fake_pool_A": fake_pool_A})[0]
t_time += (time.time() - s_time)
print("epoch{}; batch{}; g_A_loss: {}; d_B_loss: {}; g_B_loss: {}; d_A_loss: {};".format(
batch_time = time.time() - s_time
t_time += batch_time
print("epoch{}; batch{}; g_A_loss: {}; d_B_loss: {}; g_B_loss: {}; d_A_loss: {}; "
"Batch_time_cost: {:.2f}".format(
epoch, batch_id, g_A_loss[0], d_B_loss[0], g_B_loss[0],
d_A_loss[0]))
d_A_loss[0], batch_time))
losses[0].append(g_A_loss[0])
losses[1].append(d_A_loss[0])
sys.stdout.flush()
......
......@@ -390,6 +390,8 @@ def train(args):
else:
global_step, last_cost = train_with_feed(global_step)
train_time += time.time() - begin_time
print("Pass {0}, pass_time_cost {1}"
.format(epoch, "%2.2f sec" % time.time() -begin_time ))
# For internal continuous evaluation
if "CE_MODE_X" in os.environ:
print("kpis train_cost %f" % last_cost)
......
......@@ -446,7 +446,9 @@ def train(logger, args):
logger.info('Dev eval result: {}'.format(
bleu_rouge))
pass_end_time = time.time()
time_consumed = pass_end_time - pass_start_time
logger.info('epoch: {0}, epoch_time_cost: {1:.2f}'.format(
pass_id, time_consumed))
logger.info('Evaluating the model after epoch {}'.format(
pass_id))
if brc_data.dev_set is not None:
......@@ -459,7 +461,7 @@ def train(logger, args):
else:
logger.warning(
'No dev set is loaded for evaluation in the dataset!')
time_consumed = pass_end_time - pass_start_time
logger.info('Average train loss for epoch {} is {}'.format(
pass_id, "%.10f" % (1.0 * total_loss / total_num)))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册