diff --git a/ppgan/metric/compute_fid.py b/ppgan/metric/compute_fid.py index c8fc8059e2658768b1f07436f2bca6e08446014c..8d000a0d5235daf5cbb27bb5c421caaaad0743f1 100644 --- a/ppgan/metric/compute_fid.py +++ b/ppgan/metric/compute_fid.py @@ -16,6 +16,7 @@ import os import fnmatch import numpy as np import cv2 +from PIL import Image from cv2 import imread from scipy import linalg import paddle.fluid as fluid @@ -128,7 +129,7 @@ def calculate_fid_given_img(img_fake, return fid_value -def _get_activations(files, model, batch_size, dims, use_gpu, premodel_path): +def _get_activations(files, model, batch_size, dims, use_gpu, premodel_path, style=None): if len(files) % batch_size != 0: print(('Warning: number of images is not a multiple of the ' 'batch size. Some samples are going to be ignored.')) @@ -144,8 +145,23 @@ def _get_activations(files, model, batch_size, dims, use_gpu, premodel_path): for i in tqdm(range(n_batches)): start = i * batch_size end = start + batch_size - images = np.array( - [imread(str(f)).astype(np.float32) for f in files[start:end]]) + + # same as stargan-v2 official implementation: resize to 256 first, then resize to 299 + if style == 'stargan': + img_list = [] + for f in files[start:end]: + im = Image.open(str(f)).convert('RGB') + if im.size[0] != 299: + im = im.resize((256, 256), 2) + im = im.resize((299, 299), 2) + + img_list.append(np.array(im).astype('float32')) + + images = np.array( + img_list) + else: + images = np.array( + [imread(str(f)).astype(np.float32) for f in files[start:end]]) if len(images.shape) != 4: images = imread(str(files[start])) @@ -155,33 +171,53 @@ def _get_activations(files, model, batch_size, dims, use_gpu, premodel_path): images = images.transpose((0, 3, 1, 2)) images /= 255 - images = to_variable(images) - param_dict, _ = fluid.load_dygraph(premodel_path) - model.set_dict(param_dict) - model.eval() + # imagenet normalization + if style == 'stargan': + mean = np.array([0.485, 0.456, 0.406]).astype('float32') + std = np.array([0.229, 0.224, 0.225]).astype('float32') + images[:] = (images[:] - mean[:, None, None]) / std[:, None, None] - pred = model(images)[0][0].numpy() + if style=='stargan': + pred_arr[start:end] = inception_infer(images, premodel_path) + else: + with fluid.dygraph.guard(): + images = to_variable(images) + param_dict, _ = fluid.load_dygraph(premodel_path) + model.set_dict(param_dict) + model.eval() - pred_arr[start:end] = pred.reshape(end - start, -1) + pred = model(images)[0][0].numpy() + + pred_arr[start:end] = pred.reshape(end - start, -1) return pred_arr +def inception_infer(x, model_path): + exe = fluid.Executor() + [inference_program, feed_target_names, fetch_targets] = fluid.io.load_inference_model(model_path, exe) + results = exe.run(inference_program, + feed={feed_target_names[0]: x}, + fetch_list=fetch_targets) + return results[0] + + def _calculate_activation_statistics(files, model, premodel_path, batch_size=50, dims=2048, - use_gpu=False): + use_gpu=False, + style = None): act = _get_activations(files, model, batch_size, dims, use_gpu, - premodel_path) + premodel_path, style) mu = np.mean(act, axis=0) sigma = np.cov(act, rowvar=False) return mu, sigma def _compute_statistics_of_path(path, model, batch_size, dims, use_gpu, - premodel_path): + premodel_path, style=None): if path.endswith('.npz'): f = np.load(path) m, s = f['mu'][:], f['sigma'][:] @@ -193,7 +229,7 @@ def _compute_statistics_of_path(path, model, batch_size, dims, use_gpu, filenames, '*.jpg') or fnmatch.filter(filenames, '*.png'): files.append(os.path.join(root, filename)) m, s = _calculate_activation_statistics(files, model, premodel_path, - batch_size, dims, use_gpu) + batch_size, dims, use_gpu, style) return m, s @@ -202,7 +238,8 @@ def calculate_fid_given_paths(paths, batch_size, use_gpu, dims, - model=None): + model=None, + style = None): assert os.path.exists( premodel_path ), 'pretrain_model path {} is not exists! Please download it first'.format( @@ -216,9 +253,9 @@ def calculate_fid_given_paths(paths, model = InceptionV3([block_idx], class_dim=1008) m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size, dims, - use_gpu, premodel_path) + use_gpu, premodel_path, style) m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size, dims, - use_gpu, premodel_path) + use_gpu, premodel_path, style) fid_value = _calculate_frechet_distance(m1, s1, m2, s2) return fid_value diff --git a/ppgan/metric/test_fid_score.py b/ppgan/metric/test_fid_score.py index e8abccaaf3e8c4bda5a5c51e7621014b12a0664d..36412a55104085154dc7dac3b7d923a369ceab07 100644 --- a/ppgan/metric/test_fid_score.py +++ b/ppgan/metric/test_fid_score.py @@ -38,6 +38,9 @@ def parse_args(): type=int, default=1, help='sample number in a batch for inference.') + parser.add_argument('--style', + type=str, + help='calculation style: stargan or default (gan-compression style)') args = parser.parse_args() return args @@ -50,10 +53,9 @@ def main(): inference_model_path = args.inference_model batch_size = args.batch_size - with fluid.dygraph.guard(): - fid_value = calculate_fid_given_paths(paths, inference_model_path, - batch_size, args.use_gpu, 2048) - print('FID: ', fid_value) + fid_value = calculate_fid_given_paths(paths, inference_model_path, + batch_size, args.use_gpu, 2048, style=args.style) + print('FID: ', fid_value) if __name__ == "__main__":