diff --git a/demo/gan/gan_trainer_image.py b/demo/gan/gan_trainer_image.py index e8ed218663a8ab346373e1e216bfa77c06e7fbd5..11476621c1a22a98ca382f9b6cda222b4e9319bd 100644 --- a/demo/gan/gan_trainer_image.py +++ b/demo/gan/gan_trainer_image.py @@ -184,13 +184,16 @@ def main(): parser.add_argument("-d", "--dataSource", help="mnist or cifar") parser.add_argument("--useGpu", default="1", help="1 means use gpu for training") + parser.add_argument("--gpuId", default="0", + help="the gpu_id parameter") args = parser.parse_args() dataSource = args.dataSource useGpu = args.useGpu assert dataSource in ["mnist", "cifar"] assert useGpu in ["0", "1"] - api.initPaddle('--use_gpu=' + useGpu, '--dot_period=10', '--log_period=100') + api.initPaddle('--use_gpu=' + useGpu, '--dot_period=10', '--log_period=100', + '--gpu_id=' + args.gpuId) gen_conf = parse_config("gan_conf_image.py", "mode=generator_training,data=" + dataSource) dis_conf = parse_config("gan_conf_image.py", "mode=discriminator_training,data=" + dataSource) generator_conf = parse_config("gan_conf_image.py", "mode=generator,data=" + dataSource)