提交 68d54ecb 编写于 作者: W wangyang59

code modification

上级 a0411e3c
此差异已折叠。
...@@ -92,25 +92,30 @@ def conv_bn(input, ...@@ -92,25 +92,30 @@ def conv_bn(input,
nameApx = "_conv" nameApx = "_conv"
if bn: if bn:
conv = img_conv_layer( conv_act = LinearActivation()
input, else:
filter_size=filter_size, conv_act = act
num_filters=num_filters,
name=name + nameApx, conv = img_conv_layer(
num_channels=channels, input,
act=LinearActivation(), filter_size=filter_size,
groups=1, num_filters=num_filters,
stride=stride, name=name + nameApx,
padding=padding, num_channels=channels,
bias_attr=bias_attr, act=conv_act,
param_attr=param_attr, groups=1,
shared_biases=True, stride=stride,
layer_attr=None, padding=padding,
filter_size_y=None, bias_attr=bias_attr,
stride_y=None, param_attr=param_attr,
padding_y=None, shared_biases=True,
trans=trans) layer_attr=None,
filter_size_y=None,
stride_y=None,
padding_y=None,
trans=trans)
if bn:
conv_bn = batch_norm_layer( conv_bn = batch_norm_layer(
conv, conv,
act=act, act=act,
...@@ -118,27 +123,8 @@ def conv_bn(input, ...@@ -118,27 +123,8 @@ def conv_bn(input,
bias_attr=bias_attr, bias_attr=bias_attr,
param_attr=param_attr_bn, param_attr=param_attr_bn,
use_global_stats=False) use_global_stats=False)
return conv_bn return conv_bn
else: else:
conv = img_conv_layer(
input,
filter_size=filter_size,
num_filters=num_filters,
name=name + nameApx,
num_channels=channels,
act=act,
groups=1,
stride=stride,
padding=padding,
bias_attr=bias_attr,
param_attr=param_attr,
shared_biases=True,
layer_attr=None,
filter_size_y=None,
stride_y=None,
padding_y=None,
trans=trans)
return conv return conv
......
...@@ -25,24 +25,6 @@ import py_paddle.swig_paddle as api ...@@ -25,24 +25,6 @@ import py_paddle.swig_paddle as api
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
def plot2DScatter(data, outputfile):
'''
Plot the data as a 2D scatter plot and save to outputfile
data needs to be two dimensinoal
'''
x = data[:, 0]
y = data[:, 1]
logger.info("The mean vector is %s" % numpy.mean(data, 0))
logger.info("The std vector is %s" % numpy.std(data, 0))
heatmap, xedges, yedges = numpy.histogram2d(x, y, bins=50)
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
plt.clf()
plt.scatter(x, y)
plt.savefig(outputfile, bbox_inches='tight')
def CHECK_EQ(a, b): def CHECK_EQ(a, b):
assert a == b, "a=%s, b=%s" % (a, b) assert a == b, "a=%s, b=%s" % (a, b)
...@@ -80,6 +62,12 @@ def print_parameters(src): ...@@ -80,6 +62,12 @@ def print_parameters(src):
) )
# synthesize 2-D uniform data
def load_uniform_data():
data = numpy.random.rand(1000000, 2).astype('float32')
return data
def load_mnist_data(imageFile): def load_mnist_data(imageFile):
f = open(imageFile, "rb") f = open(imageFile, "rb")
f.read(16) f.read(16)
...@@ -111,10 +99,22 @@ def load_cifar_data(cifar_path): ...@@ -111,10 +99,22 @@ def load_cifar_data(cifar_path):
return data return data
# synthesize 2-D uniform data def plot2DScatter(data, outputfile):
def load_uniform_data(): '''
data = numpy.random.rand(1000000, 2).astype('float32') Plot the data as a 2D scatter plot and save to outputfile
return data data needs to be two dimensinoal
'''
x = data[:, 0]
y = data[:, 1]
logger.info("The mean vector is %s" % numpy.mean(data, 0))
logger.info("The std vector is %s" % numpy.std(data, 0))
heatmap, xedges, yedges = numpy.histogram2d(x, y, bins=50)
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
plt.clf()
plt.scatter(x, y)
plt.savefig(outputfile, bbox_inches='tight')
def merge(images, size): def merge(images, size):
...@@ -140,6 +140,13 @@ def save_images(images, path): ...@@ -140,6 +140,13 @@ def save_images(images, path):
im.save(path) im.save(path)
def save_results(samples, path, data_source):
if data_source == "uniform":
plot2DScatter(samples, path)
else:
save_images(samples, path)
def get_real_samples(batch_size, data_np): def get_real_samples(batch_size, data_np):
return data_np[numpy.random.choice( return data_np[numpy.random.choice(
data_np.shape[0], batch_size, replace=False), :] data_np.shape[0], batch_size, replace=False), :]
...@@ -210,9 +217,14 @@ def main(): ...@@ -210,9 +217,14 @@ def main():
parser.add_argument( parser.add_argument(
"--use_gpu", default="1", help="1 means use gpu for training") "--use_gpu", default="1", help="1 means use gpu for training")
parser.add_argument("--gpu_id", default="0", help="the gpu_id parameter") parser.add_argument("--gpu_id", default="0", help="the gpu_id parameter")
parser.add_argument(
"--model_dir",
default="",
help="model path for generating samples, empty means training mode")
args = parser.parse_args() args = parser.parse_args()
data_source = args.data_source data_source = args.data_source
use_gpu = args.use_gpu use_gpu = args.use_gpu
model_dir = args.model_dir
assert data_source in ["mnist", "cifar", "uniform"] assert data_source in ["mnist", "cifar", "uniform"]
assert use_gpu in ["0", "1"] assert use_gpu in ["0", "1"]
...@@ -237,6 +249,8 @@ def main(): ...@@ -237,6 +249,8 @@ def main():
dis_conf = parse_config(conf, dis_conf = parse_config(conf,
"mode=discriminator_training,data=" + data_source) "mode=discriminator_training,data=" + data_source)
generator_conf = parse_config(conf, "mode=generator,data=" + data_source) generator_conf = parse_config(conf, "mode=generator,data=" + data_source)
logger.info(str(generator_conf.model_config))
batch_size = dis_conf.opt_config.batch_size batch_size = dis_conf.opt_config.batch_size
noise_dim = get_layer_size(gen_conf.model_config, "noise") noise_dim = get_layer_size(gen_conf.model_config, "noise")
...@@ -253,15 +267,21 @@ def main(): ...@@ -253,15 +267,21 @@ def main():
# this create a gradient machine for generator # this create a gradient machine for generator
gen_training_machine = api.GradientMachine.createFromConfigProto( gen_training_machine = api.GradientMachine.createFromConfigProto(
gen_conf.model_config) gen_conf.model_config)
# generator_machine is used to generate data only, which is used for # generator_machine is used to generate data only, which is used for
# training discriminator # training discriminator
logger.info(str(generator_conf.model_config))
generator_machine = api.GradientMachine.createFromConfigProto( generator_machine = api.GradientMachine.createFromConfigProto(
generator_conf.model_config) generator_conf.model_config)
dis_trainer = api.Trainer.create(dis_conf, dis_training_machine) # In the generating settings, use previously trained model to generate
# fake samples
if model_dir != "":
generator_machine.loadParameters(model_dir)
noise = get_noise(batch_size, noise_dim)
fake_samples = get_fake_samples(generator_machine, batch_size, noise)
save_results(fake_samples, "./generated_samples.png", data_source)
return
dis_trainer = api.Trainer.create(dis_conf, dis_training_machine)
gen_trainer = api.Trainer.create(gen_conf, gen_training_machine) gen_trainer = api.Trainer.create(gen_conf, gen_training_machine)
dis_trainer.startTrain() dis_trainer.startTrain()
...@@ -325,8 +345,6 @@ def main(): ...@@ -325,8 +345,6 @@ def main():
curr_train = "gen" curr_train = "gen"
curr_strike = 1 curr_strike = 1
gen_trainer.trainOneDataBatch(batch_size, data_batch_gen) gen_trainer.trainOneDataBatch(batch_size, data_batch_gen)
# TODO: add API for paddle to allow true parameter sharing between different GradientMachines
# so that we do not need to copy shared parameters.
copy_shared_parameters(gen_training_machine, copy_shared_parameters(gen_training_machine,
dis_training_machine) dis_training_machine)
copy_shared_parameters(gen_training_machine, generator_machine) copy_shared_parameters(gen_training_machine, generator_machine)
...@@ -335,12 +353,8 @@ def main(): ...@@ -335,12 +353,8 @@ def main():
gen_trainer.finishTrainPass() gen_trainer.finishTrainPass()
# At the end of each pass, save the generated samples/images # At the end of each pass, save the generated samples/images
fake_samples = get_fake_samples(generator_machine, batch_size, noise) fake_samples = get_fake_samples(generator_machine, batch_size, noise)
if data_source == "uniform": save_results(fake_samples, "./%s_samples/train_pass%s.png" %
plot2DScatter(fake_samples, "./%s_samples/train_pass%s.png" % (data_source, train_pass), data_source)
(data_source, train_pass))
else:
save_images(fake_samples, "./%s_samples/train_pass%s.png" %
(data_source, train_pass))
dis_trainer.finishTrain() dis_trainer.finishTrain()
gen_trainer.finishTrain() gen_trainer.finishTrain()
......
文件已添加
gan/image/gan_conf_graph.png

79.2 KB | W: | H:

gan/image/gan_conf_graph.png

125.7 KB | W: | H:

gan/image/gan_conf_graph.png
gan/image/gan_conf_graph.png
gan/image/gan_conf_graph.png
gan/image/gan_conf_graph.png
  • 2-up
  • Swipe
  • Onion skin
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册