diff --git a/tools/train.py b/tools/train.py index fab10b6437cfe5ad30bcb49d7b1b884ba9665a79..c12cf005638ad8be8f8b8fce42e682fed2f29be7 100755 --- a/tools/train.py +++ b/tools/train.py @@ -89,8 +89,10 @@ def main(config, device, logger, vdl_writer): # load pretrain model pre_best_model_dict = init_model(config, model, logger, optimizer) - logger.info('train dataloader has {} iters, valid dataloader has {} iters'. - format(len(train_dataloader), len(valid_dataloader))) + logger.info('train dataloader has {} iters'.format(len(train_dataloader))) + if valid_dataloader is not None: + logger.info('valid dataloader has {} iters'.format( + len(valid_dataloader))) # start train program.train(config, train_dataloader, valid_dataloader, device, model, loss_class, optimizer, lr_scheduler, post_process_class,