Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
OneFlow-Benchmark
提交
886bb7d4
O
OneFlow-Benchmark
项目概览
Oneflow-Inc
/
OneFlow-Benchmark
上一次同步 2 年多
通知
1
Star
92
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
OneFlow-Benchmark
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
886bb7d4
编写于
7月 01, 2020
作者:
J
james plur
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
delete todo
上级
fc15e809
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
16 addition
and
35 deletion
+16
-35
Generative/README.md
Generative/README.md
+1
-31
Generative/dcgan.py
Generative/dcgan.py
+10
-1
Generative/layers.py
Generative/layers.py
+2
-2
Generative/pix2pix.py
Generative/pix2pix.py
+3
-1
未找到文件。
Generative/README.md
浏览文件 @
886bb7d4
...
...
@@ -34,13 +34,12 @@ python dcgan.py
-
`-m`
多机训练
其他需要注意的是:
-
训练将会默认使用minst数据集,如果第一次使用脚本,将会默认将数据集下载到
`.data/`
目录
-
训练结束后,将会默认保存模型到
`.checkpoint/`
目录下
-
模型的结构和参数参考了tensorflow的
[
官方示例
](
https://www.tensorflow.org/tutorials/generative/dcgan
)
,可以通过-c参数来跟tensorflow的实现进行对齐测试
-
模型会定期将生成的图片存储到
`.gout/`
目录,并在训练结束后生成图片演化的动图
-
模型会定期将生成的图片存储到
`.gout/`
目录,并在训练结束后生成图片演化的动图
,生成动图的过程会依赖python包
`imageio`
![](
https://raw.githubusercontent.com/JamiePlur/picgo/master/20200615170256.png
)
...
...
@@ -73,32 +72,3 @@ python pix2pix.py
-
模型会在训练中定期将生成的图片存储到
`.gout/`
目录
![
image-20200701153752019
](
https://raw.githubusercontent.com/JamiePlur/picgo/master/20200701153829.png
)
## [TODO]
### GAN精度评价
-
GAN的精度评价只存在于基本gan结构下,风格迁移没有定量的精度评价
-
GAN最主要的评价指标主要为inception score和Fréchet Inception Distance,都需要借助预训练的inceptionV3模型,并在imagenet分类数据集上使用,目前还没有实现
-
比较SOTA的GAN模型为了冲榜,都会附上性能指标
-
可以参考一些论文对GAN的评估工作,比较有代表性的是
[
How good is my gan?
](
https://lear.inrialpes.fr/people/alahari/papers/shmelkov18.pdf
)
![
image-20200701103940612
](
https://raw.githubusercontent.com/JamiePlur/picgo/master/20200701154135.png
)
### GAN的预训练模型
-
目前还没有可用的预训练模型,自行训练的主要困难有:
-
GAN涉及的数据集非常多,数据处理的流水线没有建立起来
-
GAN的调参非常困难,稳定性差
-
只有在大型数据集上的预训练模型才有意义,训练mnist数据集这种模型没有泛化价值
### CycleGAN
-
cyclegan已经搭建完成,但是由于模型较大,跑不起来
-
gan的模型普遍较大
\ No newline at end of file
Generative/dcgan.py
浏览文件 @
886bb7d4
...
...
@@ -63,7 +63,7 @@ class DCGAN:
nodes
.
append
(
addr_dict
)
flow
.
env
.
machine
(
nodes
)
def
train
(
self
,
epochs
=
1
,
model_dir
=
None
,
save
=
Fals
e
):
def
train
(
self
,
epochs
=
1
,
model_dir
=
None
,
save
=
Tru
e
):
func_config
=
flow
.
FunctionConfig
()
func_config
.
default_data_type
(
flow
.
float
)
func_config
.
default_distribute_strategy
(
flow
.
distribute
.
consistent_strategy
())
...
...
@@ -149,6 +149,15 @@ class DCGAN:
self
.
_eval_model_and_save_images
(
eval_generator
,
batch_idx
+
1
,
epoch_idx
+
1
)
if
save
:
from
datetime
import
datetime
if
not
os
.
path
.
exists
(
"checkpoint"
):
os
.
mkdir
(
"checkpoint"
)
check_point
.
save
(
"checkpoint/dcgan_{}"
.
format
(
str
(
datetime
.
now
().
strftime
(
"%Y-%m-%d-%H:%M:%S"
))
)
)
def
save_to_gif
(
self
):
anim_file
=
"dcgan.gif"
...
...
Generative/layers.py
浏览文件 @
886bb7d4
...
...
@@ -103,8 +103,8 @@ def conv2d(
return
output
def
batchnorm
(
input
,
name
,
axis
=
1
):
return
flow
.
layers
.
batch_normalization
(
input
,
axis
=
axis
,
name
=
name
)
def
batchnorm
(
input
,
name
,
axis
=
1
,
reuse
=
False
):
return
flow
.
layers
.
batch_normalization
(
input
,
axis
=
axis
)
def
dense
(
...
...
Generative/pix2pix.py
浏览文件 @
886bb7d4
...
...
@@ -285,6 +285,8 @@ class Pix2Pix:
_PATH
=
os
.
path
.
join
(
os
.
getcwd
(),
"data/facades.tar.gz"
)
_URL
=
"https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz"
path_to_zip
=
tf
.
keras
.
utils
.
get_file
(
_PATH
,
origin
=
_URL
,
extract
=
True
)
else
:
print
(
"Found Facades - skip download"
)
input_imgs
,
real_imgs
=
[],
[]
for
d
in
os
.
listdir
(
os
.
path
.
join
(
"data/facades/"
,
mode
)):
...
...
@@ -437,7 +439,7 @@ class Pix2Pix:
if
not
os
.
path
.
exists
(
"checkpoint"
):
os
.
mkdir
(
"checkpoint"
)
check_point
.
save
(
"checkpoint/
cp
_{}"
.
format
(
"checkpoint/
pix2pix
_{}"
.
format
(
str
(
datetime
.
now
().
strftime
(
"%Y-%m-%d-%H:%M:%S"
))
)
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录