提交 7d83a9ad 编写于 作者: M Megvii Engine Team

fix(imperative): infinite sampler support get batchsize

GitOrigin-RevId: 52e3d6524932e74432ad5c989cab3f7fddd9192f
上级 fa488338
......@@ -326,3 +326,10 @@ class Infinite(MapSampler):
def __len__(self):
return np.iinfo(np.int64).max
def __getattr__(self, name):
# if attribute could not be found in Infinite,
# try to find it in self.sampler
if name not in self.__dict__:
return getattr(self.sampler, name)
return self.__dict__[name]
......@@ -7,7 +7,12 @@ import numpy as np
import pytest
from megengine.data.dataset import ArrayDataset
from megengine.data.sampler import RandomSampler, ReplacementSampler, SequentialSampler
from megengine.data.sampler import (
Infinite,
RandomSampler,
ReplacementSampler,
SequentialSampler,
)
def test_sequential_sampler():
......@@ -25,6 +30,13 @@ def test_RandomSampler():
assert indices == sorted(list(each[0] for each in sample_indices))
def test_InfiniteSampler():
indices = list(range(20))
seque_sampler = SequentialSampler(ArrayDataset(indices), batch_size=2)
inf_sampler = Infinite(seque_sampler)
assert inf_sampler.batch_size == seque_sampler.batch_size
def test_random_sampler_seed():
seed = [0, 1]
indices = list(range(20))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册