diff --git a/tools/program.py b/tools/program.py index 33bd35c23d0b6676edfb9c537e497eb68b828c4a..ad6fcbd9b5a1213e7e88ef7c82fde07ff29bcb80 100755 --- a/tools/program.py +++ b/tools/program.py @@ -197,9 +197,11 @@ def train(config, train_reader_cost = 0.0 batch_sum = 0 batch_start = time.time() + max_iter = len(train_dataloader) - 1 if platform.system( + ) == "Windows" else len(train_dataloader) for idx, batch in enumerate(train_dataloader): train_reader_cost += time.time() - batch_start - if idx >= len(train_dataloader): + if idx >= max_iter: break lr = optimizer.get_lr() images = batch[0]