未验证 提交 0f27cccc 编写于 作者: D dyning 提交者: GitHub

Merge pull request #100 from shippingwang/update_log

Refine code and recover coloring function
...@@ -64,14 +64,14 @@ def print_dict(d, delimiter=0): ...@@ -64,14 +64,14 @@ def print_dict(d, delimiter=0):
placeholder = "-" * 60 placeholder = "-" * 60
for k, v in sorted(d.items()): for k, v in sorted(d.items()):
if isinstance(v, dict): if isinstance(v, dict):
logger.info("{}{} : ".format(delimiter * " ", k)) logger.info("{}{} : ".format(delimiter * " ", logger.coloring(k, "HEADER")))
print_dict(v, delimiter + 4) print_dict(v, delimiter + 4)
elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict): elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict):
logger.info("{}{} : ".format(delimiter * " ", k)) logger.info("{}{} : ".format(delimiter * " ", logger.coloring(str(k),"HEADER")))
for value in v: for value in v:
print_dict(value, delimiter + 4) print_dict(value, delimiter + 4)
else: else:
logger.info("{}{} : {}".format(delimiter * " ", k, v)) logger.info("{}{} : {}".format(delimiter * " ", logger.coloring(k,"HEADER"), logger.coloring(v,"OKGREEN")))
if k.isupper(): if k.isupper():
logger.info(placeholder) logger.info(placeholder)
......
...@@ -14,15 +14,48 @@ ...@@ -14,15 +14,48 @@
import logging import logging
import os import os
import datetime
logging.basicConfig(level=logging.INFO) from imp import reload
reload(logging)
logging.basicConfig(level=logging.INFO,
format="%(asctime)s %(levelname)s: %(message)s",
datefmt = "%Y-%m-%d %H:%M:%S")
def time_zone(sec, fmt):
real_time = datetime.datetime.now() + datetime.timedelta(hours=8)
return real_time.timetuple()
logging.Formatter.converter = time_zone
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
Color= {
'RED' : '\033[31m' ,
'HEADER' : '\033[35m' , # deep purple
'PURPLE' : '\033[95m' ,# purple
'OKBLUE' : '\033[94m' ,
'OKGREEN' : '\033[92m' ,
'WARNING' : '\033[93m' ,
'FAIL' : '\033[91m' ,
'ENDC' : '\033[0m' }
def coloring(message, color="OKGREEN"):
assert color in Color.keys()
if os.environ.get('PADDLECLAS_COLORING', False):
return Color[color]+str(message)+Color["ENDC"]
else:
return message
def anti_fleet(log): def anti_fleet(log):
""" """
Because of the fucking Fleet, logs will print multi-times. logs will print multi-times when calling Fleet API.
So we only display one of them and ignore the others. Only display single log and ignore the others.
""" """
def wrapper(fmt, *args): def wrapper(fmt, *args):
...@@ -39,12 +72,12 @@ def info(fmt, *args): ...@@ -39,12 +72,12 @@ def info(fmt, *args):
@anti_fleet @anti_fleet
def warning(fmt, *args): def warning(fmt, *args):
_logger.warning(fmt, *args) _logger.warning(coloring(fmt, "RED"), *args)
@anti_fleet @anti_fleet
def error(fmt, *args): def error(fmt, *args):
_logger.error(fmt, *args) _logger.error(coloring(fmt, "FAIL"), *args)
def advertise(): def advertise():
...@@ -66,7 +99,7 @@ def advertise(): ...@@ -66,7 +99,7 @@ def advertise():
website = "https://github.com/PaddlePaddle/PaddleClas" website = "https://github.com/PaddlePaddle/PaddleClas"
AD_LEN = 6 + len(max([copyright, ad, website], key=len)) AD_LEN = 6 + len(max([copyright, ad, website], key=len))
info("\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format( info(coloring("\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format(
"=" * (AD_LEN + 4), "=" * (AD_LEN + 4),
"=={}==".format(copyright.center(AD_LEN)), "=={}==".format(copyright.center(AD_LEN)),
"=" * (AD_LEN + 4), "=" * (AD_LEN + 4),
...@@ -74,4 +107,4 @@ def advertise(): ...@@ -74,4 +107,4 @@ def advertise():
"=={}==".format(ad.center(AD_LEN)), "=={}==".format(ad.center(AD_LEN)),
"=={}==".format(' ' * AD_LEN), "=={}==".format(' ' * AD_LEN),
"=={}==".format(website.center(AD_LEN)), "=={}==".format(website.center(AD_LEN)),
"=" * (AD_LEN + 4), )) "=" * (AD_LEN + 4), ),"RED"))
...@@ -46,7 +46,6 @@ def _mkdir_if_not_exist(path): ...@@ -46,7 +46,6 @@ def _mkdir_if_not_exist(path):
def _load_state(path): def _load_state(path):
print("path: ", path)
if os.path.exists(path + '.pdopt'): if os.path.exists(path + '.pdopt'):
# XXX another hack to ignore the optimizer state # XXX another hack to ignore the optimizer state
tmp = tempfile.mkdtemp() tmp = tempfile.mkdtemp()
...@@ -55,7 +54,6 @@ def _load_state(path): ...@@ -55,7 +54,6 @@ def _load_state(path):
state = fluid.io.load_program_state(dst) state = fluid.io.load_program_state(dst)
shutil.rmtree(tmp) shutil.rmtree(tmp)
else: else:
print("path: ", path)
state = fluid.io.load_program_state(path) state = fluid.io.load_program_state(path)
return state return state
...@@ -75,7 +73,7 @@ def load_params(exe, prog, path, ignore_params=[]): ...@@ -75,7 +73,7 @@ def load_params(exe, prog, path, ignore_params=[]):
raise ValueError("Model pretrain path {} does not " raise ValueError("Model pretrain path {} does not "
"exists.".format(path)) "exists.".format(path))
logger.info('Loading parameters from {}...'.format(path)) logger.info(logger.coloring('Loading parameters from {}...'.format(path), 'HEADER'))
ignore_set = set() ignore_set = set()
state = _load_state(path) state = _load_state(path)
...@@ -101,7 +99,7 @@ def load_params(exe, prog, path, ignore_params=[]): ...@@ -101,7 +99,7 @@ def load_params(exe, prog, path, ignore_params=[]):
if len(ignore_set) > 0: if len(ignore_set) > 0:
for k in ignore_set: for k in ignore_set:
if k in state: if k in state:
logger.warning('variable {} not used'.format(k)) logger.warning('variable {} is already excluded automatically'.format(k))
del state[k] del state[k]
fluid.io.set_program_state(prog, state) fluid.io.set_program_state(prog, state)
...@@ -113,7 +111,7 @@ def init_model(config, program, exe): ...@@ -113,7 +111,7 @@ def init_model(config, program, exe):
checkpoints = config.get('checkpoints') checkpoints = config.get('checkpoints')
if checkpoints: if checkpoints:
fluid.load(program, checkpoints, exe) fluid.load(program, checkpoints, exe)
logger.info("Finish initing model from {}".format(checkpoints)) logger.info(logger.coloring("Finish initing model from {}".format(checkpoints),"HEADER"))
return return
pretrained_model = config.get('pretrained_model') pretrained_model = config.get('pretrained_model')
...@@ -122,7 +120,7 @@ def init_model(config, program, exe): ...@@ -122,7 +120,7 @@ def init_model(config, program, exe):
pretrained_model = [pretrained_model] pretrained_model = [pretrained_model]
for pretrain in pretrained_model: for pretrain in pretrained_model:
load_params(exe, program, pretrain) load_params(exe, program, pretrain)
logger.info("Finish initing model from {}".format(pretrained_model)) logger.info(logger.coloring("Finish initing model from {}".format(pretrained_model),"HEADER"))
def save_model(program, model_path, epoch_id, prefix='ppcls'): def save_model(program, model_path, epoch_id, prefix='ppcls'):
...@@ -133,4 +131,4 @@ def save_model(program, model_path, epoch_id, prefix='ppcls'): ...@@ -133,4 +131,4 @@ def save_model(program, model_path, epoch_id, prefix='ppcls'):
_mkdir_if_not_exist(model_path) _mkdir_if_not_exist(model_path)
model_prefix = os.path.join(model_path, prefix) model_prefix = os.path.join(model_path, prefix)
fluid.save(program, model_prefix) fluid.save(program, model_prefix)
logger.info("Already save model in {}".format(model_path)) logger.info(logger.coloring("Already save model in {}".format(model_path),"HEADER"))
...@@ -396,19 +396,24 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'): ...@@ -396,19 +396,24 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'):
for i, m in enumerate(metrics): for i, m in enumerate(metrics):
metric_list[i].update(m[0], len(batch[0])) metric_list[i].update(m[0], len(batch[0]))
fetchs_str = ''.join([str(m.value) + ' ' fetchs_str = ''.join([str(m.value) + ' '
for m in metric_list] + [batch_time.value]) for m in metric_list] + [batch_time.value])+'s'
if mode == 'eval': if mode == 'eval':
logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str)) logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str))
else: else:
logger.info("epoch:{:<3d} {:s} step:{:<4d} {:s}s".format( epoch_str = "epoch:{:<3d}".format(epoch)
epoch, mode, idx, fetchs_str)) step_str = "{:s} step:{:<4d}".format(mode, idx)
logger.info("{:s} {:s} {:s}".format(
logger.coloring(epoch_str, "HEADER") if idx==0 else epoch_str, logger.coloring(step_str,"PURPLE"), logger.coloring(fetchs_str,'OKGREEN')))
end_str = ''.join([str(m.mean) + ' ' end_str = ''.join([str(m.mean) + ' '
for m in metric_list] + [batch_time.total]) for m in metric_list] + [batch_time.total])+'s'
if mode == 'eval': if mode == 'eval':
logger.info("END {:s} {:s}s".format(mode, end_str)) logger.info("END {:s} {:s}s".format(mode, end_str))
else: else:
logger.info("END epoch:{:<3d} {:s} {:s}s".format(epoch, mode, end_str)) end_epoch_str = "END epoch:{:<3d}".format(epoch)
logger.info("{:s} {:s} {:s}".format(logger.coloring(end_epoch_str,"RED"), logger.coloring(mode,"PURPLE"), logger.coloring(end_str,"OKGREEN")))
# return top1_acc in order to save the best model # return top1_acc in order to save the best model
if mode == 'valid': if mode == 'valid':
......
...@@ -62,7 +62,7 @@ def main(args): ...@@ -62,7 +62,7 @@ def main(args):
startup_prog = fluid.Program() startup_prog = fluid.Program()
train_prog = fluid.Program() train_prog = fluid.Program()
best_top1_acc_list = (0.0, -1) # (top1_acc, epoch_id) best_top1_acc = 0.0 # best top1 acc record
train_dataloader, train_fetchs = program.build( train_dataloader, train_fetchs = program.build(
config, train_prog, startup_prog, is_train=True) config, train_prog, startup_prog, is_train=True)
...@@ -101,13 +101,15 @@ def main(args): ...@@ -101,13 +101,15 @@ def main(args):
top1_acc = program.run(valid_dataloader, exe, top1_acc = program.run(valid_dataloader, exe,
compiled_valid_prog, valid_fetchs, compiled_valid_prog, valid_fetchs,
epoch_id, 'valid') epoch_id, 'valid')
if top1_acc > best_top1_acc_list[0]: if top1_acc > best_top1_acc:
best_top1_acc_list = (top1_acc, epoch_id) best_top1_acc = top1_acc
logger.info("Best top1 acc: {}, in epoch: {}".format( message = "The best top1 acc {:.5f}, in epoch: {:d}".format(best_top1_acc, epoch_id)
*best_top1_acc_list)) logger.info("{:s}".format(logger.coloring(message, "RED")))
model_path = os.path.join(config.model_save_dir, if epoch_id % config.save_interval==0:
model_path = os.path.join(config.model_save_dir,
config.ARCHITECTURE["name"]) config.ARCHITECTURE["name"])
save_model(train_prog, model_path, "best_model") save_model(train_prog, model_path, "best_model_in_epoch_"+str(epoch_id))
# 3. save the persistable model # 3. save the persistable model
if epoch_id % config.save_interval == 0: if epoch_id % config.save_interval == 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册