提交 c3ebff5e 编写于 作者: W wangyang59

modified demo/gan following emailxuwei comments

上级 5aa59796
...@@ -7,7 +7,7 @@ The general training procedures are implemented in gan_trainer.py. The neural ne ...@@ -7,7 +7,7 @@ The general training procedures are implemented in gan_trainer.py. The neural ne
In order to run the model, first download the corresponding data by running the shell script in ./data. In order to run the model, first download the corresponding data by running the shell script in ./data.
Then you can run the command below. The flag -d specifies the training data (cifar, mnist or uniform) and flag --useGpu specifies whether to use gpu for training (0 is cpu, 1 is gpu). Then you can run the command below. The flag -d specifies the training data (cifar, mnist or uniform) and flag --useGpu specifies whether to use gpu for training (0 is cpu, 1 is gpu).
$python gan_trainer.py -d cifar --useGpu 1 $python gan_trainer.py -d cifar --use_gpu 1
The generated images will be stored in ./cifar_samples/ The generated images will be stored in ./cifar_samples/
The corresponding models will be stored in ./cifar_params/ The corresponding models will be stored in ./cifar_params/
\ No newline at end of file
...@@ -31,8 +31,8 @@ def plot2DScatter(data, outputfile): ...@@ -31,8 +31,8 @@ def plot2DScatter(data, outputfile):
''' '''
x = data[:, 0] x = data[:, 0]
y = data[:, 1] y = data[:, 1]
print "The mean vector is %s" % numpy.mean(data, 0) logger.info("The mean vector is %s" % numpy.mean(data, 0))
print "The std vector is %s" % numpy.std(data, 0) logger.info("The std vector is %s" % numpy.std(data, 0))
heatmap, xedges, yedges = numpy.histogram2d(x, y, bins=50) heatmap, xedges, yedges = numpy.histogram2d(x, y, bins=50)
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]] extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
...@@ -192,42 +192,42 @@ def get_layer_size(model_conf, layer_name): ...@@ -192,42 +192,42 @@ def get_layer_size(model_conf, layer_name):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-d", "--dataSource", help="mnist or cifar or uniform") parser.add_argument("-d", "--data_source", help="mnist or cifar or uniform")
parser.add_argument("--useGpu", default="1", parser.add_argument("--use_gpu", default="1",
help="1 means use gpu for training") help="1 means use gpu for training")
parser.add_argument("--gpuId", default="0", parser.add_argument("--gpu_id", default="0",
help="the gpu_id parameter") help="the gpu_id parameter")
args = parser.parse_args() args = parser.parse_args()
dataSource = args.dataSource data_source = args.data_source
useGpu = args.useGpu use_gpu = args.use_gpu
assert dataSource in ["mnist", "cifar", "uniform"] assert data_source in ["mnist", "cifar", "uniform"]
assert useGpu in ["0", "1"] assert use_gpu in ["0", "1"]
if not os.path.exists("./%s_samples/" % dataSource): if not os.path.exists("./%s_samples/" % data_source):
os.makedirs("./%s_samples/" % dataSource) os.makedirs("./%s_samples/" % data_source)
if not os.path.exists("./%s_params/" % dataSource): if not os.path.exists("./%s_params/" % data_source):
os.makedirs("./%s_params/" % dataSource) os.makedirs("./%s_params/" % data_source)
api.initPaddle('--use_gpu=' + useGpu, '--dot_period=10', '--log_period=100', api.initPaddle('--use_gpu=' + use_gpu, '--dot_period=10', '--log_period=100',
'--gpu_id=' + args.gpuId, '--save_dir=' + "./%s_params/" % dataSource) '--gpu_id=' + args.gpu_id, '--save_dir=' + "./%s_params/" % data_source)
if dataSource == "uniform": if data_source == "uniform":
conf = "gan_conf.py" conf = "gan_conf.py"
num_iter = 10000 num_iter = 10000
else: else:
conf = "gan_conf_image.py" conf = "gan_conf_image.py"
num_iter = 1000 num_iter = 1000
gen_conf = parse_config(conf, "mode=generator_training,data=" + dataSource) gen_conf = parse_config(conf, "mode=generator_training,data=" + data_source)
dis_conf = parse_config(conf, "mode=discriminator_training,data=" + dataSource) dis_conf = parse_config(conf, "mode=discriminator_training,data=" + data_source)
generator_conf = parse_config(conf, "mode=generator,data=" + dataSource) generator_conf = parse_config(conf, "mode=generator,data=" + data_source)
batch_size = dis_conf.opt_config.batch_size batch_size = dis_conf.opt_config.batch_size
noise_dim = get_layer_size(gen_conf.model_config, "noise") noise_dim = get_layer_size(gen_conf.model_config, "noise")
if dataSource == "mnist": if data_source == "mnist":
data_np = load_mnist_data("./data/mnist_data/train-images-idx3-ubyte") data_np = load_mnist_data("./data/mnist_data/train-images-idx3-ubyte")
elif dataSource == "cifar": elif data_source == "cifar":
data_np = load_cifar_data("./data/cifar-10-batches-py/") data_np = load_cifar_data("./data/cifar-10-batches-py/")
else: else:
data_np = load_uniform_data() data_np = load_uniform_data()
...@@ -308,7 +308,9 @@ def main(): ...@@ -308,7 +308,9 @@ def main():
else: else:
curr_train = "gen" curr_train = "gen"
curr_strike = 1 curr_strike = 1
gen_trainer.trainOneDataBatch(batch_size, data_batch_gen) gen_trainer.trainOneDataBatch(batch_size, data_batch_gen)
# TODO: add API for paddle to allow true parameter sharing between different GradientMachines
# so that we do not need to copy shared parameters.
copy_shared_parameters(gen_training_machine, dis_training_machine) copy_shared_parameters(gen_training_machine, dis_training_machine)
copy_shared_parameters(gen_training_machine, generator_machine) copy_shared_parameters(gen_training_machine, generator_machine)
...@@ -316,10 +318,10 @@ def main(): ...@@ -316,10 +318,10 @@ def main():
gen_trainer.finishTrainPass() gen_trainer.finishTrainPass()
# At the end of each pass, save the generated samples/images # At the end of each pass, save the generated samples/images
fake_samples = get_fake_samples(generator_machine, batch_size, noise) fake_samples = get_fake_samples(generator_machine, batch_size, noise)
if dataSource == "uniform": if data_source == "uniform":
plot2DScatter(fake_samples, "./%s_samples/train_pass%s.png" % (dataSource, train_pass)) plot2DScatter(fake_samples, "./%s_samples/train_pass%s.png" % (data_source, train_pass))
else: else:
save_images(fake_samples, "./%s_samples/train_pass%s.png" % (dataSource, train_pass)) save_images(fake_samples, "./%s_samples/train_pass%s.png" % (data_source, train_pass))
dis_trainer.finishTrain() dis_trainer.finishTrain()
gen_trainer.finishTrain() gen_trainer.finishTrain()
......
...@@ -33,7 +33,7 @@ P_DECLARE_double(checkgrad_eps); ...@@ -33,7 +33,7 @@ P_DECLARE_double(checkgrad_eps);
P_DECLARE_bool(thread_local_rand_use_global_seed); P_DECLARE_bool(thread_local_rand_use_global_seed);
P_DECLARE_bool(prev_batch_state); P_DECLARE_bool(prev_batch_state);
// Test that the convTrans forward is the same as conv backward // Test that the batchNormLayer can be followed by a ConvLayer
TEST(Layer, batchNorm) { TEST(Layer, batchNorm) {
FLAGS_use_gpu = false; FLAGS_use_gpu = false;
TestConfig configBN; TestConfig configBN;
...@@ -104,7 +104,6 @@ TEST(Layer, batchNorm) { ...@@ -104,7 +104,6 @@ TEST(Layer, batchNorm) {
LayerPtr convLayer; LayerPtr convLayer;
initTestLayer(config, &layerMap, &parameters2, &convLayer); initTestLayer(config, &layerMap, &parameters2, &convLayer);
// Set convLayer outputGrad as convTransLayer input value
bnLayer->forward(PASS_GC); bnLayer->forward(PASS_GC);
convLayer->forward(PASS_GC); convLayer->forward(PASS_GC);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册