未验证 提交 10cff870 编写于 作者: Y Yizhuang Zhou 提交者: GitHub

feat(gan): add GAN codebase (#31)

上级 b5bd2ae9
Generative Adversarial Networks
---
This directory provides code to build, train and evaluate popular GAN models including DCGAN and WGAN. Most of the code are modified from a well-written and reproducible GAN benchmark [pytorch_mimicry](https://github.com/kwotsin/mimicry).
We provide pretrained DCGAN and WGAN on cifar10. They use similar ResNet backbone and share the same training setting.
![images generated by DCGAN](../../assets/dcgan.png)
#### Training Parameters
| Resolution | Batch Size | Learning Rate | β<sub>1</sub> | β<sub>2</sub> | Decay Policy | n<sub>dis</sub> | n<sub>iter</sub> |
|:----------:|:----------:|:-------------:|:-------------:|:-------------:|:------------:|:---------------:|------------------|
| 32 x 32 | 64 | 2e-4 | 0.0 | 0.9 | Linear | 5 | 100K |
Their FID and Inception Score(IS) are listed below.
#### Metrics
| Metric | Method |
|:--------------------------------:|:---------------------------------------:|
| [Inception Score (IS)](https://arxiv.org/abs/1606.03498) | 50K samples at 10 splits|
| [Fréchet Inception Distance (FID)](https://arxiv.org/abs/1706.08500) | 50K real/generated samples |
| [Kernel Inception Distance (KID)](https://arxiv.org/abs/1801.01401) | 50K real/generated samples, averaged over 10 splits.|
#### Cifar10 Results
| Method | FID Score | IS Score | KID Score |
| :-: | :-: | :-: | :-: |
| DCGAN | 27.2 | 7.0 | 0.0242 |
| WGAN-WC | 30.5 | 6.7 | 0.0249 |
### Generate images with pretrained weights
```python
import megengine.hub as hub
import megengine_mimicry.nets.dcgan.dcgan_cifar as dcgan
import megengine_mimicry.utils.vis as vis
netG = dcgan.DCGANGeneratorCIFAR()
netG.load_state_dict(hub.load_serialized_obj_from_url("https://data.megengine.org.cn/models/weights/dcgan_cifar.pkl"))
images = dcgan_generator.generate_images(num_images=64) # in NCHW format with normalized pixel values in [0, 1]
grid = vis.make_grid(images) # in HW3 format with [0, 255] BGR images for visualization
vis.save_image(grid, "visual.png")
```
### Train and evaluate a DCGAN or WGAN
```bash
# train and evaluate a DCGAN
python3 train_dcgan.py
# train and evaluate a WGAN
python3 train_wgan.py
```
#### Tensorboard visualization
```bash
tensorboard --logdir ./log --bind_all
```
# Copyright (c) 2020 Kwot Sin Lee
# This code is licensed under MIT license
# (https://github.com/kwotsin/mimicry/blob/master/LICENSE)
# ------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
from . import nets, training, datasets, metrics
# Copyright (c) 2020 Kwot Sin Lee
# This code is licensed under MIT license
# (https://github.com/kwotsin/mimicry/blob/master/LICENSE)
# ------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
"""
Script for loading datasets.
"""
import os
import megengine.data as data
import megengine.data.transform as T
def load_dataset(root, name, **kwargs):
"""
Loads different datasets specifically for GAN training.
By default, all images are normalized to values in the range [-1, 1].
Args:
root (str): Path to where datasets are stored.
name (str): Name of dataset to load.
Returns:
Dataset: Torch Dataset object for a specific dataset.
"""
if name == "cifar10":
return load_cifar10_dataset(root, **kwargs)
else:
raise ValueError("Invalid dataset name {} selected.".format(name))
def load_cifar10_dataset(root=None,
split='train',
download=True,
**kwargs):
"""
Loads the CIFAR-10 dataset.
Args:
root (str): Path to where datasets are stored.
split (str): The split of data to use.
download (bool): If True, downloads the dataset.
Returns:
Dataset: Torch Dataset object.
"""
dataset_dir = root
if dataset_dir and not os.path.exists(dataset_dir):
os.makedirs(dataset_dir)
# Build datasets
if split == "train":
dataset = data.dataset.CIFAR10(root=dataset_dir,
train=True,
download=download,
**kwargs)
elif split == "test":
dataset = data.dataset.CIFAR10(root=dataset_dir,
train=False,
download=download,
**kwargs)
else:
raise ValueError("split argument must one of ['train', 'val']")
return dataset
# Copyright (c) 2020 Kwot Sin Lee
# This code is licensed under MIT license
# (https://github.com/kwotsin/mimicry/blob/master/LICENSE)
# ------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
"""
Loads randomly sampled images from datasets for computing metrics.
"""
import os
import numpy as np
import megengine.data.transform as T
from . import data_utils
def get_random_images(dataset, num_samples):
"""
Randomly sample without replacement num_samples images.
Args:
dataset (Dataset): Torch Dataset object for indexing elements.
num_samples (int): The number of images to randomly sample.
Returns:
Tensor: Batch of num_samples images in np array form [N, H, W, C](0-255).
"""
choices = np.random.choice(range(len(dataset)),
size=num_samples,
replace=False)
images = []
for choice in choices:
img = np.array(dataset[choice][0])
img = np.expand_dims(img, axis=0)
images.append(img)
images = np.concatenate(images, axis=0)
return images
def get_cifar10_images(num_samples, root=None, **kwargs):
"""
Loads randomly sampled CIFAR-10 training images.
Args:
num_samples (int): The number of images to randomly sample.
root (str): The root directory where all datasets are stored.
Returns:
Tensor: Batch of num_samples images in np array form.
"""
dataset = data_utils.load_cifar10_dataset(root=root, **kwargs)
images = get_random_images(dataset, num_samples)
return images
def get_dataset_images(dataset_name, num_samples=50000, **kwargs):
"""
Randomly sample num_samples images based on input dataset name.
Args:
dataset_name (str): Dataset name to load images from.
num_samples (int): The number of images to randomly sample.
Returns:
Tensor: Batch of num_samples images from the specific dataset in np array form.
"""
if dataset_name == "cifar10":
images = get_cifar10_images(num_samples, **kwargs)
elif dataset_name == "cifar10_test":
images = get_cifar10_images(num_samples, split='test', **kwargs)
else:
raise ValueError("Invalid dataset name {}.".format(dataset_name))
# Check shape and permute if needed
if images.shape[1] == 3:
images = images.transpose((0, 2, 3, 1))
# Ensure the values lie within the correct range, otherwise there might be some
# preprocessing error from the library causing ill-valued scores.
if np.min(images) < 0 or np.max(images) > 255:
raise ValueError(
'Image pixel values must lie between 0 to 255 inclusive.')
return images
# Copyright (c) 2020 Kwot Sin Lee
# This code is licensed under MIT license
# (https://github.com/kwotsin/mimicry/blob/master/LICENSE)
# ------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
from . import fid, kid, inception_score, inception_model
from .compute_fid import *
from .compute_is import *
from .compute_kid import *
from .compute_metrics import *
# Copyright (c) 2020 Kwot Sin Lee
# This code is licensed under MIT license
# (https://github.com/kwotsin/mimicry/blob/master/LICENSE)
# ------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
"""
MegEngine interface for computing FID.
"""
import os
import random
import time
import numpy as np
import tensorflow as tf
from ..datasets.image_loader import get_dataset_images
from .fid import fid_utils
from .inception_model import inception_utils
from .utils import _normalize_images
def compute_real_dist_stats(num_samples,
sess,
dataset_name,
batch_size,
stats_file=None,
seed=0,
verbose=True,
log_dir='./log'):
"""
Reads the image data and compute the FID mean and cov statistics
for real images.
Args:
num_samples (int): Number of real images to compute statistics.
sess (Session): TensorFlow session to use.
dataset_name (str): The name of the dataset to load.
batch_size (int): The batch size to feedforward for inference.
stats_file (str): The statistics file to load from if there is already one.
verbose (bool): If True, prints progress of computation.
log_dir (str): Directory where feature statistics can be stored.
Returns:
ndarray: Mean features stored as np array.
ndarray: Covariance of features stored as np array.
"""
# Create custom stats file name
if stats_file is None:
stats_dir = os.path.join(log_dir, 'metrics', 'fid', 'statistics')
if not os.path.exists(stats_dir):
os.makedirs(stats_dir)
stats_file = os.path.join(
stats_dir,
"fid_stats_{}_{}k_run_{}.npz".format(dataset_name,
num_samples // 1000, seed))
if stats_file and os.path.exists(stats_file):
print("INFO: Loading existing statistics for real images...")
f = np.load(stats_file)
m_real, s_real = f['mu'][:], f['sigma'][:]
f.close()
else:
# Obtain the numpy format data
print("INFO: Obtaining images...")
images = get_dataset_images(dataset_name, num_samples=num_samples)
images = images[:, :, :, ::-1] # NOTE: opencv image convert to RGB
# Compute the mean and cov
print("INFO: Computing statistics for real images...")
m_real, s_real = fid_utils.calculate_activation_statistics(
images=images, sess=sess, batch_size=batch_size, verbose=verbose)
if not os.path.exists(stats_file):
print("INFO: Saving statistics for real images...")
np.savez(stats_file, mu=m_real, sigma=s_real)
return m_real, s_real
def compute_gen_dist_stats(netG,
num_samples,
sess,
device,
seed,
batch_size,
print_every=20,
verbose=True):
"""
Directly produces the images and convert them into numpy format without
saving the images on disk.
Args:
netG (Module): Torch Module object representing the generator model.
num_samples (int): The number of fake images for computing statistics.
sess (Session): TensorFlow session to use.
device (str): Device identifier to use for computation.
seed (int): The random seed to use.
batch_size (int): The number of samples per batch for inference.
print_every (int): Interval for printing log.
verbose (bool): If True, prints progress.
Returns:
ndarray: Mean features stored as np array.
ndarray: Covariance of features stored as np array.
"""
# Set model to evaluation mode
netG.eval() # NOTE: in MegEngine this may has no effect
# Inference variables
batch_size = min(num_samples, batch_size)
# Collect all samples()
images = []
start_time = time.time()
for idx in range(num_samples // batch_size):
# Collect fake image
fake_images = netG.generate_images(num_images=batch_size).numpy()
images.append(fake_images)
# Print some statistics
if (idx + 1) % print_every == 0:
end_time = time.time()
print(
"INFO: Generated image {}/{} [Random Seed {}] ({:.4f} sec/idx)"
.format(
(idx + 1) * batch_size, num_samples, seed,
(end_time - start_time) / (print_every * batch_size)))
start_time = end_time
# Produce images in the required (N, H, W, 3) format for FID computation
images = np.concatenate(images, 0) # Gives (N, 3, H, W) BGR
images = _normalize_images(images) # Gives (N, H, W, 3) RGB
# Compute the FID
print("INFO: Computing statistics for fake images...")
m_fake, s_fake = fid_utils.calculate_activation_statistics(
images=images, sess=sess, batch_size=batch_size, verbose=verbose)
return m_fake, s_fake
def fid_score(num_real_samples,
num_fake_samples,
netG,
device,
seed,
dataset_name,
batch_size=50,
verbose=True,
stats_file=None,
log_dir='./log'):
"""
Computes FID stats using functions that store images in memory for speed and fidelity.
Fidelity since by storing images in memory, we don't subject the scores to different read/write
implementations of imaging libraries.
Args:
num_real_samples (int): The number of real images to use for FID.
num_fake_samples (int): The number of fake images to use for FID.
netG (Module): Torch Module object representing the generator model.
device (str): Device identifier to use for computation.
seed (int): The random seed to use.
dataset_name (str): The name of the dataset to load.
batch_size (int): The batch size to feedforward for inference.
verbose (bool): If True, prints progress.
stats_file (str): The statistics file to load from if there is already one.
log_dir (str): Directory where feature statistics can be stored.
Returns:
float: Scalar FID score.
"""
start_time = time.time()
# Make sure the random seeds are fixed
# torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
# Setup directories
inception_path = os.path.join(log_dir, 'metrics', 'inception_model')
# Setup the inception graph
inception_utils.create_inception_graph(inception_path)
# Start producing statistics for real and fake images
if device is not None:
# Avoid unbounded memory usage
gpu_options = tf.compat.v1.GPUOptions(allow_growth=True,
per_process_gpu_memory_fraction=0.15,
visible_device_list=str(device))
config = tf.compat.v1.ConfigProto(gpu_options=gpu_options)
else:
config = tf.compat.v1.ConfigProto(device_count={'GPU': 0})
with tf.compat.v1.Session(config=config) as sess:
sess.run(tf.compat.v1.global_variables_initializer())
m_real, s_real = compute_real_dist_stats(num_samples=num_real_samples,
sess=sess,
dataset_name=dataset_name,
batch_size=batch_size,
verbose=verbose,
stats_file=stats_file,
log_dir=log_dir,
seed=seed)
m_fake, s_fake = compute_gen_dist_stats(netG=netG,
num_samples=num_fake_samples,
sess=sess,
device=device,
seed=seed,
batch_size=batch_size,
verbose=verbose)
FID_score = fid_utils.calculate_frechet_distance(mu1=m_real,
sigma1=s_real,
mu2=m_fake,
sigma2=s_fake)
print("INFO: FID Score: {} [Time Taken: {:.4f} secs]".format(
FID_score,
time.time() - start_time))
return float(FID_score)
# Copyright (c) 2020 Kwot Sin Lee
# This code is licensed under MIT license
# (https://github.com/kwotsin/mimicry/blob/master/LICENSE)
# ------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
"""
MegEngine interface for computing Inception Score.
"""
import os
import random
import time
import numpy as np
from .inception_model import inception_utils
from .inception_score import inception_score_utils as tf_inception_score
from .utils import _normalize_images
def inception_score(netG,
device,
num_samples,
batch_size=50,
splits=10,
log_dir='./log',
seed=0,
print_every=20):
"""
Computes the inception score of generated images.
Args:
netG (Module): The generator model to use for generating images.
device (Device): Torch device object to send model and data to.
num_samples (int): The number of samples to generate.
batch_size (int): Batch size per feedforward step for inception model.
splits (int): The number of splits to use for computing IS.
log_dir (str): Path to store metric computation objects.
seed (int): Random seed for generation.
Returns:
Mean and standard deviation of the inception score computed from using
num_samples generated images.
"""
# Make sure the random seeds are fixed
random.seed(seed)
np.random.seed(seed)
# Build inception
inception_path = os.path.join(log_dir, 'metrics/inception_model')
inception_utils.create_inception_graph(inception_path)
# Inference variables
batch_size = min(batch_size, num_samples)
num_batches = num_samples // batch_size
# Get images
images = []
start_time = time.time()
for idx in range(num_batches):
fake_images = netG.generate_images(num_images=batch_size).numpy()
fake_images = _normalize_images(fake_images) # NCHW(BGR) -> NHWC(RGB)
images.append(fake_images)
if (idx + 1) % min(print_every, num_batches) == 0:
end_time = time.time()
print(
"INFO: Generated image {}/{} [Random Seed {}] ({:.4f} sec/idx)"
.format(
(idx + 1) * batch_size, num_samples, seed,
(end_time - start_time) / (print_every * batch_size)))
start_time = end_time
images = np.concatenate(images, axis=0)
IS_score = tf_inception_score.get_inception_score(images,
splits=splits,
device=device)
print("INFO: IS Score: {}".format(IS_score))
return IS_score
# Copyright (c) 2020 Kwot Sin Lee
# This code is licensed under MIT license
# (https://github.com/kwotsin/mimicry/blob/master/LICENSE)
# ------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
"""
MegEngine interface for computing KID.
"""
import os
import random
import time
import numpy as np
import tensorflow as tf
from ..datasets.image_loader import get_dataset_images
from .inception_model import inception_utils
from .kid import kid_utils
from .utils import _normalize_images
def compute_real_dist_feat(num_samples,
sess,
dataset_name,
batch_size,
seed=0,
verbose=True,
feat_file=None,
log_dir='./log'):
"""
Reads the image data and compute the real image features.
Args:
num_samples (int): Number of real images to compute features.
sess (Session): TensorFlow session to use.
dataset_name (str): The name of the dataset to load.
batch_size (int): The batch size to feedforward for inference.
feat_file (str): The features file to load from if there is already one.
verbose (bool): If True, prints progress of computation.
log_dir (str): Directory where features can be stored.
Returns:
ndarray: Inception features of real images.
"""
# Create custom feat file name
if feat_file is None:
feat_dir = os.path.join(log_dir, 'metrics', 'kid', 'features')
if not os.path.exists(feat_dir):
os.makedirs(feat_dir)
feat_file = os.path.join(
feat_dir,
"kid_feat_{}_{}k_run_{}.npz".format(dataset_name,
num_samples // 1000, seed))
if feat_file and os.path.exists(feat_file):
print("INFO: Loading existing features for real images...")
f = np.load(feat_file)
real_feat = f['feat'][:]
f.close()
else:
# Obtain the numpy format data
print("INFO: Obtaining images...")
images = get_dataset_images(dataset_name, num_samples=num_samples)
# Compute the mean and cov
print("INFO: Computing features for real images...")
real_feat = inception_utils.get_activations(images=images,
sess=sess,
batch_size=batch_size,
verbose=verbose)
print("INFO: Saving features for real images...")
np.savez(feat_file, feat=real_feat)
return real_feat
def compute_gen_dist_feat(netG,
num_samples,
sess,
device,
seed,
batch_size,
print_every=20,
verbose=True):
"""
Directly produces the images and convert them into numpy format without
saving the images on disk.
Args:
netG (Module): Torch Module object representing the generator model.
num_samples (int): The number of fake images for computing features.
sess (Session): TensorFlow session to use.
device (str): Device identifier to use for computation.
seed (int): The random seed to use.
batch_size (int): The number of samples per batch for inference.
print_every (int): Interval for printing log.
verbose (bool): If True, prints progress.
Returns:
ndarray: Inception features of generated images.
"""
batch_size = min(num_samples, batch_size)
# Set model to evaluation mode
netG.eval()
# Collect num_samples of fake images
images = []
# Collect all samples
start_time = time.time()
for idx in range(num_samples // batch_size):
fake_images = netG.generate_images(num_images=batch_size).numpy()
# Collect fake image
images.append(fake_images)
# Print some statistics
if (idx + 1) % print_every == 0:
end_time = time.time()
print(
"INFO: Generated image {}/{} [Random Seed {}] ({:.4f} sec/idx)"
.format(
(idx + 1) * batch_size, num_samples, seed,
(end_time - start_time) / (print_every * batch_size)))
start_time = end_time
# Produce images in the required (N, H, W, 3) format for kid computation
images = np.concatenate(images, 0) # Gives (N, 3, H, W) BGR
images = _normalize_images(images) # Gives (N, H, W, 3) RGB
# Compute the kid
print("INFO: Computing features for fake images...")
fake_feat = inception_utils.get_activations(images=images,
sess=sess,
batch_size=batch_size,
verbose=verbose)
return fake_feat
def kid_score(num_subsets,
subset_size,
netG,
device,
seed,
dataset_name,
batch_size=50,
verbose=True,
feat_file=None,
log_dir='./log'):
"""
Computes KID score.
Args:
num_subsets (int): Number of subsets to compute average MMD.
subset_size (int): Size of subset for computing MMD.
netG (Module): Torch Module object representing the generator model.
device (str): Device identifier to use for computation.
seed (int): The random seed to use.
dataset_name (str): The name of the dataset to load.
batch_size (int): The batch size to feedforward for inference.
feat_file (str): The path to specific inception features for real images.
log_dir (str): Directory where features can be stored.
verbose (bool): If True, prints progress.
Returns:
tuple: Scalar mean and std of KID scores computed.
"""
start_time = time.time()
# Make sure the random seeds are fixed
random.seed(seed)
np.random.seed(seed)
# Directories
inception_path = os.path.join(log_dir, 'metrics', 'inception_model')
# Setup the inception graph
inception_utils.create_inception_graph(inception_path)
# Decide sample size
num_samples = int(num_subsets * subset_size)
# Start producing features for real and fake images
if device is not None:
# Avoid unbounded memory usage
gpu_options = tf.compat.v1.GPUOptions(allow_growth=True,
per_process_gpu_memory_fraction=0.15,
visible_device_list=str(device))
config = tf.compat.v1.ConfigProto(gpu_options=gpu_options)
else:
config = tf.compat.v1.ConfigProto(device_count={'GPU': 0})
with tf.compat.v1.Session(config=config) as sess:
sess.run(tf.compat.v1.global_variables_initializer())
real_feat = compute_real_dist_feat(num_samples=num_samples,
sess=sess,
dataset_name=dataset_name,
batch_size=batch_size,
verbose=verbose,
feat_file=feat_file,
log_dir=log_dir,
seed=seed)
fake_feat = compute_gen_dist_feat(netG=netG,
num_samples=num_samples,
sess=sess,
device=device,
seed=seed,
batch_size=batch_size,
verbose=verbose)
# Compute the KID score
scores = kid_utils.polynomial_mmd_averages(real_feat,
fake_feat,
n_subsets=num_subsets,
subset_size=subset_size)
mmd_score, mmd_std = float(np.mean(scores)), float(np.std(scores))
print("INFO: KID: {:.4f} ± {:.4f} [Time Taken: {:.4f} secs]".format(
mmd_score, mmd_std,
time.time() - start_time))
return mmd_score, mmd_std
# Copyright (c) 2020 Kwot Sin Lee
# This code is licensed under MIT license
# (https://github.com/kwotsin/mimicry/blob/master/LICENSE)
# ------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
"""
Computes different GAN metrics for a generator.
"""
import os
import numpy as np
from . import compute_fid, compute_is, compute_kid
from ..utils import common
def evaluate(metric,
netG,
log_dir,
evaluate_range=None,
evaluate_step=None,
num_runs=3,
start_seed=0,
overwrite=False,
write_to_json=True,
device=None,
**kwargs):
"""
Evaluates a generator over several runs.
Args:
metric (str): The name of the metric for evaluation.
netG (Module): Torch generator model to evaluate.
log_dir (str): The path to the log directory.
evaluate_range (tuple): The 3 valued tuple for defining a for loop.
evaluate_step (int): The specific checkpoint to load. Used in place of evaluate_range.
device (str): Device identifier to use for computation.
num_runs (int): The number of runs to compute FID for each checkpoint.
start_seed (int): Starting random seed to use.
write_to_json (bool): If True, writes to an output json file in log_dir.
overwrite (bool): If True, then overwrites previous metric score.
Returns:
None
"""
if metric == 'kid':
if 'num_subsets' not in kwargs or 'subset_size' not in kwargs:
raise ValueError(
"num_subsets and subset_size must be provided for KID computation.")
elif metric == 'fid':
if 'num_real_samples' not in kwargs or 'num_fake_samples' not in kwargs:
raise ValueError(
"num_real_samples and num_fake_samples must be provided for FID computation.")
elif metric == 'inception_score':
if 'num_samples' not in kwargs:
raise ValueError("num_samples must be provided for IS computation.")
else:
choices = ['fid', 'kid', 'inception_score']
raise ValueError("Invalid metric {} selected. Choose from {}.".format(metric, choices))
if evaluate_range and evaluate_step or not (evaluate_step
or evaluate_range):
raise ValueError(
"Only one of evaluate_step or evaluate_range can be defined.")
if evaluate_range:
if (type(evaluate_range) != tuple
or not all(map(lambda x: type(x) == int, evaluate_range))):
raise ValueError(
"evaluate_range must be a tuple of ints (start, end, step).")
ckpt_dir = os.path.join(log_dir, 'checkpoints', 'netG')
if not os.path.exists(ckpt_dir):
raise ValueError(
"Checkpoint directory {} cannot be found in log_dir.".format(
ckpt_dir))
# Decide naming convention
names_dict = {
'fid': 'FID',
'inception_score': 'Inception Score',
'kid': 'KID',
}
# Set output file and restore if available.
if metric == 'fid':
output_file = os.path.join(
log_dir,
'fid_{}k_{}k.json'.format(kwargs['num_real_samples'] // 1000,
kwargs['num_fake_samples'] // 1000))
elif metric == 'inception_score':
output_file = os.path.join(
log_dir,
'inception_score_{}k.json'.format(kwargs['num_samples'] // 1000))
elif metric == 'kid':
output_file = os.path.join(
log_dir, 'kid_{}k_{}_subsets.json'.format(
kwargs['num_subsets'] * kwargs['subset_size'] // 1000,
kwargs['num_subsets']))
if os.path.exists(output_file):
scores_dict = common.load_from_json(output_file)
scores_dict = dict([(int(k), v) for k, v in scores_dict.items()])
else:
scores_dict = {}
# Evaluate across a range
start, end, interval = evaluate_range or (evaluate_step, evaluate_step,
evaluate_step)
for step in range(start, end + 1, interval):
# Skip computed scores
if step in scores_dict and write_to_json and not overwrite:
print("INFO: {} at step {} has been computed. Skipping...".format(
names_dict[metric], step))
continue
# Load and restore the model checkpoint
ckpt_file = os.path.join(ckpt_dir, 'netG_{}_steps.pth'.format(step))
if not os.path.exists(ckpt_file):
print("INFO: Checkpoint at step {} does not exist. Skipping...".
format(step))
continue
netG.restore_checkpoint(ckpt_file=ckpt_file, optimizer=None)
# Compute score for each seed
scores = []
for seed in range(start_seed, start_seed + num_runs):
print("INFO: Computing {} in memory...".format(names_dict[metric]))
# Obtain only the raw score without var
if metric == "fid":
score = compute_fid.fid_score(netG=netG,
seed=seed,
device=device,
log_dir=log_dir,
**kwargs)
elif metric == "inception_score":
score, _ = compute_is.inception_score(netG=netG,
seed=seed,
device=device,
log_dir=log_dir,
**kwargs)
elif metric == "kid":
score, _ = compute_kid.kid_score(netG=netG,
device=device,
seed=seed,
log_dir=log_dir,
**kwargs)
scores.append(score)
print("INFO: {} (step {}) [seed {}]: {}".format(
names_dict[metric], step, seed, score))
scores_dict[step] = scores
# Print the scores in order
for step in range(start, end + 1, interval):
if step in scores_dict:
scores = scores_dict[step]
mean = np.mean(scores)
std = np.std(scores)
print("INFO: {} (step {}): {} (± {}) ".format(
names_dict[metric], step, mean, std))
# Save to output file
if write_to_json:
common.write_to_json(scores_dict, output_file)
print("INFO: {} Evaluation completed!".format(names_dict[metric]))
return scores_dict
# Copyright (c) 2020 Kwot Sin Lee
# This code is licensed under MIT license
# (https://github.com/kwotsin/mimicry/blob/master/LICENSE)
# ------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
"""
Helper functions for calculating FID as adopted from the official FID code:
https://github.com/kwotsin/dissertation/blob/master/eval/TTUR/fid.py
"""
import numpy as np
from scipy import linalg
from ..inception_model import inception_utils
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
"""
Numpy implementation of the Frechet Distance.
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
and X_2 ~ N(mu_2, C_2) is
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
Stable version by Dougal J. Sutherland.
Args:
mu1 : Numpy array containing the activations of the pool_3 layer of the
inception net ( like returned by the function 'get_predictions')
for generated samples.
mu2: The sample mean over activations of the pool_3 layer, precalcualted
on an representive data set.
sigma1 (ndarray): The covariance matrix over activations of the pool_3 layer for
generated samples.
sigma2: The covariance matrix over activations of the pool_3 layer,
precalcualted on an representive data set.
Returns:
np.float64: The Frechet Distance.
"""
if mu1.shape != mu2.shape or sigma1.shape != sigma2.shape:
raise ValueError(
"(mu1, sigma1) should have exactly the same shape as (mu2, sigma2)."
)
mu1 = np.atleast_1d(mu1)
mu2 = np.atleast_1d(mu2)
sigma1 = np.atleast_2d(sigma1)
sigma2 = np.atleast_2d(sigma2)
diff = mu1 - mu2
# Product might be almost singular
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
print(
"WARNING: fid calculation produces singular product; adding {} to diagonal of cov estimates"
.format(eps))
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
# Numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
m = np.max(np.abs(covmean.imag))
raise ValueError("Imaginary component {}".format(m))
covmean = covmean.real
tr_covmean = np.trace(covmean)
return diff.dot(diff) + np.trace(sigma1) + np.trace(
sigma2) - 2 * tr_covmean
def calculate_activation_statistics(images, sess, batch_size=50, verbose=True):
"""
Calculation of the statistics used by the FID.
Args:
images (ndarray): Numpy array of shape (N, H, W, 3) and values in
the range [0, 255].
sess (Session): TensorFlow session object.
batch_size (int): Batch size for inference.
verbose (bool): If True, prints out logging information.
Returns:
ndarray: Mean of inception features from samples.
ndarray: Covariance of inception features from samples.
"""
act = inception_utils.get_activations(images, sess, batch_size, verbose)
mu = np.mean(act, axis=0)
sigma = np.cov(act, rowvar=False)
return mu, sigma
# Copyright (c) 2020 Kwot Sin Lee
# This code is licensed under MIT license
# (https://github.com/kwotsin/mimicry/blob/master/LICENSE)
# ------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
"""
Common inception utils for computing metrics, as based on the FID helper code:
https://github.com/kwotsin/dissertation/blob/master/eval/TTUR/fid.py
"""
import os
import pathlib
import tarfile
import time
from urllib import request
import numpy as np
import tensorflow as tf
def _check_or_download_inception(inception_path):
"""
Checks if the path to the inception file is valid, or downloads
the file if it is not present.
Args:
inception_path (str): Directory for storing the inception model.
Returns:
str: File path of the inception protobuf model.
"""
# Build file path of model
inception_url = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
if inception_path is None:
inception_path = '/tmp'
inception_path = pathlib.Path(inception_path)
model_file = inception_path / 'classify_image_graph_def.pb'
# Download model if required
if not model_file.exists():
print("Downloading Inception model")
fn, _ = request.urlretrieve(inception_url)
with tarfile.open(fn, mode='r') as f:
f.extract('classify_image_graph_def.pb', str(model_file.parent))
return str(model_file)
def _get_inception_layer(sess):
"""
Prepares inception net for batched usage and returns pool_3 layer.
Args:
sess (Session): TensorFlow Session object.
Returns:
TensorFlow graph node representing inception model pool3 layer output.
"""
# Get the output node
layer_name = 'inception_model/pool_3:0'
pool3 = sess.graph.get_tensor_by_name(layer_name)
# Reshape to be batch size agnostic
ops = pool3.graph.get_operations()
for op_idx, op in enumerate(ops):
for o in op.outputs:
shape = o.get_shape()
if len(shape._dims) > 0:
try:
shape = [s.value for s in shape]
except AttributeError: # TF 2 uses None shape directly. No conversion needed.
shape = shape
new_shape = []
for j, s in enumerate(shape):
if s == 1 and j == 0:
new_shape.append(None)
else:
new_shape.append(s)
o.__dict__['_shape_val'] = tf.TensorShape(new_shape)
return pool3
def get_activations(images, sess, batch_size=50, verbose=True):
"""
Calculates the activations of the pool_3 layer for all images.
Args:
images (ndarray): Numpy array of shape (N, C, H, W) with values ranging
in the range [0, 255].
sess (Session): TensorFlow Session object.
batch_size (int): The batch size to use for inference.
verbose (bool): If True, prints out logging data for batch inference.
Returns:
ndarray: Numpy array of shape (N, 2048) representing the pool3 features from the
inception model.
"""
# Get output layer.
inception_layer = _get_inception_layer(sess)
# Inference variables
batch_size = min(batch_size, images.shape[0])
num_batches = images.shape[0] // batch_size
# Get features
pred_arr = np.empty((images.shape[0], 2048))
for i in range(num_batches):
start_time = time.time()
start = i * batch_size
end = start + batch_size
batch = images[start:end]
pred = sess.run(inception_layer,
{'inception_model/ExpandDims:0': batch})
pred_arr[start:end] = pred.reshape(batch_size, -1)
if verbose:
print("\rINFO: Propagated batch %d/%d (%.4f sec/batch)" \
% (i+1, num_batches, time.time()-start_time), end="", flush=True)
return pred_arr
def create_inception_graph(inception_path):
"""
Creates a graph from saved GraphDef file.
Args:
inception_path (str): Directory for storing the inception model.
Returns:
None
"""
if not os.path.exists(inception_path):
os.makedirs(inception_path)
# Get inception model file path
model_file = _check_or_download_inception(inception_path)
# Creates graph from saved graph_def.pb.
with tf.io.gfile.GFile(model_file, 'rb') as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='inception_model')
# Copyright (c) 2020 Kwot Sin Lee
# This code is licensed under MIT license
# (https://github.com/kwotsin/mimicry/blob/master/LICENSE)
# ------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
"""
Helper functions for computing inception score, as based on:
https://github.com/openai/improved-gan/tree/master/inception_score
"""
import time
import numpy as np
import tensorflow as tf
from ..inception_model import inception_utils
def get_predictions(images, device=None, batch_size=50, print_every=20):
"""
Get the output probabilities of images.
Args:
images (ndarray): Batch of images of shape (N, H, W, 3).
device (Device): Torch device object.
batch_size (int): Batch size for inference using inception model.
print_every (int): Prints logging variable every n batch inferences.
Returns:
ndarray: Batch of probabilities of equal size as number of images input.
"""
if device is not None:
# Avoid unbounded memory usage
gpu_options = tf.compat.v1.GPUOptions(allow_growth=True,
per_process_gpu_memory_fraction=0.15,
visible_device_list=str(device))
config = tf.compat.v1.ConfigProto(gpu_options=gpu_options)
else:
config = tf.compat.v1.ConfigProto(device_count={'GPU': 0})
# Inference variables
batch_size = min(batch_size, images.shape[0])
num_batches = images.shape[0] // batch_size
# Get predictions
preds = []
with tf.compat.v1.Session(config=config) as sess:
# Batch input preparation
inception_utils._get_inception_layer(sess)
# Define input/outputs of default graph.
pool3 = sess.graph.get_tensor_by_name('inception_model/pool_3:0')
w = sess.graph.get_operation_by_name(
"inception_model/softmax/logits/MatMul").inputs[1]
logits = tf.matmul(tf.squeeze(pool3, [1, 2]), w)
softmax = tf.nn.softmax(logits)
# Predict images
start_time = time.time()
for i in range(num_batches):
batch = images[i * batch_size:(i + 1) * batch_size]
# curr_image = np.expand_dims(images[i], axis=0)
pred = sess.run(softmax, {'inception_model/ExpandDims:0': batch})
preds.append(pred)
if (i + 1) % min(print_every, num_batches) == 0:
end_time = time.time()
print("INFO: Processed image {}/{}...({:.4f} sec/idx)".format(
(i + 1) * batch_size, images.shape[0],
(end_time - start_time) / (print_every * batch_size)))
start_time = end_time
preds = np.concatenate(preds, 0)
return preds
def get_inception_score(images, splits=10, device=None):
"""
Computes inception score according to official OpenAI implementation.
Args:
images (ndarray): Batch of images of shape (N, H, W, 3), which should have values
in the range [0, 255].
splits (int): Number of splits to use for computing IS.
device (Device): Torch device object to decide which GPU to use for TF session.
Returns:
tuple: Tuple of mean and standard deviation of the inception score computed.
"""
if np.max(images[0] < 10) and np.max(images[0] < 0):
raise ValueError("Images should have value ranging from 0 to 255.")
# Load graph and get probabilities
preds = get_predictions(images, device=device)
# Compute scores
N = preds.shape[0]
scores = []
for i in range(splits):
part = preds[(i * N // splits):((i + 1) * N // splits), :]
kl = part * (np.log(part) -
np.log(np.expand_dims(np.mean(part, 0), 0)))
kl = np.mean(np.sum(kl, 1))
scores.append(np.exp(kl))
return float(np.mean(scores)), float(np.std(scores))
# Copyright (c) 2020 Kwot Sin Lee
# This code is licensed under MIT license
# (https://github.com/kwotsin/mimicry/blob/master/LICENSE)
# ------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
"""
Helper functions for computing FID, as based on:
https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py
"""
import numpy as np
from sklearn.metrics.pairwise import polynomial_kernel
def polynomial_mmd(codes_g, codes_r, degree=3, gamma=None, coef0=1):
"""
Compute MMD between two sets of features.
Polynomial kernel given by:
K(X, Y) = (gamma <X, Y> + coef0)^degree
Args:
codes_g (ndarray): Set of features from 1st distribution.
codes_r (ndarray): Set of features from 2nd distribution.
degree (int): Power of the kernel.
gamma (float): Scaling factor of dot product.
coeff0 (float): Constant factor of kernel.
Returns:
np.float64: Scalar MMD score between features of 2 distributions.
"""
X = codes_g
Y = codes_r
K_XX = polynomial_kernel(X, degree=degree, gamma=gamma, coef0=coef0)
K_YY = polynomial_kernel(Y, degree=degree, gamma=gamma, coef0=coef0)
K_XY = polynomial_kernel(X, Y, degree=degree, gamma=gamma, coef0=coef0)
return _compute_mmd2(K_XX, K_XY, K_YY)
def polynomial_mmd_averages(codes_g,
codes_r,
n_subsets=50,
subset_size=1000,
**kernel_args):
"""
Computes average MMD between two set of features using n_subsets,
each of which is of subset_size.
Args:
codes_g (ndarray): Set of features from 1st distribution.
codes_r (ndarray): Set of features from 2nd distribution.
n_subsets (int): Number of subsets to compute averages.
subset_size (int): Size of each subset of features to choose.
Returns:
list: List of n_subsets MMD scores.
"""
m = min(codes_g.shape[0], codes_r.shape[0])
mmds = np.zeros(n_subsets)
# Account for inordinately small subset sizes
n_subsets = min(m, n_subsets)
subset_size = min(subset_size, m // n_subsets)
for i in range(n_subsets):
g = codes_g[np.random.choice(len(codes_g), subset_size, replace=False)]
r = codes_r[np.random.choice(len(codes_r), subset_size, replace=False)]
o = polynomial_mmd(g, r, **kernel_args)
mmds[i] = o
return mmds
def _sqn(arr):
flat = np.ravel(arr)
return flat.dot(flat)
def _compute_mmd2(K_XX,
K_XY,
K_YY,
unit_diagonal=False,
mmd_est='unbiased'):
"""
Based on https://github.com/dougalsutherland/opt-mmd/blob/master/two_sample/mmd.py
but changed to not compute the full kernel matrix at once.
"""
if mmd_est not in ['unbiased', 'u-statistic']:
raise ValueError(
"mmd_est should be one of [unbiased', 'u-statistic] but got {}.".
format(mmd_est))
m = K_XX.shape[0]
if K_XX.shape != (m, m):
raise ValueError("K_XX shape should be {} but got {} instead.".format(
(m, m), K_XX.shape))
if K_XY.shape != (m, m):
raise ValueError("K_XX shape should be {} but got {} instead.".format(
(m, m), K_XY.shape))
if K_YY.shape != (m, m):
raise ValueError("K_XX shape should be {} but got {} instead.".format(
(m, m), K_YY.shape))
# Get the various sums of kernels that we'll use
# Kts drop the diagonal, but we don't need to compute them explicitly
if unit_diagonal:
diag_X = diag_Y = 1
sum_diag_X = sum_diag_Y = m
sum_diag2_X = sum_diag2_Y = m
else:
diag_X = np.diagonal(K_XX)
diag_Y = np.diagonal(K_YY)
sum_diag_X = diag_X.sum()
sum_diag_Y = diag_Y.sum()
sum_diag2_X = _sqn(diag_X)
sum_diag2_Y = _sqn(diag_Y)
Kt_XX_sums = K_XX.sum(axis=1) - diag_X
Kt_YY_sums = K_YY.sum(axis=1) - diag_Y
K_XY_sums_0 = K_XY.sum(axis=0)
K_XY_sums_1 = K_XY.sum(axis=1)
Kt_XX_sum = Kt_XX_sums.sum()
Kt_YY_sum = Kt_YY_sums.sum()
K_XY_sum = K_XY_sums_0.sum()
if mmd_est == 'biased':
mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m) + (Kt_YY_sum + sum_diag_Y) /
(m * m) - 2 * K_XY_sum / (m * m))
else:
mmd2 = (Kt_XX_sum + Kt_YY_sum) / (m * (m - 1))
if mmd_est == 'unbiased':
mmd2 -= 2 * K_XY_sum / (m * m)
else:
mmd2 -= 2 * (K_XY_sum - np.trace(K_XY)) / (m * (m - 1))
return mmd2
# Copyright (c) 2020 Kwot Sin Lee
# This code is licensed under MIT license
# (https://github.com/kwotsin/mimicry/blob/master/LICENSE)
# ------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
import numpy as np
def _normalize_images(images):
"""
Given a tensor of (megengine BGR) images, uses the torchvision
normalization method to convert floating point data to integers. See reference
at: https://pytorch.org/docs/stable/_modules/torchvision/utils.html#save_image
The function uses the normalization from make_grid and save_image functions.
Args:
images (Tensor): Batch of images of shape (N, 3, H, W).
Returns:
ndarray: Batch of normalized (0-255) RGB images of shape (N, H, W, 3).
"""
# Shift the image from [-1, 1] range to [0, 1] range.
min_val = float(images.min())
max_val = float(images.max())
images = (images - min_val) / (max_val - min_val + 1e-5)
images = np.clip(images * 255 + 0.5, 0, 255).astype("uint8")
images = np.transpose(images, [0, 2, 3, 1])
# NOTE: megengine(opencv) uses BGR, while TF uses RGB. Needs conversion.
images = images[:, :, :, ::-1]
return images
# Copyright (c) 2020 Kwot Sin Lee
# This code is licensed under MIT license
# (https://github.com/kwotsin/mimicry/blob/master/LICENSE)
# ------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
import os
from abc import abstractmethod
import megengine
import megengine.jit as jit
import megengine.module as M
import numpy as np
class BaseModel(M.Module):
def __init__(self):
super().__init__()
self.train_step = self._reset_jit_graph(self._train_step_implementation)
self.infer_step = self._reset_jit_graph(self._infer_step_implementation)
def _reset_jit_graph(self, impl: callable):
"""create a `jit.trace` object based on abstract graph implementation"""
return jit.trace(impl)
@abstractmethod
def _train_step_implementation(self, *args, **kwargs):
"""Abstract train step function, traced at the beginning of training.
A typical implementation for a classifier could be
```
class Classifier(BaseModel):
def _train_step_implementation(
self,
image: Tensor,
label: Tensor,
opt: Optimizer = None
):
logits = self.forward(image)
loss = F.cross_entropy_with_softmax(logits, label)
if opt is not None:
opt.zero_grad()
opt.backward(loss)
opt.step()
```
This implementation is wrapped in a `megengine.jit.trace` object, which equals to
something like
```
@jit.trace
def train_step(image, label, opt=None):
return _train_step_implemenation(image, label, opt=opt)
```
And we call `model.train_step(np_image, np_label, opt=sgd_optimizer)` to
perform the wrapped training step.
"""
raise NotImplementedError
@abstractmethod
def _infer_step_implementation(self, *args, **kwargs):
"""Abstract infer step function, traced at the beginning of inference.
See document of `_train_step_implementation`.
"""
raise NotImplementedError
def train(self, mode: bool = True):
# when switching mode, graph should be reset
self.train_step = self._reset_jit_graph(self._train_step_implementation)
self.infer_step = self._reset_jit_graph(self._infer_step_implementation)
super().train(mode=mode)
def count_params(self):
r"""
Computes the number of parameters in this model.
Args: None
Returns:
int: Total number of weight parameters for this model.
int: Total number of trainable parameters for this model.
"""
num_total_params = sum(np.prod(p.shape) for p in self.parameters())
num_trainable_params = sum(np.prod(p.shape) for p in self.parameters(requires_grad=True))
return num_total_params, num_trainable_params
def restore_checkpoint(self, ckpt_file, optimizer=None):
r"""
Restores checkpoint from a pth file and restores optimizer state.
Args:
ckpt_file (str): A PyTorch pth file containing model weights.
optimizer (Optimizer): A vanilla optimizer to have its state restored from.
Returns:
int: Global step variable where the model was last checkpointed.
"""
if not ckpt_file:
raise ValueError("No checkpoint file to be restored.")
ckpt_dict = megengine.load(ckpt_file)
# Restore model weights
self.load_state_dict(ckpt_dict['model_state_dict'])
# Restore optimizer status if existing. Evaluation doesn't need this
if optimizer:
optimizer.load_state_dict(ckpt_dict['optimizer_state_dict'])
# Return global step
return ckpt_dict['global_step']
def save_checkpoint(self,
directory,
global_step,
optimizer=None,
name=None):
r"""
Saves checkpoint at a certain global step during training. Optimizer state
is also saved together.
Args:
directory (str): Path to save checkpoint to.
global_step (int): The global step variable during training.
optimizer (Optimizer): Optimizer state to be saved concurrently.
name (str): The name to save the checkpoint file as.
Returns:
None
"""
# Create directory to save to
if not os.path.exists(directory):
os.makedirs(directory)
# Build checkpoint dict to save.
ckpt_dict = {
'model_state_dict':
self.state_dict(),
'optimizer_state_dict':
optimizer.state_dict() if optimizer is not None else None,
'global_step':
global_step
}
# Save the file with specific name
if name is None:
name = "{}_{}_steps.pth".format(
os.path.basename(directory), # netD or netG
global_step)
megengine.save(ckpt_dict, os.path.join(directory, name))
# Copyright (c) 2020 Kwot Sin Lee
# This code is licensed under MIT license
# (https://github.com/kwotsin/mimicry/blob/master/LICENSE)
# ------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
import math
import megengine.functional as F
import megengine.module as M
class GBlock(M.Module):
r"""
Residual block for generator.
Uses bilinear (rather than nearest) interpolation, and align_corners
set to False. This is as per how torchvision does upsampling, as seen in:
https://github.com/pytorch/vision/blob/master/torchvision/models/segmentation/_utils.py
Attributes:
in_channels (int): The channel size of input feature map.
out_channels (int): The channel size of output feature map.
hidden_channels (int): The channel size of intermediate feature maps.
upsample (bool): If True, upsamples the input feature map.
num_classes (int): If more than 0, uses conditional batch norm instead.
spectral_norm (bool): If True, uses spectral norm for convolutional layers.
"""
def __init__(self,
in_channels,
out_channels,
hidden_channels=None,
upsample=False):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = hidden_channels if hidden_channels is not None else out_channels
self.learnable_sc = in_channels != out_channels or upsample
self.upsample = upsample
self.c1 = M.Conv2d(self.in_channels,
self.hidden_channels,
3,
1,
padding=1)
self.c2 = M.Conv2d(self.hidden_channels,
self.out_channels,
3,
1,
padding=1)
self.b1 = M.BatchNorm2d(self.in_channels)
self.b2 = M.BatchNorm2d(self.hidden_channels)
self.activation = M.ReLU()
M.init.xavier_uniform_(self.c1.weight, math.sqrt(2.0))
M.init.xavier_uniform_(self.c2.weight, math.sqrt(2.0))
# Shortcut layer
if self.learnable_sc:
self.c_sc = M.Conv2d(in_channels,
out_channels,
1,
1,
padding=0)
M.init.xavier_uniform_(self.c_sc.weight, 1.0)
def _upsample_conv(self, x, conv):
r"""
Helper function for performing convolution after upsampling.
"""
return conv(
F.interpolate(x,
scale_factor=2,
mode='bilinear',
align_corners=False))
def _residual(self, x):
r"""
Helper function for feedforwarding through main layers.
"""
h = x
h = self.b1(h)
h = self.activation(h)
h = self._upsample_conv(h, self.c1) if self.upsample else self.c1(h)
h = self.b2(h)
h = self.activation(h)
h = self.c2(h)
return h
def _shortcut(self, x):
r"""
Helper function for feedforwarding through shortcut layers.
"""
if self.learnable_sc:
x = self._upsample_conv(
x, self.c_sc) if self.upsample else self.c_sc(x)
return x
else:
return x
def forward(self, x):
r"""
Residual block feedforward function.
"""
return self._residual(x) + self._shortcut(x)
class DBlock(M.Module):
"""
Residual block for discriminator.
Attributes:
in_channels (int): The channel size of input feature map.
out_channels (int): The channel size of output feature map.
hidden_channels (int): The channel size of intermediate feature maps.
downsample (bool): If True, downsamples the input feature map.
spectral_norm (bool): If True, uses spectral norm for convolutional layers.
"""
def __init__(self,
in_channels,
out_channels,
hidden_channels=None,
downsample=False):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = hidden_channels if hidden_channels is not None else in_channels
self.downsample = downsample
self.learnable_sc = (in_channels != out_channels) or downsample
# Build the layers
self.c1 = M.Conv2d(self.in_channels, self.hidden_channels, 3, 1,
1)
self.c2 = M.Conv2d(self.hidden_channels, self.out_channels, 3, 1,
1)
self.activation = M.ReLU()
M.init.xavier_uniform_(self.c1.weight, math.sqrt(2.0))
M.init.xavier_uniform_(self.c2.weight, math.sqrt(2.0))
# Shortcut layer
if self.learnable_sc:
self.c_sc = M.Conv2d(in_channels, out_channels, 1, 1, 0)
M.init.xavier_uniform_(self.c_sc.weight, 1.0)
def _residual(self, x):
"""
Helper function for feedforwarding through main layers.
"""
h = x
h = self.activation(h)
h = self.c1(h)
h = self.activation(h)
h = self.c2(h)
if self.downsample:
h = F.avg_pool2d(h, 2)
return h
def _shortcut(self, x):
"""
Helper function for feedforwarding through shortcut layers.
"""
if self.learnable_sc:
x = self.c_sc(x)
return F.avg_pool2d(x, 2) if self.downsample else x
else:
return x
def forward(self, x):
"""
Residual block feedforward function.
"""
# NOTE: to completely reproduce pytorch, we use F.relu(x) to replace x in shortcut
# since pytorch use inplace relu in residual branch.
return self._residual(x) + self._shortcut(F.relu(x))
class DBlockOptimized(M.Module):
"""
Optimized residual block for discriminator. This is used as the first residual block,
where there is a definite downsampling involved. Follows the official SNGAN reference implementation
in chainer.
Attributes:
in_channels (int): The channel size of input feature map.
out_channels (int): The channel size of output feature map.
spectral_norm (bool): If True, uses spectral norm for convolutional layers.
"""
def __init__(self, in_channels, out_channels, spectral_norm=False):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.spectral_norm = spectral_norm
# Build the layers
self.c1 = M.Conv2d(self.in_channels, self.out_channels, 3, 1, 1)
self.c2 = M.Conv2d(self.out_channels, self.out_channels, 3, 1, 1)
self.c_sc = M.Conv2d(self.in_channels, self.out_channels, 1, 1, 0)
self.activation = M.ReLU()
M.init.xavier_uniform_(self.c1.weight, math.sqrt(2.0))
M.init.xavier_uniform_(self.c2.weight, math.sqrt(2.0))
M.init.xavier_uniform_(self.c_sc.weight, 1.0)
def _residual(self, x):
"""
Helper function for feedforwarding through main layers.
"""
h = x
h = self.c1(h)
h = self.activation(h)
h = self.c2(h)
h = F.avg_pool2d(h, 2)
return h
def _shortcut(self, x):
"""
Helper function for feedforwarding through shortcut layers.
"""
return self.c_sc(F.avg_pool2d(x, 2))
def forward(self, x):
"""
Residual block feedforward function.
"""
return self._residual(x) + self._shortcut(x)
# Copyright (c) 2020 Kwot Sin Lee
# This code is licensed under MIT license
# (https://github.com/kwotsin/mimicry/blob/master/LICENSE)
# ------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
from .. import gan
class DCGANBaseGenerator(gan.BaseGenerator):
r"""
ResNet backbone generator for ResNet DCGAN.
Attributes:
nz (int): Noise dimension for upsampling.
ngf (int): Variable controlling generator feature map sizes.
bottom_width (int): Starting width for upsampling generator output to an image.
loss_type (str): Name of loss to use for GAN loss.
"""
def __init__(self, nz, ngf, bottom_width, loss_type='ns', **kwargs):
super().__init__(nz=nz,
ngf=ngf,
bottom_width=bottom_width,
loss_type=loss_type,
**kwargs)
class DCGANBaseDiscriminator(gan.BaseDiscriminator):
r"""
ResNet backbone discriminator for ResNet DCGAN.
Attributes:
ndf (int): Variable controlling discriminator feature map sizes.
loss_type (str): Name of loss to use for GAN loss.
"""
def __init__(self, ndf, loss_type='ns', **kwargs):
super().__init__(ndf=ndf, loss_type=loss_type, **kwargs)
# Copyright (c) 2020 Kwot Sin Lee
# This code is licensed under MIT license
# (https://github.com/kwotsin/mimicry/blob/master/LICENSE)
# ------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
import megengine.functional as F
import megengine.module as M
from ..blocks import DBlock, DBlockOptimized, GBlock
from . import dcgan_base
class DCGANGeneratorCIFAR(dcgan_base.DCGANBaseGenerator):
r"""
ResNet backbone generator for ResNet DCGAN.
Attributes:
nz (int): Noise dimension for upsampling.
ngf (int): Variable controlling generator feature map sizes.
bottom_width (int): Starting width for upsampling generator output to an image.
loss_type (str): Name of loss to use for GAN loss.
"""
def __init__(self, nz=128, ngf=256, bottom_width=4, **kwargs):
super().__init__(nz=nz, ngf=ngf, bottom_width=bottom_width, **kwargs)
# Build the layers
self.l1 = M.Linear(self.nz, (self.bottom_width**2) * self.ngf)
self.block2 = GBlock(self.ngf, self.ngf, upsample=True)
self.block3 = GBlock(self.ngf, self.ngf, upsample=True)
self.block4 = GBlock(self.ngf, self.ngf, upsample=True)
self.b5 = M.BatchNorm2d(self.ngf)
self.c5 = M.Conv2d(self.ngf, 3, 3, 1, padding=1)
self.activation = M.ReLU()
# Initialise the weights
M.init.xavier_uniform_(self.l1.weight, 1.0)
M.init.xavier_uniform_(self.c5.weight, 1.0)
def forward(self, x):
r"""
Feedforwards a batch of noise vectors into a batch of fake images.
Args:
x (Tensor): A batch of noise vectors of shape (N, nz).
Returns:
Tensor: A batch of fake images of shape (N, C, H, W).
"""
h = self.l1(x)
h = h.reshape(x.shape[0], -1, self.bottom_width, self.bottom_width)
h = self.block2(h)
h = self.block3(h)
h = self.block4(h)
h = self.b5(h)
h = self.activation(h)
h = F.sigmoid(self.c5(h)) # sigmoid instead of tanh
return h
class DCGANDiscriminatorCIFAR(dcgan_base.DCGANBaseDiscriminator):
r"""
ResNet backbone discriminator for ResNet DCGAN.
Attributes:
ndf (int): Variable controlling discriminator feature map sizes.
loss_type (str): Name of loss to use for GAN loss.
"""
def __init__(self, ndf=128, **kwargs):
super().__init__(ndf=ndf, **kwargs)
# Build layers
self.block1 = DBlockOptimized(3, self.ndf)
self.block2 = DBlock(self.ndf,
self.ndf,
downsample=True)
self.block3 = DBlock(self.ndf,
self.ndf,
downsample=False)
self.block4 = DBlock(self.ndf,
self.ndf,
downsample=False)
self.l5 = M.Linear(self.ndf, 1)
self.activation = M.ReLU()
# Initialise the weights
M.init.xavier_uniform_(self.l5.weight, 1.0)
def forward(self, x):
r"""
Feedforwards a batch of real/fake images and produces a batch of GAN logits.
Args:
x (Tensor): A batch of images of shape (N, C, H, W).
Returns:
Tensor: A batch of GAN logits of shape (N, 1).
"""
h = x
h = self.block1(h)
h = self.block2(h)
h = self.block3(h)
h = self.block4(h)
h = self.activation(h)
# Global sum pooling
h = h.sum(3).sum(2)
output = self.l5(h)
return output
# Copyright (c) 2020 Kwot Sin Lee
# This code is licensed under MIT license
# (https://github.com/kwotsin/mimicry/blob/master/LICENSE)
# ------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
"""
Implementation of Base GAN models.
"""
import megengine
import megengine.functional as F
import megengine.module as M
import megengine.random as R
import numpy as np
from . import losses
from .basemodel import BaseModel
class BaseGenerator(BaseModel):
r"""
Base class for a generic unconditional generator model.
Attributes:
nz (int): Noise dimension for upsampling.
ngf (int): Variable controlling generator feature map sizes.
bottom_width (int): Starting width for upsampling generator output to an image.
loss_type (str): Name of loss to use for GAN loss.
"""
def __init__(self, nz, ngf, bottom_width, loss_type, **kwargs):
super().__init__(**kwargs)
self.nz = nz
self.ngf = ngf
self.bottom_width = bottom_width
self.loss_type = loss_type
def _train_step_implementation(
self,
real_batch,
netD=None,
optG=None):
# Produce fake images
fake_images = self._infer_step_implementation(real_batch)
# Compute output logit of D thinking image real
output = netD(fake_images)
# Compute loss
errG = self.compute_gan_loss(output=output)
optG.zero_grad()
optG.backward(errG)
optG.step()
return errG
def _infer_step_implementation(self, batch):
# Get only batch size from real batch
batch_size = batch.shape[0]
noise = R.gaussian(shape=[batch_size, self.nz])
fake_images = self.forward(noise)
return fake_images
def compute_gan_loss(self, output):
if self.loss_type == "ns":
errG = losses.ns_loss_gen(output)
elif self.loss_type == "wasserstein":
errG = losses.wasserstein_loss_gen(output)
else:
raise ValueError("Invalid loss_type {} selected.".format(
self.loss_type))
return errG
def generate_images(self, num_images):
"""Generate images of shape [`num_images`, C, H, W].
Depending on the final activation function, pixel values are NOT guarenteed
to be within [0, 1].
"""
return self.infer_step(np.empty(num_images, dtype="float32"))
class BaseDiscriminator(BaseModel):
r"""
Base class for a generic unconditional discriminator model.
Attributes:
ndf (int): Variable controlling discriminator feature map sizes.
loss_type (str): Name of loss to use for GAN loss.
"""
def __init__(self, ndf, loss_type, **kwargs):
super().__init__(**kwargs)
self.ndf = ndf
self.loss_type = loss_type
def _train_step_implementation(
self,
real_batch,
netG=None,
optD=None):
# Produce logits for real images
output_real = self._infer_step_implementation(real_batch)
# Produce fake images
fake_images = netG._infer_step_implementation(real_batch)
fake_images = F.zero_grad(fake_images)
# Produce logits for fake images
output_fake = self._infer_step_implementation(fake_images)
# Compute loss for D
errD = self.compute_gan_loss(output_real=output_real,
output_fake=output_fake)
D_x, D_Gz = self.compute_probs(output_real=output_real,
output_fake=output_fake)
# Backprop and update gradients
optD.zero_grad()
optD.backward(errD)
optD.step()
return errD, D_x, D_Gz
def _infer_step_implementation(self, batch):
return self.forward(batch)
def compute_gan_loss(self, output_real, output_fake):
r"""
Computes GAN loss for discriminator.
Args:
output_real (Tensor): A batch of output logits of shape (N, 1) from real images.
output_fake (Tensor): A batch of output logits of shape (N, 1) from fake images.
Returns:
errD (Tensor): A batch of GAN losses for the discriminator.
"""
# Compute loss for D
if self.loss_type == "gan" or self.loss_type == "ns":
errD = losses.minimax_loss_dis(output_fake=output_fake,
output_real=output_real)
elif self.loss_type == "wasserstein":
errD = losses.wasserstein_loss_dis(output_fake=output_fake,
output_real=output_real)
else:
raise ValueError("Invalid loss_type selected.")
return errD
def compute_probs(self, output_real, output_fake):
r"""
Computes probabilities from real/fake images logits.
Args:
output_real (Tensor): A batch of output logits of shape (N, 1) from real images.
output_fake (Tensor): A batch of output logits of shape (N, 1) from fake images.
Returns:
tuple: Average probabilities of real/fake image considered as real for the batch.
"""
D_x = F.sigmoid(output_real).mean()
D_Gz = F.sigmoid(output_fake).mean()
return D_x, D_Gz
# Copyright (c) 2020 Kwot Sin Lee
# This code is licensed under MIT license
# (https://github.com/kwotsin/mimicry/blob/master/LICENSE)
# ------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
import megengine.functional as F
from megengine.core.tensor_factory import zeros
def ns_loss_gen(output_fake):
r"""
Non-saturating loss for generator.
Args:
output_fake (Tensor): Discriminator output logits for fake images.
Returns:
Tensor: A scalar tensor loss output.
"""
output_fake = F.sigmoid(output_fake)
return -F.log(output_fake + 1e-8).mean()
# def ns_loss_gen(output_fake):
# """numerical stable version"""
# return F.log(1 + F.exp(-output_fake)).mean()
def _bce_loss_with_logits(output, labels, **kwargs):
r"""
Sigmoid cross entropy with logits, see tensorflow
https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits
"""
loss = F.maximum(output, 0) - output * labels + F.log(1 + F.exp(-F.abs(output)))
return loss.mean()
def minimax_loss_dis(output_fake,
output_real,
real_label_val=1.0,
fake_label_val=0.0,
**kwargs):
r"""
Standard minimax loss for GANs through the BCE Loss with logits fn.
Args:
output_fake (Tensor): Discriminator output logits for fake images.
output_real (Tensor): Discriminator output logits for real images.
real_label_val (int): Label for real images.
fake_label_val (int): Label for fake images.
device (torch.device): Torch device object for sending created data.
Returns:
Tensor: A scalar tensor loss output.
"""
# Produce real and fake labels.
fake_labels = zeros((output_fake.shape[0], 1)) + fake_label_val
real_labels = zeros((output_real.shape[0], 1)) + real_label_val
# FF, compute loss and backprop D
errD_fake = _bce_loss_with_logits(output=output_fake,
labels=fake_labels,
**kwargs)
errD_real = _bce_loss_with_logits(output=output_real,
labels=real_labels,
**kwargs)
# Compute cumulative error
loss = errD_real + errD_fake
return loss
def wasserstein_loss_gen(output_fake):
r"""
Computes the wasserstein loss for generator.
Args:
output_fake (Tensor): Discriminator output logits for fake images.
Returns:
Tensor: A scalar tensor loss output.
"""
loss = -output_fake.mean()
return loss
def wasserstein_loss_dis(output_real, output_fake):
r"""
Computes the wasserstein loss for the discriminator.
Args:
output_real (Tensor): Discriminator output logits for real images.
output_fake (Tensor): Discriminator output logits for fake images.
Returns:
Tensor: A scalar tensor loss output.
"""
loss = -1.0 * output_real.mean() + output_fake.mean()
return loss
# Copyright (c) 2020 Kwot Sin Lee
# This code is licensed under MIT license
# (https://github.com/kwotsin/mimicry/blob/master/LICENSE)
# ------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
import megengine.functional as F
import megengine.jit as jit
from .. import gan
from ..blocks import DBlock, DBlockOptimized
class WGANBaseGenerator(gan.BaseGenerator):
r"""
ResNet backbone generator for ResNet WGAN.
Attributes:
nz (int): Noise dimension for upsampling.
ngf (int): Variable controlling generator feature map sizes.
bottom_width (int): Starting width for upsampling generator output to an image.
loss_type (str): Name of loss to use for GAN loss.
"""
def __init__(self, nz, ngf, bottom_width, **kwargs):
super().__init__(nz=nz,
ngf=ngf,
bottom_width=bottom_width,
loss_type="wasserstein",
**kwargs)
class WGANBaseDiscriminator(gan.BaseDiscriminator):
r"""
ResNet backbone discriminator for ResNet WGAN.
Attributes:
ndf (int): Variable controlling discriminator feature map sizes.
loss_type (str): Name of loss to use for GAN loss.
"""
def __init__(self, ndf, **kwargs):
super().__init__(ndf=ndf, loss_type="wasserstein", **kwargs)
def _reset_jit_graph(self, impl: callable):
"""We override this func to attach weight clipping after default training step"""
traced_obj = jit.trace(impl)
def _(*args, **kwargs):
ret = traced_obj(*args, **kwargs)
if self.training:
self._apply_lipshitz_constraint() # dynamically apply weight clipping
return ret
return _
def _apply_lipshitz_constraint(self):
"""Weight clipping described in [Wasserstein GAN](https://arxiv.org/abs/1701.07875)"""
for p in self.parameters():
F.add_update(p, F.clamp(p, lower=-3e-2, upper=3e-2), alpha=0)
def layernorm(x):
original_shape = x.shape
x = x.reshape(original_shape[0], -1)
m = F.mean(x, axis=1, keepdims=True)
v = F.mean((x - m) ** 2, axis=1, keepdims=True)
x = (x - m) / F.maximum(F.sqrt(v), 1e-6)
x = x.reshape(original_shape)
return x
class WGANDBlockWithLayerNorm(DBlock):
def _residual(self, x):
h = x
h = layernorm(h)
h = self.activation(h)
h = self.c1(h)
h = layernorm(h)
h = self.activation(h)
h = self.c2(h)
if self.downsample:
h = F.avg_pool2d(h, 2)
return h
class WGANDBlockOptimized(DBlockOptimized):
pass
# Copyright (c) 2020 Kwot Sin Lee
# This code is licensed under MIT license
# (https://github.com/kwotsin/mimicry/blob/master/LICENSE)
# ------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
import megengine.functional as F
import megengine.module as M
from ..blocks import GBlock
from . import wgan_base
from .wgan_base import WGANDBlockOptimized as DBlockOptimized
from .wgan_base import WGANDBlockWithLayerNorm as DBlock
class WGANGeneratorCIFAR(wgan_base.WGANBaseGenerator):
r"""
ResNet backbone generator for ResNet WGAN.
Attributes:
nz (int): Noise dimension for upsampling.
ngf (int): Variable controlling generator feature map sizes.
bottom_width (int): Starting width for upsampling generator output to an image.
loss_type (str): Name of loss to use for GAN loss.
"""
def __init__(self, nz=128, ngf=256, bottom_width=4, **kwargs):
super().__init__(nz=nz, ngf=ngf, bottom_width=bottom_width, **kwargs)
# Build the layers
self.l1 = M.Linear(self.nz, (self.bottom_width**2) * self.ngf)
self.block2 = GBlock(self.ngf, self.ngf, upsample=True)
self.block3 = GBlock(self.ngf, self.ngf, upsample=True)
self.block4 = GBlock(self.ngf, self.ngf, upsample=True)
self.b5 = M.BatchNorm2d(self.ngf)
self.c5 = M.Conv2d(self.ngf, 3, 3, 1, padding=1)
self.activation = M.ReLU()
# Initialise the weights
M.init.xavier_uniform_(self.l1.weight, 1.0)
M.init.xavier_uniform_(self.c5.weight, 1.0)
def forward(self, x):
r"""
Feedforwards a batch of noise vectors into a batch of fake images.
Args:
x (Tensor): A batch of noise vectors of shape (N, nz).
Returns:
Tensor: A batch of fake images of shape (N, C, H, W).
"""
h = self.l1(x)
h = h.reshape(x.shape[0], -1, self.bottom_width, self.bottom_width)
h = self.block2(h)
h = self.block3(h)
h = self.block4(h)
h = self.b5(h)
h = self.activation(h)
h = F.tanh(self.c5(h))
return h
class WGANDiscriminatorCIFAR(wgan_base.WGANBaseDiscriminator):
r"""
ResNet backbone discriminator for ResNet WGAN.
Attributes:
ndf (int): Variable controlling discriminator feature map sizes.
loss_type (str): Name of loss to use for GAN loss.
"""
def __init__(self, ndf=128, **kwargs):
super().__init__(ndf=ndf, **kwargs)
# Build layers
self.block1 = DBlockOptimized(3, self.ndf)
self.block2 = DBlock(self.ndf,
self.ndf,
downsample=True)
self.block3 = DBlock(self.ndf,
self.ndf,
downsample=False)
self.block4 = DBlock(self.ndf,
self.ndf,
downsample=False)
self.l5 = M.Linear(self.ndf, 1)
self.activation = M.ReLU()
# Initialise the weights
M.init.xavier_uniform_(self.l5.weight, 1.0)
def forward(self, x):
r"""
Feedforwards a batch of real/fake images and produces a batch of GAN logits.
Args:
x (Tensor): A batch of images of shape (N, C, H, W).
Returns:
Tensor: A batch of GAN logits of shape (N, 1).
"""
h = x
h = self.block1(h)
h = self.block2(h)
h = self.block3(h)
h = self.block4(h)
h = self.activation(h)
# Global average pooling
h = h.mean(3).mean(2)
output = self.l5(h)
return output
# Copyright (c) 2020 Kwot Sin Lee
# This code is licensed under MIT license
# (https://github.com/kwotsin/mimicry/blob/master/LICENSE)
# ------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
from .trainer import Trainer
# Copyright (c) 2020 Kwot Sin Lee
# This code is licensed under MIT license
# (https://github.com/kwotsin/mimicry/blob/master/LICENSE)
# ------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
"""
Implementation of the Logger object for performing training logging and visualisation.
"""
import os
import numpy as np
from tensorboardX import SummaryWriter
from ..utils import vis as vutils
class Logger:
"""
Writes summaries and visualises training progress.
Attributes:
log_dir (str): The path to store logging information.
num_steps (int): Total number of training iterations.
dataset_size (int): The number of examples in the dataset.
device (Device): Torch device object to send data to.
flush_secs (int): Number of seconds before flushing summaries to disk.
writers (dict): A dictionary of tensorboard writers with keys as metric names.
num_epochs (int): The number of epochs, for extra information.
"""
def __init__(self,
log_dir,
num_steps,
dataset_size,
flush_secs=120,
**kwargs):
self.log_dir = log_dir
self.num_steps = num_steps
self.dataset_size = dataset_size
self.flush_secs = flush_secs
# self.num_epochs = self._get_epoch(num_steps)
self.writers = {}
# Create log directory if haven't already
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
# def _get_epoch(self, steps):
# """
# Helper function for getting epoch.
# """
# return max(int(steps / self.dataset_size), 1)
def _build_writer(self, metric):
writer = SummaryWriter(log_dir=os.path.join(self.log_dir, 'data',
metric),
flush_secs=self.flush_secs)
return writer
def write_summaries(self, log_data, global_step):
"""
Tasks appropriate writers to write the summaries in tensorboard. Creates additional
writers for summary writing if there are new scalars to log in log_data.
Args:
log_data (MetricLog): Dict-like object to collect log data for TB writing.
global_step (int): Global step variable for syncing logs.
Returns:
None
"""
for metric, data in log_data.items():
if metric not in self.writers:
self.writers[metric] = self._build_writer(metric)
# Write with a group name if it exists
name = log_data.get_group_name(metric) or metric
self.writers[metric].add_scalar(name,
log_data[metric],
global_step=global_step)
def close_writers(self):
"""
Closes all writers.
"""
for metric in self.writers:
self.writers[metric].close()
def print_log(self, global_step, log_data, time_taken):
"""
Formats the string to print to stdout based on training information.
Args:
log_data (MetricLog): Dict-like object to collect log data for TB writing.
global_step (int): Global step variable for syncing logs.
time_taken (float): Time taken for one training iteration.
Returns:
str: String to be printed to stdout.
"""
# Basic information
# log_to_show = [
# "INFO: [Epoch {:d}/{:d}][Global Step: {:d}/{:d}]".format(
# self._get_epoch(global_step), self.num_epochs, global_step,
# self.num_steps)
# ]
log_to_show = [
"INFO: [Global Step: {:d}/{:d}]".format(
global_step, self.num_steps)
]
# Display GAN information as fed from user.
GAN_info = [""]
metrics = sorted(log_data.keys())
for metric in metrics:
GAN_info.append('{}: {}'.format(metric, log_data[metric]))
# Add train step time information
GAN_info.append("({:.4f} sec/idx)".format(time_taken))
# Accumulate to log
log_to_show.append("\n| ".join(GAN_info))
# Finally print the output
ret = " ".join(log_to_show)
print(ret)
return ret
# def _get_fixed_noise(self, nz, num_images, output_dir=None):
# """
# Produce the fixed gaussian noise vectors used across all models
# for consistency.
# """
# if output_dir is None:
# output_dir = os.path.join(self.log_dir, 'viz')
# if not os.path.exists(output_dir):
# os.makedirs(output_dir)
# output_file = os.path.join(output_dir,
# 'fixed_noise_nz_{}.pth'.format(nz))
# if os.path.exists(output_file):
# noise = torch.load(output_file)
# else:
# noise = torch.randn((num_images, nz))
# torch.save(noise, output_file)
# return noise.to(self.device)
# def _get_fixed_labels(self, num_images, num_classes):
# """
# Produces fixed class labels for generating fixed images.
# """
# labels = np.array([i % num_classes for i in range(num_images)])
# labels = torch.from_numpy(labels).to(self.device)
# return labels
def vis_images(self, netG, global_step, num_images=64):
"""
Produce visualisations of the G(z), one fixed and one random.
Args:
netG (Module): Generator model object for producing images.
global_step (int): Global step variable for syncing logs.
num_images (int): The number of images to visualise.
Returns:
None
"""
img_dir = os.path.join(self.log_dir, 'images')
if not os.path.exists(img_dir):
os.makedirs(img_dir)
# Generate random images
fake_images = netG.generate_images(num_images=num_images)
# Generate fixed random images
# fixed_noise = self._get_fixed_noise(nz=netG.nz,
# num_images=num_images)
# if hasattr(netG, 'num_classes') and netG.num_classes > 0:
# fixed_labels = self._get_fixed_labels(num_images,
# netG.num_classes)
# fixed_fake_images = netG(fixed_noise,
# fixed_labels).detach().cpu()
# else:
# fixed_fake_images = netG(fixed_noise).detach().cpu()
# Map name to results
images_dict = {
'fake': fake_images
}
# Visualise all results
for name, images in images_dict.items():
images_viz = vutils.make_grid(images,
padding=2,
normalize=True)
vutils.save_image(images_viz,
'{}/{}_samples_step_{}.png'.format(
img_dir, name, global_step))
# Copyright (c) 2020 Kwot Sin Lee
# This code is licensed under MIT license
# (https://github.com/kwotsin/mimicry/blob/master/LICENSE)
# ------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
"""
MetricLog object for intelligently logging data to display them more intuitively.
"""
class MetricLog:
"""
A dictionary-like object that logs data, and includes an extra dict to map the metrics
to its group name, if any, and the corresponding precision to print out.
Attributes:
metrics_dict (dict): A dictionary mapping to another dict containing
the corresponding value, precision, and the group this metric belongs to.
"""
def __init__(self, **kwargs):
self.metrics_dict = {}
def add_metric(self, name, value, group=None, precision=4):
"""
Logs metric to internal dict, but with an additional option
of grouping certain metrics together.
Args:
name (str): Name of metric to log.
value (Tensor/Float): Value of the metric to log.
group (str): Name of the group to classify different metrics together.
precision (int): The number of floating point precision to represent the value.
Returns:
None
"""
# Grab tensor values only
try:
value = value.item()
except AttributeError:
value = value
self.metrics_dict[name] = dict(value=value,
group=group,
precision=precision)
def __getitem__(self, key):
return round(self.metrics_dict[key]['value'],
self.metrics_dict[key]['precision'])
def get_group_name(self, name):
"""
Obtains the group name of a particular metric. For example, errD and errG
which represents the discriminator/generator losses could fall under a
group name called "loss".
Args:
name (str): The name of the metric to retrieve group name.
Returns:
str: A string representing the group name of the metric.
"""
return self.metrics_dict[name]['group']
def keys(self):
"""
Dict like functionality for retrieving keys.
"""
return self.metrics_dict.keys()
def items(self):
"""
Dict like functionality for retrieving items.
"""
return self.metrics_dict.items()
# Copyright (c) 2020 Kwot Sin Lee
# This code is licensed under MIT license
# (https://github.com/kwotsin/mimicry/blob/master/LICENSE)
# ------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
"""
Implementation of a specific learning rate scheduler for GANs.
"""
class LRScheduler:
"""
Learning rate scheduler for training GANs. Supports GAN specific LR scheduling
policies, such as the linear decay policy using in SN-GAN paper as based on the
original chainer implementation. However, one could safely ignore this class
and instead use the official PyTorch scheduler wrappers around a optimizer
for other scheduling policies.
Attributes:
lr_decay (str): The learning rate decay policy to use.
optD (Optimizer): Torch optimizer object for discriminator.
optG (Optimizer): Torch optimizer object for generator.
num_steps (int): The number of training iterations.
lr_D (float): The initial learning rate of optD.
lr_G (float): The initial learning rate of optG.
"""
def __init__(self, lr_decay, optD, optG, num_steps, **kwargs):
if lr_decay not in [None, 'None', 'linear']:
raise NotImplementedError(
"lr_decay {} is not currently supported.")
self.lr_decay = lr_decay
self.optD = optD
self.optG = optG
self.num_steps = num_steps
# Cache the initial learning rate for uses later
self.lr_D = optD.param_groups[0]['lr']
self.lr_G = optG.param_groups[0]['lr']
def linear_decay(self, optimizer, global_step, lr_value_range,
lr_step_range):
"""
Performs linear decay of the optimizer learning rate based on the number of global
steps taken. Follows SNGAN's chainer implementation of linear decay, as seen in the
chainer references:
https://docs.chainer.org/en/stable/reference/generated/chainer.training.extensions.LinearShift.html
https://github.com/chainer/chainer/blob/v6.2.0/chainer/training/extensions/linear_shift.py#L66
Note: assumes that the optimizer has only one parameter group to update!
Args:
optimizer (Optimizer): Torch optimizer object to update learning rate.
global_step (int): The current global step of the training.
lr_value_range (tuple): A tuple of floats (x,y) to decrease from x to y.
lr_step_range (tuple): A tuple of ints (i, j) to start decreasing
when global_step > i, and until j.
Returns:
float: Float representing the new updated learning rate.
"""
# Compute the new learning rate
v1, v2 = lr_value_range
s1, s2 = lr_step_range
if global_step <= s1:
updated_lr = v1
elif global_step >= s2:
updated_lr = v2
else:
scale_factor = (global_step - s1) / (s2 - s1)
updated_lr = v1 + scale_factor * (v2 - v1)
# Update the learning rate
optimizer.param_groups[0]['lr'] = updated_lr
return updated_lr
def step(self, log_data, global_step):
"""
Takes a step for updating learning rate and updates the input log_data
with the current status.
Args:
log_data (MetricLog): Object for logging the updated learning rate metric.
global_step (int): The current global step of the training.
Returns:
MetricLog: MetricLog object containing the updated learning rate at the current global step.
"""
if self.lr_decay == "linear":
lr_D = self.linear_decay(optimizer=self.optD,
global_step=global_step,
lr_value_range=(self.lr_D, 0.0),
lr_step_range=(0, self.num_steps))
lr_G = self.linear_decay(optimizer=self.optG,
global_step=global_step,
lr_value_range=(self.lr_G, 0.0),
lr_step_range=(0, self.num_steps))
elif self.lr_decay in [None, "None"]:
lr_D = self.lr_D
lr_G = self.lr_G
else:
raise ValueError("Invalid lr_decay method {} selected.".format(
self.lr_decay))
# Update metrics log
log_data.add_metric('lr_D', lr_D, group='lr', precision=6)
log_data.add_metric('lr_G', lr_G, group='lr', precision=6)
return log_data
# Copyright (c) 2020 Kwot Sin Lee
# This code is licensed under MIT license
# (https://github.com/kwotsin/mimicry/blob/master/LICENSE)
# ------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
"""
Implementation of Trainer object for training GANs.
"""
import os
import re
import time
import megengine
from ..utils import common
from . import logger, metric_log, scheduler
class Trainer:
"""
Trainer object for constructing the GAN training pipeline.
Attributes:
netD (Module): Torch discriminator model.
netG (Module): Torch generator model.
optD (Optimizer): Torch optimizer object for discriminator.
optG (Optimizer): Torch optimizer object for generator.
dataloader (DataLoader): Torch object for loading data from a dataset object.
num_steps (int): The number of training iterations.
n_dis (int): Number of discriminator update steps per generator training step.
lr_decay (str): The learning rate decay policy to use.
log_dir (str): The path to storing logging information and checkpoints.
logger (Logger): Logger object for visualising training information.
scheduler (LRScheduler): GAN training specific learning rate scheduler object.
params (dict): Dictionary of training hyperparameters.
netD_ckpt_file (str): Custom checkpoint file to restore discriminator from.
netG_ckpt_file (str): Custom checkpoint file to restore generator from.
print_steps (int): Number of training steps before printing training info to stdout.
vis_steps (int): Number of training steps before visualising images with TensorBoard.
flush_secs (int): Number of seconds before flushing summaries to disk.
log_steps (int): Number of training steps before writing summaries to TensorBoard.
save_steps (int): Number of training steps bfeore checkpointing.
save_when_end (bool): If True, saves final checkpoint when training concludes.
"""
def __init__(self,
netD,
netG,
optD,
optG,
dataloader,
num_steps,
log_dir='./log',
n_dis=1,
netG_ckpt_file=None,
netD_ckpt_file=None,
lr_decay=None,
**kwargs):
self.netD = netD
self.netG = netG
self.optD = optD
self.optG = optG
self.n_dis = n_dis
self.lr_decay = lr_decay
self.dataloader = dataloader
self.num_steps = num_steps
self.log_dir = log_dir
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
# Obtain custom or latest checkpoint files
if netG_ckpt_file:
self.netG_ckpt_dir = os.path.dirname(netG_ckpt_file)
self.netG_ckpt_file = netG_ckpt_file
else:
self.netG_ckpt_dir = os.path.join(self.log_dir, 'checkpoints',
'netG')
self.netG_ckpt_file = self._get_latest_checkpoint(
self.netG_ckpt_dir) # can be None
if netD_ckpt_file:
self.netD_ckpt_dir = os.path.dirname(netD_ckpt_file)
self.netD_ckpt_file = netD_ckpt_file
else:
self.netD_ckpt_dir = os.path.join(self.log_dir, 'checkpoints',
'netD')
self.netD_ckpt_file = self._get_latest_checkpoint(
self.netD_ckpt_dir) # can be None
# Default parameters, unless provided by kwargs
default_params = {
'print_steps': kwargs.get('print_steps', 1),
'vis_steps': kwargs.get('vis_steps', 500),
'flush_secs': kwargs.get('flush_secs', 30),
'log_steps': kwargs.get('log_steps', 50),
'save_steps': kwargs.get('save_steps', 5000),
'save_when_end': kwargs.get('save_when_end', True),
}
for param in default_params:
self.__dict__[param] = default_params[param]
# Hyperparameters for logging experiments
self.params = {
'log_dir': self.log_dir,
'num_steps': self.num_steps,
# 'batch_size': self.dataloader.sampler.batch_size,
'n_dis': self.n_dis,
'lr_decay': self.lr_decay,
# 'optD': optD.__repr__(),
# 'optG': optG.__repr__(),
}
self.params.update(default_params)
# Log training hyperparmaeters
self._log_params(self.params)
# Training helper objects
self.logger = logger.Logger(log_dir=self.log_dir,
num_steps=self.num_steps,
dataset_size=len(self.dataloader.dataset),
flush_secs=self.flush_secs)
self.scheduler = scheduler.LRScheduler(lr_decay=self.lr_decay,
optD=self.optD,
optG=self.optG,
num_steps=self.num_steps)
def _log_params(self, params):
"""
Takes the argument options to save into a json file.
"""
params_file = os.path.join(self.log_dir, 'params.json')
# Check for discrepancy with previous training config.
if os.path.exists(params_file):
check = common.load_from_json(params_file)
if params != check:
diffs = []
for k in params:
if k in check and params[k] != check[k]:
diffs.append('{}: Expected {} but got {}.'.format(
k, check[k], params[k]))
diff_string = '\n'.join(diffs)
raise ValueError(
"Current hyperparameter configuration is different from previously:\n{}"
.format(diff_string))
common.write_to_json(params, params_file)
def _get_latest_checkpoint(self, ckpt_dir):
"""
Given a checkpoint dir, finds the checkpoint with the latest training step.
"""
def _get_step_number(k):
"""
Helper function to get step number from checkpoint files.
"""
search = re.search(r'(\d+)_steps', k)
if search:
return int(search.groups()[0])
else:
return -float('inf')
if not os.path.exists(ckpt_dir):
return None
files = os.listdir(ckpt_dir)
if len(files) == 0:
return None
ckpt_file = max(files, key=lambda x: _get_step_number(x))
return os.path.join(ckpt_dir, ckpt_file)
def _fetch_data(self, iter_dataloader):
"""
Fetches the next set of data and refresh the iterator when it is exhausted.
Follows python EAFP, so no iterator.hasNext() is used.
"""
real_batch = next(iter_dataloader)
if isinstance(real_batch, (tuple, list)): # (image, label)
real_batch = real_batch[0]
return iter_dataloader, real_batch
def _restore_models_and_step(self):
"""
Restores model and optimizer checkpoints and ensures global step is in sync.
"""
global_step_D = global_step_G = 0
if self.netD_ckpt_file and os.path.exists(self.netD_ckpt_file):
print("INFO: Restoring checkpoint for D...")
global_step_D = self.netD.restore_checkpoint(
ckpt_file=self.netD_ckpt_file, optimizer=self.optD)
if self.netG_ckpt_file and os.path.exists(self.netG_ckpt_file):
print("INFO: Restoring checkpoint for G...")
global_step_G = self.netG.restore_checkpoint(
ckpt_file=self.netG_ckpt_file, optimizer=self.optG)
if global_step_G != global_step_D:
raise ValueError('G and D Networks are out of sync.')
else:
global_step = global_step_G # Restores global step
return global_step
def train(self):
"""
Runs the training pipeline with all given parameters in Trainer.
"""
# Restore models
global_step = self._restore_models_and_step()
print("INFO: Starting training from global step {}...".format(
global_step))
try:
start_time = time.time()
# Iterate through data
iter_dataloader = iter(self.dataloader)
while global_step < self.num_steps:
log_data = metric_log.MetricLog() # log data for tensorboard
# -------------------------
# One Training Step
# -------------------------
# Update n_dis times for D
for i in range(self.n_dis):
iter_dataloader, real_batch = self._fetch_data(
iter_dataloader=iter_dataloader)
# -----------------------
# Update G Network
# -----------------------
# Update G, but only once.
if i == 0:
errG = self.netG.train_step(
real_batch,
netD=self.netD,
optG=self.optG)
log_data.add_metric("errG", errG.item(), group="loss")
# ------------------------
# Update D Network
# -----------------------
errD, D_x, D_Gz = self.netD.train_step(real_batch,
netG=self.netG,
optD=self.optD)
log_data.add_metric("errD", errD.item(), group="loss")
log_data.add_metric("D_x", D_x.item(), group="prob")
log_data.add_metric("D_Gz", D_Gz.item(), group="prob")
# --------------------------------
# Update Training Variables
# -------------------------------
global_step += 1
log_data = self.scheduler.step(log_data=log_data,
global_step=global_step)
# -------------------------
# Logging and Metrics
# -------------------------
if global_step % self.log_steps == 0:
self.logger.write_summaries(log_data=log_data,
global_step=global_step)
if global_step % self.print_steps == 0:
curr_time = time.time()
self.logger.print_log(global_step=global_step,
log_data=log_data,
time_taken=(curr_time - start_time) /
self.print_steps)
start_time = curr_time
if global_step % self.vis_steps == 0:
self.logger.vis_images(netG=self.netG,
global_step=global_step)
if global_step % self.save_steps == 0:
print("INFO: Saving checkpoints...")
self.netG.save_checkpoint(directory=self.netG_ckpt_dir,
global_step=global_step,
optimizer=self.optG)
self.netD.save_checkpoint(directory=self.netD_ckpt_dir,
global_step=global_step,
optimizer=self.optD)
# Save models at the very end of training
if self.save_when_end:
print("INFO: Saving final checkpoints...")
self.netG.save_checkpoint(directory=self.netG_ckpt_dir,
global_step=global_step,
optimizer=self.optG)
self.netD.save_checkpoint(directory=self.netD_ckpt_dir,
global_step=global_step,
optimizer=self.optD)
except KeyboardInterrupt:
print("INFO: Saving checkpoints from keyboard interrupt...")
self.netG.save_checkpoint(directory=self.netG_ckpt_dir,
global_step=global_step,
optimizer=self.optG)
self.netD.save_checkpoint(directory=self.netD_ckpt_dir,
global_step=global_step,
optimizer=self.optD)
finally:
self.logger.close_writers()
print("INFO: Training Ended.")
# Copyright (c) 2020 Kwot Sin Lee
# This code is licensed under MIT license
# (https://github.com/kwotsin/mimicry/blob/master/LICENSE)
# ------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
from .common import *
# Copyright (c) 2020 Kwot Sin Lee
# This code is licensed under MIT license
# (https://github.com/kwotsin/mimicry/blob/master/LICENSE)
# ------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
"""
Script for common utility functions.
"""
import json
import os
import numpy as np
def write_to_json(dict_to_write, output_file):
"""
Outputs a given dictionary as a JSON file with indents.
Args:
dict_to_write (dict): Input dictionary to output.
output_file (str): File path to write the dictionary.
Returns:
None
"""
with open(output_file, 'w') as file:
json.dump(dict_to_write, file, indent=4)
def load_from_json(json_file):
"""
Loads a JSON file as a dictionary and return it.
Args:
json_file (str): Input JSON file to read.
Returns:
dict: Dictionary loaded from the JSON file.
"""
with open(json_file, 'r') as file:
return json.load(file)
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import math
import cv2
import megengine
def normalize_image(tensor: megengine.Tensor, scale=255):
"""normalize image tensors of any range to [0, scale=255]"""
mi = tensor.min()
ma = tensor.max()
tensor = scale * (tensor - mi) / (ma - mi + 1e-9)
return tensor
def make_grid(
tensor: megengine.Tensor, # [N,C,H,W]
nrow: int = 8,
padding: int = 2,
background: float = 0,
normalize: bool = False,
) -> megengine.Tensor:
"""align [N, C, H, W] image tensor to [H, W, 3] image grids, for visualization"""
if normalize:
tensor = normalize_image(tensor, scale=255) # normalize to 0-255 scale
c = tensor.shape[1]
assert c in (1, 3), "only support color/grayscale images, got channel = {}".format(c)
nmaps = tensor.shape[0]
xmaps = min(nrow, nmaps)
ymaps = int(math.ceil(float(nmaps) / xmaps))
height, width = int(tensor.shape[2] + padding), int(tensor.shape[3] + padding)
num_channels = tensor.shape[1]
grid = megengine.ones((num_channels, height * ymaps + padding, width * xmaps + padding), "float32") * background
k = 0
for y in range(ymaps):
for x in range(xmaps):
if k >= nmaps:
break
grid = grid.set_subtensor(tensor[k])[:,
y * height + padding: (y + 1) * height,
x * width + padding: (x + 1) * width]
k = k + 1
c, h, w = grid.shape
grid = grid.dimshuffle(1, 2, 0) # [C,H,W] -> [H,W,C]
grid = grid.broadcast(h, w, 3) # [H,W,C] -> [H,W,3]
return grid
def save_image(image, path):
if isinstance(image, megengine.Tensor):
image = image.numpy()
cv2.imwrite(path, image)
tensorflow>=2.0
tensorboardX
# Copyright (c) 2020 Kwot Sin Lee
# This code is licensed under MIT license
# (https://github.com/kwotsin/mimicry/blob/master/LICENSE)
# ------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
import megengine.data as data
import megengine.data.transform as T
import megengine.optimizer as optim
import megengine_mimicry as mmc
import megengine_mimicry.nets.dcgan.dcgan_cifar as dcgan
dataset = mmc.datasets.load_dataset(root=None, name='cifar10')
dataloader = data.DataLoader(
dataset,
sampler=data.Infinite(data.RandomSampler(dataset, batch_size=64, drop_last=True)),
transform=T.Compose([T.Normalize(std=255), T.ToMode("CHW")]),
num_workers=4
)
netG = dcgan.DCGANGeneratorCIFAR()
netD = dcgan.DCGANDiscriminatorCIFAR()
optD = optim.Adam(netD.parameters(), 2e-4, betas=(0.0, 0.9))
optG = optim.Adam(netG.parameters(), 2e-4, betas=(0.0, 0.9))
LOG_DIR = "./log/dcgan_example"
trainer = mmc.training.Trainer(
netD=netD,
netG=netG,
optD=optD,
optG=optG,
n_dis=5,
num_steps=100000,
lr_decay="linear",
dataloader=dataloader,
log_dir=LOG_DIR,
device=0)
trainer.train()
mmc.metrics.compute_metrics.evaluate(
metric="fid",
netG=netG,
log_dir=LOG_DIR,
evaluate_step=100000,
num_runs=1,
device=0,
num_real_samples=50000,
num_fake_samples=50000,
dataset_name="cifar10",
)
mmc.metrics.compute_metrics.evaluate(
metric="inception_score",
netG=netG,
log_dir=LOG_DIR,
evaluate_step=100000,
num_runs=1,
device=0,
num_samples=50000,
)
mmc.metrics.compute_metrics.evaluate(
metric="kid",
netG=netG,
log_dir=LOG_DIR,
evaluate_step=100000,
num_runs=1,
device=0,
num_subsets=50,
subset_size=1000,
dataset_name="cifar10",
)
# Copyright (c) 2020 Kwot Sin Lee
# This code is licensed under MIT license
# (https://github.com/kwotsin/mimicry/blob/master/LICENSE)
# ------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
import megengine.data as data
import megengine.data.transform as T
import megengine.optimizer as optim
import megengine_mimicry as mmc
import megengine_mimicry.nets.wgan.wgan_cifar as wgan
dataset = mmc.datasets.load_dataset(root=None, name='cifar10')
dataloader = data.DataLoader(
dataset,
sampler=data.Infinite(data.RandomSampler(dataset, batch_size=64, drop_last=True)),
transform=T.Compose([T.Normalize(mean=127, std=127), T.ToMode("CHW")]),
num_workers=4
)
netG = wgan.WGANGeneratorCIFAR()
netD = wgan.WGANDiscriminatorCIFAR()
optD = optim.Adam(netD.parameters(), 2e-4, betas=(0.0, 0.9))
optG = optim.Adam(netG.parameters(), 2e-4, betas=(0.0, 0.9))
LOG_DIR = "./log/wgan_example"
trainer = mmc.training.Trainer(
netD=netD,
netG=netG,
optD=optD,
optG=optG,
n_dis=5,
num_steps=100000,
lr_decay="linear",
dataloader=dataloader,
log_dir=LOG_DIR,
device=0)
trainer.train()
mmc.metrics.compute_metrics.evaluate(
metric="fid",
netG=netG,
log_dir=LOG_DIR,
evaluate_step=100000,
num_runs=1,
device=0,
num_real_samples=50000,
num_fake_samples=50000,
dataset_name="cifar10",
)
mmc.metrics.compute_metrics.evaluate(
metric="inception_score",
netG=netG,
log_dir=LOG_DIR,
evaluate_step=100000,
num_runs=1,
device=0,
num_samples=50000,
)
mmc.metrics.compute_metrics.evaluate(
metric="kid",
netG=netG,
log_dir=LOG_DIR,
evaluate_step=100000,
num_runs=1,
device=0,
num_subsets=50,
subset_size=1000,
dataset_name="cifar10",
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册