提交 bc581d59 编写于 作者: M Megvii Engine Team

fix(mge/data): fix ReplacementSampler for distributed training

GitOrigin-RevId: 1142941cc09ec1595b92f767aaa9f855829efeeb
上级 c946a21f
...@@ -58,10 +58,9 @@ class MapSampler(Sampler): ...@@ -58,10 +58,9 @@ class MapSampler(Sampler):
"drop_last should be a boolean value, but got " "drop_last should be a boolean value, but got "
"drop_last={}".format(drop_last) "drop_last={}".format(drop_last)
) )
if num_samples is not None and ( if num_samples is not None and (
not isinstance(num_samples, int) not isinstance(num_samples, int) or num_samples <= 0
or isinstance(num_samples, bool)
or num_samples <= 0
): ):
raise ValueError( raise ValueError(
"num_samples should be a positive integer " "num_samples should be a positive integer "
...@@ -83,6 +82,14 @@ class MapSampler(Sampler): ...@@ -83,6 +82,14 @@ class MapSampler(Sampler):
num_samples = len(self.dataset) num_samples = len(self.dataset)
self.num_samples = int(math.ceil(num_samples / self.world_size)) self.num_samples = int(math.ceil(num_samples / self.world_size))
if self.num_samples < self.batch_size:
raise ValueError(
"num_samples should be greater than batch_size "
", but got num_samples={} and batch_size={}".format(
self.num_samples, self.batch_size
)
)
# Make sure seeds are the same at each rank # Make sure seeds are the same at each rank
if seed is None and self.world_size > 1: if seed is None and self.world_size > 1:
seed = 0 seed = 0
...@@ -297,13 +304,12 @@ class ReplacementSampler(MapSampler): ...@@ -297,13 +304,12 @@ class ReplacementSampler(MapSampler):
def sample(self) -> List: def sample(self) -> List:
n = len(self.dataset) n = len(self.dataset)
indices = self.rng.choice( sample_size = self.num_samples * self.world_size
n, size=self.num_samples, replace=True, p=self.weights indices = self.rng.choice(n, size=sample_size, replace=True, p=self.weights)
)
return indices.tolist() return indices.tolist()
class Infinite(MapSampler): class Infinite(Sampler):
r"""Infinite Sampler warper for basic sampler.""" r"""Infinite Sampler warper for basic sampler."""
def sample(self): def sample(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册