test_sampler.py 6.9 KB
Newer Older
Z
zhunaipan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import mindspore.dataset as ds
from mindspore import log as logger
17
import numpy as np
Z
zhunaipan 已提交
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90


# test5trainimgs.json contains 5 images whose un-decoded shape is [83554, 54214, 65512, 54214, 64631]
# the label of each image is [0,0,0,1,1] each image can be uniquely identified
# via the following lookup table (dict){(83554, 0): 0, (54214, 0): 1, (54214, 1): 2, (65512, 0): 3, (64631, 1): 4}

def test_sequential_sampler(print_res=False):
    manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
    map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}

    def test_config(num_samples, num_repeats=None):
        sampler = ds.SequentialSampler()
        data1 = ds.ManifestDataset(manifest_file, num_samples=num_samples, sampler=sampler)
        if num_repeats is not None:
            data1 = data1.repeat(num_repeats)
        res = []
        for item in data1.create_dict_iterator():
            logger.info("item[image].shape[0]: {}, item[label].item(): {}"
                        .format(item["image"].shape[0], item["label"].item()))
            res.append(map[(item["image"].shape[0], item["label"].item())])
        if print_res:
            logger.info("image.shapes and labels: {}".format(res))
        return res

    assert test_config(num_samples=3, num_repeats=None) == [0, 1, 2]
    assert test_config(num_samples=None, num_repeats=2) == [0, 1, 2, 3, 4] * 2
    assert test_config(num_samples=4, num_repeats=2) == [0, 1, 2, 3] * 2


def test_random_sampler(print_res=False):
    manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
    map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}

    def test_config(replacement, num_samples, num_repeats):
        sampler = ds.RandomSampler(replacement=replacement, num_samples=num_samples)
        data1 = ds.ManifestDataset(manifest_file, sampler=sampler)
        data1 = data1.repeat(num_repeats)
        res = []
        for item in data1.create_dict_iterator():
            res.append(map[(item["image"].shape[0], item["label"].item())])
        if print_res:
            logger.info("image.shapes and labels: {}".format(res))
        return res

    # this tests that each epoch COULD return different samples than the previous epoch
    assert len(set(test_config(replacement=False, num_samples=2, num_repeats=6))) > 2
    # the following two tests test replacement works
    ordered_res = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4]
    assert sorted(test_config(replacement=False, num_samples=None, num_repeats=4)) == ordered_res
    assert sorted(test_config(replacement=True, num_samples=None, num_repeats=4)) != ordered_res


def test_random_sampler_multi_iter(print_res=False):
    manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
    map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}

    def test_config(replacement, num_samples, num_repeats, validate):
        sampler = ds.RandomSampler(replacement=replacement, num_samples=num_samples)
        data1 = ds.ManifestDataset(manifest_file, sampler=sampler)
        while num_repeats > 0:
            res = []
            for item in data1.create_dict_iterator():
                res.append(map[(item["image"].shape[0], item["label"].item())])
            if print_res:
                logger.info("image.shapes and labels: {}".format(res))
            if validate != sorted(res):
                break
            num_repeats -= 1
        assert num_repeats > 0

    test_config(replacement=True, num_samples=5, num_repeats=5, validate=[0, 1, 2, 3, 4, 5])


91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
def test_sampler_py_api():
    sampler = ds.SequentialSampler().create()
    sampler.set_num_rows(128)
    sampler.set_num_samples(64)
    sampler.initialize()
    sampler.get_indices()

    sampler = ds.RandomSampler().create()
    sampler.set_num_rows(128)
    sampler.set_num_samples(64)
    sampler.initialize()
    sampler.get_indices()

    sampler = ds.DistributedSampler(8, 4).create()
    sampler.set_num_rows(128)
    sampler.set_num_samples(64)
    sampler.initialize()
    sampler.get_indices()


111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
def test_python_sampler():
    manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
    map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}

    class Sp1(ds.Sampler):
        def __iter__(self):
            return iter([i for i in range(self.dataset_size)])

    class Sp2(ds.Sampler):
        def __init__(self):
            super(Sp2, self).__init__()
            # at this stage, self.dataset_size and self.num_samples are not yet known
            self.cnt = 0

        def __iter__(self):  # first epoch, all 0, second epoch all 1, third all 2 etc.. ...
            return iter([self.cnt for i in range(self.num_samples)])

        def reset(self):
            self.cnt = (self.cnt + 1) % self.dataset_size

    def test_config(num_samples, num_repeats, sampler):
        data1 = ds.ManifestDataset(manifest_file, num_samples=num_samples, sampler=sampler)
        if num_repeats is not None:
            data1 = data1.repeat(num_repeats)
        res = []
        for item in data1.create_dict_iterator():
            logger.info("item[image].shape[0]: {}, item[label].item(): {}"
                        .format(item["image"].shape[0], item["label"].item()))
            res.append(map[(item["image"].shape[0], item["label"].item())])
        # print(res)
        return res

    def test_generator():
        class MySampler(ds.Sampler):
            def __iter__(self):
                for i in range(99, -1, -1):
                    yield i

J
jinyaohui 已提交
149
        data1 = ds.GeneratorDataset([(np.array(i),) for i in range(100)], ["data"], sampler=MySampler())
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
        i = 99
        for data in data1:
            assert data[0] == (np.array(i),)
            i = i - 1

    assert test_config(5, 2, Sp1()) == [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
    assert test_config(2, 6, Sp2()) == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 0]
    test_generator()

    sp1 = Sp1().create()
    sp1.set_num_rows(5)
    sp1.set_num_samples(5)
    sp1.initialize()
    assert list(sp1.get_indices()) == [0, 1, 2, 3, 4]


Z
zhunaipan 已提交
166 167 168 169
if __name__ == '__main__':
    test_sequential_sampler(True)
    test_random_sampler(True)
    test_random_sampler_multi_iter(True)
170
    test_sampler_py_api()
J
jinyaohui 已提交
171
    test_python_sampler()