diff --git a/configs/pix2pix_cityscapes.yaml b/configs/pix2pix_cityscapes.yaml index 11ad4fc39b1eac72498a4f1b715118eb14468f34..ba237d050cac8e066da9e9d3410e8751ba504b3b 100644 --- a/configs/pix2pix_cityscapes.yaml +++ b/configs/pix2pix_cityscapes.yaml @@ -61,7 +61,7 @@ dataset: keys: [image, image] test: name: PairedDataset - dataroot: data/cityscapes/test + dataroot: data/cityscapes/val num_workers: 4 batch_size: 1 preprocess: @@ -112,11 +112,9 @@ snapshot_config: interval: 5 validate: - interval: 500 + interval: 29750 save_img: false metrics: fid: # metric name, can be arbitrary name: FID batch_size: 8 - - diff --git a/configs/pix2pix_cityscapes_2gpus.yaml b/configs/pix2pix_cityscapes_2gpus.yaml index c242287f861bbbe288eded9d82c03a3198d2b6d8..824f0c487aa389e1e9ef54903854caa630174935 100644 --- a/configs/pix2pix_cityscapes_2gpus.yaml +++ b/configs/pix2pix_cityscapes_2gpus.yaml @@ -58,7 +58,7 @@ dataset: keys: [image, image] test: name: PairedDataset - dataroot: data/cityscapes/test + dataroot: data/cityscapes/val num_workers: 4 batch_size: 1 preprocess: @@ -109,11 +109,9 @@ snapshot_config: interval: 5 validate: - interval: 500 + interval: 29750 save_img: false metrics: fid: # metric name, can be arbitrary name: FID batch_size: 8 - - diff --git a/configs/pix2pix_facades.yaml b/configs/pix2pix_facades.yaml index 223b7133b730a0629183a639755d66d3a4ddc0c8..1654ac213325bad448f98e35f17bbb485cbf7588 100644 --- a/configs/pix2pix_facades.yaml +++ b/configs/pix2pix_facades.yaml @@ -109,11 +109,9 @@ snapshot_config: interval: 5 validate: - interval: 500 + interval: 4000 save_img: false metrics: fid: # metric name, can be arbitrary name: FID batch_size: 8 - - diff --git a/configs/stylegan_v2_256_ffhq.yaml b/configs/stylegan_v2_256_ffhq.yaml index c84c0ba4de18176d1e14ae58944d48153bfe3395..32c23c3d119315fbd8a065f9d5bb2730b486fd42 100644 --- a/configs/stylegan_v2_256_ffhq.yaml +++ b/configs/stylegan_v2_256_ffhq.yaml @@ -90,7 +90,7 @@ snapshot_config: interval: 5000 validate: - interval: 5000 + interval: 50000 save_imig: False metrics: fid: # metric name, can be arbitrary diff --git a/ppgan/metrics/fid.py b/ppgan/metrics/fid.py index ccf100d5fc592af8b214d8165eb6094d36082bec..1d36ca35193cd447d85c270e5f6a7eb5b439521c 100644 --- a/ppgan/metrics/fid.py +++ b/ppgan/metrics/fid.py @@ -39,9 +39,15 @@ inceptionV3 pretrain model is convert from pytorch, pretrain_model url is https: """ INCEPTIONV3_WEIGHT_URL = "https://paddlegan.bj.bcebos.com/InceptionV3.pdparams" + @METRICS.register() class FID(paddle.metric.Metric): - def __init__(self, batch_size=1, use_GPU=True, dims = 2048, premodel_path=None, model=None): + def __init__(self, + batch_size=1, + use_GPU=True, + dims=2048, + premodel_path=None, + model=None): self.batch_size = batch_size self.use_GPU = use_GPU self.dims = dims @@ -55,8 +61,8 @@ class FID(paddle.metric.Metric): param_dict = paddle.load(premodel_path) self.model.load_dict(param_dict) self.model.eval() - self.reset() - + self.reset() + def reset(self): self.preds = [] self.gts = [] @@ -72,7 +78,7 @@ class FID(paddle.metric.Metric): self.preds = np.concatenate(self.preds, axis=0) self.gts = np.concatenate(self.gts, axis=0) value = calculate_fid_given_img(self.preds, self.gts) - self.reset() + self.reset() return value def name(self): @@ -115,10 +121,10 @@ def _calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): def _get_activations_from_ims(img, model, batch_size, dims, use_gpu): n_batches = (len(img) + batch_size - 1) // batch_size n_used_img = len(img) - + pred_arr = np.empty((n_used_img, dims)) - - for i in tqdm(range(n_batches)): + + for i in range(n_batches): start = i * batch_size end = start + batch_size if end > len(img): @@ -126,7 +132,7 @@ def _get_activations_from_ims(img, model, batch_size, dims, use_gpu): images = img[start:end] if images.shape[1] != 3: images = images.transpose((0, 3, 1, 2)) - + images = paddle.to_tensor(images) pred = model(images)[0][0] pred_arr[start:end] = pred.reshape([end - start, -1]).cpu().numpy() @@ -138,16 +144,20 @@ def _compute_statistic_of_img(act): sigma = np.cov(act, rowvar=False) return mu, sigma + def calculate_inception_val(img_fake, img_real, batch_size, model, - use_gpu = True, - dims = 2048): - act_fake = _get_activations_from_ims(img_fake, model, batch_size, dims, use_gpu) - act_real = _get_activations_from_ims(img_real, model, batch_size, dims, use_gpu) + use_gpu=True, + dims=2048): + act_fake = _get_activations_from_ims(img_fake, model, batch_size, dims, + use_gpu) + act_real = _get_activations_from_ims(img_real, model, batch_size, dims, + use_gpu) return act_fake, act_real + def calculate_fid_given_img(act_fake, act_real): m1, s1 = _compute_statistic_of_img(act_fake) @@ -299,4 +309,3 @@ def calculate_fid_given_paths(paths, fid_value = _calculate_frechet_distance(m1, s1, m2, s2) return fid_value -