From 1e7e9023d52b7ecf79e306bfbce4865344a577f5 Mon Sep 17 00:00:00 2001 From: Yizhuang Zhou <62599194+zhouyizhuang-megvii@users.noreply.github.com> Date: Sun, 12 Apr 2020 18:23:59 +0800 Subject: [PATCH] feat(cls/shufflenet) use native infinite sampler (#9) --- .../vision/classification/shufflenet/train.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/official/vision/classification/shufflenet/train.py b/official/vision/classification/shufflenet/train.py index 65d07a4..f4a1808 100644 --- a/official/vision/classification/shufflenet/train.py +++ b/official/vision/classification/shufflenet/train.py @@ -112,16 +112,6 @@ def get_parameters(model): return groups -def infinite_iter(loader): - iterator = iter(loader) - while True: - try: - yield next(iterator) - except StopIteration: - iterator = iter(loader) - yield next(iterator) - - def worker(rank, world_size, args): if world_size > 1: # Initialize distributed process group @@ -174,9 +164,9 @@ def worker(rank, world_size, args): # Build train and valid datasets logger.info("preparing dataset..") train_dataset = data.dataset.ImageNet(args.data, train=True) - train_sampler = data.RandomSampler( + train_sampler = data.Infinite(data.RandomSampler( train_dataset, batch_size=args.batch_size, drop_last=True - ) + )) train_queue = data.DataLoader( train_dataset, sampler=train_sampler, @@ -193,7 +183,6 @@ def worker(rank, world_size, args): ), num_workers=args.workers, ) - train_queue = infinite_iter(train_queue) valid_dataset = data.dataset.ImageNet(args.data, train=False) valid_sampler = data.SequentialSampler( -- GitLab