提交 d6e52889 编写于 作者: W wangguanzhong 提交者: GitHub

Reconstruction for cfg printer (#3362)

上级 81fd696e
......@@ -15,6 +15,8 @@
from argparse import ArgumentParser, RawDescriptionHelpFormatter
import yaml
import re
from ppdet.core.workspace import get_registered_modules
__all__ = ['ColorTTY', 'ArgsParser']
......@@ -42,13 +44,12 @@ class ColorTTY(object):
class ArgsParser(ArgumentParser):
def __init__(self):
super(ArgsParser, self).__init__(
formatter_class=RawDescriptionHelpFormatter)
self.add_argument("-c", "--config", help="configuration file to use")
self.add_argument("-o", "--opt", nargs='*',
help="set configuration options")
self.add_argument(
"-o", "--opt", nargs='*', help="set configuration options")
def parse_args(self, argv=None):
args = super(ArgsParser, self).parse_args(argv)
......@@ -78,3 +79,73 @@ class ArgsParser(ArgumentParser):
cur[key] = {}
cur = cur[key]
return config
def print_total_cfg(config):
modules = get_registered_modules()
color_tty = ColorTTY()
green = '___{}___'.format(color_tty.colors.index('green') + 31)
styled = {}
for key in config.keys():
if not config[key]: # empty schema
continue
if key not in modules and not hasattr(config[key], '__dict__'):
styled[key] = config[key]
continue
elif key in modules:
module = modules[key]
else:
type_name = type(config[key]).__name__
if type_name in modules:
module = modules[type_name].copy()
module.update({
k: v
for k, v in config[key].__dict__.items()
if k in module.schema
})
key += " ({})".format(type_name)
default = module.find_default_keys()
missing = module.find_missing_keys()
mismatch = module.find_mismatch_keys()
extra = module.find_extra_keys()
dep_missing = []
for dep in module.inject:
if isinstance(module[dep], str) and module[dep] != '<value>':
if module[dep] not in modules: # not a valid module
dep_missing.append(dep)
else:
dep_mod = modules[module[dep]]
# empty dict but mandatory
if not dep_mod and dep_mod.mandatory():
dep_missing.append(dep)
override = list(
set(module.keys()) - set(default) - set(extra) - set(dep_missing))
replacement = {}
for name in set(override + default + extra + mismatch + missing):
new_name = name
if name in missing:
value = "<missing>"
else:
value = module[name]
if name in extra:
value = dump_value(value) + " <extraneous>"
elif name in mismatch:
value = dump_value(value) + " <type mismatch>"
elif name in dep_missing:
value = dump_value(value) + " <module config missing>"
elif name in override and value != '<missing>':
mark = green
new_name = mark + name
replacement[new_name] = value
styled[key] = replacement
buffer = yaml.dump(styled, default_flow_style=False, default_style='')
buffer = (re.sub(r"<missing>", r"[31m<missing>[0m", buffer))
buffer = (re.sub(r"<extraneous>", r"[33m<extraneous>[0m", buffer))
buffer = (re.sub(r"<type mismatch>", r"[31m<type mismatch>[0m", buffer))
buffer = (re.sub(r"<module config missing>",
r"[31m<module config missing>[0m", buffer))
buffer = re.sub(r"___(\d+)___(.*?):", r"[\1m\2[0m:", buffer)
print(buffer)
......@@ -14,14 +14,13 @@
from __future__ import print_function
import re
import sys
from argparse import ArgumentParser, RawDescriptionHelpFormatter
import yaml
from ppdet.core.workspace import get_registered_modules, load_config
from ppdet.utils.cli import ColorTTY
from ppdet.utils.cli import ColorTTY, print_total_cfg
color_tty = ColorTTY()
......@@ -151,75 +150,6 @@ def generate_config(**kwargs):
print(dump_config(s, minimal))
def print_total_cfg(config):
modules = get_registered_modules()
green = '___{}___'.format(color_tty.colors.index('green') + 31)
styled = {}
for key in config.keys():
if not config[key]: # empty schema
continue
if key not in modules and not hasattr(config[key], '__dict__'):
styled[key] = config[key]
continue
elif key in modules:
module = modules[key]
else:
type_name = type(config[key]).__name__
if type_name in modules:
module = modules[type_name].copy()
module.update({
k: v
for k, v in config[key].__dict__.items()
if k in module.schema
})
key += " ({})".format(type_name)
default = module.find_default_keys()
missing = module.find_missing_keys()
mismatch = module.find_mismatch_keys()
extra = module.find_extra_keys()
dep_missing = []
for dep in module.inject:
if isinstance(module[dep], str) and module[dep] != '<value>':
if module[dep] not in modules: # not a valid module
dep_missing.append(dep)
else:
dep_mod = modules[module[dep]]
# empty dict but mandatory
if not dep_mod and dep_mod.mandatory():
dep_missing.append(dep)
override = list(
set(module.keys()) - set(default) - set(extra) - set(dep_missing))
replacement = {}
for name in set(override + default + extra + mismatch + missing):
new_name = name
if name in missing:
value = "<missing>"
else:
value = module[name]
if name in extra:
value = dump_value(value) + " <extraneous>"
elif name in mismatch:
value = dump_value(value) + " <type mismatch>"
elif name in dep_missing:
value = dump_value(value) + " <module config missing>"
elif name in override and value != '<missing>':
mark = green
new_name = mark + name
replacement[new_name] = value
styled[key] = replacement
buffer = yaml.dump(styled, default_flow_style=False, default_style='')
buffer = (re.sub(r"<missing>", r"<missing>", buffer))
buffer = (re.sub(r"<extraneous>", r"<extraneous>", buffer))
buffer = (re.sub(r"<type mismatch>", r"<type mismatch>", buffer))
buffer = (re.sub(r"<module config missing>",
r"<module config missing>", buffer))
buffer = re.sub(r"___(\d+)___(.*?):", r"[\1m\2:", buffer)
print(buffer)
# FIXME this is pretty hackish, maybe implement a custom YAML printer?
def analyze_config(**kwargs):
config = load_config(kwargs['file'])
......
......@@ -33,7 +33,7 @@ set_paddle_flags(
import paddle.fluid as fluid
from tools.configure import print_total_cfg
from ppdet.utils.cli import print_total_cfg
from ppdet.utils.eval_utils import parse_fetches, eval_run, eval_results, json_eval_results
import ppdet.utils.checkpoint as checkpoint
from ppdet.utils.cli import ArgsParser
......
......@@ -37,7 +37,7 @@ set_paddle_flags(
from paddle import fluid
from tools.configure import print_total_cfg
from ppdet.utils.cli import print_total_cfg
from ppdet.core.workspace import load_config, merge_config, create
from ppdet.modeling.model_input import create_feed
from ppdet.data.data_feed import create_reader
......
......@@ -21,7 +21,6 @@ import time
import numpy as np
import datetime
from collections import deque
from tools.configure import print_total_cfg
def set_paddle_flags(**kwargs):
......@@ -40,6 +39,7 @@ from paddle import fluid
from ppdet.core.workspace import load_config, merge_config, create
from ppdet.data.data_feed import create_reader
from ppdet.utils.cli import print_total_cfg
from ppdet.utils.eval_utils import parse_fetches, eval_run, eval_results
from ppdet.utils.stats import TrainingStats
from ppdet.utils.cli import ArgsParser
......@@ -142,7 +142,8 @@ def main():
train_compile_program = fluid.compiler.CompiledProgram(
train_prog).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy,
loss_name=loss.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
if FLAGS.eval:
eval_compile_program = fluid.compiler.CompiledProgram(eval_prog)
......@@ -159,13 +160,10 @@ def main():
elif cfg.pretrain_weights:
checkpoint.load_pretrain(exe, train_prog, cfg.pretrain_weights)
train_reader = create_reader(
train_feed,
(cfg.max_iters - start_iter) * devices_num,
FLAGS.dataset_dir)
train_reader = create_reader(train_feed, (cfg.max_iters - start_iter) *
devices_num, FLAGS.dataset_dir)
train_pyreader.decorate_sample_list_generator(train_reader, place)
# whether output bbox is normalized in model output layer
is_bbox_normalized = False
if hasattr(model, 'is_bbox_normalized') and \
......@@ -230,12 +228,12 @@ def main():
box_ap_stats = eval_results(
results, eval_feed, cfg.metric, cfg.num_classes, resolution,
is_bbox_normalized, FLAGS.output_eval, map_type)
# use tb_paddle to log mAP
if FLAGS.use_tb:
tb_writer.add_scalar("mAP", box_ap_stats[0], tb_mAP_step)
tb_mAP_step += 1
if box_ap_stats[0] > best_box_ap_list[0]:
best_box_ap_list[0] = box_ap_stats[0]
best_box_ap_list[1] = it
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册