未验证 提交 1f335bbe 编写于 作者: L lzzyzlbb 提交者: GitHub

fix fid (#336)

* fix fid

* fix fid

* add pixel2pixel facades model
上级 bab376f4
......@@ -107,3 +107,13 @@ log_config:
snapshot_config:
interval: 5
validate:
interval: 500
save_img: false
metrics:
fid: # metric name, can be arbitrary
name: FID
batch_size: 8
......@@ -107,3 +107,13 @@ log_config:
snapshot_config:
interval: 5
validate:
interval: 500
save_img: false
metrics:
fid: # metric name, can be arbitrary
name: FID
batch_size: 8
......@@ -107,3 +107,13 @@ log_config:
snapshot_config:
interval: 5
validate:
interval: 500
save_img: false
metrics:
fid: # metric name, can be arbitrary
name: FID
batch_size: 8
......@@ -43,6 +43,7 @@
| 模型 | 数据集 | 下载地址 |
|---|---|---|
| Pix2Pix_cityscapes | cityscapes | [Pix2Pix_cityscapes](https://paddlegan.bj.bcebos.com/models/Pix2Pix_cityscapes.pdparams)
| Pix2Pix_facedes | facades | [Pix2Pix_facades](https://paddlegan.bj.bcebos.com/models/Pixel2Pixel_facades.pdparams)
......
......@@ -44,6 +44,7 @@
| 模型 | 数据集 | 下载地址 |
|---|---|---|
| Pix2Pix_cityscapes | cityscapes | [Pix2Pix_cityscapes](https://paddlegan.bj.bcebos.com/models/Pix2Pix_cityscapes.pdparams)
| Pix2Pix_facedes | facades | [Pix2Pix_facades](https://paddlegan.bj.bcebos.com/models/Pixel2Pixel_facades.pdparams)
# 2 CycleGAN
......
......@@ -53,21 +53,30 @@ class FID(paddle.metric.Metric):
premodel_path = get_weights_path_from_url(INCEPTIONV3_WEIGHT_URL)
self.model = model
param_dict = paddle.load(premodel_path)
model.load_dict(param_dict)
model.eval()
self.model.load_dict(param_dict)
self.model.eval()
self.reset()
def reset(self):
self.preds = []
self.gts = []
self.results = []
def update(self, preds, gts):
value = calculate_fid_given_img(preds, gts, self.batch_size, self.model, self.use_GPU, self.dims)
self.results.append(value)
if len(preds.shape) >=4:
self.preds.append(preds)
self.gts.append(gts)
else:
for i in range(preds.shape[0]):
self.preds.append(preds[i,:,:,:,:])
self.gts.append(gts[i,:,:,:,:])
def accumulate(self):
if len(self.results) <= 0:
return 0.
return np.mean(self.results)
self.preds = paddle.concat(self.preds, axis=0)
self.gts = paddle.concat(self.gts, axis=0)
value = calculate_fid_given_img(self.preds, self.gts, self.batch_size, self.model, self.use_GPU, self.dims)
self.reset()
return value
def name(self):
return 'FID'
......@@ -123,7 +132,6 @@ def _get_activations_from_ims(img, model, batch_size, dims, use_gpu):
images = img[start:end]
if images.shape[1] != 3:
images = images.transpose((0, 3, 1, 2))
images /= 255
images = paddle.to_tensor(images)
pred = model(images)[0][0]
......
......@@ -141,5 +141,10 @@ class Pix2PixModel(BaseModel):
optimizers['optimG'].step()
def test_iter(self, metrics=None):
self.nets['netG'].eval()
self.forward()
with paddle.no_grad():
self.forward()
if metrics is not None:
for metric in metrics.values():
metric.update(self.fake_B, self.real_B)
self.nets['netG'].train()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册