diff --git a/tools/static/train.py b/tools/static/train.py index 2b44befade66ea40931f1885d5df872547311f6d..e3ece9c7614bc64104e3e2e2150b8c0a9360ca59 100644 --- a/tools/static/train.py +++ b/tools/static/train.py @@ -63,15 +63,18 @@ def main(args): config = get_config(args.config, overrides=args.override, show=True) # assign the place - use_gpu = config.get("use_gpu", True) + use_gpu = config.get("use_gpu", False) use_xpu = config.get("use_xpu", False) - assert (use_gpu or use_xpu - ) is True, "gpu or xpu must be true in static mode!" assert ( use_gpu and use_xpu ) is not True, "gpu and xpu can not be true in the same time in static mode!" - place = paddle.set_device('gpu' if use_gpu else 'xpu') + if use_gpu: + place = paddle.set_device('gpu') + elif use_xpu: + place = paddle.set_device('xpu') + else: + place = paddle.set_device('cpu') # startup_prog is used to do some parameter init work, # and train prog is used to hold the network