diff --git a/configs/cyclegan_cityscapes.yaml b/configs/cyclegan_cityscapes.yaml index cbd1b84c1373045fef07d17b42a2bf8e45ea643b..fdd8e9507a3f1fc8b209bf3f01d4f3d60b1e3f03 100644 --- a/configs/cyclegan_cityscapes.yaml +++ b/configs/cyclegan_cityscapes.yaml @@ -28,6 +28,7 @@ dataset: train: name: UnpairedDataset dataroot: data/cityscapes + num_workers: 4 phase: train max_dataset_size: inf direction: AtoB diff --git a/configs/pix2pix_cityscapes.yaml b/configs/pix2pix_cityscapes.yaml index b074fddf678d19f5fd4de2d7084ab7b0ad24be32..8f58fae62349dceed5d7d8bedebe7dcd49fcedb7 100644 --- a/configs/pix2pix_cityscapes.yaml +++ b/configs/pix2pix_cityscapes.yaml @@ -25,6 +25,7 @@ dataset: train: name: PairedDataset dataroot: data/cityscapes + num_workers: 4 phase: train max_dataset_size: inf direction: BtoA diff --git a/ppgan/datasets/builder.py b/ppgan/datasets/builder.py index 7db96f7e7084f2b6e6f52133bc16fb136b2bb06f..5c76210aa30dab54fb0357263d20e9c305834f37 100644 --- a/ppgan/datasets/builder.py +++ b/ppgan/datasets/builder.py @@ -111,6 +111,6 @@ def build_dataloader(cfg, is_train=True): batch_size = cfg.get('batch_size', 1) num_workers = cfg.get('num_workers', 0) - dataloader = DictDataLoader(dataset, batch_size, is_train) + dataloader = DictDataLoader(dataset, batch_size, is_train, num_workers) return dataloader \ No newline at end of file diff --git a/ppgan/engine/trainer.py b/ppgan/engine/trainer.py index 35dc5fd8a988148367969ce12511ba5c6eec0e43..ce8f2267ec15751806ba39ab5d1b1da15b1881b5 100644 --- a/ppgan/engine/trainer.py +++ b/ppgan/engine/trainer.py @@ -2,8 +2,9 @@ import os import time import logging +import paddle -from paddle.imperative import ParallelEnv +from paddle.imperative import ParallelEnv, DataParallel from ..datasets.builder import build_dataloader from ..models.builder import build_model @@ -22,10 +23,13 @@ class Trainer: # build model self.model = build_model(cfg) + # multiple gpus prepare + if ParallelEnv().nranks > 1: + self.distributed_data_parallel() self.logger = logging.getLogger(__name__) + # base config - # self.timestamp = time.strftime('-%Y-%m-%d-%H-%M', time.localtime()) self.output_dir = cfg.output_dir self.epochs = cfg.epochs self.start_epoch = 0 @@ -37,25 +41,39 @@ class Trainer: self.cfg = cfg self.local_rank = ParallelEnv().local_rank + + # time count + self.time_count = {} + def distributed_data_parallel(self): + strategy = paddle.imperative.prepare_context() + for name in self.model.model_names: + if isinstance(name, str): + net = getattr(self.model, 'net' + name) + setattr(self.model, 'net' + name, DataParallel(net, strategy)) + def train(self): for epoch in range(self.start_epoch, self.epochs): - start_time = time.time() self.current_epoch = epoch + start_time = step_start_time = time.time() for i, data in enumerate(self.train_dataloader): + data_time = time.time() self.batch_id = i # unpack data from dataset and apply preprocessing # data input should be dict self.model.set_input(data) self.model.optimize_parameters() - + + self.data_time = data_time - step_start_time + self.step_time = time.time() - step_start_time if i % self.log_interval == 0: self.print_log() if i % self.visual_interval == 0: self.visual('visual_train') + step_start_time = time.time() self.logger.info('train one epoch time: {}'.format(time.time() - start_time)) if epoch % self.weight_interval == 0: self.save(epoch, 'weight', keep=-1) @@ -98,6 +116,12 @@ class Trainer: for k, v in losses.items(): message += '%s: %.3f ' % (k, v) + if hasattr(self, 'data_time'): + message += 'reader cost: %.5fs ' % self.data_time + + if hasattr(self, 'step_time'): + message += 'batch cost: %.5fs' % self.step_time + # print the message self.logger.info(message) diff --git a/ppgan/models/cycle_gan_model.py b/ppgan/models/cycle_gan_model.py index 98b83b1896ea7412ae03f939ee8a455ab8b93b44..d191a80ffb91acba31347d8bf80ee48b74e961fb 100644 --- a/ppgan/models/cycle_gan_model.py +++ b/ppgan/models/cycle_gan_model.py @@ -1,4 +1,5 @@ import paddle +from paddle.imperative import ParallelEnv from .base_model import BaseModel from .builder import MODELS @@ -137,7 +138,13 @@ class CycleGANModel(BaseModel): loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss and calculate gradients loss_D = (loss_D_real + loss_D_fake) * 0.5 - loss_D.backward() + # loss_D.backward() + if ParallelEnv().nranks > 1: + loss_D = netD.scale_loss(loss_D) + loss_D.backward() + netD.apply_collective_grads() + else: + loss_D.backward() return loss_D def backward_D_A(self): @@ -177,7 +184,14 @@ class CycleGANModel(BaseModel): self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B # combined loss and calculate gradients self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B - self.loss_G.backward() + + if ParallelEnv().nranks > 1: + self.loss_G = self.netG_A.scale_loss(self.loss_G) + self.loss_G.backward() + self.netG_A.apply_collective_grads() + self.netG_B.apply_collective_grads() + else: + self.loss_G.backward() def optimize_parameters(self): """Calculate losses, gradients, and update network weights; called in every training iteration""" diff --git a/ppgan/models/generators/resnet.py b/ppgan/models/generators/resnet.py index dc040f84507ad1f5d7556061cff02f211a865613..6a82670a7a08a4e5e6adb167e7bd61b05fc23f0a 100644 --- a/ppgan/models/generators/resnet.py +++ b/ppgan/models/generators/resnet.py @@ -36,11 +36,8 @@ class ResnetGenerator(paddle.fluid.dygraph.Layer): else: use_bias = norm_layer == nn.InstanceNorm - print('norm layer:', norm_layer, 'use bias:', use_bias) - model = [ReflectionPad2d(3), nn.Conv2D(input_nc, ngf, filter_size=7, padding=0, bias_attr=use_bias), - # nn.nn.Conv2D(input_nc, ngf, filter_size=7, padding=0, bias_attr=use_bias), norm_layer(ngf), nn.ReLU()] @@ -62,8 +59,7 @@ class ResnetGenerator(paddle.fluid.dygraph.Layer): model += [ nn.Conv2DTranspose(ngf * mult, int(ngf * mult / 2), filter_size=3, stride=2, - padding=1, #output_padding=1, - # padding='same', #output_padding=1, + padding=1, bias_attr=use_bias), Pad2D(paddings=[0, 1, 0, 1], mode='constant', pad_value=0.0), norm_layer(int(ngf * mult / 2)), diff --git a/ppgan/models/pix2pix_model.py b/ppgan/models/pix2pix_model.py index 9419b0f08516c0616c672988e89cb3556ce137f8..737b4f68f8b90b89c3cc5c34b3e0a37e6f32dfd3 100644 --- a/ppgan/models/pix2pix_model.py +++ b/ppgan/models/pix2pix_model.py @@ -1,4 +1,5 @@ import paddle +from paddle.imperative import ParallelEnv from .base_model import BaseModel from .builder import MODELS @@ -43,7 +44,6 @@ class Pix2PixModel(BaseModel): # define networks (both generator and discriminator) self.netG = build_generator(opt.model.generator) - # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc if self.isTrain: self.netD = build_discriminator(opt.model.discriminator) @@ -98,7 +98,12 @@ class Pix2PixModel(BaseModel): self.loss_D_real = self.criterionGAN(pred_real, True) # combine loss and calculate gradients self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 - self.loss_D.backward() + if ParallelEnv().nranks > 1: + self.loss_D = self.netD.scale_loss(self.loss_D) + self.loss_D.backward() + self.netD.apply_collective_grads() + else: + self.loss_D.backward() def backward_G(self): """Calculate GAN and L1 loss for the generator""" @@ -110,8 +115,13 @@ class Pix2PixModel(BaseModel): self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1 # combine loss and calculate gradients self.loss_G = self.loss_G_GAN + self.loss_G_L1 - # self.loss_G = self.loss_G_L1 - self.loss_G.backward() + + if ParallelEnv().nranks > 1: + self.loss_G = self.netG.scale_loss(self.loss_G) + self.loss_G.backward() + self.netG.apply_collective_grads() + else: + self.loss_G.backward() def optimize_parameters(self): # compute fake images: G(A) diff --git a/ppgan/utils/filesystem.py b/ppgan/utils/filesystem.py index 1eb9f0da435fe8eb945025717f1d591ea2cd9c71..83f0d892ffec6f2389e1d393843ac8ca087b7b9e 100644 --- a/ppgan/utils/filesystem.py +++ b/ppgan/utils/filesystem.py @@ -11,15 +11,13 @@ def save(state_dicts, file_name): def convert(state_dict): model_dict = {} - # name_table = {} + for k, v in state_dict.items(): if isinstance(v, (paddle.framework.Variable, paddle.imperative.core.VarBase)): model_dict[k] = v.numpy() else: model_dict[k] = v - return state_dict - # name_table[k] = v.name - # model_dict["StructuredToParameterName@@"] = name_table + return model_dict final_dict = {}