From f9f0d30ef50afd28c812bafb5cfb13d18a9c165e Mon Sep 17 00:00:00 2001 From: arlesniak Date: Tue, 1 Sep 2020 04:20:35 +0200 Subject: [PATCH] Enable CPU training for DyGraph MNIST Resnet (#4824) --- dygraph/mnist/train.py | 24 ++++++++++++++++++++---- dygraph/resnet/train.py | 15 +++++++++++++-- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/dygraph/mnist/train.py b/dygraph/mnist/train.py index 58db6f1d..df7da1db 100644 --- a/dygraph/mnist/train.py +++ b/dygraph/mnist/train.py @@ -35,6 +35,12 @@ def parse_args(): ) parser.add_argument("-e", "--epoch", default=5, type=int, help="set epoch") parser.add_argument("--ce", action="store_true", help="run ce") + parser.add_argument( + '--use_gpu', + type=ast.literal_eval, + default=True, + help='default use gpu.') + args = parser.parse_args() return args @@ -149,8 +155,13 @@ def test_mnist(reader, model, batch_size): def inference_mnist(): - place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \ - if args.use_data_parallel else fluid.CUDAPlace(0) + if not args.use_gpu: + place = fluid.CPUPlace() + elif not args.use_data_parallel: + place = fluid.CUDAPlace(0) + else: + place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) + with fluid.dygraph.guard(place): mnist_infer = MNIST() # load checkpoint @@ -180,8 +191,13 @@ def train_mnist(args): epoch_num = args.epoch BATCH_SIZE = 64 - place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \ - if args.use_data_parallel else fluid.CUDAPlace(0) + if not args.use_gpu: + place = fluid.CPUPlace() + elif not args.use_data_parallel: + place = fluid.CUDAPlace(0) + else: + place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) + with fluid.dygraph.guard(place): if args.ce: print("ce mode") diff --git a/dygraph/resnet/train.py b/dygraph/resnet/train.py index c6a9c888..a31e7455 100644 --- a/dygraph/resnet/train.py +++ b/dygraph/resnet/train.py @@ -130,6 +130,11 @@ def parse_args(): type=float, default=[0.229, 0.224, 0.225], help="The std of input image data") + parser.add_argument( + '--use_gpu', + type=ast.literal_eval, + default=True, + help='default use gpu.') args = parser.parse_args() return args @@ -354,8 +359,14 @@ def eval(model, data): def train_resnet(): epoch = args.epoch - place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \ - if args.use_data_parallel else fluid.CUDAPlace(0) + + if not args.use_gpu: + place = fluid.CPUPlace() + elif not args.use_data_parallel: + place = fluid.CUDAPlace(0) + else: + place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) + with fluid.dygraph.guard(place): if args.ce: print("ce mode") -- GitLab