提交 5aecef5d 编写于 作者: M Megvii Engine Team

fix(mge/data): fix ReplacementSampler for distributed training

GitOrigin-RevId: 1142941cc09ec1595b92f767aaa9f855829efeeb
上级 e137cb82
......@@ -58,10 +58,9 @@ class MapSampler(Sampler):
"drop_last should be a boolean value, but got "
"drop_last={}".format(drop_last)
)
if num_samples is not None and (
not isinstance(num_samples, int)
or isinstance(num_samples, bool)
or num_samples <= 0
not isinstance(num_samples, int) or num_samples <= 0
):
raise ValueError(
"num_samples should be a positive integer "
......@@ -83,6 +82,14 @@ class MapSampler(Sampler):
num_samples = len(self.dataset)
self.num_samples = int(math.ceil(num_samples / self.world_size))
if self.num_samples < self.batch_size:
raise ValueError(
"num_samples should be greater than batch_size "
", but got num_samples={} and batch_size={}".format(
self.num_samples, self.batch_size
)
)
# Make sure seeds are the same at each rank
if seed is None and self.world_size > 1:
seed = 0
......@@ -297,13 +304,12 @@ class ReplacementSampler(MapSampler):
def sample(self) -> List:
n = len(self.dataset)
indices = self.rng.choice(
n, size=self.num_samples, replace=True, p=self.weights
)
sample_size = self.num_samples * self.world_size
indices = self.rng.choice(n, size=sample_size, replace=True, p=self.weights)
return indices.tolist()
class Infinite(MapSampler):
class Infinite(Sampler):
r"""Infinite Sampler warper for basic sampler."""
def sample(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册