diff --git a/ppcls/utils/logger.py b/ppcls/utils/logger.py index 5b192c61b46d82cc99e88269c5118c3e9e182b66..7dda65e6422451cc1234ea1738bb6c699b86d0f3 100644 --- a/ppcls/utils/logger.py +++ b/ppcls/utils/logger.py @@ -19,9 +19,10 @@ import datetime from imp import reload reload(logging) -logging.basicConfig(level=logging.INFO, - format="%(asctime)s %(levelname)s: %(message)s", - datefmt = "%Y-%m-%d %H:%M:%S") +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S") def time_zone(sec, fmt): @@ -32,22 +33,22 @@ def time_zone(sec, fmt): logging.Formatter.converter = time_zone _logger = logging.getLogger(__name__) - -Color= { - 'RED' : '\033[31m' , - 'HEADER' : '\033[35m' , # deep purple - 'PURPLE' : '\033[95m' ,# purple - 'OKBLUE' : '\033[94m' , - 'OKGREEN' : '\033[92m' , - 'WARNING' : '\033[93m' , - 'FAIL' : '\033[91m' , - 'ENDC' : '\033[0m' } +Color = { + 'RED': '\033[31m', + 'HEADER': '\033[35m', # deep purple + 'PURPLE': '\033[95m', # purple + 'OKBLUE': '\033[94m', + 'OKGREEN': '\033[92m', + 'WARNING': '\033[93m', + 'FAIL': '\033[91m', + 'ENDC': '\033[0m' +} def coloring(message, color="OKGREEN"): assert color in Color.keys() if os.environ.get('PADDLECLAS_COLORING', False): - return Color[color]+str(message)+Color["ENDC"] + return Color[color] + str(message) + Color["ENDC"] else: return message @@ -80,6 +81,10 @@ def error(fmt, *args): _logger.error(coloring(fmt, "FAIL"), *args) +def scaler(name, value, step, writer): + writer.add_scalar(name, value, step) + + def advertise(): """ Show the advertising message like the following: @@ -99,12 +104,13 @@ def advertise(): website = "https://github.com/PaddlePaddle/PaddleClas" AD_LEN = 6 + len(max([copyright, ad, website], key=len)) - info(coloring("\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format( - "=" * (AD_LEN + 4), - "=={}==".format(copyright.center(AD_LEN)), - "=" * (AD_LEN + 4), - "=={}==".format(' ' * AD_LEN), - "=={}==".format(ad.center(AD_LEN)), - "=={}==".format(' ' * AD_LEN), - "=={}==".format(website.center(AD_LEN)), - "=" * (AD_LEN + 4), ),"RED")) + info( + coloring("\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format( + "=" * (AD_LEN + 4), + "=={}==".format(copyright.center(AD_LEN)), + "=" * (AD_LEN + 4), + "=={}==".format(' ' * AD_LEN), + "=={}==".format(ad.center(AD_LEN)), + "=={}==".format(' ' * AD_LEN), + "=={}==".format(website.center(AD_LEN)), + "=" * (AD_LEN + 4), ), "RED")) diff --git a/tools/program.py b/tools/program.py index b73f1064284a7f1929f99d892bee43798a5d52fb..2428409ef70850ce5247c2e448e5716e50f97178 100644 --- a/tools/program.py +++ b/tools/program.py @@ -384,7 +384,10 @@ def compile(config, program, loss_name=None): return compiled_program -def run(dataloader, exe, program, fetchs, epoch=0, mode='train'): +total_step = 0 + + +def run(dataloader, exe, program, fetchs, epoch=0, mode='train', vdl_writer=None): """ Feed data to the model and fetch the measures and loss @@ -412,6 +415,10 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'): metric_list[i].update(m[0], len(batch[0])) fetchs_str = ''.join([str(m.value) + ' ' for m in metric_list] + [batch_time.value]) + 's' + if vdl_writer: + global total_step + logger.scaler('loss', metrics[0][0], total_step, vdl_writer) + total_step += 1 if mode == 'eval': logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str)) else: diff --git a/tools/train.py b/tools/train.py index cd5b7d25b499751ccc9289fe96139e630b09fee7..993907136656f789dd8841a1636af6ea515177c8 100644 --- a/tools/train.py +++ b/tools/train.py @@ -19,6 +19,7 @@ from __future__ import print_function import argparse import os +from visualdl import LogWriter import paddle.fluid as fluid from paddle.fluid.incubate.fleet.base import role_maker from paddle.fluid.incubate.fleet.collective import fleet @@ -38,6 +39,11 @@ def parse_args(): type=str, default='configs/ResNet/ResNet50.yaml', help='config file path') + parser.add_argument( + '--vdl_dir', + type=str, + default=None, + help='VisualDL logging directory for image.') parser.add_argument( '-o', '--override', @@ -91,10 +97,12 @@ def main(args): compiled_valid_prog = program.compile(config, valid_prog) compiled_train_prog = fleet.main_program + vdl_writer = LogWriter(args.vdl_dir) if args.vdl_dir else None + for epoch_id in range(config.epochs): # 1. train with train dataset program.run(train_dataloader, exe, compiled_train_prog, train_fetchs, - epoch_id, 'train') + epoch_id, 'train', vdl_writer) if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0: # 2. validate with validate dataset if config.validate and epoch_id % config.valid_interval == 0: @@ -103,13 +111,15 @@ def main(args): epoch_id, 'valid') if top1_acc > best_top1_acc: best_top1_acc = top1_acc - message = "The best top1 acc {:.5f}, in epoch: {:d}".format(best_top1_acc, epoch_id) + message = "The best top1 acc {:.5f}, in epoch: {:d}".format( + best_top1_acc, epoch_id) logger.info("{:s}".format(logger.coloring(message, "RED"))) - if epoch_id % config.save_interval==0: + if epoch_id % config.save_interval == 0: model_path = os.path.join(config.model_save_dir, - config.ARCHITECTURE["name"]) - save_model(train_prog, model_path, "best_model_in_epoch_"+str(epoch_id)) + config.ARCHITECTURE["name"]) + save_model(train_prog, model_path, + "best_model_in_epoch_" + str(epoch_id)) # 3. save the persistable model if epoch_id % config.save_interval == 0: