diff --git a/tools/program.py b/tools/program.py index 33bf1adc98fc48289e9873d073501d76547c011d..41acb8665a345d1251fbf721e557212cd2771e04 100755 --- a/tools/program.py +++ b/tools/program.py @@ -352,7 +352,8 @@ def preprocess(): with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f: yaml.dump(dict(config), f, default_flow_style=False, sort_keys=False) - logger = get_logger(log_file='{}/train.log'.format(save_model_dir)) + logger = get_logger( + name='root', log_file='{}/train.log'.format(save_model_dir)) if config['Global']['use_visualdl']: from visualdl import LogWriter vdl_writer_path = '{}/vdl/'.format(save_model_dir) diff --git a/tools/train.py b/tools/train.py index bdba7dba772891d740baf03bc0adb963ae661332..32abe0fc76a585a00daba6709b5e819952fcd466 100755 --- a/tools/train.py +++ b/tools/train.py @@ -36,7 +36,6 @@ from ppocr.optimizer import build_optimizer from ppocr.postprocess import build_post_process from ppocr.metrics import build_metric from ppocr.utils.save_load import init_model -from ppocr.utils.utility import print_dict import tools.program as program dist.get_world_size() @@ -61,7 +60,7 @@ def main(config, device, logger, vdl_writer): global_config) # build model - #for rec algorithm + # for rec algorithm if hasattr(post_process_class, 'character'): char_num = len(getattr(post_process_class, 'character')) config['Architecture']["Head"]['out_channels'] = char_num @@ -81,10 +80,11 @@ def main(config, device, logger, vdl_writer): # build metric eval_class = build_metric(config['Metric']) - # load pretrain model pre_best_model_dict = init_model(config, model, logger, optimizer) + logger.info('train dataloader has {} iter, valid dataloader has {} iter'. + format(len(train_dataloader), len(valid_dataloader))) # start train program.train(config, train_dataloader, valid_dataloader, device, model, loss_class, optimizer, lr_scheduler, post_process_class, @@ -92,8 +92,7 @@ def main(config, device, logger, vdl_writer): def test_reader(config, device, logger): - loader = build_dataloader(config, 'Train', device) - # loader = build_dataloader(config, 'Eval', device) + loader = build_dataloader(config, 'Train', device, logger) import time starttime = time.time() count = 0