“ef9a264b7a288a07c43ddb244c4f9ab0e8df90e4”上不存在“drivers/git@gitcode.net:openanolis/cloud-kernel.git”
test_minddataset_exception.py 8.7 KB
Newer Older
Z
zhunaipan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
#!/usr/bin/env python
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import os
import pytest

import mindspore.dataset as ds
from mindspore.mindrecord import FileWriter

CV_FILE_NAME = "./imagenet.mindrecord"
L
liyong 已提交
24
CV1_FILE_NAME = "./imagenet1.mindrecord"
Z
zhunaipan 已提交
25 26 27 28


def create_cv_mindrecord(files_num):
    """tutorial for cv dataset writer."""
Y
Yang 已提交
29 30 31 32
    if os.path.exists(CV_FILE_NAME):
        os.remove(CV_FILE_NAME)
    if os.path.exists("{}.db".format(CV_FILE_NAME)):
        os.remove("{}.db".format(CV_FILE_NAME))
Z
zhunaipan 已提交
33 34 35 36 37 38 39 40 41
    writer = FileWriter(CV_FILE_NAME, files_num)
    cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}}
    data = [{"file_name": "001.jpg", "label": 43, "data": bytes('0xffsafdafda', encoding='utf-8')}]
    writer.add_schema(cv_schema_json, "img_schema")
    writer.add_index(["file_name", "label"])
    writer.write_raw_data(data)
    writer.commit()


L
liyong 已提交
42 43
def create_diff_schema_cv_mindrecord(files_num):
    """tutorial for cv dataset writer."""
Y
Yang 已提交
44 45 46 47
    if os.path.exists(CV1_FILE_NAME):
        os.remove(CV1_FILE_NAME)
    if os.path.exists("{}.db".format(CV1_FILE_NAME)):
        os.remove("{}.db".format(CV1_FILE_NAME))
L
liyong 已提交
48 49 50 51 52 53 54 55
    writer = FileWriter(CV1_FILE_NAME, files_num)
    cv_schema_json = {"file_name_1": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}}
    data = [{"file_name_1": "001.jpg", "label": 43, "data": bytes('0xffsafdafda', encoding='utf-8')}]
    writer.add_schema(cv_schema_json, "img_schema")
    writer.add_index(["file_name_1", "label"])
    writer.write_raw_data(data)
    writer.commit()

J
jinyaohui 已提交
56

L
liyong 已提交
57 58
def create_diff_page_size_cv_mindrecord(files_num):
    """tutorial for cv dataset writer."""
Y
Yang 已提交
59 60 61 62
    if os.path.exists(CV1_FILE_NAME):
        os.remove(CV1_FILE_NAME)
    if os.path.exists("{}.db".format(CV1_FILE_NAME)):
        os.remove("{}.db".format(CV1_FILE_NAME))
L
liyong 已提交
63
    writer = FileWriter(CV1_FILE_NAME, files_num)
J
jinyaohui 已提交
64
    writer.set_page_size(1 << 26)  # 64MB
L
liyong 已提交
65 66 67 68 69 70 71
    cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}}
    data = [{"file_name": "001.jpg", "label": 43, "data": bytes('0xffsafdafda', encoding='utf-8')}]
    writer.add_schema(cv_schema_json, "img_schema")
    writer.add_index(["file_name", "label"])
    writer.write_raw_data(data)
    writer.commit()

J
jinyaohui 已提交
72

Z
zhunaipan 已提交
73 74 75 76 77
def test_cv_lack_json():
    """tutorial for cv minderdataset."""
    create_cv_mindrecord(1)
    columns_list = ["data", "file_name", "label"]
    num_readers = 4
Y
Yang 已提交
78 79
    with pytest.raises(Exception):
        ds.MindDataset(CV_FILE_NAME, "no_exist.json", columns_list, num_readers)
Z
zhunaipan 已提交
80 81 82 83 84 85 86 87 88
    os.remove(CV_FILE_NAME)
    os.remove("{}.db".format(CV_FILE_NAME))


def test_cv_lack_mindrecord():
    """tutorial for cv minderdataset."""
    columns_list = ["data", "file_name", "label"]
    num_readers = 4
    with pytest.raises(Exception, match="does not exist or permission denied"):
Y
Yang 已提交
89
        _ = ds.MindDataset("no_exist.mindrecord", columns_list, num_readers)
Z
zhunaipan 已提交
90 91 92 93 94 95 96 97 98 99


def test_invalid_mindrecord():
    with open('dummy.mindrecord', 'w') as f:
        f.write('just for test')
    columns_list = ["data", "file_name", "label"]
    num_readers = 4
    with pytest.raises(Exception, match="MindRecordOp init failed"):
        data_set = ds.MindDataset('dummy.mindrecord', columns_list, num_readers)
        num_iter = 0
Y
Yang 已提交
100
        for _ in data_set.create_dict_iterator():
Z
zhunaipan 已提交
101 102 103 104 105 106 107 108 109 110 111 112 113
            num_iter += 1
        assert num_iter == 0
    os.remove('dummy.mindrecord')


def test_minddataset_lack_db():
    create_cv_mindrecord(1)
    os.remove("{}.db".format(CV_FILE_NAME))
    columns_list = ["data", "file_name", "label"]
    num_readers = 4
    with pytest.raises(Exception, match="MindRecordOp init failed"):
        data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers)
        num_iter = 0
Y
Yang 已提交
114
        for _ in data_set.create_dict_iterator():
Z
zhunaipan 已提交
115 116 117
            num_iter += 1
        assert num_iter == 0
    os.remove(CV_FILE_NAME)
L
liyong 已提交
118 119 120 121 122 123 124 125 126 127


def test_cv_minddataset_pk_sample_error_class_column():
    create_cv_mindrecord(1)
    columns_list = ["data", "file_name", "label"]
    num_readers = 4
    sampler = ds.PKSampler(5, None, True, 'no_exsit_column')
    with pytest.raises(Exception, match="MindRecordOp launch failed"):
        data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, sampler=sampler)
        num_iter = 0
Y
Yang 已提交
128
        for _ in data_set.create_dict_iterator():
L
liyong 已提交
129 130 131 132
            num_iter += 1
    os.remove(CV_FILE_NAME)
    os.remove("{}.db".format(CV_FILE_NAME))

J
jinyaohui 已提交
133

L
liyong 已提交
134 135 136 137 138
def test_cv_minddataset_pk_sample_exclusive_shuffle():
    create_cv_mindrecord(1)
    columns_list = ["data", "file_name", "label"]
    num_readers = 4
    sampler = ds.PKSampler(2)
139
    with pytest.raises(Exception, match="sampler and shuffle cannot be specified at the same time."):
L
liyong 已提交
140
        data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers,
J
jinyaohui 已提交
141
                                  sampler=sampler, shuffle=False)
L
liyong 已提交
142
        num_iter = 0
Y
Yang 已提交
143
        for _ in data_set.create_dict_iterator():
L
liyong 已提交
144 145 146 147
            num_iter += 1
    os.remove(CV_FILE_NAME)
    os.remove("{}.db".format(CV_FILE_NAME))

J
jinyaohui 已提交
148

L
liyong 已提交
149 150 151 152 153 154 155
def test_cv_minddataset_reader_different_schema():
    create_cv_mindrecord(1)
    create_diff_schema_cv_mindrecord(1)
    columns_list = ["data", "label"]
    num_readers = 4
    with pytest.raises(Exception, match="MindRecordOp init failed"):
        data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list,
J
jinyaohui 已提交
156
                                  num_readers)
L
liyong 已提交
157
        num_iter = 0
Y
Yang 已提交
158
        for _ in data_set.create_dict_iterator():
L
liyong 已提交
159 160 161 162 163 164
            num_iter += 1
    os.remove(CV_FILE_NAME)
    os.remove("{}.db".format(CV_FILE_NAME))
    os.remove(CV1_FILE_NAME)
    os.remove("{}.db".format(CV1_FILE_NAME))

J
jinyaohui 已提交
165

L
liyong 已提交
166 167 168 169 170 171 172
def test_cv_minddataset_reader_different_page_size():
    create_cv_mindrecord(1)
    create_diff_page_size_cv_mindrecord(1)
    columns_list = ["data", "label"]
    num_readers = 4
    with pytest.raises(Exception, match="MindRecordOp init failed"):
        data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list,
J
jinyaohui 已提交
173
                                  num_readers)
L
liyong 已提交
174
        num_iter = 0
Y
Yang 已提交
175
        for _ in data_set.create_dict_iterator():
L
liyong 已提交
176 177 178 179 180
            num_iter += 1
    os.remove(CV_FILE_NAME)
    os.remove("{}.db".format(CV_FILE_NAME))
    os.remove(CV1_FILE_NAME)
    os.remove("{}.db".format(CV1_FILE_NAME))
181

J
jinyaohui 已提交
182

183 184 185 186
def test_minddataset_invalidate_num_shards():
    create_cv_mindrecord(1)
    columns_list = ["data", "label"]
    num_readers = 4
N
nhussain 已提交
187
    with pytest.raises(Exception) as error_info:
188
        data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, 2)
189
        num_iter = 0
Y
Yang 已提交
190
        for _ in data_set.create_dict_iterator():
191
            num_iter += 1
N
nhussain 已提交
192 193
    assert 'Input shard_id is not within the required interval of (0 to 0).' in repr(error_info)

194 195 196 197 198 199 200
    os.remove(CV_FILE_NAME)
    os.remove("{}.db".format(CV_FILE_NAME))

def test_minddataset_invalidate_shard_id():
    create_cv_mindrecord(1)
    columns_list = ["data", "label"]
    num_readers = 4
N
nhussain 已提交
201
    with pytest.raises(Exception) as error_info:
202 203
        data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, -1)
        num_iter = 0
Y
Yang 已提交
204
        for _ in data_set.create_dict_iterator():
205
            num_iter += 1
N
nhussain 已提交
206
    assert 'Input shard_id is not within the required interval of (0 to 0).' in repr(error_info)
207 208 209
    os.remove(CV_FILE_NAME)
    os.remove("{}.db".format(CV_FILE_NAME))

J
jinyaohui 已提交
210

211 212 213 214
def test_minddataset_shard_id_bigger_than_num_shard():
    create_cv_mindrecord(1)
    columns_list = ["data", "label"]
    num_readers = 4
N
nhussain 已提交
215
    with pytest.raises(Exception) as error_info:
216 217
        data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 2)
        num_iter = 0
Y
Yang 已提交
218
        for _ in data_set.create_dict_iterator():
219
            num_iter += 1
N
nhussain 已提交
220
    assert 'Input shard_id is not within the required interval of (0 to 1).' in repr(error_info)
221

N
nhussain 已提交
222
    with pytest.raises(Exception) as error_info:
223 224
        data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 5)
        num_iter = 0
Y
Yang 已提交
225
        for _ in data_set.create_dict_iterator():
226
            num_iter += 1
N
nhussain 已提交
227
    assert 'Input shard_id is not within the required interval of (0 to 1).' in repr(error_info)
228 229 230

    os.remove(CV_FILE_NAME)
    os.remove("{}.db".format(CV_FILE_NAME))