未验证 提交 da02f748 编写于 作者: Z zhengya01 提交者: GitHub

Merge pull request #4 from PaddlePaddle/develop

update
...@@ -165,7 +165,8 @@ def train(args): ...@@ -165,7 +165,8 @@ def train(args):
'conditions': conditions_data}, 'conditions': conditions_data},
fetch_list={dg_loss})[0][0] fetch_list={dg_loss})[0][0]
losses[1].append(dg_loss_n) losses[1].append(dg_loss_n)
t_time += (time.time() - s_time) batch_time = time.time() - s_time
t_time += batch_time
...@@ -180,8 +181,9 @@ def train(args): ...@@ -180,8 +181,9 @@ def train(args):
fetch_list={g_img})[0] fetch_list={g_img})[0]
total_images = np.concatenate([real_image, generated_images]) total_images = np.concatenate([real_image, generated_images])
fig = plot(total_images) fig = plot(total_images)
msg = "Epoch ID={0}\n Batch ID={1}\n D-Loss={2}\n DG-Loss={3}\n gen={4}".format( msg = "Epoch ID={0}\n Batch ID={1}\n D-Loss={2}\n DG-Loss={3}\n gen={4}\n " \
pass_id, batch_id, d_loss_n, dg_loss_n, check(generated_images)) "Batch_time_cost={5:.2f}".format(
pass_id, batch_id, d_loss_n, dg_loss_n, check(generated_images), batch_time)
print(msg) print(msg)
plt.title(msg) plt.title(msg)
plt.savefig( plt.savefig(
......
...@@ -187,10 +187,12 @@ def train(args): ...@@ -187,10 +187,12 @@ def train(args):
fetch_list=[d_A_trainer.d_loss_A], fetch_list=[d_A_trainer.d_loss_A],
feed={"input_A": tensor_A, feed={"input_A": tensor_A,
"fake_pool_A": fake_pool_A})[0] "fake_pool_A": fake_pool_A})[0]
t_time += (time.time() - s_time) batch_time = time.time() - s_time
print("epoch{}; batch{}; g_A_loss: {}; d_B_loss: {}; g_B_loss: {}; d_A_loss: {};".format( t_time += batch_time
print("epoch{}; batch{}; g_A_loss: {}; d_B_loss: {}; g_B_loss: {}; d_A_loss: {}; "
"Batch_time_cost: {:.2f}".format(
epoch, batch_id, g_A_loss[0], d_B_loss[0], g_B_loss[0], epoch, batch_id, g_A_loss[0], d_B_loss[0], g_B_loss[0],
d_A_loss[0])) d_A_loss[0], batch_time))
losses[0].append(g_A_loss[0]) losses[0].append(g_A_loss[0])
losses[1].append(d_A_loss[0]) losses[1].append(d_A_loss[0])
sys.stdout.flush() sys.stdout.flush()
......
...@@ -390,6 +390,8 @@ def train(args): ...@@ -390,6 +390,8 @@ def train(args):
else: else:
global_step, last_cost = train_with_feed(global_step) global_step, last_cost = train_with_feed(global_step)
train_time += time.time() - begin_time train_time += time.time() - begin_time
print("Pass {0}, pass_time_cost {1}"
.format(epoch, "%2.2f sec" % time.time() -begin_time ))
# For internal continuous evaluation # For internal continuous evaluation
if "CE_MODE_X" in os.environ: if "CE_MODE_X" in os.environ:
print("kpis train_cost %f" % last_cost) print("kpis train_cost %f" % last_cost)
......
...@@ -446,7 +446,9 @@ def train(logger, args): ...@@ -446,7 +446,9 @@ def train(logger, args):
logger.info('Dev eval result: {}'.format( logger.info('Dev eval result: {}'.format(
bleu_rouge)) bleu_rouge))
pass_end_time = time.time() pass_end_time = time.time()
time_consumed = pass_end_time - pass_start_time
logger.info('epoch: {0}, epoch_time_cost: {1:.2f}'.format(
pass_id, time_consumed))
logger.info('Evaluating the model after epoch {}'.format( logger.info('Evaluating the model after epoch {}'.format(
pass_id)) pass_id))
if brc_data.dev_set is not None: if brc_data.dev_set is not None:
...@@ -459,7 +461,7 @@ def train(logger, args): ...@@ -459,7 +461,7 @@ def train(logger, args):
else: else:
logger.warning( logger.warning(
'No dev set is loaded for evaluation in the dataset!') 'No dev set is loaded for evaluation in the dataset!')
time_consumed = pass_end_time - pass_start_time
logger.info('Average train loss for epoch {} is {}'.format( logger.info('Average train loss for epoch {} is {}'.format(
pass_id, "%.10f" % (1.0 * total_loss / total_num))) pass_id, "%.10f" % (1.0 * total_loss / total_num)))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册