提交 d5f0a141 编写于 作者: S shippingwang

Coloring and refine code

上级 bfcd9e4d
......@@ -64,14 +64,14 @@ def print_dict(d, delimiter=0):
placeholder = "-" * 60
for k, v in sorted(d.items()):
if isinstance(v, dict):
logger.info("{}{} : ".format(delimiter * " ", k))
logger.info("{}{} : ".format(delimiter * " ", logger.coloring(k, "HEADER")))
print_dict(v, delimiter + 4)
elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict):
logger.info("{}{} : ".format(delimiter * " ", k))
logger.info("{}{} : ".format(delimiter * " ", logger.coloring(str(k),"HEADER")))
for value in v:
print_dict(value, delimiter + 4)
else:
logger.info("{}{} : {}".format(delimiter * " ", k, v))
logger.info("{}{} : {}".format(delimiter * " ", logger.coloring(k,"HEADER"), logger.coloring(v,"OKGREEN")))
if k.isupper():
logger.info(placeholder)
......
......@@ -15,14 +15,30 @@
import logging
import os
logging.basicConfig(level=logging.INFO)
logging.basicConfig(level=logging.INFO, format='%(message)s')
_logger = logging.getLogger(__name__)
Color= {
'RED' : '\033[31m' ,
'HEADER' : '\033[35m' , # deep purple
'PURPLE' : '\033[95m' ,# purple
'OKBLUE' : '\033[94m' ,
'OKGREEN' : '\033[92m' ,
'WARNING' : '\033[93m' ,
'FAIL' : '\033[91m' ,
'ENDC' : '\033[0m' }
def coloring(message, color="OKGREEN"):
assert color in Color.keys()
if os.environ.get('PADDLECLAS_COLORING', False):
return Color[color]+str(message)+Color["ENDC"]
else:
return message
def anti_fleet(log):
"""
Because of the fucking Fleet, logs will print multi-times.
So we only display one of them and ignore the others.
logs will print multi-times when calling Fleet API.
Only display single log and ignore the others.
"""
def wrapper(fmt, *args):
......@@ -39,12 +55,12 @@ def info(fmt, *args):
@anti_fleet
def warning(fmt, *args):
_logger.warning(fmt, *args)
_logger.warning(coloring(fmt, "RED"), *args)
@anti_fleet
def error(fmt, *args):
_logger.error(fmt, *args)
_logger.error(coloring(fmt, "FAIL"), *args)
def advertise():
......@@ -66,7 +82,7 @@ def advertise():
website = "https://github.com/PaddlePaddle/PaddleClas"
AD_LEN = 6 + len(max([copyright, ad, website], key=len))
info("\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format(
info(coloring("\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format(
"=" * (AD_LEN + 4),
"=={}==".format(copyright.center(AD_LEN)),
"=" * (AD_LEN + 4),
......@@ -74,4 +90,4 @@ def advertise():
"=={}==".format(ad.center(AD_LEN)),
"=={}==".format(' ' * AD_LEN),
"=={}==".format(website.center(AD_LEN)),
"=" * (AD_LEN + 4), ))
"=" * (AD_LEN + 4), ),"RED"))
......@@ -46,7 +46,6 @@ def _mkdir_if_not_exist(path):
def _load_state(path):
print("path: ", path)
if os.path.exists(path + '.pdopt'):
# XXX another hack to ignore the optimizer state
tmp = tempfile.mkdtemp()
......@@ -55,7 +54,6 @@ def _load_state(path):
state = fluid.io.load_program_state(dst)
shutil.rmtree(tmp)
else:
print("path: ", path)
state = fluid.io.load_program_state(path)
return state
......@@ -75,7 +73,7 @@ def load_params(exe, prog, path, ignore_params=[]):
raise ValueError("Model pretrain path {} does not "
"exists.".format(path))
logger.info('Loading parameters from {}...'.format(path))
logger.info(logger.coloring('Loading parameters from {}...'.format(path), 'HEADER'))
ignore_set = set()
state = _load_state(path)
......@@ -101,7 +99,7 @@ def load_params(exe, prog, path, ignore_params=[]):
if len(ignore_set) > 0:
for k in ignore_set:
if k in state:
logger.warning('variable {} not used'.format(k))
logger.warning('variable {} is already excluded automatically'.format(k))
del state[k]
fluid.io.set_program_state(prog, state)
......@@ -113,7 +111,7 @@ def init_model(config, program, exe):
checkpoints = config.get('checkpoints')
if checkpoints:
fluid.load(program, checkpoints, exe)
logger.info("Finish initing model from {}".format(checkpoints))
logger.info(logger.coloring("Finish initing model from {}".format(checkpoints),"HEADER"))
return
pretrained_model = config.get('pretrained_model')
......@@ -122,7 +120,7 @@ def init_model(config, program, exe):
pretrained_model = [pretrained_model]
for pretrain in pretrained_model:
load_params(exe, program, pretrain)
logger.info("Finish initing model from {}".format(pretrained_model))
logger.info(logger.coloring("Finish initing model from {}".format(pretrained_model),"HEADER"))
def save_model(program, model_path, epoch_id, prefix='ppcls'):
......@@ -133,4 +131,4 @@ def save_model(program, model_path, epoch_id, prefix='ppcls'):
_mkdir_if_not_exist(model_path)
model_prefix = os.path.join(model_path, prefix)
fluid.save(program, model_prefix)
logger.info("Already save model in {}".format(model_path))
logger.info(logger.coloring("Already save model in {}".format(model_path),"HEADER"))
......@@ -396,19 +396,24 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'):
for i, m in enumerate(metrics):
metric_list[i].update(m[0], len(batch[0]))
fetchs_str = ''.join([str(m.value) + ' '
for m in metric_list] + [batch_time.value])
for m in metric_list] + [batch_time.value])+'s'
if mode == 'eval':
logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str))
else:
logger.info("epoch:{:<3d} {:s} step:{:<4d} {:s}s".format(
epoch, mode, idx, fetchs_str))
epoch_str = "epoch:{:<3d}".format(epoch)
step_str = "{:s} step:{:<4d}".format(mode, idx)
logger.info("{:s} {:s} {:s}".format(
logger.coloring(epoch_str, "HEADER") if idx==0 else epoch_str, logger.coloring(step_str,"PURPLE"), logger.coloring(fetchs_str,'OKGREEN')))
end_str = ''.join([str(m.mean) + ' '
for m in metric_list] + [batch_time.total])
for m in metric_list] + [batch_time.total])+'s'
if mode == 'eval':
logger.info("END {:s} {:s}s".format(mode, end_str))
else:
logger.info("END epoch:{:<3d} {:s} {:s}s".format(epoch, mode, end_str))
end_epoch_str = "END epoch:{:<3d}".format(epoch)
logger.info("{:s} {:s} {:s}".format(logger.coloring(end_epoch_str,"RED"), logger.coloring(mode,"PURPLE"), logger.coloring(end_str,"OKGREEN")))
# return top1_acc in order to save the best model
if mode == 'valid':
......
......@@ -62,7 +62,7 @@ def main(args):
startup_prog = fluid.Program()
train_prog = fluid.Program()
best_top1_acc_list = (0.0, -1) # (top1_acc, epoch_id)
best_top1_acc = 0.0 # best top1 acc record
train_dataloader, train_fetchs = program.build(
config, train_prog, startup_prog, is_train=True)
......@@ -101,13 +101,15 @@ def main(args):
top1_acc = program.run(valid_dataloader, exe,
compiled_valid_prog, valid_fetchs,
epoch_id, 'valid')
if top1_acc > best_top1_acc_list[0]:
best_top1_acc_list = (top1_acc, epoch_id)
logger.info("Best top1 acc: {}, in epoch: {}".format(
*best_top1_acc_list))
model_path = os.path.join(config.model_save_dir,
if top1_acc > best_top1_acc:
best_top1_acc = top1_acc
message = "The best top1 acc {:.5f}, in epoch: {:d}".format(best_top1_acc, epoch_id)
logger.info("{:s}".format(logger.coloring(message, "RED")))
if epoch_id % config.save_interval==0:
model_path = os.path.join(config.model_save_dir,
config.ARCHITECTURE["name"])
save_model(train_prog, model_path, "best_model")
save_model(train_prog, model_path, "best_model_in_epoch_"+str(epoch_id))
# 3. save the persistable model
if epoch_id % config.save_interval == 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册