提交 4878f078 编写于 作者: W wangyang59

add gpu_id flag in demo/gan

上级 d8aada07
......@@ -184,13 +184,16 @@ def main():
parser.add_argument("-d", "--dataSource", help="mnist or cifar")
parser.add_argument("--useGpu", default="1",
help="1 means use gpu for training")
parser.add_argument("--gpuId", default="0",
help="the gpu_id parameter")
args = parser.parse_args()
dataSource = args.dataSource
useGpu = args.useGpu
assert dataSource in ["mnist", "cifar"]
assert useGpu in ["0", "1"]
api.initPaddle('--use_gpu=' + useGpu, '--dot_period=10', '--log_period=100')
api.initPaddle('--use_gpu=' + useGpu, '--dot_period=10', '--log_period=100',
'--gpu_id=' + args.gpuId)
gen_conf = parse_config("gan_conf_image.py", "mode=generator_training,data=" + dataSource)
dis_conf = parse_config("gan_conf_image.py", "mode=discriminator_training,data=" + dataSource)
generator_conf = parse_config("gan_conf_image.py", "mode=generator,data=" + dataSource)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册