提交 e3e78204 编写于 作者: X xiefangqi

fix cifar stuck problem

上级 d00f7d8f
......@@ -336,6 +336,9 @@ Status CifarOp::GetCifarFiles() {
std::string err_msg = "Unable to open directory " + dataset_directory.toString();
RETURN_STATUS_UNEXPECTED(err_msg);
}
if (cifar_files_.size() == 0) {
RETURN_STATUS_UNEXPECTED("No .bin files found under " + folder_path_);
}
std::sort(cifar_files_.begin(), cifar_files_.end());
return Status::OK();
}
......
......@@ -24,6 +24,7 @@ from mindspore import log as logger
DATA_DIR_10 = "../data/dataset/testCifar10Data"
DATA_DIR_100 = "../data/dataset/testCifar100Data"
NO_BIN_DIR = "../data/dataset/testMnistData"
def load_cifar(path, kind="cifar10"):
......@@ -208,6 +209,12 @@ def test_cifar10_exception():
with pytest.raises(ValueError, match=error_msg_6):
ds.Cifar10Dataset(DATA_DIR_10, shuffle=False, num_parallel_workers=88)
error_msg_7 = "No .bin files found"
with pytest.raises(RuntimeError, match=error_msg_7):
ds1 = ds.Cifar10Dataset(NO_BIN_DIR)
for _ in ds1.__iter__():
pass
def test_cifar10_visualize(plot=False):
"""
......@@ -352,6 +359,12 @@ def test_cifar100_exception():
with pytest.raises(ValueError, match=error_msg_6):
ds.Cifar100Dataset(DATA_DIR_100, shuffle=False, num_parallel_workers=88)
error_msg_7 = "No .bin files found"
with pytest.raises(RuntimeError, match=error_msg_7):
ds1 = ds.Cifar100Dataset(NO_BIN_DIR)
for _ in ds1.__iter__():
pass
def test_cifar100_visualize(plot=False):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册