diff --git a/ppgan/engine/trainer.py b/ppgan/engine/trainer.py index 1e6407fc21e2df5606b3eceff25175b3f5e76722..5d83f302c8251ec21cfd1be9f2b30d626ebcedb7 100755 --- a/ppgan/engine/trainer.py +++ b/ppgan/engine/trainer.py @@ -186,9 +186,9 @@ class Trainer: self.model.setup_input(data) self.model.train_iter(self.optimizers) - batch_cost_averager.record(time.time() - step_start_time, - num_samples=self.cfg.get( - 'batch_size', 1)) + batch_cost_averager.record( + time.time() - step_start_time, + num_samples=self.cfg['dataset']['train'].get('batch_size', 1)) step_start_time = time.time() @@ -233,7 +233,8 @@ class Trainer: for i in range(self.max_eval_steps): if self.max_eval_steps < self.log_interval or i % self.log_interval == 0: self.logger.info('Test iter: [%d/%d]' % - (i * self.world_size, self.max_eval_steps * self.world_size)) + (i * self.world_size, + self.max_eval_steps * self.world_size)) data = next(iter_loader) self.model.setup_input(data) @@ -268,7 +269,6 @@ class Trainer: step=self.batch_id, is_save_image=True) - if self.metrics: for metric_name, metric in self.metrics.items(): self.logger.info("Metric {}: {:.4f}".format( @@ -441,4 +441,4 @@ class Trainer: when finish the training need close file handler or other. """ if self.enable_visualdl: - self.vdl_logger.close() \ No newline at end of file + self.vdl_logger.close()