From 78582b997e1444b5e51908e640611fdf55fef6fd Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Fri, 28 Oct 2022 11:28:11 +0000 Subject: [PATCH] fix: replace use_gpu, etc. by device --- ppcls/static/train.py | 24 +++++------------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/ppcls/static/train.py b/ppcls/static/train.py index c58aaf52..eed68d38 100644 --- a/ppcls/static/train.py +++ b/ppcls/static/train.py @@ -76,8 +76,12 @@ def main(args): if global_config.get("is_distributed", True): fleet.init(is_collective=True) + # assign the device - use_gpu = global_config.get("use_gpu", True) + assert global_config[ + "device"] in ["cpu", "gpu", "xpu", "npu", "mlu", "ascend"] + device = paddle.set_device(global_config["device"]) + # amp related config if 'AMP' in config: AMP_RELATED_FLAGS_SETTING = { @@ -89,24 +93,6 @@ def main(args): os.environ['FLAGS_cudnn_batchnorm_spatial_persistent'] = '1' paddle.set_flags(AMP_RELATED_FLAGS_SETTING) - use_xpu = global_config.get("use_xpu", False) - use_npu = global_config.get("use_npu", False) - use_mlu = global_config.get("use_mlu", False) - assert ( - use_gpu + use_xpu + use_npu + use_mlu + use_ascend <= 1 - ), "gpu, xpu, npu, mlu and ascend can not be true in the same time in static mode!" - - if use_gpu: - device = paddle.set_device('gpu') - elif use_xpu: - device = paddle.set_device('xpu') - elif use_npu: - device = paddle.set_device('npu') - elif use_mlu: - device = paddle.set_device('mlu') - else: - device = paddle.set_device('cpu') - # visualDL vdl_writer = None if global_config["use_visualdl"]: -- GitLab