未验证 提交 f9f0d30e 编写于 作者: A arlesniak 提交者: GitHub

Enable CPU training for DyGraph MNIST Resnet (#4824)

上级 64cde5d1
......@@ -35,6 +35,12 @@ def parse_args():
)
parser.add_argument("-e", "--epoch", default=5, type=int, help="set epoch")
parser.add_argument("--ce", action="store_true", help="run ce")
parser.add_argument(
'--use_gpu',
type=ast.literal_eval,
default=True,
help='default use gpu.')
args = parser.parse_args()
return args
......@@ -149,8 +155,13 @@ def test_mnist(reader, model, batch_size):
def inference_mnist():
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if args.use_data_parallel else fluid.CUDAPlace(0)
if not args.use_gpu:
place = fluid.CPUPlace()
elif not args.use_data_parallel:
place = fluid.CUDAPlace(0)
else:
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id)
with fluid.dygraph.guard(place):
mnist_infer = MNIST()
# load checkpoint
......@@ -180,8 +191,13 @@ def train_mnist(args):
epoch_num = args.epoch
BATCH_SIZE = 64
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if args.use_data_parallel else fluid.CUDAPlace(0)
if not args.use_gpu:
place = fluid.CPUPlace()
elif not args.use_data_parallel:
place = fluid.CUDAPlace(0)
else:
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id)
with fluid.dygraph.guard(place):
if args.ce:
print("ce mode")
......
......@@ -130,6 +130,11 @@ def parse_args():
type=float,
default=[0.229, 0.224, 0.225],
help="The std of input image data")
parser.add_argument(
'--use_gpu',
type=ast.literal_eval,
default=True,
help='default use gpu.')
args = parser.parse_args()
return args
......@@ -354,8 +359,14 @@ def eval(model, data):
def train_resnet():
epoch = args.epoch
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if args.use_data_parallel else fluid.CUDAPlace(0)
if not args.use_gpu:
place = fluid.CPUPlace()
elif not args.use_data_parallel:
place = fluid.CUDAPlace(0)
else:
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id)
with fluid.dygraph.guard(place):
if args.ce:
print("ce mode")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册