From 4878f0783b91ea64cf553b729eba3cb56f37eac6 Mon Sep 17 00:00:00 2001 From: wangyang59 Date: Wed, 16 Nov 2016 09:54:40 -0800 Subject: [PATCH] add gpu_id flag in demo/gan --- demo/gan/gan_trainer_image.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/demo/gan/gan_trainer_image.py b/demo/gan/gan_trainer_image.py index e8ed218663a..11476621c1a 100644 --- a/demo/gan/gan_trainer_image.py +++ b/demo/gan/gan_trainer_image.py @@ -184,13 +184,16 @@ def main(): parser.add_argument("-d", "--dataSource", help="mnist or cifar") parser.add_argument("--useGpu", default="1", help="1 means use gpu for training") + parser.add_argument("--gpuId", default="0", + help="the gpu_id parameter") args = parser.parse_args() dataSource = args.dataSource useGpu = args.useGpu assert dataSource in ["mnist", "cifar"] assert useGpu in ["0", "1"] - api.initPaddle('--use_gpu=' + useGpu, '--dot_period=10', '--log_period=100') + api.initPaddle('--use_gpu=' + useGpu, '--dot_period=10', '--log_period=100', + '--gpu_id=' + args.gpuId) gen_conf = parse_config("gan_conf_image.py", "mode=generator_training,data=" + dataSource) dis_conf = parse_config("gan_conf_image.py", "mode=discriminator_training,data=" + dataSource) generator_conf = parse_config("gan_conf_image.py", "mode=generator,data=" + dataSource) -- GitLab