diff --git a/ppgan/engine/trainer.py b/ppgan/engine/trainer.py index 0944087a7855d858bbd20a8f75ff1457e79a5cba..b518c6ba85044ded0dc105d8d2c29f6a28d08d0a 100644 --- a/ppgan/engine/trainer.py +++ b/ppgan/engine/trainer.py @@ -55,11 +55,8 @@ class Trainer: def distributed_data_parallel(self): strategy = paddle.distributed.prepare_context() - for name in self.model.model_names: - if isinstance(name, str): - net = getattr(self.model, 'net' + name) - setattr(self.model, 'net' + name, - paddle.DataParallel(net, strategy)) + for net_name, net in self.model.nets.items(): + self.model.nets[net_name] = paddle.DataParallel(net, strategy) def train(self): reader_cost_averager = TimeAverager() @@ -77,9 +74,9 @@ class Trainer: self.model.set_input(data) self.model.optimize_parameters() - batch_cost_averager.record( - time.time() - step_start_time, - num_samples=self.cfg.get('batch_size', 1)) + batch_cost_averager.record(time.time() - step_start_time, + num_samples=self.cfg.get( + 'batch_size', 1)) if i % self.log_interval == 0: self.data_time = reader_cost_averager.get_average() self.step_time = batch_cost_averager.get_average() @@ -94,8 +91,8 @@ class Trainer: step_start_time = time.time() - self.logger.info( - 'train one epoch time: {}'.format(time.time() - start_time)) + self.logger.info('train one epoch time: {}'.format(time.time() - + start_time)) if self.validate_interval > -1 and epoch % self.validate_interval: self.validate() self.model.lr_scheduler.step() @@ -105,8 +102,8 @@ class Trainer: def validate(self): if not hasattr(self, 'val_dataloader'): - self.val_dataloader = build_dataloader( - self.cfg.dataset.val, is_train=False) + self.val_dataloader = build_dataloader(self.cfg.dataset.val, + is_train=False) metric_result = {} @@ -152,8 +149,8 @@ class Trainer: self.visual('visual_val', visual_results=visual_results) if i % self.log_interval == 0: - self.logger.info( - 'val iter: [%d/%d]' % (i, len(self.val_dataloader))) + self.logger.info('val iter: [%d/%d]' % + (i, len(self.val_dataloader))) for metric_name in metric_result.keys(): metric_result[metric_name] /= len(self.val_dataloader.dataset) @@ -163,8 +160,8 @@ class Trainer: def test(self): if not hasattr(self, 'test_dataloader'): - self.test_dataloader = build_dataloader( - self.cfg.dataset.test, is_train=False) + self.test_dataloader = build_dataloader(self.cfg.dataset.test, + is_train=False) # data[0]: img, data[1]: img path index # test batch size must be 1 @@ -188,8 +185,8 @@ class Trainer: self.visual('visual_test', visual_results=visual_results) if i % self.log_interval == 0: - self.logger.info( - 'Test iter: [%d/%d]' % (i, len(self.test_dataloader))) + self.logger.info('Test iter: [%d/%d]' % + (i, len(self.test_dataloader))) def print_log(self): losses = self.model.get_current_losses() @@ -277,13 +274,13 @@ class Trainer: self.start_epoch = state_dicts['epoch'] + 1 for net_name, net in self.model.nets.items(): - net.set_dict(state_dicts[net_name]) + net.set_state_dict(state_dicts[net_name]) for opt_name, opt in self.model.optimizers.items(): - opt.set_dict(state_dicts[opt_name]) + opt.set_state_dict(state_dicts[opt_name]) def load(self, weight_path): state_dicts = load(weight_path) for net_name, net in self.model.nets.items(): - net.set_dict(state_dicts[net_name]) + net.set_state_dict(state_dicts[net_name]) diff --git a/ppgan/utils/filesystem.py b/ppgan/utils/filesystem.py index a4f94636cc221abbd0e859b2aa969a492b7d1034..98495b23f6bf1f12b4c464d87f3438f695b6ae04 100644 --- a/ppgan/utils/filesystem.py +++ b/ppgan/utils/filesystem.py @@ -6,7 +6,11 @@ import paddle def makedirs(dir): if not os.path.exists(dir): - os.makedirs(dir) + # avoid error when train with multiple gpus + try: + os.makedirs(dir) + except: + pass def save(state_dicts, file_name):