diff --git a/tools/train.py b/tools/train.py index 7886316b0c417b0944d4d9646030d1baafaa9439..3e45600657245e466f60312a476c3644c4fa5092 100755 --- a/tools/train.py +++ b/tools/train.py @@ -88,20 +88,23 @@ def main(config, device, logger, vdl_writer): best_model_dict, logger, vdl_writer) -def test_reader(config, place, logger): - train_loader = build_dataloader(config['TRAIN'], place) +def test_reader(config, place, logger, global_config): + train_loader, _ = build_dataloader( + config['TRAIN'], place, global_config=global_config) import time starttime = time.time() count = 0 try: - for data in train_loader(): + for data in train_loader: count += 1 if count % 1 == 0: batch_time = time.time() - starttime starttime = time.time() - logger.info("reader: {}, {}, {}".format(count, - len(data), batch_time)) + logger.info("reader: {}, {}, {}".format( + count, len(data[0]), batch_time)) except Exception as e: + import traceback + traceback.print_exc() logger.info(e) logger.info("finish reader: {}, Success!".format(count)) @@ -130,7 +133,7 @@ def dis_main(): device)) main(config, device, logger, vdl_writer) - # test_reader(config, place, logger) + # test_reader(config, device, logger, config['Global']) if __name__ == '__main__':