未验证 提交 19fe4fbc 编写于 作者: L lzzyzlbb 提交者: GitHub

add fid for style gan (#347)

* add fid for style gan

* add fid for style gan

* add fid for style gan
上级 e5d59918
...@@ -23,6 +23,7 @@ model: ...@@ -23,6 +23,7 @@ model:
params: params:
gen_iters: 4 gen_iters: 4
disc_iters: 16 disc_iters: 16
max_eval_steps: 50000
export_model: export_model:
- {name: 'gen', inputs_num: 2} - {name: 'gen', inputs_num: 2}
...@@ -72,3 +73,11 @@ log_config: ...@@ -72,3 +73,11 @@ log_config:
snapshot_config: snapshot_config:
interval: 5000 interval: 5000
validate:
interval: 5000
save_imig: False
metrics:
fid: # metric name, can be arbitrary
name: FID
batch_size: 4
...@@ -48,7 +48,7 @@ class FID(paddle.metric.Metric): ...@@ -48,7 +48,7 @@ class FID(paddle.metric.Metric):
self.premodel_path = premodel_path self.premodel_path = premodel_path
if model is None: if model is None:
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
model = InceptionV3([block_idx]) model = InceptionV3([block_idx], normalize_input=False)
if premodel_path is None: if premodel_path is None:
premodel_path = get_weights_path_from_url(INCEPTIONV3_WEIGHT_URL) premodel_path = get_weights_path_from_url(INCEPTIONV3_WEIGHT_URL)
self.model = model self.model = model
...@@ -63,18 +63,15 @@ class FID(paddle.metric.Metric): ...@@ -63,18 +63,15 @@ class FID(paddle.metric.Metric):
self.results = [] self.results = []
def update(self, preds, gts): def update(self, preds, gts):
if len(preds.shape) >=4: preds_inception, gts_inception = calculate_inception_val(
self.preds.append(preds) preds, gts, self.batch_size, self.model, self.use_GPU, self.dims)
self.gts.append(gts) self.preds.append(preds_inception)
else: self.gts.append(gts_inception)
for i in range(preds.shape[0]):
self.preds.append(preds[i,:,:,:,:])
self.gts.append(gts[i,:,:,:,:])
def accumulate(self): def accumulate(self):
self.preds = paddle.concat(self.preds, axis=0) self.preds = np.concatenate(self.preds, axis=0)
self.gts = paddle.concat(self.gts, axis=0) self.gts = np.concatenate(self.gts, axis=0)
value = calculate_fid_given_img(self.preds, self.gts, self.batch_size, self.model, self.use_GPU, self.dims) value = calculate_fid_given_img(self.preds, self.gts)
self.reset() self.reset()
return value return value
...@@ -82,8 +79,6 @@ class FID(paddle.metric.Metric): ...@@ -82,8 +79,6 @@ class FID(paddle.metric.Metric):
return 'FID' return 'FID'
def _calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): def _calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
m1 = np.atleast_1d(mu1) m1 = np.atleast_1d(mu1)
m2 = np.atleast_1d(mu2) m2 = np.atleast_1d(mu2)
...@@ -111,7 +106,6 @@ def _calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): ...@@ -111,7 +106,6 @@ def _calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
m = np.max(np.abs(covmean.imag)) m = np.max(np.abs(covmean.imag))
raise ValueError('Imaginary component {}'.format(m)) raise ValueError('Imaginary component {}'.format(m))
covmean = covmean.real covmean = covmean.real
tr_covmean = np.trace(covmean) tr_covmean = np.trace(covmean)
return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) -
...@@ -132,32 +126,32 @@ def _get_activations_from_ims(img, model, batch_size, dims, use_gpu): ...@@ -132,32 +126,32 @@ def _get_activations_from_ims(img, model, batch_size, dims, use_gpu):
images = img[start:end] images = img[start:end]
if images.shape[1] != 3: if images.shape[1] != 3:
images = images.transpose((0, 3, 1, 2)) images = images.transpose((0, 3, 1, 2))
images = paddle.to_tensor(images) images = paddle.to_tensor(images)
pred = model(images)[0][0] pred = model(images)[0][0]
pred_arr[start:end] = pred.reshape([end - start, -1]).cpu().numpy() pred_arr[start:end] = pred.reshape([end - start, -1]).cpu().numpy()
return pred_arr return pred_arr
def _compute_statistic_of_img(img, model, batch_size, dims, use_gpu): def _compute_statistic_of_img(act):
act = _get_activations_from_ims(img, model, batch_size, dims, use_gpu)
mu = np.mean(act, axis=0) mu = np.mean(act, axis=0)
sigma = np.cov(act, rowvar=False) sigma = np.cov(act, rowvar=False)
return mu, sigma return mu, sigma
def calculate_inception_val(img_fake,
def calculate_fid_given_img(img_fake,
img_real, img_real,
batch_size, batch_size,
model, model,
use_gpu = True, use_gpu = True,
dims = 2048): dims = 2048):
act_fake = _get_activations_from_ims(img_fake, model, batch_size, dims, use_gpu)
act_real = _get_activations_from_ims(img_real, model, batch_size, dims, use_gpu)
return act_fake, act_real
m1, s1 = _compute_statistic_of_img(img_fake, model, batch_size, dims, def calculate_fid_given_img(act_fake, act_real):
use_gpu)
m2, s2 = _compute_statistic_of_img(img_real, model, batch_size, dims,
use_gpu)
m1, s1 = _compute_statistic_of_img(act_fake)
m2, s2 = _compute_statistic_of_img(act_real)
fid_value = _calculate_frechet_distance(m1, s1, m2, s2) fid_value = _calculate_frechet_distance(m1, s1, m2, s2)
return fid_value return fid_value
...@@ -305,3 +299,4 @@ def calculate_fid_given_paths(paths, ...@@ -305,3 +299,4 @@ def calculate_fid_given_paths(paths,
fid_value = _calculate_frechet_distance(m1, s1, m2, s2) fid_value = _calculate_frechet_distance(m1, s1, m2, s2)
return fid_value return fid_value
...@@ -79,7 +79,8 @@ class StyleGAN2Model(BaseModel): ...@@ -79,7 +79,8 @@ class StyleGAN2Model(BaseModel):
r1_reg_weight=10., r1_reg_weight=10.,
path_reg_weight=2., path_reg_weight=2.,
path_batch_shrink=2., path_batch_shrink=2.,
params=None): params=None,
max_eval_steps=50000):
"""Initialize the CycleGAN class. """Initialize the CycleGAN class.
Args: Args:
...@@ -107,6 +108,7 @@ class StyleGAN2Model(BaseModel): ...@@ -107,6 +108,7 @@ class StyleGAN2Model(BaseModel):
self.mean_path_length = 0 self.mean_path_length = 0
self.nets['gen'] = build_generator(generator) self.nets['gen'] = build_generator(generator)
self.max_eval_steps = max_eval_steps
# define discriminators # define discriminators
if discriminator: if discriminator:
...@@ -280,3 +282,15 @@ class StyleGAN2Model(BaseModel): ...@@ -280,3 +282,15 @@ class StyleGAN2Model(BaseModel):
self.visual_items['fake_img_ema'] = sample self.visual_items['fake_img_ema'] = sample
self.current_iter += 1 self.current_iter += 1
def test_iter(self, metrics=None):
self.nets['gen_ema'].eval()
batch = self.real_img.shape[0]
noises = [paddle.randn([batch, self.num_style_feat])]
fake_img, _ = self.nets['gen_ema'](noises)
with paddle.no_grad():
if metrics is not None:
for metric in metrics.values():
metric.update(fake_img, self.real_img)
self.nets['gen_ema'].train()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册