提交 d8aada07 编写于 作者: W wangyang59

added cifar data into dema/gan

上级 fb0d80d5
...@@ -2,5 +2,7 @@ output/ ...@@ -2,5 +2,7 @@ output/
*.png *.png
.pydevproject .pydevproject
.project .project
train.log *.log
*.pyc
data/raw_data/ data/raw_data/
data/cifar-10-batches-py/
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
tar zxf cifar-10-python.tar.gz
rm cifar-10-python.tar.gz
...@@ -12,10 +12,9 @@ ...@@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle.trainer_config_helpers import * from paddle.trainer_config_helpers import *
from paddle.trainer_config_helpers.activations import LinearActivation
from numpy.distutils.system_info import tmp
mode = get_config_arg("mode", str, "generator") mode = get_config_arg("mode", str, "generator")
dataSource = get_config_arg("data", str, "mnist")
assert mode in set(["generator", assert mode in set(["generator",
"discriminator", "discriminator",
"generator_training", "generator_training",
...@@ -30,8 +29,12 @@ print('mode=%s' % mode) ...@@ -30,8 +29,12 @@ print('mode=%s' % mode)
noise_dim = 100 noise_dim = 100
gf_dim = 64 gf_dim = 64
df_dim = 64 df_dim = 64
sample_dim = 28 # image dim if dataSource == "mnist":
c_dim = 1 # image color sample_dim = 28 # image dim
c_dim = 1 # image color
else:
sample_dim = 32
c_dim = 3
s2, s4 = int(sample_dim/2), int(sample_dim/4), s2, s4 = int(sample_dim/2), int(sample_dim/4),
s8, s16 = int(sample_dim/8), int(sample_dim/16) s8, s16 = int(sample_dim/8), int(sample_dim/16)
......
...@@ -16,31 +16,13 @@ import argparse ...@@ -16,31 +16,13 @@ import argparse
import itertools import itertools
import random import random
import numpy import numpy
import cPickle
import sys,os,gc import sys,os,gc
from PIL import Image from PIL import Image
from paddle.trainer.config_parser import parse_config from paddle.trainer.config_parser import parse_config
from paddle.trainer.config_parser import logger from paddle.trainer.config_parser import logger
import py_paddle.swig_paddle as api import py_paddle.swig_paddle as api
from py_paddle import DataProviderConverter
import matplotlib.pyplot as plt
def plot2DScatter(data, outputfile):
# Generate some test data
x = data[:, 0]
y = data[:, 1]
print "The mean vector is %s" % numpy.mean(data, 0)
print "The std vector is %s" % numpy.std(data, 0)
heatmap, xedges, yedges = numpy.histogram2d(x, y, bins=50)
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
plt.clf()
plt.scatter(x, y)
# plt.show()
plt.savefig(outputfile, bbox_inches='tight')
def CHECK_EQ(a, b): def CHECK_EQ(a, b):
assert a == b, "a=%s, b=%s" % (a, b) assert a == b, "a=%s, b=%s" % (a, b)
...@@ -94,18 +76,39 @@ def load_mnist_data(imageFile): ...@@ -94,18 +76,39 @@ def load_mnist_data(imageFile):
f.close() f.close()
return data return data
def load_cifar_data(cifar_path):
batch_size = 10000
data = numpy.zeros((5*batch_size, 32*32*3), dtype = "float32")
for i in range(1, 6):
file = cifar_path + "/data_batch_" + str(i)
fo = open(file, 'rb')
dict = cPickle.load(fo)
fo.close()
data[(i - 1)*batch_size:(i*batch_size), :] = dict["data"]
data = data / 255.0 * 2.0 - 1.0
return data
def merge(images, size): def merge(images, size):
h, w = 28, 28 if images.shape[1] == 28*28:
img = numpy.zeros((h * size[0], w * size[1])) h, w, c = 28, 28, 1
else:
h, w, c = 32, 32, 3
img = numpy.zeros((h * size[0], w * size[1], c))
for idx in xrange(size[0] * size[1]): for idx in xrange(size[0] * size[1]):
i = idx % size[1] i = idx % size[1]
j = idx // size[1] j = idx // size[1]
img[j*h:j*h+h, i*w:i*w+w] = (images[idx, :].reshape((h, w)) + 1.0) / 2.0 * 255.0 #img[j*h:j*h+h, i*w:i*w+w, :] = (images[idx, :].reshape((h, w, c), order="F") + 1.0) / 2.0 * 255.0
return img img[j*h:j*h+h, i*w:i*w+w, :] = \
((images[idx, :].reshape((h, w, c), order="F").transpose(1, 0, 2) + 1.0) / 2.0 * 255.0)
return img.astype('uint8')
def saveImages(images, path): def saveImages(images, path):
merged_img = merge(images, [8, 8]) merged_img = merge(images, [8, 8])
im = Image.fromarray(merged_img).convert('RGB') if merged_img.shape[2] == 1:
im = Image.fromarray(numpy.squeeze(merged_img)).convert('RGB')
else:
im = Image.fromarray(merged_img, mode="RGB")
im.save(path) im.save(path)
def get_real_samples(batch_size, data_np): def get_real_samples(batch_size, data_np):
...@@ -115,9 +118,9 @@ def get_real_samples(batch_size, data_np): ...@@ -115,9 +118,9 @@ def get_real_samples(batch_size, data_np):
def get_noise(batch_size, noise_dim): def get_noise(batch_size, noise_dim):
return numpy.random.normal(size=(batch_size, noise_dim)).astype('float32') return numpy.random.normal(size=(batch_size, noise_dim)).astype('float32')
def get_sample_noise(batch_size): def get_sample_noise(batch_size, sample_dim):
return numpy.random.normal(size=(batch_size, 28*28), return numpy.random.normal(size=(batch_size, sample_dim),
scale=0.1).astype('float32') scale=0.01).astype('float32')
def get_fake_samples(generator_machine, batch_size, noise): def get_fake_samples(generator_machine, batch_size, noise):
gen_inputs = api.Arguments.createArguments(1) gen_inputs = api.Arguments.createArguments(1)
...@@ -177,15 +180,31 @@ def get_layer_size(model_conf, layer_name): ...@@ -177,15 +180,31 @@ def get_layer_size(model_conf, layer_name):
def main(): def main():
api.initPaddle('--use_gpu=1', '--dot_period=10', '--log_period=100') parser = argparse.ArgumentParser()
gen_conf = parse_config("gan_conf_image.py", "mode=generator_training") parser.add_argument("-d", "--dataSource", help="mnist or cifar")
dis_conf = parse_config("gan_conf_image.py", "mode=discriminator_training") parser.add_argument("--useGpu", default="1",
generator_conf = parse_config("gan_conf_image.py", "mode=generator") help="1 means use gpu for training")
args = parser.parse_args()
dataSource = args.dataSource
useGpu = args.useGpu
assert dataSource in ["mnist", "cifar"]
assert useGpu in ["0", "1"]
api.initPaddle('--use_gpu=' + useGpu, '--dot_period=10', '--log_period=100')
gen_conf = parse_config("gan_conf_image.py", "mode=generator_training,data=" + dataSource)
dis_conf = parse_config("gan_conf_image.py", "mode=discriminator_training,data=" + dataSource)
generator_conf = parse_config("gan_conf_image.py", "mode=generator,data=" + dataSource)
batch_size = dis_conf.opt_config.batch_size batch_size = dis_conf.opt_config.batch_size
noise_dim = get_layer_size(gen_conf.model_config, "noise") noise_dim = get_layer_size(gen_conf.model_config, "noise")
sample_dim = get_layer_size(dis_conf.model_config, "sample") sample_dim = get_layer_size(dis_conf.model_config, "sample")
if dataSource == "mnist":
data_np = load_mnist_data("./data/raw_data/train-images-idx3-ubyte") data_np = load_mnist_data("./data/raw_data/train-images-idx3-ubyte")
else:
data_np = load_cifar_data("./data/cifar-10-batches-py/")
if not os.path.exists("./%s_samples/" % dataSource):
os.makedirs("./%s_samples/" % dataSource)
# this create a gradient machine for discriminator # this create a gradient machine for discriminator
dis_training_machine = api.GradientMachine.createFromConfigProto( dis_training_machine = api.GradientMachine.createFromConfigProto(
...@@ -224,12 +243,12 @@ def main(): ...@@ -224,12 +243,12 @@ def main():
# generator_machine, batch_size, noise_dim, sample_dim) # generator_machine, batch_size, noise_dim, sample_dim)
# dis_loss = get_training_loss(dis_training_machine, data_batch_dis) # dis_loss = get_training_loss(dis_training_machine, data_batch_dis)
noise = get_noise(batch_size, noise_dim) noise = get_noise(batch_size, noise_dim)
sample_noise = get_sample_noise(batch_size) sample_noise = get_sample_noise(batch_size, sample_dim)
data_batch_dis_pos = prepare_discriminator_data_batch_pos( data_batch_dis_pos = prepare_discriminator_data_batch_pos(
batch_size, data_np, sample_noise) batch_size, data_np, sample_noise)
dis_loss_pos = get_training_loss(dis_training_machine, data_batch_dis_pos) dis_loss_pos = get_training_loss(dis_training_machine, data_batch_dis_pos)
sample_noise = get_sample_noise(batch_size) sample_noise = get_sample_noise(batch_size, sample_dim)
data_batch_dis_neg = prepare_discriminator_data_batch_neg( data_batch_dis_neg = prepare_discriminator_data_batch_neg(
generator_machine, batch_size, noise, sample_noise) generator_machine, batch_size, noise, sample_noise)
dis_loss_neg = get_training_loss(dis_training_machine, data_batch_dis_neg) dis_loss_neg = get_training_loss(dis_training_machine, data_batch_dis_neg)
...@@ -271,7 +290,7 @@ def main(): ...@@ -271,7 +290,7 @@ def main():
fake_samples = get_fake_samples(generator_machine, batch_size, noise) fake_samples = get_fake_samples(generator_machine, batch_size, noise)
saveImages(fake_samples, "train_pass%s.png" % train_pass) saveImages(fake_samples, "./%s_samples/train_pass%s.png" % (dataSource, train_pass))
dis_trainer.finishTrain() dis_trainer.finishTrain()
gen_trainer.finishTrain() gen_trainer.finishTrain()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册