提交 45a9977d 编写于 作者: M Megvii Engine Team

fix(mge/data): fix weighted random sampler

GitOrigin-RevId: d09cbbfffd2434dea3f56b3db2f642b5c236141a
上级 1918ed2c
......@@ -297,10 +297,10 @@ class ReplacementSampler(MapSampler):
def sample(self) -> List:
n = len(self.dataset)
if self.weights is None:
return self.rng.randint(n, size=self.num_samples).tolist()
else:
return self.rng.multinomial(n, self.weights, self.num_samples).tolist()
indices = self.rng.choice(
n, size=self.num_samples, replace=True, p=self.weights
)
return indices.tolist()
class Infinite(MapSampler):
......
......@@ -58,8 +58,34 @@ def test_random_sampler_seed():
def test_ReplacementSampler():
num_samples = 30
indices = list(range(20))
weights = list(range(20))
num_data = 20
indices = list(range(num_data))
sampler = ReplacementSampler(
ArrayDataset(indices), num_samples=num_samples, weights=None
)
assert len(list(each[0] for each in sampler)) == num_samples
num_data = 8
weights = list(range(num_data))
indices = list(range(num_data))
sampler = ReplacementSampler(
ArrayDataset(indices), num_samples=num_samples, weights=weights
)
assert len(list(each[0] for each in sampler)) == num_samples
iter = 1000
hist = [0 for _ in range(num_data)]
for _ in range(iter):
for index in sampler:
index = index[0]
hist[index] += 1
actual_weights = np.array(hist) / sum(hist)
desired_weights = np.array(weights) / sum(weights)
np.testing.assert_allclose(actual_weights, desired_weights, rtol=8e-2)
num_data = 50000
num_samples = 50000 * 30
weights = list(range(num_data))
indices = list(range(num_data))
sampler = ReplacementSampler(
ArrayDataset(indices), num_samples=num_samples, weights=weights
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册