diff --git a/tools/eval_mot.py b/tools/eval_mot.py index cce7085171e296760280532bfb8b6acacf4c8011..7769892730adbb52e71751d3ed9de97fe70e6ecb 100644 --- a/tools/eval_mot.py +++ b/tools/eval_mot.py @@ -28,11 +28,11 @@ import warnings warnings.filterwarnings('ignore') import paddle -from paddle.distributed import ParallelEnv + from ppdet.core.workspace import load_config, merge_config -from ppdet.engine import Tracker -from ppdet.utils.check import check_gpu, check_version, check_config +from ppdet.utils.check import check_gpu, check_npu, check_xpu, check_version, check_config from ppdet.utils.cli import ArgsParser +from ppdet.engine import Tracker def parse_args(): @@ -104,9 +104,9 @@ def main(): cfg = load_config(FLAGS.config) merge_config(FLAGS.opt) - check_config(cfg) - check_gpu(cfg.use_gpu) - check_version() + # disable npu in config by default + if 'use_npu' not in cfg: + cfg.use_npu = False # disable xpu in config by default if 'use_xpu' not in cfg: @@ -114,11 +114,22 @@ def main(): if cfg.use_gpu: place = paddle.set_device('gpu') + elif cfg.use_npu: + place = paddle.set_device('npu') elif cfg.use_xpu: place = paddle.set_device('xpu') else: place = paddle.set_device('cpu') + if 'norm_type' in cfg and cfg['norm_type'] == 'sync_bn' and not cfg.use_gpu: + cfg['norm_type'] = 'bn' + + check_config(cfg) + check_gpu(cfg.use_gpu) + check_npu(cfg.use_npu) + check_xpu(cfg.use_xpu) + check_version() + run(FLAGS, cfg) diff --git a/tools/infer_mot.py b/tools/infer_mot.py index aa2e3f88fa56c88d46fa4e83f4bd781d367a6e84..9ede62b617b30b9ffd7aea701eb85f04747950ac 100644 --- a/tools/infer_mot.py +++ b/tools/infer_mot.py @@ -28,10 +28,9 @@ import warnings warnings.filterwarnings('ignore') import paddle -from paddle.distributed import ParallelEnv from ppdet.core.workspace import load_config, merge_config from ppdet.engine import Tracker -from ppdet.utils.check import check_gpu, check_version, check_config +from ppdet.utils.check import check_gpu, check_npu, check_xpu, check_version, check_config from ppdet.utils.cli import ArgsParser @@ -117,12 +116,32 @@ def main(): cfg = load_config(FLAGS.config) merge_config(FLAGS.opt) + # disable npu in config by default + if 'use_npu' not in cfg: + cfg.use_npu = False + + # disable xpu in config by default + if 'use_xpu' not in cfg: + cfg.use_xpu = False + + if cfg.use_gpu: + place = paddle.set_device('gpu') + elif cfg.use_npu: + place = paddle.set_device('npu') + elif cfg.use_xpu: + place = paddle.set_device('xpu') + else: + place = paddle.set_device('cpu') + + if 'norm_type' in cfg and cfg['norm_type'] == 'sync_bn' and not cfg.use_gpu: + cfg['norm_type'] = 'bn' + check_config(cfg) check_gpu(cfg.use_gpu) + check_npu(cfg.use_npu) + check_xpu(cfg.use_xpu) check_version() - place = 'gpu:{}'.format(ParallelEnv().dev_id) if cfg.use_gpu else 'cpu' - place = paddle.set_device(place) run(FLAGS, cfg)