提交 82c888f0 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4930 Fix CI cifar hang issue

Merge pull request !4930 from xiefangqi/fix_cifar_nofile_issue
...@@ -336,6 +336,9 @@ Status CifarOp::GetCifarFiles() { ...@@ -336,6 +336,9 @@ Status CifarOp::GetCifarFiles() {
std::string err_msg = "Unable to open directory " + dataset_directory.toString(); std::string err_msg = "Unable to open directory " + dataset_directory.toString();
RETURN_STATUS_UNEXPECTED(err_msg); RETURN_STATUS_UNEXPECTED(err_msg);
} }
if (cifar_files_.size() == 0) {
RETURN_STATUS_UNEXPECTED("No .bin files found under " + folder_path_);
}
std::sort(cifar_files_.begin(), cifar_files_.end()); std::sort(cifar_files_.begin(), cifar_files_.end());
return Status::OK(); return Status::OK();
} }
......
...@@ -24,6 +24,7 @@ from mindspore import log as logger ...@@ -24,6 +24,7 @@ from mindspore import log as logger
DATA_DIR_10 = "../data/dataset/testCifar10Data" DATA_DIR_10 = "../data/dataset/testCifar10Data"
DATA_DIR_100 = "../data/dataset/testCifar100Data" DATA_DIR_100 = "../data/dataset/testCifar100Data"
NO_BIN_DIR = "../data/dataset/testMnistData"
def load_cifar(path, kind="cifar10"): def load_cifar(path, kind="cifar10"):
...@@ -208,6 +209,12 @@ def test_cifar10_exception(): ...@@ -208,6 +209,12 @@ def test_cifar10_exception():
with pytest.raises(ValueError, match=error_msg_6): with pytest.raises(ValueError, match=error_msg_6):
ds.Cifar10Dataset(DATA_DIR_10, shuffle=False, num_parallel_workers=88) ds.Cifar10Dataset(DATA_DIR_10, shuffle=False, num_parallel_workers=88)
error_msg_7 = "No .bin files found"
with pytest.raises(RuntimeError, match=error_msg_7):
ds1 = ds.Cifar10Dataset(NO_BIN_DIR)
for _ in ds1.__iter__():
pass
def test_cifar10_visualize(plot=False): def test_cifar10_visualize(plot=False):
""" """
...@@ -352,6 +359,12 @@ def test_cifar100_exception(): ...@@ -352,6 +359,12 @@ def test_cifar100_exception():
with pytest.raises(ValueError, match=error_msg_6): with pytest.raises(ValueError, match=error_msg_6):
ds.Cifar100Dataset(DATA_DIR_100, shuffle=False, num_parallel_workers=88) ds.Cifar100Dataset(DATA_DIR_100, shuffle=False, num_parallel_workers=88)
error_msg_7 = "No .bin files found"
with pytest.raises(RuntimeError, match=error_msg_7):
ds1 = ds.Cifar100Dataset(NO_BIN_DIR)
for _ in ds1.__iter__():
pass
def test_cifar100_visualize(plot=False): def test_cifar100_visualize(plot=False):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册