未验证 提交 1f335bbe 编写于 作者: L lzzyzlbb 提交者: GitHub

fix fid (#336)

* fix fid

* fix fid

* add pixel2pixel facades model
上级 bab376f4
...@@ -107,3 +107,13 @@ log_config: ...@@ -107,3 +107,13 @@ log_config:
snapshot_config: snapshot_config:
interval: 5 interval: 5
validate:
interval: 500
save_img: false
metrics:
fid: # metric name, can be arbitrary
name: FID
batch_size: 8
...@@ -107,3 +107,13 @@ log_config: ...@@ -107,3 +107,13 @@ log_config:
snapshot_config: snapshot_config:
interval: 5 interval: 5
validate:
interval: 500
save_img: false
metrics:
fid: # metric name, can be arbitrary
name: FID
batch_size: 8
...@@ -107,3 +107,13 @@ log_config: ...@@ -107,3 +107,13 @@ log_config:
snapshot_config: snapshot_config:
interval: 5 interval: 5
validate:
interval: 500
save_img: false
metrics:
fid: # metric name, can be arbitrary
name: FID
batch_size: 8
...@@ -43,6 +43,7 @@ ...@@ -43,6 +43,7 @@
| 模型 | 数据集 | 下载地址 | | 模型 | 数据集 | 下载地址 |
|---|---|---| |---|---|---|
| Pix2Pix_cityscapes | cityscapes | [Pix2Pix_cityscapes](https://paddlegan.bj.bcebos.com/models/Pix2Pix_cityscapes.pdparams) | Pix2Pix_cityscapes | cityscapes | [Pix2Pix_cityscapes](https://paddlegan.bj.bcebos.com/models/Pix2Pix_cityscapes.pdparams)
| Pix2Pix_facedes | facades | [Pix2Pix_facades](https://paddlegan.bj.bcebos.com/models/Pixel2Pixel_facades.pdparams)
......
...@@ -44,6 +44,7 @@ ...@@ -44,6 +44,7 @@
| 模型 | 数据集 | 下载地址 | | 模型 | 数据集 | 下载地址 |
|---|---|---| |---|---|---|
| Pix2Pix_cityscapes | cityscapes | [Pix2Pix_cityscapes](https://paddlegan.bj.bcebos.com/models/Pix2Pix_cityscapes.pdparams) | Pix2Pix_cityscapes | cityscapes | [Pix2Pix_cityscapes](https://paddlegan.bj.bcebos.com/models/Pix2Pix_cityscapes.pdparams)
| Pix2Pix_facedes | facades | [Pix2Pix_facades](https://paddlegan.bj.bcebos.com/models/Pixel2Pixel_facades.pdparams)
# 2 CycleGAN # 2 CycleGAN
......
...@@ -53,21 +53,30 @@ class FID(paddle.metric.Metric): ...@@ -53,21 +53,30 @@ class FID(paddle.metric.Metric):
premodel_path = get_weights_path_from_url(INCEPTIONV3_WEIGHT_URL) premodel_path = get_weights_path_from_url(INCEPTIONV3_WEIGHT_URL)
self.model = model self.model = model
param_dict = paddle.load(premodel_path) param_dict = paddle.load(premodel_path)
model.load_dict(param_dict) self.model.load_dict(param_dict)
model.eval() self.model.eval()
self.reset() self.reset()
def reset(self): def reset(self):
self.preds = []
self.gts = []
self.results = [] self.results = []
def update(self, preds, gts): def update(self, preds, gts):
value = calculate_fid_given_img(preds, gts, self.batch_size, self.model, self.use_GPU, self.dims) if len(preds.shape) >=4:
self.results.append(value) self.preds.append(preds)
self.gts.append(gts)
else:
for i in range(preds.shape[0]):
self.preds.append(preds[i,:,:,:,:])
self.gts.append(gts[i,:,:,:,:])
def accumulate(self): def accumulate(self):
if len(self.results) <= 0: self.preds = paddle.concat(self.preds, axis=0)
return 0. self.gts = paddle.concat(self.gts, axis=0)
return np.mean(self.results) value = calculate_fid_given_img(self.preds, self.gts, self.batch_size, self.model, self.use_GPU, self.dims)
self.reset()
return value
def name(self): def name(self):
return 'FID' return 'FID'
...@@ -123,7 +132,6 @@ def _get_activations_from_ims(img, model, batch_size, dims, use_gpu): ...@@ -123,7 +132,6 @@ def _get_activations_from_ims(img, model, batch_size, dims, use_gpu):
images = img[start:end] images = img[start:end]
if images.shape[1] != 3: if images.shape[1] != 3:
images = images.transpose((0, 3, 1, 2)) images = images.transpose((0, 3, 1, 2))
images /= 255
images = paddle.to_tensor(images) images = paddle.to_tensor(images)
pred = model(images)[0][0] pred = model(images)[0][0]
......
...@@ -141,5 +141,10 @@ class Pix2PixModel(BaseModel): ...@@ -141,5 +141,10 @@ class Pix2PixModel(BaseModel):
optimizers['optimG'].step() optimizers['optimG'].step()
def test_iter(self, metrics=None): def test_iter(self, metrics=None):
self.nets['netG'].eval()
self.forward()
with paddle.no_grad(): with paddle.no_grad():
self.forward() if metrics is not None:
for metric in metrics.values():
metric.update(self.fake_B, self.real_B)
self.nets['netG'].train()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册