未验证 提交 19fe4fbc 编写于 作者: L lzzyzlbb 提交者: GitHub

add fid for style gan (#347)

* add fid for style gan

* add fid for style gan

* add fid for style gan
上级 e5d59918
......@@ -23,6 +23,7 @@ model:
params:
gen_iters: 4
disc_iters: 16
max_eval_steps: 50000
export_model:
- {name: 'gen', inputs_num: 2}
......@@ -72,3 +73,11 @@ log_config:
snapshot_config:
interval: 5000
validate:
interval: 5000
save_imig: False
metrics:
fid: # metric name, can be arbitrary
name: FID
batch_size: 4
......@@ -48,7 +48,7 @@ class FID(paddle.metric.Metric):
self.premodel_path = premodel_path
if model is None:
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
model = InceptionV3([block_idx])
model = InceptionV3([block_idx], normalize_input=False)
if premodel_path is None:
premodel_path = get_weights_path_from_url(INCEPTIONV3_WEIGHT_URL)
self.model = model
......@@ -63,18 +63,15 @@ class FID(paddle.metric.Metric):
self.results = []
def update(self, preds, gts):
if len(preds.shape) >=4:
self.preds.append(preds)
self.gts.append(gts)
else:
for i in range(preds.shape[0]):
self.preds.append(preds[i,:,:,:,:])
self.gts.append(gts[i,:,:,:,:])
preds_inception, gts_inception = calculate_inception_val(
preds, gts, self.batch_size, self.model, self.use_GPU, self.dims)
self.preds.append(preds_inception)
self.gts.append(gts_inception)
def accumulate(self):
self.preds = paddle.concat(self.preds, axis=0)
self.gts = paddle.concat(self.gts, axis=0)
value = calculate_fid_given_img(self.preds, self.gts, self.batch_size, self.model, self.use_GPU, self.dims)
self.preds = np.concatenate(self.preds, axis=0)
self.gts = np.concatenate(self.gts, axis=0)
value = calculate_fid_given_img(self.preds, self.gts)
self.reset()
return value
......@@ -82,8 +79,6 @@ class FID(paddle.metric.Metric):
return 'FID'
def _calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
m1 = np.atleast_1d(mu1)
m2 = np.atleast_1d(mu2)
......@@ -111,7 +106,6 @@ def _calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
m = np.max(np.abs(covmean.imag))
raise ValueError('Imaginary component {}'.format(m))
covmean = covmean.real
tr_covmean = np.trace(covmean)
return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) -
......@@ -139,25 +133,25 @@ def _get_activations_from_ims(img, model, batch_size, dims, use_gpu):
return pred_arr
def _compute_statistic_of_img(img, model, batch_size, dims, use_gpu):
act = _get_activations_from_ims(img, model, batch_size, dims, use_gpu)
def _compute_statistic_of_img(act):
mu = np.mean(act, axis=0)
sigma = np.cov(act, rowvar=False)
return mu, sigma
def calculate_fid_given_img(img_fake,
def calculate_inception_val(img_fake,
img_real,
batch_size,
model,
use_gpu = True,
dims = 2048):
act_fake = _get_activations_from_ims(img_fake, model, batch_size, dims, use_gpu)
act_real = _get_activations_from_ims(img_real, model, batch_size, dims, use_gpu)
return act_fake, act_real
m1, s1 = _compute_statistic_of_img(img_fake, model, batch_size, dims,
use_gpu)
m2, s2 = _compute_statistic_of_img(img_real, model, batch_size, dims,
use_gpu)
def calculate_fid_given_img(act_fake, act_real):
m1, s1 = _compute_statistic_of_img(act_fake)
m2, s2 = _compute_statistic_of_img(act_real)
fid_value = _calculate_frechet_distance(m1, s1, m2, s2)
return fid_value
......@@ -305,3 +299,4 @@ def calculate_fid_given_paths(paths,
fid_value = _calculate_frechet_distance(m1, s1, m2, s2)
return fid_value
......@@ -79,7 +79,8 @@ class StyleGAN2Model(BaseModel):
r1_reg_weight=10.,
path_reg_weight=2.,
path_batch_shrink=2.,
params=None):
params=None,
max_eval_steps=50000):
"""Initialize the CycleGAN class.
Args:
......@@ -107,6 +108,7 @@ class StyleGAN2Model(BaseModel):
self.mean_path_length = 0
self.nets['gen'] = build_generator(generator)
self.max_eval_steps = max_eval_steps
# define discriminators
if discriminator:
......@@ -280,3 +282,15 @@ class StyleGAN2Model(BaseModel):
self.visual_items['fake_img_ema'] = sample
self.current_iter += 1
def test_iter(self, metrics=None):
self.nets['gen_ema'].eval()
batch = self.real_img.shape[0]
noises = [paddle.randn([batch, self.num_style_feat])]
fake_img, _ = self.nets['gen_ema'](noises)
with paddle.no_grad():
if metrics is not None:
for metric in metrics.values():
metric.update(fake_img, self.real_img)
self.nets['gen_ema'].train()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册