diff --git a/tools/program.py b/tools/program.py index db8e44df360d8d255406a13ae697f598d7a96a3a..7e54a2f8c2f1db8881aa476a309c8a8c563fcae5 100755 --- a/tools/program.py +++ b/tools/program.py @@ -199,8 +199,12 @@ def train(config, train_reader_cost = 0.0 batch_sum = 0 batch_start = time.time() - for idx, batch in enumerate(train_dataloader()): + max_iter = len(train_dataloader) - 1 if platform.system( + ) == "Windows" else len(train_dataloader) + for idx, batch in enumerate(train_dataloader): train_reader_cost += time.time() - batch_start + if idx >= max_iter: + break lr = optimizer.get_lr() images = batch[0] if use_srn: