未验证 提交 0666e2e8 编写于 作者: L LielinJiang 提交者: GitHub

increase validate interval (#382)

上级 f84ddda4
......@@ -61,7 +61,7 @@ dataset:
keys: [image, image]
test:
name: PairedDataset
dataroot: data/cityscapes/test
dataroot: data/cityscapes/val
num_workers: 4
batch_size: 1
preprocess:
......@@ -112,11 +112,9 @@ snapshot_config:
interval: 5
validate:
interval: 500
interval: 29750
save_img: false
metrics:
fid: # metric name, can be arbitrary
name: FID
batch_size: 8
......@@ -58,7 +58,7 @@ dataset:
keys: [image, image]
test:
name: PairedDataset
dataroot: data/cityscapes/test
dataroot: data/cityscapes/val
num_workers: 4
batch_size: 1
preprocess:
......@@ -109,11 +109,9 @@ snapshot_config:
interval: 5
validate:
interval: 500
interval: 29750
save_img: false
metrics:
fid: # metric name, can be arbitrary
name: FID
batch_size: 8
......@@ -109,11 +109,9 @@ snapshot_config:
interval: 5
validate:
interval: 500
interval: 4000
save_img: false
metrics:
fid: # metric name, can be arbitrary
name: FID
batch_size: 8
......@@ -90,7 +90,7 @@ snapshot_config:
interval: 5000
validate:
interval: 5000
interval: 50000
save_imig: False
metrics:
fid: # metric name, can be arbitrary
......
......@@ -39,9 +39,15 @@ inceptionV3 pretrain model is convert from pytorch, pretrain_model url is https:
"""
INCEPTIONV3_WEIGHT_URL = "https://paddlegan.bj.bcebos.com/InceptionV3.pdparams"
@METRICS.register()
class FID(paddle.metric.Metric):
def __init__(self, batch_size=1, use_GPU=True, dims = 2048, premodel_path=None, model=None):
def __init__(self,
batch_size=1,
use_GPU=True,
dims=2048,
premodel_path=None,
model=None):
self.batch_size = batch_size
self.use_GPU = use_GPU
self.dims = dims
......@@ -118,7 +124,7 @@ def _get_activations_from_ims(img, model, batch_size, dims, use_gpu):
pred_arr = np.empty((n_used_img, dims))
for i in tqdm(range(n_batches)):
for i in range(n_batches):
start = i * batch_size
end = start + batch_size
if end > len(img):
......@@ -138,16 +144,20 @@ def _compute_statistic_of_img(act):
sigma = np.cov(act, rowvar=False)
return mu, sigma
def calculate_inception_val(img_fake,
img_real,
batch_size,
model,
use_gpu = True,
dims = 2048):
act_fake = _get_activations_from_ims(img_fake, model, batch_size, dims, use_gpu)
act_real = _get_activations_from_ims(img_real, model, batch_size, dims, use_gpu)
use_gpu=True,
dims=2048):
act_fake = _get_activations_from_ims(img_fake, model, batch_size, dims,
use_gpu)
act_real = _get_activations_from_ims(img_real, model, batch_size, dims,
use_gpu)
return act_fake, act_real
def calculate_fid_given_img(act_fake, act_real):
m1, s1 = _compute_statistic_of_img(act_fake)
......@@ -299,4 +309,3 @@ def calculate_fid_given_paths(paths,
fid_value = _calculate_frechet_distance(m1, s1, m2, s2)
return fid_value
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册