提交 1e80b869 编写于 作者: L LielinJiang 提交者: Joejiong

Merge pull request #2 from LielinJiang/multiprocess-dataset

Support multiprocess dict dataset
...@@ -101,3 +101,7 @@ venv.bak/ ...@@ -101,3 +101,7 @@ venv.bak/
# mypy # mypy
.mypy_cache/ .mypy_cache/
# data
data/
output_dir/
\ No newline at end of file
- repo: local
hooks:
- id: yapf
name: yapf
entry: yapf
language: system
args: [-i, --style .style.yapf]
files: \.py$
- repo: https://github.com/pre-commit/pre-commit-hooks
sha: a11d9314b22d8f8c7556443875b731ef05965464
hooks:
- id: check-merge-conflict
- id: check-symlinks
- id: end-of-file-fixer
- id: trailing-whitespace
- id: detect-private-key
- id: check-symlinks
- id: check-added-large-files
- repo: local
hooks:
- id: flake8
name: flake8
entry: flake8
language: system
args:
- --count
- --select=E9,F63,F7,F82
- --show-source
- --statistics
files: \.py$
- repo: local
hooks:
- id: copyright_checker
name: copyright_checker
entry: python ./.copyright.hook
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
exclude: (?!.*third_party)^.*$
\ No newline at end of file
[style]
based_on_style = pep8
column_limit = 80
\ No newline at end of file
English | [简体中文](./README.md)
# PaddleGAN # PaddleGAN
still under development!! still under development!!
## Download Dataset
This script could download several dataset for paired images for image2image translation task.
```
cd PaddleGAN/script/
bash pix2pix_download.sh [cityscapes|facades|edges2handbags|edges2shoes|maps]
```
## Train ## Train
``` ```
python -u tools/main.py --config-file configs/cyclegan-cityscapes.yaml python -u tools/main.py --config-file configs/cyclegan-cityscapes.yaml
......
...@@ -26,7 +26,7 @@ model: ...@@ -26,7 +26,7 @@ model:
dataset: dataset:
train: train:
name: UnalignedDataset name: UnpairedDataset
dataroot: data/cityscapes dataroot: data/cityscapes
phase: train phase: train
max_dataset_size: inf max_dataset_size: inf
......
...@@ -26,7 +26,7 @@ model: ...@@ -26,7 +26,7 @@ model:
dataset: dataset:
train: train:
name: UnalignedDataset name: UnpairedDataset
dataroot: data/horse2zebra dataroot: data/horse2zebra
phase: train phase: train
max_dataset_size: inf max_dataset_size: inf
......
...@@ -23,7 +23,7 @@ model: ...@@ -23,7 +23,7 @@ model:
dataset: dataset:
train: train:
name: AlignedDataset name: PairedDataset
dataroot: data/cityscapes dataroot: data/cityscapes
phase: train phase: train
max_dataset_size: inf max_dataset_size: inf
...@@ -38,7 +38,7 @@ dataset: ...@@ -38,7 +38,7 @@ dataset:
preprocess: resize_and_crop preprocess: resize_and_crop
no_flip: False no_flip: False
test: test:
name: AlignedDataset name: PairedDataset
dataroot: data/cityscapes/ dataroot: data/cityscapes/
phase: test phase: test
max_dataset_size: inf max_dataset_size: inf
......
...@@ -23,7 +23,7 @@ model: ...@@ -23,7 +23,7 @@ model:
dataset: dataset:
train: train:
name: AlignedDataset name: PairedDataset
dataroot: data/cityscapes dataroot: data/cityscapes
phase: train phase: train
max_dataset_size: inf max_dataset_size: inf
...@@ -38,7 +38,7 @@ dataset: ...@@ -38,7 +38,7 @@ dataset:
preprocess: resize_and_crop preprocess: resize_and_crop
no_flip: False no_flip: False
test: test:
name: AlignedDataset name: PairedDataset
dataroot: data/cityscapes/ dataroot: data/cityscapes/
phase: test phase: test
max_dataset_size: inf max_dataset_size: inf
......
epochs: 200
isTrain: True
output_dir: output_dir
lambda_L1: 100
model:
name: Pix2PixModel
generator:
name: UnetGenerator
norm_type: batch
input_nc: 3
output_nc: 3
num_downs: 8 #unet256
ngf: 64
use_dropout: False
discriminator:
name: NLayerDiscriminator
ndf: 64
n_layers: 3
input_nc: 6
norm_type: batch
gan_mode: vanilla
dataset:
train:
name: PairedDataset
dataroot: data/facades/
phase: train
max_dataset_size: inf
direction: BtoA
input_nc: 3
output_nc: 3
serial_batches: False
pool_size: 0
transform:
load_size: 286
crop_size: 256
preprocess: resize_and_crop
no_flip: False
test:
name: PairedDataset
dataroot: data/facades/
phase: test
max_dataset_size: inf
direction: BtoA
input_nc: 3
output_nc: 3
serial_batches: True
pool_size: 50
transform:
load_size: 256
crop_size: 256
preprocess: resize_and_crop
no_flip: True
optimizer:
name: Adam
beta1: 0.5
lr_scheduler:
name: linear
learning_rate: 0.0002
start_epoch: 100
decay_epochs: 100
log_config:
interval: 100
visiual_interval: 500
snapshot_config:
interval: 5
from .unaligned_dataset import UnalignedDataset from .unpaired_dataset import UnpairedDataset
from .single_dataset import SingleDataset from .single_dataset import SingleDataset
from .aligned_dataset import AlignedDataset from .paired_dataset import PairedDataset
import time
import paddle import paddle
import numbers import numbers
import numpy as np import numpy as np
...@@ -23,7 +24,7 @@ class DictDataset(paddle.io.Dataset): ...@@ -23,7 +24,7 @@ class DictDataset(paddle.io.Dataset):
for k, v in single_item.items(): for k, v in single_item.items():
if not isinstance(v, (numbers.Number, np.ndarray)): if not isinstance(v, (numbers.Number, np.ndarray)):
self.non_tensor_dict.update({k: {}}) setattr(self, k, Manager().dict())
self.non_tensor_keys_set.add(k) self.non_tensor_keys_set.add(k)
else: else:
self.tensor_keys_set.add(k) self.tensor_keys_set.add(k)
...@@ -38,9 +39,7 @@ class DictDataset(paddle.io.Dataset): ...@@ -38,9 +39,7 @@ class DictDataset(paddle.io.Dataset):
if isinstance(v, (numbers.Number, np.ndarray)): if isinstance(v, (numbers.Number, np.ndarray)):
tmp_list.append(v) tmp_list.append(v)
else: else:
tmp_dict = self.non_tensor_dict[k] getattr(self, k).update({index: v})
tmp_dict.update({index: v})
self.non_tensor_dict[k] = tmp_dict
tmp_list.append(index) tmp_list.append(index)
return tuple(tmp_list) return tuple(tmp_list)
...@@ -50,11 +49,11 @@ class DictDataset(paddle.io.Dataset): ...@@ -50,11 +49,11 @@ class DictDataset(paddle.io.Dataset):
def reset(self): def reset(self):
for k in self.non_tensor_keys_set: for k in self.non_tensor_keys_set:
self.non_tensor_dict[k] = {} setattr(self, k, Manager().dict())
class DictDataLoader(): class DictDataLoader():
def __init__(self, dataset, batch_size, is_train, num_workers=0): def __init__(self, dataset, batch_size, is_train, num_workers=4):
self.dataset = DictDataset(dataset) self.dataset = DictDataset(dataset)
...@@ -97,7 +96,7 @@ class DictDataLoader(): ...@@ -97,7 +96,7 @@ class DictDataLoader():
if isinstance(indexs, paddle.Variable): if isinstance(indexs, paddle.Variable):
indexs = indexs.numpy() indexs = indexs.numpy()
current_items = [] current_items = []
items = self.dataset.non_tensor_dict[key] items = getattr(self.dataset, key)
for index in indexs: for index in indexs:
current_items.append(items[index]) current_items.append(items[index])
...@@ -105,6 +104,7 @@ class DictDataLoader(): ...@@ -105,6 +104,7 @@ class DictDataLoader():
return current_items return current_items
def build_dataloader(cfg, is_train=True): def build_dataloader(cfg, is_train=True):
dataset = DATASETS.get(cfg.name)(cfg) dataset = DATASETS.get(cfg.name)(cfg)
......
...@@ -8,19 +8,19 @@ from .builder import DATASETS ...@@ -8,19 +8,19 @@ from .builder import DATASETS
@DATASETS.register() @DATASETS.register()
class AlignedDataset(BaseDataset): class PairedDataset(BaseDataset):
"""A dataset class for paired image dataset. """A dataset class for paired image dataset.
""" """
def __init__(self, opt): def __init__(self, cfg):
"""Initialize this dataset class. """Initialize this dataset class.
Args: Args:
cfg (dict) -- stores all the experiment flags cfg (dict) -- stores all the experiment flags
""" """
BaseDataset.__init__(self, opt) BaseDataset.__init__(self, cfg)
self.dir_AB = os.path.join(opt.dataroot, opt.phase) # get the image directory self.dir_AB = os.path.join(cfg.dataroot, cfg.phase) # get the image directory
self.AB_paths = sorted(make_dataset(self.dir_AB, opt.max_dataset_size)) # get image paths self.AB_paths = sorted(make_dataset(self.dir_AB, cfg.max_dataset_size)) # get image paths
assert(self.cfg.transform.load_size >= self.cfg.transform.crop_size) # crop_size should be smaller than the size of loaded image assert(self.cfg.transform.load_size >= self.cfg.transform.crop_size) # crop_size should be smaller than the size of loaded image
self.input_nc = self.cfg.output_nc if self.cfg.direction == 'BtoA' else self.cfg.input_nc self.input_nc = self.cfg.output_nc if self.cfg.direction == 'BtoA' else self.cfg.input_nc
self.output_nc = self.cfg.input_nc if self.cfg.direction == 'BtoA' else self.cfg.output_nc self.output_nc = self.cfg.input_nc if self.cfg.direction == 'BtoA' else self.cfg.output_nc
......
...@@ -8,7 +8,7 @@ from .builder import DATASETS ...@@ -8,7 +8,7 @@ from .builder import DATASETS
@DATASETS.register() @DATASETS.register()
class UnalignedDataset(BaseDataset): class UnpairedDataset(BaseDataset):
""" """
""" """
......
...@@ -5,7 +5,7 @@ from .builder import MODELS ...@@ -5,7 +5,7 @@ from .builder import MODELS
from .generators.builder import build_generator from .generators.builder import build_generator
from .discriminators.builder import build_discriminator from .discriminators.builder import build_discriminator
from .losses import GANLoss from .losses import GANLoss
# from ..modules.nn import L1Loss
from ..solver import build_optimizer from ..solver import build_optimizer
from ..utils.image_pool import ImagePool from ..utils.image_pool import ImagePool
...@@ -27,7 +27,7 @@ class CycleGANModel(BaseModel): ...@@ -27,7 +27,7 @@ class CycleGANModel(BaseModel):
"""Initialize the CycleGAN class. """Initialize the CycleGAN class.
Parameters: Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions opt (config)-- stores all the experiment flags; needs to be a subclass of Dict
""" """
BaseModel.__init__(self, opt) BaseModel.__init__(self, opt)
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses> # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
...@@ -35,12 +35,15 @@ class CycleGANModel(BaseModel): ...@@ -35,12 +35,15 @@ class CycleGANModel(BaseModel):
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals> # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
visual_names_A = ['real_A', 'fake_B', 'rec_A'] visual_names_A = ['real_A', 'fake_B', 'rec_A']
visual_names_B = ['real_B', 'fake_A', 'rec_B'] visual_names_B = ['real_B', 'fake_A', 'rec_B']
if self.isTrain and self.opt.lambda_identity > 0.0: # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
# if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
if self.isTrain and self.opt.lambda_identity > 0.0:
visual_names_A.append('idt_B') visual_names_A.append('idt_B')
visual_names_B.append('idt_A') visual_names_B.append('idt_A')
self.visual_names = visual_names_A + visual_names_B # combine visualizations for A and B # combine visualizations for A and B
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>. self.visual_names = visual_names_A + visual_names_B
# specify the models you want to save to the disk.
if self.isTrain: if self.isTrain:
self.model_names = ['G_A', 'G_B', 'D_A', 'D_B'] self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
else: # during test time, only load Gs else: # during test time, only load Gs
...@@ -59,22 +62,22 @@ class CycleGANModel(BaseModel): ...@@ -59,22 +62,22 @@ class CycleGANModel(BaseModel):
if self.isTrain: if self.isTrain:
if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels
assert(opt.dataset.train.input_nc == opt.dataset.train.output_nc) assert(opt.dataset.train.input_nc == opt.dataset.train.output_nc)
self.fake_A_pool = ImagePool(opt.dataset.train.pool_size) # create image buffer to store previously generated images # create image buffer to store previously generated images
self.fake_B_pool = ImagePool(opt.dataset.train.pool_size) # create image buffer to store previously generated images self.fake_A_pool = ImagePool(opt.dataset.train.pool_size)
# create image buffer to store previously generated images
self.fake_B_pool = ImagePool(opt.dataset.train.pool_size)
# define loss functions # define loss functions
self.criterionGAN = GANLoss(opt.model.gan_mode, [[[[1.0]]]], [[[[0.0]]]])#.to(self.device) # define GAN loss. self.criterionGAN = GANLoss(opt.model.gan_mode)
self.criterionCycle = paddle.nn.L1Loss() self.criterionCycle = paddle.nn.L1Loss()
self.criterionIdt = paddle.nn.L1Loss() self.criterionIdt = paddle.nn.L1Loss()
self.optimizer_G = build_optimizer(opt.optimizer, parameter_list=self.netG_A.parameters() + self.netG_B.parameters()) self.optimizer_G = build_optimizer(opt.optimizer, parameter_list=self.netG_A.parameters() + self.netG_B.parameters())
self.optimizer_D = build_optimizer(opt.optimizer, parameter_list=self.netD_A.parameters() + self.netD_B.parameters()) self.optimizer_D = build_optimizer(opt.optimizer, parameter_list=self.netD_A.parameters() + self.netD_B.parameters())
# self.optimizer_DA = build_optimizer(opt.optimizer, parameter_list=self.netD_A.parameters())
# self.optimizer_DB = build_optimizer(opt.optimizer, parameter_list=self.netD_B.parameters())
self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D) self.optimizers.append(self.optimizer_D)
# self.optimizers.append(self.optimizer_DA)
# self.optimizers.append(self.optimizer_DB) self.optimizer_names.extend(['optimizer_G', 'optimizer_D'])
self.optimizer_names.extend(['optimizer_G', 'optimizer_D'])#A', 'optimizer_DB'])
def set_input(self, input): def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps. """Unpack input data from the dataloader and perform necessary pre-processing steps.
...@@ -102,7 +105,7 @@ class CycleGANModel(BaseModel): ...@@ -102,7 +105,7 @@ class CycleGANModel(BaseModel):
self.image_paths = input['A_paths'] self.image_paths = input['A_paths']
elif 'B_paths' in input: elif 'B_paths' in input:
self.image_paths = input['B_paths'] self.image_paths = input['B_paths']
# self.image_paths = input['A_paths' if AtoB else 'B_paths']
def forward(self): def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>.""" """Run forward pass; called by both functions <optimize_parameters> and <test>."""
...@@ -115,20 +118,6 @@ class CycleGANModel(BaseModel): ...@@ -115,20 +118,6 @@ class CycleGANModel(BaseModel):
self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B)) self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B))
# def forward_test(self, input):
# input = paddle.imperative.to_variable(input)
# net_g = getattr(self, 'netG_' + self.opt.dataset.test.direction[0])
# return net_g(input)
# def test(self, input):
# """Forward function used in test time.
# This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
# It also calls <compute_visuals> to produce additional visualization results
# """
# with paddle.imperative.no_grad():
# return self.forward_test(input)
def backward_D_basic(self, netD, real, fake): def backward_D_basic(self, netD, real, fake):
"""Calculate GAN loss for the discriminator """Calculate GAN loss for the discriminator
...@@ -193,27 +182,26 @@ class CycleGANModel(BaseModel): ...@@ -193,27 +182,26 @@ class CycleGANModel(BaseModel):
def optimize_parameters(self): def optimize_parameters(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration""" """Calculate losses, gradients, and update network weights; called in every training iteration"""
# forward # forward
self.forward() # compute fake images and reconstruction images. # compute fake images and reconstruction images.
self.forward()
# G_A and G_B # G_A and G_B
self.set_requires_grad([self.netD_A, self.netD_B], False) # Ds require no gradients when optimizing Gs # Ds require no gradients when optimizing Gs
self.optimizer_G.clear_gradients() #zero_grad() # set G_A and G_B's gradients to zero self.set_requires_grad([self.netD_A, self.netD_B], False)
self.backward_G() # calculate gradients for G_A and G_B # set G_A and G_B's gradients to zero
self.optimizer_G.minimize(self.loss_G) #step() # update G_A and G_B's weights self.optimizer_G.clear_gradients()
# self.optimizer_G.clear_gradients() # calculate gradients for G_A and G_B
# self.optimizer_G.clear_gradients() self.backward_G()
# update G_A and G_B's weights
self.optimizer_G.minimize(self.loss_G)
# D_A and D_B # D_A and D_B
self.set_requires_grad([self.netD_A, self.netD_B], True) self.set_requires_grad([self.netD_A, self.netD_B], True)
# self.set_requires_grad(self.netD_A, True)
self.optimizer_D.clear_gradients() #zero_grad() # set D_A and D_B's gradients to zero # set D_A and D_B's gradients to zero
self.backward_D_A() # calculate gradients for D_A self.optimizer_D.clear_gradients()
self.backward_D_B() # calculate graidents for D_B # calculate gradients for D_A
self.optimizer_D.minimize(self.loss_D_A + self.loss_D_B) # update D_A and D_B's weights self.backward_D_A()
# self.backward_D_A() # calculate gradients for D_A # calculate graidents for D_B
# self.optimizer_DA.minimize(self.loss_D_A) #step() # update D_A and D_B's weights self.backward_D_B()
# self.optimizer_DA.clear_gradients() #zero_g # update D_A and D_B's weights
# self.set_requires_grad(self.netD_B, True) self.optimizer_D.minimize(self.loss_D_A + self.loss_D_B)
# self.optimizer_DB.clear_gradients() #zero_grad() # set D_A and D_B's gradients to zero
# self.backward_D_B() # calculate graidents for D_B
# self.optimizer_DB.minimize(self.loss_D_B) #step() # update D_A and D_B's weights
# self.optimizer_DB.clear_gradients() #zero_grad() # set D_A and D_B's gradients to zero
...@@ -4,6 +4,7 @@ import numpy as np ...@@ -4,6 +4,7 @@ import numpy as np
from ..modules.nn import BCEWithLogitsLoss from ..modules.nn import BCEWithLogitsLoss
class GANLoss(paddle.fluid.dygraph.Layer): class GANLoss(paddle.fluid.dygraph.Layer):
"""Define different GAN objectives. """Define different GAN objectives.
...@@ -23,16 +24,14 @@ class GANLoss(paddle.fluid.dygraph.Layer): ...@@ -23,16 +24,14 @@ class GANLoss(paddle.fluid.dygraph.Layer):
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
""" """
super(GANLoss, self).__init__() super(GANLoss, self).__init__()
self.real_label = paddle.fluid.dygraph.to_variable(np.array(target_real_label)) self.target_real_label = target_real_label
self.fake_label = paddle.fluid.dygraph.to_variable(np.array(target_fake_label)) self.target_fake_label = target_fake_label
# self.real_label.stop_gradients = True
# self.fake_label.stop_gradients = True
self.gan_mode = gan_mode self.gan_mode = gan_mode
if gan_mode == 'lsgan': if gan_mode == 'lsgan':
self.loss = nn.MSELoss() self.loss = nn.MSELoss()
elif gan_mode == 'vanilla': elif gan_mode == 'vanilla':
self.loss = BCEWithLogitsLoss()#nn.BCEWithLogitsLoss() self.loss = BCEWithLogitsLoss()
elif gan_mode in ['wgangp']: elif gan_mode in ['wgangp']:
self.loss = None self.loss = None
else: else:
...@@ -50,14 +49,16 @@ class GANLoss(paddle.fluid.dygraph.Layer): ...@@ -50,14 +49,16 @@ class GANLoss(paddle.fluid.dygraph.Layer):
""" """
if target_is_real: if target_is_real:
target_tensor = paddle.fill_constant(shape=paddle.shape(prediction), value=1.0, dtype='float32')#self.real_label if not hasattr(self, 'target_real_tensor'):
self.target_real_tensor = paddle.fill_constant(shape=paddle.shape(prediction), value=self.target_real_label, dtype='float32')
target_tensor = self.target_real_tensor
else: else:
target_tensor = paddle.fill_constant(shape=paddle.shape(prediction), value=0.0, dtype='float32')#self.fake_label if not hasattr(self, 'target_fake_tensor'):
self.target_fake_tensor = paddle.fill_constant(shape=paddle.shape(prediction), value=self.target_fake_label, dtype='float32')
target_tensor = self.target_fake_tensor
# target_tensor = paddle.cast(target_tensor, prediction.dtype)
# target_tensor = paddle.expand_as(target_tensor, prediction)
# target_tensor.stop_gradient = True # target_tensor.stop_gradient = True
return target_tensor#paddle.expand_as(target_tensor, prediction) return target_tensor
def __call__(self, prediction, target_is_real): def __call__(self, prediction, target_is_real):
"""Calculate loss given Discriminator's output and grount truth labels. """Calculate loss given Discriminator's output and grount truth labels.
......
# import torch
# import paddle
# from .base_model import BaseModel
# from . import networks
import paddle import paddle
from .base_model import BaseModel from .base_model import BaseModel
...@@ -9,7 +5,7 @@ from .builder import MODELS ...@@ -9,7 +5,7 @@ from .builder import MODELS
from .generators.builder import build_generator from .generators.builder import build_generator
from .discriminators.builder import build_discriminator from .discriminators.builder import build_discriminator
from .losses import GANLoss from .losses import GANLoss
# from ..modules.nn import L1Loss
from ..solver import build_optimizer from ..solver import build_optimizer
from ..utils.image_pool import ImagePool from ..utils.image_pool import ImagePool
...@@ -18,10 +14,10 @@ from ..utils.image_pool import ImagePool ...@@ -18,10 +14,10 @@ from ..utils.image_pool import ImagePool
class Pix2PixModel(BaseModel): class Pix2PixModel(BaseModel):
""" This class implements the pix2pix model, for learning a mapping from input images to output images given paired data. """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data.
The model training requires '--dataset_mode aligned' dataset. The model training requires 'paired' dataset.
By default, it uses a '--netG unet256' U-Net generator, By default, it uses a '--netG unet256' U-Net generator,
a '--netD basic' discriminator (PatchGAN), a '--netD basic' discriminator (from PatchGAN),
and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper). and a vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper).
pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf
""" """
...@@ -30,41 +26,37 @@ class Pix2PixModel(BaseModel): ...@@ -30,41 +26,37 @@ class Pix2PixModel(BaseModel):
"""Initialize the pix2pix class. """Initialize the pix2pix class.
Parameters: Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions opt (config dict)-- stores all the experiment flags; needs to be a subclass of Dict
""" """
BaseModel.__init__(self, opt) BaseModel.__init__(self, opt)
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses> # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake'] self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals> # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
self.visual_names = ['real_A', 'fake_B', 'real_B'] self.visual_names = ['real_A', 'fake_B', 'real_B']
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks> # specify the models you want to save to the disk.
if self.isTrain: if self.isTrain:
self.model_names = ['G', 'D'] self.model_names = ['G', 'D']
else: # during test time, only load G else:
# during test time, only load G
self.model_names = ['G'] self.model_names = ['G']
# define networks (both generator and discriminator) # define networks (both generator and discriminator)
self.netG = build_generator(opt.model.generator) self.netG = build_generator(opt.model.generator)
# self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
# not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
# define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
if self.isTrain:
self.netD = build_discriminator(opt.model.discriminator) self.netD = build_discriminator(opt.model.discriminator)
# self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
# opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
if self.isTrain: if self.isTrain:
# define loss functions # define loss functions
self.criterionGAN = GANLoss(opt.model.gan_mode, [[[[1.0]]]], [[[[0.0]]]])#.to(self.device) self.criterionGAN = GANLoss(opt.model.gan_mode)
self.criterionL1 = paddle.nn.L1Loss() self.criterionL1 = paddle.nn.L1Loss()
# initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
# self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) # build optimizers
# self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_G = build_optimizer(opt.optimizer, parameter_list=self.netG.parameters())
# FIXME: step per epoch self.optimizer_D = build_optimizer(opt.optimizer, parameter_list=self.netD.parameters())
# lr_scheduler_g = self.build_lr_scheduler(opt.lr, step_per_epoch=2975)
# lr_scheduler_d = self.build_lr_scheduler(opt.lr, step_per_epoch=2975)
# lr_scheduler = self.build_lr_scheduler()
self.optimizer_G = build_optimizer(opt.optimizer, parameter_list=self.netG.parameters()) #paddle.optimizer.Adam(learning_rate=lr_scheduler_g, parameter_list=self.netG.parameters(), beta1=opt.beta1)
self.optimizer_D = build_optimizer(opt.optimizer, parameter_list=self.netD.parameters()) #paddle.optimizer.Adam(learning_rate=lr_scheduler_d, parameter_list=self.netD.parameters(), beta1=opt.beta1)
self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D) self.optimizers.append(self.optimizer_D)
...@@ -78,16 +70,12 @@ class Pix2PixModel(BaseModel): ...@@ -78,16 +70,12 @@ class Pix2PixModel(BaseModel):
The option 'direction' can be used to swap images in domain A and domain B. The option 'direction' can be used to swap images in domain A and domain B.
""" """
# AtoB = self.opt.direction == 'AtoB'
# self.real_A = input['A' if AtoB else 'B'].to(self.device)
# self.real_B = input['B' if AtoB else 'A'].to(self.device)
# self.image_paths = input['A_paths' if AtoB else 'B_paths']
AtoB = self.opt.dataset.train.direction == 'AtoB' AtoB = self.opt.dataset.train.direction == 'AtoB'
self.real_A = paddle.imperative.to_variable(input['A' if AtoB else 'B']) self.real_A = paddle.imperative.to_variable(input['A' if AtoB else 'B'])
self.real_B = paddle.imperative.to_variable(input['B' if AtoB else 'A']) self.real_B = paddle.imperative.to_variable(input['B' if AtoB else 'A'])
self.image_paths = input['A_paths' if AtoB else 'B_paths'] self.image_paths = input['A_paths' if AtoB else 'B_paths']
# self.real_A = paddle.imperative.to_variable(input[0] if AtoB else input[1])
# self.real_B = paddle.imperative.to_variable(input[1] if AtoB else input[0])
def forward(self): def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>.""" """Run forward pass; called by both functions <optimize_parameters> and <test>."""
...@@ -97,19 +85,11 @@ class Pix2PixModel(BaseModel): ...@@ -97,19 +85,11 @@ class Pix2PixModel(BaseModel):
input = paddle.imperative.to_variable(input) input = paddle.imperative.to_variable(input)
return self.netG(input) return self.netG(input)
# def test(self, input):
# """Forward function used in test time.
# This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
# It also calls <compute_visuals> to produce additional visualization results
# """
# with paddle.imperative.no_grad():
# return self.forward_test(input)
def backward_D(self): def backward_D(self):
"""Calculate GAN loss for the discriminator""" """Calculate GAN loss for the discriminator"""
# Fake; stop backprop to the generator by detaching fake_B # Fake; stop backprop to the generator by detaching fake_B
fake_AB = paddle.concat((self.real_A, self.fake_B), 1) # we use conditional GANs; we need to feed both input and output to the discriminator # use conditional GANs; we need to feed both input and output to the discriminator
fake_AB = paddle.concat((self.real_A, self.fake_B), 1)
pred_fake = self.netD(fake_AB.detach()) pred_fake = self.netD(fake_AB.detach())
self.loss_D_fake = self.criterionGAN(pred_fake, False) self.loss_D_fake = self.criterionGAN(pred_fake, False)
# Real # Real
...@@ -134,16 +114,17 @@ class Pix2PixModel(BaseModel): ...@@ -134,16 +114,17 @@ class Pix2PixModel(BaseModel):
self.loss_G.backward() self.loss_G.backward()
def optimize_parameters(self): def optimize_parameters(self):
self.forward() # compute fake images: G(A) # compute fake images: G(A)
self.forward()
# update D # update D
self.set_requires_grad(self.netD, True) # enable backprop for D self.set_requires_grad(self.netD, True)
self.optimizer_D.clear_gradients() # set D's gradients to zero self.optimizer_D.clear_gradients()
self.backward_D() # calculate gradients for D self.backward_D()
self.optimizer_D.minimize(self.loss_D) # update D's weights self.optimizer_D.minimize(self.loss_D)
# self.netD.clear_gradients()
# self.optimizer_D.clear_gradients()
# update G # update G
self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G self.set_requires_grad(self.netD, False)
self.optimizer_G.clear_gradients() # set G's gradients to zero self.optimizer_G.clear_gradients()
self.backward_G() # calculate graidents for G self.backward_G()
self.optimizer_G.minimize(self.loss_G) # udpate G's weights self.optimizer_G.minimize(self.loss_G)
FILE=$1
URL=https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/$FILE.tar.gz
TAR_FILE=./$FILE.tar.gz
TARGET_DIR=./$FILE/
wget -N $URL -O $TAR_FILE --no-check-certificate
mkdir $TARGET_DIR
tar -zxvf $TAR_FILE -C ../data/
rm $TAR_FILE
rm -rf $TARGET_DIR
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册