Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleGAN
提交
19fe4fbc
P
PaddleGAN
项目概览
PaddlePaddle
/
PaddleGAN
大约 1 年 前同步成功
通知
97
Star
7254
Fork
1210
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
19fe4fbc
编写于
6月 22, 2021
作者:
L
lzzyzlbb
提交者:
GitHub
6月 22, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add fid for style gan (#347)
* add fid for style gan * add fid for style gan * add fid for style gan
上级
e5d59918
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
43 addition
and
25 deletion
+43
-25
configs/stylegan_v2_256_ffhq.yaml
configs/stylegan_v2_256_ffhq.yaml
+9
-0
ppgan/metrics/fid.py
ppgan/metrics/fid.py
+19
-24
ppgan/models/styleganv2_model.py
ppgan/models/styleganv2_model.py
+15
-1
未找到文件。
configs/stylegan_v2_256_ffhq.yaml
浏览文件 @
19fe4fbc
...
@@ -23,6 +23,7 @@ model:
...
@@ -23,6 +23,7 @@ model:
params
:
params
:
gen_iters
:
4
gen_iters
:
4
disc_iters
:
16
disc_iters
:
16
max_eval_steps
:
50000
export_model
:
export_model
:
-
{
name
:
'
gen'
,
inputs_num
:
2
}
-
{
name
:
'
gen'
,
inputs_num
:
2
}
...
@@ -72,3 +73,11 @@ log_config:
...
@@ -72,3 +73,11 @@ log_config:
snapshot_config
:
snapshot_config
:
interval
:
5000
interval
:
5000
validate
:
interval
:
5000
save_imig
:
False
metrics
:
fid
:
# metric name, can be arbitrary
name
:
FID
batch_size
:
4
ppgan/metrics/fid.py
浏览文件 @
19fe4fbc
...
@@ -48,7 +48,7 @@ class FID(paddle.metric.Metric):
...
@@ -48,7 +48,7 @@ class FID(paddle.metric.Metric):
self
.
premodel_path
=
premodel_path
self
.
premodel_path
=
premodel_path
if
model
is
None
:
if
model
is
None
:
block_idx
=
InceptionV3
.
BLOCK_INDEX_BY_DIM
[
dims
]
block_idx
=
InceptionV3
.
BLOCK_INDEX_BY_DIM
[
dims
]
model
=
InceptionV3
([
block_idx
])
model
=
InceptionV3
([
block_idx
]
,
normalize_input
=
False
)
if
premodel_path
is
None
:
if
premodel_path
is
None
:
premodel_path
=
get_weights_path_from_url
(
INCEPTIONV3_WEIGHT_URL
)
premodel_path
=
get_weights_path_from_url
(
INCEPTIONV3_WEIGHT_URL
)
self
.
model
=
model
self
.
model
=
model
...
@@ -63,18 +63,15 @@ class FID(paddle.metric.Metric):
...
@@ -63,18 +63,15 @@ class FID(paddle.metric.Metric):
self
.
results
=
[]
self
.
results
=
[]
def
update
(
self
,
preds
,
gts
):
def
update
(
self
,
preds
,
gts
):
if
len
(
preds
.
shape
)
>=
4
:
preds_inception
,
gts_inception
=
calculate_inception_val
(
self
.
preds
.
append
(
preds
)
preds
,
gts
,
self
.
batch_size
,
self
.
model
,
self
.
use_GPU
,
self
.
dims
)
self
.
gts
.
append
(
gts
)
self
.
preds
.
append
(
preds_inception
)
else
:
self
.
gts
.
append
(
gts_inception
)
for
i
in
range
(
preds
.
shape
[
0
]):
self
.
preds
.
append
(
preds
[
i
,:,:,:,:])
self
.
gts
.
append
(
gts
[
i
,:,:,:,:])
def
accumulate
(
self
):
def
accumulate
(
self
):
self
.
preds
=
paddle
.
concat
(
self
.
preds
,
axis
=
0
)
self
.
preds
=
np
.
concatenate
(
self
.
preds
,
axis
=
0
)
self
.
gts
=
paddle
.
concat
(
self
.
gts
,
axis
=
0
)
self
.
gts
=
np
.
concatenate
(
self
.
gts
,
axis
=
0
)
value
=
calculate_fid_given_img
(
self
.
preds
,
self
.
gts
,
self
.
batch_size
,
self
.
model
,
self
.
use_GPU
,
self
.
dims
)
value
=
calculate_fid_given_img
(
self
.
preds
,
self
.
gts
)
self
.
reset
()
self
.
reset
()
return
value
return
value
...
@@ -82,8 +79,6 @@ class FID(paddle.metric.Metric):
...
@@ -82,8 +79,6 @@ class FID(paddle.metric.Metric):
return
'FID'
return
'FID'
def
_calculate_frechet_distance
(
mu1
,
sigma1
,
mu2
,
sigma2
,
eps
=
1e-6
):
def
_calculate_frechet_distance
(
mu1
,
sigma1
,
mu2
,
sigma2
,
eps
=
1e-6
):
m1
=
np
.
atleast_1d
(
mu1
)
m1
=
np
.
atleast_1d
(
mu1
)
m2
=
np
.
atleast_1d
(
mu2
)
m2
=
np
.
atleast_1d
(
mu2
)
...
@@ -111,7 +106,6 @@ def _calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
...
@@ -111,7 +106,6 @@ def _calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
m
=
np
.
max
(
np
.
abs
(
covmean
.
imag
))
m
=
np
.
max
(
np
.
abs
(
covmean
.
imag
))
raise
ValueError
(
'Imaginary component {}'
.
format
(
m
))
raise
ValueError
(
'Imaginary component {}'
.
format
(
m
))
covmean
=
covmean
.
real
covmean
=
covmean
.
real
tr_covmean
=
np
.
trace
(
covmean
)
tr_covmean
=
np
.
trace
(
covmean
)
return
(
diff
.
dot
(
diff
)
+
np
.
trace
(
sigma1
)
+
np
.
trace
(
sigma2
)
-
return
(
diff
.
dot
(
diff
)
+
np
.
trace
(
sigma1
)
+
np
.
trace
(
sigma2
)
-
...
@@ -132,32 +126,32 @@ def _get_activations_from_ims(img, model, batch_size, dims, use_gpu):
...
@@ -132,32 +126,32 @@ def _get_activations_from_ims(img, model, batch_size, dims, use_gpu):
images
=
img
[
start
:
end
]
images
=
img
[
start
:
end
]
if
images
.
shape
[
1
]
!=
3
:
if
images
.
shape
[
1
]
!=
3
:
images
=
images
.
transpose
((
0
,
3
,
1
,
2
))
images
=
images
.
transpose
((
0
,
3
,
1
,
2
))
images
=
paddle
.
to_tensor
(
images
)
images
=
paddle
.
to_tensor
(
images
)
pred
=
model
(
images
)[
0
][
0
]
pred
=
model
(
images
)[
0
][
0
]
pred_arr
[
start
:
end
]
=
pred
.
reshape
([
end
-
start
,
-
1
]).
cpu
().
numpy
()
pred_arr
[
start
:
end
]
=
pred
.
reshape
([
end
-
start
,
-
1
]).
cpu
().
numpy
()
return
pred_arr
return
pred_arr
def
_compute_statistic_of_img
(
img
,
model
,
batch_size
,
dims
,
use_gpu
):
def
_compute_statistic_of_img
(
act
):
act
=
_get_activations_from_ims
(
img
,
model
,
batch_size
,
dims
,
use_gpu
)
mu
=
np
.
mean
(
act
,
axis
=
0
)
mu
=
np
.
mean
(
act
,
axis
=
0
)
sigma
=
np
.
cov
(
act
,
rowvar
=
False
)
sigma
=
np
.
cov
(
act
,
rowvar
=
False
)
return
mu
,
sigma
return
mu
,
sigma
def
calculate_inception_val
(
img_fake
,
def
calculate_fid_given_img
(
img_fake
,
img_real
,
img_real
,
batch_size
,
batch_size
,
model
,
model
,
use_gpu
=
True
,
use_gpu
=
True
,
dims
=
2048
):
dims
=
2048
):
act_fake
=
_get_activations_from_ims
(
img_fake
,
model
,
batch_size
,
dims
,
use_gpu
)
act_real
=
_get_activations_from_ims
(
img_real
,
model
,
batch_size
,
dims
,
use_gpu
)
return
act_fake
,
act_real
m1
,
s1
=
_compute_statistic_of_img
(
img_fake
,
model
,
batch_size
,
dims
,
def
calculate_fid_given_img
(
act_fake
,
act_real
):
use_gpu
)
m2
,
s2
=
_compute_statistic_of_img
(
img_real
,
model
,
batch_size
,
dims
,
use_gpu
)
m1
,
s1
=
_compute_statistic_of_img
(
act_fake
)
m2
,
s2
=
_compute_statistic_of_img
(
act_real
)
fid_value
=
_calculate_frechet_distance
(
m1
,
s1
,
m2
,
s2
)
fid_value
=
_calculate_frechet_distance
(
m1
,
s1
,
m2
,
s2
)
return
fid_value
return
fid_value
...
@@ -305,3 +299,4 @@ def calculate_fid_given_paths(paths,
...
@@ -305,3 +299,4 @@ def calculate_fid_given_paths(paths,
fid_value
=
_calculate_frechet_distance
(
m1
,
s1
,
m2
,
s2
)
fid_value
=
_calculate_frechet_distance
(
m1
,
s1
,
m2
,
s2
)
return
fid_value
return
fid_value
ppgan/models/styleganv2_model.py
浏览文件 @
19fe4fbc
...
@@ -79,7 +79,8 @@ class StyleGAN2Model(BaseModel):
...
@@ -79,7 +79,8 @@ class StyleGAN2Model(BaseModel):
r1_reg_weight
=
10.
,
r1_reg_weight
=
10.
,
path_reg_weight
=
2.
,
path_reg_weight
=
2.
,
path_batch_shrink
=
2.
,
path_batch_shrink
=
2.
,
params
=
None
):
params
=
None
,
max_eval_steps
=
50000
):
"""Initialize the CycleGAN class.
"""Initialize the CycleGAN class.
Args:
Args:
...
@@ -107,6 +108,7 @@ class StyleGAN2Model(BaseModel):
...
@@ -107,6 +108,7 @@ class StyleGAN2Model(BaseModel):
self
.
mean_path_length
=
0
self
.
mean_path_length
=
0
self
.
nets
[
'gen'
]
=
build_generator
(
generator
)
self
.
nets
[
'gen'
]
=
build_generator
(
generator
)
self
.
max_eval_steps
=
max_eval_steps
# define discriminators
# define discriminators
if
discriminator
:
if
discriminator
:
...
@@ -280,3 +282,15 @@ class StyleGAN2Model(BaseModel):
...
@@ -280,3 +282,15 @@ class StyleGAN2Model(BaseModel):
self
.
visual_items
[
'fake_img_ema'
]
=
sample
self
.
visual_items
[
'fake_img_ema'
]
=
sample
self
.
current_iter
+=
1
self
.
current_iter
+=
1
def
test_iter
(
self
,
metrics
=
None
):
self
.
nets
[
'gen_ema'
].
eval
()
batch
=
self
.
real_img
.
shape
[
0
]
noises
=
[
paddle
.
randn
([
batch
,
self
.
num_style_feat
])]
fake_img
,
_
=
self
.
nets
[
'gen_ema'
](
noises
)
with
paddle
.
no_grad
():
if
metrics
is
not
None
:
for
metric
in
metrics
.
values
():
metric
.
update
(
fake_img
,
self
.
real_img
)
self
.
nets
[
'gen_ema'
].
train
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录