From cf1b5ce991cd435fdcf8a84d117e119d5d2eb919 Mon Sep 17 00:00:00 2001 From: Feng Ni Date: Tue, 5 Apr 2022 21:08:51 +0800 Subject: [PATCH] [MOT] support xpu npu MOT infer (#5584) --- tools/eval_mot.py | 23 +++++++++++++++++------ tools/infer_mot.py | 27 +++++++++++++++++++++++---- 2 files changed, 40 insertions(+), 10 deletions(-) diff --git a/tools/eval_mot.py b/tools/eval_mot.py index cce708517..776989273 100644 --- a/tools/eval_mot.py +++ b/tools/eval_mot.py @@ -28,11 +28,11 @@ import warnings warnings.filterwarnings('ignore') import paddle -from paddle.distributed import ParallelEnv + from ppdet.core.workspace import load_config, merge_config -from ppdet.engine import Tracker -from ppdet.utils.check import check_gpu, check_version, check_config +from ppdet.utils.check import check_gpu, check_npu, check_xpu, check_version, check_config from ppdet.utils.cli import ArgsParser +from ppdet.engine import Tracker def parse_args(): @@ -104,9 +104,9 @@ def main(): cfg = load_config(FLAGS.config) merge_config(FLAGS.opt) - check_config(cfg) - check_gpu(cfg.use_gpu) - check_version() + # disable npu in config by default + if 'use_npu' not in cfg: + cfg.use_npu = False # disable xpu in config by default if 'use_xpu' not in cfg: @@ -114,11 +114,22 @@ def main(): if cfg.use_gpu: place = paddle.set_device('gpu') + elif cfg.use_npu: + place = paddle.set_device('npu') elif cfg.use_xpu: place = paddle.set_device('xpu') else: place = paddle.set_device('cpu') + if 'norm_type' in cfg and cfg['norm_type'] == 'sync_bn' and not cfg.use_gpu: + cfg['norm_type'] = 'bn' + + check_config(cfg) + check_gpu(cfg.use_gpu) + check_npu(cfg.use_npu) + check_xpu(cfg.use_xpu) + check_version() + run(FLAGS, cfg) diff --git a/tools/infer_mot.py b/tools/infer_mot.py index aa2e3f88f..9ede62b61 100644 --- a/tools/infer_mot.py +++ b/tools/infer_mot.py @@ -28,10 +28,9 @@ import warnings warnings.filterwarnings('ignore') import paddle -from paddle.distributed import ParallelEnv from ppdet.core.workspace import load_config, merge_config from ppdet.engine import Tracker -from ppdet.utils.check import check_gpu, check_version, check_config +from ppdet.utils.check import check_gpu, check_npu, check_xpu, check_version, check_config from ppdet.utils.cli import ArgsParser @@ -117,12 +116,32 @@ def main(): cfg = load_config(FLAGS.config) merge_config(FLAGS.opt) + # disable npu in config by default + if 'use_npu' not in cfg: + cfg.use_npu = False + + # disable xpu in config by default + if 'use_xpu' not in cfg: + cfg.use_xpu = False + + if cfg.use_gpu: + place = paddle.set_device('gpu') + elif cfg.use_npu: + place = paddle.set_device('npu') + elif cfg.use_xpu: + place = paddle.set_device('xpu') + else: + place = paddle.set_device('cpu') + + if 'norm_type' in cfg and cfg['norm_type'] == 'sync_bn' and not cfg.use_gpu: + cfg['norm_type'] = 'bn' + check_config(cfg) check_gpu(cfg.use_gpu) + check_npu(cfg.use_npu) + check_xpu(cfg.use_xpu) check_version() - place = 'gpu:{}'.format(ParallelEnv().dev_id) if cfg.use_gpu else 'cpu' - place = paddle.set_device(place) run(FLAGS, cfg) -- GitLab