From 91dee973dac6fbb00fb91c5decc4e390bfa2b2fa Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Fri, 6 Nov 2020 19:11:35 +0800 Subject: [PATCH] change log name to root --- tools/program.py | 3 ++- tools/train.py | 9 ++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tools/program.py b/tools/program.py index 33bf1adc..41acb866 100755 --- a/tools/program.py +++ b/tools/program.py @@ -352,7 +352,8 @@ def preprocess(): with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f: yaml.dump(dict(config), f, default_flow_style=False, sort_keys=False) - logger = get_logger(log_file='{}/train.log'.format(save_model_dir)) + logger = get_logger( + name='root', log_file='{}/train.log'.format(save_model_dir)) if config['Global']['use_visualdl']: from visualdl import LogWriter vdl_writer_path = '{}/vdl/'.format(save_model_dir) diff --git a/tools/train.py b/tools/train.py index bdba7dba..32abe0fc 100755 --- a/tools/train.py +++ b/tools/train.py @@ -36,7 +36,6 @@ from ppocr.optimizer import build_optimizer from ppocr.postprocess import build_post_process from ppocr.metrics import build_metric from ppocr.utils.save_load import init_model -from ppocr.utils.utility import print_dict import tools.program as program dist.get_world_size() @@ -61,7 +60,7 @@ def main(config, device, logger, vdl_writer): global_config) # build model - #for rec algorithm + # for rec algorithm if hasattr(post_process_class, 'character'): char_num = len(getattr(post_process_class, 'character')) config['Architecture']["Head"]['out_channels'] = char_num @@ -81,10 +80,11 @@ def main(config, device, logger, vdl_writer): # build metric eval_class = build_metric(config['Metric']) - # load pretrain model pre_best_model_dict = init_model(config, model, logger, optimizer) + logger.info('train dataloader has {} iter, valid dataloader has {} iter'. + format(len(train_dataloader), len(valid_dataloader))) # start train program.train(config, train_dataloader, valid_dataloader, device, model, loss_class, optimizer, lr_scheduler, post_process_class, @@ -92,8 +92,7 @@ def main(config, device, logger, vdl_writer): def test_reader(config, device, logger): - loader = build_dataloader(config, 'Train', device) - # loader = build_dataloader(config, 'Eval', device) + loader = build_dataloader(config, 'Train', device, logger) import time starttime = time.time() count = 0 -- GitLab