diff --git a/demo/gan/gan_conf_image.py b/demo/gan/gan_conf_image.py index defeccb7eb553af0d303392043cbefe77f794a58..58bed2b18941bfb872c54dc71a539c4c7d13df14 100644 --- a/demo/gan/gan_conf_image.py +++ b/demo/gan/gan_conf_image.py @@ -117,10 +117,12 @@ def generator(noise): """ generator generates a sample given noise """ - param_attr = ParamAttr(is_static=is_discriminator_training) + param_attr = ParamAttr(is_static=is_discriminator_training, + initial_mean=0.0, + initial_std=0.02) bias_attr = ParamAttr(is_static=is_discriminator_training, - initial_mean=1.0, - initial_std=0) + initial_mean=0.0, + initial_std=0.0) param_attr_bn=ParamAttr(is_static=is_discriminator_training, initial_mean=1.0, diff --git a/demo/gan/gan_trainer_image.py b/demo/gan/gan_trainer_image.py index f6c3d2891b0ee3999786dbec7fc904f4e902a51a..536abab9210bfb6bc3e93d3a327aaf2ebb322d96 100644 --- a/demo/gan/gan_trainer_image.py +++ b/demo/gan/gan_trainer_image.py @@ -197,7 +197,7 @@ def main(): curr_train = "dis" curr_strike = 0 - MAX_strike = 100 + MAX_strike = 10 for train_pass in xrange(100): dis_trainer.startTrainPass()