From fc6af3653cf73218fb6d25207ba4f0f977eae9e0 Mon Sep 17 00:00:00 2001 From: Mark Ma <519329064@qq.com> Date: Sun, 30 Aug 2020 15:07:15 +0800 Subject: [PATCH] add stargan-v2 style FID calculation. add --style command line option to let user choose stargan or gan-compression style (by default gan-compression style will be used). move `dygraph.guard()` declaration into fid module for two reason: 1. the inference model didn't work in dygraph mode, so we dynamically choose whether to use dygraph mode after style is determined. 2. easier to use for end user (no need to call fluid.dygraph.guard() explicitly) --- ppgan/metric/compute_fid.py | 69 ++++++++++++++++++++++++++-------- ppgan/metric/test_fid_score.py | 10 +++-- 2 files changed, 59 insertions(+), 20 deletions(-) diff --git a/ppgan/metric/compute_fid.py b/ppgan/metric/compute_fid.py index c8fc805..8d000a0 100644 --- a/ppgan/metric/compute_fid.py +++ b/ppgan/metric/compute_fid.py @@ -16,6 +16,7 @@ import os import fnmatch import numpy as np import cv2 +from PIL import Image from cv2 import imread from scipy import linalg import paddle.fluid as fluid @@ -128,7 +129,7 @@ def calculate_fid_given_img(img_fake, return fid_value -def _get_activations(files, model, batch_size, dims, use_gpu, premodel_path): +def _get_activations(files, model, batch_size, dims, use_gpu, premodel_path, style=None): if len(files) % batch_size != 0: print(('Warning: number of images is not a multiple of the ' 'batch size. Some samples are going to be ignored.')) @@ -144,8 +145,23 @@ def _get_activations(files, model, batch_size, dims, use_gpu, premodel_path): for i in tqdm(range(n_batches)): start = i * batch_size end = start + batch_size - images = np.array( - [imread(str(f)).astype(np.float32) for f in files[start:end]]) + + # same as stargan-v2 official implementation: resize to 256 first, then resize to 299 + if style == 'stargan': + img_list = [] + for f in files[start:end]: + im = Image.open(str(f)).convert('RGB') + if im.size[0] != 299: + im = im.resize((256, 256), 2) + im = im.resize((299, 299), 2) + + img_list.append(np.array(im).astype('float32')) + + images = np.array( + img_list) + else: + images = np.array( + [imread(str(f)).astype(np.float32) for f in files[start:end]]) if len(images.shape) != 4: images = imread(str(files[start])) @@ -155,33 +171,53 @@ def _get_activations(files, model, batch_size, dims, use_gpu, premodel_path): images = images.transpose((0, 3, 1, 2)) images /= 255 - images = to_variable(images) - param_dict, _ = fluid.load_dygraph(premodel_path) - model.set_dict(param_dict) - model.eval() + # imagenet normalization + if style == 'stargan': + mean = np.array([0.485, 0.456, 0.406]).astype('float32') + std = np.array([0.229, 0.224, 0.225]).astype('float32') + images[:] = (images[:] - mean[:, None, None]) / std[:, None, None] - pred = model(images)[0][0].numpy() + if style=='stargan': + pred_arr[start:end] = inception_infer(images, premodel_path) + else: + with fluid.dygraph.guard(): + images = to_variable(images) + param_dict, _ = fluid.load_dygraph(premodel_path) + model.set_dict(param_dict) + model.eval() - pred_arr[start:end] = pred.reshape(end - start, -1) + pred = model(images)[0][0].numpy() + + pred_arr[start:end] = pred.reshape(end - start, -1) return pred_arr +def inception_infer(x, model_path): + exe = fluid.Executor() + [inference_program, feed_target_names, fetch_targets] = fluid.io.load_inference_model(model_path, exe) + results = exe.run(inference_program, + feed={feed_target_names[0]: x}, + fetch_list=fetch_targets) + return results[0] + + def _calculate_activation_statistics(files, model, premodel_path, batch_size=50, dims=2048, - use_gpu=False): + use_gpu=False, + style = None): act = _get_activations(files, model, batch_size, dims, use_gpu, - premodel_path) + premodel_path, style) mu = np.mean(act, axis=0) sigma = np.cov(act, rowvar=False) return mu, sigma def _compute_statistics_of_path(path, model, batch_size, dims, use_gpu, - premodel_path): + premodel_path, style=None): if path.endswith('.npz'): f = np.load(path) m, s = f['mu'][:], f['sigma'][:] @@ -193,7 +229,7 @@ def _compute_statistics_of_path(path, model, batch_size, dims, use_gpu, filenames, '*.jpg') or fnmatch.filter(filenames, '*.png'): files.append(os.path.join(root, filename)) m, s = _calculate_activation_statistics(files, model, premodel_path, - batch_size, dims, use_gpu) + batch_size, dims, use_gpu, style) return m, s @@ -202,7 +238,8 @@ def calculate_fid_given_paths(paths, batch_size, use_gpu, dims, - model=None): + model=None, + style = None): assert os.path.exists( premodel_path ), 'pretrain_model path {} is not exists! Please download it first'.format( @@ -216,9 +253,9 @@ def calculate_fid_given_paths(paths, model = InceptionV3([block_idx], class_dim=1008) m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size, dims, - use_gpu, premodel_path) + use_gpu, premodel_path, style) m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size, dims, - use_gpu, premodel_path) + use_gpu, premodel_path, style) fid_value = _calculate_frechet_distance(m1, s1, m2, s2) return fid_value diff --git a/ppgan/metric/test_fid_score.py b/ppgan/metric/test_fid_score.py index e8abcca..36412a5 100644 --- a/ppgan/metric/test_fid_score.py +++ b/ppgan/metric/test_fid_score.py @@ -38,6 +38,9 @@ def parse_args(): type=int, default=1, help='sample number in a batch for inference.') + parser.add_argument('--style', + type=str, + help='calculation style: stargan or default (gan-compression style)') args = parser.parse_args() return args @@ -50,10 +53,9 @@ def main(): inference_model_path = args.inference_model batch_size = args.batch_size - with fluid.dygraph.guard(): - fid_value = calculate_fid_given_paths(paths, inference_model_path, - batch_size, args.use_gpu, 2048) - print('FID: ', fid_value) + fid_value = calculate_fid_given_paths(paths, inference_model_path, + batch_size, args.use_gpu, 2048, style=args.style) + print('FID: ', fid_value) if __name__ == "__main__": -- GitLab