diff --git a/.gitignore b/.gitignore index 3a301952651a7459af6aa13482b5cf7813b0e9d9..b245b85b056b3433442c99cc1cd7e3c8d63b75a5 100644 --- a/.gitignore +++ b/.gitignore @@ -100,4 +100,8 @@ venv.bak/ /site # mypy -.mypy_cache/ \ No newline at end of file +.mypy_cache/ + +# data +data/ +output_dir/ \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d8f849f3205d59f42a06a722e333f7ecf45d3b13 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,41 @@ +- repo: local + hooks: + - id: yapf + name: yapf + entry: yapf + language: system + args: [-i, --style .style.yapf] + files: \.py$ + +- repo: https://github.com/pre-commit/pre-commit-hooks + sha: a11d9314b22d8f8c7556443875b731ef05965464 + hooks: + - id: check-merge-conflict + - id: check-symlinks + - id: end-of-file-fixer + - id: trailing-whitespace + - id: detect-private-key + - id: check-symlinks + - id: check-added-large-files + +- repo: local + hooks: + - id: flake8 + name: flake8 + entry: flake8 + language: system + args: + - --count + - --select=E9,F63,F7,F82 + - --show-source + - --statistics + files: \.py$ + +- repo: local + hooks: + - id: copyright_checker + name: copyright_checker + entry: python ./.copyright.hook + language: system + files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$ + exclude: (?!.*third_party)^.*$ \ No newline at end of file diff --git a/.style.yapf b/.style.yapf new file mode 100644 index 0000000000000000000000000000000000000000..b62febf509036e6b75d4d3ffa76754d6e2e80d98 --- /dev/null +++ b/.style.yapf @@ -0,0 +1,3 @@ +[style] +based_on_style = pep8 +column_limit = 80 \ No newline at end of file diff --git a/README.md b/README.md index 5103bf78b1fd98a55a9634a9185250957bdbcb10..f4e16b4fc9f0e33a01d6664600b77b0ce563317a 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,16 @@ +English | [简体中文](./README.md) + # PaddleGAN still under development!! +## Download Dataset +This script could download several dataset for paired images for image2image translation task. + +``` +cd PaddleGAN/script/ +bash pix2pix_download.sh [cityscapes|facades|edges2handbags|edges2shoes|maps] +``` ## Train ``` python -u tools/main.py --config-file configs/cyclegan-cityscapes.yaml diff --git a/configs/cyclegan-cityscapes.yaml b/configs/cyclegan_cityscapes.yaml similarity index 97% rename from configs/cyclegan-cityscapes.yaml rename to configs/cyclegan_cityscapes.yaml index fd48f9f5905967e11bf4db336a4dbef23d0cd49f..cbd1b84c1373045fef07d17b42a2bf8e45ea643b 100644 --- a/configs/cyclegan-cityscapes.yaml +++ b/configs/cyclegan_cityscapes.yaml @@ -26,7 +26,7 @@ model: dataset: train: - name: UnalignedDataset + name: UnpairedDataset dataroot: data/cityscapes phase: train max_dataset_size: inf diff --git a/configs/cyclegan-horse2zebra.yaml b/configs/cyclegan_horse2zebra.yaml similarity index 97% rename from configs/cyclegan-horse2zebra.yaml rename to configs/cyclegan_horse2zebra.yaml index 7cd27e3049a7c6df38b72831cc3ff6f412ff6005..72c7cf6ccf3a9b058914498a25c8165400d3cb4e 100644 --- a/configs/cyclegan-horse2zebra.yaml +++ b/configs/cyclegan_horse2zebra.yaml @@ -26,7 +26,7 @@ model: dataset: train: - name: UnalignedDataset + name: UnpairedDataset dataroot: data/horse2zebra phase: train max_dataset_size: inf diff --git a/configs/pix2pix-cityscapes.yaml b/configs/pix2pix_cityscapes.yaml similarity index 95% rename from configs/pix2pix-cityscapes.yaml rename to configs/pix2pix_cityscapes.yaml index 3131d0b078c972c8892f28dd8a27ffb45e1ac8a1..b074fddf678d19f5fd4de2d7084ab7b0ad24be32 100644 --- a/configs/pix2pix-cityscapes.yaml +++ b/configs/pix2pix_cityscapes.yaml @@ -23,7 +23,7 @@ model: dataset: train: - name: AlignedDataset + name: PairedDataset dataroot: data/cityscapes phase: train max_dataset_size: inf @@ -38,7 +38,7 @@ dataset: preprocess: resize_and_crop no_flip: False test: - name: AlignedDataset + name: PairedDataset dataroot: data/cityscapes/ phase: test max_dataset_size: inf diff --git a/configs/pix2pix-cityscapes-2gpus.yaml b/configs/pix2pix_cityscapes_2gpus.yaml similarity index 95% rename from configs/pix2pix-cityscapes-2gpus.yaml rename to configs/pix2pix_cityscapes_2gpus.yaml index 5b785c16b9a05b2874c54937728fa794bd3cb494..387f16bd2de43cd150a852b238057850607911be 100644 --- a/configs/pix2pix-cityscapes-2gpus.yaml +++ b/configs/pix2pix_cityscapes_2gpus.yaml @@ -23,7 +23,7 @@ model: dataset: train: - name: AlignedDataset + name: PairedDataset dataroot: data/cityscapes phase: train max_dataset_size: inf @@ -38,7 +38,7 @@ dataset: preprocess: resize_and_crop no_flip: False test: - name: AlignedDataset + name: PairedDataset dataroot: data/cityscapes/ phase: test max_dataset_size: inf diff --git a/configs/pix2pix_facades.yaml b/configs/pix2pix_facades.yaml new file mode 100644 index 0000000000000000000000000000000000000000..06a8403f97723664e5fe330e7bafdd0d3171aa89 --- /dev/null +++ b/configs/pix2pix_facades.yaml @@ -0,0 +1,70 @@ +epochs: 200 +isTrain: True +output_dir: output_dir +lambda_L1: 100 + +model: + name: Pix2PixModel + generator: + name: UnetGenerator + norm_type: batch + input_nc: 3 + output_nc: 3 + num_downs: 8 #unet256 + ngf: 64 + use_dropout: False + discriminator: + name: NLayerDiscriminator + ndf: 64 + n_layers: 3 + input_nc: 6 + norm_type: batch + gan_mode: vanilla + +dataset: + train: + name: PairedDataset + dataroot: data/facades/ + phase: train + max_dataset_size: inf + direction: BtoA + input_nc: 3 + output_nc: 3 + serial_batches: False + pool_size: 0 + transform: + load_size: 286 + crop_size: 256 + preprocess: resize_and_crop + no_flip: False + test: + name: PairedDataset + dataroot: data/facades/ + phase: test + max_dataset_size: inf + direction: BtoA + input_nc: 3 + output_nc: 3 + serial_batches: True + pool_size: 50 + transform: + load_size: 256 + crop_size: 256 + preprocess: resize_and_crop + no_flip: True + +optimizer: + name: Adam + beta1: 0.5 + lr_scheduler: + name: linear + learning_rate: 0.0002 + start_epoch: 100 + decay_epochs: 100 + +log_config: + interval: 100 + visiual_interval: 500 + +snapshot_config: + interval: 5 diff --git a/ppgan/datasets/__init__.py b/ppgan/datasets/__init__.py index 178cb4d9709b0b8c0f68df9170be80692593092e..9b807e9be0c83dda6415ebf01418cc77b8f463ba 100644 --- a/ppgan/datasets/__init__.py +++ b/ppgan/datasets/__init__.py @@ -1,3 +1,3 @@ -from .unaligned_dataset import UnalignedDataset +from .unpaired_dataset import UnpairedDataset from .single_dataset import SingleDataset -from .aligned_dataset import AlignedDataset +from .paired_dataset import PairedDataset diff --git a/ppgan/datasets/builder.py b/ppgan/datasets/builder.py index 7dc8be530d0be6e1ba48fa7007208c4e8e9b842a..7db96f7e7084f2b6e6f52133bc16fb136b2bb06f 100644 --- a/ppgan/datasets/builder.py +++ b/ppgan/datasets/builder.py @@ -1,3 +1,4 @@ +import time import paddle import numbers import numpy as np @@ -23,7 +24,7 @@ class DictDataset(paddle.io.Dataset): for k, v in single_item.items(): if not isinstance(v, (numbers.Number, np.ndarray)): - self.non_tensor_dict.update({k: {}}) + setattr(self, k, Manager().dict()) self.non_tensor_keys_set.add(k) else: self.tensor_keys_set.add(k) @@ -38,9 +39,7 @@ class DictDataset(paddle.io.Dataset): if isinstance(v, (numbers.Number, np.ndarray)): tmp_list.append(v) else: - tmp_dict = self.non_tensor_dict[k] - tmp_dict.update({index: v}) - self.non_tensor_dict[k] = tmp_dict + getattr(self, k).update({index: v}) tmp_list.append(index) return tuple(tmp_list) @@ -50,11 +49,11 @@ class DictDataset(paddle.io.Dataset): def reset(self): for k in self.non_tensor_keys_set: - self.non_tensor_dict[k] = {} + setattr(self, k, Manager().dict()) class DictDataLoader(): - def __init__(self, dataset, batch_size, is_train, num_workers=0): + def __init__(self, dataset, batch_size, is_train, num_workers=4): self.dataset = DictDataset(dataset) @@ -97,7 +96,7 @@ class DictDataLoader(): if isinstance(indexs, paddle.Variable): indexs = indexs.numpy() current_items = [] - items = self.dataset.non_tensor_dict[key] + items = getattr(self.dataset, key) for index in indexs: current_items.append(items[index]) @@ -105,6 +104,7 @@ class DictDataLoader(): return current_items + def build_dataloader(cfg, is_train=True): dataset = DATASETS.get(cfg.name)(cfg) diff --git a/ppgan/datasets/aligned_dataset.py b/ppgan/datasets/paired_dataset.py similarity index 91% rename from ppgan/datasets/aligned_dataset.py rename to ppgan/datasets/paired_dataset.py index 8c8f8ce4e62abf6b51cf652c5be0db28d54c805e..368f8371178ab771d3139103992a97abc3ee0fe8 100644 --- a/ppgan/datasets/aligned_dataset.py +++ b/ppgan/datasets/paired_dataset.py @@ -8,19 +8,19 @@ from .builder import DATASETS @DATASETS.register() -class AlignedDataset(BaseDataset): +class PairedDataset(BaseDataset): """A dataset class for paired image dataset. """ - def __init__(self, opt): + def __init__(self, cfg): """Initialize this dataset class. Args: cfg (dict) -- stores all the experiment flags """ - BaseDataset.__init__(self, opt) - self.dir_AB = os.path.join(opt.dataroot, opt.phase) # get the image directory - self.AB_paths = sorted(make_dataset(self.dir_AB, opt.max_dataset_size)) # get image paths + BaseDataset.__init__(self, cfg) + self.dir_AB = os.path.join(cfg.dataroot, cfg.phase) # get the image directory + self.AB_paths = sorted(make_dataset(self.dir_AB, cfg.max_dataset_size)) # get image paths assert(self.cfg.transform.load_size >= self.cfg.transform.crop_size) # crop_size should be smaller than the size of loaded image self.input_nc = self.cfg.output_nc if self.cfg.direction == 'BtoA' else self.cfg.input_nc self.output_nc = self.cfg.input_nc if self.cfg.direction == 'BtoA' else self.cfg.output_nc diff --git a/ppgan/datasets/unaligned_dataset.py b/ppgan/datasets/unpaired_dataset.py similarity index 98% rename from ppgan/datasets/unaligned_dataset.py rename to ppgan/datasets/unpaired_dataset.py index da673a07216cf6ae827b5431d6fd42f86ca1a0dc..5cabc5391b84e9f6aa55e0925d4202c7b3d09418 100644 --- a/ppgan/datasets/unaligned_dataset.py +++ b/ppgan/datasets/unpaired_dataset.py @@ -8,7 +8,7 @@ from .builder import DATASETS @DATASETS.register() -class UnalignedDataset(BaseDataset): +class UnpairedDataset(BaseDataset): """ """ diff --git a/ppgan/models/cycle_gan_model.py b/ppgan/models/cycle_gan_model.py index 2d4b9e3998a5cce832aa54911ace28699193c2d5..98b83b1896ea7412ae03f939ee8a455ab8b93b44 100644 --- a/ppgan/models/cycle_gan_model.py +++ b/ppgan/models/cycle_gan_model.py @@ -5,7 +5,7 @@ from .builder import MODELS from .generators.builder import build_generator from .discriminators.builder import build_discriminator from .losses import GANLoss -# from ..modules.nn import L1Loss + from ..solver import build_optimizer from ..utils.image_pool import ImagePool @@ -27,7 +27,7 @@ class CycleGANModel(BaseModel): """Initialize the CycleGAN class. Parameters: - opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + opt (config)-- stores all the experiment flags; needs to be a subclass of Dict """ BaseModel.__init__(self, opt) # specify the training losses you want to print out. The training/test scripts will call @@ -35,12 +35,15 @@ class CycleGANModel(BaseModel): # specify the images you want to save/display. The training/test scripts will call visual_names_A = ['real_A', 'fake_B', 'rec_A'] visual_names_B = ['real_B', 'fake_A', 'rec_B'] - if self.isTrain and self.opt.lambda_identity > 0.0: # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B) + + # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B) + if self.isTrain and self.opt.lambda_identity > 0.0: visual_names_A.append('idt_B') visual_names_B.append('idt_A') - self.visual_names = visual_names_A + visual_names_B # combine visualizations for A and B - # specify the models you want to save to the disk. The training/test scripts will call and . + # combine visualizations for A and B + self.visual_names = visual_names_A + visual_names_B + # specify the models you want to save to the disk. if self.isTrain: self.model_names = ['G_A', 'G_B', 'D_A', 'D_B'] else: # during test time, only load Gs @@ -59,22 +62,22 @@ class CycleGANModel(BaseModel): if self.isTrain: if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels assert(opt.dataset.train.input_nc == opt.dataset.train.output_nc) - self.fake_A_pool = ImagePool(opt.dataset.train.pool_size) # create image buffer to store previously generated images - self.fake_B_pool = ImagePool(opt.dataset.train.pool_size) # create image buffer to store previously generated images + # create image buffer to store previously generated images + self.fake_A_pool = ImagePool(opt.dataset.train.pool_size) + # create image buffer to store previously generated images + self.fake_B_pool = ImagePool(opt.dataset.train.pool_size) # define loss functions - self.criterionGAN = GANLoss(opt.model.gan_mode, [[[[1.0]]]], [[[[0.0]]]])#.to(self.device) # define GAN loss. + self.criterionGAN = GANLoss(opt.model.gan_mode) self.criterionCycle = paddle.nn.L1Loss() self.criterionIdt = paddle.nn.L1Loss() self.optimizer_G = build_optimizer(opt.optimizer, parameter_list=self.netG_A.parameters() + self.netG_B.parameters()) self.optimizer_D = build_optimizer(opt.optimizer, parameter_list=self.netD_A.parameters() + self.netD_B.parameters()) - # self.optimizer_DA = build_optimizer(opt.optimizer, parameter_list=self.netD_A.parameters()) - # self.optimizer_DB = build_optimizer(opt.optimizer, parameter_list=self.netD_B.parameters()) + self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) - # self.optimizers.append(self.optimizer_DA) - # self.optimizers.append(self.optimizer_DB) - self.optimizer_names.extend(['optimizer_G', 'optimizer_D'])#A', 'optimizer_DB']) + + self.optimizer_names.extend(['optimizer_G', 'optimizer_D']) def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. @@ -102,7 +105,7 @@ class CycleGANModel(BaseModel): self.image_paths = input['A_paths'] elif 'B_paths' in input: self.image_paths = input['B_paths'] - # self.image_paths = input['A_paths' if AtoB else 'B_paths'] + def forward(self): """Run forward pass; called by both functions and .""" @@ -115,20 +118,6 @@ class CycleGANModel(BaseModel): self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B)) - # def forward_test(self, input): - # input = paddle.imperative.to_variable(input) - # net_g = getattr(self, 'netG_' + self.opt.dataset.test.direction[0]) - # return net_g(input) - - # def test(self, input): - # """Forward function used in test time. - - # This function wraps function in no_grad() so we don't save intermediate steps for backprop - # It also calls to produce additional visualization results - # """ - # with paddle.imperative.no_grad(): - # return self.forward_test(input) - def backward_D_basic(self, netD, real, fake): """Calculate GAN loss for the discriminator @@ -193,27 +182,26 @@ class CycleGANModel(BaseModel): def optimize_parameters(self): """Calculate losses, gradients, and update network weights; called in every training iteration""" # forward - self.forward() # compute fake images and reconstruction images. + # compute fake images and reconstruction images. + self.forward() # G_A and G_B - self.set_requires_grad([self.netD_A, self.netD_B], False) # Ds require no gradients when optimizing Gs - self.optimizer_G.clear_gradients() #zero_grad() # set G_A and G_B's gradients to zero - self.backward_G() # calculate gradients for G_A and G_B - self.optimizer_G.minimize(self.loss_G) #step() # update G_A and G_B's weights - # self.optimizer_G.clear_gradients() - # self.optimizer_G.clear_gradients() + # Ds require no gradients when optimizing Gs + self.set_requires_grad([self.netD_A, self.netD_B], False) + # set G_A and G_B's gradients to zero + self.optimizer_G.clear_gradients() + # calculate gradients for G_A and G_B + self.backward_G() + # update G_A and G_B's weights + self.optimizer_G.minimize(self.loss_G) # D_A and D_B self.set_requires_grad([self.netD_A, self.netD_B], True) - # self.set_requires_grad(self.netD_A, True) - self.optimizer_D.clear_gradients() #zero_grad() # set D_A and D_B's gradients to zero - self.backward_D_A() # calculate gradients for D_A - self.backward_D_B() # calculate graidents for D_B - self.optimizer_D.minimize(self.loss_D_A + self.loss_D_B) # update D_A and D_B's weights - # self.backward_D_A() # calculate gradients for D_A - # self.optimizer_DA.minimize(self.loss_D_A) #step() # update D_A and D_B's weights - # self.optimizer_DA.clear_gradients() #zero_g - # self.set_requires_grad(self.netD_B, True) - # self.optimizer_DB.clear_gradients() #zero_grad() # set D_A and D_B's gradients to zero - - # self.backward_D_B() # calculate graidents for D_B - # self.optimizer_DB.minimize(self.loss_D_B) #step() # update D_A and D_B's weights - # self.optimizer_DB.clear_gradients() #zero_grad() # set D_A and D_B's gradients to zero + + # set D_A and D_B's gradients to zero + self.optimizer_D.clear_gradients() + # calculate gradients for D_A + self.backward_D_A() + # calculate graidents for D_B + self.backward_D_B() + # update D_A and D_B's weights + self.optimizer_D.minimize(self.loss_D_A + self.loss_D_B) + diff --git a/ppgan/models/losses.py b/ppgan/models/losses.py index 2ad6459231e3deb6cf2fc8845380f838c3c35d3c..75d7e00f63b73b68c71ac67b8979809dfd3f1983 100644 --- a/ppgan/models/losses.py +++ b/ppgan/models/losses.py @@ -4,6 +4,7 @@ import numpy as np from ..modules.nn import BCEWithLogitsLoss + class GANLoss(paddle.fluid.dygraph.Layer): """Define different GAN objectives. @@ -23,16 +24,14 @@ class GANLoss(paddle.fluid.dygraph.Layer): LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. """ super(GANLoss, self).__init__() - self.real_label = paddle.fluid.dygraph.to_variable(np.array(target_real_label)) - self.fake_label = paddle.fluid.dygraph.to_variable(np.array(target_fake_label)) - # self.real_label.stop_gradients = True - # self.fake_label.stop_gradients = True + self.target_real_label = target_real_label + self.target_fake_label = target_fake_label self.gan_mode = gan_mode if gan_mode == 'lsgan': self.loss = nn.MSELoss() elif gan_mode == 'vanilla': - self.loss = BCEWithLogitsLoss()#nn.BCEWithLogitsLoss() + self.loss = BCEWithLogitsLoss() elif gan_mode in ['wgangp']: self.loss = None else: @@ -50,14 +49,16 @@ class GANLoss(paddle.fluid.dygraph.Layer): """ if target_is_real: - target_tensor = paddle.fill_constant(shape=paddle.shape(prediction), value=1.0, dtype='float32')#self.real_label + if not hasattr(self, 'target_real_tensor'): + self.target_real_tensor = paddle.fill_constant(shape=paddle.shape(prediction), value=self.target_real_label, dtype='float32') + target_tensor = self.target_real_tensor else: - target_tensor = paddle.fill_constant(shape=paddle.shape(prediction), value=0.0, dtype='float32')#self.fake_label + if not hasattr(self, 'target_fake_tensor'): + self.target_fake_tensor = paddle.fill_constant(shape=paddle.shape(prediction), value=self.target_fake_label, dtype='float32') + target_tensor = self.target_fake_tensor - # target_tensor = paddle.cast(target_tensor, prediction.dtype) - # target_tensor = paddle.expand_as(target_tensor, prediction) # target_tensor.stop_gradient = True - return target_tensor#paddle.expand_as(target_tensor, prediction) + return target_tensor def __call__(self, prediction, target_is_real): """Calculate loss given Discriminator's output and grount truth labels. diff --git a/ppgan/models/pix2pix_model.py b/ppgan/models/pix2pix_model.py index ab3ff48bd05808bd3259339b21c18679ae989057..9419b0f08516c0616c672988e89cb3556ce137f8 100644 --- a/ppgan/models/pix2pix_model.py +++ b/ppgan/models/pix2pix_model.py @@ -1,7 +1,3 @@ -# import torch -# import paddle -# from .base_model import BaseModel -# from . import networks import paddle from .base_model import BaseModel @@ -9,7 +5,7 @@ from .builder import MODELS from .generators.builder import build_generator from .discriminators.builder import build_discriminator from .losses import GANLoss -# from ..modules.nn import L1Loss + from ..solver import build_optimizer from ..utils.image_pool import ImagePool @@ -18,10 +14,10 @@ from ..utils.image_pool import ImagePool class Pix2PixModel(BaseModel): """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data. - The model training requires '--dataset_mode aligned' dataset. + The model training requires 'paired' dataset. By default, it uses a '--netG unet256' U-Net generator, - a '--netD basic' discriminator (PatchGAN), - and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper). + a '--netD basic' discriminator (from PatchGAN), + and a vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper). pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf """ @@ -30,41 +26,37 @@ class Pix2PixModel(BaseModel): """Initialize the pix2pix class. Parameters: - opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + opt (config dict)-- stores all the experiment flags; needs to be a subclass of Dict """ BaseModel.__init__(self, opt) # specify the training losses you want to print out. The training/test scripts will call self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake'] # specify the images you want to save/display. The training/test scripts will call self.visual_names = ['real_A', 'fake_B', 'real_B'] - # specify the models you want to save to the disk. The training/test scripts will call and + # specify the models you want to save to the disk. if self.isTrain: self.model_names = ['G', 'D'] - else: # during test time, only load G + else: + # during test time, only load G self.model_names = ['G'] + # define networks (both generator and discriminator) self.netG = build_generator(opt.model.generator) - # self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, - # not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) - if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc + + # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc + if self.isTrain: self.netD = build_discriminator(opt.model.discriminator) - # self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, - # opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) + if self.isTrain: # define loss functions - self.criterionGAN = GANLoss(opt.model.gan_mode, [[[[1.0]]]], [[[[0.0]]]])#.to(self.device) + self.criterionGAN = GANLoss(opt.model.gan_mode) self.criterionL1 = paddle.nn.L1Loss() - # initialize optimizers; schedulers will be automatically created by function . - # self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) - # self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) - # FIXME: step per epoch - # lr_scheduler_g = self.build_lr_scheduler(opt.lr, step_per_epoch=2975) - # lr_scheduler_d = self.build_lr_scheduler(opt.lr, step_per_epoch=2975) - # lr_scheduler = self.build_lr_scheduler() - self.optimizer_G = build_optimizer(opt.optimizer, parameter_list=self.netG.parameters()) #paddle.optimizer.Adam(learning_rate=lr_scheduler_g, parameter_list=self.netG.parameters(), beta1=opt.beta1) - self.optimizer_D = build_optimizer(opt.optimizer, parameter_list=self.netD.parameters()) #paddle.optimizer.Adam(learning_rate=lr_scheduler_d, parameter_list=self.netD.parameters(), beta1=opt.beta1) + + # build optimizers + self.optimizer_G = build_optimizer(opt.optimizer, parameter_list=self.netG.parameters()) + self.optimizer_D = build_optimizer(opt.optimizer, parameter_list=self.netD.parameters()) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) @@ -78,16 +70,12 @@ class Pix2PixModel(BaseModel): The option 'direction' can be used to swap images in domain A and domain B. """ - # AtoB = self.opt.direction == 'AtoB' - # self.real_A = input['A' if AtoB else 'B'].to(self.device) - # self.real_B = input['B' if AtoB else 'A'].to(self.device) - # self.image_paths = input['A_paths' if AtoB else 'B_paths'] + AtoB = self.opt.dataset.train.direction == 'AtoB' self.real_A = paddle.imperative.to_variable(input['A' if AtoB else 'B']) self.real_B = paddle.imperative.to_variable(input['B' if AtoB else 'A']) self.image_paths = input['A_paths' if AtoB else 'B_paths'] - # self.real_A = paddle.imperative.to_variable(input[0] if AtoB else input[1]) - # self.real_B = paddle.imperative.to_variable(input[1] if AtoB else input[0]) + def forward(self): """Run forward pass; called by both functions and .""" @@ -96,20 +84,12 @@ class Pix2PixModel(BaseModel): def forward_test(self, input): input = paddle.imperative.to_variable(input) return self.netG(input) - - # def test(self, input): - # """Forward function used in test time. - - # This function wraps function in no_grad() so we don't save intermediate steps for backprop - # It also calls to produce additional visualization results - # """ - # with paddle.imperative.no_grad(): - # return self.forward_test(input) def backward_D(self): """Calculate GAN loss for the discriminator""" # Fake; stop backprop to the generator by detaching fake_B - fake_AB = paddle.concat((self.real_A, self.fake_B), 1) # we use conditional GANs; we need to feed both input and output to the discriminator + # use conditional GANs; we need to feed both input and output to the discriminator + fake_AB = paddle.concat((self.real_A, self.fake_B), 1) pred_fake = self.netD(fake_AB.detach()) self.loss_D_fake = self.criterionGAN(pred_fake, False) # Real @@ -134,16 +114,17 @@ class Pix2PixModel(BaseModel): self.loss_G.backward() def optimize_parameters(self): - self.forward() # compute fake images: G(A) + # compute fake images: G(A) + self.forward() + # update D - self.set_requires_grad(self.netD, True) # enable backprop for D - self.optimizer_D.clear_gradients() # set D's gradients to zero - self.backward_D() # calculate gradients for D - self.optimizer_D.minimize(self.loss_D) # update D's weights - # self.netD.clear_gradients() - # self.optimizer_D.clear_gradients() + self.set_requires_grad(self.netD, True) + self.optimizer_D.clear_gradients() + self.backward_D() + self.optimizer_D.minimize(self.loss_D) + # update G - self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G - self.optimizer_G.clear_gradients() # set G's gradients to zero - self.backward_G() # calculate graidents for G - self.optimizer_G.minimize(self.loss_G) # udpate G's weights + self.set_requires_grad(self.netD, False) + self.optimizer_G.clear_gradients() + self.backward_G() + self.optimizer_G.minimize(self.loss_G) diff --git a/script/pix2pix_download.sh b/script/pix2pix_download.sh new file mode 100644 index 0000000000000000000000000000000000000000..d3cc8ce9244ce9644aa25e2841b3d42bf5395f4e --- /dev/null +++ b/script/pix2pix_download.sh @@ -0,0 +1,9 @@ +FILE=$1 +URL=https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/$FILE.tar.gz +TAR_FILE=./$FILE.tar.gz +TARGET_DIR=./$FILE/ +wget -N $URL -O $TAR_FILE --no-check-certificate +mkdir $TARGET_DIR +tar -zxvf $TAR_FILE -C ../data/ +rm $TAR_FILE +rm -rf $TARGET_DIR