提交 62772c11 编写于 作者: S shippingwang

add visualdl

上级 bd67368c
......@@ -19,9 +19,10 @@ import datetime
from imp import reload
reload(logging)
logging.basicConfig(level=logging.INFO,
format="%(asctime)s %(levelname)s: %(message)s",
datefmt = "%Y-%m-%d %H:%M:%S")
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S")
def time_zone(sec, fmt):
......@@ -32,22 +33,22 @@ def time_zone(sec, fmt):
logging.Formatter.converter = time_zone
_logger = logging.getLogger(__name__)
Color= {
'RED' : '\033[31m' ,
'HEADER' : '\033[35m' , # deep purple
'PURPLE' : '\033[95m' ,# purple
'OKBLUE' : '\033[94m' ,
'OKGREEN' : '\033[92m' ,
'WARNING' : '\033[93m' ,
'FAIL' : '\033[91m' ,
'ENDC' : '\033[0m' }
Color = {
'RED': '\033[31m',
'HEADER': '\033[35m', # deep purple
'PURPLE': '\033[95m', # purple
'OKBLUE': '\033[94m',
'OKGREEN': '\033[92m',
'WARNING': '\033[93m',
'FAIL': '\033[91m',
'ENDC': '\033[0m'
}
def coloring(message, color="OKGREEN"):
assert color in Color.keys()
if os.environ.get('PADDLECLAS_COLORING', False):
return Color[color]+str(message)+Color["ENDC"]
return Color[color] + str(message) + Color["ENDC"]
else:
return message
......@@ -80,6 +81,12 @@ def error(fmt, *args):
_logger.error(coloring(fmt, "FAIL"), *args)
def scaler(name, value, step, path):
from visualdl import LogWriter
vdl_writer = LogWriter(path)
vdl_writer.add_scalar(name, value, step)
def advertise():
"""
Show the advertising message like the following:
......@@ -99,12 +106,13 @@ def advertise():
website = "https://github.com/PaddlePaddle/PaddleClas"
AD_LEN = 6 + len(max([copyright, ad, website], key=len))
info(coloring("\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format(
"=" * (AD_LEN + 4),
"=={}==".format(copyright.center(AD_LEN)),
"=" * (AD_LEN + 4),
"=={}==".format(' ' * AD_LEN),
"=={}==".format(ad.center(AD_LEN)),
"=={}==".format(' ' * AD_LEN),
"=={}==".format(website.center(AD_LEN)),
"=" * (AD_LEN + 4), ),"RED"))
info(
coloring("\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format(
"=" * (AD_LEN + 4),
"=={}==".format(copyright.center(AD_LEN)),
"=" * (AD_LEN + 4),
"=={}==".format(' ' * AD_LEN),
"=={}==".format(ad.center(AD_LEN)),
"=={}==".format(' ' * AD_LEN),
"=={}==".format(website.center(AD_LEN)),
"=" * (AD_LEN + 4), ), "RED"))
......@@ -384,7 +384,10 @@ def compile(config, program, loss_name=None):
return compiled_program
def run(dataloader, exe, program, fetchs, epoch=0, mode='train'):
total_step = 0
def run(dataloader, exe, program, fetchs, epoch=0, mode='train', vdl_dir=None):
"""
Feed data to the model and fetch the measures and loss
......@@ -412,6 +415,10 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'):
metric_list[i].update(m[0], len(batch[0]))
fetchs_str = ''.join([str(m.value) + ' '
for m in metric_list] + [batch_time.value]) + 's'
if vdl_dir:
global total_step
logger.scaler('loss', metrics[0][0], total_step, vdl_dir)
total_step += 1
if mode == 'eval':
logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str))
else:
......
......@@ -38,6 +38,11 @@ def parse_args():
type=str,
default='configs/ResNet/ResNet50.yaml',
help='config file path')
parser.add_argument(
'--vdl_dir',
type=str,
default="scaler",
help='VisualDL logging directory for image.')
parser.add_argument(
'-o',
'--override',
......@@ -94,7 +99,7 @@ def main(args):
for epoch_id in range(config.epochs):
# 1. train with train dataset
program.run(train_dataloader, exe, compiled_train_prog, train_fetchs,
epoch_id, 'train')
epoch_id, 'train', args.vdl_dir)
if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0:
# 2. validate with validate dataset
if config.validate and epoch_id % config.valid_interval == 0:
......@@ -103,13 +108,15 @@ def main(args):
epoch_id, 'valid')
if top1_acc > best_top1_acc:
best_top1_acc = top1_acc
message = "The best top1 acc {:.5f}, in epoch: {:d}".format(best_top1_acc, epoch_id)
message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
best_top1_acc, epoch_id)
logger.info("{:s}".format(logger.coloring(message, "RED")))
if epoch_id % config.save_interval==0:
if epoch_id % config.save_interval == 0:
model_path = os.path.join(config.model_save_dir,
config.ARCHITECTURE["name"])
save_model(train_prog, model_path, "best_model_in_epoch_"+str(epoch_id))
config.ARCHITECTURE["name"])
save_model(train_prog, model_path,
"best_model_in_epoch_" + str(epoch_id))
# 3. save the persistable model
if epoch_id % config.save_interval == 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册