未验证 提交 de6cb8d2 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #13 from gbstack/master

Add StarGAN-v2 style FID calculation
...@@ -8,3 +8,12 @@ wget https://paddlegan.bj.bcebos.com/InceptionV3.pdparams ...@@ -8,3 +8,12 @@ wget https://paddlegan.bj.bcebos.com/InceptionV3.pdparams
``` ```
python test_fid_score.py --image_data_path1 /path/to/dataset1 --image_data_path2 /path/to/dataset2 --inference_model ./InceptionV3.pdparams python test_fid_score.py --image_data_path1 /path/to/dataset1 --image_data_path2 /path/to/dataset2 --inference_model ./InceptionV3.pdparams
``` ```
### Inception-V3 weights converted from torchvision
Download: https://aistudio.baidu.com/aistudio/datasetdetail/51890
This model weights file is converted from official torchvision inception-v3 model. And both BigGAN and StarGAN-v2 is using it to calculate FID score.
Note that this model weights is different from above one (which is converted from tensorflow unofficial version)
...@@ -16,15 +16,18 @@ import os ...@@ -16,15 +16,18 @@ import os
import fnmatch import fnmatch
import numpy as np import numpy as np
import cv2 import cv2
from PIL import Image
from cv2 import imread from cv2 import imread
from scipy import linalg from scipy import linalg
import paddle.fluid as fluid import paddle.fluid as fluid
from inception import InceptionV3 from inception import InceptionV3
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
try:
def tqdm(x): from tqdm import tqdm
return x except:
def tqdm(x):
return x
""" based on https://github.com/mit-han-lab/gan-compression/blob/master/metric/fid_score.py """ based on https://github.com/mit-han-lab/gan-compression/blob/master/metric/fid_score.py
...@@ -128,7 +131,7 @@ def calculate_fid_given_img(img_fake, ...@@ -128,7 +131,7 @@ def calculate_fid_given_img(img_fake,
return fid_value return fid_value
def _get_activations(files, model, batch_size, dims, use_gpu, premodel_path): def _get_activations(files, model, batch_size, dims, use_gpu, premodel_path, style=None):
if len(files) % batch_size != 0: if len(files) % batch_size != 0:
print(('Warning: number of images is not a multiple of the ' print(('Warning: number of images is not a multiple of the '
'batch size. Some samples are going to be ignored.')) 'batch size. Some samples are going to be ignored.'))
...@@ -144,8 +147,23 @@ def _get_activations(files, model, batch_size, dims, use_gpu, premodel_path): ...@@ -144,8 +147,23 @@ def _get_activations(files, model, batch_size, dims, use_gpu, premodel_path):
for i in tqdm(range(n_batches)): for i in tqdm(range(n_batches)):
start = i * batch_size start = i * batch_size
end = start + batch_size end = start + batch_size
images = np.array(
[imread(str(f)).astype(np.float32) for f in files[start:end]]) # same as stargan-v2 official implementation: resize to 256 first, then resize to 299
if style == 'stargan':
img_list = []
for f in files[start:end]:
im = Image.open(str(f)).convert('RGB')
if im.size[0] != 299:
im = im.resize((256, 256), 2)
im = im.resize((299, 299), 2)
img_list.append(np.array(im).astype('float32'))
images = np.array(
img_list)
else:
images = np.array(
[imread(str(f)).astype(np.float32) for f in files[start:end]])
if len(images.shape) != 4: if len(images.shape) != 4:
images = imread(str(files[start])) images = imread(str(files[start]))
...@@ -155,33 +173,53 @@ def _get_activations(files, model, batch_size, dims, use_gpu, premodel_path): ...@@ -155,33 +173,53 @@ def _get_activations(files, model, batch_size, dims, use_gpu, premodel_path):
images = images.transpose((0, 3, 1, 2)) images = images.transpose((0, 3, 1, 2))
images /= 255 images /= 255
images = to_variable(images) # imagenet normalization
param_dict, _ = fluid.load_dygraph(premodel_path) if style == 'stargan':
model.set_dict(param_dict) mean = np.array([0.485, 0.456, 0.406]).astype('float32')
model.eval() std = np.array([0.229, 0.224, 0.225]).astype('float32')
images[:] = (images[:] - mean[:, None, None]) / std[:, None, None]
pred = model(images)[0][0].numpy() if style=='stargan':
pred_arr[start:end] = inception_infer(images, premodel_path)
else:
with fluid.dygraph.guard():
images = to_variable(images)
param_dict, _ = fluid.load_dygraph(premodel_path)
model.set_dict(param_dict)
model.eval()
pred_arr[start:end] = pred.reshape(end - start, -1) pred = model(images)[0][0].numpy()
pred_arr[start:end] = pred.reshape(end - start, -1)
return pred_arr return pred_arr
def inception_infer(x, model_path):
exe = fluid.Executor()
[inference_program, feed_target_names, fetch_targets] = fluid.io.load_inference_model(model_path, exe)
results = exe.run(inference_program,
feed={feed_target_names[0]: x},
fetch_list=fetch_targets)
return results[0]
def _calculate_activation_statistics(files, def _calculate_activation_statistics(files,
model, model,
premodel_path, premodel_path,
batch_size=50, batch_size=50,
dims=2048, dims=2048,
use_gpu=False): use_gpu=False,
style = None):
act = _get_activations(files, model, batch_size, dims, use_gpu, act = _get_activations(files, model, batch_size, dims, use_gpu,
premodel_path) premodel_path, style)
mu = np.mean(act, axis=0) mu = np.mean(act, axis=0)
sigma = np.cov(act, rowvar=False) sigma = np.cov(act, rowvar=False)
return mu, sigma return mu, sigma
def _compute_statistics_of_path(path, model, batch_size, dims, use_gpu, def _compute_statistics_of_path(path, model, batch_size, dims, use_gpu,
premodel_path): premodel_path, style=None):
if path.endswith('.npz'): if path.endswith('.npz'):
f = np.load(path) f = np.load(path)
m, s = f['mu'][:], f['sigma'][:] m, s = f['mu'][:], f['sigma'][:]
...@@ -193,7 +231,7 @@ def _compute_statistics_of_path(path, model, batch_size, dims, use_gpu, ...@@ -193,7 +231,7 @@ def _compute_statistics_of_path(path, model, batch_size, dims, use_gpu,
filenames, '*.jpg') or fnmatch.filter(filenames, '*.png'): filenames, '*.jpg') or fnmatch.filter(filenames, '*.png'):
files.append(os.path.join(root, filename)) files.append(os.path.join(root, filename))
m, s = _calculate_activation_statistics(files, model, premodel_path, m, s = _calculate_activation_statistics(files, model, premodel_path,
batch_size, dims, use_gpu) batch_size, dims, use_gpu, style)
return m, s return m, s
...@@ -202,7 +240,8 @@ def calculate_fid_given_paths(paths, ...@@ -202,7 +240,8 @@ def calculate_fid_given_paths(paths,
batch_size, batch_size,
use_gpu, use_gpu,
dims, dims,
model=None): model=None,
style = None):
assert os.path.exists( assert os.path.exists(
premodel_path premodel_path
), 'pretrain_model path {} is not exists! Please download it first'.format( ), 'pretrain_model path {} is not exists! Please download it first'.format(
...@@ -211,14 +250,15 @@ def calculate_fid_given_paths(paths, ...@@ -211,14 +250,15 @@ def calculate_fid_given_paths(paths,
if not os.path.exists(p): if not os.path.exists(p):
raise RuntimeError('Invalid path: %s' % p) raise RuntimeError('Invalid path: %s' % p)
if model is None: if model is None and style != 'stargan':
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] with fluid.dygraph.guard():
model = InceptionV3([block_idx], class_dim=1008) block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
model = InceptionV3([block_idx], class_dim=1008)
m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size, dims, m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size, dims,
use_gpu, premodel_path) use_gpu, premodel_path, style)
m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size, dims, m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size, dims,
use_gpu, premodel_path) use_gpu, premodel_path, style)
fid_value = _calculate_frechet_distance(m1, s1, m2, s2) fid_value = _calculate_frechet_distance(m1, s1, m2, s2)
return fid_value return fid_value
...@@ -38,6 +38,9 @@ def parse_args(): ...@@ -38,6 +38,9 @@ def parse_args():
type=int, type=int,
default=1, default=1,
help='sample number in a batch for inference.') help='sample number in a batch for inference.')
parser.add_argument('--style',
type=str,
help='calculation style: stargan or default (gan-compression style)')
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -50,10 +53,9 @@ def main(): ...@@ -50,10 +53,9 @@ def main():
inference_model_path = args.inference_model inference_model_path = args.inference_model
batch_size = args.batch_size batch_size = args.batch_size
with fluid.dygraph.guard(): fid_value = calculate_fid_given_paths(paths, inference_model_path,
fid_value = calculate_fid_given_paths(paths, inference_model_path, batch_size, args.use_gpu, 2048, style=args.style)
batch_size, args.use_gpu, 2048) print('FID: ', fid_value)
print('FID: ', fid_value)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册