diff --git a/ppdet/utils/cli.py b/ppdet/utils/cli.py index 7b7a89060a7de3fff7168d9773b5aae056fad011..1bec22894bbca2e87a6de75fc03c699ef31e89ce 100644 --- a/ppdet/utils/cli.py +++ b/ppdet/utils/cli.py @@ -15,6 +15,8 @@ from argparse import ArgumentParser, RawDescriptionHelpFormatter import yaml +import re +from ppdet.core.workspace import get_registered_modules __all__ = ['ColorTTY', 'ArgsParser'] @@ -42,13 +44,12 @@ class ColorTTY(object): class ArgsParser(ArgumentParser): - def __init__(self): super(ArgsParser, self).__init__( formatter_class=RawDescriptionHelpFormatter) self.add_argument("-c", "--config", help="configuration file to use") - self.add_argument("-o", "--opt", nargs='*', - help="set configuration options") + self.add_argument( + "-o", "--opt", nargs='*', help="set configuration options") def parse_args(self, argv=None): args = super(ArgsParser, self).parse_args(argv) @@ -78,3 +79,73 @@ class ArgsParser(ArgumentParser): cur[key] = {} cur = cur[key] return config + + +def print_total_cfg(config): + modules = get_registered_modules() + color_tty = ColorTTY() + green = '___{}___'.format(color_tty.colors.index('green') + 31) + + styled = {} + for key in config.keys(): + if not config[key]: # empty schema + continue + + if key not in modules and not hasattr(config[key], '__dict__'): + styled[key] = config[key] + continue + elif key in modules: + module = modules[key] + else: + type_name = type(config[key]).__name__ + if type_name in modules: + module = modules[type_name].copy() + module.update({ + k: v + for k, v in config[key].__dict__.items() + if k in module.schema + }) + key += " ({})".format(type_name) + default = module.find_default_keys() + missing = module.find_missing_keys() + mismatch = module.find_mismatch_keys() + extra = module.find_extra_keys() + dep_missing = [] + for dep in module.inject: + if isinstance(module[dep], str) and module[dep] != '': + if module[dep] not in modules: # not a valid module + dep_missing.append(dep) + else: + dep_mod = modules[module[dep]] + # empty dict but mandatory + if not dep_mod and dep_mod.mandatory(): + dep_missing.append(dep) + override = list( + set(module.keys()) - set(default) - set(extra) - set(dep_missing)) + replacement = {} + for name in set(override + default + extra + mismatch + missing): + new_name = name + if name in missing: + value = "" + else: + value = module[name] + + if name in extra: + value = dump_value(value) + " " + elif name in mismatch: + value = dump_value(value) + " " + elif name in dep_missing: + value = dump_value(value) + " " + elif name in override and value != '': + mark = green + new_name = mark + name + replacement[new_name] = value + styled[key] = replacement + buffer = yaml.dump(styled, default_flow_style=False, default_style='') + buffer = (re.sub(r"", r"[31m[0m", buffer)) + buffer = (re.sub(r"", r"[33m[0m", buffer)) + buffer = (re.sub(r"", r"[31m[0m", buffer)) + buffer = (re.sub(r"", + r"[31m[0m", buffer)) + buffer = re.sub(r"___(\d+)___(.*?):", r"[\1m\2[0m:", buffer) + print(buffer) diff --git a/tools/configure.py b/tools/configure.py index 4a565aed9da87d38af7442771160f9c18ee21f73..3bfc8b83ad21cb94149c51fb698ef7ced7e72f86 100644 --- a/tools/configure.py +++ b/tools/configure.py @@ -14,14 +14,13 @@ from __future__ import print_function -import re import sys from argparse import ArgumentParser, RawDescriptionHelpFormatter import yaml from ppdet.core.workspace import get_registered_modules, load_config -from ppdet.utils.cli import ColorTTY +from ppdet.utils.cli import ColorTTY, print_total_cfg color_tty = ColorTTY() @@ -151,75 +150,6 @@ def generate_config(**kwargs): print(dump_config(s, minimal)) -def print_total_cfg(config): - modules = get_registered_modules() - green = '___{}___'.format(color_tty.colors.index('green') + 31) - - styled = {} - for key in config.keys(): - if not config[key]: # empty schema - continue - - if key not in modules and not hasattr(config[key], '__dict__'): - styled[key] = config[key] - continue - elif key in modules: - module = modules[key] - else: - type_name = type(config[key]).__name__ - if type_name in modules: - module = modules[type_name].copy() - module.update({ - k: v - for k, v in config[key].__dict__.items() - if k in module.schema - }) - key += " ({})".format(type_name) - default = module.find_default_keys() - missing = module.find_missing_keys() - mismatch = module.find_mismatch_keys() - extra = module.find_extra_keys() - dep_missing = [] - for dep in module.inject: - if isinstance(module[dep], str) and module[dep] != '': - if module[dep] not in modules: # not a valid module - dep_missing.append(dep) - else: - dep_mod = modules[module[dep]] - # empty dict but mandatory - if not dep_mod and dep_mod.mandatory(): - dep_missing.append(dep) - override = list( - set(module.keys()) - set(default) - set(extra) - set(dep_missing)) - replacement = {} - for name in set(override + default + extra + mismatch + missing): - new_name = name - if name in missing: - value = "" - else: - value = module[name] - - if name in extra: - value = dump_value(value) + " " - elif name in mismatch: - value = dump_value(value) + " " - elif name in dep_missing: - value = dump_value(value) + " " - elif name in override and value != '': - mark = green - new_name = mark + name - replacement[new_name] = value - styled[key] = replacement - buffer = yaml.dump(styled, default_flow_style=False, default_style='') - buffer = (re.sub(r"", r"", buffer)) - buffer = (re.sub(r"", r"", buffer)) - buffer = (re.sub(r"", r"", buffer)) - buffer = (re.sub(r"", - r"", buffer)) - buffer = re.sub(r"___(\d+)___(.*?):", r"[\1m\2:", buffer) - print(buffer) - - # FIXME this is pretty hackish, maybe implement a custom YAML printer? def analyze_config(**kwargs): config = load_config(kwargs['file']) diff --git a/tools/eval.py b/tools/eval.py index 089beec22701facabd1dc6ea2b5373fc9069e0a0..5e596f020ec3018690cd24179ae00611c76e1910 100644 --- a/tools/eval.py +++ b/tools/eval.py @@ -33,7 +33,7 @@ set_paddle_flags( import paddle.fluid as fluid -from tools.configure import print_total_cfg +from ppdet.utils.cli import print_total_cfg from ppdet.utils.eval_utils import parse_fetches, eval_run, eval_results, json_eval_results import ppdet.utils.checkpoint as checkpoint from ppdet.utils.cli import ArgsParser diff --git a/tools/infer.py b/tools/infer.py index b6a4d66a0ac2e007b40521dd3761d726d5a51aa5..32e6040d2f3db5768d5191babc45e65368517466 100644 --- a/tools/infer.py +++ b/tools/infer.py @@ -37,7 +37,7 @@ set_paddle_flags( from paddle import fluid -from tools.configure import print_total_cfg +from ppdet.utils.cli import print_total_cfg from ppdet.core.workspace import load_config, merge_config, create from ppdet.modeling.model_input import create_feed from ppdet.data.data_feed import create_reader diff --git a/tools/train.py b/tools/train.py index 52062f33853bff574791c580cbfacae2a974bf77..fc14fa13ead7f7daea967abdb1f33c99c51d7089 100644 --- a/tools/train.py +++ b/tools/train.py @@ -21,7 +21,6 @@ import time import numpy as np import datetime from collections import deque -from tools.configure import print_total_cfg def set_paddle_flags(**kwargs): @@ -40,6 +39,7 @@ from paddle import fluid from ppdet.core.workspace import load_config, merge_config, create from ppdet.data.data_feed import create_reader +from ppdet.utils.cli import print_total_cfg from ppdet.utils.eval_utils import parse_fetches, eval_run, eval_results from ppdet.utils.stats import TrainingStats from ppdet.utils.cli import ArgsParser @@ -142,7 +142,8 @@ def main(): train_compile_program = fluid.compiler.CompiledProgram( train_prog).with_data_parallel( - loss_name=loss.name, build_strategy=build_strategy, + loss_name=loss.name, + build_strategy=build_strategy, exec_strategy=exec_strategy) if FLAGS.eval: eval_compile_program = fluid.compiler.CompiledProgram(eval_prog) @@ -159,13 +160,10 @@ def main(): elif cfg.pretrain_weights: checkpoint.load_pretrain(exe, train_prog, cfg.pretrain_weights) - train_reader = create_reader( - train_feed, - (cfg.max_iters - start_iter) * devices_num, - FLAGS.dataset_dir) + train_reader = create_reader(train_feed, (cfg.max_iters - start_iter) * + devices_num, FLAGS.dataset_dir) train_pyreader.decorate_sample_list_generator(train_reader, place) - # whether output bbox is normalized in model output layer is_bbox_normalized = False if hasattr(model, 'is_bbox_normalized') and \ @@ -230,12 +228,12 @@ def main(): box_ap_stats = eval_results( results, eval_feed, cfg.metric, cfg.num_classes, resolution, is_bbox_normalized, FLAGS.output_eval, map_type) - + # use tb_paddle to log mAP if FLAGS.use_tb: tb_writer.add_scalar("mAP", box_ap_stats[0], tb_mAP_step) tb_mAP_step += 1 - + if box_ap_stats[0] > best_box_ap_list[0]: best_box_ap_list[0] = box_ap_stats[0] best_box_ap_list[1] = it