提交 93af332e 编写于 作者: W wangyang59

unified cifar/mnist/uniform gan training in demo

上级 4878f078
...@@ -4,5 +4,5 @@ output/ ...@@ -4,5 +4,5 @@ output/
.project .project
*.log *.log
*.pyc *.pyc
data/raw_data/ data/mnist_data/
data/cifar-10-batches-py/ data/cifar-10-batches-py/
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
# This scripts downloads the mnist data and unzips it. # This scripts downloads the mnist data and unzips it.
set -e set -e
DIR="$( cd "$(dirname "$0")" ; pwd -P )" DIR="$( cd "$(dirname "$0")" ; pwd -P )"
rm -rf "$DIR/raw_data" rm -rf "$DIR/mnist_data"
mkdir "$DIR/raw_data" mkdir "$DIR/mnist_data"
cd "$DIR/raw_data" cd "$DIR/mnist_data"
echo "Downloading..." echo "Downloading..."
......
...@@ -32,7 +32,7 @@ sample_dim = 2 ...@@ -32,7 +32,7 @@ sample_dim = 2
settings( settings(
batch_size=128, batch_size=128,
learning_rate=1e-4, learning_rate=1e-4,
learning_method=AdamOptimizer() learning_method=AdamOptimizer(beta1=0.7)
) )
def discriminator(sample): def discriminator(sample):
...@@ -47,16 +47,15 @@ def discriminator(sample): ...@@ -47,16 +47,15 @@ def discriminator(sample):
bias_attr = ParamAttr(is_static=is_generator_training, bias_attr = ParamAttr(is_static=is_generator_training,
initial_mean=1.0, initial_mean=1.0,
initial_std=0) initial_std=0)
hidden = fc_layer(input=sample, name="dis_hidden", size=hidden_dim, hidden = fc_layer(input=sample, name="dis_hidden", size=hidden_dim,
bias_attr=bias_attr, bias_attr=bias_attr,
param_attr=param_attr, param_attr=param_attr,
act=ReluActivation()) act=ReluActivation())
#act=LinearActivation())
hidden2 = fc_layer(input=hidden, name="dis_hidden2", size=hidden_dim, hidden2 = fc_layer(input=hidden, name="dis_hidden2", size=hidden_dim,
bias_attr=bias_attr, bias_attr=bias_attr,
param_attr=param_attr, param_attr=param_attr,
#act=ReluActivation())
act=LinearActivation()) act=LinearActivation())
hidden_bn = batch_norm_layer(hidden2, hidden_bn = batch_norm_layer(hidden2,
...@@ -88,12 +87,10 @@ def generator(noise): ...@@ -88,12 +87,10 @@ def generator(noise):
bias_attr=bias_attr, bias_attr=bias_attr,
param_attr=param_attr, param_attr=param_attr,
act=ReluActivation()) act=ReluActivation())
#act=LinearActivation())
hidden2 = fc_layer(input=hidden, name="gen_hidden2", size=hidden_dim, hidden2 = fc_layer(input=hidden, name="gen_hidden2", size=hidden_dim,
bias_attr=bias_attr, bias_attr=bias_attr,
param_attr=param_attr, param_attr=param_attr,
#act=ReluActivation())
act=LinearActivation()) act=LinearActivation())
hidden_bn = batch_norm_layer(hidden2, hidden_bn = batch_norm_layer(hidden2,
......
...@@ -113,7 +113,6 @@ def generator(noise): ...@@ -113,7 +113,6 @@ def generator(noise):
size=s8 * s8 * gf_dim * 4, size=s8 * s8 * gf_dim * 4,
bias_attr=bias_attr, bias_attr=bias_attr,
param_attr=param_attr, param_attr=param_attr,
#act=ReluActivation())
act=LinearActivation()) act=LinearActivation())
h1_bn = batch_norm_layer(h1, h1_bn = batch_norm_layer(h1,
...@@ -235,13 +234,8 @@ if is_discriminator_training: ...@@ -235,13 +234,8 @@ if is_discriminator_training:
sample = data_layer(name="sample", size=sample_dim * sample_dim*c_dim) sample = data_layer(name="sample", size=sample_dim * sample_dim*c_dim)
if is_generator_training or is_discriminator_training: if is_generator_training or is_discriminator_training:
sample_noise = data_layer(name="sample_noise",
size=sample_dim * sample_dim * c_dim)
label = data_layer(name="label", size=1) label = data_layer(name="label", size=1)
prob = discriminator(addto_layer([sample, sample_noise], prob = discriminator(sample)
act=LinearActivation(),
name="add",
bias_attr=False))
cost = cross_entropy(input=prob, label=label) cost = cross_entropy(input=prob, label=label)
classification_error_evaluator(input=prob, label=label, name=mode+'_error') classification_error_evaluator(input=prob, label=label, name=mode+'_error')
outputs(cost) outputs(cost)
......
...@@ -71,7 +71,7 @@ def print_parameters(src): ...@@ -71,7 +71,7 @@ def print_parameters(src):
print "value is %s \n" % p.getBuf(api.PARAMETER_VALUE).copyToNumpyArray() print "value is %s \n" % p.getBuf(api.PARAMETER_VALUE).copyToNumpyArray()
def get_real_samples(batch_size, sample_dim): def get_real_samples(batch_size, sample_dim):
return numpy.random.rand(batch_size, sample_dim).astype('float32') * 10.0 - 10.0 return numpy.random.rand(batch_size, sample_dim).astype('float32')
# return numpy.random.normal(loc=100.0, scale=100.0, size=(batch_size, sample_dim)).astype('float32') # return numpy.random.normal(loc=100.0, scale=100.0, size=(batch_size, sample_dim)).astype('float32')
def get_fake_samples(generator_machine, batch_size, noise_dim, sample_dim): def get_fake_samples(generator_machine, batch_size, noise_dim, sample_dim):
...@@ -106,7 +106,7 @@ def prepare_discriminator_data_batch_pos(batch_size, noise_dim, sample_dim): ...@@ -106,7 +106,7 @@ def prepare_discriminator_data_batch_pos(batch_size, noise_dim, sample_dim):
labels = numpy.ones(batch_size, dtype='int32') labels = numpy.ones(batch_size, dtype='int32')
inputs = api.Arguments.createArguments(2) inputs = api.Arguments.createArguments(2)
inputs.setSlotValue(0, api.Matrix.createGpuDenseFromNumpy(real_samples)) inputs.setSlotValue(0, api.Matrix.createGpuDenseFromNumpy(real_samples))
inputs.setSlotIds(1, api.IVector.createGpuVectorFromNumy(labels)) inputs.setSlotIds(1, api.IVector.createGpuVectorFromNumpy(labels))
return inputs return inputs
def prepare_discriminator_data_batch_neg(generator_machine, batch_size, noise_dim, sample_dim): def prepare_discriminator_data_batch_neg(generator_machine, batch_size, noise_dim, sample_dim):
...@@ -114,7 +114,7 @@ def prepare_discriminator_data_batch_neg(generator_machine, batch_size, noise_di ...@@ -114,7 +114,7 @@ def prepare_discriminator_data_batch_neg(generator_machine, batch_size, noise_di
labels = numpy.zeros(batch_size, dtype='int32') labels = numpy.zeros(batch_size, dtype='int32')
inputs = api.Arguments.createArguments(2) inputs = api.Arguments.createArguments(2)
inputs.setSlotValue(0, api.Matrix.createGpuDenseFromNumpy(fake_samples)) inputs.setSlotValue(0, api.Matrix.createGpuDenseFromNumpy(fake_samples))
inputs.setSlotIds(1, api.IVector.createGpuVectorFromNumy(labels)) inputs.setSlotIds(1, api.IVector.createGpuVectorFromNumpy(labels))
return inputs return inputs
def prepare_generator_data_batch(batch_size, dim): def prepare_generator_data_batch(batch_size, dim):
...@@ -122,7 +122,7 @@ def prepare_generator_data_batch(batch_size, dim): ...@@ -122,7 +122,7 @@ def prepare_generator_data_batch(batch_size, dim):
label = numpy.ones(batch_size, dtype='int32') label = numpy.ones(batch_size, dtype='int32')
inputs = api.Arguments.createArguments(2) inputs = api.Arguments.createArguments(2)
inputs.setSlotValue(0, api.Matrix.createGpuDenseFromNumpy(noise)) inputs.setSlotValue(0, api.Matrix.createGpuDenseFromNumpy(noise))
inputs.setSlotIds(1, api.IVector.createGpuVectorFromNumy(label)) inputs.setSlotIds(1, api.IVector.createGpuVectorFromNumpy(label))
return inputs return inputs
...@@ -140,7 +140,8 @@ def get_layer_size(model_conf, layer_name): ...@@ -140,7 +140,8 @@ def get_layer_size(model_conf, layer_name):
def main(): def main():
api.initPaddle('--use_gpu=1', '--dot_period=100', '--log_period=10000') api.initPaddle('--use_gpu=1', '--dot_period=10', '--log_period=100',
'--gpu_id=2')
gen_conf = parse_config("gan_conf.py", "mode=generator_training") gen_conf = parse_config("gan_conf.py", "mode=generator_training")
dis_conf = parse_config("gan_conf.py", "mode=discriminator_training") dis_conf = parse_config("gan_conf.py", "mode=discriminator_training")
generator_conf = parse_config("gan_conf.py", "mode=generator") generator_conf = parse_config("gan_conf.py", "mode=generator")
...@@ -175,10 +176,10 @@ def main(): ...@@ -175,10 +176,10 @@ def main():
curr_strike = 0 curr_strike = 0
MAX_strike = 5 MAX_strike = 5
for train_pass in xrange(10): for train_pass in xrange(100):
dis_trainer.startTrainPass() dis_trainer.startTrainPass()
gen_trainer.startTrainPass() gen_trainer.startTrainPass()
for i in xrange(100000): for i in xrange(1000):
# data_batch_dis = prepare_discriminator_data_batch( # data_batch_dis = prepare_discriminator_data_batch(
# generator_machine, batch_size, noise_dim, sample_dim) # generator_machine, batch_size, noise_dim, sample_dim)
# dis_loss = get_training_loss(dis_training_machine, data_batch_dis) # dis_loss = get_training_loss(dis_training_machine, data_batch_dis)
...@@ -199,7 +200,7 @@ def main(): ...@@ -199,7 +200,7 @@ def main():
if i % 1000 == 0: if i % 1000 == 0:
print "d_loss is %s g_loss is %s" % (dis_loss, gen_loss) print "d_loss is %s g_loss is %s" % (dis_loss, gen_loss)
if (not (curr_train == "dis" and curr_strike == MAX_strike)) and ((curr_train == "gen" and curr_strike == MAX_strike) or dis_loss > 0.690 or dis_loss > gen_loss): if (not (curr_train == "dis" and curr_strike == MAX_strike)) and ((curr_train == "gen" and curr_strike == MAX_strike) or dis_loss > gen_loss):
if curr_train == "dis": if curr_train == "dis":
curr_strike += 1 curr_strike += 1
else: else:
......
...@@ -13,16 +13,29 @@ ...@@ -13,16 +13,29 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import itertools
import random import random
import numpy import numpy
import cPickle import cPickle
import sys,os,gc import sys,os
from PIL import Image from PIL import Image
from paddle.trainer.config_parser import parse_config from paddle.trainer.config_parser import parse_config
from paddle.trainer.config_parser import logger from paddle.trainer.config_parser import logger
import py_paddle.swig_paddle as api import py_paddle.swig_paddle as api
import matplotlib.pyplot as plt
def plot2DScatter(data, outputfile):
x = data[:, 0]
y = data[:, 1]
print "The mean vector is %s" % numpy.mean(data, 0)
print "The std vector is %s" % numpy.std(data, 0)
heatmap, xedges, yedges = numpy.histogram2d(x, y, bins=50)
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
plt.clf()
plt.scatter(x, y)
plt.savefig(outputfile, bbox_inches='tight')
def CHECK_EQ(a, b): def CHECK_EQ(a, b):
assert a == b, "a=%s, b=%s" % (a, b) assert a == b, "a=%s, b=%s" % (a, b)
...@@ -60,7 +73,6 @@ def load_mnist_data(imageFile): ...@@ -60,7 +73,6 @@ def load_mnist_data(imageFile):
# Define number of samples for train/test # Define number of samples for train/test
if "train" in imageFile: if "train" in imageFile:
#n = 60000
n = 60000 n = 60000
else: else:
n = 10000 n = 10000
...@@ -89,6 +101,11 @@ def load_cifar_data(cifar_path): ...@@ -89,6 +101,11 @@ def load_cifar_data(cifar_path):
data = data / 255.0 * 2.0 - 1.0 data = data / 255.0 * 2.0 - 1.0
return data return data
# synthesize 2-D uniform data
def load_uniform_data():
data = numpy.random.rand(1000000, 2).astype('float32')
return data
def merge(images, size): def merge(images, size):
if images.shape[1] == 28*28: if images.shape[1] == 28*28:
h, w, c = 28, 28, 1 h, w, c = 28, 28, 1
...@@ -98,7 +115,6 @@ def merge(images, size): ...@@ -98,7 +115,6 @@ def merge(images, size):
for idx in xrange(size[0] * size[1]): for idx in xrange(size[0] * size[1]):
i = idx % size[1] i = idx % size[1]
j = idx // size[1] j = idx // size[1]
#img[j*h:j*h+h, i*w:i*w+w, :] = (images[idx, :].reshape((h, w, c), order="F") + 1.0) / 2.0 * 255.0
img[j*h:j*h+h, i*w:i*w+w, :] = \ img[j*h:j*h+h, i*w:i*w+w, :] = \
((images[idx, :].reshape((h, w, c), order="F").transpose(1, 0, 2) + 1.0) / 2.0 * 255.0) ((images[idx, :].reshape((h, w, c), order="F").transpose(1, 0, 2) + 1.0) / 2.0 * 255.0)
return img.astype('uint8') return img.astype('uint8')
...@@ -118,13 +134,9 @@ def get_real_samples(batch_size, data_np): ...@@ -118,13 +134,9 @@ def get_real_samples(batch_size, data_np):
def get_noise(batch_size, noise_dim): def get_noise(batch_size, noise_dim):
return numpy.random.normal(size=(batch_size, noise_dim)).astype('float32') return numpy.random.normal(size=(batch_size, noise_dim)).astype('float32')
def get_sample_noise(batch_size, sample_dim):
return numpy.random.normal(size=(batch_size, sample_dim),
scale=0.01).astype('float32')
def get_fake_samples(generator_machine, batch_size, noise): def get_fake_samples(generator_machine, batch_size, noise):
gen_inputs = api.Arguments.createArguments(1) gen_inputs = api.Arguments.createArguments(1)
gen_inputs.setSlotValue(0, api.Matrix.createGpuDenseFromNumpy(noise)) gen_inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(noise))
gen_outputs = api.Arguments.createArguments(0) gen_outputs = api.Arguments.createArguments(0)
generator_machine.forward(gen_inputs, gen_outputs, api.PASS_TEST) generator_machine.forward(gen_inputs, gen_outputs, api.PASS_TEST)
fake_samples = gen_outputs.getSlotValue(0).copyToNumpyMat() fake_samples = gen_outputs.getSlotValue(0).copyToNumpyMat()
...@@ -136,33 +148,27 @@ def get_training_loss(training_machine, inputs): ...@@ -136,33 +148,27 @@ def get_training_loss(training_machine, inputs):
loss = outputs.getSlotValue(0).copyToNumpyMat() loss = outputs.getSlotValue(0).copyToNumpyMat()
return numpy.mean(loss) return numpy.mean(loss)
def prepare_discriminator_data_batch_pos(batch_size, data_np, sample_noise): def prepare_discriminator_data_batch_pos(batch_size, data_np):
real_samples = get_real_samples(batch_size, data_np) real_samples = get_real_samples(batch_size, data_np)
labels = numpy.ones(batch_size, dtype='int32') labels = numpy.ones(batch_size, dtype='int32')
inputs = api.Arguments.createArguments(3) inputs = api.Arguments.createArguments(2)
inputs.setSlotValue(0, api.Matrix.createGpuDenseFromNumpy(real_samples)) inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(real_samples))
inputs.setSlotValue(1, api.Matrix.createGpuDenseFromNumpy(sample_noise)) inputs.setSlotIds(1, api.IVector.createVectorFromNumpy(labels))
inputs.setSlotIds(2, api.IVector.createGpuVectorFromNumpy(labels))
return inputs return inputs
def prepare_discriminator_data_batch_neg(generator_machine, batch_size, noise, def prepare_discriminator_data_batch_neg(generator_machine, batch_size, noise):
sample_noise):
fake_samples = get_fake_samples(generator_machine, batch_size, noise) fake_samples = get_fake_samples(generator_machine, batch_size, noise)
#print fake_samples.shape
labels = numpy.zeros(batch_size, dtype='int32') labels = numpy.zeros(batch_size, dtype='int32')
inputs = api.Arguments.createArguments(3) inputs = api.Arguments.createArguments(2)
inputs.setSlotValue(0, api.Matrix.createGpuDenseFromNumpy(fake_samples)) inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(fake_samples))
inputs.setSlotValue(1, api.Matrix.createGpuDenseFromNumpy(sample_noise)) inputs.setSlotIds(1, api.IVector.createVectorFromNumpy(labels))
inputs.setSlotIds(2, api.IVector.createGpuVectorFromNumpy(labels))
return inputs return inputs
def prepare_generator_data_batch(batch_size, noise, sample_noise): def prepare_generator_data_batch(batch_size, noise):
label = numpy.ones(batch_size, dtype='int32') label = numpy.ones(batch_size, dtype='int32')
#label = numpy.zeros(batch_size, dtype='int32') inputs = api.Arguments.createArguments(2)
inputs = api.Arguments.createArguments(3) inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(noise))
inputs.setSlotValue(0, api.Matrix.createGpuDenseFromNumpy(noise)) inputs.setSlotIds(1, api.IVector.createVectorFromNumpy(label))
inputs.setSlotValue(1, api.Matrix.createGpuDenseFromNumpy(sample_noise))
inputs.setSlotIds(2, api.IVector.createGpuVectorFromNumpy(label))
return inputs return inputs
...@@ -181,7 +187,7 @@ def get_layer_size(model_conf, layer_name): ...@@ -181,7 +187,7 @@ def get_layer_size(model_conf, layer_name):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-d", "--dataSource", help="mnist or cifar") parser.add_argument("-d", "--dataSource", help="mnist or cifar or uniform")
parser.add_argument("--useGpu", default="1", parser.add_argument("--useGpu", default="1",
help="1 means use gpu for training") help="1 means use gpu for training")
parser.add_argument("--gpuId", default="0", parser.add_argument("--gpuId", default="0",
...@@ -189,22 +195,31 @@ def main(): ...@@ -189,22 +195,31 @@ def main():
args = parser.parse_args() args = parser.parse_args()
dataSource = args.dataSource dataSource = args.dataSource
useGpu = args.useGpu useGpu = args.useGpu
assert dataSource in ["mnist", "cifar"] assert dataSource in ["mnist", "cifar", "uniform"]
assert useGpu in ["0", "1"] assert useGpu in ["0", "1"]
api.initPaddle('--use_gpu=' + useGpu, '--dot_period=10', '--log_period=100', api.initPaddle('--use_gpu=' + useGpu, '--dot_period=10', '--log_period=100',
'--gpu_id=' + args.gpuId) '--gpu_id=' + args.gpuId)
gen_conf = parse_config("gan_conf_image.py", "mode=generator_training,data=" + dataSource)
dis_conf = parse_config("gan_conf_image.py", "mode=discriminator_training,data=" + dataSource) if dataSource == "uniform":
generator_conf = parse_config("gan_conf_image.py", "mode=generator,data=" + dataSource) conf = "gan_conf.py"
num_iter = 10000
else:
conf = "gan_conf_image.py"
num_iter = 1000
gen_conf = parse_config(conf, "mode=generator_training,data=" + dataSource)
dis_conf = parse_config(conf, "mode=discriminator_training,data=" + dataSource)
generator_conf = parse_config(conf, "mode=generator,data=" + dataSource)
batch_size = dis_conf.opt_config.batch_size batch_size = dis_conf.opt_config.batch_size
noise_dim = get_layer_size(gen_conf.model_config, "noise") noise_dim = get_layer_size(gen_conf.model_config, "noise")
sample_dim = get_layer_size(dis_conf.model_config, "sample")
if dataSource == "mnist": if dataSource == "mnist":
data_np = load_mnist_data("./data/raw_data/train-images-idx3-ubyte") data_np = load_mnist_data("./data/mnist_data/train-images-idx3-ubyte")
else: elif dataSource == "cifar":
data_np = load_cifar_data("./data/cifar-10-batches-py/") data_np = load_cifar_data("./data/cifar-10-batches-py/")
else:
data_np = load_uniform_data()
if not os.path.exists("./%s_samples/" % dataSource): if not os.path.exists("./%s_samples/" % dataSource):
os.makedirs("./%s_samples/" % dataSource) os.makedirs("./%s_samples/" % dataSource)
...@@ -234,39 +249,37 @@ def main(): ...@@ -234,39 +249,37 @@ def main():
copy_shared_parameters(gen_training_machine, dis_training_machine) copy_shared_parameters(gen_training_machine, dis_training_machine)
copy_shared_parameters(gen_training_machine, generator_machine) copy_shared_parameters(gen_training_machine, generator_machine)
# constrain that either discriminator or generator can not be trained
# consecutively more than MAX_strike times
curr_train = "dis" curr_train = "dis"
curr_strike = 0 curr_strike = 0
MAX_strike = 10 MAX_strike = 5
for train_pass in xrange(100): for train_pass in xrange(100):
dis_trainer.startTrainPass() dis_trainer.startTrainPass()
gen_trainer.startTrainPass() gen_trainer.startTrainPass()
for i in xrange(1000): for i in xrange(num_iter):
# data_batch_dis = prepare_discriminator_data_batch(
# generator_machine, batch_size, noise_dim, sample_dim)
# dis_loss = get_training_loss(dis_training_machine, data_batch_dis)
noise = get_noise(batch_size, noise_dim) noise = get_noise(batch_size, noise_dim)
sample_noise = get_sample_noise(batch_size, sample_dim)
data_batch_dis_pos = prepare_discriminator_data_batch_pos( data_batch_dis_pos = prepare_discriminator_data_batch_pos(
batch_size, data_np, sample_noise) batch_size, data_np)
dis_loss_pos = get_training_loss(dis_training_machine, data_batch_dis_pos) dis_loss_pos = get_training_loss(dis_training_machine, data_batch_dis_pos)
sample_noise = get_sample_noise(batch_size, sample_dim)
data_batch_dis_neg = prepare_discriminator_data_batch_neg( data_batch_dis_neg = prepare_discriminator_data_batch_neg(
generator_machine, batch_size, noise, sample_noise) generator_machine, batch_size, noise)
dis_loss_neg = get_training_loss(dis_training_machine, data_batch_dis_neg) dis_loss_neg = get_training_loss(dis_training_machine, data_batch_dis_neg)
dis_loss = (dis_loss_pos + dis_loss_neg) / 2.0 dis_loss = (dis_loss_pos + dis_loss_neg) / 2.0
data_batch_gen = prepare_generator_data_batch( data_batch_gen = prepare_generator_data_batch(
batch_size, noise, sample_noise) batch_size, noise)
gen_loss = get_training_loss(gen_training_machine, data_batch_gen) gen_loss = get_training_loss(gen_training_machine, data_batch_gen)
if i % 100 == 0: if i % 100 == 0:
print "d_pos_loss is %s d_neg_loss is %s" % (dis_loss_pos, dis_loss_neg) print "d_pos_loss is %s d_neg_loss is %s" % (dis_loss_pos, dis_loss_neg)
print "d_loss is %s g_loss is %s" % (dis_loss, gen_loss) print "d_loss is %s g_loss is %s" % (dis_loss, gen_loss)
if (not (curr_train == "dis" and curr_strike == MAX_strike)) and ((curr_train == "gen" and curr_strike == MAX_strike) or dis_loss_neg > gen_loss): if (not (curr_train == "dis" and curr_strike == MAX_strike)) and \
((curr_train == "gen" and curr_strike == MAX_strike) or dis_loss > gen_loss):
if curr_train == "dis": if curr_train == "dis":
curr_strike += 1 curr_strike += 1
else: else:
...@@ -274,8 +287,6 @@ def main(): ...@@ -274,8 +287,6 @@ def main():
curr_strike = 1 curr_strike = 1
dis_trainer.trainOneDataBatch(batch_size, data_batch_dis_neg) dis_trainer.trainOneDataBatch(batch_size, data_batch_dis_neg)
dis_trainer.trainOneDataBatch(batch_size, data_batch_dis_pos) dis_trainer.trainOneDataBatch(batch_size, data_batch_dis_pos)
# dis_loss = numpy.mean(dis_trainer.getForwardOutput()[0]["value"])
# print "getForwardOutput loss is %s" % dis_loss
copy_shared_parameters(dis_training_machine, gen_training_machine) copy_shared_parameters(dis_training_machine, gen_training_machine)
else: else:
...@@ -290,9 +301,10 @@ def main(): ...@@ -290,9 +301,10 @@ def main():
dis_trainer.finishTrainPass() dis_trainer.finishTrainPass()
gen_trainer.finishTrainPass() gen_trainer.finishTrainPass()
fake_samples = get_fake_samples(generator_machine, batch_size, noise) fake_samples = get_fake_samples(generator_machine, batch_size, noise)
if dataSource == "uniform":
plot2DScatter(fake_samples, "./%s_samples/train_pass%s.png" % (dataSource, train_pass))
else:
saveImages(fake_samples, "./%s_samples/train_pass%s.png" % (dataSource, train_pass)) saveImages(fake_samples, "./%s_samples/train_pass%s.png" % (dataSource, train_pass))
dis_trainer.finishTrain() dis_trainer.finishTrain()
gen_trainer.finishTrain() gen_trainer.finishTrain()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册