diff --git a/configs/det/det_mv3_db.yml b/configs/det/det_mv3_db.yml index 2f39fbd232fa4bcab4cd30622d21c56d11a72d31..8f5685ec2a314c4b6a00c6c636f36b9c9c5daf00 100644 --- a/configs/det/det_mv3_db.yml +++ b/configs/det/det_mv3_db.yml @@ -1,6 +1,7 @@ Global: use_gpu: true use_xpu: false + use_mlu: false epoch_num: 1200 log_smooth_window: 20 print_batch_step: 10 diff --git a/tools/program.py b/tools/program.py index 9117d51b95b343c46982f212d4e5faa069b7b44a..8bbd233550a165e7ab4206209dcce5c539b1cedc 100755 --- a/tools/program.py +++ b/tools/program.py @@ -114,7 +114,7 @@ def merge_config(config, opts): return config -def check_device(use_gpu, use_xpu=False, use_npu=False): +def check_device(use_gpu, use_xpu=False, use_npu=False, use_mlu=False): """ Log error and exit when set use_gpu=true in paddlepaddle cpu version. @@ -137,6 +137,9 @@ def check_device(use_gpu, use_xpu=False, use_npu=False): if use_npu and not paddle.device.is_compiled_with_npu(): print(err.format("use_npu", "npu", "npu", "use_npu")) sys.exit(1) + if use_mlu and not paddle.device.is_compiled_with_mlu(): + print(err.format("use_mlu", "mlu", "mlu", "use_mlu")) + sys.exit(1) except Exception as e: pass @@ -618,6 +621,7 @@ def preprocess(is_train=False): use_gpu = config['Global'].get('use_gpu', False) use_xpu = config['Global'].get('use_xpu', False) use_npu = config['Global'].get('use_npu', False) + use_mlu = config['Global'].get('use_mlu', False) alg = config['Architecture']['algorithm'] assert alg in [ @@ -632,10 +636,12 @@ def preprocess(is_train=False): device = 'xpu:{0}'.format(os.getenv('FLAGS_selected_xpus', 0)) elif use_npu: device = 'npu:{0}'.format(os.getenv('FLAGS_selected_npus', 0)) + elif use_mlu: + device = 'mlu:{0}'.format(os.getenv('FLAGS_selected_mlus', 0)) else: device = 'gpu:{}'.format(dist.ParallelEnv() .dev_id) if use_gpu else 'cpu' - check_device(use_gpu, use_xpu, use_npu) + check_device(use_gpu, use_xpu, use_npu, use_mlu) device = paddle.set_device(device) diff --git a/tools/train.py b/tools/train.py index 970a52624af7b2831d88956f857cd4271086bcca..ff261e85fec10ec974ff763d6c3747faaa47c8d9 100755 --- a/tools/train.py +++ b/tools/train.py @@ -149,10 +149,11 @@ def main(config, device, logger, vdl_writer): amp_level = config["Global"].get("amp_level", 'O2') amp_custom_black_list = config['Global'].get('amp_custom_black_list', []) if use_amp: - AMP_RELATED_FLAGS_SETTING = { - 'FLAGS_cudnn_batchnorm_spatial_persistent': 1, - 'FLAGS_max_inplace_grad_add': 8, - } + AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, } + if paddle.is_compiled_with_cuda(): + AMP_RELATED_FLAGS_SETTING.update({ + 'FLAGS_cudnn_batchnorm_spatial_persistent': 1 + }) paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING) scale_loss = config["Global"].get("scale_loss", 1.0) use_dynamic_loss_scaling = config["Global"].get(