From cf40ed6f1fb62466bda03cac657981be647b9c2f Mon Sep 17 00:00:00 2001 From: shippingwang Date: Fri, 29 May 2020 13:52:43 +0000 Subject: [PATCH] fixed --- ppcls/utils/logger.py | 6 ++---- tools/program.py | 6 +++--- tools/train.py | 5 ++++- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/ppcls/utils/logger.py b/ppcls/utils/logger.py index 5e2c2b58..7dda65e6 100644 --- a/ppcls/utils/logger.py +++ b/ppcls/utils/logger.py @@ -81,10 +81,8 @@ def error(fmt, *args): _logger.error(coloring(fmt, "FAIL"), *args) -def scaler(name, value, step, path): - from visualdl import LogWriter - vdl_writer = LogWriter(path) - vdl_writer.add_scalar(name, value, step) +def scaler(name, value, step, writer): + writer.add_scalar(name, value, step) def advertise(): diff --git a/tools/program.py b/tools/program.py index 251d7f81..2428409e 100644 --- a/tools/program.py +++ b/tools/program.py @@ -387,7 +387,7 @@ def compile(config, program, loss_name=None): total_step = 0 -def run(dataloader, exe, program, fetchs, epoch=0, mode='train', vdl_dir=None): +def run(dataloader, exe, program, fetchs, epoch=0, mode='train', vdl_writer=None): """ Feed data to the model and fetch the measures and loss @@ -415,9 +415,9 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train', vdl_dir=None): metric_list[i].update(m[0], len(batch[0])) fetchs_str = ''.join([str(m.value) + ' ' for m in metric_list] + [batch_time.value]) + 's' - if vdl_dir: + if vdl_writer: global total_step - logger.scaler('loss', metrics[0][0], total_step, vdl_dir) + logger.scaler('loss', metrics[0][0], total_step, vdl_writer) total_step += 1 if mode == 'eval': logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str)) diff --git a/tools/train.py b/tools/train.py index 57832389..922d871f 100644 --- a/tools/train.py +++ b/tools/train.py @@ -19,6 +19,7 @@ from __future__ import print_function import argparse import os +from visualdl import LogWriter import paddle.fluid as fluid from paddle.fluid.incubate.fleet.base import role_maker from paddle.fluid.incubate.fleet.collective import fleet @@ -96,10 +97,12 @@ def main(args): compiled_valid_prog = program.compile(config, valid_prog) compiled_train_prog = fleet.main_program + vdl_writer = LogWriter(args.vdl_dir) if args.vdl_dir else None + for epoch_id in range(config.epochs): # 1. train with train dataset program.run(train_dataloader, exe, compiled_train_prog, train_fetchs, - epoch_id, 'train', args.vdl_dir) + epoch_id, 'train', vdl_writer) if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0: # 2. validate with validate dataset if config.validate and epoch_id % config.valid_interval == 0: -- GitLab