提交 db80f4ff 编写于 作者: Q qianlong

The num_samples and numRows in schema for TFRecordDataset are conflict

上级 46acf238
......@@ -162,7 +162,11 @@ Status StorageClient::numRowsFromFile(uint32_t &num_rows) const {
std::ifstream in(schemaFile);
nlohmann::json js;
in >> js;
num_rows = js.value("numRows", 0);
if (js.find("numRows") == js.end()) {
num_rows = MAX_INTEGER_INT32;
} else {
num_rows = js.value("numRows", 0);
}
if (num_rows == 0) {
std::string err_msg =
"Storage client has not properly done dataset "
......
......@@ -163,6 +163,9 @@ Status TFReaderOp::Init() {
if (total_rows_ == 0) {
total_rows_ = data_schema_->num_rows();
}
if (total_rows_ < 0) {
RETURN_STATUS_UNEXPECTED("The num_sample or numRows for TFRecordDataset should be greater than 0");
}
// Build the index with our files such that each file corresponds to a key id.
RETURN_IF_NOT_OK(filename_index_->insert(dataset_files_list_));
......
......@@ -1455,7 +1455,7 @@ class StorageDataset(SourceDataset):
Args:
dataset_files (list[str]): List of files to be read.
schema (str): Path to the json schema file.
schema (str): Path to the json schema file. If numRows(parsed from schema) is not exist, read the full dataset.
distribution (str, optional): Path of distribution config file (default="").
columns_list (list[str], optional): List of columns to be read (default=None, read all columns).
num_parallel_workers (int, optional): Number of parallel working threads (default=None).
......@@ -2193,7 +2193,10 @@ class TFRecordDataset(SourceDataset):
schema (str or Schema, optional): Path to the json schema file or schema object (default=None).
If the schema is not provided, the meta data from the TFData file is considered the schema.
columns_list (list[str], optional): List of columns to be read (default=None, read all columns)
num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset).
num_samples (int, optional): number of samples(rows) to read (default=None).
If num_samples is None and numRows(parsed from schema) is not exist, read the full dataset;
If num_samples is None and numRows(parsed from schema) is greater than 0, read numRows rows;
If both num_samples and numRows(parsed from schema) are greater than 0, read num_samples rows.
num_parallel_workers (int, optional): number of workers to read the data
(default=None, number set in the config).
shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL).
......@@ -2711,10 +2714,10 @@ class Schema:
"""
def __init__(self, schema_file=None):
self.num_rows = None
if schema_file is None:
self.columns = []
self.dataset_type = ''
self.num_rows = 0
else:
if not os.path.isfile(schema_file) or not os.access(schema_file, os.R_OK):
raise ValueError("The file %s does not exist or permission denied!" % schema_file)
......@@ -2859,6 +2862,9 @@ class Schema:
raise RuntimeError("DatasetType field is missing.")
if self.columns is None:
raise RuntimeError("Columns are missing.")
if self.num_rows is not None:
if not isinstance(self.num_rows, int) or self.num_rows <= 0:
raise ValueError("numRows must be greater than 0")
def __str__(self):
return self.to_json()
......
{
"datasetType": "TF",
"columns": {
"col_sint16": {
"type": "int16",
"rank": 1,
"shape": [1]
},
"col_sint32": {
"type": "int32",
"rank": 1,
"shape": [1]
},
"col_sint64": {
"type": "int64",
"rank": 1,
"shape": [1]
},
"col_float": {
"type": "float32",
"rank": 1,
"shape": [1]
},
"col_1d": {
"type": "int64",
"rank": 1,
"shape": [2]
},
"col_2d": {
"type": "int64",
"rank": 2,
"shape": [2, 2]
},
"col_3d": {
"type": "int64",
"rank": 3,
"shape": [2, 2, 2]
},
"col_binary": {
"type": "uint8",
"rank": 1,
"shape": [1]
}
}
}
{
"datasetType": "TF",
"columns": {
"image": {
"type": "uint8",
"rank": 1,
"t_impl": "cvmat"
},
"label" : {
"type": "uint64",
"rank": 1,
"t_impl": "flex"
}
}
}
......@@ -37,3 +37,15 @@ def test_case_storage():
filename = "storage_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
def test_case_no_rows():
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetNoRowsSchema.json"
dataset = ds.StorageDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
assert dataset.get_dataset_size() == 3
count = 0
for data in dataset.create_tuple_iterator():
count += 1
assert count == 3
......@@ -37,6 +37,36 @@ def test_case_tf_shape():
assert (len(output_shape[-1]) == 1)
def test_case_tf_read_all_dataset():
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json"
ds1 = ds.TFRecordDataset(FILES, schema_file)
assert ds1.get_dataset_size() == 12
count = 0
for data in ds1.create_tuple_iterator():
count += 1
assert count == 12
def test_case_num_samples():
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json"
ds1 = ds.TFRecordDataset(FILES, schema_file, num_samples=8)
assert ds1.get_dataset_size() == 8
count = 0
for data in ds1.create_dict_iterator():
count += 1
assert count == 8
def test_case_num_samples2():
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json"
ds1 = ds.TFRecordDataset(FILES, schema_file)
assert ds1.get_dataset_size() == 7
count = 0
for data in ds1.create_dict_iterator():
count += 1
assert count == 7
def test_case_tf_shape_2():
ds1 = ds.TFRecordDataset(FILES, SCHEMA_FILE)
ds1 = ds1.batch(2)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册