From db80f4ff928213b08ec2e49b21a90c2e707a7467 Mon Sep 17 00:00:00 2001 From: qianlong Date: Mon, 20 Apr 2020 21:27:11 +0800 Subject: [PATCH] The num_samples and numRows in schema for TFRecordDataset are conflict --- .../datasetops/source/storage_client.cc | 6 ++- .../engine/datasetops/source/tf_reader_op.cc | 3 ++ mindspore/dataset/engine/datasets.py | 12 +++-- .../datasetSchemaNoRow.json | 45 +++++++++++++++++++ .../datasetNoRowsSchema.json | 15 +++++++ tests/ut/python/dataset/test_storage.py | 12 +++++ tests/ut/python/dataset/test_tfreader_op.py | 30 +++++++++++++ 7 files changed, 119 insertions(+), 4 deletions(-) create mode 100644 tests/ut/data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json create mode 100644 tests/ut/data/dataset/test_tf_file_3_images/datasetNoRowsSchema.json diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/storage_client.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/storage_client.cc index 862edcf63..7f081af2b 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/storage_client.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/storage_client.cc @@ -162,7 +162,11 @@ Status StorageClient::numRowsFromFile(uint32_t &num_rows) const { std::ifstream in(schemaFile); nlohmann::json js; in >> js; - num_rows = js.value("numRows", 0); + if (js.find("numRows") == js.end()) { + num_rows = MAX_INTEGER_INT32; + } else { + num_rows = js.value("numRows", 0); + } if (num_rows == 0) { std::string err_msg = "Storage client has not properly done dataset " diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc index a72be1f70..6132f628d 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc @@ -163,6 +163,9 @@ Status TFReaderOp::Init() { if (total_rows_ == 0) { total_rows_ = data_schema_->num_rows(); } + if (total_rows_ < 0) { + RETURN_STATUS_UNEXPECTED("The num_sample or numRows for TFRecordDataset should be greater than 0"); + } // Build the index with our files such that each file corresponds to a key id. RETURN_IF_NOT_OK(filename_index_->insert(dataset_files_list_)); diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 28697a6c4..855e4609b 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -1455,7 +1455,7 @@ class StorageDataset(SourceDataset): Args: dataset_files (list[str]): List of files to be read. - schema (str): Path to the json schema file. + schema (str): Path to the json schema file. If numRows(parsed from schema) is not exist, read the full dataset. distribution (str, optional): Path of distribution config file (default=""). columns_list (list[str], optional): List of columns to be read (default=None, read all columns). num_parallel_workers (int, optional): Number of parallel working threads (default=None). @@ -2193,7 +2193,10 @@ class TFRecordDataset(SourceDataset): schema (str or Schema, optional): Path to the json schema file or schema object (default=None). If the schema is not provided, the meta data from the TFData file is considered the schema. columns_list (list[str], optional): List of columns to be read (default=None, read all columns) - num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset). + num_samples (int, optional): number of samples(rows) to read (default=None). + If num_samples is None and numRows(parsed from schema) is not exist, read the full dataset; + If num_samples is None and numRows(parsed from schema) is greater than 0, read numRows rows; + If both num_samples and numRows(parsed from schema) are greater than 0, read num_samples rows. num_parallel_workers (int, optional): number of workers to read the data (default=None, number set in the config). shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL). @@ -2711,10 +2714,10 @@ class Schema: """ def __init__(self, schema_file=None): + self.num_rows = None if schema_file is None: self.columns = [] self.dataset_type = '' - self.num_rows = 0 else: if not os.path.isfile(schema_file) or not os.access(schema_file, os.R_OK): raise ValueError("The file %s does not exist or permission denied!" % schema_file) @@ -2859,6 +2862,9 @@ class Schema: raise RuntimeError("DatasetType field is missing.") if self.columns is None: raise RuntimeError("Columns are missing.") + if self.num_rows is not None: + if not isinstance(self.num_rows, int) or self.num_rows <= 0: + raise ValueError("numRows must be greater than 0") def __str__(self): return self.to_json() diff --git a/tests/ut/data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json b/tests/ut/data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json new file mode 100644 index 000000000..92abf66ef --- /dev/null +++ b/tests/ut/data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json @@ -0,0 +1,45 @@ +{ + "datasetType": "TF", + "columns": { + "col_sint16": { + "type": "int16", + "rank": 1, + "shape": [1] + }, + "col_sint32": { + "type": "int32", + "rank": 1, + "shape": [1] + }, + "col_sint64": { + "type": "int64", + "rank": 1, + "shape": [1] + }, + "col_float": { + "type": "float32", + "rank": 1, + "shape": [1] + }, + "col_1d": { + "type": "int64", + "rank": 1, + "shape": [2] + }, + "col_2d": { + "type": "int64", + "rank": 2, + "shape": [2, 2] + }, + "col_3d": { + "type": "int64", + "rank": 3, + "shape": [2, 2, 2] + }, + "col_binary": { + "type": "uint8", + "rank": 1, + "shape": [1] + } + } +} diff --git a/tests/ut/data/dataset/test_tf_file_3_images/datasetNoRowsSchema.json b/tests/ut/data/dataset/test_tf_file_3_images/datasetNoRowsSchema.json new file mode 100644 index 000000000..e00fd39c1 --- /dev/null +++ b/tests/ut/data/dataset/test_tf_file_3_images/datasetNoRowsSchema.json @@ -0,0 +1,15 @@ +{ + "datasetType": "TF", + "columns": { + "image": { + "type": "uint8", + "rank": 1, + "t_impl": "cvmat" + }, + "label" : { + "type": "uint64", + "rank": 1, + "t_impl": "flex" + } + } +} diff --git a/tests/ut/python/dataset/test_storage.py b/tests/ut/python/dataset/test_storage.py index b37a52f37..92a689a68 100644 --- a/tests/ut/python/dataset/test_storage.py +++ b/tests/ut/python/dataset/test_storage.py @@ -37,3 +37,15 @@ def test_case_storage(): filename = "storage_result.npz" save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) + + +def test_case_no_rows(): + DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] + SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetNoRowsSchema.json" + + dataset = ds.StorageDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"]) + assert dataset.get_dataset_size() == 3 + count = 0 + for data in dataset.create_tuple_iterator(): + count += 1 + assert count == 3 diff --git a/tests/ut/python/dataset/test_tfreader_op.py b/tests/ut/python/dataset/test_tfreader_op.py index 3add50e1c..c5d9471f8 100644 --- a/tests/ut/python/dataset/test_tfreader_op.py +++ b/tests/ut/python/dataset/test_tfreader_op.py @@ -37,6 +37,36 @@ def test_case_tf_shape(): assert (len(output_shape[-1]) == 1) +def test_case_tf_read_all_dataset(): + schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json" + ds1 = ds.TFRecordDataset(FILES, schema_file) + assert ds1.get_dataset_size() == 12 + count = 0 + for data in ds1.create_tuple_iterator(): + count += 1 + assert count == 12 + + +def test_case_num_samples(): + schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json" + ds1 = ds.TFRecordDataset(FILES, schema_file, num_samples=8) + assert ds1.get_dataset_size() == 8 + count = 0 + for data in ds1.create_dict_iterator(): + count += 1 + assert count == 8 + + +def test_case_num_samples2(): + schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json" + ds1 = ds.TFRecordDataset(FILES, schema_file) + assert ds1.get_dataset_size() == 7 + count = 0 + for data in ds1.create_dict_iterator(): + count += 1 + assert count == 7 + + def test_case_tf_shape_2(): ds1 = ds.TFRecordDataset(FILES, SCHEMA_FILE) ds1 = ds1.batch(2) -- GitLab