未验证 提交 f3b2e8aa 编写于 作者: R ruri 提交者: GitHub

Merge pull request #137 from shippingwang/add_visualdl

add visualdl
...@@ -19,9 +19,10 @@ import datetime ...@@ -19,9 +19,10 @@ import datetime
from imp import reload from imp import reload
reload(logging) reload(logging)
logging.basicConfig(level=logging.INFO, logging.basicConfig(
format="%(asctime)s %(levelname)s: %(message)s", level=logging.INFO,
datefmt = "%Y-%m-%d %H:%M:%S") format="%(asctime)s %(levelname)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S")
def time_zone(sec, fmt): def time_zone(sec, fmt):
...@@ -32,22 +33,22 @@ def time_zone(sec, fmt): ...@@ -32,22 +33,22 @@ def time_zone(sec, fmt):
logging.Formatter.converter = time_zone logging.Formatter.converter = time_zone
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
Color = {
Color= { 'RED': '\033[31m',
'RED' : '\033[31m' , 'HEADER': '\033[35m', # deep purple
'HEADER' : '\033[35m' , # deep purple 'PURPLE': '\033[95m', # purple
'PURPLE' : '\033[95m' ,# purple 'OKBLUE': '\033[94m',
'OKBLUE' : '\033[94m' , 'OKGREEN': '\033[92m',
'OKGREEN' : '\033[92m' , 'WARNING': '\033[93m',
'WARNING' : '\033[93m' , 'FAIL': '\033[91m',
'FAIL' : '\033[91m' , 'ENDC': '\033[0m'
'ENDC' : '\033[0m' } }
def coloring(message, color="OKGREEN"): def coloring(message, color="OKGREEN"):
assert color in Color.keys() assert color in Color.keys()
if os.environ.get('PADDLECLAS_COLORING', False): if os.environ.get('PADDLECLAS_COLORING', False):
return Color[color]+str(message)+Color["ENDC"] return Color[color] + str(message) + Color["ENDC"]
else: else:
return message return message
...@@ -80,6 +81,10 @@ def error(fmt, *args): ...@@ -80,6 +81,10 @@ def error(fmt, *args):
_logger.error(coloring(fmt, "FAIL"), *args) _logger.error(coloring(fmt, "FAIL"), *args)
def scaler(name, value, step, writer):
writer.add_scalar(name, value, step)
def advertise(): def advertise():
""" """
Show the advertising message like the following: Show the advertising message like the following:
...@@ -99,12 +104,13 @@ def advertise(): ...@@ -99,12 +104,13 @@ def advertise():
website = "https://github.com/PaddlePaddle/PaddleClas" website = "https://github.com/PaddlePaddle/PaddleClas"
AD_LEN = 6 + len(max([copyright, ad, website], key=len)) AD_LEN = 6 + len(max([copyright, ad, website], key=len))
info(coloring("\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format( info(
"=" * (AD_LEN + 4), coloring("\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format(
"=={}==".format(copyright.center(AD_LEN)), "=" * (AD_LEN + 4),
"=" * (AD_LEN + 4), "=={}==".format(copyright.center(AD_LEN)),
"=={}==".format(' ' * AD_LEN), "=" * (AD_LEN + 4),
"=={}==".format(ad.center(AD_LEN)), "=={}==".format(' ' * AD_LEN),
"=={}==".format(' ' * AD_LEN), "=={}==".format(ad.center(AD_LEN)),
"=={}==".format(website.center(AD_LEN)), "=={}==".format(' ' * AD_LEN),
"=" * (AD_LEN + 4), ),"RED")) "=={}==".format(website.center(AD_LEN)),
"=" * (AD_LEN + 4), ), "RED"))
...@@ -384,7 +384,10 @@ def compile(config, program, loss_name=None): ...@@ -384,7 +384,10 @@ def compile(config, program, loss_name=None):
return compiled_program return compiled_program
def run(dataloader, exe, program, fetchs, epoch=0, mode='train'): total_step = 0
def run(dataloader, exe, program, fetchs, epoch=0, mode='train', vdl_writer=None):
""" """
Feed data to the model and fetch the measures and loss Feed data to the model and fetch the measures and loss
...@@ -412,6 +415,10 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'): ...@@ -412,6 +415,10 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'):
metric_list[i].update(m[0], len(batch[0])) metric_list[i].update(m[0], len(batch[0]))
fetchs_str = ''.join([str(m.value) + ' ' fetchs_str = ''.join([str(m.value) + ' '
for m in metric_list] + [batch_time.value]) + 's' for m in metric_list] + [batch_time.value]) + 's'
if vdl_writer:
global total_step
logger.scaler('loss', metrics[0][0], total_step, vdl_writer)
total_step += 1
if mode == 'eval': if mode == 'eval':
logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str)) logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str))
else: else:
......
...@@ -19,6 +19,7 @@ from __future__ import print_function ...@@ -19,6 +19,7 @@ from __future__ import print_function
import argparse import argparse
import os import os
from visualdl import LogWriter
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.incubate.fleet.base import role_maker from paddle.fluid.incubate.fleet.base import role_maker
from paddle.fluid.incubate.fleet.collective import fleet from paddle.fluid.incubate.fleet.collective import fleet
...@@ -38,6 +39,11 @@ def parse_args(): ...@@ -38,6 +39,11 @@ def parse_args():
type=str, type=str,
default='configs/ResNet/ResNet50.yaml', default='configs/ResNet/ResNet50.yaml',
help='config file path') help='config file path')
parser.add_argument(
'--vdl_dir',
type=str,
default=None,
help='VisualDL logging directory for image.')
parser.add_argument( parser.add_argument(
'-o', '-o',
'--override', '--override',
...@@ -91,10 +97,12 @@ def main(args): ...@@ -91,10 +97,12 @@ def main(args):
compiled_valid_prog = program.compile(config, valid_prog) compiled_valid_prog = program.compile(config, valid_prog)
compiled_train_prog = fleet.main_program compiled_train_prog = fleet.main_program
vdl_writer = LogWriter(args.vdl_dir) if args.vdl_dir else None
for epoch_id in range(config.epochs): for epoch_id in range(config.epochs):
# 1. train with train dataset # 1. train with train dataset
program.run(train_dataloader, exe, compiled_train_prog, train_fetchs, program.run(train_dataloader, exe, compiled_train_prog, train_fetchs,
epoch_id, 'train') epoch_id, 'train', vdl_writer)
if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0: if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0:
# 2. validate with validate dataset # 2. validate with validate dataset
if config.validate and epoch_id % config.valid_interval == 0: if config.validate and epoch_id % config.valid_interval == 0:
...@@ -103,13 +111,15 @@ def main(args): ...@@ -103,13 +111,15 @@ def main(args):
epoch_id, 'valid') epoch_id, 'valid')
if top1_acc > best_top1_acc: if top1_acc > best_top1_acc:
best_top1_acc = top1_acc best_top1_acc = top1_acc
message = "The best top1 acc {:.5f}, in epoch: {:d}".format(best_top1_acc, epoch_id) message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
best_top1_acc, epoch_id)
logger.info("{:s}".format(logger.coloring(message, "RED"))) logger.info("{:s}".format(logger.coloring(message, "RED")))
if epoch_id % config.save_interval==0: if epoch_id % config.save_interval == 0:
model_path = os.path.join(config.model_save_dir, model_path = os.path.join(config.model_save_dir,
config.ARCHITECTURE["name"]) config.ARCHITECTURE["name"])
save_model(train_prog, model_path, "best_model_in_epoch_"+str(epoch_id)) save_model(train_prog, model_path,
"best_model_in_epoch_" + str(epoch_id))
# 3. save the persistable model # 3. save the persistable model
if epoch_id % config.save_interval == 0: if epoch_id % config.save_interval == 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册