提交 fc6af365 编写于 作者: M Mark Ma

add stargan-v2 style FID calculation.

add --style command line option to let user choose stargan or gan-compression style (by default gan-compression style will be used).
move `dygraph.guard()` declaration into fid module for two reason: 1. the inference model didn't work in dygraph mode, so we dynamically choose whether to use dygraph mode after style is determined. 2. easier to use for end user (no need to call fluid.dygraph.guard() explicitly)
上级 5b31853d
......@@ -16,6 +16,7 @@ import os
import fnmatch
import numpy as np
import cv2
from PIL import Image
from cv2 import imread
from scipy import linalg
import paddle.fluid as fluid
......@@ -128,7 +129,7 @@ def calculate_fid_given_img(img_fake,
return fid_value
def _get_activations(files, model, batch_size, dims, use_gpu, premodel_path):
def _get_activations(files, model, batch_size, dims, use_gpu, premodel_path, style=None):
if len(files) % batch_size != 0:
print(('Warning: number of images is not a multiple of the '
'batch size. Some samples are going to be ignored.'))
......@@ -144,8 +145,23 @@ def _get_activations(files, model, batch_size, dims, use_gpu, premodel_path):
for i in tqdm(range(n_batches)):
start = i * batch_size
end = start + batch_size
images = np.array(
[imread(str(f)).astype(np.float32) for f in files[start:end]])
# same as stargan-v2 official implementation: resize to 256 first, then resize to 299
if style == 'stargan':
img_list = []
for f in files[start:end]:
im = Image.open(str(f)).convert('RGB')
if im.size[0] != 299:
im = im.resize((256, 256), 2)
im = im.resize((299, 299), 2)
img_list.append(np.array(im).astype('float32'))
images = np.array(
img_list)
else:
images = np.array(
[imread(str(f)).astype(np.float32) for f in files[start:end]])
if len(images.shape) != 4:
images = imread(str(files[start]))
......@@ -155,33 +171,53 @@ def _get_activations(files, model, batch_size, dims, use_gpu, premodel_path):
images = images.transpose((0, 3, 1, 2))
images /= 255
images = to_variable(images)
param_dict, _ = fluid.load_dygraph(premodel_path)
model.set_dict(param_dict)
model.eval()
# imagenet normalization
if style == 'stargan':
mean = np.array([0.485, 0.456, 0.406]).astype('float32')
std = np.array([0.229, 0.224, 0.225]).astype('float32')
images[:] = (images[:] - mean[:, None, None]) / std[:, None, None]
pred = model(images)[0][0].numpy()
if style=='stargan':
pred_arr[start:end] = inception_infer(images, premodel_path)
else:
with fluid.dygraph.guard():
images = to_variable(images)
param_dict, _ = fluid.load_dygraph(premodel_path)
model.set_dict(param_dict)
model.eval()
pred_arr[start:end] = pred.reshape(end - start, -1)
pred = model(images)[0][0].numpy()
pred_arr[start:end] = pred.reshape(end - start, -1)
return pred_arr
def inception_infer(x, model_path):
exe = fluid.Executor()
[inference_program, feed_target_names, fetch_targets] = fluid.io.load_inference_model(model_path, exe)
results = exe.run(inference_program,
feed={feed_target_names[0]: x},
fetch_list=fetch_targets)
return results[0]
def _calculate_activation_statistics(files,
model,
premodel_path,
batch_size=50,
dims=2048,
use_gpu=False):
use_gpu=False,
style = None):
act = _get_activations(files, model, batch_size, dims, use_gpu,
premodel_path)
premodel_path, style)
mu = np.mean(act, axis=0)
sigma = np.cov(act, rowvar=False)
return mu, sigma
def _compute_statistics_of_path(path, model, batch_size, dims, use_gpu,
premodel_path):
premodel_path, style=None):
if path.endswith('.npz'):
f = np.load(path)
m, s = f['mu'][:], f['sigma'][:]
......@@ -193,7 +229,7 @@ def _compute_statistics_of_path(path, model, batch_size, dims, use_gpu,
filenames, '*.jpg') or fnmatch.filter(filenames, '*.png'):
files.append(os.path.join(root, filename))
m, s = _calculate_activation_statistics(files, model, premodel_path,
batch_size, dims, use_gpu)
batch_size, dims, use_gpu, style)
return m, s
......@@ -202,7 +238,8 @@ def calculate_fid_given_paths(paths,
batch_size,
use_gpu,
dims,
model=None):
model=None,
style = None):
assert os.path.exists(
premodel_path
), 'pretrain_model path {} is not exists! Please download it first'.format(
......@@ -216,9 +253,9 @@ def calculate_fid_given_paths(paths,
model = InceptionV3([block_idx], class_dim=1008)
m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size, dims,
use_gpu, premodel_path)
use_gpu, premodel_path, style)
m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size, dims,
use_gpu, premodel_path)
use_gpu, premodel_path, style)
fid_value = _calculate_frechet_distance(m1, s1, m2, s2)
return fid_value
......@@ -38,6 +38,9 @@ def parse_args():
type=int,
default=1,
help='sample number in a batch for inference.')
parser.add_argument('--style',
type=str,
help='calculation style: stargan or default (gan-compression style)')
args = parser.parse_args()
return args
......@@ -50,10 +53,9 @@ def main():
inference_model_path = args.inference_model
batch_size = args.batch_size
with fluid.dygraph.guard():
fid_value = calculate_fid_given_paths(paths, inference_model_path,
batch_size, args.use_gpu, 2048)
print('FID: ', fid_value)
fid_value = calculate_fid_given_paths(paths, inference_model_path,
batch_size, args.use_gpu, 2048, style=args.style)
print('FID: ', fid_value)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册