未验证 提交 0666e2e8 编写于 作者: L LielinJiang 提交者: GitHub

increase validate interval (#382)

上级 f84ddda4
...@@ -61,7 +61,7 @@ dataset: ...@@ -61,7 +61,7 @@ dataset:
keys: [image, image] keys: [image, image]
test: test:
name: PairedDataset name: PairedDataset
dataroot: data/cityscapes/test dataroot: data/cityscapes/val
num_workers: 4 num_workers: 4
batch_size: 1 batch_size: 1
preprocess: preprocess:
...@@ -112,11 +112,9 @@ snapshot_config: ...@@ -112,11 +112,9 @@ snapshot_config:
interval: 5 interval: 5
validate: validate:
interval: 500 interval: 29750
save_img: false save_img: false
metrics: metrics:
fid: # metric name, can be arbitrary fid: # metric name, can be arbitrary
name: FID name: FID
batch_size: 8 batch_size: 8
...@@ -58,7 +58,7 @@ dataset: ...@@ -58,7 +58,7 @@ dataset:
keys: [image, image] keys: [image, image]
test: test:
name: PairedDataset name: PairedDataset
dataroot: data/cityscapes/test dataroot: data/cityscapes/val
num_workers: 4 num_workers: 4
batch_size: 1 batch_size: 1
preprocess: preprocess:
...@@ -109,11 +109,9 @@ snapshot_config: ...@@ -109,11 +109,9 @@ snapshot_config:
interval: 5 interval: 5
validate: validate:
interval: 500 interval: 29750
save_img: false save_img: false
metrics: metrics:
fid: # metric name, can be arbitrary fid: # metric name, can be arbitrary
name: FID name: FID
batch_size: 8 batch_size: 8
...@@ -109,11 +109,9 @@ snapshot_config: ...@@ -109,11 +109,9 @@ snapshot_config:
interval: 5 interval: 5
validate: validate:
interval: 500 interval: 4000
save_img: false save_img: false
metrics: metrics:
fid: # metric name, can be arbitrary fid: # metric name, can be arbitrary
name: FID name: FID
batch_size: 8 batch_size: 8
...@@ -90,7 +90,7 @@ snapshot_config: ...@@ -90,7 +90,7 @@ snapshot_config:
interval: 5000 interval: 5000
validate: validate:
interval: 5000 interval: 50000
save_imig: False save_imig: False
metrics: metrics:
fid: # metric name, can be arbitrary fid: # metric name, can be arbitrary
......
...@@ -39,9 +39,15 @@ inceptionV3 pretrain model is convert from pytorch, pretrain_model url is https: ...@@ -39,9 +39,15 @@ inceptionV3 pretrain model is convert from pytorch, pretrain_model url is https:
""" """
INCEPTIONV3_WEIGHT_URL = "https://paddlegan.bj.bcebos.com/InceptionV3.pdparams" INCEPTIONV3_WEIGHT_URL = "https://paddlegan.bj.bcebos.com/InceptionV3.pdparams"
@METRICS.register() @METRICS.register()
class FID(paddle.metric.Metric): class FID(paddle.metric.Metric):
def __init__(self, batch_size=1, use_GPU=True, dims = 2048, premodel_path=None, model=None): def __init__(self,
batch_size=1,
use_GPU=True,
dims=2048,
premodel_path=None,
model=None):
self.batch_size = batch_size self.batch_size = batch_size
self.use_GPU = use_GPU self.use_GPU = use_GPU
self.dims = dims self.dims = dims
...@@ -55,8 +61,8 @@ class FID(paddle.metric.Metric): ...@@ -55,8 +61,8 @@ class FID(paddle.metric.Metric):
param_dict = paddle.load(premodel_path) param_dict = paddle.load(premodel_path)
self.model.load_dict(param_dict) self.model.load_dict(param_dict)
self.model.eval() self.model.eval()
self.reset() self.reset()
def reset(self): def reset(self):
self.preds = [] self.preds = []
self.gts = [] self.gts = []
...@@ -72,7 +78,7 @@ class FID(paddle.metric.Metric): ...@@ -72,7 +78,7 @@ class FID(paddle.metric.Metric):
self.preds = np.concatenate(self.preds, axis=0) self.preds = np.concatenate(self.preds, axis=0)
self.gts = np.concatenate(self.gts, axis=0) self.gts = np.concatenate(self.gts, axis=0)
value = calculate_fid_given_img(self.preds, self.gts) value = calculate_fid_given_img(self.preds, self.gts)
self.reset() self.reset()
return value return value
def name(self): def name(self):
...@@ -115,10 +121,10 @@ def _calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): ...@@ -115,10 +121,10 @@ def _calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
def _get_activations_from_ims(img, model, batch_size, dims, use_gpu): def _get_activations_from_ims(img, model, batch_size, dims, use_gpu):
n_batches = (len(img) + batch_size - 1) // batch_size n_batches = (len(img) + batch_size - 1) // batch_size
n_used_img = len(img) n_used_img = len(img)
pred_arr = np.empty((n_used_img, dims)) pred_arr = np.empty((n_used_img, dims))
for i in tqdm(range(n_batches)): for i in range(n_batches):
start = i * batch_size start = i * batch_size
end = start + batch_size end = start + batch_size
if end > len(img): if end > len(img):
...@@ -126,7 +132,7 @@ def _get_activations_from_ims(img, model, batch_size, dims, use_gpu): ...@@ -126,7 +132,7 @@ def _get_activations_from_ims(img, model, batch_size, dims, use_gpu):
images = img[start:end] images = img[start:end]
if images.shape[1] != 3: if images.shape[1] != 3:
images = images.transpose((0, 3, 1, 2)) images = images.transpose((0, 3, 1, 2))
images = paddle.to_tensor(images) images = paddle.to_tensor(images)
pred = model(images)[0][0] pred = model(images)[0][0]
pred_arr[start:end] = pred.reshape([end - start, -1]).cpu().numpy() pred_arr[start:end] = pred.reshape([end - start, -1]).cpu().numpy()
...@@ -138,16 +144,20 @@ def _compute_statistic_of_img(act): ...@@ -138,16 +144,20 @@ def _compute_statistic_of_img(act):
sigma = np.cov(act, rowvar=False) sigma = np.cov(act, rowvar=False)
return mu, sigma return mu, sigma
def calculate_inception_val(img_fake, def calculate_inception_val(img_fake,
img_real, img_real,
batch_size, batch_size,
model, model,
use_gpu = True, use_gpu=True,
dims = 2048): dims=2048):
act_fake = _get_activations_from_ims(img_fake, model, batch_size, dims, use_gpu) act_fake = _get_activations_from_ims(img_fake, model, batch_size, dims,
act_real = _get_activations_from_ims(img_real, model, batch_size, dims, use_gpu) use_gpu)
act_real = _get_activations_from_ims(img_real, model, batch_size, dims,
use_gpu)
return act_fake, act_real return act_fake, act_real
def calculate_fid_given_img(act_fake, act_real): def calculate_fid_given_img(act_fake, act_real):
m1, s1 = _compute_statistic_of_img(act_fake) m1, s1 = _compute_statistic_of_img(act_fake)
...@@ -299,4 +309,3 @@ def calculate_fid_given_paths(paths, ...@@ -299,4 +309,3 @@ def calculate_fid_given_paths(paths,
fid_value = _calculate_frechet_distance(m1, s1, m2, s2) fid_value = _calculate_frechet_distance(m1, s1, m2, s2)
return fid_value return fid_value
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册