From 0666e2e8b76fddc046f66075b72830e0bff0fe32 Mon Sep 17 00:00:00 2001 From: LielinJiang <50691816+LielinJiang@users.noreply.github.com> Date: Wed, 4 Aug 2021 16:51:16 +0800 Subject: [PATCH] increase validate interval (#382) --- configs/pix2pix_cityscapes.yaml | 6 ++--- configs/pix2pix_cityscapes_2gpus.yaml | 6 ++--- configs/pix2pix_facades.yaml | 4 +-- configs/stylegan_v2_256_ffhq.yaml | 2 +- ppgan/metrics/fid.py | 35 +++++++++++++++++---------- 5 files changed, 28 insertions(+), 25 deletions(-) diff --git a/configs/pix2pix_cityscapes.yaml b/configs/pix2pix_cityscapes.yaml index 11ad4fc..ba237d0 100644 --- a/configs/pix2pix_cityscapes.yaml +++ b/configs/pix2pix_cityscapes.yaml @@ -61,7 +61,7 @@ dataset: keys: [image, image] test: name: PairedDataset - dataroot: data/cityscapes/test + dataroot: data/cityscapes/val num_workers: 4 batch_size: 1 preprocess: @@ -112,11 +112,9 @@ snapshot_config: interval: 5 validate: - interval: 500 + interval: 29750 save_img: false metrics: fid: # metric name, can be arbitrary name: FID batch_size: 8 - - diff --git a/configs/pix2pix_cityscapes_2gpus.yaml b/configs/pix2pix_cityscapes_2gpus.yaml index c242287..824f0c4 100644 --- a/configs/pix2pix_cityscapes_2gpus.yaml +++ b/configs/pix2pix_cityscapes_2gpus.yaml @@ -58,7 +58,7 @@ dataset: keys: [image, image] test: name: PairedDataset - dataroot: data/cityscapes/test + dataroot: data/cityscapes/val num_workers: 4 batch_size: 1 preprocess: @@ -109,11 +109,9 @@ snapshot_config: interval: 5 validate: - interval: 500 + interval: 29750 save_img: false metrics: fid: # metric name, can be arbitrary name: FID batch_size: 8 - - diff --git a/configs/pix2pix_facades.yaml b/configs/pix2pix_facades.yaml index 223b713..1654ac2 100644 --- a/configs/pix2pix_facades.yaml +++ b/configs/pix2pix_facades.yaml @@ -109,11 +109,9 @@ snapshot_config: interval: 5 validate: - interval: 500 + interval: 4000 save_img: false metrics: fid: # metric name, can be arbitrary name: FID batch_size: 8 - - diff --git a/configs/stylegan_v2_256_ffhq.yaml b/configs/stylegan_v2_256_ffhq.yaml index c84c0ba..32c23c3 100644 --- a/configs/stylegan_v2_256_ffhq.yaml +++ b/configs/stylegan_v2_256_ffhq.yaml @@ -90,7 +90,7 @@ snapshot_config: interval: 5000 validate: - interval: 5000 + interval: 50000 save_imig: False metrics: fid: # metric name, can be arbitrary diff --git a/ppgan/metrics/fid.py b/ppgan/metrics/fid.py index ccf100d..1d36ca3 100644 --- a/ppgan/metrics/fid.py +++ b/ppgan/metrics/fid.py @@ -39,9 +39,15 @@ inceptionV3 pretrain model is convert from pytorch, pretrain_model url is https: """ INCEPTIONV3_WEIGHT_URL = "https://paddlegan.bj.bcebos.com/InceptionV3.pdparams" + @METRICS.register() class FID(paddle.metric.Metric): - def __init__(self, batch_size=1, use_GPU=True, dims = 2048, premodel_path=None, model=None): + def __init__(self, + batch_size=1, + use_GPU=True, + dims=2048, + premodel_path=None, + model=None): self.batch_size = batch_size self.use_GPU = use_GPU self.dims = dims @@ -55,8 +61,8 @@ class FID(paddle.metric.Metric): param_dict = paddle.load(premodel_path) self.model.load_dict(param_dict) self.model.eval() - self.reset() - + self.reset() + def reset(self): self.preds = [] self.gts = [] @@ -72,7 +78,7 @@ class FID(paddle.metric.Metric): self.preds = np.concatenate(self.preds, axis=0) self.gts = np.concatenate(self.gts, axis=0) value = calculate_fid_given_img(self.preds, self.gts) - self.reset() + self.reset() return value def name(self): @@ -115,10 +121,10 @@ def _calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): def _get_activations_from_ims(img, model, batch_size, dims, use_gpu): n_batches = (len(img) + batch_size - 1) // batch_size n_used_img = len(img) - + pred_arr = np.empty((n_used_img, dims)) - - for i in tqdm(range(n_batches)): + + for i in range(n_batches): start = i * batch_size end = start + batch_size if end > len(img): @@ -126,7 +132,7 @@ def _get_activations_from_ims(img, model, batch_size, dims, use_gpu): images = img[start:end] if images.shape[1] != 3: images = images.transpose((0, 3, 1, 2)) - + images = paddle.to_tensor(images) pred = model(images)[0][0] pred_arr[start:end] = pred.reshape([end - start, -1]).cpu().numpy() @@ -138,16 +144,20 @@ def _compute_statistic_of_img(act): sigma = np.cov(act, rowvar=False) return mu, sigma + def calculate_inception_val(img_fake, img_real, batch_size, model, - use_gpu = True, - dims = 2048): - act_fake = _get_activations_from_ims(img_fake, model, batch_size, dims, use_gpu) - act_real = _get_activations_from_ims(img_real, model, batch_size, dims, use_gpu) + use_gpu=True, + dims=2048): + act_fake = _get_activations_from_ims(img_fake, model, batch_size, dims, + use_gpu) + act_real = _get_activations_from_ims(img_real, model, batch_size, dims, + use_gpu) return act_fake, act_real + def calculate_fid_given_img(act_fake, act_real): m1, s1 = _compute_statistic_of_img(act_fake) @@ -299,4 +309,3 @@ def calculate_fid_given_paths(paths, fid_value = _calculate_frechet_distance(m1, s1, m2, s2) return fid_value - -- GitLab