未验证 提交 f9f0d30e 编写于 作者: A arlesniak 提交者: GitHub

Enable CPU training for DyGraph MNIST Resnet (#4824)

上级 64cde5d1
...@@ -35,6 +35,12 @@ def parse_args(): ...@@ -35,6 +35,12 @@ def parse_args():
) )
parser.add_argument("-e", "--epoch", default=5, type=int, help="set epoch") parser.add_argument("-e", "--epoch", default=5, type=int, help="set epoch")
parser.add_argument("--ce", action="store_true", help="run ce") parser.add_argument("--ce", action="store_true", help="run ce")
parser.add_argument(
'--use_gpu',
type=ast.literal_eval,
default=True,
help='default use gpu.')
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -149,8 +155,13 @@ def test_mnist(reader, model, batch_size): ...@@ -149,8 +155,13 @@ def test_mnist(reader, model, batch_size):
def inference_mnist(): def inference_mnist():
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \ if not args.use_gpu:
if args.use_data_parallel else fluid.CUDAPlace(0) place = fluid.CPUPlace()
elif not args.use_data_parallel:
place = fluid.CUDAPlace(0)
else:
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id)
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
mnist_infer = MNIST() mnist_infer = MNIST()
# load checkpoint # load checkpoint
...@@ -180,8 +191,13 @@ def train_mnist(args): ...@@ -180,8 +191,13 @@ def train_mnist(args):
epoch_num = args.epoch epoch_num = args.epoch
BATCH_SIZE = 64 BATCH_SIZE = 64
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \ if not args.use_gpu:
if args.use_data_parallel else fluid.CUDAPlace(0) place = fluid.CPUPlace()
elif not args.use_data_parallel:
place = fluid.CUDAPlace(0)
else:
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id)
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
if args.ce: if args.ce:
print("ce mode") print("ce mode")
......
...@@ -130,6 +130,11 @@ def parse_args(): ...@@ -130,6 +130,11 @@ def parse_args():
type=float, type=float,
default=[0.229, 0.224, 0.225], default=[0.229, 0.224, 0.225],
help="The std of input image data") help="The std of input image data")
parser.add_argument(
'--use_gpu',
type=ast.literal_eval,
default=True,
help='default use gpu.')
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -354,8 +359,14 @@ def eval(model, data): ...@@ -354,8 +359,14 @@ def eval(model, data):
def train_resnet(): def train_resnet():
epoch = args.epoch epoch = args.epoch
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if args.use_data_parallel else fluid.CUDAPlace(0) if not args.use_gpu:
place = fluid.CPUPlace()
elif not args.use_data_parallel:
place = fluid.CUDAPlace(0)
else:
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id)
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
if args.ce: if args.ce:
print("ce mode") print("ce mode")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册