未验证 提交 aabe1db1 编写于 作者: Y Yu Yang 提交者: GitHub

Feature/simple gan for api (#6149)

* Expose sigmoid_cross_entropy_with_logits

Also, change the `labels` to `label` for api consistency

* Very simple GAN based on pure FC layers
上级 813bbf40
import errno
import math
import os
import matplotlib
import numpy
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
NOISE_SIZE = 100
NUM_PASS = 1000
NUM_REAL_IMGS_IN_BATCH = 121
NUM_TRAIN_TIMES_OF_DG = 3
LEARNING_RATE = 2e-5
def D(x):
hidden = fluid.layers.fc(input=x,
size=200,
act='relu',
param_attr='D.w1',
bias_attr='D.b1')
logits = fluid.layers.fc(input=hidden,
size=1,
act=None,
param_attr='D.w2',
bias_attr='D.b2')
return logits
def G(x):
hidden = fluid.layers.fc(input=x,
size=200,
act='relu',
param_attr='G.w1',
bias_attr='G.b1')
img = fluid.layers.fc(input=hidden,
size=28 * 28,
act='tanh',
param_attr='G.w2',
bias_attr='G.b2')
return img
def plot(gen_data):
gen_data.resize(gen_data.shape[0], 28, 28)
n = int(math.ceil(math.sqrt(gen_data.shape[0])))
fig = plt.figure(figsize=(n, n))
gs = gridspec.GridSpec(n, n)
gs.update(wspace=0.05, hspace=0.05)
for i, sample in enumerate(gen_data):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
return fig
def main():
try:
os.makedirs("./out")
except OSError as e:
if e.errno != errno.EEXIST:
raise
startup_program = fluid.Program()
d_program = fluid.Program()
dg_program = fluid.Program()
with fluid.program_guard(d_program, startup_program):
img = fluid.layers.data(name='img', shape=[784], dtype='float32')
d_loss = fluid.layers.sigmoid_cross_entropy_with_logits(
x=D(img),
label=fluid.layers.data(
name='label', shape=[1], dtype='float32'))
d_loss = fluid.layers.mean(x=d_loss)
with fluid.program_guard(dg_program, startup_program):
noise = fluid.layers.data(
name='noise', shape=[NOISE_SIZE], dtype='float32')
g_img = G(x=noise)
g_program = dg_program.clone()
dg_loss = fluid.layers.sigmoid_cross_entropy_with_logits(
x=D(g_img),
label=fluid.layers.fill_constant_batch_size_like(
input=noise, dtype='float32', shape=[-1, 1], value=1.0))
dg_loss = fluid.layers.mean(x=dg_loss)
opt = fluid.optimizer.Adam(learning_rate=LEARNING_RATE)
opt.minimize(loss=d_loss, startup_program=startup_program)
opt.minimize(
loss=dg_loss,
startup_program=startup_program,
parameter_list=[
p.name for p in g_program.global_block().all_parameters()
])
exe = fluid.Executor(fluid.CPUPlace())
exe.run(startup_program)
num_true = NUM_REAL_IMGS_IN_BATCH
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=60000),
batch_size=num_true)
for pass_id in range(NUM_PASS):
for batch_id, data in enumerate(train_reader()):
num_true = len(data)
n = numpy.random.uniform(
low=-1.0, high=1.0,
size=[num_true * NOISE_SIZE]).astype('float32').reshape(
[num_true, NOISE_SIZE])
generated_img = exe.run(g_program,
feed={'noise': n},
fetch_list={g_img})[0]
real_data = numpy.array(map(lambda x: x[0], data)).astype('float32')
real_data = real_data.reshape(num_true, 784)
total_data = numpy.concatenate([real_data, generated_img])
total_label = numpy.concatenate([
numpy.ones(
shape=[real_data.shape[0], 1], dtype='float32'),
numpy.zeros(
shape=[real_data.shape[0], 1], dtype='float32')
])
d_loss_np = exe.run(d_program,
feed={'img': total_data,
'label': total_label},
fetch_list={d_loss})[0]
for _ in xrange(NUM_TRAIN_TIMES_OF_DG):
n = numpy.random.uniform(
low=-1.0, high=1.0,
size=[2 * num_true * NOISE_SIZE]).astype('float32').reshape(
[2 * num_true, NOISE_SIZE, 1, 1])
dg_loss_np = exe.run(dg_program,
feed={'noise': n},
fetch_list={dg_loss})[0]
print("Pass ID={0}, Batch ID={1}, D-Loss={2}, DG-Loss={3}".format(
pass_id, batch_id, d_loss_np, dg_loss_np))
# generate image each batch
fig = plot(generated_img)
plt.savefig(
'out/{0}.png'.format(str(pass_id).zfill(3)), bbox_inches='tight')
plt.close(fig)
if __name__ == '__main__':
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册