提交 07c143e6 编写于 作者: L LielinJiang

init commit

上级 22554686
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
\ No newline at end of file
# PaddleGAN
\ No newline at end of file
# PaddleGAN
still under development!!
## Train
```
python -u tools/main.py --config-file configs/cyclegan-cityscapes.yaml
```
continue train from last checkpoint
```
python -u tools/main.py --config-file configs/cyclegan-cityscapes.yaml --resume your_checkpoint_path
```
multiple gpus train:
```
CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch tools/main.py --config-file configs/pix2pix-cityscapes.yaml
```
## Evaluate
```
python tools/main.py --config-file configs/cyclegan-cityscapes.yaml --evaluate-only --load your_weight_path
```
epochs: 200
isTrain: True
output_dir: output_dir
lambda_A: 10.0
lambda_B: 10.0
lambda_identity: 0.5
model:
name: CycleGANModel
defaults: &defaults
norm_type: instance
input_nc: 3
generator:
name: ResnetGenerator
output_nc: 3
n_blocks: 9
ngf: 64
use_dropout: False
<<: *defaults
discriminator:
name: NLayerDiscriminator
ndf: 64
n_layers: 3
<<: *defaults
gan_mode: lsgan
dataset:
train:
name: UnalignedDataset
dataroot: data/cityscapes
phase: train
max_dataset_size: inf
direction: AtoB
input_nc: 3
output_nc: 3
serial_batches: False
pool_size: 50
transform:
load_size: 286
crop_size: 256
preprocess: resize_and_crop
no_flip: False
test:
name: SingleDataset
dataroot: data/cityscapes/testB
max_dataset_size: inf
direction: BtoA
input_nc: 3
output_nc: 3
serial_batches: False
pool_size: 50
transform:
load_size: 256
crop_size: 256
preprocess: resize_and_crop
no_flip: True
optimizer:
name: Adam
beta1: 0.5
lr_scheduler:
name: linear
learning_rate: 0.0004
start_epoch: 100
decay_epochs: 100
log_config:
interval: 100
visiual_interval: 500
snapshot_config:
interval: 5
epochs: 200
isTrain: True
output_dir: output_dir
lambda_A: 10.0
lambda_B: 10.0
lambda_identity: 0.5
model:
name: CycleGANModel
defaults: &defaults
norm_type: instance
input_nc: 3
generator:
name: ResnetGenerator
output_nc: 3
n_blocks: 9
ngf: 64
use_dropout: False
<<: *defaults
discriminator:
name: NLayerDiscriminator
ndf: 64
n_layers: 3
<<: *defaults
gan_mode: lsgan
dataset:
train:
name: UnalignedDataset
dataroot: data/horse2zebra
phase: train
max_dataset_size: inf
direction: AtoB
input_nc: 3
output_nc: 3
serial_batches: False
pool_size: 50
transform:
load_size: 286
crop_size: 256
preprocess: resize_and_crop
no_flip: False
test:
name: SingleDataset
dataroot: data/horse2zebra/testA
max_dataset_size: inf
direction: AtoB
input_nc: 3
output_nc: 3
serial_batches: False
pool_size: 50
transform:
load_size: 256
crop_size: 256
preprocess: resize_and_crop
no_flip: True
optimizer:
name: Adam
beta1: 0.5
lr_scheduler:
name: linear
learning_rate: 0.0002
start_epoch: 100
decay_epochs: 100
log_config:
interval: 100
visiual_interval: 500
snapshot_config:
interval: 5
epochs: 200
isTrain: True
output_dir: output_dir
lambda_L1: 100
model:
name: Pix2PixModel
generator:
name: UnetGenerator
norm_type: batch
input_nc: 3
output_nc: 3
num_downs: 8 #unet256
ngf: 64
use_dropout: False
discriminator:
name: NLayerDiscriminator
ndf: 64
n_layers: 3
input_nc: 6
norm_type: batch
gan_mode: vanilla
dataset:
train:
name: AlignedDataset
dataroot: data/cityscapes
phase: train
max_dataset_size: inf
direction: BtoA
input_nc: 3
output_nc: 3
serial_batches: False
pool_size: 0
transform:
load_size: 286
crop_size: 256
preprocess: resize_and_crop
no_flip: False
test:
name: AlignedDataset
dataroot: data/cityscapes/
phase: test
max_dataset_size: inf
direction: BtoA
input_nc: 3
output_nc: 3
serial_batches: True
pool_size: 50
transform:
load_size: 256
crop_size: 256
preprocess: resize_and_crop
no_flip: True
optimizer:
name: Adam
beta1: 0.5
lr_scheduler:
name: linear
learning_rate: 0.0004
start_epoch: 100
decay_epochs: 100
log_config:
interval: 100
visiual_interval: 500
snapshot_config:
interval: 5
epochs: 200
isTrain: True
output_dir: output_dir
lambda_L1: 100
model:
name: Pix2PixModel
generator:
name: UnetGenerator
norm_type: batch
input_nc: 3
output_nc: 3
num_downs: 8 #unet256
ngf: 64
use_dropout: False
discriminator:
name: NLayerDiscriminator
ndf: 64
n_layers: 3
input_nc: 6
norm_type: batch
gan_mode: vanilla
dataset:
train:
name: AlignedDataset
dataroot: data/cityscapes
phase: train
max_dataset_size: inf
direction: BtoA
input_nc: 3
output_nc: 3
serial_batches: False
pool_size: 0
transform:
load_size: 286
crop_size: 256
preprocess: resize_and_crop
no_flip: False
test:
name: AlignedDataset
dataroot: data/cityscapes/
phase: test
max_dataset_size: inf
direction: BtoA
input_nc: 3
output_nc: 3
serial_batches: True
pool_size: 50
transform:
load_size: 256
crop_size: 256
preprocess: resize_and_crop
no_flip: True
optimizer:
name: Adam
beta1: 0.5
lr_scheduler:
name: linear
learning_rate: 0.0002
start_epoch: 100
decay_epochs: 100
log_config:
interval: 100
visiual_interval: 500
snapshot_config:
interval: 5
from .unaligned_dataset import UnalignedDataset
from .single_dataset import SingleDataset
from .aligned_dataset import AlignedDataset
import cv2
import paddle
import os.path
from .base_dataset import BaseDataset, get_params, get_transform
from .image_folder import make_dataset
from .builder import DATASETS
@DATASETS.register()
class AlignedDataset(BaseDataset):
"""A dataset class for paired image dataset.
"""
def __init__(self, opt):
"""Initialize this dataset class.
Args:
cfg (dict) -- stores all the experiment flags
"""
BaseDataset.__init__(self, opt)
self.dir_AB = os.path.join(opt.dataroot, opt.phase) # get the image directory
self.AB_paths = sorted(make_dataset(self.dir_AB, opt.max_dataset_size)) # get image paths
assert(self.cfg.transform.load_size >= self.cfg.transform.crop_size) # crop_size should be smaller than the size of loaded image
self.input_nc = self.cfg.output_nc if self.cfg.direction == 'BtoA' else self.cfg.input_nc
self.output_nc = self.cfg.input_nc if self.cfg.direction == 'BtoA' else self.cfg.output_nc
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index - - a random integer for data indexing
Returns a dictionary that contains A, B, A_paths and B_paths
A (tensor) - - an image in the input domain
B (tensor) - - its corresponding image in the target domain
A_paths (str) - - image paths
B_paths (str) - - image paths (same as A_paths)
"""
# read a image given a random integer index
AB_path = self.AB_paths[index]
AB = cv2.imread(AB_path)
# split AB image into A and B
h, w = AB.shape[:2]
# w, h = AB.size
w2 = int(w / 2)
A = AB[:h, :w2, :]
B = AB[:h, w2:, :]
# apply the same transform to both A and B
# transform_params = get_params(self.opt, A.size)
transform_params = get_params(self.cfg.transform, (w2, h))
A_transform = get_transform(self.cfg.transform, transform_params, grayscale=(self.input_nc == 1))
B_transform = get_transform(self.cfg.transform, transform_params, grayscale=(self.output_nc == 1))
A = A_transform(A)
B = B_transform(B)
# return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path}
return A, B, index #{'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path}
def __len__(self):
"""Return the total number of images in the dataset."""
return len(self.AB_paths)
def get_path_by_indexs(self, indexs):
if isinstance(indexs, paddle.Variable):
indexs = indexs.numpy()
current_paths = []
for index in indexs:
current_paths.append(self.AB_paths[index])
return current_paths
# code was heavily based on https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
import random
import numpy as np
from paddle.io import Dataset
from PIL import Image
import cv2
import paddle.incubate.hapi.vision.transforms as transforms
from .transforms import transforms as T
from abc import ABC, abstractmethod
class BaseDataset(Dataset, ABC):
"""This class is an abstract base class (ABC) for datasets.
"""
def __init__(self, cfg):
"""Initialize the class; save the options in the class
Args:
cfg (dict) -- stores all the experiment flags
"""
self.cfg = cfg
self.root = cfg.dataroot
@abstractmethod
def __len__(self):
"""Return the total number of images in the dataset."""
return 0
@abstractmethod
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index - - a random integer for data indexing
Returns:
a dictionary of data with their names. It ususally contains the data itself and its metadata information.
"""
pass
def get_params(cfg, size):
w, h = size
new_h = h
new_w = w
if cfg.preprocess == 'resize_and_crop':
new_h = new_w = cfg.load_size
elif cfg.preprocess == 'scale_width_and_crop':
new_w = cfg.load_size
new_h = cfg.load_size * h // w
x = random.randint(0, np.maximum(0, new_w - cfg.crop_size))
y = random.randint(0, np.maximum(0, new_h - cfg.crop_size))
flip = random.random() > 0.5
return {'crop_pos': (x, y), 'flip': flip}
def get_transform(cfg, params=None, grayscale=False, method=cv2.INTER_CUBIC, convert=True):
transform_list = []
if grayscale:
print('grayscale not support for now!!!')
# transform_list.append(transforms.Grayscale(1))
if 'resize' in cfg.preprocess:
osize = (cfg.load_size, cfg.load_size)
# print('os size:', osize)
transform_list.append(transforms.Resize(osize, method))
elif 'scale_width' in cfg.preprocess:
print('scale_width not support for now!!!')
# transform_list.append(transforms.Lambda(lambda img: __scale_width(img, cfg.load_size, cfg.crop_size, method)))
if 'crop' in cfg.preprocess:
# print('crop not support for now!!!', cfg.crop_size)
# transform_list.append(T.RandomCrop(cfg.crop_size))
if params is None:
transform_list.append(T.RandomCrop(cfg.crop_size))
else:
# print('crop not support for now!!!')
transform_list.append(T.Crop(params['crop_pos'], cfg.crop_size))
if cfg.preprocess == 'none':
print('preprocess not support for now!!!')
# transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
if not cfg.no_flip:
if params is None:
transform_list.append(transforms.RandomHorizontalFlip())
elif params['flip']:
transform_list.append(transforms.RandomHorizontalFlip(1.0))
if convert:
transform_list += [transforms.Permute(to_rgb=True)]
transform_list += [transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5))]
return transforms.Compose(transform_list)
def __make_power_2(img, base, method=Image.BICUBIC):
ow, oh = img.size
h = int(round(oh / base) * base)
w = int(round(ow / base) * base)
if h == oh and w == ow:
return img
__print_size_warning(ow, oh, w, h)
return img.resize((w, h), method)
def __scale_width(img, target_size, crop_size, method=Image.BICUBIC):
ow, oh = img.size
if ow == target_size and oh >= crop_size:
return img
w = target_size
h = int(max(target_size * oh / ow, crop_size))
return img.resize((w, h), method)
def __crop(img, pos, size):
ow, oh = img.size
x1, y1 = pos
tw = th = size
if (ow > tw or oh > th):
return img.crop((x1, y1, x1 + tw, y1 + th))
return img
def __flip(img, flip):
if flip:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img
def __print_size_warning(ow, oh, w, h):
"""Print warning information about image size(only print once)"""
if not hasattr(__print_size_warning, 'has_printed'):
print("The image size needs to be a multiple of 4. "
"The loaded image size was (%d, %d), so it was adjusted to "
"(%d, %d). This adjustment will be done to all images "
"whose sizes are not multiples of 4" % (ow, oh, w, h))
__print_size_warning.has_printed = True
import paddle
import numbers
import numpy as np
from paddle.imperative import ParallelEnv
from paddle.incubate.hapi.distributed import DistributedBatchSampler
from ..utils.registry import Registry
DATASETS = Registry("DATASETS")
def build_dataloader(cfg, is_train=True):
dataset = DATASETS.get(cfg.name)(cfg)
batch_size = cfg.get('batch_size', 1)
# dataloader = DictDataLoader(dataset, batch_size, is_train)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) \
if ParallelEnv().nranks > 1 else paddle.fluid.CUDAPlace(0)
sampler = DistributedBatchSampler(
dataset,
batch_size=batch_size,
shuffle=True if is_train else False,
drop_last=True if is_train else False)
dataloader = paddle.io.DataLoader(dataset,
batch_sampler=sampler,
places=place,
num_workers=0)
return dataloader
\ No newline at end of file
"""A modified image folder class
We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
so that this class can load images from both current directory and its subdirectories.
"""
from paddle.io import Dataset
from PIL import Image
import os
import os.path
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
'.tif', '.TIF', '.tiff', '.TIFF',
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset(dir, max_dataset_size=float("inf")):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir
for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
return images[:min(float(max_dataset_size), len(images))]
def default_loader(path):
return Image.open(path).convert('RGB')
class ImageFolder(Dataset):
def __init__(self, root, transform=None, return_paths=False,
loader=default_loader):
imgs = make_dataset(root)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in: " + root + "\n"
"Supported image extensions are: " +
",".join(IMG_EXTENSIONS)))
self.root = root
self.imgs = imgs
self.transform = transform
self.return_paths = return_paths
self.loader = loader
def __getitem__(self, index):
path = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.return_paths:
return img, path
else:
return img
def __len__(self):
return len(self.imgs)
import cv2
import paddle
from .base_dataset import BaseDataset, get_transform
from .image_folder import make_dataset
from .builder import DATASETS
@DATASETS.register()
class SingleDataset(BaseDataset):
"""
"""
def __init__(self, cfg):
"""Initialize this dataset class.
Args:
cfg (dict) -- stores all the experiment flags
"""
BaseDataset.__init__(self, cfg)
self.A_paths = sorted(make_dataset(cfg.dataroot, cfg.max_dataset_size))
input_nc = self.cfg.output_nc if self.cfg.direction == 'BtoA' else self.cfg.input_nc
self.transform = get_transform(cfg.transform, grayscale=(input_nc == 1))
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index - - a random integer for data indexing
Returns a dictionary that contains A and A_paths
A(tensor) - - an image in one domain
A_paths(str) - - the path of the image
"""
A_path = self.A_paths[index]
# A_img = Image.open(A_path).convert('RGB')
A_img = cv2.imread(A_path)
A = self.transform(A_img)
return (A, index) #{'A': A, 'A_paths': A_path}
def __len__(self):
"""Return the total number of images in the dataset."""
return len(self.A_paths)
def get_path_by_indexs(self, indexs):
if isinstance(indexs, paddle.Variable):
indexs = indexs.numpy()
current_paths = []
for index in indexs:
current_paths.append(self.A_paths[index])
return current_paths
import random
class RandomCrop(object):
def __init__(self, output_size):
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
self.output_size = output_size
def _get_params(self, img):
h, w, _ = img.shape
th, tw = self.output_size
if w == tw and h == th:
return 0, 0, h, w
i = random.randint(0, h - th)
j = random.randint(0, w - tw)
return i, j, th, tw
def __call__(self, img):
i, j, h, w = self._get_params(img)
cropped_img = img[i:i + h, j:j + w]
return cropped_img
class Crop():
def __init__(self, pos, size):
self.pos = pos
self.size = size
def __call__(self, img):
oh, ow, _ = img.shape
x, y = self.pos
th = tw = self.size
if (ow > tw or oh > th):
return img[y: y + th, x: x + tw]
return img
\ No newline at end of file
import cv2
import random
import os.path
from .base_dataset import BaseDataset, get_transform
from .image_folder import make_dataset
from .builder import DATASETS
@DATASETS.register()
class UnalignedDataset(BaseDataset):
"""
"""
def __init__(self, cfg):
"""Initialize this dataset class.
Args:
cfg (dict) -- stores all the experiment flags
"""
BaseDataset.__init__(self, cfg)
self.dir_A = os.path.join(cfg.dataroot, cfg.phase + 'A') # create a path '/path/to/data/trainA'
self.dir_B = os.path.join(cfg.dataroot, cfg.phase + 'B') # create a path '/path/to/data/trainB'
self.A_paths = sorted(make_dataset(self.dir_A, cfg.max_dataset_size)) # load images from '/path/to/data/trainA'
self.B_paths = sorted(make_dataset(self.dir_B, cfg.max_dataset_size)) # load images from '/path/to/data/trainB'
self.A_size = len(self.A_paths) # get the size of dataset A
self.B_size = len(self.B_paths) # get the size of dataset B
btoA = self.cfg.direction == 'BtoA'
input_nc = self.cfg.output_nc if btoA else self.cfg.input_nc # get the number of channels of input image
output_nc = self.cfg.input_nc if btoA else self.cfg.output_nc # get the number of channels of output image
self.transform_A = get_transform(self.cfg.transform, grayscale=(input_nc == 1))
self.transform_B = get_transform(self.cfg.transform, grayscale=(output_nc == 1))
self.reset_paths()
def reset_paths(self):
self.path_dict = {}
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index (int) -- a random integer for data indexing
Returns a dictionary that contains A, B, A_paths and B_paths
A (tensor) -- an image in the input domain
B (tensor) -- its corresponding image in the target domain
A_paths (str) -- image paths
B_paths (str) -- image paths
"""
A_path = self.A_paths[index % self.A_size] # make sure index is within then range
if self.cfg.serial_batches: # make sure index is within then range
index_B = index % self.B_size
else: # randomize the index for domain B to avoid fixed pairs.
index_B = random.randint(0, self.B_size - 1)
B_path = self.B_paths[index_B]
A_img = cv2.imread(A_path)
B_img = cv2.imread(B_path)
# apply image transformation
A = self.transform_A(A_img)
B = self.transform_B(B_img)
return A, B
def __len__(self):
"""Return the total number of images in the dataset.
As we have two datasets with potentially different number of images,
we take a maximum of
"""
return max(self.A_size, self.B_size)
import os
import time
import logging
from paddle.imperative import ParallelEnv
from ..datasets.builder import build_dataloader
from ..models.builder import build_model
from ..utils.visual import tensor2img, save_image
from ..utils.filesystems import save, load, makedirs
class Trainer:
def __init__(self, cfg):
# build train dataloader
self.train_dataloader = build_dataloader(cfg.dataset.train)
if 'lr_scheduler' in cfg.optimizer:
cfg.optimizer.lr_scheduler.step_per_epoch = len(self.train_dataloader)
# build model
self.model = build_model(cfg)
self.logger = logging.getLogger(__name__)
# base config
# self.timestamp = time.strftime('-%Y-%m-%d-%H-%M', time.localtime())
self.output_dir = cfg.output_dir
self.epochs = cfg.epochs
self.start_epoch = 0
self.current_epoch = 0
self.batch_id = 0
self.weight_interval = cfg.snapshot_config.interval
self.log_interval = cfg.log_config.interval
self.visual_interval = cfg.log_config.visiual_interval
self.cfg = cfg
self.local_rank = ParallelEnv().local_rank
def train(self):
for epoch in range(self.start_epoch, self.epochs):
start_time = time.time()
self.current_epoch = epoch
for i, data in enumerate(self.train_dataloader):
self.batch_id = i
# unpack data from dataset and apply preprocessing
self.model.set_input(data)
self.model.optimize_parameters()
if i % self.log_interval == 0:
self.print_log()
if i % self.visual_interval == 0:
self.visual('visual_train')
self.logger.info('train one epoch time: {}'.format(time.time() - start_time))
if epoch % self.weight_interval == 0:
self.save(epoch, 'weight', keep=-1)
self.save(epoch)
def test(self):
if not hasattr(self, 'test_dataloader'):
self.test_dataloader = build_dataloader(self.cfg.dataset.test, is_train=False)
# data[0]: img, data[1]: img path index
# test batch size must be 1
for i, data in enumerate(self.test_dataloader):
self.batch_id = i
# FIXME: dataloader not support map input, hard code now!!!
if self.cfg.dataset.test.name == 'AlignedDataset':
if self.cfg.dataset.test.direction == 'BtoA':
fake = self.model.test(data[1])
else:
fake = self.model.test(data[0])
elif self.cfg.dataset.test.name == 'SingleDataset':
fake = self.model.test(data[0])
current_paths = self.test_dataloader.dataset.get_path_by_indexs(data[-1])
visual_results = {}
for j in range(len(current_paths)):
name = os.path.basename(current_paths[j])
name = os.path.splitext(name)[0]
visual_results.update({name + '_fakeB': fake[j]})
visual_results.update({name + '_realA': data[1]})
visual_results.update({name + '_realB': data[0]})
# visual_results.update({'realB': data[1]})
self.visual('visual_test', visual_results=visual_results)
if i % self.log_interval == 0:
self.logger.info('Test iter: [%d/%d]' % (i, len(self.test_dataloader)))
def print_log(self):
losses = self.model.get_current_losses()
message = 'Epoch: %d, iters: %d ' % (self.current_epoch, self.batch_id)
message += '%s: %.6f ' % ('lr', self.current_learning_rate)
for k, v in losses.items():
message += '%s: %.3f ' % (k, v)
# print the message
self.logger.info(message)
@property
def current_learning_rate(self):
return self.model.optimizers[0].current_step_lr()
def visual(self, results_dir, visual_results=None):
self.model.compute_visuals()
if visual_results is None:
visual_results = self.model.get_current_visuals()
if self.cfg.isTrain:
msg = 'epoch%.3d_' % self.current_epoch
else:
msg = ''
makedirs(os.path.join(self.output_dir, results_dir))
for label, image in visual_results.items():
image_numpy = tensor2img(image)
img_path = os.path.join(self.output_dir, results_dir, msg + '%s.png' % (label))
save_image(image_numpy, img_path)
def save(self, epoch, name='checkpoint', keep=1):
if self.local_rank != 0:
return
assert name in ['checkpoint', 'weight']
state_dicts = {}
save_filename = 'epoch_%s_%s.pkl' % (epoch, name)
save_path = os.path.join(self.output_dir, save_filename)
for net_name in self.model.model_names:
if isinstance(net_name, str):
net = getattr(self.model, 'net' + net_name)
state_dicts['net' + net_name] = net.state_dict()
if name == 'weight':
save(state_dicts, save_path)
return
state_dicts['epoch'] = epoch
for opt_name in self.model.optimizer_names:
if isinstance(opt_name, str):
opt = getattr(self.model, opt_name)
state_dicts[opt_name] = opt.state_dict()
save(state_dicts, save_path)
if keep > 0:
try:
checkpoint_name_to_be_removed = os.path.join(self.output_dir,
'epoch_%s_%s.pkl' % (epoch - keep, name))
if os.path.exists(checkpoint_name_to_be_removed):
os.remove(checkpoint_name_to_be_removed)
except Exception as e:
self.logger.info('remove old checkpoints error: {}'.format(e))
def resume(self, checkpoint_path):
state_dicts = load(checkpoint_path)
if state_dicts.get('epoch', None) is not None:
self.start_epoch = state_dicts['epoch'] + 1
for name in self.model.model_names:
if isinstance(name, str):
net = getattr(self.model, 'net' + name)
net.set_dict(state_dicts['net' + name])
for name in self.model.optimizer_names:
if isinstance(name, str):
opt = getattr(self.model, name)
opt.set_dict(state_dicts[name])
def load(self, weight_path):
state_dicts = load(weight_path)
for name in self.model.model_names:
if isinstance(name, str):
net = getattr(self.model, 'net' + name)
net.set_dict(state_dicts['net' + name])
\ No newline at end of file
from .base_model import BaseModel
from .cycle_gan_model import CycleGANModel
from .pix2pix_model import Pix2PixModel
# code was heavily based on https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
import os
import paddle
import numpy as np
from collections import OrderedDict
from abc import ABC, abstractmethod
class BaseModel(ABC):
"""This class is an abstract base class (ABC) for models.
To create a subclass, you need to implement the following five functions:
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
-- <set_input>: unpack data from dataset and apply preprocessing.
-- <forward>: produce intermediate results.
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
"""
def __init__(self, opt):
"""Initialize the BaseModel class.
Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
When creating your custom class, you need to implement your own initialization.
In this function, you should first call <BaseModel.__init__(self, opt)>
Then, you need to define four lists:
-- self.loss_names (str list): specify the training losses that you want to plot and save.
-- self.model_names (str list): define networks used in our training.
-- self.visual_names (str list): specify the images that you want to display and save.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
"""
self.opt = opt
self.isTrain = opt.isTrain
self.save_dir = os.path.join(opt.output_dir, opt.model.name) # save all the checkpoints to save_dir
self.loss_names = []
self.model_names = []
self.visual_names = []
self.optimizers = []
self.optimizer_names = []
self.image_paths = []
self.metric = 0 # used for learning rate policy 'plateau'
@staticmethod
def modify_commandline_options(parser, is_train):
"""Add new model-specific options, and rewrite default values for existing options.
Parameters:
parser -- original option parser
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
"""
return parser
@abstractmethod
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): includes the data itself and its metadata information.
"""
pass
@abstractmethod
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
pass
@abstractmethod
def optimize_parameters(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
pass
def eval(self):
"""Make models eval mode during test time"""
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
net.eval()
def test(self, input):
"""Forward function used in test time.
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
It also calls <compute_visuals> to produce additional visualization results
"""
with paddle.imperative.no_grad():
self.forward_test()
self.compute_visuals()
def compute_visuals(self):
"""Calculate additional output images for visdom and HTML visualization"""
pass
def get_image_paths(self):
""" Return image paths that are used to load current data"""
return self.image_paths
def get_current_visuals(self):
"""Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
visual_ret = OrderedDict()
for name in self.visual_names:
if isinstance(name, str):
visual_ret[name] = getattr(self, name)
return visual_ret
def get_current_losses(self):
"""Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
errors_ret = OrderedDict()
for name in self.loss_names:
if isinstance(name, str):
errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
return errors_ret
def set_requires_grad(self, nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Parameters:
nets (network list) -- a list of networks
requires_grad (bool) -- whether the networks require gradients or not
"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
# print('trainable:', param.trainable)
param.trainable = requires_grad
# param.stop_gradient = not requires_grad
import paddle
from ..utils.registry import Registry
MODELS = Registry("MODEL")
def build_model(cfg):
# dataset = MODELS.get(cfg.MODEL.name)(cfg.MODEL)
# place = paddle.CUDAPlace(0)
# dataloader = paddle.io.DataLoader(dataset,
# batch_size=1, #opt.batch_size,
# places=place,
# shuffle=True, #not opt.serial_batches,
# num_workers=0)#int(opt.num_threads))
model = MODELS.get(cfg.model.name)(cfg)
return model
# pass
\ No newline at end of file
import paddle
from .base_model import BaseModel
from .builder import MODELS
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from .losses import GANLoss
# from ..modules.nn import L1Loss
from ..solver import build_optimizer
from ..utils.image_pool import ImagePool
@MODELS.register()
class CycleGANModel(BaseModel):
"""
This class implements the CycleGAN model, for learning image-to-image translation without paired data.
The model training requires '--dataset_mode unaligned' dataset.
By default, it uses a '--netG resnet_9blocks' ResNet generator,
a '--netD basic' discriminator (PatchGAN introduced by pix2pix),
and a least-square GANs objective ('--gan_mode lsgan').
CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
"""
def __init__(self, opt):
"""Initialize the CycleGAN class.
Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
BaseModel.__init__(self, opt)
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
visual_names_A = ['real_A', 'fake_B', 'rec_A']
visual_names_B = ['real_B', 'fake_A', 'rec_B']
if self.isTrain and self.opt.lambda_identity > 0.0: # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
visual_names_A.append('idt_B')
visual_names_B.append('idt_A')
self.visual_names = visual_names_A + visual_names_B # combine visualizations for A and B
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.
if self.isTrain:
self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
else: # during test time, only load Gs
self.model_names = ['G_A', 'G_B']
# define networks (both Generators and discriminators)
# The naming is different from those used in the paper.
# Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
self.netG_A = build_generator(opt.model.generator)
self.netG_B = build_generator(opt.model.generator)
if self.isTrain: # define discriminators
self.netD_A = build_discriminator(opt.model.discriminator)
self.netD_B = build_discriminator(opt.model.discriminator)
if self.isTrain:
if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels
assert(opt.dataset.train.input_nc == opt.dataset.train.output_nc)
self.fake_A_pool = ImagePool(opt.dataset.train.pool_size) # create image buffer to store previously generated images
self.fake_B_pool = ImagePool(opt.dataset.train.pool_size) # create image buffer to store previously generated images
# define loss functions
self.criterionGAN = GANLoss(opt.model.gan_mode, [[[[1.0]]]], [[[[0.0]]]])#.to(self.device) # define GAN loss.
self.criterionCycle = paddle.nn.L1Loss()
self.criterionIdt = paddle.nn.L1Loss()
self.optimizer_G = build_optimizer(opt.optimizer, parameter_list=self.netG_A.parameters() + self.netG_B.parameters())
self.optimizer_D = build_optimizer(opt.optimizer, parameter_list=self.netD_A.parameters() + self.netD_B.parameters())
# self.optimizer_DA = build_optimizer(opt.optimizer, parameter_list=self.netD_A.parameters())
# self.optimizer_DB = build_optimizer(opt.optimizer, parameter_list=self.netD_B.parameters())
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D)
# self.optimizers.append(self.optimizer_DA)
# self.optimizers.append(self.optimizer_DB)
self.optimizer_names.extend(['optimizer_G', 'optimizer_D'])#A', 'optimizer_DB'])
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): include the data itself and its metadata information.
The option 'direction' can be used to swap domain A and domain B.
"""
AtoB = self.opt.dataset.train.direction == 'AtoB'
self.real_A = paddle.imperative.to_variable(input[0] if AtoB else input[1])
self.real_B = paddle.imperative.to_variable(input[1] if AtoB else input[0])
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.fake_B = self.netG_A(self.real_A) # G_A(A)
self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A))
self.fake_A = self.netG_B(self.real_B) # G_B(B)
self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B))
def forward_test(self, input):
input = paddle.imperative.to_variable(input)
net_g = getattr(self, 'netG_' + self.opt.dataset.test.direction[0])
return net_g(input)
def test(self, input):
"""Forward function used in test time.
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
It also calls <compute_visuals> to produce additional visualization results
"""
with paddle.imperative.no_grad():
return self.forward_test(input)
def backward_D_basic(self, netD, real, fake):
"""Calculate GAN loss for the discriminator
Parameters:
netD (network) -- the discriminator D
real (tensor array) -- real images
fake (tensor array) -- images generated by a generator
Return the discriminator loss.
We also call loss_D.backward() to calculate the gradients.
"""
# Real
pred_real = netD(real)
loss_D_real = self.criterionGAN(pred_real, True)
# Fake
pred_fake = netD(fake.detach())
loss_D_fake = self.criterionGAN(pred_fake, False)
# Combined loss and calculate gradients
loss_D = (loss_D_real + loss_D_fake) * 0.5
loss_D.backward()
return loss_D
def backward_D_A(self):
"""Calculate GAN loss for discriminator D_A"""
fake_B = self.fake_B_pool.query(self.fake_B)
self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
def backward_D_B(self):
"""Calculate GAN loss for discriminator D_B"""
fake_A = self.fake_A_pool.query(self.fake_A)
self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
def backward_G(self):
"""Calculate the loss for generators G_A and G_B"""
lambda_idt = self.opt.lambda_identity
lambda_A = self.opt.lambda_A
lambda_B = self.opt.lambda_B
# Identity loss
if lambda_idt > 0:
# G_A should be identity if real_B is fed: ||G_A(B) - B||
self.idt_A = self.netG_A(self.real_B)
self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
# G_B should be identity if real_A is fed: ||G_B(A) - A||
self.idt_B = self.netG_B(self.real_A)
self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
else:
self.loss_idt_A = 0
self.loss_idt_B = 0
# GAN loss D_A(G_A(A))
self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
# GAN loss D_B(G_B(B))
self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
# Forward cycle loss || G_B(G_A(A)) - A||
self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
# Backward cycle loss || G_A(G_B(B)) - B||
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
# combined loss and calculate gradients
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
self.loss_G.backward()
def optimize_parameters(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
# forward
self.forward() # compute fake images and reconstruction images.
# G_A and G_B
self.set_requires_grad([self.netD_A, self.netD_B], False) # Ds require no gradients when optimizing Gs
self.optimizer_G.clear_gradients() #zero_grad() # set G_A and G_B's gradients to zero
self.backward_G() # calculate gradients for G_A and G_B
self.optimizer_G.minimize(self.loss_G) #step() # update G_A and G_B's weights
# self.optimizer_G.clear_gradients()
# self.optimizer_G.clear_gradients()
# D_A and D_B
self.set_requires_grad([self.netD_A, self.netD_B], True)
# self.set_requires_grad(self.netD_A, True)
self.optimizer_D.clear_gradients() #zero_grad() # set D_A and D_B's gradients to zero
self.backward_D_A() # calculate gradients for D_A
self.backward_D_B() # calculate graidents for D_B
self.optimizer_D.minimize(self.loss_D_A + self.loss_D_B) # update D_A and D_B's weights
# self.backward_D_A() # calculate gradients for D_A
# self.optimizer_DA.minimize(self.loss_D_A) #step() # update D_A and D_B's weights
# self.optimizer_DA.clear_gradients() #zero_g
# self.set_requires_grad(self.netD_B, True)
# self.optimizer_DB.clear_gradients() #zero_grad() # set D_A and D_B's gradients to zero
# self.backward_D_B() # calculate graidents for D_B
# self.optimizer_DB.minimize(self.loss_D_B) #step() # update D_A and D_B's weights
# self.optimizer_DB.clear_gradients() #zero_grad() # set D_A and D_B's gradients to zero
from .nlayers import NLayerDiscriminator
\ No newline at end of file
import copy
from ...utils.registry import Registry
DISCRIMINATORS = Registry("DISCRIMINATOR")
def build_discriminator(cfg):
cfg_copy = copy.deepcopy(cfg)
name = cfg_copy.pop('name')
discriminator = DISCRIMINATORS.get(name)(**cfg_copy)
return discriminator
import paddle
import functools
import numpy as np
import paddle.nn as nn
from ...modules.nn import ReflectionPad2d, LeakyReLU, Tanh, Dropout, BCEWithLogitsLoss, Conv2DTranspose, Conv2D, Pad2D, MSELoss
from ...modules.norm import build_norm_layer
from .builder import DISCRIMINATORS
@DISCRIMINATORS.register()
class NLayerDiscriminator(paddle.fluid.dygraph.Layer):
"""Defines a PatchGAN discriminator"""
def __init__(self, input_nc, ndf=64, n_layers=3, norm_type='instance'):
"""Construct a PatchGAN discriminator
Args:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_type (str) -- normalization layer type
"""
super(NLayerDiscriminator, self).__init__()
norm_layer = build_norm_layer(norm_type)
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm
else:
use_bias = norm_layer == nn.InstanceNorm
kw = 4
padw = 1
sequence = [Conv2D(input_nc, ndf, filter_size=kw, stride=2, padding=padw), LeakyReLU(0.2, True)]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
Conv2D(ndf * nf_mult_prev, ndf * nf_mult, filter_size=kw, stride=2, padding=padw, bias_attr=use_bias),
norm_layer(ndf * nf_mult),
LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
Conv2D(ndf * nf_mult_prev, ndf * nf_mult, filter_size=kw, stride=1, padding=padw, bias_attr=use_bias),
norm_layer(ndf * nf_mult),
LeakyReLU(0.2, True)
]
sequence += [Conv2D(ndf * nf_mult, 1, filter_size=kw, stride=1, padding=padw)]
self.model = nn.Sequential(*sequence)
def forward(self, input):
"""Standard forward."""
return self.model(input)
\ No newline at end of file
from .resnet import ResnetGenerator
from .unet import UnetGenerator
\ No newline at end of file
import copy
from ...utils.registry import Registry
GENERATORS = Registry("GENERATOR")
def build_generator(cfg):
cfg_copy = copy.deepcopy(cfg)
name = cfg_copy.pop('name')
generator = GENERATORS.get(name)(**cfg_copy)
return generator
import paddle
import paddle.nn as nn
import functools
from ...modules.nn import ReflectionPad2d, LeakyReLU, Tanh, Dropout, BCEWithLogitsLoss, Conv2DTranspose, Conv2D, Pad2D, MSELoss
from ...modules.norm import build_norm_layer
from .builder import GENERATORS
@GENERATORS.register()
class ResnetGenerator(paddle.fluid.dygraph.Layer):
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
"""
def __init__(self, input_nc, output_nc, ngf=64, norm_type='instance', use_dropout=False, n_blocks=6, padding_type='reflect'):
"""Construct a Resnet-based generator
Args:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
use_dropout (bool) -- if use dropout layers
n_blocks (int) -- the number of ResNet blocks
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
"""
assert(n_blocks >= 0)
super(ResnetGenerator, self).__init__()
norm_layer = build_norm_layer(norm_type)
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm
else:
use_bias = norm_layer == nn.InstanceNorm
print('norm layer:', norm_layer, 'use bias:', use_bias)
model = [ReflectionPad2d(3),
nn.Conv2D(input_nc, ngf, filter_size=7, padding=0, bias_attr=use_bias),
# nn.nn.Conv2D(input_nc, ngf, filter_size=7, padding=0, bias_attr=use_bias),
norm_layer(ngf),
nn.ReLU()]
n_downsampling = 2
for i in range(n_downsampling): # add downsampling layers
mult = 2 ** i
model += [
nn.Conv2D(ngf * mult, ngf * mult * 2, filter_size=3, stride=2, padding=1, bias_attr=use_bias),
norm_layer(ngf * mult * 2),
nn.ReLU()]
mult = 2 ** n_downsampling
for i in range(n_blocks): # add ResNet blocks
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
for i in range(n_downsampling): # add upsampling layers
mult = 2 ** (n_downsampling - i)
model += [
nn.Conv2DTranspose(ngf * mult, int(ngf * mult / 2),
filter_size=3, stride=2,
padding=1, #output_padding=1,
# padding='same', #output_padding=1,
bias_attr=use_bias),
Pad2D(paddings=[0, 1, 0, 1], mode='constant', pad_value=0.0),
norm_layer(int(ngf * mult / 2)),
nn.ReLU()]
model += [ReflectionPad2d(3)]
model += [nn.Conv2D(ngf, output_nc, filter_size=7, padding=0)]
model += [Tanh()]
self.model = nn.Sequential(*model)
def forward(self, x):
"""Standard forward"""
return self.model(x)
class ResnetBlock(paddle.fluid.dygraph.Layer):
"""Define a Resnet block"""
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
"""Initialize the Resnet block
A resnet block is a conv block with skip connections
We construct a conv block with build_conv_block function,
and implement skip connections in <forward> function.
Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
"""
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
"""Construct a convolutional block.
Parameters:
dim (int) -- the number of channels in the conv layer.
padding_type (str) -- the name of padding layer: reflect | replicate | zero
norm_layer -- normalization layer
use_dropout (bool) -- if use dropout layers.
use_bias (bool) -- if the conv layer uses bias or not
Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
"""
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2D(dim, dim, filter_size=3, padding=p, bias_attr=use_bias), norm_layer(dim), nn.ReLU()]
if use_dropout:
conv_block += [Dropout(0.5)]
p = 0
if padding_type == 'reflect':
conv_block += [ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2D(dim, dim, filter_size=3, padding=p, bias_attr=use_bias), norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
"""Forward function (with skip connections)"""
out = x + self.conv_block(x) # add skip connections
return out
import paddle
import paddle.nn as nn
import functools
from ...modules.nn import ReflectionPad2d, LeakyReLU, Tanh, Dropout, Conv2DTranspose, Conv2D
from ...modules.norm import build_norm_layer
from .builder import GENERATORS
@GENERATORS.register()
class UnetGenerator(paddle.fluid.dygraph.Layer):
"""Create a Unet-based generator"""
def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_type='batch', use_dropout=False):
"""Construct a Unet generator
Args:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
image of size 128x128 will become of size 1x1 # at the bottleneck
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
We construct the U-Net from the innermost layer to the outermost layer.
It is a recursive process.
"""
super(UnetGenerator, self).__init__()
norm_layer = build_norm_layer(norm_type)
# construct unet structure
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
# gradually reduce the number of filters from ngf * 8 to ngf
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
def forward(self, input):
"""Standard forward"""
# tmp = self.model._sub_layers['model'][0](input)
# tmp1 = self.model._sub_layers['model'][1](tmp)
# tmp2 = self.model._sub_layers['model'][2](tmp1)
# import pickle
# pickle.dump(tmp2.numpy(), open('/workspace/notebook/align_pix2pix/tmp2-pd.pkl', 'wb'))
# tmp3 = self.model._sub_layers['model'][3](tmp2)
# pickle.dump(tmp3.numpy(), open('/workspace/notebook/align_pix2pix/tmp3-pd.pkl', 'wb'))
# tmp4 = self.model._sub_layers['model'][4](tmp3)
return self.model(input)
class UnetSkipConnectionBlock(paddle.fluid.dygraph.Layer):
"""Defines the Unet submodule with skip connection.
X -------------------identity----------------------
|-- downsampling -- |submodule| -- upsampling --|
"""
def __init__(self, outer_nc, inner_nc, input_nc=None,
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm, use_dropout=False):
"""Construct a Unet submodule with skip connections.
Parameters:
outer_nc (int) -- the number of filters in the outer conv layer
inner_nc (int) -- the number of filters in the inner conv layer
input_nc (int) -- the number of channels in input images/features
submodule (UnetSkipConnectionBlock) -- previously defined submodules
outermost (bool) -- if this module is the outermost module
innermost (bool) -- if this module is the innermost module
norm_layer -- normalization layer
use_dropout (bool) -- if use dropout layers.
"""
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm
else:
use_bias = norm_layer == nn.InstanceNorm
if input_nc is None:
input_nc = outer_nc
downconv = Conv2D(input_nc, inner_nc, filter_size=4,
stride=2, padding=1, bias_attr=use_bias)
downrelu = LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(True)
upnorm = norm_layer(outer_nc)
if outermost:
upconv = Conv2DTranspose(inner_nc * 2, outer_nc,
filter_size=4, stride=2,
padding=1)
down = [downconv]
up = [uprelu, upconv, Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = Conv2DTranspose(inner_nc, outer_nc,
filter_size=4, stride=2,
padding=1, bias_attr=use_bias)
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
model = down + up
else:
upconv = Conv2DTranspose(inner_nc * 2, outer_nc,
filter_size=4, stride=2,
padding=1, bias_attr=use_bias)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if use_dropout:
model = down + [submodule] + up + [Dropout(0.5)]
else:
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost:
return self.model(x)
else: # add skip connections
return paddle.concat([x, self.model(x)], 1)
import paddle
import paddle.nn as nn
import numpy as np
from ..modules.nn import BCEWithLogitsLoss
class GANLoss(paddle.fluid.dygraph.Layer):
"""Define different GAN objectives.
The GANLoss class abstracts away the need to create the target label tensor
that has the same size as the input.
"""
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
""" Initialize the GANLoss class.
Parameters:
gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
target_real_label (bool) - - label for a real image
target_fake_label (bool) - - label of a fake image
Note: Do not use sigmoid as the last layer of Discriminator.
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
"""
super(GANLoss, self).__init__()
self.real_label = paddle.fluid.dygraph.to_variable(np.array(target_real_label))
self.fake_label = paddle.fluid.dygraph.to_variable(np.array(target_fake_label))
# self.real_label.stop_gradients = True
# self.fake_label.stop_gradients = True
self.gan_mode = gan_mode
if gan_mode == 'lsgan':
self.loss = nn.MSELoss()
elif gan_mode == 'vanilla':
self.loss = BCEWithLogitsLoss()#nn.BCEWithLogitsLoss()
elif gan_mode in ['wgangp']:
self.loss = None
else:
raise NotImplementedError('gan mode %s not implemented' % gan_mode)
def get_target_tensor(self, prediction, target_is_real):
"""Create label tensors with the same size as the input.
Parameters:
prediction (tensor) - - tpyically the prediction from a discriminator
target_is_real (bool) - - if the ground truth label is for real images or fake images
Returns:
A label tensor filled with ground truth label, and with the size of the input
"""
if target_is_real:
target_tensor = paddle.fill_constant(shape=paddle.shape(prediction), value=1.0, dtype='float32')#self.real_label
else:
target_tensor = paddle.fill_constant(shape=paddle.shape(prediction), value=0.0, dtype='float32')#self.fake_label
# target_tensor = paddle.cast(target_tensor, prediction.dtype)
# target_tensor = paddle.expand_as(target_tensor, prediction)
# target_tensor.stop_gradient = True
return target_tensor#paddle.expand_as(target_tensor, prediction)
def __call__(self, prediction, target_is_real):
"""Calculate loss given Discriminator's output and grount truth labels.
Parameters:
prediction (tensor) - - tpyically the prediction output from a discriminator
target_is_real (bool) - - if the ground truth label is for real images or fake images
Returns:
the calculated loss.
"""
if self.gan_mode in ['lsgan', 'vanilla']:
target_tensor = self.get_target_tensor(prediction, target_is_real)
loss = self.loss(prediction, target_tensor)
elif self.gan_mode == 'wgangp':
if target_is_real:
loss = -prediction.mean()
else:
loss = prediction.mean()
return loss
\ No newline at end of file
# import torch
# import paddle
# from .base_model import BaseModel
# from . import networks
import paddle
from .base_model import BaseModel
from .builder import MODELS
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from .losses import GANLoss
# from ..modules.nn import L1Loss
from ..solver import build_optimizer
from ..utils.image_pool import ImagePool
@MODELS.register()
class Pix2PixModel(BaseModel):
""" This class implements the pix2pix model, for learning a mapping from input images to output images given paired data.
The model training requires '--dataset_mode aligned' dataset.
By default, it uses a '--netG unet256' U-Net generator,
a '--netD basic' discriminator (PatchGAN),
and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper).
pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf
"""
def __init__(self, opt):
"""Initialize the pix2pix class.
Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
BaseModel.__init__(self, opt)
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
self.visual_names = ['real_A', 'fake_B', 'real_B']
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
if self.isTrain:
self.model_names = ['G', 'D']
else: # during test time, only load G
self.model_names = ['G']
# define networks (both generator and discriminator)
self.netG = build_generator(opt.model.generator)
# self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
# not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
self.netD = build_discriminator(opt.model.discriminator)
# self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
# opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
if self.isTrain:
# define loss functions
self.criterionGAN = GANLoss(opt.model.gan_mode, [[[[1.0]]]], [[[[0.0]]]])#.to(self.device)
self.criterionL1 = paddle.nn.L1Loss()
# initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
# self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
# self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
# FIXME: step per epoch
# lr_scheduler_g = self.build_lr_scheduler(opt.lr, step_per_epoch=2975)
# lr_scheduler_d = self.build_lr_scheduler(opt.lr, step_per_epoch=2975)
# lr_scheduler = self.build_lr_scheduler()
self.optimizer_G = build_optimizer(opt.optimizer, parameter_list=self.netG.parameters()) #paddle.optimizer.Adam(learning_rate=lr_scheduler_g, parameter_list=self.netG.parameters(), beta1=opt.beta1)
self.optimizer_D = build_optimizer(opt.optimizer, parameter_list=self.netD.parameters()) #paddle.optimizer.Adam(learning_rate=lr_scheduler_d, parameter_list=self.netD.parameters(), beta1=opt.beta1)
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D)
self.optimizer_names.extend(['optimizer_G', 'optimizer_D'])
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): include the data itself and its metadata information.
The option 'direction' can be used to swap images in domain A and domain B.
"""
# AtoB = self.opt.direction == 'AtoB'
# self.real_A = input['A' if AtoB else 'B'].to(self.device)
# self.real_B = input['B' if AtoB else 'A'].to(self.device)
# self.image_paths = input['A_paths' if AtoB else 'B_paths']
AtoB = self.opt.dataset.train.direction == 'AtoB'
self.real_A = paddle.imperative.to_variable(input[0] if AtoB else input[1])
self.real_B = paddle.imperative.to_variable(input[1] if AtoB else input[0])
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.fake_B = self.netG(self.real_A) # G(A)
def forward_test(self, input):
input = paddle.imperative.to_variable(input)
return self.netG(input)
def test(self, input):
"""Forward function used in test time.
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
It also calls <compute_visuals> to produce additional visualization results
"""
with paddle.imperative.no_grad():
return self.forward_test(input)
def backward_D(self):
"""Calculate GAN loss for the discriminator"""
# Fake; stop backprop to the generator by detaching fake_B
fake_AB = paddle.concat((self.real_A, self.fake_B), 1) # we use conditional GANs; we need to feed both input and output to the discriminator
pred_fake = self.netD(fake_AB.detach())
self.loss_D_fake = self.criterionGAN(pred_fake, False)
# Real
real_AB = paddle.concat((self.real_A, self.real_B), 1)
pred_real = self.netD(real_AB)
self.loss_D_real = self.criterionGAN(pred_real, True)
# combine loss and calculate gradients
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
self.loss_D.backward()
def backward_G(self):
"""Calculate GAN and L1 loss for the generator"""
# First, G(A) should fake the discriminator
fake_AB = paddle.concat((self.real_A, self.fake_B), 1)
pred_fake = self.netD(fake_AB)
self.loss_G_GAN = self.criterionGAN(pred_fake, True)
# Second, G(A) = B
self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1
# combine loss and calculate gradients
self.loss_G = self.loss_G_GAN + self.loss_G_L1
# self.loss_G = self.loss_G_L1
self.loss_G.backward()
def optimize_parameters(self):
self.forward() # compute fake images: G(A)
# update D
self.set_requires_grad(self.netD, True) # enable backprop for D
self.optimizer_D.clear_gradients() # set D's gradients to zero
self.backward_D() # calculate gradients for D
self.optimizer_D.minimize(self.loss_D) # update D's weights
# self.netD.clear_gradients()
# self.optimizer_D.clear_gradients()
# update G
self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G
self.optimizer_G.clear_gradients() # set G's gradients to zero
self.backward_G() # calculate graidents for G
self.optimizer_G.minimize(self.loss_G) # udpate G's weights
import paddle
from paddle.fluid.dygraph import Layer
from paddle import fluid
class MSELoss():
def __init__(self):
pass
def __call__(self, prediction, label):
return fluid.layers.mse_loss(prediction, label)
class L1Loss():
def __init__(self):
pass
def __call__(self, prediction, label):
return fluid.layers.reduce_mean(fluid.layers.elementwise_sub(prediction, label, act='abs'))
class ReflectionPad2d(Layer):
def __init__(self, size):
super(ReflectionPad2d, self).__init__()
self.size = size
def forward(self, x):
return fluid.layers.pad2d(x, [self.size] * 4, mode="reflect")
class LeakyReLU(Layer):
def __init__(self, alpha, inplace=False):
super(LeakyReLU, self).__init__()
self.alpha = alpha
def forward(self, x):
return fluid.layers.leaky_relu(x, self.alpha)
class Tanh(Layer):
def __init__(self):
super(Tanh, self).__init__()
def forward(self, x):
return fluid.layers.tanh(x)
class Dropout(Layer):
def __init__(self, prob, mode='upscale_in_train'):
super(Dropout, self).__init__()
self.prob = prob
self.mode = mode
def forward(self, x):
return fluid.layers.dropout(x, self.prob, dropout_implementation=self.mode)
class BCEWithLogitsLoss():
def __init__(self, weight=None, reduction='mean'):
self.weight = weight
self.reduction = 'mean'
def __call__(self, x, label):
out = paddle.fluid.layers.sigmoid_cross_entropy_with_logits(x, label)
if self.reduction == 'sum':
return fluid.layers.reduce_sum(out)
elif self.reduction == 'mean':
return fluid.layers.reduce_mean(out)
else:
return out
# class BCEWithLogitsLoss(fluid.dygraph.Layer):
# def __init__(self, weight=None, reduction='mean'):
# if reduction not in ['sum', 'mean', 'none']:
# raise ValueError(
# "The value of 'reduction' in bce_loss should be 'sum', 'mean' or 'none', but "
# "received %s, which is not allowed." % reduction)
# super(BCEWithLogitsLoss, self).__init__()
# # self.weight = weight
# # self.reduction = reduction
# self.bce_loss = paddle.nn.BCELoss(weight, reduction)
# def forward(self, input, label):
# input = paddle.nn.functional.sigmoid(input, True)
# return self.bce_loss(input, label)
def initial_type(
input,
op_type,
fan_out,
init="normal",
use_bias=False,
filter_size=0,
stddev=0.02,
name=None):
if init == "kaiming":
if op_type == 'conv':
fan_in = input.shape[1] * filter_size * filter_size
elif op_type == 'deconv':
fan_in = fan_out * filter_size * filter_size
else:
if len(input.shape) > 2:
fan_in = input.shape[1] * input.shape[2] * input.shape[3]
else:
fan_in = input.shape[1]
bound = 1 / math.sqrt(fan_in)
param_attr = fluid.ParamAttr(
# name=name + "_w",
initializer=fluid.initializer.Uniform(
low=-bound, high=bound))
if use_bias == True:
bias_attr = fluid.ParamAttr(
# name=name + '_b',
initializer=fluid.initializer.Uniform(
low=-bound, high=bound))
else:
bias_attr = False
else:
param_attr = fluid.ParamAttr(
# name=name + "_w",
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=stddev))
if use_bias == True:
bias_attr = fluid.ParamAttr(
# name=name + "_b",
initializer=fluid.initializer.Constant(0.0))
else:
bias_attr = False
return param_attr, bias_attr
class Conv2D(paddle.nn.Conv2D):
def __init__(self,
num_channels,
num_filters,
filter_size,
padding=0,
stride=1,
dilation=1,
groups=1,
param_attr=None,
bias_attr=None,
use_cudnn=True,
act=None,
data_format="NCHW",
dtype='float32',
init_type='normal'):
param_attr, bias_attr = initial_type(
input=input,
op_type='conv',
fan_out=num_filters,
init=init_type,
use_bias=True if bias_attr != False else False,
filter_size=filter_size)
super(Conv2D, self).__init__(num_channels,
num_filters,
filter_size,
padding,
stride,
dilation,
groups,
param_attr,
bias_attr,
use_cudnn,
act,
data_format,
dtype)
class Conv2DTranspose(paddle.nn.Conv2DTranspose):
def __init__(self,
num_channels,
num_filters,
filter_size,
output_size=None,
padding=0,
stride=1,
dilation=1,
groups=1,
param_attr=None,
bias_attr=None,
use_cudnn=True,
act=None,
data_format="NCHW",
dtype='float32',
init_type='normal'):
param_attr, bias_attr = initial_type(
input=input,
op_type='deconv',
fan_out=num_filters,
init=init_type,
use_bias=True if bias_attr != False else False,
filter_size=filter_size)
super(Conv2DTranspose, self).__init__(
num_channels,
num_filters,
filter_size,
output_size,
padding,
stride,
dilation,
groups,
param_attr,
bias_attr,
use_cudnn,
act,
data_format,
dtype)
class Pad2D(fluid.dygraph.Layer):
def __init__(self, paddings, mode, pad_value=0.0):
super(Pad2D, self).__init__()
self.paddings = paddings
self.mode = mode
def forward(self, x):
return fluid.layers.pad2d(x, self.paddings, self.mode)
\ No newline at end of file
import paddle
import functools
import paddle.nn as nn
class Identity(paddle.fluid.dygraph.Layer):
def forward(self, x):
return x
def build_norm_layer(norm_type='instance'):
"""Return a normalization layer
Args:
norm_type (str) -- the name of the normalization layer: batch | instance | none
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
"""
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm, param_attr=paddle.ParamAttr(initializer=paddle.fluid.initializer.NormalInitializer(1.0, 0.02)), bias_attr=paddle.ParamAttr(initializer=paddle.fluid.initializer.Constant(0.0)), trainable_statistics=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm, param_attr=paddle.ParamAttr(initializer=paddle.fluid.initializer.Constant(1.0), learning_rate=0.0, trainable=False), bias_attr=paddle.ParamAttr(initializer=paddle.fluid.initializer.Constant(0.0), learning_rate=0.0, trainable=False))
elif norm_type == 'none':
def norm_layer(x): return Identity()
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
\ No newline at end of file
from .optimizer import build_optimizer
\ No newline at end of file
import paddle
def build_lr_scheduler(cfg):
name = cfg.pop('name')
# TODO: add more learning rate scheduler
if name == 'linear':
return LinearDecay(**cfg)
else:
raise NotImplementedError
class LinearDecay(paddle.fluid.dygraph.learning_rate_scheduler.LearningRateDecay):
def __init__(self, learning_rate, step_per_epoch, start_epoch, decay_epochs):
super(LinearDecay, self).__init__()
self.learning_rate = learning_rate
self.start_epoch = start_epoch
self.decay_epochs = decay_epochs
self.step_per_epoch = step_per_epoch
def step(self):
cur_epoch = int(self.step_num // self.step_per_epoch)
decay_rate = 1.0 - max(0, cur_epoch + 1 - self.start_epoch) / float(self.decay_epochs + 1)
return self.create_lr_var(decay_rate * self.learning_rate)
\ No newline at end of file
import copy
import paddle
from .lr_scheduler import build_lr_scheduler
def build_optimizer(cfg, parameter_list=None):
cfg_copy = copy.deepcopy(cfg)
lr_scheduler_cfg = cfg_copy.pop('lr_scheduler', None)
lr_scheduler = build_lr_scheduler(lr_scheduler_cfg)
opt_name = cfg_copy.pop('name')
return getattr(paddle.optimizer, opt_name)(lr_scheduler, parameter_list=parameter_list, **cfg_copy)
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import yaml
__all__ = ['get_config']
class AttrDict(dict):
def __getattr__(self, key):
# return self[key]
try:
return self[key]
except KeyError:
raise AttributeError(key)
def __setattr__(self, key, value):
if key in self.__dict__:
self.__dict__[key] = value
else:
self[key] = value
def create_attr_dict(yaml_config):
from ast import literal_eval
for key, value in yaml_config.items():
if type(value) is dict:
yaml_config[key] = value = AttrDict(value)
if isinstance(value, str):
try:
value = literal_eval(value)
except BaseException:
pass
if isinstance(value, AttrDict):
create_attr_dict(yaml_config[key])
else:
yaml_config[key] = value
def parse_config(cfg_file):
"""Load a config file into AttrDict"""
with open(cfg_file, 'r') as fopen:
yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.SafeLoader))
create_attr_dict(yaml_config)
return yaml_config
def override(dl, ks, v):
"""
Recursively replace dict of list
Args:
dl(dict or list): dict or list to be replaced
ks(list): list of keys
v(str): value to be replaced
"""
def str2num(v):
try:
return eval(v)
except Exception:
return v
assert isinstance(dl, (list, dict)), ("{} should be a list or a dict")
assert len(ks) > 0, ('lenght of keys should larger than 0')
if isinstance(dl, list):
k = str2num(ks[0])
if len(ks) == 1:
assert k < len(dl), ('index({}) out of range({})'.format(k, dl))
dl[k] = str2num(v)
else:
override(dl[k], ks[1:], v)
else:
if len(ks) == 1:
assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
dl[ks[0]] = str2num(v)
else:
override(dl[ks[0]], ks[1:], v)
def override_config(config, options=None):
"""
Recursively override the config
Args:
config(dict): dict to be replaced
options(list): list of pairs(key0.key1.idx.key2=value)
such as: [
'topk=2',
'VALID.transforms.1.ResizeImage.resize_short=300'
]
Returns:
config(dict): replaced config
"""
if options is not None:
for opt in options:
assert isinstance(opt, str), (
"option({}) should be a str".format(opt))
assert "=" in opt, (
"option({}) should contain a ="
"to distinguish between key and value".format(opt))
pair = opt.split('=')
assert len(pair) == 2, ("there can be only a = in the option")
key, value = pair
keys = key.split('.')
override(config, keys, value)
return config
def get_config(fname, overrides=None, show=True):
"""
Read config from file
"""
assert os.path.exists(fname), (
'config file({}) is not exist'.format(fname))
config = parse_config(fname)
override_config(config, overrides)
return config
\ No newline at end of file
import os
import six
import pickle
import paddle
def makedirs(dir):
if not os.path.exists(dir):
os.makedirs(dir)
def save(state_dicts, file_name):
def convert(state_dict):
model_dict = {}
name_table = {}
for k, v in state_dict.items():
if isinstance(v, (paddle.framework.Variable, paddle.imperative.core.VarBase)):
model_dict[k] = v.numpy()
else:
model_dict[k] = v
print('enter k', k)
return state_dict
name_table[k] = v.name
model_dict["StructuredToParameterName@@"] = name_table
return model_dict
final_dict = {}
for k, v in state_dicts.items():
if isinstance(v, (paddle.framework.Variable, paddle.imperative.core.VarBase)):
final_dict = convert(state_dicts)
break
elif isinstance(v, dict):
final_dict[k] = convert(v)
else:
final_dict[k] = v
with open(file_name, 'wb') as f:
pickle.dump(final_dict, f, protocol=2)
def load(file_name):
with open(file_name, 'rb') as f:
state_dicts = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
return state_dicts
\ No newline at end of file
import random
import paddle
class ImagePool():
"""This class implements an image buffer that stores previously generated images.
This buffer enables us to update discriminators using a history of generated images
rather than the ones produced by the latest generators.
"""
def __init__(self, pool_size):
"""Initialize the ImagePool class
Parameters:
pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
"""
self.pool_size = pool_size
if self.pool_size > 0: # create an empty pool
self.num_imgs = 0
self.images = []
def query(self, images):
"""Return an image from the pool.
Parameters:
images: the latest generated images from the generator
Returns images from the buffer.
By 50/100, the buffer will return input images.
By 50/100, the buffer will return images previously stored in the buffer,
and insert the current images to the buffer.
"""
if self.pool_size == 0: # if the buffer size is 0, do nothing
return images
return_images = []
for image in images:
image = paddle.unsqueeze(image, 0)
if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer
self.num_imgs = self.num_imgs + 1
self.images.append(image)
return_images.append(image)
else:
p = random.uniform(0, 1)
if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
# FIXME: clone
# tmp = (self.images[random_id]).detach() #.clone()
tmp = self.images[random_id] #.clone()
self.images[random_id] = image
return_images.append(tmp)
else: # by another 50% chance, the buffer will return the current image
return_images.append(image)
return_images = paddle.concat(return_images, 0) # collect all the images and return
return return_images
import logging
import os
import sys
from paddle.imperative import ParallelEnv
def setup_logger(output=None, name="ppgan"):
"""
Initialize the detectron2 logger and set its verbosity level to "INFO".
Args:
output (str): a file name or a directory to save log. If None, will not save log file.
If ends with ".txt" or ".log", assumed to be a file name.
Otherwise, logs will be saved to `output/log.txt`.
name (str): the root module name of this logger
Returns:
logging.Logger: a logger
"""
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
logger.propagate = False
plain_formatter = logging.Formatter(
"[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S"
)
# stdout logging: master only
local_rank = ParallelEnv().local_rank
if local_rank == 0:
ch = logging.StreamHandler(stream=sys.stdout)
ch.setLevel(logging.DEBUG)
formatter = plain_formatter
ch.setFormatter(formatter)
logger.addHandler(ch)
# file logging: all workers
if output is not None:
if output.endswith(".txt") or output.endswith(".log"):
filename = output
else:
filename = os.path.join(output, "log.txt")
if local_rank > 0:
filename = filename + ".rank{}".format(local_rank)
# PathManager.mkdirs(os.path.dirname(filename))
os.makedirs(os.path.dirname(filename), exist_ok=True)
# fh = logging.StreamHandler(_cached_log_stream(filename)
fh = logging.FileHandler(filename, mode='a')
fh.setLevel(logging.DEBUG)
fh.setFormatter(plain_formatter)
logger.addHandler(fh)
return logger
\ No newline at end of file
import argparse
def parse_args():
parser = argparse.ArgumentParser(description='Segmentron')
parser.add_argument('--config-file', metavar="FILE",
help='config file path')
# cuda setting
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
# checkpoint and log
parser.add_argument('--resume', type=str, default=None,
help='put the path to resuming file if needed')
parser.add_argument('--load', type=str, default=None,
help='put the path to resuming file if needed')
# for evaluation
parser.add_argument('--val-interval', type=int, default=1,
help='run validation every interval')
parser.add_argument('--evaluate-only', action='store_true', default=False,
help='skip validation during training')
# config options
parser.add_argument('opts', help='See config for all options',
default=None, nargs=argparse.REMAINDER)
args = parser.parse_args()
return args
\ No newline at end of file
class Registry(object):
"""
The registry that provides name -> object mapping, to support third-party users' custom modules.
To create a registry (inside segmentron):
.. code-block:: python
BACKBONE_REGISTRY = Registry('BACKBONE')
To register an object:
.. code-block:: python
@BACKBONE_REGISTRY.register()
class MyBackbone():
...
Or:
.. code-block:: python
BACKBONE_REGISTRY.register(MyBackbone)
"""
def __init__(self, name):
"""
Args:
name (str): the name of this registry
"""
self._name = name
self._obj_map = {}
def _do_register(self, name, obj):
assert (
name not in self._obj_map
), "An object named '{}' was already registered in '{}' registry!".format(name, self._name)
self._obj_map[name] = obj
def register(self, obj=None, name=None):
"""
Register the given object under the the name `obj.__name__`.
Can be used as either a decorator or not. See docstring of this class for usage.
"""
if obj is None:
# used as a decorator
def deco(func_or_class, name=name):
if name is None:
name = func_or_class.__name__
self._do_register(name, func_or_class)
return func_or_class
return deco
# used as a function call
if name is None:
name = obj.__name__
self._do_register(name, obj)
def get(self, name):
ret = self._obj_map.get(name)
if ret is None:
raise KeyError("No object named '{}' found in '{}' registry!".format(name, self._name))
return ret
\ No newline at end of file
import os
import time
import paddle
from paddle.imperative import ParallelEnv
from .logger import setup_logger
def setup(args, cfg):
if args.evaluate_only:
cfg.isTrain = False
cfg.timestamp = time.strftime('-%Y-%m-%d-%H-%M', time.localtime())
cfg.output_dir = os.path.join(cfg.output_dir, str(cfg.model.name) + cfg.timestamp)
logger = setup_logger(cfg.output_dir)
logger.info('Configs: {}'.format(cfg))
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) \
if ParallelEnv().nranks > 1 else paddle.fluid.CUDAPlace(0)
paddle.enable_imperative(place)
import numpy as np
from PIL import Image
def tensor2img(input_image, imtype=np.uint8):
""""Converts a Tensor array into a numpy image array.
Parameters:
input_image (tensor) -- the input image tensor array
imtype (type) -- the desired type of the converted numpy array
"""
if not isinstance(input_image, np.ndarray):
image_numpy = input_image.numpy() # convert it into a numpy array
if len(image_numpy.shape) == 4:
image_numpy = image_numpy[0]
if image_numpy.shape[0] == 1: # grayscale to RGB
image_numpy = np.tile(image_numpy, (3, 1, 1))
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
else: # if it is a numpy array, do nothing
image_numpy = input_image
return image_numpy.astype(imtype)
def save_image(image_numpy, image_path, aspect_ratio=1.0):
"""Save a numpy image to the disk
Parameters:
image_numpy (numpy array) -- input numpy array
image_path (str) -- the path of the image
"""
image_pil = Image.fromarray(image_numpy)
h, w, _ = image_numpy.shape
if aspect_ratio > 1.0:
image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
if aspect_ratio < 1.0:
image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
image_pil.save(image_path)
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
cur_path = os.path.abspath(os.path.dirname(__file__))
root_path = os.path.split(cur_path)[0]
sys.path.append(root_path)
from ppgan.utils.options import parse_args
from ppgan.utils.config import get_config
from ppgan.utils.setup import setup
from ppgan.engine.trainer import Trainer
def main(args, cfg):
# init environment, include logger, dynamic graph, seed, device, train or test mode...
setup(args, cfg)
# build trainer
trainer = Trainer(cfg)
# continue train or evaluate, checkpoint need contain epoch and optimizer info
if args.resume:
trainer.resume(args.resume)
# evaluate or finute, only load generator weights
elif args.load:
trainer.load(args.load)
if args.evaluate_only:
trainer.test()
return
trainer.train()
if __name__ == '__main__':
args = parse_args()
cfg = get_config(args.config_file)
main(args, cfg)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册