From 066d53f8ecec279913f886527472eceb3d7a774c Mon Sep 17 00:00:00 2001 From: QingshuChen Date: Tue, 8 Dec 2020 20:59:23 +0800 Subject: [PATCH] support cpu/xpu/gpu in static graph (#460) --- tools/static/train.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tools/static/train.py b/tools/static/train.py index 2b44befa..e3ece9c7 100644 --- a/tools/static/train.py +++ b/tools/static/train.py @@ -63,15 +63,18 @@ def main(args): config = get_config(args.config, overrides=args.override, show=True) # assign the place - use_gpu = config.get("use_gpu", True) + use_gpu = config.get("use_gpu", False) use_xpu = config.get("use_xpu", False) - assert (use_gpu or use_xpu - ) is True, "gpu or xpu must be true in static mode!" assert ( use_gpu and use_xpu ) is not True, "gpu and xpu can not be true in the same time in static mode!" - place = paddle.set_device('gpu' if use_gpu else 'xpu') + if use_gpu: + place = paddle.set_device('gpu') + elif use_xpu: + place = paddle.set_device('xpu') + else: + place = paddle.set_device('cpu') # startup_prog is used to do some parameter init work, # and train prog is used to hold the network -- GitLab