From 1eda0ef071f43932d421de7906bafdc0532d13af Mon Sep 17 00:00:00 2001 From: jiangzhiwen Date: Wed, 29 Jul 2020 15:32:47 +0800 Subject: [PATCH] change num_samples definition --- .../minddata/dataset/engine/datasetops/source/csv_op.cc | 4 ++-- mindspore/dataset/engine/datasets.py | 6 +++--- mindspore/dataset/engine/validators.py | 6 +++++- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc index 28eb87a55..37d957ba2 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc @@ -27,7 +27,7 @@ namespace mindspore { namespace dataset { CsvOp::Builder::Builder() - : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) { + : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(-1), builder_shuffle_files_(false) { std::shared_ptr config_manager = GlobalContext::config_manager(); builder_num_workers_ = config_manager->num_parallel_workers(); builder_op_connector_size_ = config_manager->op_connector_size(); @@ -451,7 +451,7 @@ Status CsvOp::operator()() { RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer)); if (buffer->eoe()) { workers_done++; - } else if (num_samples_ == 0 || rows_read < num_samples_) { + } else if (num_samples_ == -1 || rows_read < num_samples_) { if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) { int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read); RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove)); diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index fa9c1099f..c9cd4f609 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -4935,7 +4935,7 @@ class CSVDataset(SourceDataset): columns as string type. column_names (list[str], optional): List of column names of the dataset (default=None). If this is not provided, infers the column_names from the first row of CSV file. - num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset). + num_samples (int, optional): number of samples(rows) to read (default=-1, reads the full dataset). num_parallel_workers (int, optional): number of workers to read the data (default=None, number set in the config). shuffle (Union[bool, Shuffle level], optional): perform reshuffling of the data every epoch @@ -4959,7 +4959,7 @@ class CSVDataset(SourceDataset): """ @check_csvdataset - def __init__(self, dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=None, + def __init__(self, dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=-1, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None): super().__init__(num_parallel_workers) self.dataset_files = self._find_files(dataset_files) @@ -5010,7 +5010,7 @@ class CSVDataset(SourceDataset): if self._dataset_size is None: num_rows = CsvOp.get_num_rows(self.dataset_files, self.column_names is None) num_rows = get_num_rows(num_rows, self.num_shards) - if self.num_samples is None: + if self.num_samples == -1: return num_rows return min(self.num_samples, num_rows) return self._dataset_size diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index a9a61c113..8fbc569bb 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -813,12 +813,16 @@ def check_csvdataset(method): def new_method(self, *args, **kwargs): _, param_dict = parse_user_args(method, *args, **kwargs) - nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] + nreq_param_int = ['num_parallel_workers', 'num_shards', 'shard_id'] # check dataset_files; required argument dataset_files = param_dict.get('dataset_files') type_check(dataset_files, (str, list), "dataset files") + # check num_samples + num_samples = param_dict.get('num_samples') + check_value(num_samples, [-1, INT32_MAX], "num_samples") + # check field_delim field_delim = param_dict.get('field_delim') type_check(field_delim, (str,), 'field delim') -- GitLab