diff --git a/demo/gan/gan_conf_image.py b/demo/gan/gan_conf_image.py index 00711730d56d54ab8ef776d25c33fbf2e9445e89..83fa34fabcf83b5c55a853d2800620e086b8045a 100644 --- a/demo/gan/gan_conf_image.py +++ b/demo/gan/gan_conf_image.py @@ -36,9 +36,9 @@ s2, s4 = int(sample_dim/2), int(sample_dim/4), s8, s16 = int(sample_dim/8), int(sample_dim/16) settings( - batch_size=100, - learning_rate=1e-4, - learning_method=AdamOptimizer() + batch_size=128, + learning_rate=2e-4, + learning_method=AdamOptimizer(beta1=0.5) ) def conv_bn(input, channels, imgSize, num_filters, output_x, stride, name, diff --git a/demo/gan/gan_trainer_image.py b/demo/gan/gan_trainer_image.py index 8f1e17b9c74631b0b8d423f892c0bc44e41507ae..51bbbe8f1f4e1a860a076b41402b45a026be189b 100644 --- a/demo/gan/gan_trainer_image.py +++ b/demo/gan/gan_trainer_image.py @@ -88,7 +88,7 @@ def load_mnist_data(imageFile): for i in range(n): pixels = [] for j in range(28 * 28): - pixels.append(float(ord(f.read(1))) / 255.0) + pixels.append(float(ord(f.read(1))) / 255.0 * 2.0 - 1.0) data[i, :] = pixels f.close()