From 083ff38550a5c232e8472f1bbe93b3b9c7ec6bd5 Mon Sep 17 00:00:00 2001 From: wangxinxin08 <69842442+wangxinxin08@users.noreply.github.com> Date: Thu, 12 May 2022 11:08:04 +0800 Subject: [PATCH] polish args (#5941) * polish args * fix import bugs --- ppdet/engine/trainer.py | 2 +- ppdet/utils/cli.py | 7 +++++++ tools/eval.py | 11 ++--------- tools/eval_mot.py | 3 --- tools/export_model.py | 3 --- tools/infer.py | 8 ++------ tools/infer_mot.py | 3 --- tools/train.py | 16 +++------------- 8 files changed, 15 insertions(+), 38 deletions(-) diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index e124855f9..de188e8e1 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -204,7 +204,7 @@ class Trainer(object): classwise = self.cfg['classwise'] if 'classwise' in self.cfg else False if self.cfg.metric == 'COCO' or self.cfg.metric == "SNIPERCOCO": # TODO: bias should be unified - bias = self.cfg['bias'] if 'bias' in self.cfg else 0 + bias = 1 if self.cfg.get('bias', False) else 0 output_eval = self.cfg['output_eval'] \ if 'output_eval' in self.cfg else None save_prediction_only = self.cfg.get('save_prediction_only', False) diff --git a/ppdet/utils/cli.py b/ppdet/utils/cli.py index b8ba59d78..2c5acc0e5 100644 --- a/ppdet/utils/cli.py +++ b/ppdet/utils/cli.py @@ -81,6 +81,13 @@ class ArgsParser(ArgumentParser): return config +def merge_args(config, args, exclude_args=['config', 'opt', 'slim_config']): + for k, v in vars(args).items(): + if k not in exclude_args: + config[k] = v + return config + + def print_total_cfg(config): modules = get_registered_modules() color_tty = ColorTTY() diff --git a/tools/eval.py b/tools/eval.py index 308dd8fbb..3128261b0 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -31,7 +31,7 @@ import paddle from ppdet.core.workspace import load_config, merge_config from ppdet.utils.check import check_gpu, check_npu, check_xpu, check_version, check_config -from ppdet.utils.cli import ArgsParser +from ppdet.utils.cli import ArgsParser, merge_args from ppdet.engine import Trainer, init_parallel_env from ppdet.metrics.coco_utils import json_eval_results from ppdet.slim import build_slim_model @@ -109,11 +109,7 @@ def run(FLAGS, cfg): def main(): FLAGS = parse_args() cfg = load_config(FLAGS.config) - # TODO: bias should be unified - cfg['bias'] = 1 if FLAGS.bias else 0 - cfg['classwise'] = True if FLAGS.classwise else False - cfg['output_eval'] = FLAGS.output_eval - cfg['save_prediction_only'] = FLAGS.save_prediction_only + merge_args(cfg, FLAGS) merge_config(FLAGS.opt) # disable npu in config by default @@ -133,9 +129,6 @@ def main(): else: place = paddle.set_device('cpu') - if 'norm_type' in cfg and cfg['norm_type'] == 'sync_bn' and not cfg.use_gpu: - cfg['norm_type'] = 'bn' - if FLAGS.slim_config: cfg = build_slim_model(cfg, FLAGS.slim_config, mode='eval') diff --git a/tools/eval_mot.py b/tools/eval_mot.py index 776989273..a9ca51703 100644 --- a/tools/eval_mot.py +++ b/tools/eval_mot.py @@ -121,9 +121,6 @@ def main(): else: place = paddle.set_device('cpu') - if 'norm_type' in cfg and cfg['norm_type'] == 'sync_bn' and not cfg.use_gpu: - cfg['norm_type'] = 'bn' - check_config(cfg) check_gpu(cfg.use_gpu) check_npu(cfg.use_npu) diff --git a/tools/export_model.py b/tools/export_model.py index 3a417a37c..1eaac7a76 100644 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -90,9 +90,6 @@ def main(): paddle.set_device("cpu") FLAGS = parse_args() cfg = load_config(FLAGS.config) - # TODO: to be refined in the future - if 'norm_type' in cfg and cfg['norm_type'] == 'sync_bn': - FLAGS.opt['norm_type'] = 'bn' merge_config(FLAGS.opt) if FLAGS.slim_config: diff --git a/tools/infer.py b/tools/infer.py index 1980ced03..3a5674e7b 100755 --- a/tools/infer.py +++ b/tools/infer.py @@ -32,7 +32,7 @@ import paddle from ppdet.core.workspace import load_config, merge_config from ppdet.engine import Trainer from ppdet.utils.check import check_gpu, check_npu, check_xpu, check_version, check_config -from ppdet.utils.cli import ArgsParser +from ppdet.utils.cli import ArgsParser, merge_args from ppdet.slim import build_slim_model from ppdet.utils.logger import setup_logger @@ -137,8 +137,7 @@ def run(FLAGS, cfg): def main(): FLAGS = parse_args() cfg = load_config(FLAGS.config) - cfg['use_vdl'] = FLAGS.use_vdl - cfg['vdl_log_dir'] = FLAGS.vdl_log_dir + merge_args(cfg, FLAGS) merge_config(FLAGS.opt) # disable npu in config by default @@ -158,9 +157,6 @@ def main(): else: place = paddle.set_device('cpu') - if 'norm_type' in cfg and cfg['norm_type'] == 'sync_bn' and not cfg.use_gpu: - cfg['norm_type'] = 'bn' - if FLAGS.slim_config: cfg = build_slim_model(cfg, FLAGS.slim_config, mode='test') diff --git a/tools/infer_mot.py b/tools/infer_mot.py index 9ede62b61..ef13bff93 100644 --- a/tools/infer_mot.py +++ b/tools/infer_mot.py @@ -133,9 +133,6 @@ def main(): else: place = paddle.set_device('cpu') - if 'norm_type' in cfg and cfg['norm_type'] == 'sync_bn' and not cfg.use_gpu: - cfg['norm_type'] = 'bn' - check_config(cfg) check_gpu(cfg.use_gpu) check_npu(cfg.use_npu) diff --git a/tools/train.py b/tools/train.py index ddbf24fda..8e4977e77 100755 --- a/tools/train.py +++ b/tools/train.py @@ -33,14 +33,14 @@ from ppdet.core.workspace import load_config, merge_config from ppdet.engine import Trainer, init_parallel_env, set_random_seed, init_fleet_env from ppdet.slim import build_slim_model -import ppdet.utils.cli as cli +from ppdet.utils.cli import ArgsParser, merge_args import ppdet.utils.check as check from ppdet.utils.logger import setup_logger logger = setup_logger('train') def parse_args(): - parser = cli.ArgsParser() + parser = ArgsParser() parser.add_argument( "--eval", action='store_true', @@ -130,14 +130,7 @@ def run(FLAGS, cfg): def main(): FLAGS = parse_args() cfg = load_config(FLAGS.config) - cfg['amp'] = FLAGS.amp - cfg['fleet'] = FLAGS.fleet - cfg['use_vdl'] = FLAGS.use_vdl - cfg['vdl_log_dir'] = FLAGS.vdl_log_dir - cfg['save_prediction_only'] = FLAGS.save_prediction_only - cfg['profiler_options'] = FLAGS.profiler_options - cfg['save_proposals'] = FLAGS.save_proposals - cfg['proposals_path'] = FLAGS.proposals_path + merge_args(cfg, FLAGS) merge_config(FLAGS.opt) # disable npu in config by default @@ -157,9 +150,6 @@ def main(): else: place = paddle.set_device('cpu') - if 'norm_type' in cfg and cfg['norm_type'] == 'sync_bn' and not cfg.use_gpu: - cfg['norm_type'] = 'bn' - if FLAGS.slim_config: cfg = build_slim_model(cfg, FLAGS.slim_config) -- GitLab