提交 68d54ecb 编写于 作者: W wangyang59

code modification

上级 a0411e3c
此差异已折叠。
......@@ -92,25 +92,30 @@ def conv_bn(input,
nameApx = "_conv"
if bn:
conv = img_conv_layer(
input,
filter_size=filter_size,
num_filters=num_filters,
name=name + nameApx,
num_channels=channels,
act=LinearActivation(),
groups=1,
stride=stride,
padding=padding,
bias_attr=bias_attr,
param_attr=param_attr,
shared_biases=True,
layer_attr=None,
filter_size_y=None,
stride_y=None,
padding_y=None,
trans=trans)
conv_act = LinearActivation()
else:
conv_act = act
conv = img_conv_layer(
input,
filter_size=filter_size,
num_filters=num_filters,
name=name + nameApx,
num_channels=channels,
act=conv_act,
groups=1,
stride=stride,
padding=padding,
bias_attr=bias_attr,
param_attr=param_attr,
shared_biases=True,
layer_attr=None,
filter_size_y=None,
stride_y=None,
padding_y=None,
trans=trans)
if bn:
conv_bn = batch_norm_layer(
conv,
act=act,
......@@ -118,27 +123,8 @@ def conv_bn(input,
bias_attr=bias_attr,
param_attr=param_attr_bn,
use_global_stats=False)
return conv_bn
else:
conv = img_conv_layer(
input,
filter_size=filter_size,
num_filters=num_filters,
name=name + nameApx,
num_channels=channels,
act=act,
groups=1,
stride=stride,
padding=padding,
bias_attr=bias_attr,
param_attr=param_attr,
shared_biases=True,
layer_attr=None,
filter_size_y=None,
stride_y=None,
padding_y=None,
trans=trans)
return conv
......
......@@ -25,24 +25,6 @@ import py_paddle.swig_paddle as api
import matplotlib.pyplot as plt
def plot2DScatter(data, outputfile):
'''
Plot the data as a 2D scatter plot and save to outputfile
data needs to be two dimensinoal
'''
x = data[:, 0]
y = data[:, 1]
logger.info("The mean vector is %s" % numpy.mean(data, 0))
logger.info("The std vector is %s" % numpy.std(data, 0))
heatmap, xedges, yedges = numpy.histogram2d(x, y, bins=50)
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
plt.clf()
plt.scatter(x, y)
plt.savefig(outputfile, bbox_inches='tight')
def CHECK_EQ(a, b):
assert a == b, "a=%s, b=%s" % (a, b)
......@@ -80,6 +62,12 @@ def print_parameters(src):
)
# synthesize 2-D uniform data
def load_uniform_data():
data = numpy.random.rand(1000000, 2).astype('float32')
return data
def load_mnist_data(imageFile):
f = open(imageFile, "rb")
f.read(16)
......@@ -111,10 +99,22 @@ def load_cifar_data(cifar_path):
return data
# synthesize 2-D uniform data
def load_uniform_data():
data = numpy.random.rand(1000000, 2).astype('float32')
return data
def plot2DScatter(data, outputfile):
'''
Plot the data as a 2D scatter plot and save to outputfile
data needs to be two dimensinoal
'''
x = data[:, 0]
y = data[:, 1]
logger.info("The mean vector is %s" % numpy.mean(data, 0))
logger.info("The std vector is %s" % numpy.std(data, 0))
heatmap, xedges, yedges = numpy.histogram2d(x, y, bins=50)
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
plt.clf()
plt.scatter(x, y)
plt.savefig(outputfile, bbox_inches='tight')
def merge(images, size):
......@@ -140,6 +140,13 @@ def save_images(images, path):
im.save(path)
def save_results(samples, path, data_source):
if data_source == "uniform":
plot2DScatter(samples, path)
else:
save_images(samples, path)
def get_real_samples(batch_size, data_np):
return data_np[numpy.random.choice(
data_np.shape[0], batch_size, replace=False), :]
......@@ -210,9 +217,14 @@ def main():
parser.add_argument(
"--use_gpu", default="1", help="1 means use gpu for training")
parser.add_argument("--gpu_id", default="0", help="the gpu_id parameter")
parser.add_argument(
"--model_dir",
default="",
help="model path for generating samples, empty means training mode")
args = parser.parse_args()
data_source = args.data_source
use_gpu = args.use_gpu
model_dir = args.model_dir
assert data_source in ["mnist", "cifar", "uniform"]
assert use_gpu in ["0", "1"]
......@@ -237,6 +249,8 @@ def main():
dis_conf = parse_config(conf,
"mode=discriminator_training,data=" + data_source)
generator_conf = parse_config(conf, "mode=generator,data=" + data_source)
logger.info(str(generator_conf.model_config))
batch_size = dis_conf.opt_config.batch_size
noise_dim = get_layer_size(gen_conf.model_config, "noise")
......@@ -253,15 +267,21 @@ def main():
# this create a gradient machine for generator
gen_training_machine = api.GradientMachine.createFromConfigProto(
gen_conf.model_config)
# generator_machine is used to generate data only, which is used for
# training discriminator
logger.info(str(generator_conf.model_config))
generator_machine = api.GradientMachine.createFromConfigProto(
generator_conf.model_config)
dis_trainer = api.Trainer.create(dis_conf, dis_training_machine)
# In the generating settings, use previously trained model to generate
# fake samples
if model_dir != "":
generator_machine.loadParameters(model_dir)
noise = get_noise(batch_size, noise_dim)
fake_samples = get_fake_samples(generator_machine, batch_size, noise)
save_results(fake_samples, "./generated_samples.png", data_source)
return
dis_trainer = api.Trainer.create(dis_conf, dis_training_machine)
gen_trainer = api.Trainer.create(gen_conf, gen_training_machine)
dis_trainer.startTrain()
......@@ -325,8 +345,6 @@ def main():
curr_train = "gen"
curr_strike = 1
gen_trainer.trainOneDataBatch(batch_size, data_batch_gen)
# TODO: add API for paddle to allow true parameter sharing between different GradientMachines
# so that we do not need to copy shared parameters.
copy_shared_parameters(gen_training_machine,
dis_training_machine)
copy_shared_parameters(gen_training_machine, generator_machine)
......@@ -335,12 +353,8 @@ def main():
gen_trainer.finishTrainPass()
# At the end of each pass, save the generated samples/images
fake_samples = get_fake_samples(generator_machine, batch_size, noise)
if data_source == "uniform":
plot2DScatter(fake_samples, "./%s_samples/train_pass%s.png" %
(data_source, train_pass))
else:
save_images(fake_samples, "./%s_samples/train_pass%s.png" %
(data_source, train_pass))
save_results(fake_samples, "./%s_samples/train_pass%s.png" %
(data_source, train_pass), data_source)
dis_trainer.finishTrain()
gen_trainer.finishTrain()
......
文件已添加
gan/image/gan_conf_graph.png

79.2 KB | W: | H:

gan/image/gan_conf_graph.png

125.7 KB | W: | H:

gan/image/gan_conf_graph.png
gan/image/gan_conf_graph.png
gan/image/gan_conf_graph.png
gan/image/gan_conf_graph.png
  • 2-up
  • Swipe
  • Onion skin
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册