未验证 提交 59cf316a 编写于 作者: J Jie Han 提交者: GitHub

docs: update install instruction (#88)

* docs: update install instruction

1. paddlepaddle2.0.0rc instruction can be found from main page
2. only two steps, install paddlepaddle and paddlegan, other than 3step

* docs: update en install
* feat: add dcgan
* reform: debug with mnist datset
validate performance with different options, and original structure and setting
上级 1893271b
epochs: 200
output_dir: output_dir
model:
name: DCGANModel
generator:
name: DCGenerator
norm_type: batch
input_nz: 100
input_nc: 1
output_nc: 1
ngf: 64
discriminator:
name: DCDiscriminator
norm_type: batch
ndf: 64
input_nc: 1
gan_mode: vanilla #wgangp
dataset:
train:
name: SingleDataset
dataroot: data/mnist/train
phase: train
max_dataset_size: inf
direction: AtoB
input_nc: 1
output_nc: 1
batch_size: 128
serial_batches: False
transforms:
- name: Resize
size: [64, 64]
interpolation: 'bicubic' #cv2.INTER_CUBIC
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
test:
name: SingleDataset
dataroot: data/mnist/test
max_dataset_size: inf
input_nc: 1
output_nc: 1
serial_batches: False
transforms:
- name: Resize
size: [64, 64]
interpolation: 'bicubic' #cv2.INTER_CUBIC
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
optimizer:
name: Adam
beta1: 0.5
lr_scheduler:
name: linear
learning_rate: 0.00002
start_epoch: 100
decay_epochs: 100
log_config:
interval: 100
visiual_interval: 500
snapshot_config:
interval: 5
......@@ -24,8 +24,11 @@ Note: command above will install paddle with cuda10.2,if your installed cuda i
</code></pre> </details> </td> <td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda9-cudnn7-mkl%2Fpaddlepaddle_gpu-2.0.0rc0.post90-cp36-cp36m-linux_x86_64.whl
</code></pre> </details> </td> </tr></tbody></table>
Visit home page of [paddlepaddle](https://www.paddlepaddle.org.cn/install/quick) for support of other systems, such as Windows10.
### 2. Install through pip
### 2. Install paddleGAN
#### 2.1 Install through pip
```
# only support Python3
......@@ -39,7 +42,7 @@ git clone https://github.com/PaddlePaddle/PaddleGAN
cd PaddleGAN
```
### 3. Install through source code
#### 2.2 Install through source code
```
git clone https://github.com/PaddlePaddle/PaddleGAN
......
......@@ -23,8 +23,11 @@ pip install -U paddlepaddle-gpu==2.0.0rc0
</code></pre> </details> </td> <td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda9-cudnn7-mkl%2Fpaddlepaddle_gpu-2.0.0rc0.post90-cp36-cp36m-linux_x86_64.whl
</code></pre> </details> </td> </tr></tbody></table>
### 2. 通过Pip安装
支持更多系统的安装教程请前往[paddlepaddle官网](https://www.paddlepaddle.org.cn/install/quick)
### 2. 安装PaddleGAN
##### 2.1 通过Pip安裝
```
# only support Python3
python3 -m pip install --upgrade ppgan
......@@ -37,7 +40,7 @@ git clone https://github.com/PaddlePaddle/PaddleGAN
cd PaddleGAN
```
### 3. 通过源码安装PaddleGAN
##### 2.2通过源码安装
```
git clone https://github.com/PaddlePaddle/PaddleGAN
......
......@@ -18,3 +18,4 @@ from .pix2pix_model import Pix2PixModel
from .srgan_model import SRGANModel
from .sr_model import SRModel
from .makeup_model import MakeupModel
from .dc_gan_model import DCGANModel
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from .base_model import BaseModel
from .builder import MODELS
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from .losses import GANLoss
from ..solver import build_optimizer
from ..modules.init import init_weights
@MODELS.register()
class DCGANModel(BaseModel):
""" This class implements the DCGAN model, for learning a distribution from input images.
The model training requires dataset.
By default, it uses a '--netG DCGenerator' generator,
a '--netD DCDiscriminator' discriminator,
and a vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper).
DCGAN paper: https://arxiv.org/pdf/1511.06434
"""
def __init__(self, cfg):
"""Initialize the DCGAN class.
Parameters:
opt (config dict)-- stores all the experiment flags; needs to be a subclass of Dict
"""
super(DCGANModel, self).__init__(cfg)
# define networks (both generator and discriminator)
self.nets['netG'] = build_generator(cfg.model.generator)
init_weights(self.nets['netG'])
self.cfg = cfg
if self.is_train:
self.nets['netD'] = build_discriminator(cfg.model.discriminator)
init_weights(self.nets['netD'])
if self.is_train:
self.losses = {}
# define loss functions
self.criterionGAN = GANLoss(cfg.model.gan_mode)
# build optimizers
self.build_lr_scheduler()
self.optimizers['optimizer_G'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.nets['netG'].parameters())
self.optimizers['optimizer_D'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.nets['netD'].parameters())
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): include the data itself and its metadata information.
"""
# get 1-channel gray image, or 3-channel color image
self.real = paddle.to_tensor(input['A'][:,0:self.cfg.model.generator.input_nc,:,:])
self.image_paths = input['A_paths']
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
# generate random noise and fake image
self.z = paddle.rand(shape=(self.real.shape[0],self.cfg.model.generator.input_nz,1,1))
self.fake = self.nets['netG'](self.z)
# put items to visual dict
self.visual_items['real'] = self.real
self.visual_items['fake'] = self.fake
def backward_D(self):
"""Calculate GAN loss for the discriminator"""
# Fake; stop backprop to the generator by detaching fake
pred_fake = self.nets['netD'](self.fake.detach())
self.loss_D_fake = self.criterionGAN(pred_fake, False)
pred_real = self.nets['netD'](self.real)
self.loss_D_real = self.criterionGAN(pred_real, True)
# combine loss and calculate gradients
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
self.loss_D.backward()
self.losses['D_fake_loss'] = self.loss_D_fake
self.losses['D_real_loss'] = self.loss_D_real
def backward_G(self):
"""Calculate GAN loss for the generator"""
# G(A) should fake the discriminator
pred_fake = self.nets['netD'](self.fake)
self.loss_G_GAN = self.criterionGAN(pred_fake, True)
# combine loss and calculate gradients
self.loss_G = self.loss_G_GAN
self.loss_G.backward()
self.losses['G_adv_loss'] = self.loss_G_GAN
def optimize_parameters(self):
# compute fake images: G(A)
self.forward()
#update D
self.set_requires_grad(self.nets['netD'], True)
self.set_requires_grad(self.nets['netG'], False)
self.optimizers['optimizer_D'].clear_grad()
self.backward_D()
self.optimizers['optimizer_D'].step()
# update G
self.set_requires_grad(self.nets['netD'], False)
self.set_requires_grad(self.nets['netG'], True)
self.optimizers['optimizer_G'].clear_grad()
self.backward_G()
self.optimizers['optimizer_G'].step()
\ No newline at end of file
......@@ -13,3 +13,4 @@
# limitations under the License.
from .nlayers import NLayerDiscriminator
from .dcdiscriminator import DCDiscriminator
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import functools
import numpy as np
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import BatchNorm2D
from ...modules.norm import build_norm_layer
from .builder import DISCRIMINATORS
@DISCRIMINATORS.register()
class DCDiscriminator(nn.Layer):
"""Defines a DCGAN discriminator"""
def __init__(self, input_nc, ndf=64, norm_type='instance'):
"""Construct a DCGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
norm_type (str) -- normalization layer type
"""
super(DCDiscriminator, self).__init__()
norm_layer = build_norm_layer(norm_type)
if type(
norm_layer
) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func == nn.BatchNorm2D
else:
use_bias = norm_layer == nn.BatchNorm2D
kw = 4
padw = 1
sequence = [
nn.Conv2D(input_nc,
ndf,
kernel_size=kw,
stride=2,
padding=padw,
bias_attr=use_bias),
nn.LeakyReLU(0.2)
]
nf_mult = 1
nf_mult_prev = 1
n_downsampling = 4
for n in range(1, n_downsampling): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
if norm_type == 'batch':
sequence += [
nn.Conv2D(ndf * nf_mult_prev,
ndf * nf_mult,
kernel_size=kw,
stride=2,
padding=padw),
BatchNorm2D(ndf * nf_mult),
nn.LeakyReLU(0.2)
]
else:
sequence += [
nn.Conv2D(ndf * nf_mult_prev,
ndf * nf_mult,
kernel_size=kw,
stride=2,
padding=padw,
bias_attr=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2)
]
nf_mult_prev = nf_mult
sequence += [
nn.Conv2D(ndf * nf_mult_prev,
1,
kernel_size=kw,
stride=1,
padding=0)
] # output 1 channel prediction map
self.model = nn.Sequential(*sequence)
def forward(self, input):
"""Standard forward."""
return self.model(input)
......@@ -16,3 +16,4 @@ from .resnet import ResnetGenerator
from .unet import UnetGenerator
from .rrdb_net import RRDBNet
from .makeup import GeneratorPSGANAttention
from .dcgenerator import DCGenerator
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import functools
from paddle.nn import BatchNorm2D
from ...modules.norm import build_norm_layer
from .builder import GENERATORS
@GENERATORS.register()
class DCGenerator(nn.Layer):
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
"""
def __init__(self,
input_nz,
input_nc,
output_nc,
ngf=64,
norm_type='batch',
padding_type='reflect'):
"""Construct a DCGenerator generator
Args:
input_nz (int) -- the number of dimension in input noise
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
"""
super(DCGenerator, self).__init__()
norm_layer = build_norm_layer(norm_type)
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.BatchNorm2D
else:
use_bias = norm_layer == nn.BatchNorm2D
mult = 8
n_downsampling = 4
if norm_type == 'batch':
model = [
nn.Conv2DTranspose(input_nz,
ngf * mult,
kernel_size=4,
stride=1,
padding=0,
bias_attr=use_bias),
BatchNorm2D(ngf * mult),
nn.ReLU()
]
else:
model = [
nn.Conv2DTranspose(input_nz,
ngf * mult,
kernel_size=4,
stride=1,
padding=0,
bias_attr=use_bias),
norm_layer(ngf * mult),
nn.ReLU()
]
for i in range(1,n_downsampling): # add upsampling layers
mult = 2**(n_downsampling - i)
output_size = 2**(i+2)
if norm_type == 'batch':
model += [
nn.Conv2DTranspose(ngf * mult,
ngf * mult//2,
kernel_size=4,
stride=2,
padding=1,
bias_attr=use_bias),
BatchNorm2D(ngf * mult//2),
nn.ReLU()
]
else:
model += [
nn.Conv2DTranspose(ngf * mult,
int(ngf * mult//2),
kernel_size=4,
stride=2,
padding=1,
bias_attr=use_bias),
norm_layer(int(ngf * mult // 2)),
nn.ReLU()
]
output_size = 2**(6)
model += [
nn.Conv2DTranspose(ngf ,
output_nc,
kernel_size=4,
stride=2,
padding=1,
bias_attr=use_bias),
nn.Tanh()
]
self.model = nn.Sequential(*model)
def forward(self, x):
"""Standard forward"""
return self.model(x)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册