From 49ecf9c3bc9e6154360a84f402d8b669580b6dd3 Mon Sep 17 00:00:00 2001 From: zhangyikun02 <1129622649@qq.com> Date: Wed, 23 Feb 2022 08:31:16 +0000 Subject: [PATCH] add use_xpu config for det_mv3_db.yml --- configs/det/det_mv3_db.yml | 1 + tools/program.py | 31 ++++++++++++++++++++++++++++++- 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/configs/det/det_mv3_db.yml b/configs/det/det_mv3_db.yml index 1fab509d..6edf0b91 100644 --- a/configs/det/det_mv3_db.yml +++ b/configs/det/det_mv3_db.yml @@ -1,5 +1,6 @@ Global: use_gpu: true + use_xpu: false epoch_num: 1200 log_smooth_window: 20 print_batch_step: 10 diff --git a/tools/program.py b/tools/program.py index c5b0e69b..e92bef33 100755 --- a/tools/program.py +++ b/tools/program.py @@ -130,6 +130,25 @@ def check_gpu(use_gpu): pass +def check_xpu(use_xpu): + """ + Log error and exit when set use_xpu=true in paddlepaddle + cpu/gpu version. + """ + err = "Config use_xpu cannot be set as true while you are " \ + "using paddlepaddle cpu/gpu version ! \nPlease try: \n" \ + "\t1. Install paddlepaddle-xpu to run model on XPU \n" \ + "\t2. Set use_xpu as false in config file to run " \ + "model on CPU/GPU" + + try: + if use_xpu and not paddle.is_compiled_with_xpu(): + print(err) + sys.exit(1) + except Exception as e: + pass + + def train(config, train_dataloader, valid_dataloader, @@ -512,6 +531,12 @@ def preprocess(is_train=False): use_gpu = config['Global']['use_gpu'] check_gpu(use_gpu) + # check if set use_xpu=True in paddlepaddle cpu/gpu version + use_xpu = False + if 'use_xpu' in config['Global']: + use_xpu = config['Global']['use_xpu'] + check_xpu(use_xpu) + alg = config['Architecture']['algorithm'] assert alg in [ 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', @@ -519,7 +544,11 @@ def preprocess(is_train=False): 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM' ] - device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' + device = 'cpu' + if use_gpu: + device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) + if use_xpu: + device = 'xpu' device = paddle.set_device(device) config['Global']['distributed'] = dist.get_world_size() != 1 -- GitLab