提交 93af332e 编写于 作者: W wangyang59

unified cifar/mnist/uniform gan training in demo

上级 4878f078
......@@ -4,5 +4,5 @@ output/
.project
*.log
*.pyc
data/raw_data/
data/mnist_data/
data/cifar-10-batches-py/
......@@ -2,9 +2,9 @@
# This scripts downloads the mnist data and unzips it.
set -e
DIR="$( cd "$(dirname "$0")" ; pwd -P )"
rm -rf "$DIR/raw_data"
mkdir "$DIR/raw_data"
cd "$DIR/raw_data"
rm -rf "$DIR/mnist_data"
mkdir "$DIR/mnist_data"
cd "$DIR/mnist_data"
echo "Downloading..."
......
......@@ -32,7 +32,7 @@ sample_dim = 2
settings(
batch_size=128,
learning_rate=1e-4,
learning_method=AdamOptimizer()
learning_method=AdamOptimizer(beta1=0.7)
)
def discriminator(sample):
......@@ -47,16 +47,15 @@ def discriminator(sample):
bias_attr = ParamAttr(is_static=is_generator_training,
initial_mean=1.0,
initial_std=0)
hidden = fc_layer(input=sample, name="dis_hidden", size=hidden_dim,
bias_attr=bias_attr,
param_attr=param_attr,
act=ReluActivation())
#act=LinearActivation())
hidden2 = fc_layer(input=hidden, name="dis_hidden2", size=hidden_dim,
bias_attr=bias_attr,
param_attr=param_attr,
#act=ReluActivation())
act=LinearActivation())
hidden_bn = batch_norm_layer(hidden2,
......@@ -88,12 +87,10 @@ def generator(noise):
bias_attr=bias_attr,
param_attr=param_attr,
act=ReluActivation())
#act=LinearActivation())
hidden2 = fc_layer(input=hidden, name="gen_hidden2", size=hidden_dim,
bias_attr=bias_attr,
param_attr=param_attr,
#act=ReluActivation())
act=LinearActivation())
hidden_bn = batch_norm_layer(hidden2,
......
......@@ -113,7 +113,6 @@ def generator(noise):
size=s8 * s8 * gf_dim * 4,
bias_attr=bias_attr,
param_attr=param_attr,
#act=ReluActivation())
act=LinearActivation())
h1_bn = batch_norm_layer(h1,
......@@ -235,13 +234,8 @@ if is_discriminator_training:
sample = data_layer(name="sample", size=sample_dim * sample_dim*c_dim)
if is_generator_training or is_discriminator_training:
sample_noise = data_layer(name="sample_noise",
size=sample_dim * sample_dim * c_dim)
label = data_layer(name="label", size=1)
prob = discriminator(addto_layer([sample, sample_noise],
act=LinearActivation(),
name="add",
bias_attr=False))
prob = discriminator(sample)
cost = cross_entropy(input=prob, label=label)
classification_error_evaluator(input=prob, label=label, name=mode+'_error')
outputs(cost)
......
......@@ -71,7 +71,7 @@ def print_parameters(src):
print "value is %s \n" % p.getBuf(api.PARAMETER_VALUE).copyToNumpyArray()
def get_real_samples(batch_size, sample_dim):
return numpy.random.rand(batch_size, sample_dim).astype('float32') * 10.0 - 10.0
return numpy.random.rand(batch_size, sample_dim).astype('float32')
# return numpy.random.normal(loc=100.0, scale=100.0, size=(batch_size, sample_dim)).astype('float32')
def get_fake_samples(generator_machine, batch_size, noise_dim, sample_dim):
......@@ -106,7 +106,7 @@ def prepare_discriminator_data_batch_pos(batch_size, noise_dim, sample_dim):
labels = numpy.ones(batch_size, dtype='int32')
inputs = api.Arguments.createArguments(2)
inputs.setSlotValue(0, api.Matrix.createGpuDenseFromNumpy(real_samples))
inputs.setSlotIds(1, api.IVector.createGpuVectorFromNumy(labels))
inputs.setSlotIds(1, api.IVector.createGpuVectorFromNumpy(labels))
return inputs
def prepare_discriminator_data_batch_neg(generator_machine, batch_size, noise_dim, sample_dim):
......@@ -114,7 +114,7 @@ def prepare_discriminator_data_batch_neg(generator_machine, batch_size, noise_di
labels = numpy.zeros(batch_size, dtype='int32')
inputs = api.Arguments.createArguments(2)
inputs.setSlotValue(0, api.Matrix.createGpuDenseFromNumpy(fake_samples))
inputs.setSlotIds(1, api.IVector.createGpuVectorFromNumy(labels))
inputs.setSlotIds(1, api.IVector.createGpuVectorFromNumpy(labels))
return inputs
def prepare_generator_data_batch(batch_size, dim):
......@@ -122,7 +122,7 @@ def prepare_generator_data_batch(batch_size, dim):
label = numpy.ones(batch_size, dtype='int32')
inputs = api.Arguments.createArguments(2)
inputs.setSlotValue(0, api.Matrix.createGpuDenseFromNumpy(noise))
inputs.setSlotIds(1, api.IVector.createGpuVectorFromNumy(label))
inputs.setSlotIds(1, api.IVector.createGpuVectorFromNumpy(label))
return inputs
......@@ -140,7 +140,8 @@ def get_layer_size(model_conf, layer_name):
def main():
api.initPaddle('--use_gpu=1', '--dot_period=100', '--log_period=10000')
api.initPaddle('--use_gpu=1', '--dot_period=10', '--log_period=100',
'--gpu_id=2')
gen_conf = parse_config("gan_conf.py", "mode=generator_training")
dis_conf = parse_config("gan_conf.py", "mode=discriminator_training")
generator_conf = parse_config("gan_conf.py", "mode=generator")
......@@ -175,10 +176,10 @@ def main():
curr_strike = 0
MAX_strike = 5
for train_pass in xrange(10):
for train_pass in xrange(100):
dis_trainer.startTrainPass()
gen_trainer.startTrainPass()
for i in xrange(100000):
for i in xrange(1000):
# data_batch_dis = prepare_discriminator_data_batch(
# generator_machine, batch_size, noise_dim, sample_dim)
# dis_loss = get_training_loss(dis_training_machine, data_batch_dis)
......@@ -199,7 +200,7 @@ def main():
if i % 1000 == 0:
print "d_loss is %s g_loss is %s" % (dis_loss, gen_loss)
if (not (curr_train == "dis" and curr_strike == MAX_strike)) and ((curr_train == "gen" and curr_strike == MAX_strike) or dis_loss > 0.690 or dis_loss > gen_loss):
if (not (curr_train == "dis" and curr_strike == MAX_strike)) and ((curr_train == "gen" and curr_strike == MAX_strike) or dis_loss > gen_loss):
if curr_train == "dis":
curr_strike += 1
else:
......
......@@ -13,16 +13,29 @@
# limitations under the License.
import argparse
import itertools
import random
import numpy
import cPickle
import sys,os,gc
import sys,os
from PIL import Image
from paddle.trainer.config_parser import parse_config
from paddle.trainer.config_parser import logger
import py_paddle.swig_paddle as api
import matplotlib.pyplot as plt
def plot2DScatter(data, outputfile):
x = data[:, 0]
y = data[:, 1]
print "The mean vector is %s" % numpy.mean(data, 0)
print "The std vector is %s" % numpy.std(data, 0)
heatmap, xedges, yedges = numpy.histogram2d(x, y, bins=50)
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
plt.clf()
plt.scatter(x, y)
plt.savefig(outputfile, bbox_inches='tight')
def CHECK_EQ(a, b):
assert a == b, "a=%s, b=%s" % (a, b)
......@@ -60,7 +73,6 @@ def load_mnist_data(imageFile):
# Define number of samples for train/test
if "train" in imageFile:
#n = 60000
n = 60000
else:
n = 10000
......@@ -89,6 +101,11 @@ def load_cifar_data(cifar_path):
data = data / 255.0 * 2.0 - 1.0
return data
# synthesize 2-D uniform data
def load_uniform_data():
data = numpy.random.rand(1000000, 2).astype('float32')
return data
def merge(images, size):
if images.shape[1] == 28*28:
h, w, c = 28, 28, 1
......@@ -98,7 +115,6 @@ def merge(images, size):
for idx in xrange(size[0] * size[1]):
i = idx % size[1]
j = idx // size[1]
#img[j*h:j*h+h, i*w:i*w+w, :] = (images[idx, :].reshape((h, w, c), order="F") + 1.0) / 2.0 * 255.0
img[j*h:j*h+h, i*w:i*w+w, :] = \
((images[idx, :].reshape((h, w, c), order="F").transpose(1, 0, 2) + 1.0) / 2.0 * 255.0)
return img.astype('uint8')
......@@ -118,13 +134,9 @@ def get_real_samples(batch_size, data_np):
def get_noise(batch_size, noise_dim):
return numpy.random.normal(size=(batch_size, noise_dim)).astype('float32')
def get_sample_noise(batch_size, sample_dim):
return numpy.random.normal(size=(batch_size, sample_dim),
scale=0.01).astype('float32')
def get_fake_samples(generator_machine, batch_size, noise):
gen_inputs = api.Arguments.createArguments(1)
gen_inputs.setSlotValue(0, api.Matrix.createGpuDenseFromNumpy(noise))
gen_inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(noise))
gen_outputs = api.Arguments.createArguments(0)
generator_machine.forward(gen_inputs, gen_outputs, api.PASS_TEST)
fake_samples = gen_outputs.getSlotValue(0).copyToNumpyMat()
......@@ -136,33 +148,27 @@ def get_training_loss(training_machine, inputs):
loss = outputs.getSlotValue(0).copyToNumpyMat()
return numpy.mean(loss)
def prepare_discriminator_data_batch_pos(batch_size, data_np, sample_noise):
def prepare_discriminator_data_batch_pos(batch_size, data_np):
real_samples = get_real_samples(batch_size, data_np)
labels = numpy.ones(batch_size, dtype='int32')
inputs = api.Arguments.createArguments(3)
inputs.setSlotValue(0, api.Matrix.createGpuDenseFromNumpy(real_samples))
inputs.setSlotValue(1, api.Matrix.createGpuDenseFromNumpy(sample_noise))
inputs.setSlotIds(2, api.IVector.createGpuVectorFromNumpy(labels))
inputs = api.Arguments.createArguments(2)
inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(real_samples))
inputs.setSlotIds(1, api.IVector.createVectorFromNumpy(labels))
return inputs
def prepare_discriminator_data_batch_neg(generator_machine, batch_size, noise,
sample_noise):
def prepare_discriminator_data_batch_neg(generator_machine, batch_size, noise):
fake_samples = get_fake_samples(generator_machine, batch_size, noise)
#print fake_samples.shape
labels = numpy.zeros(batch_size, dtype='int32')
inputs = api.Arguments.createArguments(3)
inputs.setSlotValue(0, api.Matrix.createGpuDenseFromNumpy(fake_samples))
inputs.setSlotValue(1, api.Matrix.createGpuDenseFromNumpy(sample_noise))
inputs.setSlotIds(2, api.IVector.createGpuVectorFromNumpy(labels))
inputs = api.Arguments.createArguments(2)
inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(fake_samples))
inputs.setSlotIds(1, api.IVector.createVectorFromNumpy(labels))
return inputs
def prepare_generator_data_batch(batch_size, noise, sample_noise):
def prepare_generator_data_batch(batch_size, noise):
label = numpy.ones(batch_size, dtype='int32')
#label = numpy.zeros(batch_size, dtype='int32')
inputs = api.Arguments.createArguments(3)
inputs.setSlotValue(0, api.Matrix.createGpuDenseFromNumpy(noise))
inputs.setSlotValue(1, api.Matrix.createGpuDenseFromNumpy(sample_noise))
inputs.setSlotIds(2, api.IVector.createGpuVectorFromNumpy(label))
inputs = api.Arguments.createArguments(2)
inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(noise))
inputs.setSlotIds(1, api.IVector.createVectorFromNumpy(label))
return inputs
......@@ -181,7 +187,7 @@ def get_layer_size(model_conf, layer_name):
def main():
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--dataSource", help="mnist or cifar")
parser.add_argument("-d", "--dataSource", help="mnist or cifar or uniform")
parser.add_argument("--useGpu", default="1",
help="1 means use gpu for training")
parser.add_argument("--gpuId", default="0",
......@@ -189,22 +195,31 @@ def main():
args = parser.parse_args()
dataSource = args.dataSource
useGpu = args.useGpu
assert dataSource in ["mnist", "cifar"]
assert dataSource in ["mnist", "cifar", "uniform"]
assert useGpu in ["0", "1"]
api.initPaddle('--use_gpu=' + useGpu, '--dot_period=10', '--log_period=100',
'--gpu_id=' + args.gpuId)
gen_conf = parse_config("gan_conf_image.py", "mode=generator_training,data=" + dataSource)
dis_conf = parse_config("gan_conf_image.py", "mode=discriminator_training,data=" + dataSource)
generator_conf = parse_config("gan_conf_image.py", "mode=generator,data=" + dataSource)
if dataSource == "uniform":
conf = "gan_conf.py"
num_iter = 10000
else:
conf = "gan_conf_image.py"
num_iter = 1000
gen_conf = parse_config(conf, "mode=generator_training,data=" + dataSource)
dis_conf = parse_config(conf, "mode=discriminator_training,data=" + dataSource)
generator_conf = parse_config(conf, "mode=generator,data=" + dataSource)
batch_size = dis_conf.opt_config.batch_size
noise_dim = get_layer_size(gen_conf.model_config, "noise")
sample_dim = get_layer_size(dis_conf.model_config, "sample")
if dataSource == "mnist":
data_np = load_mnist_data("./data/raw_data/train-images-idx3-ubyte")
else:
data_np = load_mnist_data("./data/mnist_data/train-images-idx3-ubyte")
elif dataSource == "cifar":
data_np = load_cifar_data("./data/cifar-10-batches-py/")
else:
data_np = load_uniform_data()
if not os.path.exists("./%s_samples/" % dataSource):
os.makedirs("./%s_samples/" % dataSource)
......@@ -234,48 +249,44 @@ def main():
copy_shared_parameters(gen_training_machine, dis_training_machine)
copy_shared_parameters(gen_training_machine, generator_machine)
# constrain that either discriminator or generator can not be trained
# consecutively more than MAX_strike times
curr_train = "dis"
curr_strike = 0
MAX_strike = 10
MAX_strike = 5
for train_pass in xrange(100):
dis_trainer.startTrainPass()
gen_trainer.startTrainPass()
for i in xrange(1000):
# data_batch_dis = prepare_discriminator_data_batch(
# generator_machine, batch_size, noise_dim, sample_dim)
# dis_loss = get_training_loss(dis_training_machine, data_batch_dis)
for i in xrange(num_iter):
noise = get_noise(batch_size, noise_dim)
sample_noise = get_sample_noise(batch_size, sample_dim)
data_batch_dis_pos = prepare_discriminator_data_batch_pos(
batch_size, data_np, sample_noise)
batch_size, data_np)
dis_loss_pos = get_training_loss(dis_training_machine, data_batch_dis_pos)
sample_noise = get_sample_noise(batch_size, sample_dim)
data_batch_dis_neg = prepare_discriminator_data_batch_neg(
generator_machine, batch_size, noise, sample_noise)
generator_machine, batch_size, noise)
dis_loss_neg = get_training_loss(dis_training_machine, data_batch_dis_neg)
dis_loss = (dis_loss_pos + dis_loss_neg) / 2.0
data_batch_gen = prepare_generator_data_batch(
batch_size, noise, sample_noise)
batch_size, noise)
gen_loss = get_training_loss(gen_training_machine, data_batch_gen)
if i % 100 == 0:
print "d_pos_loss is %s d_neg_loss is %s" % (dis_loss_pos, dis_loss_neg)
print "d_loss is %s g_loss is %s" % (dis_loss, gen_loss)
if (not (curr_train == "dis" and curr_strike == MAX_strike)) and ((curr_train == "gen" and curr_strike == MAX_strike) or dis_loss_neg > gen_loss):
if (not (curr_train == "dis" and curr_strike == MAX_strike)) and \
((curr_train == "gen" and curr_strike == MAX_strike) or dis_loss > gen_loss):
if curr_train == "dis":
curr_strike += 1
else:
curr_train = "dis"
curr_strike = 1
dis_trainer.trainOneDataBatch(batch_size, data_batch_dis_neg)
dis_trainer.trainOneDataBatch(batch_size, data_batch_dis_pos)
# dis_loss = numpy.mean(dis_trainer.getForwardOutput()[0]["value"])
# print "getForwardOutput loss is %s" % dis_loss
dis_trainer.trainOneDataBatch(batch_size, data_batch_dis_pos)
copy_shared_parameters(dis_training_machine, gen_training_machine)
else:
......@@ -290,10 +301,11 @@ def main():
dis_trainer.finishTrainPass()
gen_trainer.finishTrainPass()
fake_samples = get_fake_samples(generator_machine, batch_size, noise)
saveImages(fake_samples, "./%s_samples/train_pass%s.png" % (dataSource, train_pass))
if dataSource == "uniform":
plot2DScatter(fake_samples, "./%s_samples/train_pass%s.png" % (dataSource, train_pass))
else:
saveImages(fake_samples, "./%s_samples/train_pass%s.png" % (dataSource, train_pass))
dis_trainer.finishTrain()
gen_trainer.finishTrain()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册