未验证 提交 d60dee05 编写于 作者: Y Yiqun Liu 提交者: GitHub

Add calcutation and printing of ips for gan models (#4944)

上级 7196201c
......@@ -338,17 +338,19 @@ class CycleGAN(object):
feed={"input_A": tensor_A,
"fake_pool_A": fake_pool_A})[0]
batch_cost_averager.record(time.time() - batch_start)
batch_cost_averager.record(
time.time() - batch_start, num_samples=self.cfg.batch_size)
if batch_id % self.cfg.print_freq == 0:
print("epoch{}: batch{}: \n\
d_A_loss: {}; g_A_loss: {}; g_A_cyc_loss: {}; g_A_idt_loss: {}; \n\
d_B_loss: {}; g_B_loss: {}; g_B_cyc_loss: {}; g_B_idt_loss: {}; \n\
reader_cost: {}, Batch_time_cost: {}"
d_A_loss: {:.5f}; g_A_loss: {:.5f}; g_A_cyc_loss: {:.5f}; g_A_idt_loss: {:.5f}; \n\
d_B_loss: {:.5f}; g_B_loss: {:.5f}; g_B_cyc_loss: {:.5f}; g_B_idt_loss: {:.5f}; \n\
batch_cost: {:.5f} sec, reader_cost: {:.5f} sec, ips: {:.5f} images/sec"
.format(epoch_id, batch_id, d_A_loss[0], g_A_loss[
0], g_A_cyc_loss[0], g_A_idt_loss[0], d_B_loss[0],
g_B_loss[0], g_B_cyc_loss[0], g_B_idt_loss[0],
batch_cost_averager.get_average(),
reader_cost_averager.get_average(),
batch_cost_averager.get_average()))
batch_cost_averager.get_ips_average()))
reader_cost_averager.reset()
batch_cost_averager.reset()
......
......@@ -292,16 +292,18 @@ class Pix2pix(object):
],
feed=tensor)
batch_cost_averager.record(time.time() - batch_start)
batch_cost_averager.record(
time.time() - batch_start, num_samples=self.cfg.batch_size)
if batch_id % self.cfg.print_freq == 0:
print("epoch{}: batch{}: \n\
g_loss_gan: {}; g_loss_l1: {}; \n\
d_loss_real: {}; d_loss_fake: {}; \n\
reader_cost: {}, Batch_time_cost: {}"
g_loss_gan: {:.5f}; g_loss_l1: {:.5f}; \n\
d_loss_real: {:.5f}; d_loss_fake: {:.5f}; \n\
batch_cost: {:.5f} sec, reader_cost: {:.5f} sec, ips: {:.5f} images/sec"
.format(epoch_id, batch_id, g_loss_gan[0], g_loss_l1[
0], d_loss_real[0], d_loss_fake[0],
batch_cost_averager.get_average(),
reader_cost_averager.get_average(),
batch_cost_averager.get_average()))
batch_cost_averager.get_ips_average()))
reader_cost_averager.reset()
batch_cost_averager.reset()
......
......@@ -376,19 +376,21 @@ class STGAN(object):
g_loss_fake, g_loss_rec, g_loss_cls = exe.run(
gen_trainer_program, fetch_list=d_fetches, feed=data)
print("epoch{}: batch{}: \n\
g_loss_fake: {}; g_loss_rec: {}; g_loss_cls: {}"
g_loss_fake: {:.5f}; g_loss_rec: {:.5f}; g_loss_cls: {:.5f}"
.format(epoch_id, batch_id, g_loss_fake[0],
g_loss_rec[0], g_loss_cls[0]))
batch_cost_averager.record(time.time() - batch_start)
batch_cost_averager.record(
time.time() - batch_start, num_samples=self.cfg.batch_size)
if (batch_id + 1) % self.cfg.print_freq == 0:
print("epoch{}: batch{}: \n\
d_loss: {}; d_loss_real: {}; d_loss_fake: {}; d_loss_cls: {}; d_loss_gp: {} \n\
reader_cost: {}, Batch_time_cost: {}"
d_loss: {:.5f}; d_loss_real: {:.5f}; d_loss_fake: {:.5f}; d_loss_cls: {:.5f}; d_loss_gp: {:.5f} \n\
batch_cost: {:.5f} sec, reader_cost: {:.5f} sec, ips: {:.5f} images/sec"
.format(epoch_id, batch_id, d_loss[0], d_loss_real[0],
d_loss_fake[0], d_loss_cls[0], d_loss_gp[0],
batch_cost_averager.get_average(),
reader_cost_averager.get_average(),
batch_cost_averager.get_average()))
batch_cost_averager.get_ips_average()))
reader_cost_averager.reset()
batch_cost_averager.reset()
......
......@@ -345,19 +345,21 @@ class StarGAN(object):
],
feed=data)
print("epoch{}: batch{}: \n\
g_loss_fake: {}; g_loss_rec: {}; g_loss_cls: {}"
g_loss_fake: {:.5f}; g_loss_rec: {:.5f}; g_loss_cls: {:.5f}"
.format(epoch_id, batch_id, g_loss_fake[0],
g_loss_rec[0], g_loss_cls[0]))
batch_cost_averager.record(time.time() - batch_start)
batch_cost_averager.record(
time.time() - batch_start, num_samples=self.cfg.batch_size)
if (batch_id + 1) % self.cfg.print_freq == 0:
print("epoch{}: batch{}: \n\
d_loss_real: {}; d_loss_fake: {}; d_loss_cls: {}; d_loss_gp: {} \n\
reader_cost: {}, Batch_time_cost: {}"
d_loss_real: {:.5f}; d_loss_fake: {:.5f}; d_loss_cls: {:.5f}; d_loss_gp: {:.5f} \n\
batch_cost: {:.5f} sec, reader_cost: {:.5f} sec, ips: {:.5f} images/sec"
.format(epoch_id, batch_id, d_loss_real[0],
d_loss_fake[0], d_loss_cls[0], d_loss_gp[0],
batch_cost_averager.get_average(),
reader_cost_averager.get_average(),
batch_cost_averager.get_average()))
batch_cost_averager.get_ips_average()))
reader_cost_averager.reset()
batch_cost_averager.reset()
......
......@@ -22,12 +22,20 @@ class TimeAverager(object):
def reset(self):
self._cnt = 0
self._total_time = 0
self._total_samples = 0
def record(self, usetime):
def record(self, usetime, num_samples=None):
self._cnt += 1
self._total_time += usetime
if num_samples:
self._total_samples += num_samples
def get_average(self):
if self._cnt == 0:
return 0
return self._total_time / self._cnt
return self._total_time / float(self._cnt)
def get_ips_average(self):
if not self._total_samples or self._cnt == 0:
return 0
return float(self._total_samples) / self._total_time
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册