diff --git a/configs/dcgan_mnist.yaml b/configs/dcgan_mnist.yaml new file mode 100644 index 0000000000000000000000000000000000000000..89a2ad4fb4d0f4711ed9f0f3c01b8de1191cb375 --- /dev/null +++ b/configs/dcgan_mnist.yaml @@ -0,0 +1,70 @@ +epochs: 200 +output_dir: output_dir + +model: + name: DCGANModel + generator: + name: DCGenerator + norm_type: batch + input_nz: 100 + input_nc: 1 + output_nc: 1 + ngf: 64 + discriminator: + name: DCDiscriminator + norm_type: batch + ndf: 64 + input_nc: 1 + gan_mode: vanilla #wgangp + +dataset: + train: + name: SingleDataset + dataroot: data/mnist/train + phase: train + max_dataset_size: inf + direction: AtoB + input_nc: 1 + output_nc: 1 + batch_size: 128 + serial_batches: False + transforms: + - name: Resize + size: [64, 64] + interpolation: 'bicubic' #cv2.INTER_CUBIC + - name: Transpose + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + test: + name: SingleDataset + dataroot: data/mnist/test + max_dataset_size: inf + input_nc: 1 + output_nc: 1 + serial_batches: False + transforms: + - name: Resize + size: [64, 64] + interpolation: 'bicubic' #cv2.INTER_CUBIC + - name: Transpose + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + +optimizer: + name: Adam + beta1: 0.5 + +lr_scheduler: + name: linear + learning_rate: 0.00002 + start_epoch: 100 + decay_epochs: 100 + +log_config: + interval: 100 + visiual_interval: 500 + +snapshot_config: + interval: 5 diff --git a/docs/en_US/install.md b/docs/en_US/install.md index 748749da0dcd36cbcc043c4f2815665dc7272480..94de4f6c9bdefd4eb848563bf06c30db93c65c0f 100644 --- a/docs/en_US/install.md +++ b/docs/en_US/install.md @@ -24,8 +24,11 @@ Note: command above will install paddle with cuda10.2,if your installed cuda i
install
python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda9-cudnn7-mkl%2Fpaddlepaddle_gpu-2.0.0rc0.post90-cp36-cp36m-linux_x86_64.whl
 
+Visit home page of [paddlepaddle](https://www.paddlepaddle.org.cn/install/quick) for support of other systems, such as Windows10. -### 2. Install through pip +### 2. Install paddleGAN + +#### 2.1 Install through pip ``` # only support Python3 @@ -39,7 +42,7 @@ git clone https://github.com/PaddlePaddle/PaddleGAN cd PaddleGAN ``` -### 3. Install through source code +#### 2.2 Install through source code ``` git clone https://github.com/PaddlePaddle/PaddleGAN diff --git a/docs/zh_CN/install.md b/docs/zh_CN/install.md index c03ab636263700d09dd41ad07a113d1099cbc049..70c4205473a33d17578a3b0e3bfc9ff8cd0865de 100644 --- a/docs/zh_CN/install.md +++ b/docs/zh_CN/install.md @@ -23,8 +23,11 @@ pip install -U paddlepaddle-gpu==2.0.0rc0
install
python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda9-cudnn7-mkl%2Fpaddlepaddle_gpu-2.0.0rc0.post90-cp36-cp36m-linux_x86_64.whl
 
-### 2. 通过Pip安装 +支持更多系统的安装教程请前往[paddlepaddle官网](https://www.paddlepaddle.org.cn/install/quick) +### 2. 安装PaddleGAN + +##### 2.1 通过Pip安裝 ``` # only support Python3 python3 -m pip install --upgrade ppgan @@ -37,7 +40,7 @@ git clone https://github.com/PaddlePaddle/PaddleGAN cd PaddleGAN ``` -### 3. 通过源码安装PaddleGAN +##### 2.2通过源码安装 ``` git clone https://github.com/PaddlePaddle/PaddleGAN diff --git a/ppgan/models/__init__.py b/ppgan/models/__init__.py index 9b7b4bc99c9f00c49684fa03e26420a9619f0b0a..4a2d433c3c8bf60b7ee0f79c4845b4707c4974c9 100644 --- a/ppgan/models/__init__.py +++ b/ppgan/models/__init__.py @@ -18,3 +18,4 @@ from .pix2pix_model import Pix2PixModel from .srgan_model import SRGANModel from .sr_model import SRModel from .makeup_model import MakeupModel +from .dc_gan_model import DCGANModel diff --git a/ppgan/models/dc_gan_model.py b/ppgan/models/dc_gan_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e279527357056ba65d915e8ac73727923014195d --- /dev/null +++ b/ppgan/models/dc_gan_model.py @@ -0,0 +1,136 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from .base_model import BaseModel + +from .builder import MODELS +from .generators.builder import build_generator +from .discriminators.builder import build_discriminator +from .losses import GANLoss + +from ..solver import build_optimizer +from ..modules.init import init_weights + + +@MODELS.register() +class DCGANModel(BaseModel): + """ This class implements the DCGAN model, for learning a distribution from input images. + + The model training requires dataset. + By default, it uses a '--netG DCGenerator' generator, + a '--netD DCDiscriminator' discriminator, + and a vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper). + + DCGAN paper: https://arxiv.org/pdf/1511.06434 + """ + def __init__(self, cfg): + """Initialize the DCGAN class. + + Parameters: + opt (config dict)-- stores all the experiment flags; needs to be a subclass of Dict + """ + super(DCGANModel, self).__init__(cfg) + # define networks (both generator and discriminator) + self.nets['netG'] = build_generator(cfg.model.generator) + init_weights(self.nets['netG']) + self.cfg = cfg + if self.is_train: + self.nets['netD'] = build_discriminator(cfg.model.discriminator) + init_weights(self.nets['netD']) + + if self.is_train: + self.losses = {} + # define loss functions + self.criterionGAN = GANLoss(cfg.model.gan_mode) + + # build optimizers + self.build_lr_scheduler() + self.optimizers['optimizer_G'] = build_optimizer( + cfg.optimizer, + self.lr_scheduler, + parameter_list=self.nets['netG'].parameters()) + self.optimizers['optimizer_D'] = build_optimizer( + cfg.optimizer, + self.lr_scheduler, + parameter_list=self.nets['netD'].parameters()) + + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input (dict): include the data itself and its metadata information. + """ + # get 1-channel gray image, or 3-channel color image + self.real = paddle.to_tensor(input['A'][:,0:self.cfg.model.generator.input_nc,:,:]) + self.image_paths = input['A_paths'] + + def forward(self): + """Run forward pass; called by both functions and .""" + + # generate random noise and fake image + self.z = paddle.rand(shape=(self.real.shape[0],self.cfg.model.generator.input_nz,1,1)) + self.fake = self.nets['netG'](self.z) + + # put items to visual dict + self.visual_items['real'] = self.real + self.visual_items['fake'] = self.fake + + def backward_D(self): + """Calculate GAN loss for the discriminator""" + # Fake; stop backprop to the generator by detaching fake + pred_fake = self.nets['netD'](self.fake.detach()) + self.loss_D_fake = self.criterionGAN(pred_fake, False) + + pred_real = self.nets['netD'](self.real) + self.loss_D_real = self.criterionGAN(pred_real, True) + + # combine loss and calculate gradients + self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 + + self.loss_D.backward() + + self.losses['D_fake_loss'] = self.loss_D_fake + self.losses['D_real_loss'] = self.loss_D_real + + def backward_G(self): + """Calculate GAN loss for the generator""" + # G(A) should fake the discriminator + pred_fake = self.nets['netD'](self.fake) + self.loss_G_GAN = self.criterionGAN(pred_fake, True) + + # combine loss and calculate gradients + self.loss_G = self.loss_G_GAN + + self.loss_G.backward() + + self.losses['G_adv_loss'] = self.loss_G_GAN + + def optimize_parameters(self): + # compute fake images: G(A) + self.forward() + + #update D + self.set_requires_grad(self.nets['netD'], True) + self.set_requires_grad(self.nets['netG'], False) + self.optimizers['optimizer_D'].clear_grad() + self.backward_D() + self.optimizers['optimizer_D'].step() + + # update G + self.set_requires_grad(self.nets['netD'], False) + self.set_requires_grad(self.nets['netG'], True) + self.optimizers['optimizer_G'].clear_grad() + self.backward_G() + self.optimizers['optimizer_G'].step() \ No newline at end of file diff --git a/ppgan/models/discriminators/__init__.py b/ppgan/models/discriminators/__init__.py index 436724f6561d094f23ba1e9993a4be46c1f88d5b..caa2f8b410f7b9a7c49b099dd08b6a4b70c2391f 100644 --- a/ppgan/models/discriminators/__init__.py +++ b/ppgan/models/discriminators/__init__.py @@ -13,3 +13,4 @@ # limitations under the License. from .nlayers import NLayerDiscriminator +from .dcdiscriminator import DCDiscriminator diff --git a/ppgan/models/discriminators/dcdiscriminator.py b/ppgan/models/discriminators/dcdiscriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..f66b49a84ebca5b5d14e6f2f00ff065e035973bd --- /dev/null +++ b/ppgan/models/discriminators/dcdiscriminator.py @@ -0,0 +1,103 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import functools +import numpy as np +import paddle.nn as nn +import paddle.nn.functional as F + +from paddle.nn import BatchNorm2D +from ...modules.norm import build_norm_layer + +from .builder import DISCRIMINATORS + + +@DISCRIMINATORS.register() +class DCDiscriminator(nn.Layer): + """Defines a DCGAN discriminator""" + def __init__(self, input_nc, ndf=64, norm_type='instance'): + """Construct a DCGAN discriminator + + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + norm_type (str) -- normalization layer type + """ + super(DCDiscriminator, self).__init__() + norm_layer = build_norm_layer(norm_type) + if type( + norm_layer + ) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func == nn.BatchNorm2D + else: + use_bias = norm_layer == nn.BatchNorm2D + + kw = 4 + padw = 1 + + sequence = [ + nn.Conv2D(input_nc, + ndf, + kernel_size=kw, + stride=2, + padding=padw, + bias_attr=use_bias), + nn.LeakyReLU(0.2) + ] + + nf_mult = 1 + nf_mult_prev = 1 + n_downsampling = 4 + + for n in range(1, n_downsampling): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2**n, 8) + if norm_type == 'batch': + sequence += [ + nn.Conv2D(ndf * nf_mult_prev, + ndf * nf_mult, + kernel_size=kw, + stride=2, + padding=padw), + BatchNorm2D(ndf * nf_mult), + nn.LeakyReLU(0.2) + ] + else: + sequence += [ + nn.Conv2D(ndf * nf_mult_prev, + ndf * nf_mult, + kernel_size=kw, + stride=2, + padding=padw, + bias_attr=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2) + ] + + nf_mult_prev = nf_mult + + sequence += [ + nn.Conv2D(ndf * nf_mult_prev, + 1, + kernel_size=kw, + stride=1, + padding=0) + ] # output 1 channel prediction map + + self.model = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.model(input) diff --git a/ppgan/models/generators/__init__.py b/ppgan/models/generators/__init__.py index 6f6b433a35b579dfcce0e1760480eb2beac613af..e407542479e8fe1707b9977be2a18246d91a9654 100644 --- a/ppgan/models/generators/__init__.py +++ b/ppgan/models/generators/__init__.py @@ -16,3 +16,4 @@ from .resnet import ResnetGenerator from .unet import UnetGenerator from .rrdb_net import RRDBNet from .makeup import GeneratorPSGANAttention +from .dcgenerator import DCGenerator diff --git a/ppgan/models/generators/dcgenerator.py b/ppgan/models/generators/dcgenerator.py new file mode 100644 index 0000000000000000000000000000000000000000..5bbdbb33e84358583d0d30f3596c0880e2a52711 --- /dev/null +++ b/ppgan/models/generators/dcgenerator.py @@ -0,0 +1,123 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import functools + +from paddle.nn import BatchNorm2D +from ...modules.norm import build_norm_layer + +from .builder import GENERATORS + + +@GENERATORS.register() +class DCGenerator(nn.Layer): + """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations. + + code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) + """ + def __init__(self, + input_nz, + input_nc, + output_nc, + ngf=64, + norm_type='batch', + padding_type='reflect'): + """Construct a DCGenerator generator + + Args: + input_nz (int) -- the number of dimension in input noise + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + ngf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero + """ + super(DCGenerator, self).__init__() + + norm_layer = build_norm_layer(norm_type) + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.BatchNorm2D + else: + use_bias = norm_layer == nn.BatchNorm2D + + mult = 8 + n_downsampling = 4 + + if norm_type == 'batch': + model = [ + nn.Conv2DTranspose(input_nz, + ngf * mult, + kernel_size=4, + stride=1, + padding=0, + bias_attr=use_bias), + BatchNorm2D(ngf * mult), + nn.ReLU() + ] + else: + model = [ + nn.Conv2DTranspose(input_nz, + ngf * mult, + kernel_size=4, + stride=1, + padding=0, + bias_attr=use_bias), + norm_layer(ngf * mult), + nn.ReLU() + ] + + for i in range(1,n_downsampling): # add upsampling layers + mult = 2**(n_downsampling - i) + output_size = 2**(i+2) + if norm_type == 'batch': + model += [ + nn.Conv2DTranspose(ngf * mult, + ngf * mult//2, + kernel_size=4, + stride=2, + padding=1, + bias_attr=use_bias), + BatchNorm2D(ngf * mult//2), + nn.ReLU() + ] + else: + model += [ + nn.Conv2DTranspose(ngf * mult, + int(ngf * mult//2), + kernel_size=4, + stride=2, + padding=1, + bias_attr=use_bias), + norm_layer(int(ngf * mult // 2)), + nn.ReLU() + ] + + output_size = 2**(6) + model += [ + nn.Conv2DTranspose(ngf , + output_nc, + kernel_size=4, + stride=2, + padding=1, + bias_attr=use_bias), + nn.Tanh() + ] + + self.model = nn.Sequential(*model) + + def forward(self, x): + """Standard forward""" + return self.model(x) \ No newline at end of file