From 19fe4fbc10ea4394638e227abe5daf5c026a671f Mon Sep 17 00:00:00 2001 From: lzzyzlbb <287246233@qq.com> Date: Tue, 22 Jun 2021 16:05:43 +0800 Subject: [PATCH] add fid for style gan (#347) * add fid for style gan * add fid for style gan * add fid for style gan --- configs/stylegan_v2_256_ffhq.yaml | 9 +++++++ ppgan/metrics/fid.py | 43 ++++++++++++++----------------- ppgan/models/styleganv2_model.py | 16 +++++++++++- 3 files changed, 43 insertions(+), 25 deletions(-) diff --git a/configs/stylegan_v2_256_ffhq.yaml b/configs/stylegan_v2_256_ffhq.yaml index 1265589..13d91f7 100644 --- a/configs/stylegan_v2_256_ffhq.yaml +++ b/configs/stylegan_v2_256_ffhq.yaml @@ -23,6 +23,7 @@ model: params: gen_iters: 4 disc_iters: 16 + max_eval_steps: 50000 export_model: - {name: 'gen', inputs_num: 2} @@ -72,3 +73,11 @@ log_config: snapshot_config: interval: 5000 + +validate: + interval: 5000 + save_imig: False + metrics: + fid: # metric name, can be arbitrary + name: FID + batch_size: 4 diff --git a/ppgan/metrics/fid.py b/ppgan/metrics/fid.py index 0635ca3..ccf100d 100644 --- a/ppgan/metrics/fid.py +++ b/ppgan/metrics/fid.py @@ -48,7 +48,7 @@ class FID(paddle.metric.Metric): self.premodel_path = premodel_path if model is None: block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] - model = InceptionV3([block_idx]) + model = InceptionV3([block_idx], normalize_input=False) if premodel_path is None: premodel_path = get_weights_path_from_url(INCEPTIONV3_WEIGHT_URL) self.model = model @@ -63,18 +63,15 @@ class FID(paddle.metric.Metric): self.results = [] def update(self, preds, gts): - if len(preds.shape) >=4: - self.preds.append(preds) - self.gts.append(gts) - else: - for i in range(preds.shape[0]): - self.preds.append(preds[i,:,:,:,:]) - self.gts.append(gts[i,:,:,:,:]) - + preds_inception, gts_inception = calculate_inception_val( + preds, gts, self.batch_size, self.model, self.use_GPU, self.dims) + self.preds.append(preds_inception) + self.gts.append(gts_inception) + def accumulate(self): - self.preds = paddle.concat(self.preds, axis=0) - self.gts = paddle.concat(self.gts, axis=0) - value = calculate_fid_given_img(self.preds, self.gts, self.batch_size, self.model, self.use_GPU, self.dims) + self.preds = np.concatenate(self.preds, axis=0) + self.gts = np.concatenate(self.gts, axis=0) + value = calculate_fid_given_img(self.preds, self.gts) self.reset() return value @@ -82,8 +79,6 @@ class FID(paddle.metric.Metric): return 'FID' - - def _calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): m1 = np.atleast_1d(mu1) m2 = np.atleast_1d(mu2) @@ -111,7 +106,6 @@ def _calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): m = np.max(np.abs(covmean.imag)) raise ValueError('Imaginary component {}'.format(m)) covmean = covmean.real - tr_covmean = np.trace(covmean) return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - @@ -132,32 +126,32 @@ def _get_activations_from_ims(img, model, batch_size, dims, use_gpu): images = img[start:end] if images.shape[1] != 3: images = images.transpose((0, 3, 1, 2)) - + images = paddle.to_tensor(images) pred = model(images)[0][0] pred_arr[start:end] = pred.reshape([end - start, -1]).cpu().numpy() return pred_arr -def _compute_statistic_of_img(img, model, batch_size, dims, use_gpu): - act = _get_activations_from_ims(img, model, batch_size, dims, use_gpu) +def _compute_statistic_of_img(act): mu = np.mean(act, axis=0) sigma = np.cov(act, rowvar=False) return mu, sigma - -def calculate_fid_given_img(img_fake, +def calculate_inception_val(img_fake, img_real, batch_size, model, use_gpu = True, dims = 2048): + act_fake = _get_activations_from_ims(img_fake, model, batch_size, dims, use_gpu) + act_real = _get_activations_from_ims(img_real, model, batch_size, dims, use_gpu) + return act_fake, act_real - m1, s1 = _compute_statistic_of_img(img_fake, model, batch_size, dims, - use_gpu) - m2, s2 = _compute_statistic_of_img(img_real, model, batch_size, dims, - use_gpu) +def calculate_fid_given_img(act_fake, act_real): + m1, s1 = _compute_statistic_of_img(act_fake) + m2, s2 = _compute_statistic_of_img(act_real) fid_value = _calculate_frechet_distance(m1, s1, m2, s2) return fid_value @@ -305,3 +299,4 @@ def calculate_fid_given_paths(paths, fid_value = _calculate_frechet_distance(m1, s1, m2, s2) return fid_value + diff --git a/ppgan/models/styleganv2_model.py b/ppgan/models/styleganv2_model.py index 36071e0..827e6ed 100644 --- a/ppgan/models/styleganv2_model.py +++ b/ppgan/models/styleganv2_model.py @@ -79,7 +79,8 @@ class StyleGAN2Model(BaseModel): r1_reg_weight=10., path_reg_weight=2., path_batch_shrink=2., - params=None): + params=None, + max_eval_steps=50000): """Initialize the CycleGAN class. Args: @@ -107,6 +108,7 @@ class StyleGAN2Model(BaseModel): self.mean_path_length = 0 self.nets['gen'] = build_generator(generator) + self.max_eval_steps = max_eval_steps # define discriminators if discriminator: @@ -280,3 +282,15 @@ class StyleGAN2Model(BaseModel): self.visual_items['fake_img_ema'] = sample self.current_iter += 1 + + def test_iter(self, metrics=None): + self.nets['gen_ema'].eval() + batch = self.real_img.shape[0] + noises = [paddle.randn([batch, self.num_style_feat])] + fake_img, _ = self.nets['gen_ema'](noises) + with paddle.no_grad(): + if metrics is not None: + for metric in metrics.values(): + metric.update(fake_img, self.real_img) + self.nets['gen_ema'].train() + -- GitLab