test_cache_map.py 5.9 KB
Newer Older
J
Jesse Lee 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing cache operator with mappable datasets
"""
J
Jesse Lee 已提交
18 19
import os
import pytest
J
Jesse Lee 已提交
20
import mindspore.dataset as ds
N
nhussain 已提交
21
import mindspore.dataset.vision.c_transforms as c_vision
J
Jesse Lee 已提交
22 23 24 25 26 27 28
from mindspore import log as logger
from util import save_and_check_md5

DATA_DIR = "../data/dataset/testImageNetData/train/"

GENERATE_GOLDEN = False

G
guansongsong 已提交
29

J
Jesse Lee 已提交
30
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
J
Jesse Lee 已提交
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
def test_cache_map_basic1():
    """
    Test mappable leaf with cache op right over the leaf

       Repeat
         |
     Map(decode)
         |
       Cache
         |
     ImageFolder
    """

    logger.info("Test cache map basic 1")

    some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)

    # This DATA_DIR only has 2 images in it
N
nhussain 已提交
49
    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
J
Jesse Lee 已提交
50 51 52 53 54 55 56 57 58
    decode_op = c_vision.Decode()
    ds1 = ds1.map(input_columns=["image"], operations=decode_op)
    ds1 = ds1.repeat(4)

    filename = "cache_map_01_result.npz"
    save_and_check_md5(ds1, filename, generate_golden=GENERATE_GOLDEN)

    logger.info("test_cache_map_basic1 Ended.\n")

J
Jesse Lee 已提交
59
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
J
Jesse Lee 已提交
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
def test_cache_map_basic2():
    """
    Test mappable leaf with the cache op later in the tree above the map(decode)

       Repeat
         |
       Cache
         |
     Map(decode)
         |
     ImageFolder
    """

    logger.info("Test cache map basic 2")

    some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)

    # This DATA_DIR only has 2 images in it
N
nhussain 已提交
78
    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
J
Jesse Lee 已提交
79 80 81 82 83 84 85 86 87
    decode_op = c_vision.Decode()
    ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
    ds1 = ds1.repeat(4)

    filename = "cache_map_02_result.npz"
    save_and_check_md5(ds1, filename, generate_golden=GENERATE_GOLDEN)

    logger.info("test_cache_map_basic2 Ended.\n")

J
Jesse Lee 已提交
88
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
J
Jesse Lee 已提交
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
def test_cache_map_basic3():
    """
    Test a repeat under mappable cache

        Cache
          |
      Map(decode)
          |
        Repeat
          |
      ImageFolder
    """

    logger.info("Test cache basic 3")

    some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)

    # This DATA_DIR only has 2 images in it
N
nhussain 已提交
107
    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
J
Jesse Lee 已提交
108 109 110
    decode_op = c_vision.Decode()
    ds1 = ds1.repeat(4)
    ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
Q
qianlong 已提交
111
    logger.info("ds1.dataset_size is ", ds1.get_dataset_size())
J
Jesse Lee 已提交
112 113

    num_iter = 0
114
    for _ in ds1.create_dict_iterator(num_epochs=1):
Q
qianlong 已提交
115
        logger.info("get data from dataset")
J
Jesse Lee 已提交
116 117 118 119 120 121
        num_iter += 1

    logger.info("Number of data in ds1: {} ".format(num_iter))
    assert num_iter == 8
    logger.info('test_cache_basic3 Ended.\n')

J
Jesse Lee 已提交
122
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
Q
qianlong 已提交
123
def test_cache_map_basic4():
G
guansongsong 已提交
124 125 126 127 128 129 130
    """
    Test different rows result in core dump
    """
    logger.info("Test cache basic 4")
    some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)

    # This DATA_DIR only has 2 images in it
N
nhussain 已提交
131
    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
G
guansongsong 已提交
132 133 134 135 136 137 138
    decode_op = c_vision.Decode()
    ds1 = ds1.repeat(4)
    ds1 = ds1.map(input_columns=["image"], operations=decode_op)
    logger.info("ds1.dataset_size is ", ds1.get_dataset_size())
    shape = ds1.output_shapes()
    logger.info(shape)
    num_iter = 0
139
    for _ in ds1.create_dict_iterator(num_epochs=1):
G
guansongsong 已提交
140 141 142 143 144 145
        logger.info("get data from dataset")
        num_iter += 1

    logger.info("Number of data in ds1: {} ".format(num_iter))
    assert num_iter == 8
    logger.info('test_cache_basic3 Ended.\n')
Q
qianlong 已提交
146

J
Jesse Lee 已提交
147
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
J
Jesse Lee 已提交
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
def test_cache_map_failure1():
    """
    Test nested cache (failure)

        Repeat
          |
        Cache
          |
      Map(decode)
          |
        Cache
          |
      ImageFolder

    """
    logger.info("Test cache failure 1")

    some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)

    # This DATA_DIR only has 2 images in it
N
nhussain 已提交
168
    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
J
Jesse Lee 已提交
169 170 171 172 173 174
    decode_op = c_vision.Decode()
    ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
    ds1 = ds1.repeat(4)

    try:
        num_iter = 0
175
        for _ in ds1.create_dict_iterator(num_epochs=1):
J
Jesse Lee 已提交
176 177 178 179 180 181 182 183
            num_iter += 1
    except RuntimeError as e:
        logger.info("Got an exception in DE: {}".format(str(e)))
        assert "Nested cache operations is not supported!" in str(e)

    assert num_iter == 0
    logger.info('test_cache_failure1 Ended.\n')

G
guansongsong 已提交
184

J
Jesse Lee 已提交
185 186
if __name__ == '__main__':
    test_cache_map_basic1()
Q
qianlong 已提交
187
    logger.info("test_cache_map_basic1 success.")
J
Jesse Lee 已提交
188
    test_cache_map_basic2()
Q
qianlong 已提交
189
    logger.info("test_cache_map_basic2 success.")
J
Jesse Lee 已提交
190
    test_cache_map_basic3()
Q
qianlong 已提交
191 192 193
    logger.info("test_cache_map_basic3 success.")
    test_cache_map_basic4()
    logger.info("test_cache_map_basic3 success.")
J
Jesse Lee 已提交
194
    test_cache_map_failure1()
Q
qianlong 已提交
195
    logger.info("test_cache_map_failure1 success.")