diff --git a/README.md b/README.md index 1fcc4243acc5bcc9ca25781e8b779d9b415c4d23..87bbcb42cecf172c83bb38acb47646b0d205fa66 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,8 @@ PaddleGAN is an development kit of Generative Adversarial Network based on Paddl ![](./docs/imgs/sr_demo.png) +### Motion driving +![](./docs/imgs/first_order.gif) Features: diff --git a/README_cn.md b/README_cn.md index 6b5bc0344ba1e0d94563573ca569421f702d2a74..ad40a4868953eb9ae68829e24bb8475fd0285141 100644 --- a/README_cn.md +++ b/README_cn.md @@ -17,6 +17,8 @@ PaddleGAN 是一个基于飞桨的生成对抗网络开发工具包. ### 超分辨率 ![](./docs/imgs/sr_demo.png) +### 动作驱动 +![](./docs/imgs/first_order.gif) 特性: diff --git a/docs/imgs/first_order.gif b/docs/imgs/first_order.gif new file mode 100644 index 0000000000000000000000000000000000000000..9b6b609939f4e9e71ffe2afdb2b9f68ad0585c16 Binary files /dev/null and b/docs/imgs/first_order.gif differ diff --git a/ppgan/engine/trainer.py b/ppgan/engine/trainer.py index e400e4bcd9023f678a16da172a04d93a47ba705f..da5ee8961c0a22b5e6638149de17e1069f8c1617 100644 --- a/ppgan/engine/trainer.py +++ b/ppgan/engine/trainer.py @@ -77,10 +77,13 @@ class Trainer: self.model.set_input(data) self.model.optimize_parameters() - batch_cost_averager.record(time.time() - step_start_time) + batch_cost_averager.record( + time.time() - step_start_time, + num_samples=self.cfg.get('batch_size', 1)) if i % self.log_interval == 0: self.data_time = reader_cost_averager.get_average() self.step_time = batch_cost_averager.get_average() + self.ips = batch_cost_averager.get_ips_average() self.print_log() reader_cost_averager.reset() @@ -197,11 +200,14 @@ class Trainer: for k, v in losses.items(): message += '%s: %.3f ' % (k, v) + if hasattr(self, 'step_time'): + message += 'batch_cost: %.5f sec ' % self.step_time + if hasattr(self, 'data_time'): - message += 'reader cost: %.5fs ' % self.data_time + message += 'reader_cost: %.5f sec ' % self.data_time - if hasattr(self, 'step_time'): - message += 'batch cost: %.5fs' % self.step_time + if hasattr(self, 'ips'): + message += 'ips: %.5f images/s' % self.ips # print the message self.logger.info(message) diff --git a/ppgan/utils/timer.py b/ppgan/utils/timer.py index 6b277e5e291c70327fb08e899f4606ae08f68f5c..838dc752aab5edbdbc7948456c0bb6a102d12952 100644 --- a/ppgan/utils/timer.py +++ b/ppgan/utils/timer.py @@ -22,12 +22,20 @@ class TimeAverager(object): def reset(self): self._cnt = 0 self._total_time = 0 + self._total_samples = 0 - def record(self, usetime): + def record(self, usetime, num_samples=None): self._cnt += 1 self._total_time += usetime + if num_samples: + self._total_samples += num_samples def get_average(self): if self._cnt == 0: return 0 - return self._total_time / self._cnt + return self._total_time / float(self._cnt) + + def get_ips_average(self): + if not self._total_samples or self._cnt == 0: + return 0 + return float(self._total_samples) / self._total_time