test_sampler.py 9.6 KB
Newer Older
Z
zhunaipan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
J
jinyaohui 已提交
15
import numpy as np
16
import pytest
J
jinyaohui 已提交
17

Z
zhunaipan 已提交
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
import mindspore.dataset as ds
from mindspore import log as logger


# test5trainimgs.json contains 5 images whose un-decoded shape is [83554, 54214, 65512, 54214, 64631]
# the label of each image is [0,0,0,1,1] each image can be uniquely identified
# via the following lookup table (dict){(83554, 0): 0, (54214, 0): 1, (54214, 1): 2, (65512, 0): 3, (64631, 1): 4}

def test_sequential_sampler(print_res=False):
    manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
    map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}

    def test_config(num_samples, num_repeats=None):
        sampler = ds.SequentialSampler()
        data1 = ds.ManifestDataset(manifest_file, num_samples=num_samples, sampler=sampler)
        if num_repeats is not None:
            data1 = data1.repeat(num_repeats)
        res = []
        for item in data1.create_dict_iterator():
            logger.info("item[image].shape[0]: {}, item[label].item(): {}"
                        .format(item["image"].shape[0], item["label"].item()))
            res.append(map[(item["image"].shape[0], item["label"].item())])
        if print_res:
            logger.info("image.shapes and labels: {}".format(res))
        return res

    assert test_config(num_samples=3, num_repeats=None) == [0, 1, 2]
    assert test_config(num_samples=None, num_repeats=2) == [0, 1, 2, 3, 4] * 2
    assert test_config(num_samples=4, num_repeats=2) == [0, 1, 2, 3] * 2


def test_random_sampler(print_res=False):
    manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
    map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}

    def test_config(replacement, num_samples, num_repeats):
        sampler = ds.RandomSampler(replacement=replacement, num_samples=num_samples)
        data1 = ds.ManifestDataset(manifest_file, sampler=sampler)
        data1 = data1.repeat(num_repeats)
        res = []
        for item in data1.create_dict_iterator():
            res.append(map[(item["image"].shape[0], item["label"].item())])
        if print_res:
            logger.info("image.shapes and labels: {}".format(res))
        return res

    # this tests that each epoch COULD return different samples than the previous epoch
    assert len(set(test_config(replacement=False, num_samples=2, num_repeats=6))) > 2
    # the following two tests test replacement works
    ordered_res = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4]
    assert sorted(test_config(replacement=False, num_samples=None, num_repeats=4)) == ordered_res
    assert sorted(test_config(replacement=True, num_samples=None, num_repeats=4)) != ordered_res


def test_random_sampler_multi_iter(print_res=False):
    manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
    map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}

    def test_config(replacement, num_samples, num_repeats, validate):
        sampler = ds.RandomSampler(replacement=replacement, num_samples=num_samples)
        data1 = ds.ManifestDataset(manifest_file, sampler=sampler)
        while num_repeats > 0:
            res = []
            for item in data1.create_dict_iterator():
                res.append(map[(item["image"].shape[0], item["label"].item())])
            if print_res:
                logger.info("image.shapes and labels: {}".format(res))
            if validate != sorted(res):
                break
            num_repeats -= 1
        assert num_repeats > 0

    test_config(replacement=True, num_samples=5, num_repeats=5, validate=[0, 1, 2, 3, 4, 5])


93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
def test_sampler_py_api():
    sampler = ds.SequentialSampler().create()
    sampler.set_num_rows(128)
    sampler.set_num_samples(64)
    sampler.initialize()
    sampler.get_indices()

    sampler = ds.RandomSampler().create()
    sampler.set_num_rows(128)
    sampler.set_num_samples(64)
    sampler.initialize()
    sampler.get_indices()

    sampler = ds.DistributedSampler(8, 4).create()
    sampler.set_num_rows(128)
    sampler.set_num_samples(64)
    sampler.initialize()
    sampler.get_indices()


113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
def test_python_sampler():
    manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
    map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}

    class Sp1(ds.Sampler):
        def __iter__(self):
            return iter([i for i in range(self.dataset_size)])

    class Sp2(ds.Sampler):
        def __init__(self):
            super(Sp2, self).__init__()
            # at this stage, self.dataset_size and self.num_samples are not yet known
            self.cnt = 0

        def __iter__(self):  # first epoch, all 0, second epoch all 1, third all 2 etc.. ...
            return iter([self.cnt for i in range(self.num_samples)])

        def reset(self):
            self.cnt = (self.cnt + 1) % self.dataset_size

    def test_config(num_samples, num_repeats, sampler):
        data1 = ds.ManifestDataset(manifest_file, num_samples=num_samples, sampler=sampler)
        if num_repeats is not None:
            data1 = data1.repeat(num_repeats)
        res = []
        for item in data1.create_dict_iterator():
            logger.info("item[image].shape[0]: {}, item[label].item(): {}"
                        .format(item["image"].shape[0], item["label"].item()))
            res.append(map[(item["image"].shape[0], item["label"].item())])
        # print(res)
        return res

    def test_generator():
        class MySampler(ds.Sampler):
            def __iter__(self):
                for i in range(99, -1, -1):
                    yield i

J
jinyaohui 已提交
151
        data1 = ds.GeneratorDataset([(np.array(i),) for i in range(100)], ["data"], sampler=MySampler())
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
        i = 99
        for data in data1:
            assert data[0] == (np.array(i),)
            i = i - 1

    assert test_config(5, 2, Sp1()) == [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
    assert test_config(2, 6, Sp2()) == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 0]
    test_generator()

    sp1 = Sp1().create()
    sp1.set_num_rows(5)
    sp1.set_num_samples(5)
    sp1.initialize()
    assert list(sp1.get_indices()) == [0, 1, 2, 3, 4]


168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
def test_subset_sampler():
    manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
    map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}

    def test_config(num_samples, start_index, subset_size):
        sampler = ds.SubsetSampler(start_index, subset_size)
        d = ds.ManifestDataset(manifest_file, sampler=sampler)

        res = []
        for item in d.create_dict_iterator():
            res.append(map[(item["image"].shape[0], item["label"].item())])

        return res

    with pytest.raises(RuntimeError) as info:
        test_config(5, 0, 0)
    assert "subset_size <= 0" in str(info.value)

    assert test_config(5, 0, 1) == [0]
    assert test_config(5, 0, 2) == [0, 1]
    assert test_config(5, 0, 3) == [0, 1, 2]
    assert test_config(5, 0, 4) == [0, 1, 2, 3]
    assert test_config(5, 0, 5) == [0, 1, 2, 3, 4]
    assert test_config(5, 1, 1) == [1]
    assert test_config(5, 2, 3) == [2, 3, 4]
    assert test_config(5, 3, 2) == [3, 4]
    assert test_config(5, 4, 1) == [4]


197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
def test_sampler_chain():
    manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
    map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}

    def test_config(num_shards, shard_id):
        sampler = ds.DistributedSampler(num_shards, shard_id, False)
        child_sampler = ds.SequentialSampler()
        sampler.add_child(child_sampler)

        data1 = ds.ManifestDataset(manifest_file, num_samples=5, sampler=sampler)

        res = []
        for item in data1.create_dict_iterator():
            logger.info("item[image].shape[0]: {}, item[label].item(): {}"
                        .format(item["image"].shape[0], item["label"].item()))
            res.append(map[(item["image"].shape[0], item["label"].item())])
        return res

    assert test_config(2, 0) == [0, 2, 4]
    assert test_config(2, 1) == [1, 3, 0]
    assert test_config(5, 0) == [0]
    assert test_config(5, 1) == [1]
    assert test_config(5, 2) == [2]
    assert test_config(5, 3) == [3]
    assert test_config(5, 4) == [4]

223 224 225 226 227 228 229 230 231 232 233 234 235 236
def test_add_sampler_invalid_input():
    manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
    map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
    data1 = ds.ManifestDataset(manifest_file)

    with pytest.raises(TypeError) as info:
        data1.use_sampler(1)
    assert "not an instance of a sampler" in str(info.value)

    with pytest.raises(TypeError) as info:
        data1.use_sampler("sampler")
    assert "not an instance of a sampler" in str(info.value)


Z
zhunaipan 已提交
237 238 239 240
if __name__ == '__main__':
    test_sequential_sampler(True)
    test_random_sampler(True)
    test_random_sampler_multi_iter(True)
241
    test_sampler_py_api()
J
jinyaohui 已提交
242
    test_python_sampler()
243
    test_subset_sampler()
244
    test_sampler_chain()
245
    test_add_sampler_invalid_input()