提交 886bb7d4 编写于 作者: J james plur

delete todo

上级 fc15e809
......@@ -34,13 +34,12 @@ python dcgan.py
- `-m` 多机训练
其他需要注意的是:
- 训练将会默认使用minst数据集,如果第一次使用脚本,将会默认将数据集下载到`.data/`目录
- 训练结束后,将会默认保存模型到`.checkpoint/`目录下
- 模型的结构和参数参考了tensorflow的[官方示例](https://www.tensorflow.org/tutorials/generative/dcgan),可以通过-c参数来跟tensorflow的实现进行对齐测试
- 模型会定期将生成的图片存储到`.gout/`目录,并在训练结束后生成图片演化的动图
- 模型会定期将生成的图片存储到`.gout/`目录,并在训练结束后生成图片演化的动图,生成动图的过程会依赖python包`imageio`
![](https://raw.githubusercontent.com/JamiePlur/picgo/master/20200615170256.png)
......@@ -73,32 +72,3 @@ python pix2pix.py
- 模型会在训练中定期将生成的图片存储到`.gout/`目录
![image-20200701153752019](https://raw.githubusercontent.com/JamiePlur/picgo/master/20200701153829.png)
## [TODO]
### GAN精度评价
- GAN的精度评价只存在于基本gan结构下,风格迁移没有定量的精度评价
- GAN最主要的评价指标主要为inception score和Fréchet Inception Distance,都需要借助预训练的inceptionV3模型,并在imagenet分类数据集上使用,目前还没有实现
- 比较SOTA的GAN模型为了冲榜,都会附上性能指标
- 可以参考一些论文对GAN的评估工作,比较有代表性的是[How good is my gan?](https://lear.inrialpes.fr/people/alahari/papers/shmelkov18.pdf)
![image-20200701103940612](https://raw.githubusercontent.com/JamiePlur/picgo/master/20200701154135.png)
### GAN的预训练模型
- 目前还没有可用的预训练模型,自行训练的主要困难有:
- GAN涉及的数据集非常多,数据处理的流水线没有建立起来
- GAN的调参非常困难,稳定性差
- 只有在大型数据集上的预训练模型才有意义,训练mnist数据集这种模型没有泛化价值
### CycleGAN
- cyclegan已经搭建完成,但是由于模型较大,跑不起来
- gan的模型普遍较大
\ No newline at end of file
......@@ -63,7 +63,7 @@ class DCGAN:
nodes.append(addr_dict)
flow.env.machine(nodes)
def train(self, epochs=1, model_dir=None, save=False):
def train(self, epochs=1, model_dir=None, save=True):
func_config = flow.FunctionConfig()
func_config.default_data_type(flow.float)
func_config.default_distribute_strategy(flow.distribute.consistent_strategy())
......@@ -149,6 +149,15 @@ class DCGAN:
self._eval_model_and_save_images(
eval_generator, batch_idx + 1, epoch_idx + 1
)
if save:
from datetime import datetime
if not os.path.exists("checkpoint"):
os.mkdir("checkpoint")
check_point.save(
"checkpoint/dcgan_{}".format(
str(datetime.now().strftime("%Y-%m-%d-%H:%M:%S"))
)
)
def save_to_gif(self):
anim_file = "dcgan.gif"
......
......@@ -103,8 +103,8 @@ def conv2d(
return output
def batchnorm(input, name, axis=1):
return flow.layers.batch_normalization(input, axis=axis, name=name)
def batchnorm(input, name, axis=1, reuse=False):
return flow.layers.batch_normalization(input, axis=axis)
def dense(
......
......@@ -285,6 +285,8 @@ class Pix2Pix:
_PATH = os.path.join(os.getcwd(), "data/facades.tar.gz")
_URL = "https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz"
path_to_zip = tf.keras.utils.get_file(_PATH, origin=_URL, extract=True)
else:
print("Found Facades - skip download")
input_imgs, real_imgs = [], []
for d in os.listdir(os.path.join("data/facades/", mode)):
......@@ -437,7 +439,7 @@ class Pix2Pix:
if not os.path.exists("checkpoint"):
os.mkdir("checkpoint")
check_point.save(
"checkpoint/cp_{}".format(
"checkpoint/pix2pix_{}".format(
str(datetime.now().strftime("%Y-%m-%d-%H:%M:%S"))
)
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册