未验证 提交 18cbd6a7 编写于 作者: F Feng Ni 提交者: GitHub

[MOT] support xpu npu MOT infer (#5585)

上级 fd9b5c61
...@@ -28,11 +28,11 @@ import warnings ...@@ -28,11 +28,11 @@ import warnings
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
import paddle import paddle
from paddle.distributed import ParallelEnv
from ppdet.core.workspace import load_config, merge_config from ppdet.core.workspace import load_config, merge_config
from ppdet.engine import Tracker from ppdet.utils.check import check_gpu, check_npu, check_xpu, check_version, check_config
from ppdet.utils.check import check_gpu, check_version, check_config
from ppdet.utils.cli import ArgsParser from ppdet.utils.cli import ArgsParser
from ppdet.engine import Tracker
def parse_args(): def parse_args():
...@@ -104,12 +104,32 @@ def main(): ...@@ -104,12 +104,32 @@ def main():
cfg = load_config(FLAGS.config) cfg = load_config(FLAGS.config)
merge_config(FLAGS.opt) merge_config(FLAGS.opt)
# disable npu in config by default
if 'use_npu' not in cfg:
cfg.use_npu = False
# disable xpu in config by default
if 'use_xpu' not in cfg:
cfg.use_xpu = False
if cfg.use_gpu:
place = paddle.set_device('gpu')
elif cfg.use_npu:
place = paddle.set_device('npu')
elif cfg.use_xpu:
place = paddle.set_device('xpu')
else:
place = paddle.set_device('cpu')
if 'norm_type' in cfg and cfg['norm_type'] == 'sync_bn' and not cfg.use_gpu:
cfg['norm_type'] = 'bn'
check_config(cfg) check_config(cfg)
check_gpu(cfg.use_gpu) check_gpu(cfg.use_gpu)
check_npu(cfg.use_npu)
check_xpu(cfg.use_xpu)
check_version() check_version()
place = 'gpu:{}'.format(ParallelEnv().dev_id) if cfg.use_gpu else 'cpu'
place = paddle.set_device(place)
run(FLAGS, cfg) run(FLAGS, cfg)
......
...@@ -28,10 +28,9 @@ import warnings ...@@ -28,10 +28,9 @@ import warnings
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
import paddle import paddle
from paddle.distributed import ParallelEnv
from ppdet.core.workspace import load_config, merge_config from ppdet.core.workspace import load_config, merge_config
from ppdet.engine import Tracker from ppdet.engine import Tracker
from ppdet.utils.check import check_gpu, check_version, check_config from ppdet.utils.check import check_gpu, check_npu, check_xpu, check_version, check_config
from ppdet.utils.cli import ArgsParser from ppdet.utils.cli import ArgsParser
...@@ -117,12 +116,32 @@ def main(): ...@@ -117,12 +116,32 @@ def main():
cfg = load_config(FLAGS.config) cfg = load_config(FLAGS.config)
merge_config(FLAGS.opt) merge_config(FLAGS.opt)
# disable npu in config by default
if 'use_npu' not in cfg:
cfg.use_npu = False
# disable xpu in config by default
if 'use_xpu' not in cfg:
cfg.use_xpu = False
if cfg.use_gpu:
place = paddle.set_device('gpu')
elif cfg.use_npu:
place = paddle.set_device('npu')
elif cfg.use_xpu:
place = paddle.set_device('xpu')
else:
place = paddle.set_device('cpu')
if 'norm_type' in cfg and cfg['norm_type'] == 'sync_bn' and not cfg.use_gpu:
cfg['norm_type'] = 'bn'
check_config(cfg) check_config(cfg)
check_gpu(cfg.use_gpu) check_gpu(cfg.use_gpu)
check_npu(cfg.use_npu)
check_xpu(cfg.use_xpu)
check_version() check_version()
place = 'gpu:{}'.format(ParallelEnv().dev_id) if cfg.use_gpu else 'cpu'
place = paddle.set_device(place)
run(FLAGS, cfg) run(FLAGS, cfg)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册