未验证 提交 1e7e9023 编写于 作者: Y Yizhuang Zhou 提交者: GitHub

feat(cls/shufflenet) use native infinite sampler (#9)

上级 b5da0a1a
...@@ -112,16 +112,6 @@ def get_parameters(model): ...@@ -112,16 +112,6 @@ def get_parameters(model):
return groups return groups
def infinite_iter(loader):
iterator = iter(loader)
while True:
try:
yield next(iterator)
except StopIteration:
iterator = iter(loader)
yield next(iterator)
def worker(rank, world_size, args): def worker(rank, world_size, args):
if world_size > 1: if world_size > 1:
# Initialize distributed process group # Initialize distributed process group
...@@ -174,9 +164,9 @@ def worker(rank, world_size, args): ...@@ -174,9 +164,9 @@ def worker(rank, world_size, args):
# Build train and valid datasets # Build train and valid datasets
logger.info("preparing dataset..") logger.info("preparing dataset..")
train_dataset = data.dataset.ImageNet(args.data, train=True) train_dataset = data.dataset.ImageNet(args.data, train=True)
train_sampler = data.RandomSampler( train_sampler = data.Infinite(data.RandomSampler(
train_dataset, batch_size=args.batch_size, drop_last=True train_dataset, batch_size=args.batch_size, drop_last=True
) ))
train_queue = data.DataLoader( train_queue = data.DataLoader(
train_dataset, train_dataset,
sampler=train_sampler, sampler=train_sampler,
...@@ -193,7 +183,6 @@ def worker(rank, world_size, args): ...@@ -193,7 +183,6 @@ def worker(rank, world_size, args):
), ),
num_workers=args.workers, num_workers=args.workers,
) )
train_queue = infinite_iter(train_queue)
valid_dataset = data.dataset.ImageNet(args.data, train=False) valid_dataset = data.dataset.ImageNet(args.data, train=False)
valid_sampler = data.SequentialSampler( valid_sampler = data.SequentialSampler(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册