diff --git a/demo/gan/gan_conf_image.py b/demo/gan/gan_conf_image.py index 83fa34fabcf83b5c55a853d2800620e086b8045a..e811bb96e88aefcd5ec510ea5c77ca4f35adfc47 100644 --- a/demo/gan/gan_conf_image.py +++ b/demo/gan/gan_conf_image.py @@ -232,8 +232,13 @@ if is_discriminator_training: sample = data_layer(name="sample", size=sample_dim * sample_dim*c_dim) if is_generator_training or is_discriminator_training: + sample_noise = data_layer(name="sample_noise", + size=sample_dim * sample_dim * c_dim) label = data_layer(name="label", size=1) - prob = discriminator(sample) + prob = discriminator(addto_layer([sample, sample_noise], + act=LinearActivation(), + name="add", + bias_attr=False)) cost = cross_entropy(input=prob, label=label) classification_error_evaluator(input=prob, label=label, name=mode+'_error') outputs(cost) diff --git a/demo/gan/gan_trainer_image.py b/demo/gan/gan_trainer_image.py index 51bbbe8f1f4e1a860a076b41402b45a026be189b..9c7ddd4796fd62c2e1f6877cb3a74378be5dcab1 100644 --- a/demo/gan/gan_trainer_image.py +++ b/demo/gan/gan_trainer_image.py @@ -115,9 +115,13 @@ def get_real_samples(batch_size, data_np): def get_noise(batch_size, noise_dim): return numpy.random.normal(size=(batch_size, noise_dim)).astype('float32') +def get_sample_noise(batch_size): + return numpy.random.normal(size=(batch_size, 28*28), + scale=0.1).astype('float32') + def get_fake_samples(generator_machine, batch_size, noise): - gen_inputs = prepare_generator_data_batch(batch_size, noise) - gen_inputs.resize(1) + gen_inputs = api.Arguments.createArguments(1) + gen_inputs.setSlotValue(0, api.Matrix.createGpuDenseFromNumpy(noise)) gen_outputs = api.Arguments.createArguments(0) generator_machine.forward(gen_inputs, gen_outputs, api.PASS_TEST) fake_samples = gen_outputs.getSlotValue(0).copyToNumpyMat() @@ -129,29 +133,33 @@ def get_training_loss(training_machine, inputs): loss = outputs.getSlotValue(0).copyToNumpyMat() return numpy.mean(loss) -def prepare_discriminator_data_batch_pos(batch_size, data_np): +def prepare_discriminator_data_batch_pos(batch_size, data_np, sample_noise): real_samples = get_real_samples(batch_size, data_np) labels = numpy.ones(batch_size, dtype='int32') - inputs = api.Arguments.createArguments(2) + inputs = api.Arguments.createArguments(3) inputs.setSlotValue(0, api.Matrix.createGpuDenseFromNumpy(real_samples)) - inputs.setSlotIds(1, api.IVector.createGpuVectorFromNumpy(labels)) + inputs.setSlotValue(1, api.Matrix.createGpuDenseFromNumpy(sample_noise)) + inputs.setSlotIds(2, api.IVector.createGpuVectorFromNumpy(labels)) return inputs -def prepare_discriminator_data_batch_neg(generator_machine, batch_size, noise): +def prepare_discriminator_data_batch_neg(generator_machine, batch_size, noise, + sample_noise): fake_samples = get_fake_samples(generator_machine, batch_size, noise) #print fake_samples.shape labels = numpy.zeros(batch_size, dtype='int32') - inputs = api.Arguments.createArguments(2) + inputs = api.Arguments.createArguments(3) inputs.setSlotValue(0, api.Matrix.createGpuDenseFromNumpy(fake_samples)) - inputs.setSlotIds(1, api.IVector.createGpuVectorFromNumpy(labels)) + inputs.setSlotValue(1, api.Matrix.createGpuDenseFromNumpy(sample_noise)) + inputs.setSlotIds(2, api.IVector.createGpuVectorFromNumpy(labels)) return inputs -def prepare_generator_data_batch(batch_size, noise): +def prepare_generator_data_batch(batch_size, noise, sample_noise): label = numpy.ones(batch_size, dtype='int32') #label = numpy.zeros(batch_size, dtype='int32') - inputs = api.Arguments.createArguments(2) + inputs = api.Arguments.createArguments(3) inputs.setSlotValue(0, api.Matrix.createGpuDenseFromNumpy(noise)) - inputs.setSlotIds(1, api.IVector.createGpuVectorFromNumpy(label)) + inputs.setSlotValue(1, api.Matrix.createGpuDenseFromNumpy(sample_noise)) + inputs.setSlotIds(2, api.IVector.createGpuVectorFromNumpy(label)) return inputs @@ -216,25 +224,27 @@ def main(): # generator_machine, batch_size, noise_dim, sample_dim) # dis_loss = get_training_loss(dis_training_machine, data_batch_dis) noise = get_noise(batch_size, noise_dim) + sample_noise = get_sample_noise(batch_size) data_batch_dis_pos = prepare_discriminator_data_batch_pos( - batch_size, data_np) + batch_size, data_np, sample_noise) dis_loss_pos = get_training_loss(dis_training_machine, data_batch_dis_pos) - + + sample_noise = get_sample_noise(batch_size) data_batch_dis_neg = prepare_discriminator_data_batch_neg( - generator_machine, batch_size, noise) + generator_machine, batch_size, noise, sample_noise) dis_loss_neg = get_training_loss(dis_training_machine, data_batch_dis_neg) dis_loss = (dis_loss_pos + dis_loss_neg) / 2.0 data_batch_gen = prepare_generator_data_batch( - batch_size, noise) + batch_size, noise, sample_noise) gen_loss = get_training_loss(gen_training_machine, data_batch_gen) if i % 100 == 0: print "d_pos_loss is %s d_neg_loss is %s" % (dis_loss_pos, dis_loss_neg) print "d_loss is %s g_loss is %s" % (dis_loss, gen_loss) - if (not (curr_train == "dis" and curr_strike == MAX_strike)) and ((curr_train == "gen" and curr_strike == MAX_strike) or dis_loss > gen_loss): + if (not (curr_train == "dis" and curr_strike == MAX_strike)) and ((curr_train == "gen" and curr_strike == MAX_strike) or dis_loss_neg > gen_loss): if curr_train == "dis": curr_strike += 1 else: