diff --git a/official/vision/classification/shufflenet/train.py b/official/vision/classification/shufflenet/train.py index 65d07a42cc49dc1fd94ba7fefa8a7705247b69cf..f4a18081998ecff7a60837446850f16dbcab0fe1 100644 --- a/official/vision/classification/shufflenet/train.py +++ b/official/vision/classification/shufflenet/train.py @@ -112,16 +112,6 @@ def get_parameters(model): return groups -def infinite_iter(loader): - iterator = iter(loader) - while True: - try: - yield next(iterator) - except StopIteration: - iterator = iter(loader) - yield next(iterator) - - def worker(rank, world_size, args): if world_size > 1: # Initialize distributed process group @@ -174,9 +164,9 @@ def worker(rank, world_size, args): # Build train and valid datasets logger.info("preparing dataset..") train_dataset = data.dataset.ImageNet(args.data, train=True) - train_sampler = data.RandomSampler( + train_sampler = data.Infinite(data.RandomSampler( train_dataset, batch_size=args.batch_size, drop_last=True - ) + )) train_queue = data.DataLoader( train_dataset, sampler=train_sampler, @@ -193,7 +183,6 @@ def worker(rank, world_size, args): ), num_workers=args.workers, ) - train_queue = infinite_iter(train_queue) valid_dataset = data.dataset.ImageNet(args.data, train=False) valid_sampler = data.SequentialSampler(