提交 c6882656 编写于 作者: J jonyguo

fix: when use MindDataset block_reade=True hung

上级 9e17b996
......@@ -785,6 +785,8 @@ vector<std::string> ShardReader::GetAllColumns() {
MSRStatus ShardReader::CreateTasksByBlock(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
const std::vector<std::shared_ptr<ShardOperator>> &operators) {
vector<std::string> columns = GetAllColumns();
CheckIfColumnInIndex(columns);
for (const auto &rg : row_group_summary) {
auto shard_id = std::get<0>(rg);
auto group_id = std::get<1>(rg);
......
......@@ -143,6 +143,7 @@ class FileWriter:
ParamTypeError: If index field is invalid.
MRMDefineIndexError: If index field is not primitive type.
MRMAddIndexError: If failed to add index field.
MRMGetMetaError: If the schema is not set or get meta failed.
"""
if not index_fields or not isinstance(index_fields, list):
raise ParamTypeError('index_fields', 'list')
......
......@@ -24,7 +24,7 @@ from mindspore import log as logger
from .cifar100 import Cifar100
from ..common.exceptions import PathNotExistsError
from ..filewriter import FileWriter
from ..shardutils import check_filename
from ..shardutils import check_filename, SUCCESS
try:
cv2 = import_module("cv2")
except ModuleNotFoundError:
......@@ -98,8 +98,11 @@ class Cifar100ToMR:
data_list = _construct_raw_data(images, fine_labels, coarse_labels)
test_data_list = _construct_raw_data(test_images, test_fine_labels, test_coarse_labels)
_generate_mindrecord(self.destination, data_list, fields, "img_train")
_generate_mindrecord(self.destination + "_test", test_data_list, fields, "img_test")
if _generate_mindrecord(self.destination, data_list, fields, "img_train") != SUCCESS:
return FAILED
if _generate_mindrecord(self.destination + "_test", test_data_list, fields, "img_test") != SUCCESS:
return FAILED
return SUCCESS
def _construct_raw_data(images, fine_labels, coarse_labels):
"""
......
......@@ -47,7 +47,9 @@ def add_and_remove_cv_file():
os.remove("{}.db".format(x)) if os.path.exists("{}.db".format(x)) else None
writer = FileWriter(CV_FILE_NAME, FILES_NUM)
data = get_data(CV_DIR_NAME)
cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"},
cv_schema_json = {"id": {"type": "int32"},
"file_name": {"type": "string"},
"label": {"type": "int32"},
"data": {"type": "bytes"}}
writer.add_schema(cv_schema_json, "img_schema")
writer.add_index(["file_name", "label"])
......@@ -226,6 +228,24 @@ def test_cv_minddataset_blockreader_tutorial(add_and_remove_cv_file):
num_iter += 1
assert num_iter == 20
def test_cv_minddataset_blockreader_some_field_not_in_index_tutorial(add_and_remove_cv_file):
"""tutorial for cv minddataset."""
columns_list = ["id", "data", "label"]
num_readers = 4
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, shuffle=False,
block_reader=True)
assert data_set.get_dataset_size() == 10
repeat_num = 2
data_set = data_set.repeat(repeat_num)
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info("-------------- block reader repeat tow {} -----------------".format(num_iter))
logger.info("-------------- item[id]: {} ----------------------------".format(item["id"]))
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
num_iter += 1
assert num_iter == 20
def test_cv_minddataset_reader_basic_tutorial(add_and_remove_cv_file):
"""tutorial for cv minderdataset."""
......@@ -359,13 +379,14 @@ def get_data(dir_name):
lines = file_reader.readlines()
data_list = []
for line in lines:
for i, line in enumerate(lines):
try:
filename, label = line.split(",")
label = label.strip("\n")
with open(os.path.join(img_dir, filename), "rb") as file_reader:
img = file_reader.read()
data_json = {"file_name": filename,
data_json = {"id": i,
"file_name": filename,
"data": img,
"label": int(label)}
data_list.append(data_json)
......
......@@ -18,6 +18,7 @@ import pytest
from mindspore.mindrecord import Cifar100ToMR
from mindspore.mindrecord import FileReader
from mindspore.mindrecord import MRMOpenError
from mindspore.mindrecord import SUCCESS
from mindspore import log as logger
CIFAR100_DIR = "../data/mindrecord/testCifar100Data"
......@@ -26,7 +27,8 @@ MINDRECORD_FILE = "./cifar100.mindrecord"
def test_cifar100_to_mindrecord_without_index_fields():
"""test transform cifar100 dataset to mindrecord without index fields."""
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, MINDRECORD_FILE)
cifar100_transformer.transform()
ret = cifar100_transformer.transform()
assert ret == SUCCESS, "Failed to tranform from cifar100 to mindrecord"
assert os.path.exists(MINDRECORD_FILE)
assert os.path.exists(MINDRECORD_FILE + "_test")
read()
......
......@@ -16,7 +16,7 @@
import os
import pytest
from mindspore.mindrecord import FileWriter, FileReader, MindPage
from mindspore.mindrecord import MRMOpenError, MRMGenerateIndexError, ParamValueError
from mindspore.mindrecord import MRMOpenError, MRMGenerateIndexError, ParamValueError, MRMGetMetaError
from mindspore import log as logger
from utils import get_data
......@@ -280,3 +280,9 @@ def test_cv_file_writer_shard_num_greater_than_1000():
with pytest.raises(ParamValueError) as err:
FileWriter(CV_FILE_NAME, 1001)
assert 'Shard number should between' in str(err.value)
def test_add_index_without_add_schema():
with pytest.raises(MRMGetMetaError) as err:
fw = FileWriter(CV_FILE_NAME)
fw.add_index(["label"])
assert 'Failed to get meta info' in str(err.value)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册