未验证 提交 e87436a5 编写于 作者: K Kaipeng Deng 提交者: GitHub

DistributedBatchSampler add num_replicas and rank. test=develop (#26315)

上级 241b44db
......@@ -49,6 +49,13 @@ class DistributedBatchSampler(BatchSampler):
`__len__` for BatchSampler to get sample
number of data source.
batch_size(int): sample indice number in a mini-batch indices.
num_replicas(int, optional): porcess number in distributed training.
If :attr:`num_replicas` is None, :attr:`num_replicas` will be
retrieved from :code:`paddle.fluid.dygraph.parallel.ParallenEnv`.
Default None.
rank(int, optional): the rank of the current process among :attr:`num_replicas`
processes. If :attr:`rank` is None, :attr:`rank` is retrieved from
:code:`paddle.fluid.dygraph.parallel.ParallenEnv`. Default None.
shuffle(bool): whther to shuffle indices order before genrating
batch indices. Default False.
drop_last(bool): whether drop the last incomplete batch dataset size
......@@ -84,7 +91,13 @@ class DistributedBatchSampler(BatchSampler):
break
"""
def __init__(self, dataset, batch_size, shuffle=False, drop_last=False):
def __init__(self,
dataset,
batch_size,
num_replicas=None,
rank=None,
shuffle=False,
drop_last=False):
self.dataset = dataset
assert isinstance(batch_size, int) and batch_size > 0, \
......@@ -96,9 +109,21 @@ class DistributedBatchSampler(BatchSampler):
assert isinstance(drop_last, bool), \
"drop_last should be a boolean number"
if num_replicas is not None:
assert isinstance(num_replicas, int) and num_replicas > 0, \
"num_replicas should be a positive integer"
self.nranks = num_replicas
else:
self.nranks = ParallelEnv().nranks
if rank is not None:
assert isinstance(rank, int) and rank >= 0, \
"rank should be a non-negative integer"
self.local_rank = rank
else:
self.local_rank = ParallelEnv().local_rank
self.drop_last = drop_last
self.nranks = ParallelEnv().nranks
self.local_rank = ParallelEnv().local_rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.nranks))
self.total_size = self.num_samples * self.nranks
......
......@@ -169,6 +169,12 @@ class TestModel(unittest.TestCase):
def test_fit_static(self):
self.fit(False)
def test_fit_dynamic_with_rank(self):
self.fit(True, 2, 0)
def test_fit_static_with_rank(self):
self.fit(False, 2, 0)
def test_evaluate_dygraph(self):
self.evaluate(True)
......@@ -184,7 +190,7 @@ class TestModel(unittest.TestCase):
def test_prepare_context(self):
prepare_distributed_context()
def fit(self, dynamic):
def fit(self, dynamic, num_replicas=None, rank=None):
fluid.enable_dygraph(self.device) if dynamic else None
seed = 333
fluid.default_startup_program().random_seed = seed
......@@ -204,9 +210,17 @@ class TestModel(unittest.TestCase):
np.testing.assert_allclose(result['acc'], self.acc1)
train_sampler = DistributedBatchSampler(
self.train_dataset, batch_size=64, shuffle=False)
self.train_dataset,
batch_size=64,
shuffle=False,
num_replicas=num_replicas,
rank=rank)
val_sampler = DistributedBatchSampler(
self.val_dataset, batch_size=64, shuffle=False)
self.val_dataset,
batch_size=64,
shuffle=False,
num_replicas=num_replicas,
rank=rank)
train_loader = fluid.io.DataLoader(
self.train_dataset,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册