diff --git a/configs/det/det_mv3_db.yml b/configs/det/det_mv3_db.yml index 1fab509d12167f0cfa3bb77cf21173c68af55737..6edf0b9194ee59143e287394f505b60010ec6644 100644 --- a/configs/det/det_mv3_db.yml +++ b/configs/det/det_mv3_db.yml @@ -1,5 +1,6 @@ Global: use_gpu: true + use_xpu: false epoch_num: 1200 log_smooth_window: 20 print_batch_step: 10 diff --git a/tools/program.py b/tools/program.py index c5b0e69b2d7256a1efe6b13efeea265cfcb3f5df..e92bef330056a2fe5ca53ed31f02422f43bbee4c 100755 --- a/tools/program.py +++ b/tools/program.py @@ -130,6 +130,25 @@ def check_gpu(use_gpu): pass +def check_xpu(use_xpu): + """ + Log error and exit when set use_xpu=true in paddlepaddle + cpu/gpu version. + """ + err = "Config use_xpu cannot be set as true while you are " \ + "using paddlepaddle cpu/gpu version ! \nPlease try: \n" \ + "\t1. Install paddlepaddle-xpu to run model on XPU \n" \ + "\t2. Set use_xpu as false in config file to run " \ + "model on CPU/GPU" + + try: + if use_xpu and not paddle.is_compiled_with_xpu(): + print(err) + sys.exit(1) + except Exception as e: + pass + + def train(config, train_dataloader, valid_dataloader, @@ -512,6 +531,12 @@ def preprocess(is_train=False): use_gpu = config['Global']['use_gpu'] check_gpu(use_gpu) + # check if set use_xpu=True in paddlepaddle cpu/gpu version + use_xpu = False + if 'use_xpu' in config['Global']: + use_xpu = config['Global']['use_xpu'] + check_xpu(use_xpu) + alg = config['Architecture']['algorithm'] assert alg in [ 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', @@ -519,7 +544,11 @@ def preprocess(is_train=False): 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM' ] - device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' + device = 'cpu' + if use_gpu: + device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) + if use_xpu: + device = 'xpu' device = paddle.set_device(device) config['Global']['distributed'] = dist.get_world_size() != 1