From d5f0a14126500d7487755d86f53d6f0c0ee67570 Mon Sep 17 00:00:00 2001 From: shippingwang Date: Wed, 6 May 2020 08:35:11 +0000 Subject: [PATCH] Coloring and refine code --- ppcls/utils/config.py | 6 +++--- ppcls/utils/logger.py | 30 +++++++++++++++++++++++------- ppcls/utils/save_load.py | 12 +++++------- tools/program.py | 15 ++++++++++----- tools/train.py | 16 +++++++++------- 5 files changed, 50 insertions(+), 29 deletions(-) diff --git a/ppcls/utils/config.py b/ppcls/utils/config.py index b1c1be4e..93b11569 100644 --- a/ppcls/utils/config.py +++ b/ppcls/utils/config.py @@ -64,14 +64,14 @@ def print_dict(d, delimiter=0): placeholder = "-" * 60 for k, v in sorted(d.items()): if isinstance(v, dict): - logger.info("{}{} : ".format(delimiter * " ", k)) + logger.info("{}{} : ".format(delimiter * " ", logger.coloring(k, "HEADER"))) print_dict(v, delimiter + 4) elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict): - logger.info("{}{} : ".format(delimiter * " ", k)) + logger.info("{}{} : ".format(delimiter * " ", logger.coloring(str(k),"HEADER"))) for value in v: print_dict(value, delimiter + 4) else: - logger.info("{}{} : {}".format(delimiter * " ", k, v)) + logger.info("{}{} : {}".format(delimiter * " ", logger.coloring(k,"HEADER"), logger.coloring(v,"OKGREEN"))) if k.isupper(): logger.info(placeholder) diff --git a/ppcls/utils/logger.py b/ppcls/utils/logger.py index 5b4ae2ca..22e4f63b 100644 --- a/ppcls/utils/logger.py +++ b/ppcls/utils/logger.py @@ -15,14 +15,30 @@ import logging import os -logging.basicConfig(level=logging.INFO) +logging.basicConfig(level=logging.INFO, format='%(message)s') _logger = logging.getLogger(__name__) +Color= { + 'RED' : '\033[31m' , + 'HEADER' : '\033[35m' , # deep purple + 'PURPLE' : '\033[95m' ,# purple + 'OKBLUE' : '\033[94m' , + 'OKGREEN' : '\033[92m' , + 'WARNING' : '\033[93m' , + 'FAIL' : '\033[91m' , + 'ENDC' : '\033[0m' } + +def coloring(message, color="OKGREEN"): + assert color in Color.keys() + if os.environ.get('PADDLECLAS_COLORING', False): + return Color[color]+str(message)+Color["ENDC"] + else: + return message def anti_fleet(log): """ - Because of the fucking Fleet, logs will print multi-times. - So we only display one of them and ignore the others. + logs will print multi-times when calling Fleet API. + Only display single log and ignore the others. """ def wrapper(fmt, *args): @@ -39,12 +55,12 @@ def info(fmt, *args): @anti_fleet def warning(fmt, *args): - _logger.warning(fmt, *args) + _logger.warning(coloring(fmt, "RED"), *args) @anti_fleet def error(fmt, *args): - _logger.error(fmt, *args) + _logger.error(coloring(fmt, "FAIL"), *args) def advertise(): @@ -66,7 +82,7 @@ def advertise(): website = "https://github.com/PaddlePaddle/PaddleClas" AD_LEN = 6 + len(max([copyright, ad, website], key=len)) - info("\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format( + info(coloring("\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format( "=" * (AD_LEN + 4), "=={}==".format(copyright.center(AD_LEN)), "=" * (AD_LEN + 4), @@ -74,4 +90,4 @@ def advertise(): "=={}==".format(ad.center(AD_LEN)), "=={}==".format(' ' * AD_LEN), "=={}==".format(website.center(AD_LEN)), - "=" * (AD_LEN + 4), )) + "=" * (AD_LEN + 4), ),"RED")) diff --git a/ppcls/utils/save_load.py b/ppcls/utils/save_load.py index 673e5430..aed3b88c 100644 --- a/ppcls/utils/save_load.py +++ b/ppcls/utils/save_load.py @@ -46,7 +46,6 @@ def _mkdir_if_not_exist(path): def _load_state(path): - print("path: ", path) if os.path.exists(path + '.pdopt'): # XXX another hack to ignore the optimizer state tmp = tempfile.mkdtemp() @@ -55,7 +54,6 @@ def _load_state(path): state = fluid.io.load_program_state(dst) shutil.rmtree(tmp) else: - print("path: ", path) state = fluid.io.load_program_state(path) return state @@ -75,7 +73,7 @@ def load_params(exe, prog, path, ignore_params=[]): raise ValueError("Model pretrain path {} does not " "exists.".format(path)) - logger.info('Loading parameters from {}...'.format(path)) + logger.info(logger.coloring('Loading parameters from {}...'.format(path), 'HEADER')) ignore_set = set() state = _load_state(path) @@ -101,7 +99,7 @@ def load_params(exe, prog, path, ignore_params=[]): if len(ignore_set) > 0: for k in ignore_set: if k in state: - logger.warning('variable {} not used'.format(k)) + logger.warning('variable {} is already excluded automatically'.format(k)) del state[k] fluid.io.set_program_state(prog, state) @@ -113,7 +111,7 @@ def init_model(config, program, exe): checkpoints = config.get('checkpoints') if checkpoints: fluid.load(program, checkpoints, exe) - logger.info("Finish initing model from {}".format(checkpoints)) + logger.info(logger.coloring("Finish initing model from {}".format(checkpoints),"HEADER")) return pretrained_model = config.get('pretrained_model') @@ -122,7 +120,7 @@ def init_model(config, program, exe): pretrained_model = [pretrained_model] for pretrain in pretrained_model: load_params(exe, program, pretrain) - logger.info("Finish initing model from {}".format(pretrained_model)) + logger.info(logger.coloring("Finish initing model from {}".format(pretrained_model),"HEADER")) def save_model(program, model_path, epoch_id, prefix='ppcls'): @@ -133,4 +131,4 @@ def save_model(program, model_path, epoch_id, prefix='ppcls'): _mkdir_if_not_exist(model_path) model_prefix = os.path.join(model_path, prefix) fluid.save(program, model_prefix) - logger.info("Already save model in {}".format(model_path)) + logger.info(logger.coloring("Already save model in {}".format(model_path),"HEADER")) diff --git a/tools/program.py b/tools/program.py index 50c609de..c8e9556e 100644 --- a/tools/program.py +++ b/tools/program.py @@ -396,19 +396,24 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'): for i, m in enumerate(metrics): metric_list[i].update(m[0], len(batch[0])) fetchs_str = ''.join([str(m.value) + ' ' - for m in metric_list] + [batch_time.value]) + for m in metric_list] + [batch_time.value])+'s' if mode == 'eval': logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str)) else: - logger.info("epoch:{:<3d} {:s} step:{:<4d} {:s}s".format( - epoch, mode, idx, fetchs_str)) + epoch_str = "epoch:{:<3d}".format(epoch) + step_str = "{:s} step:{:<4d}".format(mode, idx) + + logger.info("{:s} {:s} {:s}".format( + logger.coloring(epoch_str, "HEADER") if idx==0 else epoch_str, logger.coloring(step_str,"PURPLE"), logger.coloring(fetchs_str,'OKGREEN'))) end_str = ''.join([str(m.mean) + ' ' - for m in metric_list] + [batch_time.total]) + for m in metric_list] + [batch_time.total])+'s' if mode == 'eval': logger.info("END {:s} {:s}s".format(mode, end_str)) else: - logger.info("END epoch:{:<3d} {:s} {:s}s".format(epoch, mode, end_str)) + end_epoch_str = "END epoch:{:<3d}".format(epoch) + + logger.info("{:s} {:s} {:s}".format(logger.coloring(end_epoch_str,"RED"), logger.coloring(mode,"PURPLE"), logger.coloring(end_str,"OKGREEN"))) # return top1_acc in order to save the best model if mode == 'valid': diff --git a/tools/train.py b/tools/train.py index ab7752fd..cd5b7d25 100644 --- a/tools/train.py +++ b/tools/train.py @@ -62,7 +62,7 @@ def main(args): startup_prog = fluid.Program() train_prog = fluid.Program() - best_top1_acc_list = (0.0, -1) # (top1_acc, epoch_id) + best_top1_acc = 0.0 # best top1 acc record train_dataloader, train_fetchs = program.build( config, train_prog, startup_prog, is_train=True) @@ -101,13 +101,15 @@ def main(args): top1_acc = program.run(valid_dataloader, exe, compiled_valid_prog, valid_fetchs, epoch_id, 'valid') - if top1_acc > best_top1_acc_list[0]: - best_top1_acc_list = (top1_acc, epoch_id) - logger.info("Best top1 acc: {}, in epoch: {}".format( - *best_top1_acc_list)) - model_path = os.path.join(config.model_save_dir, + if top1_acc > best_top1_acc: + best_top1_acc = top1_acc + message = "The best top1 acc {:.5f}, in epoch: {:d}".format(best_top1_acc, epoch_id) + logger.info("{:s}".format(logger.coloring(message, "RED"))) + if epoch_id % config.save_interval==0: + + model_path = os.path.join(config.model_save_dir, config.ARCHITECTURE["name"]) - save_model(train_prog, model_path, "best_model") + save_model(train_prog, model_path, "best_model_in_epoch_"+str(epoch_id)) # 3. save the persistable model if epoch_id % config.save_interval == 0: -- GitLab