From 1f335bbe716ad78fb7136585d2c8d2b5bdf1373c Mon Sep 17 00:00:00 2001 From: lzzyzlbb <287246233@qq.com> Date: Fri, 4 Jun 2021 15:55:50 +0800 Subject: [PATCH] fix fid (#336) * fix fid * fix fid * add pixel2pixel facades model --- configs/pix2pix_cityscapes.yaml | 10 +++++++++ configs/pix2pix_cityscapes_2gpus.yaml | 10 +++++++++ configs/pix2pix_facades.yaml | 10 +++++++++ docs/en_US/tutorials/pix2pix_cyclegan.md | 1 + docs/zh_CN/tutorials/pix2pix_cyclegan.md | 1 + ppgan/metrics/fid.py | 26 ++++++++++++++++-------- ppgan/models/pix2pix_model.py | 7 ++++++- 7 files changed, 55 insertions(+), 10 deletions(-) diff --git a/configs/pix2pix_cityscapes.yaml b/configs/pix2pix_cityscapes.yaml index 5a24315..47f1716 100644 --- a/configs/pix2pix_cityscapes.yaml +++ b/configs/pix2pix_cityscapes.yaml @@ -107,3 +107,13 @@ log_config: snapshot_config: interval: 5 + +validate: + interval: 500 + save_img: false + metrics: + fid: # metric name, can be arbitrary + name: FID + batch_size: 8 + + diff --git a/configs/pix2pix_cityscapes_2gpus.yaml b/configs/pix2pix_cityscapes_2gpus.yaml index 2d27c3b..c242287 100644 --- a/configs/pix2pix_cityscapes_2gpus.yaml +++ b/configs/pix2pix_cityscapes_2gpus.yaml @@ -107,3 +107,13 @@ log_config: snapshot_config: interval: 5 + +validate: + interval: 500 + save_img: false + metrics: + fid: # metric name, can be arbitrary + name: FID + batch_size: 8 + + diff --git a/configs/pix2pix_facades.yaml b/configs/pix2pix_facades.yaml index b73005d..223b713 100644 --- a/configs/pix2pix_facades.yaml +++ b/configs/pix2pix_facades.yaml @@ -107,3 +107,13 @@ log_config: snapshot_config: interval: 5 + +validate: + interval: 500 + save_img: false + metrics: + fid: # metric name, can be arbitrary + name: FID + batch_size: 8 + + diff --git a/docs/en_US/tutorials/pix2pix_cyclegan.md b/docs/en_US/tutorials/pix2pix_cyclegan.md index 818ea8d..97bf65b 100644 --- a/docs/en_US/tutorials/pix2pix_cyclegan.md +++ b/docs/en_US/tutorials/pix2pix_cyclegan.md @@ -43,6 +43,7 @@ | 模型 | 数据集 | 下载地址 | |---|---|---| | Pix2Pix_cityscapes | cityscapes | [Pix2Pix_cityscapes](https://paddlegan.bj.bcebos.com/models/Pix2Pix_cityscapes.pdparams) +| Pix2Pix_facedes | facades | [Pix2Pix_facades](https://paddlegan.bj.bcebos.com/models/Pixel2Pixel_facades.pdparams) diff --git a/docs/zh_CN/tutorials/pix2pix_cyclegan.md b/docs/zh_CN/tutorials/pix2pix_cyclegan.md index 5a21130..a121cfd 100644 --- a/docs/zh_CN/tutorials/pix2pix_cyclegan.md +++ b/docs/zh_CN/tutorials/pix2pix_cyclegan.md @@ -44,6 +44,7 @@ | 模型 | 数据集 | 下载地址 | |---|---|---| | Pix2Pix_cityscapes | cityscapes | [Pix2Pix_cityscapes](https://paddlegan.bj.bcebos.com/models/Pix2Pix_cityscapes.pdparams) +| Pix2Pix_facedes | facades | [Pix2Pix_facades](https://paddlegan.bj.bcebos.com/models/Pixel2Pixel_facades.pdparams) # 2 CycleGAN diff --git a/ppgan/metrics/fid.py b/ppgan/metrics/fid.py index ab0ae54..0635ca3 100644 --- a/ppgan/metrics/fid.py +++ b/ppgan/metrics/fid.py @@ -53,21 +53,30 @@ class FID(paddle.metric.Metric): premodel_path = get_weights_path_from_url(INCEPTIONV3_WEIGHT_URL) self.model = model param_dict = paddle.load(premodel_path) - model.load_dict(param_dict) - model.eval() + self.model.load_dict(param_dict) + self.model.eval() self.reset() def reset(self): + self.preds = [] + self.gts = [] self.results = [] def update(self, preds, gts): - value = calculate_fid_given_img(preds, gts, self.batch_size, self.model, self.use_GPU, self.dims) - self.results.append(value) - + if len(preds.shape) >=4: + self.preds.append(preds) + self.gts.append(gts) + else: + for i in range(preds.shape[0]): + self.preds.append(preds[i,:,:,:,:]) + self.gts.append(gts[i,:,:,:,:]) + def accumulate(self): - if len(self.results) <= 0: - return 0. - return np.mean(self.results) + self.preds = paddle.concat(self.preds, axis=0) + self.gts = paddle.concat(self.gts, axis=0) + value = calculate_fid_given_img(self.preds, self.gts, self.batch_size, self.model, self.use_GPU, self.dims) + self.reset() + return value def name(self): return 'FID' @@ -123,7 +132,6 @@ def _get_activations_from_ims(img, model, batch_size, dims, use_gpu): images = img[start:end] if images.shape[1] != 3: images = images.transpose((0, 3, 1, 2)) - images /= 255 images = paddle.to_tensor(images) pred = model(images)[0][0] diff --git a/ppgan/models/pix2pix_model.py b/ppgan/models/pix2pix_model.py index d2ec0de..2c8d552 100644 --- a/ppgan/models/pix2pix_model.py +++ b/ppgan/models/pix2pix_model.py @@ -141,5 +141,10 @@ class Pix2PixModel(BaseModel): optimizers['optimG'].step() def test_iter(self, metrics=None): + self.nets['netG'].eval() + self.forward() with paddle.no_grad(): - self.forward() + if metrics is not None: + for metric in metrics.values(): + metric.update(self.fake_B, self.real_B) + self.nets['netG'].train() -- GitLab