test_distributed_sampler.py 2.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import unittest

from paddle.incubate.hapi.distributed import DistributedBatchSampler


class FakeDataset():
    def __init__(self):
        pass

    def __getitem__(self, index):
        return index

    def __len__(self):
        return 10


class TestDistributedBatchSampler(unittest.TestCase):
    def test_sampler(self):
        dataset = FakeDataset()
        sampler = DistributedBatchSampler(dataset, batch_size=1, shuffle=True)
        for batch_idx in sampler:
            batch_idx
            pass

    def test_multiple_gpus_sampler(self):
        dataset = FakeDataset()
        sampler1 = DistributedBatchSampler(
            dataset, batch_size=4, shuffle=True, drop_last=True)
        sampler2 = DistributedBatchSampler(
            dataset, batch_size=4, shuffle=True, drop_last=True)

        sampler1.nranks = 2
        sampler1.local_rank = 0
        sampler1.num_samples = int(
            math.ceil(len(dataset) * 1.0 / sampler1.nranks))
        sampler1.total_size = sampler1.num_samples * sampler1.nranks

        sampler2.nranks = 2
        sampler2.local_rank = 1
        sampler2.num_samples = int(
            math.ceil(len(dataset) * 1.0 / sampler2.nranks))
        sampler2.total_size = sampler2.num_samples * sampler2.nranks

        for batch_idx in sampler1:
            batch_idx
            pass

        for batch_idx in sampler2:
            batch_idx
            pass


if __name__ == '__main__':
    unittest.main()