未验证 提交 0666e2e8 编写于 作者: L LielinJiang 提交者: GitHub

increase validate interval (#382)

上级 f84ddda4
......@@ -61,7 +61,7 @@ dataset:
keys: [image, image]
test:
name: PairedDataset
dataroot: data/cityscapes/test
dataroot: data/cityscapes/val
num_workers: 4
batch_size: 1
preprocess:
......@@ -112,11 +112,9 @@ snapshot_config:
interval: 5
validate:
interval: 500
interval: 29750
save_img: false
metrics:
fid: # metric name, can be arbitrary
name: FID
batch_size: 8
......@@ -58,7 +58,7 @@ dataset:
keys: [image, image]
test:
name: PairedDataset
dataroot: data/cityscapes/test
dataroot: data/cityscapes/val
num_workers: 4
batch_size: 1
preprocess:
......@@ -109,11 +109,9 @@ snapshot_config:
interval: 5
validate:
interval: 500
interval: 29750
save_img: false
metrics:
fid: # metric name, can be arbitrary
name: FID
batch_size: 8
......@@ -109,11 +109,9 @@ snapshot_config:
interval: 5
validate:
interval: 500
interval: 4000
save_img: false
metrics:
fid: # metric name, can be arbitrary
name: FID
batch_size: 8
......@@ -90,7 +90,7 @@ snapshot_config:
interval: 5000
validate:
interval: 5000
interval: 50000
save_imig: False
metrics:
fid: # metric name, can be arbitrary
......
......@@ -39,9 +39,15 @@ inceptionV3 pretrain model is convert from pytorch, pretrain_model url is https:
"""
INCEPTIONV3_WEIGHT_URL = "https://paddlegan.bj.bcebos.com/InceptionV3.pdparams"
@METRICS.register()
class FID(paddle.metric.Metric):
def __init__(self, batch_size=1, use_GPU=True, dims = 2048, premodel_path=None, model=None):
def __init__(self,
batch_size=1,
use_GPU=True,
dims=2048,
premodel_path=None,
model=None):
self.batch_size = batch_size
self.use_GPU = use_GPU
self.dims = dims
......@@ -55,8 +61,8 @@ class FID(paddle.metric.Metric):
param_dict = paddle.load(premodel_path)
self.model.load_dict(param_dict)
self.model.eval()
self.reset()
self.reset()
def reset(self):
self.preds = []
self.gts = []
......@@ -72,7 +78,7 @@ class FID(paddle.metric.Metric):
self.preds = np.concatenate(self.preds, axis=0)
self.gts = np.concatenate(self.gts, axis=0)
value = calculate_fid_given_img(self.preds, self.gts)
self.reset()
self.reset()
return value
def name(self):
......@@ -115,10 +121,10 @@ def _calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
def _get_activations_from_ims(img, model, batch_size, dims, use_gpu):
n_batches = (len(img) + batch_size - 1) // batch_size
n_used_img = len(img)
pred_arr = np.empty((n_used_img, dims))
for i in tqdm(range(n_batches)):
for i in range(n_batches):
start = i * batch_size
end = start + batch_size
if end > len(img):
......@@ -126,7 +132,7 @@ def _get_activations_from_ims(img, model, batch_size, dims, use_gpu):
images = img[start:end]
if images.shape[1] != 3:
images = images.transpose((0, 3, 1, 2))
images = paddle.to_tensor(images)
pred = model(images)[0][0]
pred_arr[start:end] = pred.reshape([end - start, -1]).cpu().numpy()
......@@ -138,16 +144,20 @@ def _compute_statistic_of_img(act):
sigma = np.cov(act, rowvar=False)
return mu, sigma
def calculate_inception_val(img_fake,
img_real,
batch_size,
model,
use_gpu = True,
dims = 2048):
act_fake = _get_activations_from_ims(img_fake, model, batch_size, dims, use_gpu)
act_real = _get_activations_from_ims(img_real, model, batch_size, dims, use_gpu)
use_gpu=True,
dims=2048):
act_fake = _get_activations_from_ims(img_fake, model, batch_size, dims,
use_gpu)
act_real = _get_activations_from_ims(img_real, model, batch_size, dims,
use_gpu)
return act_fake, act_real
def calculate_fid_given_img(act_fake, act_real):
m1, s1 = _compute_statistic_of_img(act_fake)
......@@ -299,4 +309,3 @@ def calculate_fid_given_paths(paths,
fid_value = _calculate_frechet_distance(m1, s1, m2, s2)
return fid_value
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册