提交 edb230a9 编写于 作者: L LielinJiang

Merge branch 'master' of https://github.com/PaddlePaddle/PaddleGAN into release/0.1.0

...@@ -18,6 +18,8 @@ PaddleGAN is an development kit of Generative Adversarial Network based on Paddl ...@@ -18,6 +18,8 @@ PaddleGAN is an development kit of Generative Adversarial Network based on Paddl
![](./docs/imgs/sr_demo.png) ![](./docs/imgs/sr_demo.png)
### Motion driving
![](./docs/imgs/first_order.gif)
Features: Features:
......
...@@ -17,6 +17,8 @@ PaddleGAN 是一个基于飞桨的生成对抗网络开发工具包. ...@@ -17,6 +17,8 @@ PaddleGAN 是一个基于飞桨的生成对抗网络开发工具包.
### 超分辨率 ### 超分辨率
![](./docs/imgs/sr_demo.png) ![](./docs/imgs/sr_demo.png)
### 动作驱动
![](./docs/imgs/first_order.gif)
特性: 特性:
......
...@@ -77,10 +77,13 @@ class Trainer: ...@@ -77,10 +77,13 @@ class Trainer:
self.model.set_input(data) self.model.set_input(data)
self.model.optimize_parameters() self.model.optimize_parameters()
batch_cost_averager.record(time.time() - step_start_time) batch_cost_averager.record(
time.time() - step_start_time,
num_samples=self.cfg.get('batch_size', 1))
if i % self.log_interval == 0: if i % self.log_interval == 0:
self.data_time = reader_cost_averager.get_average() self.data_time = reader_cost_averager.get_average()
self.step_time = batch_cost_averager.get_average() self.step_time = batch_cost_averager.get_average()
self.ips = batch_cost_averager.get_ips_average()
self.print_log() self.print_log()
reader_cost_averager.reset() reader_cost_averager.reset()
...@@ -197,11 +200,14 @@ class Trainer: ...@@ -197,11 +200,14 @@ class Trainer:
for k, v in losses.items(): for k, v in losses.items():
message += '%s: %.3f ' % (k, v) message += '%s: %.3f ' % (k, v)
if hasattr(self, 'step_time'):
message += 'batch_cost: %.5f sec ' % self.step_time
if hasattr(self, 'data_time'): if hasattr(self, 'data_time'):
message += 'reader cost: %.5fs ' % self.data_time message += 'reader_cost: %.5f sec ' % self.data_time
if hasattr(self, 'step_time'): if hasattr(self, 'ips'):
message += 'batch cost: %.5fs' % self.step_time message += 'ips: %.5f images/s' % self.ips
# print the message # print the message
self.logger.info(message) self.logger.info(message)
......
...@@ -22,12 +22,20 @@ class TimeAverager(object): ...@@ -22,12 +22,20 @@ class TimeAverager(object):
def reset(self): def reset(self):
self._cnt = 0 self._cnt = 0
self._total_time = 0 self._total_time = 0
self._total_samples = 0
def record(self, usetime): def record(self, usetime, num_samples=None):
self._cnt += 1 self._cnt += 1
self._total_time += usetime self._total_time += usetime
if num_samples:
self._total_samples += num_samples
def get_average(self): def get_average(self):
if self._cnt == 0: if self._cnt == 0:
return 0 return 0
return self._total_time / self._cnt return self._total_time / float(self._cnt)
def get_ips_average(self):
if not self._total_samples or self._cnt == 0:
return 0
return float(self._total_samples) / self._total_time
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册