diff --git a/configs/dcgan_mnist.yaml b/configs/dcgan_mnist.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..89a2ad4fb4d0f4711ed9f0f3c01b8de1191cb375
--- /dev/null
+++ b/configs/dcgan_mnist.yaml
@@ -0,0 +1,70 @@
+epochs: 200
+output_dir: output_dir
+
+model:
+ name: DCGANModel
+ generator:
+ name: DCGenerator
+ norm_type: batch
+ input_nz: 100
+ input_nc: 1
+ output_nc: 1
+ ngf: 64
+ discriminator:
+ name: DCDiscriminator
+ norm_type: batch
+ ndf: 64
+ input_nc: 1
+ gan_mode: vanilla #wgangp
+
+dataset:
+ train:
+ name: SingleDataset
+ dataroot: data/mnist/train
+ phase: train
+ max_dataset_size: inf
+ direction: AtoB
+ input_nc: 1
+ output_nc: 1
+ batch_size: 128
+ serial_batches: False
+ transforms:
+ - name: Resize
+ size: [64, 64]
+ interpolation: 'bicubic' #cv2.INTER_CUBIC
+ - name: Transpose
+ - name: Normalize
+ mean: [127.5, 127.5, 127.5]
+ std: [127.5, 127.5, 127.5]
+ test:
+ name: SingleDataset
+ dataroot: data/mnist/test
+ max_dataset_size: inf
+ input_nc: 1
+ output_nc: 1
+ serial_batches: False
+ transforms:
+ - name: Resize
+ size: [64, 64]
+ interpolation: 'bicubic' #cv2.INTER_CUBIC
+ - name: Transpose
+ - name: Normalize
+ mean: [127.5, 127.5, 127.5]
+ std: [127.5, 127.5, 127.5]
+
+optimizer:
+ name: Adam
+ beta1: 0.5
+
+lr_scheduler:
+ name: linear
+ learning_rate: 0.00002
+ start_epoch: 100
+ decay_epochs: 100
+
+log_config:
+ interval: 100
+ visiual_interval: 500
+
+snapshot_config:
+ interval: 5
diff --git a/docs/en_US/install.md b/docs/en_US/install.md
index 748749da0dcd36cbcc043c4f2815665dc7272480..94de4f6c9bdefd4eb848563bf06c30db93c65c0f 100644
--- a/docs/en_US/install.md
+++ b/docs/en_US/install.md
@@ -24,8 +24,11 @@ Note: command above will install paddle with cuda10.2,if your installed cuda i
install python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda9-cudnn7-mkl%2Fpaddlepaddle_gpu-2.0.0rc0.post90-cp36-cp36m-linux_x86_64.whl
|
+Visit home page of [paddlepaddle](https://www.paddlepaddle.org.cn/install/quick) for support of other systems, such as Windows10.
-### 2. Install through pip
+### 2. Install paddleGAN
+
+#### 2.1 Install through pip
```
# only support Python3
@@ -39,7 +42,7 @@ git clone https://github.com/PaddlePaddle/PaddleGAN
cd PaddleGAN
```
-### 3. Install through source code
+#### 2.2 Install through source code
```
git clone https://github.com/PaddlePaddle/PaddleGAN
diff --git a/docs/zh_CN/install.md b/docs/zh_CN/install.md
index c03ab636263700d09dd41ad07a113d1099cbc049..70c4205473a33d17578a3b0e3bfc9ff8cd0865de 100644
--- a/docs/zh_CN/install.md
+++ b/docs/zh_CN/install.md
@@ -23,8 +23,11 @@ pip install -U paddlepaddle-gpu==2.0.0rc0
install python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda9-cudnn7-mkl%2Fpaddlepaddle_gpu-2.0.0rc0.post90-cp36-cp36m-linux_x86_64.whl
|
-### 2. 通过Pip安装
+支持更多系统的安装教程请前往[paddlepaddle官网](https://www.paddlepaddle.org.cn/install/quick)
+### 2. 安装PaddleGAN
+
+##### 2.1 通过Pip安裝
```
# only support Python3
python3 -m pip install --upgrade ppgan
@@ -37,7 +40,7 @@ git clone https://github.com/PaddlePaddle/PaddleGAN
cd PaddleGAN
```
-### 3. 通过源码安装PaddleGAN
+##### 2.2通过源码安装
```
git clone https://github.com/PaddlePaddle/PaddleGAN
diff --git a/ppgan/models/__init__.py b/ppgan/models/__init__.py
index 9b7b4bc99c9f00c49684fa03e26420a9619f0b0a..4a2d433c3c8bf60b7ee0f79c4845b4707c4974c9 100644
--- a/ppgan/models/__init__.py
+++ b/ppgan/models/__init__.py
@@ -18,3 +18,4 @@ from .pix2pix_model import Pix2PixModel
from .srgan_model import SRGANModel
from .sr_model import SRModel
from .makeup_model import MakeupModel
+from .dc_gan_model import DCGANModel
diff --git a/ppgan/models/dc_gan_model.py b/ppgan/models/dc_gan_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..e279527357056ba65d915e8ac73727923014195d
--- /dev/null
+++ b/ppgan/models/dc_gan_model.py
@@ -0,0 +1,136 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+from .base_model import BaseModel
+
+from .builder import MODELS
+from .generators.builder import build_generator
+from .discriminators.builder import build_discriminator
+from .losses import GANLoss
+
+from ..solver import build_optimizer
+from ..modules.init import init_weights
+
+
+@MODELS.register()
+class DCGANModel(BaseModel):
+ """ This class implements the DCGAN model, for learning a distribution from input images.
+
+ The model training requires dataset.
+ By default, it uses a '--netG DCGenerator' generator,
+ a '--netD DCDiscriminator' discriminator,
+ and a vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper).
+
+ DCGAN paper: https://arxiv.org/pdf/1511.06434
+ """
+ def __init__(self, cfg):
+ """Initialize the DCGAN class.
+
+ Parameters:
+ opt (config dict)-- stores all the experiment flags; needs to be a subclass of Dict
+ """
+ super(DCGANModel, self).__init__(cfg)
+ # define networks (both generator and discriminator)
+ self.nets['netG'] = build_generator(cfg.model.generator)
+ init_weights(self.nets['netG'])
+ self.cfg = cfg
+ if self.is_train:
+ self.nets['netD'] = build_discriminator(cfg.model.discriminator)
+ init_weights(self.nets['netD'])
+
+ if self.is_train:
+ self.losses = {}
+ # define loss functions
+ self.criterionGAN = GANLoss(cfg.model.gan_mode)
+
+ # build optimizers
+ self.build_lr_scheduler()
+ self.optimizers['optimizer_G'] = build_optimizer(
+ cfg.optimizer,
+ self.lr_scheduler,
+ parameter_list=self.nets['netG'].parameters())
+ self.optimizers['optimizer_D'] = build_optimizer(
+ cfg.optimizer,
+ self.lr_scheduler,
+ parameter_list=self.nets['netD'].parameters())
+
+ def set_input(self, input):
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
+
+ Parameters:
+ input (dict): include the data itself and its metadata information.
+ """
+ # get 1-channel gray image, or 3-channel color image
+ self.real = paddle.to_tensor(input['A'][:,0:self.cfg.model.generator.input_nc,:,:])
+ self.image_paths = input['A_paths']
+
+ def forward(self):
+ """Run forward pass; called by both functions and ."""
+
+ # generate random noise and fake image
+ self.z = paddle.rand(shape=(self.real.shape[0],self.cfg.model.generator.input_nz,1,1))
+ self.fake = self.nets['netG'](self.z)
+
+ # put items to visual dict
+ self.visual_items['real'] = self.real
+ self.visual_items['fake'] = self.fake
+
+ def backward_D(self):
+ """Calculate GAN loss for the discriminator"""
+ # Fake; stop backprop to the generator by detaching fake
+ pred_fake = self.nets['netD'](self.fake.detach())
+ self.loss_D_fake = self.criterionGAN(pred_fake, False)
+
+ pred_real = self.nets['netD'](self.real)
+ self.loss_D_real = self.criterionGAN(pred_real, True)
+
+ # combine loss and calculate gradients
+ self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
+
+ self.loss_D.backward()
+
+ self.losses['D_fake_loss'] = self.loss_D_fake
+ self.losses['D_real_loss'] = self.loss_D_real
+
+ def backward_G(self):
+ """Calculate GAN loss for the generator"""
+ # G(A) should fake the discriminator
+ pred_fake = self.nets['netD'](self.fake)
+ self.loss_G_GAN = self.criterionGAN(pred_fake, True)
+
+ # combine loss and calculate gradients
+ self.loss_G = self.loss_G_GAN
+
+ self.loss_G.backward()
+
+ self.losses['G_adv_loss'] = self.loss_G_GAN
+
+ def optimize_parameters(self):
+ # compute fake images: G(A)
+ self.forward()
+
+ #update D
+ self.set_requires_grad(self.nets['netD'], True)
+ self.set_requires_grad(self.nets['netG'], False)
+ self.optimizers['optimizer_D'].clear_grad()
+ self.backward_D()
+ self.optimizers['optimizer_D'].step()
+
+ # update G
+ self.set_requires_grad(self.nets['netD'], False)
+ self.set_requires_grad(self.nets['netG'], True)
+ self.optimizers['optimizer_G'].clear_grad()
+ self.backward_G()
+ self.optimizers['optimizer_G'].step()
\ No newline at end of file
diff --git a/ppgan/models/discriminators/__init__.py b/ppgan/models/discriminators/__init__.py
index 436724f6561d094f23ba1e9993a4be46c1f88d5b..caa2f8b410f7b9a7c49b099dd08b6a4b70c2391f 100644
--- a/ppgan/models/discriminators/__init__.py
+++ b/ppgan/models/discriminators/__init__.py
@@ -13,3 +13,4 @@
# limitations under the License.
from .nlayers import NLayerDiscriminator
+from .dcdiscriminator import DCDiscriminator
diff --git a/ppgan/models/discriminators/dcdiscriminator.py b/ppgan/models/discriminators/dcdiscriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..f66b49a84ebca5b5d14e6f2f00ff065e035973bd
--- /dev/null
+++ b/ppgan/models/discriminators/dcdiscriminator.py
@@ -0,0 +1,103 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import functools
+import numpy as np
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+from paddle.nn import BatchNorm2D
+from ...modules.norm import build_norm_layer
+
+from .builder import DISCRIMINATORS
+
+
+@DISCRIMINATORS.register()
+class DCDiscriminator(nn.Layer):
+ """Defines a DCGAN discriminator"""
+ def __init__(self, input_nc, ndf=64, norm_type='instance'):
+ """Construct a DCGAN discriminator
+
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ ndf (int) -- the number of filters in the last conv layer
+ norm_type (str) -- normalization layer type
+ """
+ super(DCDiscriminator, self).__init__()
+ norm_layer = build_norm_layer(norm_type)
+ if type(
+ norm_layer
+ ) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
+ use_bias = norm_layer.func == nn.BatchNorm2D
+ else:
+ use_bias = norm_layer == nn.BatchNorm2D
+
+ kw = 4
+ padw = 1
+
+ sequence = [
+ nn.Conv2D(input_nc,
+ ndf,
+ kernel_size=kw,
+ stride=2,
+ padding=padw,
+ bias_attr=use_bias),
+ nn.LeakyReLU(0.2)
+ ]
+
+ nf_mult = 1
+ nf_mult_prev = 1
+ n_downsampling = 4
+
+ for n in range(1, n_downsampling): # gradually increase the number of filters
+ nf_mult_prev = nf_mult
+ nf_mult = min(2**n, 8)
+ if norm_type == 'batch':
+ sequence += [
+ nn.Conv2D(ndf * nf_mult_prev,
+ ndf * nf_mult,
+ kernel_size=kw,
+ stride=2,
+ padding=padw),
+ BatchNorm2D(ndf * nf_mult),
+ nn.LeakyReLU(0.2)
+ ]
+ else:
+ sequence += [
+ nn.Conv2D(ndf * nf_mult_prev,
+ ndf * nf_mult,
+ kernel_size=kw,
+ stride=2,
+ padding=padw,
+ bias_attr=use_bias),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2)
+ ]
+
+ nf_mult_prev = nf_mult
+
+ sequence += [
+ nn.Conv2D(ndf * nf_mult_prev,
+ 1,
+ kernel_size=kw,
+ stride=1,
+ padding=0)
+ ] # output 1 channel prediction map
+
+ self.model = nn.Sequential(*sequence)
+
+ def forward(self, input):
+ """Standard forward."""
+ return self.model(input)
diff --git a/ppgan/models/generators/__init__.py b/ppgan/models/generators/__init__.py
index 6f6b433a35b579dfcce0e1760480eb2beac613af..e407542479e8fe1707b9977be2a18246d91a9654 100644
--- a/ppgan/models/generators/__init__.py
+++ b/ppgan/models/generators/__init__.py
@@ -16,3 +16,4 @@ from .resnet import ResnetGenerator
from .unet import UnetGenerator
from .rrdb_net import RRDBNet
from .makeup import GeneratorPSGANAttention
+from .dcgenerator import DCGenerator
diff --git a/ppgan/models/generators/dcgenerator.py b/ppgan/models/generators/dcgenerator.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bbdbb33e84358583d0d30f3596c0880e2a52711
--- /dev/null
+++ b/ppgan/models/generators/dcgenerator.py
@@ -0,0 +1,123 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import paddle.nn as nn
+import functools
+
+from paddle.nn import BatchNorm2D
+from ...modules.norm import build_norm_layer
+
+from .builder import GENERATORS
+
+
+@GENERATORS.register()
+class DCGenerator(nn.Layer):
+ """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
+
+ code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
+ """
+ def __init__(self,
+ input_nz,
+ input_nc,
+ output_nc,
+ ngf=64,
+ norm_type='batch',
+ padding_type='reflect'):
+ """Construct a DCGenerator generator
+
+ Args:
+ input_nz (int) -- the number of dimension in input noise
+ input_nc (int) -- the number of channels in input images
+ output_nc (int) -- the number of channels in output images
+ ngf (int) -- the number of filters in the last conv layer
+ norm_layer -- normalization layer
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
+ """
+ super(DCGenerator, self).__init__()
+
+ norm_layer = build_norm_layer(norm_type)
+ if type(norm_layer) == functools.partial:
+ use_bias = norm_layer.func == nn.BatchNorm2D
+ else:
+ use_bias = norm_layer == nn.BatchNorm2D
+
+ mult = 8
+ n_downsampling = 4
+
+ if norm_type == 'batch':
+ model = [
+ nn.Conv2DTranspose(input_nz,
+ ngf * mult,
+ kernel_size=4,
+ stride=1,
+ padding=0,
+ bias_attr=use_bias),
+ BatchNorm2D(ngf * mult),
+ nn.ReLU()
+ ]
+ else:
+ model = [
+ nn.Conv2DTranspose(input_nz,
+ ngf * mult,
+ kernel_size=4,
+ stride=1,
+ padding=0,
+ bias_attr=use_bias),
+ norm_layer(ngf * mult),
+ nn.ReLU()
+ ]
+
+ for i in range(1,n_downsampling): # add upsampling layers
+ mult = 2**(n_downsampling - i)
+ output_size = 2**(i+2)
+ if norm_type == 'batch':
+ model += [
+ nn.Conv2DTranspose(ngf * mult,
+ ngf * mult//2,
+ kernel_size=4,
+ stride=2,
+ padding=1,
+ bias_attr=use_bias),
+ BatchNorm2D(ngf * mult//2),
+ nn.ReLU()
+ ]
+ else:
+ model += [
+ nn.Conv2DTranspose(ngf * mult,
+ int(ngf * mult//2),
+ kernel_size=4,
+ stride=2,
+ padding=1,
+ bias_attr=use_bias),
+ norm_layer(int(ngf * mult // 2)),
+ nn.ReLU()
+ ]
+
+ output_size = 2**(6)
+ model += [
+ nn.Conv2DTranspose(ngf ,
+ output_nc,
+ kernel_size=4,
+ stride=2,
+ padding=1,
+ bias_attr=use_bias),
+ nn.Tanh()
+ ]
+
+ self.model = nn.Sequential(*model)
+
+ def forward(self, x):
+ """Standard forward"""
+ return self.model(x)
\ No newline at end of file